summaryrefslogtreecommitdiff
path: root/modules/remote-sycl/c++/transfer.hpp
diff options
context:
space:
mode:
authorClaudius "keldu" Holeksa <mail@keldu.de>2024-06-26 09:39:34 +0200
committerClaudius "keldu" Holeksa <mail@keldu.de>2024-06-26 09:39:34 +0200
commit729307460e77f62a532ee9841dcaed9c47f46419 (patch)
tree0b52ddbfa47d9d148907de90e7a2987d72ed7d73 /modules/remote-sycl/c++/transfer.hpp
parent51b50882d2906b83c5275c732a56ff333ae6696f (diff)
Added better structure for the data server
Diffstat (limited to 'modules/remote-sycl/c++/transfer.hpp')
-rw-r--r--modules/remote-sycl/c++/transfer.hpp118
1 files changed, 102 insertions, 16 deletions
diff --git a/modules/remote-sycl/c++/transfer.hpp b/modules/remote-sycl/c++/transfer.hpp
index 6849caa..65a9b9e 100644
--- a/modules/remote-sycl/c++/transfer.hpp
+++ b/modules/remote-sycl/c++/transfer.hpp
@@ -2,11 +2,26 @@
#include "common.hpp"
#include "data.hpp"
+#include "device.hpp"
+
#include <forstio/error.hpp>
+#include <forstio/reduce_templates.hpp>
namespace saw {
-template<typename Schema, typename Encoding>
-class data_server<Schema, Encoding, rmt::Sycl> {
+namespace impl {
+template<typename Encoding, typename T>
+struct data_server_redux {
+ using type = std::tuple<>;
+};
+
+template<typename Encoding, typename... Schema>
+struct data_server_redux<Encoding, tmpl_group<Schema...>> {
+ using type = std::tuple<std::unordered_map<uint64_t, data<Schema, Encoding, rmt::Sycl>>...>;
+};
+}
+
+template<typename... Schema, typename Encoding>
+class data_server<tmpl_group<Schema...>, Encoding, rmt::Sycl> {
private:
/**
* Device context class
@@ -16,7 +31,7 @@ private:
/**
* Store for the data the server manages.
*/
- std::unordered_map<uint64_t, data<Schema, Encoding, rmt::Sycl>> values_;
+ impl::data_server_redux<Encoding, typename tmpl_reduce<tmpl_group<Schema...>>::type >::type values_;
public:
/**
* Main constructor
@@ -26,29 +41,83 @@ public:
{}
/**
- * Receive data which we will store.
+ * Get data which we will store.
*/
- error_or<void> send(const data<Schema, Encoding, storage::Default>& dat, id<Schema> store_id){
- auto eoval = device_->copy_to_device(dat);
+ template<typename Sch>
+ error_or<void> send(const data<Sch, Encoding, storage::Default>& dat, id<Sch> store_id){
+ auto& vals = std::get<Sch>(values_);
+ auto eoval = device_->template copy_to_device<Sch, Encoding, storage::Default>(dat);
if(eoval.is_error()){
auto& err = eoval.get_error();
return std::move(err);
}
- return make_error<err::not_implemented>();
+ auto& val = eoval.get_value();
+ try {
+ auto insert_res = vals.insert(std::make_pair(store_id.get_value(), std::move(val)));
+ if(!insert_res.second){
+ return make_error<err::already_exists>();
+ }
+ }catch ( std::exception& ){
+ return make_error<err::out_of_memory>();
+ }
+ return void_t{};
+ }
+
+ /**
+ * Requests data from the server
+ */
+ template<typename Sch>
+ error_or<data<Sch, Encoding, storage::Default>> receive(id<Sch> store_id){
+ auto& vals = std::get<Sch>(values_);
+ auto find_res = vals.find(store_id.get_value());
+ if(find_res == vals.end()){
+ return make_error<err::not_found>();
+ }
+ auto& dat = find_res->second;
+
+ auto eoval = device_->copy_to_host(dat);
+ return eoval;
+ }
+
+ /**
+ * Request an erase of the stored data
+ */
+ template<typename Sch>
+ error_or<void> erase(id<Sch> store_id){
+ auto& vals = std::get<Sch>(values_);
+ auto erase_op = vals.erase(store_id.get_value());
+ if(erase_op == 0u){
+ return make_error<err::not_found>();
+ }
+ return void_t{};
}
- error_or<data<Schema, Encoding, storage::Default>> receive(id<Schema> store_id){
- return make_error<err::not_implemented>();
+ /**
+ * Get the stored data on the server side for immediate use.
+ * Insert operations may invalidate the pointer.
+ */
+ template<typename Sch>
+ error_or<data<Sch, Encoding, rmt::Sycl>*> find(id<Sch> store_id){
+ auto& vals = std::get<Sch>(values_);
+ auto find_res = vals.find(store_id.get_value());
+ if(find_res == vals.end()){
+ return make_error<err::not_found>();
+ }
+
+ return &(find_res.second);
}
};
-template<typename Schema, typename Encoding>
-class data_client<Schema, Encoding, rmt::Sycl> {
+/**
+ * Client for transporting data to remote and receiving data back
+ */
+template<typename... Schema, typename Encoding>
+class data_client<tmpl_group<Schema...>, Encoding, rmt::Sycl> {
private:
/**
* Corresponding server for this client
*/
- data_server<Schema, Encoding, rmt::Sycl>* srv_;
+ data_server<tmpl_group<Schema...>, Encoding, rmt::Sycl>* srv_;
/**
* The next id for identifying issues on the remote side.
@@ -58,16 +127,17 @@ public:
/**
* Main constructor
*/
- data_client(data_server<Schema, Encoding, rmt::Sycl>& srv__):
+ data_client(data_server<tmpl_group<Schema...>, Encoding, rmt::Sycl>& srv__):
srv_{&srv__},
next_id_{0u}
{}
/**
- * Send data to.
+ * Send data to the remote.
*/
- error_or<id<Schema>> send(const data<Schema, Encoding, storage::Default>& dat){
- id<Schema> dat_id{next_id_};
+ template<typename Sch>
+ error_or<id<Sch>> send(const data<Sch, Encoding, storage::Default>& dat){
+ id<Sch> dat_id{next_id_};
auto eov = srv_->send(dat, dat_id);
if(eov.is_error()){
auto& err = eov.get_error();
@@ -77,5 +147,21 @@ public:
++next_id_;
return dat_id;
}
+
+ /**
+ * Receive data
+ */
+ template<typename Sch>
+ conveyor<data<Sch, Encoding, storage::Default>> receive(id<Sch> dat_id){
+ return srv_->receive(dat_id);
+ }
+
+ /**
+ * Erase data
+ */
+ template<typename Sch>
+ error_or<void> erase(id<Sch> dat_id){
+ return srv_->erase(dat_id);
+ }
};
}