diff options
author | Claudius "keldu" Holeksa <mail@keldu.de> | 2024-06-26 09:39:34 +0200 |
---|---|---|
committer | Claudius "keldu" Holeksa <mail@keldu.de> | 2024-06-26 09:39:34 +0200 |
commit | 729307460e77f62a532ee9841dcaed9c47f46419 (patch) | |
tree | 0b52ddbfa47d9d148907de90e7a2987d72ed7d73 /modules/remote-sycl | |
parent | 51b50882d2906b83c5275c732a56ff333ae6696f (diff) |
Added better structure for the data server
Diffstat (limited to 'modules/remote-sycl')
-rw-r--r-- | modules/remote-sycl/c++/data.hpp | 23 | ||||
-rw-r--r-- | modules/remote-sycl/c++/device.hpp | 38 | ||||
-rw-r--r-- | modules/remote-sycl/c++/remote.hpp | 138 | ||||
-rw-r--r-- | modules/remote-sycl/c++/transfer.hpp | 118 |
4 files changed, 204 insertions, 113 deletions
diff --git a/modules/remote-sycl/c++/data.hpp b/modules/remote-sycl/c++/data.hpp index 7ecf1ae..3ad1d9c 100644 --- a/modules/remote-sycl/c++/data.hpp +++ b/modules/remote-sycl/c++/data.hpp @@ -4,4 +4,27 @@ namespace saw { +/** + * Generic wrapper class which stores data on the sycl side. + * Most of the times this will be a root object. + */ +template<typename Schema> +class data<Schema, encode::Native, rmt::Sycl> { +private: + cl::sycl::buffer<data<Schema, encode::Native, storage::Default>> data_; +public: + data(data<Schema, encode::Native, storage::Default>& data__): + data_{&data__, 1u} + {} + + auto& get_handle() { + return data_; + } + + template<cl::sycl::access::mode AccessMode> + auto access(cl::sycl::handler& h){ + return data_.template get_access<AccessMode>(h); + } +}; + } diff --git a/modules/remote-sycl/c++/device.hpp b/modules/remote-sycl/c++/device.hpp index 30eed2f..6d133ae 100644 --- a/modules/remote-sycl/c++/device.hpp +++ b/modules/remote-sycl/c++/device.hpp @@ -1,5 +1,43 @@ #pragma once +#include "common.hpp" + namespace saw { +/** + * Represents a remote Sycl device. + */ +template<> +class device<rmt::Sycl> final { +private: + cl::sycl::queue cmd_queue_; +public: + device() = default; + + SAW_FORBID_COPY(device); + SAW_FORBID_MOVE(device); + + /** + * Copy data to device + */ + template<typename Schema, typename Encoding, typename Storage> + error_or<data<Schema, Encoding, rmt::Sycl>> copy_to_device(const data<Schema, Encoding, Storage>& host_data){ + return data<Schema, Encoding, rmt::Sycl>::copy_to_device(host_data, *this); + } + + /** + * Copy data to host + */ + template<typename Schema, typename Encoding, typename Storage> + error_or<data<Schema, Encoding, Storage>> copy_to_host(const data<Schema, Encoding, rmt::Sycl>& dev_data){ + return dev_data.copy_to_host(); + } + + /** + * Get a reference to the handle + */ + cl::sycl::queue& get_handle(){ + return cmd_queue_; + } +}; } diff --git a/modules/remote-sycl/c++/remote.hpp b/modules/remote-sycl/c++/remote.hpp index 2d0eaea..1ae3103 100644 --- a/modules/remote-sycl/c++/remote.hpp +++ b/modules/remote-sycl/c++/remote.hpp @@ -2,28 +2,10 @@ #include "common.hpp" #include "data.hpp" +#include "device.hpp" +#include "transfer.hpp" namespace saw { - -template<typename Schema> -class data<Schema, encode::Native, rmt::Sycl> { -private: - cl::sycl::buffer<data<Schema, encode::Native, storage::Default>> data_; -public: - data(data<Schema, encode::Native, storage::Default>& data__): - data_{&data__, 1u} - {} - - auto& get_handle() { - return data_; - } - - template<cl::sycl::access::mode AccessMode> - auto access(cl::sycl::handler& h){ - return data_.template get_access<AccessMode>(h); - } -}; - /** * Remote data class for the Sycl backend. */ @@ -78,35 +60,22 @@ public: /** * Meant to be a helper object which holds the allocated data on the sycl side */ -template<typename Schema, typename Encoding, typename Backend> -class device_data; +//template<typename Schema, typename Encoding, typename Backend> +//class device_data; /** * This class helps in regards to the ownership on the server side - */ template<typename Schema, typename Encoding> class device_data<Schema, Encoding, rmt::Sycl> { private: - /** - * The actual data - */ data<Schema,Encoding,storage::Default>* device_data_; - /** - * The sycl queue object - */ cl::sycl::queue* queue_; public: - /** - * Main constructor - */ device_data(data<Schema,Encoding,storage::Default>& device_data__, cl::sycl::queue& queue__): device_data_{&device_data__}, queue_{&queue__} {} - /** - * Destructor specifically designed to deallocate on the device. - */ ~device_data(){ if(device_data_){ cl::sycl::free(device_data_,queue_); @@ -117,64 +86,33 @@ public: SAW_FORBID_COPY(device_data); SAW_FORBID_MOVE(device_data); }; - -namespace impl { -template<typename Schema, typename Encoding, typename Backend> -struct device_id_map { - std::vector<device_data<Schema, Encoding, Backend>> data; -}; - -template<typename Iface, typename Encoding, typename Storage> -struct rpc_id_map_helper { - static_assert(always_false<Iface, Encoding,Storage>, "Only supports Interface schema types."); -}; - -template<typename... Members, typename Encoding, typename Storage> -struct rpc_id_map_helper<schema::Interface<Members...>, Encoding, Storage> { - std::tuple<id_map<typename Members::ValueType::ResponseT, Encoding, Storage>...> maps; -}; -} + */ } // Maybe a helper impl tmpl file? namespace saw { -/** - * Represents a remote Sycl device. - */ -template<> -class device<rmt::Sycl> final { -private: - cl::sycl::queue cmd_queue_; -public: - device() = default; - SAW_FORBID_COPY(device); - SAW_FORBID_MOVE(device); +namespace impl { +template<typename Func> +struct rpc_func_type_helper; - /** - * Copy data to device - */ - template<typename Schema, typename Encoding, typename Storage> - error_or<data<Schema, Encoding, rmt::Sycl>> copy_to_device(const data<Schema, Encoding, Storage>& host_data){ - return data<Schema, Encoding, rmt::Sycl>::copy_to_device(host_data, *this); - } - - /** - * Copy data to host - */ - template<typename Schema, typename Encoding, typename Storage> - error_or<data<Schema, Encoding, Storage>> copy_to_host(const data<Schema, Encoding, rmt::Sycl>& device_data){ - return device_data.copy_to_host(); - } - - /** - * Get a reference to the handle - */ - cl::sycl::queue& get_handle(){ - return cmd_queue_; - } +template<typename Response, typename Request> +struct rpc_func_type_helper<schema::Function<Request, Response>>{ + using type = tmpl_group<Response, Request>; +}; + +template<typename Iface> +struct rpc_iface_type_helper { + using type = tmpl_group<>; +}; + +template<typename Func, string_literal K, typename... Functions, string_literal... Keys> +struct rpc_iface_type_helper<schema::Interface<schema::Member<Func,K>,schema::Member<Functions,Keys>...>> { + using inner_type = typename rpc_func_type_helper<Func>::type; + using type = typename tmpl_concat<inner_type, typename rpc_iface_type_helper<schema::Interface<schema::Member<Functions,Keys>...>>::type>::type; }; +} /** * Rpc Client class for the Sycl backend. @@ -189,11 +127,17 @@ private: rpc_server<Iface, Encoding, rmt::Sycl>* srv_; /** + * TransferClient created from the internal RPC data server + */ + data_client<typename impl::rpc_iface_type_helper<Iface>::type, Encoding, rmt::Sycl> data_client_; + + /** * Generated some sort of id for the request. */ public: rpc_client(rpc_server<Iface, Encoding, rmt::Sycl>& srv): - srv_{&srv} + srv_{&srv}, + data_client_{srv_->data_server} {} /** @@ -219,6 +163,7 @@ class rpc_server<Iface, Encoding, rmt::Sycl> { public: using InterfaceCtxT = cl::sycl::queue*; using InterfaceT = interface<Iface, Encoding, storage::Default, InterfaceCtxT>; + private: /** * Device instance enabling the use of the remote device. @@ -226,14 +171,15 @@ private: device<rmt::Sycl>* device_; /** - * The interface including the relevant context class. + * Data server storing the relevant data */ - interface<Iface, Encoding, storage::Default, InterfaceCtxT> cl_interface_; + data_server<typename impl::rpc_iface_type_helper<Iface>::type, Encoding, rmt::Sycl> data_server_; /** - * Basic storage for response data. + * The interface including the relevant context class. */ - impl::rpc_id_map_helper<Iface, Encoding, rmt::Sycl> storage_; + interface<Iface, Encoding, storage::Default, InterfaceCtxT> cl_interface_; + public: /** @@ -241,8 +187,8 @@ public: */ rpc_server(device<rmt::Sycl>& dev__, InterfaceT cl_iface): device_{&dev__}, - cl_interface_{std::move(cl_iface)}, - storage_{} + data_server_{}, + cl_interface_{std::move(cl_iface)} {} /** @@ -276,7 +222,7 @@ public: /** * Object needed if and only if the provided data type is not an id */ - own<device_data<typename FuncT::RequestT, Encoding, rmt::Sycl>> dev_tmp_inp = nullptr; + own<data<typename FuncT::RequestT, Encoding, rmt::Sycl>> dev_tmp_inp = nullptr; /** * First check if it's data or an id. * If it's an id, check if it's registered within the storage and retrieve it. @@ -284,8 +230,7 @@ public: auto eoinp = [&,this]() -> error_or<data<typename FuncT::RequestT, Encoding, rmt::Sycl>* > { if(input.is_id()){ // storage_.maps - auto& inner_map = std::get<id_map<typename FuncT::RequestT, Encoding, rmt::Sycl>> (storage_.maps); - auto eov = inner_map.find(input.get_id()); + auto eov = data_server_.template find<typename FuncT::RequestT>(input.get_id()); if(eov.is_error()){ return std::move(eov.get_error()); } @@ -319,8 +264,7 @@ public: /** * Store returned data in rpc storage */ - auto& inner_map = std::get<id_map<typename schema_member_type<Name, Iface>::type::RequestT, Encoding,rmt::Sycl>> (storage_.maps); - auto eoid = inner_map.insert_as(std::move(val), rpc_id); + auto eoid = data_server_.template insert<typename schema_member_type<Name, Iface>::type::RequestT>(std::move(val), rpc_id); if(eoid.is_error()){ return std::move(eoid.get_error()); } diff --git a/modules/remote-sycl/c++/transfer.hpp b/modules/remote-sycl/c++/transfer.hpp index 6849caa..65a9b9e 100644 --- a/modules/remote-sycl/c++/transfer.hpp +++ b/modules/remote-sycl/c++/transfer.hpp @@ -2,11 +2,26 @@ #include "common.hpp" #include "data.hpp" +#include "device.hpp" + #include <forstio/error.hpp> +#include <forstio/reduce_templates.hpp> namespace saw { -template<typename Schema, typename Encoding> -class data_server<Schema, Encoding, rmt::Sycl> { +namespace impl { +template<typename Encoding, typename T> +struct data_server_redux { + using type = std::tuple<>; +}; + +template<typename Encoding, typename... Schema> +struct data_server_redux<Encoding, tmpl_group<Schema...>> { + using type = std::tuple<std::unordered_map<uint64_t, data<Schema, Encoding, rmt::Sycl>>...>; +}; +} + +template<typename... Schema, typename Encoding> +class data_server<tmpl_group<Schema...>, Encoding, rmt::Sycl> { private: /** * Device context class @@ -16,7 +31,7 @@ private: /** * Store for the data the server manages. */ - std::unordered_map<uint64_t, data<Schema, Encoding, rmt::Sycl>> values_; + impl::data_server_redux<Encoding, typename tmpl_reduce<tmpl_group<Schema...>>::type >::type values_; public: /** * Main constructor @@ -26,29 +41,83 @@ public: {} /** - * Receive data which we will store. + * Get data which we will store. */ - error_or<void> send(const data<Schema, Encoding, storage::Default>& dat, id<Schema> store_id){ - auto eoval = device_->copy_to_device(dat); + template<typename Sch> + error_or<void> send(const data<Sch, Encoding, storage::Default>& dat, id<Sch> store_id){ + auto& vals = std::get<Sch>(values_); + auto eoval = device_->template copy_to_device<Sch, Encoding, storage::Default>(dat); if(eoval.is_error()){ auto& err = eoval.get_error(); return std::move(err); } - return make_error<err::not_implemented>(); + auto& val = eoval.get_value(); + try { + auto insert_res = vals.insert(std::make_pair(store_id.get_value(), std::move(val))); + if(!insert_res.second){ + return make_error<err::already_exists>(); + } + }catch ( std::exception& ){ + return make_error<err::out_of_memory>(); + } + return void_t{}; + } + + /** + * Requests data from the server + */ + template<typename Sch> + error_or<data<Sch, Encoding, storage::Default>> receive(id<Sch> store_id){ + auto& vals = std::get<Sch>(values_); + auto find_res = vals.find(store_id.get_value()); + if(find_res == vals.end()){ + return make_error<err::not_found>(); + } + auto& dat = find_res->second; + + auto eoval = device_->copy_to_host(dat); + return eoval; + } + + /** + * Request an erase of the stored data + */ + template<typename Sch> + error_or<void> erase(id<Sch> store_id){ + auto& vals = std::get<Sch>(values_); + auto erase_op = vals.erase(store_id.get_value()); + if(erase_op == 0u){ + return make_error<err::not_found>(); + } + return void_t{}; } - error_or<data<Schema, Encoding, storage::Default>> receive(id<Schema> store_id){ - return make_error<err::not_implemented>(); + /** + * Get the stored data on the server side for immediate use. + * Insert operations may invalidate the pointer. + */ + template<typename Sch> + error_or<data<Sch, Encoding, rmt::Sycl>*> find(id<Sch> store_id){ + auto& vals = std::get<Sch>(values_); + auto find_res = vals.find(store_id.get_value()); + if(find_res == vals.end()){ + return make_error<err::not_found>(); + } + + return &(find_res.second); } }; -template<typename Schema, typename Encoding> -class data_client<Schema, Encoding, rmt::Sycl> { +/** + * Client for transporting data to remote and receiving data back + */ +template<typename... Schema, typename Encoding> +class data_client<tmpl_group<Schema...>, Encoding, rmt::Sycl> { private: /** * Corresponding server for this client */ - data_server<Schema, Encoding, rmt::Sycl>* srv_; + data_server<tmpl_group<Schema...>, Encoding, rmt::Sycl>* srv_; /** * The next id for identifying issues on the remote side. @@ -58,16 +127,17 @@ public: /** * Main constructor */ - data_client(data_server<Schema, Encoding, rmt::Sycl>& srv__): + data_client(data_server<tmpl_group<Schema...>, Encoding, rmt::Sycl>& srv__): srv_{&srv__}, next_id_{0u} {} /** - * Send data to. + * Send data to the remote. */ - error_or<id<Schema>> send(const data<Schema, Encoding, storage::Default>& dat){ - id<Schema> dat_id{next_id_}; + template<typename Sch> + error_or<id<Sch>> send(const data<Sch, Encoding, storage::Default>& dat){ + id<Sch> dat_id{next_id_}; auto eov = srv_->send(dat, dat_id); if(eov.is_error()){ auto& err = eov.get_error(); @@ -77,5 +147,21 @@ public: ++next_id_; return dat_id; } + + /** + * Receive data + */ + template<typename Sch> + conveyor<data<Sch, Encoding, storage::Default>> receive(id<Sch> dat_id){ + return srv_->receive(dat_id); + } + + /** + * Erase data + */ + template<typename Sch> + error_or<void> erase(id<Sch> dat_id){ + return srv_->erase(dat_id); + } }; } |