diff options
author | Claudius "keldu" Holeksa <mail@keldu.de> | 2024-06-26 09:39:34 +0200 |
---|---|---|
committer | Claudius "keldu" Holeksa <mail@keldu.de> | 2024-06-26 09:39:34 +0200 |
commit | 729307460e77f62a532ee9841dcaed9c47f46419 (patch) | |
tree | 0b52ddbfa47d9d148907de90e7a2987d72ed7d73 /modules/remote-sycl/c++/transfer.hpp | |
parent | 51b50882d2906b83c5275c732a56ff333ae6696f (diff) |
Added better structure for the data server
Diffstat (limited to 'modules/remote-sycl/c++/transfer.hpp')
-rw-r--r-- | modules/remote-sycl/c++/transfer.hpp | 118 |
1 files changed, 102 insertions, 16 deletions
diff --git a/modules/remote-sycl/c++/transfer.hpp b/modules/remote-sycl/c++/transfer.hpp index 6849caa..65a9b9e 100644 --- a/modules/remote-sycl/c++/transfer.hpp +++ b/modules/remote-sycl/c++/transfer.hpp @@ -2,11 +2,26 @@ #include "common.hpp" #include "data.hpp" +#include "device.hpp" + #include <forstio/error.hpp> +#include <forstio/reduce_templates.hpp> namespace saw { -template<typename Schema, typename Encoding> -class data_server<Schema, Encoding, rmt::Sycl> { +namespace impl { +template<typename Encoding, typename T> +struct data_server_redux { + using type = std::tuple<>; +}; + +template<typename Encoding, typename... Schema> +struct data_server_redux<Encoding, tmpl_group<Schema...>> { + using type = std::tuple<std::unordered_map<uint64_t, data<Schema, Encoding, rmt::Sycl>>...>; +}; +} + +template<typename... Schema, typename Encoding> +class data_server<tmpl_group<Schema...>, Encoding, rmt::Sycl> { private: /** * Device context class @@ -16,7 +31,7 @@ private: /** * Store for the data the server manages. */ - std::unordered_map<uint64_t, data<Schema, Encoding, rmt::Sycl>> values_; + impl::data_server_redux<Encoding, typename tmpl_reduce<tmpl_group<Schema...>>::type >::type values_; public: /** * Main constructor @@ -26,29 +41,83 @@ public: {} /** - * Receive data which we will store. + * Get data which we will store. */ - error_or<void> send(const data<Schema, Encoding, storage::Default>& dat, id<Schema> store_id){ - auto eoval = device_->copy_to_device(dat); + template<typename Sch> + error_or<void> send(const data<Sch, Encoding, storage::Default>& dat, id<Sch> store_id){ + auto& vals = std::get<Sch>(values_); + auto eoval = device_->template copy_to_device<Sch, Encoding, storage::Default>(dat); if(eoval.is_error()){ auto& err = eoval.get_error(); return std::move(err); } - return make_error<err::not_implemented>(); + auto& val = eoval.get_value(); + try { + auto insert_res = vals.insert(std::make_pair(store_id.get_value(), std::move(val))); + if(!insert_res.second){ + return make_error<err::already_exists>(); + } + }catch ( std::exception& ){ + return make_error<err::out_of_memory>(); + } + return void_t{}; + } + + /** + * Requests data from the server + */ + template<typename Sch> + error_or<data<Sch, Encoding, storage::Default>> receive(id<Sch> store_id){ + auto& vals = std::get<Sch>(values_); + auto find_res = vals.find(store_id.get_value()); + if(find_res == vals.end()){ + return make_error<err::not_found>(); + } + auto& dat = find_res->second; + + auto eoval = device_->copy_to_host(dat); + return eoval; + } + + /** + * Request an erase of the stored data + */ + template<typename Sch> + error_or<void> erase(id<Sch> store_id){ + auto& vals = std::get<Sch>(values_); + auto erase_op = vals.erase(store_id.get_value()); + if(erase_op == 0u){ + return make_error<err::not_found>(); + } + return void_t{}; } - error_or<data<Schema, Encoding, storage::Default>> receive(id<Schema> store_id){ - return make_error<err::not_implemented>(); + /** + * Get the stored data on the server side for immediate use. + * Insert operations may invalidate the pointer. + */ + template<typename Sch> + error_or<data<Sch, Encoding, rmt::Sycl>*> find(id<Sch> store_id){ + auto& vals = std::get<Sch>(values_); + auto find_res = vals.find(store_id.get_value()); + if(find_res == vals.end()){ + return make_error<err::not_found>(); + } + + return &(find_res.second); } }; -template<typename Schema, typename Encoding> -class data_client<Schema, Encoding, rmt::Sycl> { +/** + * Client for transporting data to remote and receiving data back + */ +template<typename... Schema, typename Encoding> +class data_client<tmpl_group<Schema...>, Encoding, rmt::Sycl> { private: /** * Corresponding server for this client */ - data_server<Schema, Encoding, rmt::Sycl>* srv_; + data_server<tmpl_group<Schema...>, Encoding, rmt::Sycl>* srv_; /** * The next id for identifying issues on the remote side. @@ -58,16 +127,17 @@ public: /** * Main constructor */ - data_client(data_server<Schema, Encoding, rmt::Sycl>& srv__): + data_client(data_server<tmpl_group<Schema...>, Encoding, rmt::Sycl>& srv__): srv_{&srv__}, next_id_{0u} {} /** - * Send data to. + * Send data to the remote. */ - error_or<id<Schema>> send(const data<Schema, Encoding, storage::Default>& dat){ - id<Schema> dat_id{next_id_}; + template<typename Sch> + error_or<id<Sch>> send(const data<Sch, Encoding, storage::Default>& dat){ + id<Sch> dat_id{next_id_}; auto eov = srv_->send(dat, dat_id); if(eov.is_error()){ auto& err = eov.get_error(); @@ -77,5 +147,21 @@ public: ++next_id_; return dat_id; } + + /** + * Receive data + */ + template<typename Sch> + conveyor<data<Sch, Encoding, storage::Default>> receive(id<Sch> dat_id){ + return srv_->receive(dat_id); + } + + /** + * Erase data + */ + template<typename Sch> + error_or<void> erase(id<Sch> dat_id){ + return srv_->erase(dat_id); + } }; } |