diff options
Diffstat (limited to 'modules/remote-hip/c++/transfer.hpp')
-rw-r--r-- | modules/remote-hip/c++/transfer.hpp | 238 |
1 files changed, 238 insertions, 0 deletions
diff --git a/modules/remote-hip/c++/transfer.hpp b/modules/remote-hip/c++/transfer.hpp new file mode 100644 index 0000000..8c2cc02 --- /dev/null +++ b/modules/remote-hip/c++/transfer.hpp @@ -0,0 +1,238 @@ +#pragma once + +#include "common.hpp" +#include "data.hpp" +#include "device.hpp" + +#include <forstio/error.hpp> +#include <forstio/reduce_templates.hpp> +#include <forstio/remote/transfer.hpp> + +namespace saw { + +template<typename Schema, typename Encoding> +class data_server<Schema, Encoding, rmt::Hip> final : public i_data_server<rmt::Hip> { +private: + our<device<rmt::Hip>> device_; + + std::map<uint64_t, data<Schema, encode::Sycl<Encoding>>> values_; +public: + data_server(our<device<rmt::Hip>> device__): + device_{std::move(device__)} + {} + + error_or<void> send(const data<Schema,Encoding>& dat, id<Schema> store_id){ + auto eo_val = device_->template copy_to_device(dat); + if(eo_val.is_error()){ + auto& err = eo_val.get_error(); + return std::move(err); + } + auto& val = eo_val.get_value(); + + try { + auto insert_res = values_.emplace(std::make_pair(store_id.get_value(), std::move(val))); + if(!insert_res.second){ + return make_error<err::already_exists>(); + } + }catch(const std::exception&){ + return make_error<err::out_of_memory>(); + } + return make_void(); + } + + error_or<void> allocate(const data<typename meta_schema<Schema>::MetaSchema, Encoding>& dat, id<Schema> store_id){ + return make_error<err::not_implemented>("Allocate not implemented"); + return make_void(); + } + + error_or<data<Schema,Encoding>> receive(id<Schema> store_id){ + return make_error<err::not_implemented>("Receive not implemented"); + } + + error_or<void> erase(id<Schema> store_id){ + return make_error<err::not_implemented>("Erase not implemented"); + return make_void(); + } + + error_or<ptr<data<Schema, encode::Sycl<Encoding>>>> find(id<Schema> store_id){ + auto find_res = values_.find(store_id.get_value()); + if(find_res == values_.end()){ + return make_error<err::not_found>(); + } + + return {(find_res.second)}; + } +}; + +template<typename... Schema, typename Encoding> +class data_server<tmpl_group<Schema...>, Encoding, rmt::Hip> { +private: + /** + * Device context class + */ + our<device<rmt::Hip>> device_; + + /** + * Store for the data the server manages. + */ + typename impl::data_server_redux<encode::Sycl<Encoding>, typename tmpl_reduce<tmpl_group<Schema...>>::type >::type values_; +public: + /** + * Main constructor + */ + data_server(our<device<rmt::Hip>> device__): + device_{std::move(device__)} + {} + + /** + * Get data which we will store. + */ + template<typename Sch> + error_or<void> send(const data<Sch, Encoding>& dat, id<Sch> store_id){ + auto& vals = std::get<std::unordered_map<uint64_t, data<Sch,encode::Sycl<Encoding>>>>(values_); + auto eoval = device_->template copy_to_device<Sch, Encoding>(dat); + if(eoval.is_error()){ + auto& err = eoval.get_error(); + return std::move(err); + } + auto& val = eoval.get_value(); + try { + auto insert_res = vals.insert(std::make_pair(store_id.get_value(), std::move(val))); + if(!insert_res.second){ + return make_error<err::already_exists>(); + } + }catch ( std::exception& ){ + return make_error<err::out_of_memory>(); + } + return void_t{}; + } + + template<typename Sch> + error_or<void> allocate(const data<typename meta_schema<Sch>::MetaSchema, Encoding>& dat, id<Sch> store_id){ + auto& vals = std::get<std::unordered_map<uint64_t, data<Sch,encode::Sycl<Encoding>>>>(values_); + auto eoval = device_->template allocate_on_device<Sch, Encoding>(dat); + if(eoval.is_error()){ + auto& err = eoval.get_error(); + return std::move(err); + } + auto& val = eoval.get_value(); + try { + auto insert_res = vals.insert(std::make_pair(store_id.get_value(), std::move(val))); + if(!insert_res.second){ + return make_error<err::already_exists>(); + } + }catch ( std::exception& ){ + return make_error<err::out_of_memory>(); + } + return void_t{}; + } + + /** + * Requests data from the server + */ + template<typename Sch> + error_or<data<Sch, Encoding>> receive(id<Sch> store_id){ + auto& vals = std::get<std::unordered_map<uint64_t, data<Sch,encode::Sycl<Encoding>>>>(values_); + auto find_res = vals.find(store_id.get_value()); + if(find_res == vals.end()){ + return make_error<err::not_found>(); + } + auto& dat = find_res->second; + + auto eoval = device_->template copy_to_host<Sch, Encoding>(dat); + return eoval; + } + + /** + * Request an erase of the stored data + */ + template<typename Sch> + error_or<void> erase(id<Sch> store_id){ + auto& vals = std::get<std::unordered_map<uint64_t, data<Sch,Encoding>>>(values_); + auto erase_op = vals.erase(store_id.get_value()); + if(erase_op == 0u){ + return make_error<err::not_found>(); + } + return void_t{}; + } + + /** + * Get the stored data on the server side for immediate use. + * Insert operations may invalidate the pointer. + */ + template<typename Sch> + error_or<data<Sch, encode::Sycl<Encoding>>*> find(id<Sch> store_id){ + auto& vals = std::get<std::unordered_map<uint64_t, data<Sch,Encoding>>>(values_); + auto find_res = vals.find(store_id.get_value()); + if(find_res == vals.end()){ + return make_error<err::not_found>(); + } + + return &(find_res.second); + } +}; + +/** + * Client for transporting data to remote and receiving data back + */ +template<typename... Schema, typename Encoding> +class data_client<tmpl_group<Schema...>, Encoding, rmt::Hip> { +private: + /** + * Corresponding server for this client + */ + data_server<tmpl_group<Schema...>, Encoding, rmt::Hip>* srv_; + + /** + * The next id for identifying issues on the remote side. + */ + uint64_t next_id_; +public: + /** + * Main constructor + */ + data_client(data_server<tmpl_group<Schema...>, Encoding, rmt::Hip>& srv__): + srv_{&srv__}, + next_id_{0u} + {} + + /** + * Send data to the remote. + */ + template<typename Sch> + error_or<id<Sch>> send(const data<Sch, Encoding>& dat){ + id<Sch> dat_id{next_id_}; + auto eov = srv_->send(dat, dat_id); + if(eov.is_error()){ + auto& err = eov.get_error(); + return std::move(err); + } + + ++next_id_; + return dat_id; + } + + /** + * Receive data + */ + template<typename Sch> + conveyor<data<Sch, Encoding>> receive(id<Sch> dat_id){ + auto eov = srv_->receive(dat_id); + if(eov.is_error()){ + auto& err = eov.get_error(); + return std::move(err); + } + + auto& val = eov.get_value(); + return std::move(val); + } + + /** + * Erase data + */ + template<typename Sch> + error_or<void> erase(id<Sch> dat_id){ + return srv_->erase(dat_id); + } +}; +} |