summaryrefslogtreecommitdiff
path: root/modules/remote-hip/c++/transfer.hpp
diff options
context:
space:
mode:
Diffstat (limited to 'modules/remote-hip/c++/transfer.hpp')
-rw-r--r--modules/remote-hip/c++/transfer.hpp238
1 files changed, 238 insertions, 0 deletions
diff --git a/modules/remote-hip/c++/transfer.hpp b/modules/remote-hip/c++/transfer.hpp
new file mode 100644
index 0000000..8c2cc02
--- /dev/null
+++ b/modules/remote-hip/c++/transfer.hpp
@@ -0,0 +1,238 @@
+#pragma once
+
+#include "common.hpp"
+#include "data.hpp"
+#include "device.hpp"
+
+#include <forstio/error.hpp>
+#include <forstio/reduce_templates.hpp>
+#include <forstio/remote/transfer.hpp>
+
+namespace saw {
+
+template<typename Schema, typename Encoding>
+class data_server<Schema, Encoding, rmt::Hip> final : public i_data_server<rmt::Hip> {
+private:
+ our<device<rmt::Hip>> device_;
+
+ std::map<uint64_t, data<Schema, encode::Sycl<Encoding>>> values_;
+public:
+ data_server(our<device<rmt::Hip>> device__):
+ device_{std::move(device__)}
+ {}
+
+ error_or<void> send(const data<Schema,Encoding>& dat, id<Schema> store_id){
+ auto eo_val = device_->template copy_to_device(dat);
+ if(eo_val.is_error()){
+ auto& err = eo_val.get_error();
+ return std::move(err);
+ }
+ auto& val = eo_val.get_value();
+
+ try {
+ auto insert_res = values_.emplace(std::make_pair(store_id.get_value(), std::move(val)));
+ if(!insert_res.second){
+ return make_error<err::already_exists>();
+ }
+ }catch(const std::exception&){
+ return make_error<err::out_of_memory>();
+ }
+ return make_void();
+ }
+
+ error_or<void> allocate(const data<typename meta_schema<Schema>::MetaSchema, Encoding>& dat, id<Schema> store_id){
+ return make_error<err::not_implemented>("Allocate not implemented");
+ return make_void();
+ }
+
+ error_or<data<Schema,Encoding>> receive(id<Schema> store_id){
+ return make_error<err::not_implemented>("Receive not implemented");
+ }
+
+ error_or<void> erase(id<Schema> store_id){
+ return make_error<err::not_implemented>("Erase not implemented");
+ return make_void();
+ }
+
+ error_or<ptr<data<Schema, encode::Sycl<Encoding>>>> find(id<Schema> store_id){
+ auto find_res = values_.find(store_id.get_value());
+ if(find_res == values_.end()){
+ return make_error<err::not_found>();
+ }
+
+ return {(find_res.second)};
+ }
+};
+
+template<typename... Schema, typename Encoding>
+class data_server<tmpl_group<Schema...>, Encoding, rmt::Hip> {
+private:
+ /**
+ * Device context class
+ */
+ our<device<rmt::Hip>> device_;
+
+ /**
+ * Store for the data the server manages.
+ */
+ typename impl::data_server_redux<encode::Sycl<Encoding>, typename tmpl_reduce<tmpl_group<Schema...>>::type >::type values_;
+public:
+ /**
+ * Main constructor
+ */
+ data_server(our<device<rmt::Hip>> device__):
+ device_{std::move(device__)}
+ {}
+
+ /**
+ * Get data which we will store.
+ */
+ template<typename Sch>
+ error_or<void> send(const data<Sch, Encoding>& dat, id<Sch> store_id){
+ auto& vals = std::get<std::unordered_map<uint64_t, data<Sch,encode::Sycl<Encoding>>>>(values_);
+ auto eoval = device_->template copy_to_device<Sch, Encoding>(dat);
+ if(eoval.is_error()){
+ auto& err = eoval.get_error();
+ return std::move(err);
+ }
+ auto& val = eoval.get_value();
+ try {
+ auto insert_res = vals.insert(std::make_pair(store_id.get_value(), std::move(val)));
+ if(!insert_res.second){
+ return make_error<err::already_exists>();
+ }
+ }catch ( std::exception& ){
+ return make_error<err::out_of_memory>();
+ }
+ return void_t{};
+ }
+
+ template<typename Sch>
+ error_or<void> allocate(const data<typename meta_schema<Sch>::MetaSchema, Encoding>& dat, id<Sch> store_id){
+ auto& vals = std::get<std::unordered_map<uint64_t, data<Sch,encode::Sycl<Encoding>>>>(values_);
+ auto eoval = device_->template allocate_on_device<Sch, Encoding>(dat);
+ if(eoval.is_error()){
+ auto& err = eoval.get_error();
+ return std::move(err);
+ }
+ auto& val = eoval.get_value();
+ try {
+ auto insert_res = vals.insert(std::make_pair(store_id.get_value(), std::move(val)));
+ if(!insert_res.second){
+ return make_error<err::already_exists>();
+ }
+ }catch ( std::exception& ){
+ return make_error<err::out_of_memory>();
+ }
+ return void_t{};
+ }
+
+ /**
+ * Requests data from the server
+ */
+ template<typename Sch>
+ error_or<data<Sch, Encoding>> receive(id<Sch> store_id){
+ auto& vals = std::get<std::unordered_map<uint64_t, data<Sch,encode::Sycl<Encoding>>>>(values_);
+ auto find_res = vals.find(store_id.get_value());
+ if(find_res == vals.end()){
+ return make_error<err::not_found>();
+ }
+ auto& dat = find_res->second;
+
+ auto eoval = device_->template copy_to_host<Sch, Encoding>(dat);
+ return eoval;
+ }
+
+ /**
+ * Request an erase of the stored data
+ */
+ template<typename Sch>
+ error_or<void> erase(id<Sch> store_id){
+ auto& vals = std::get<std::unordered_map<uint64_t, data<Sch,Encoding>>>(values_);
+ auto erase_op = vals.erase(store_id.get_value());
+ if(erase_op == 0u){
+ return make_error<err::not_found>();
+ }
+ return void_t{};
+ }
+
+ /**
+ * Get the stored data on the server side for immediate use.
+ * Insert operations may invalidate the pointer.
+ */
+ template<typename Sch>
+ error_or<data<Sch, encode::Sycl<Encoding>>*> find(id<Sch> store_id){
+ auto& vals = std::get<std::unordered_map<uint64_t, data<Sch,Encoding>>>(values_);
+ auto find_res = vals.find(store_id.get_value());
+ if(find_res == vals.end()){
+ return make_error<err::not_found>();
+ }
+
+ return &(find_res.second);
+ }
+};
+
+/**
+ * Client for transporting data to remote and receiving data back
+ */
+template<typename... Schema, typename Encoding>
+class data_client<tmpl_group<Schema...>, Encoding, rmt::Hip> {
+private:
+ /**
+ * Corresponding server for this client
+ */
+ data_server<tmpl_group<Schema...>, Encoding, rmt::Hip>* srv_;
+
+ /**
+ * The next id for identifying issues on the remote side.
+ */
+ uint64_t next_id_;
+public:
+ /**
+ * Main constructor
+ */
+ data_client(data_server<tmpl_group<Schema...>, Encoding, rmt::Hip>& srv__):
+ srv_{&srv__},
+ next_id_{0u}
+ {}
+
+ /**
+ * Send data to the remote.
+ */
+ template<typename Sch>
+ error_or<id<Sch>> send(const data<Sch, Encoding>& dat){
+ id<Sch> dat_id{next_id_};
+ auto eov = srv_->send(dat, dat_id);
+ if(eov.is_error()){
+ auto& err = eov.get_error();
+ return std::move(err);
+ }
+
+ ++next_id_;
+ return dat_id;
+ }
+
+ /**
+ * Receive data
+ */
+ template<typename Sch>
+ conveyor<data<Sch, Encoding>> receive(id<Sch> dat_id){
+ auto eov = srv_->receive(dat_id);
+ if(eov.is_error()){
+ auto& err = eov.get_error();
+ return std::move(err);
+ }
+
+ auto& val = eov.get_value();
+ return std::move(val);
+ }
+
+ /**
+ * Erase data
+ */
+ template<typename Sch>
+ error_or<void> erase(id<Sch> dat_id){
+ return srv_->erase(dat_id);
+ }
+};
+}