From 729307460e77f62a532ee9841dcaed9c47f46419 Mon Sep 17 00:00:00 2001 From: "Claudius \"keldu\" Holeksa" Date: Wed, 26 Jun 2024 09:39:34 +0200 Subject: Added better structure for the data server --- modules/core/c++/reduce_templates.hpp | 41 ++++++++++ modules/core/c++/templates.hpp | 14 ++++ modules/core/tests/core.cpp | 23 ++++++ modules/remote-sycl/c++/data.hpp | 23 ++++++ modules/remote-sycl/c++/device.hpp | 38 ++++++++++ modules/remote-sycl/c++/remote.hpp | 138 ++++++++++------------------------ modules/remote-sycl/c++/transfer.hpp | 118 +++++++++++++++++++++++++---- 7 files changed, 282 insertions(+), 113 deletions(-) create mode 100644 modules/core/c++/reduce_templates.hpp (limited to 'modules') diff --git a/modules/core/c++/reduce_templates.hpp b/modules/core/c++/reduce_templates.hpp new file mode 100644 index 0000000..ef5fa4c --- /dev/null +++ b/modules/core/c++/reduce_templates.hpp @@ -0,0 +1,41 @@ +#pragma once + +#include "templates.hpp" + +namespace saw { + +namespace impl { +template +struct tmpl_group_reduce_match { + static constexpr bool has_type = false; + using type = saw::tmpl_group; +}; + +template +struct tmpl_group_reduce_match> { + static constexpr bool has_type = std::is_same_v or tmpl_group_reduce_match>::has_type; + + using type = typename std::conditional, tmpl_group>::type; +}; + +template +struct tmpl_group_reduce { + using reduced_type = T; +}; + +/** + * Reducing in outer loop + */ +template +struct tmpl_group_reduce> { + using reduced_inner_list = typename tmpl_group_reduce>::reduced_type; + + using reduced_type = typename tmpl_group_reduce_match::type; +}; +} + +template +struct tmpl_reduce { + using type = typename impl::tmpl_group_reduce::reduced_type; +}; +} diff --git a/modules/core/c++/templates.hpp b/modules/core/c++/templates.hpp index acbaeb0..70836ae 100644 --- a/modules/core/c++/templates.hpp +++ b/modules/core/c++/templates.hpp @@ -5,6 +5,20 @@ namespace saw { +/** + * This type is meant for grouping of template types + */ +template +struct tmpl_group {}; + +template +struct tmpl_concat; + +template +struct tmpl_concat, tmpl_group> { + using type = tmpl_group; +}; + template struct parameter_pack_index; template struct parameter_pack_index { diff --git a/modules/core/tests/core.cpp b/modules/core/tests/core.cpp index b1ce741..f418a07 100644 --- a/modules/core/tests/core.cpp +++ b/modules/core/tests/core.cpp @@ -1,6 +1,7 @@ #include "../c++/test/suite.hpp" #include "../c++/id.hpp" #include "../c++/string_literal.hpp" +#include "../c++/reduce_templates.hpp" namespace { SAW_TEST("ID functionality") { @@ -37,4 +38,26 @@ SAW_TEST("String Literal Append"){ SAW_EXPECT(c == "foobar", "CT String sum is not \"foobar\""); } + +SAW_TEST("Template Group Reduction"){ + using namespace saw; + + struct foo { + std::string name = "foo"; + }; + struct bar { + std::string name = "bar"; + }; + struct baz { + std::string name = "baz"; + }; + + using grp = tmpl_group; + using red_grp = tmpl_group; + + using alg_red_grp = tmpl_reduce::type; + + static_assert(std::is_same_v, "Should be same type"); + SAW_EXPECT((std::is_same_v), "Should be same type"); +} } diff --git a/modules/remote-sycl/c++/data.hpp b/modules/remote-sycl/c++/data.hpp index 7ecf1ae..3ad1d9c 100644 --- a/modules/remote-sycl/c++/data.hpp +++ b/modules/remote-sycl/c++/data.hpp @@ -4,4 +4,27 @@ namespace saw { +/** + * Generic wrapper class which stores data on the sycl side. + * Most of the times this will be a root object. + */ +template +class data { +private: + cl::sycl::buffer> data_; +public: + data(data& data__): + data_{&data__, 1u} + {} + + auto& get_handle() { + return data_; + } + + template + auto access(cl::sycl::handler& h){ + return data_.template get_access(h); + } +}; + } diff --git a/modules/remote-sycl/c++/device.hpp b/modules/remote-sycl/c++/device.hpp index 30eed2f..6d133ae 100644 --- a/modules/remote-sycl/c++/device.hpp +++ b/modules/remote-sycl/c++/device.hpp @@ -1,5 +1,43 @@ #pragma once +#include "common.hpp" + namespace saw { +/** + * Represents a remote Sycl device. + */ +template<> +class device final { +private: + cl::sycl::queue cmd_queue_; +public: + device() = default; + + SAW_FORBID_COPY(device); + SAW_FORBID_MOVE(device); + + /** + * Copy data to device + */ + template + error_or> copy_to_device(const data& host_data){ + return data::copy_to_device(host_data, *this); + } + + /** + * Copy data to host + */ + template + error_or> copy_to_host(const data& dev_data){ + return dev_data.copy_to_host(); + } + + /** + * Get a reference to the handle + */ + cl::sycl::queue& get_handle(){ + return cmd_queue_; + } +}; } diff --git a/modules/remote-sycl/c++/remote.hpp b/modules/remote-sycl/c++/remote.hpp index 2d0eaea..1ae3103 100644 --- a/modules/remote-sycl/c++/remote.hpp +++ b/modules/remote-sycl/c++/remote.hpp @@ -2,28 +2,10 @@ #include "common.hpp" #include "data.hpp" +#include "device.hpp" +#include "transfer.hpp" namespace saw { - -template -class data { -private: - cl::sycl::buffer> data_; -public: - data(data& data__): - data_{&data__, 1u} - {} - - auto& get_handle() { - return data_; - } - - template - auto access(cl::sycl::handler& h){ - return data_.template get_access(h); - } -}; - /** * Remote data class for the Sycl backend. */ @@ -78,35 +60,22 @@ public: /** * Meant to be a helper object which holds the allocated data on the sycl side */ -template -class device_data; +//template +//class device_data; /** * This class helps in regards to the ownership on the server side - */ template class device_data { private: - /** - * The actual data - */ data* device_data_; - /** - * The sycl queue object - */ cl::sycl::queue* queue_; public: - /** - * Main constructor - */ device_data(data& device_data__, cl::sycl::queue& queue__): device_data_{&device_data__}, queue_{&queue__} {} - /** - * Destructor specifically designed to deallocate on the device. - */ ~device_data(){ if(device_data_){ cl::sycl::free(device_data_,queue_); @@ -117,64 +86,33 @@ public: SAW_FORBID_COPY(device_data); SAW_FORBID_MOVE(device_data); }; - -namespace impl { -template -struct device_id_map { - std::vector> data; -}; - -template -struct rpc_id_map_helper { - static_assert(always_false, "Only supports Interface schema types."); -}; - -template -struct rpc_id_map_helper, Encoding, Storage> { - std::tuple...> maps; -}; -} + */ } // Maybe a helper impl tmpl file? namespace saw { -/** - * Represents a remote Sycl device. - */ -template<> -class device final { -private: - cl::sycl::queue cmd_queue_; -public: - device() = default; - SAW_FORBID_COPY(device); - SAW_FORBID_MOVE(device); +namespace impl { +template +struct rpc_func_type_helper; - /** - * Copy data to device - */ - template - error_or> copy_to_device(const data& host_data){ - return data::copy_to_device(host_data, *this); - } - - /** - * Copy data to host - */ - template - error_or> copy_to_host(const data& device_data){ - return device_data.copy_to_host(); - } - - /** - * Get a reference to the handle - */ - cl::sycl::queue& get_handle(){ - return cmd_queue_; - } +template +struct rpc_func_type_helper>{ + using type = tmpl_group; +}; + +template +struct rpc_iface_type_helper { + using type = tmpl_group<>; +}; + +template +struct rpc_iface_type_helper,schema::Member...>> { + using inner_type = typename rpc_func_type_helper::type; + using type = typename tmpl_concat...>>::type>::type; }; +} /** * Rpc Client class for the Sycl backend. @@ -188,12 +126,18 @@ private: */ rpc_server* srv_; + /** + * TransferClient created from the internal RPC data server + */ + data_client::type, Encoding, rmt::Sycl> data_client_; + /** * Generated some sort of id for the request. */ public: rpc_client(rpc_server& srv): - srv_{&srv} + srv_{&srv}, + data_client_{srv_->data_server} {} /** @@ -219,6 +163,7 @@ class rpc_server { public: using InterfaceCtxT = cl::sycl::queue*; using InterfaceT = interface; + private: /** * Device instance enabling the use of the remote device. @@ -226,14 +171,15 @@ private: device* device_; /** - * The interface including the relevant context class. + * Data server storing the relevant data */ - interface cl_interface_; + data_server::type, Encoding, rmt::Sycl> data_server_; /** - * Basic storage for response data. + * The interface including the relevant context class. */ - impl::rpc_id_map_helper storage_; + interface cl_interface_; + public: /** @@ -241,8 +187,8 @@ public: */ rpc_server(device& dev__, InterfaceT cl_iface): device_{&dev__}, - cl_interface_{std::move(cl_iface)}, - storage_{} + data_server_{}, + cl_interface_{std::move(cl_iface)} {} /** @@ -276,7 +222,7 @@ public: /** * Object needed if and only if the provided data type is not an id */ - own> dev_tmp_inp = nullptr; + own> dev_tmp_inp = nullptr; /** * First check if it's data or an id. * If it's an id, check if it's registered within the storage and retrieve it. @@ -284,8 +230,7 @@ public: auto eoinp = [&,this]() -> error_or* > { if(input.is_id()){ // storage_.maps - auto& inner_map = std::get> (storage_.maps); - auto eov = inner_map.find(input.get_id()); + auto eov = data_server_.template find(input.get_id()); if(eov.is_error()){ return std::move(eov.get_error()); } @@ -319,8 +264,7 @@ public: /** * Store returned data in rpc storage */ - auto& inner_map = std::get::type::RequestT, Encoding,rmt::Sycl>> (storage_.maps); - auto eoid = inner_map.insert_as(std::move(val), rpc_id); + auto eoid = data_server_.template insert::type::RequestT>(std::move(val), rpc_id); if(eoid.is_error()){ return std::move(eoid.get_error()); } diff --git a/modules/remote-sycl/c++/transfer.hpp b/modules/remote-sycl/c++/transfer.hpp index 6849caa..65a9b9e 100644 --- a/modules/remote-sycl/c++/transfer.hpp +++ b/modules/remote-sycl/c++/transfer.hpp @@ -2,11 +2,26 @@ #include "common.hpp" #include "data.hpp" +#include "device.hpp" + #include +#include namespace saw { -template -class data_server { +namespace impl { +template +struct data_server_redux { + using type = std::tuple<>; +}; + +template +struct data_server_redux> { + using type = std::tuple>...>; +}; +} + +template +class data_server, Encoding, rmt::Sycl> { private: /** * Device context class @@ -16,7 +31,7 @@ private: /** * Store for the data the server manages. */ - std::unordered_map> values_; + impl::data_server_redux>::type >::type values_; public: /** * Main constructor @@ -26,29 +41,83 @@ public: {} /** - * Receive data which we will store. + * Get data which we will store. */ - error_or send(const data& dat, id store_id){ - auto eoval = device_->copy_to_device(dat); + template + error_or send(const data& dat, id store_id){ + auto& vals = std::get(values_); + auto eoval = device_->template copy_to_device(dat); if(eoval.is_error()){ auto& err = eoval.get_error(); return std::move(err); } - return make_error(); + auto& val = eoval.get_value(); + try { + auto insert_res = vals.insert(std::make_pair(store_id.get_value(), std::move(val))); + if(!insert_res.second){ + return make_error(); + } + }catch ( std::exception& ){ + return make_error(); + } + return void_t{}; + } + + /** + * Requests data from the server + */ + template + error_or> receive(id store_id){ + auto& vals = std::get(values_); + auto find_res = vals.find(store_id.get_value()); + if(find_res == vals.end()){ + return make_error(); + } + auto& dat = find_res->second; + + auto eoval = device_->copy_to_host(dat); + return eoval; + } + + /** + * Request an erase of the stored data + */ + template + error_or erase(id store_id){ + auto& vals = std::get(values_); + auto erase_op = vals.erase(store_id.get_value()); + if(erase_op == 0u){ + return make_error(); + } + return void_t{}; } - error_or> receive(id store_id){ - return make_error(); + /** + * Get the stored data on the server side for immediate use. + * Insert operations may invalidate the pointer. + */ + template + error_or*> find(id store_id){ + auto& vals = std::get(values_); + auto find_res = vals.find(store_id.get_value()); + if(find_res == vals.end()){ + return make_error(); + } + + return &(find_res.second); } }; -template -class data_client { +/** + * Client for transporting data to remote and receiving data back + */ +template +class data_client, Encoding, rmt::Sycl> { private: /** * Corresponding server for this client */ - data_server* srv_; + data_server, Encoding, rmt::Sycl>* srv_; /** * The next id for identifying issues on the remote side. @@ -58,16 +127,17 @@ public: /** * Main constructor */ - data_client(data_server& srv__): + data_client(data_server, Encoding, rmt::Sycl>& srv__): srv_{&srv__}, next_id_{0u} {} /** - * Send data to. + * Send data to the remote. */ - error_or> send(const data& dat){ - id dat_id{next_id_}; + template + error_or> send(const data& dat){ + id dat_id{next_id_}; auto eov = srv_->send(dat, dat_id); if(eov.is_error()){ auto& err = eov.get_error(); @@ -77,5 +147,21 @@ public: ++next_id_; return dat_id; } + + /** + * Receive data + */ + template + conveyor> receive(id dat_id){ + return srv_->receive(dat_id); + } + + /** + * Erase data + */ + template + error_or erase(id dat_id){ + return srv_->erase(dat_id); + } }; } -- cgit v1.2.3