diff options
Diffstat (limited to 'modules/remote-sycl/c++')
-rw-r--r-- | modules/remote-sycl/c++/data.hpp | 2 | ||||
-rw-r--r-- | modules/remote-sycl/c++/device.hpp | 4 | ||||
-rw-r--r-- | modules/remote-sycl/c++/remote.hpp | 11 | ||||
-rw-r--r-- | modules/remote-sycl/c++/transfer.hpp | 15 |
4 files changed, 20 insertions, 12 deletions
diff --git a/modules/remote-sycl/c++/data.hpp b/modules/remote-sycl/c++/data.hpp index 3ad1d9c..42bc3b1 100644 --- a/modules/remote-sycl/c++/data.hpp +++ b/modules/remote-sycl/c++/data.hpp @@ -13,7 +13,7 @@ class data<Schema, encode::Native, rmt::Sycl> { private: cl::sycl::buffer<data<Schema, encode::Native, storage::Default>> data_; public: - data(data<Schema, encode::Native, storage::Default>& data__): + data(const data<Schema, encode::Native, storage::Default>& data__): data_{&data__, 1u} {} diff --git a/modules/remote-sycl/c++/device.hpp b/modules/remote-sycl/c++/device.hpp index 6d133ae..46644b4 100644 --- a/modules/remote-sycl/c++/device.hpp +++ b/modules/remote-sycl/c++/device.hpp @@ -21,7 +21,7 @@ public: */ template<typename Schema, typename Encoding, typename Storage> error_or<data<Schema, Encoding, rmt::Sycl>> copy_to_device(const data<Schema, Encoding, Storage>& host_data){ - return data<Schema, Encoding, rmt::Sycl>::copy_to_device(host_data, *this); + return data<Schema, Encoding, rmt::Sycl>{host_data}; } /** @@ -29,7 +29,7 @@ public: */ template<typename Schema, typename Encoding, typename Storage> error_or<data<Schema, Encoding, Storage>> copy_to_host(const data<Schema, Encoding, rmt::Sycl>& dev_data){ - return dev_data.copy_to_host(); + return {}; } /** diff --git a/modules/remote-sycl/c++/remote.hpp b/modules/remote-sycl/c++/remote.hpp index 1ae3103..7e77ec9 100644 --- a/modules/remote-sycl/c++/remote.hpp +++ b/modules/remote-sycl/c++/remote.hpp @@ -170,10 +170,11 @@ private: */ device<rmt::Sycl>* device_; + using DataServerT = data_server<typename impl::rpc_iface_type_helper<Iface>::type, Encoding, rmt::Sycl>; /** * Data server storing the relevant data */ - data_server<typename impl::rpc_iface_type_helper<Iface>::type, Encoding, rmt::Sycl> data_server_; + DataServerT* data_server_; /** * The interface including the relevant context class. @@ -185,9 +186,9 @@ public: /** * Main constructor */ - rpc_server(device<rmt::Sycl>& dev__, InterfaceT cl_iface): + rpc_server(device<rmt::Sycl>& dev__, DataServerT& data_server__, InterfaceT cl_iface): device_{&dev__}, - data_server_{}, + data_server_{&data_server__}, cl_interface_{std::move(cl_iface)} {} @@ -230,7 +231,7 @@ public: auto eoinp = [&,this]() -> error_or<data<typename FuncT::RequestT, Encoding, rmt::Sycl>* > { if(input.is_id()){ // storage_.maps - auto eov = data_server_.template find<typename FuncT::RequestT>(input.get_id()); + auto eov = data_server_->template find<typename FuncT::RequestT>(input.get_id()); if(eov.is_error()){ return std::move(eov.get_error()); } @@ -264,7 +265,7 @@ public: /** * Store returned data in rpc storage */ - auto eoid = data_server_.template insert<typename schema_member_type<Name, Iface>::type::RequestT>(std::move(val), rpc_id); + auto eoid = data_server_->template insert<typename schema_member_type<Name, Iface>::type::RequestT>(std::move(val), rpc_id); if(eoid.is_error()){ return std::move(eoid.get_error()); } diff --git a/modules/remote-sycl/c++/transfer.hpp b/modules/remote-sycl/c++/transfer.hpp index 65a9b9e..8987de9 100644 --- a/modules/remote-sycl/c++/transfer.hpp +++ b/modules/remote-sycl/c++/transfer.hpp @@ -45,7 +45,7 @@ public: */ template<typename Sch> error_or<void> send(const data<Sch, Encoding, storage::Default>& dat, id<Sch> store_id){ - auto& vals = std::get<Sch>(values_); + auto& vals = std::get<std::unordered_map<uint64_t, data<Sch,Encoding,rmt::Sycl>>>(values_); auto eoval = device_->template copy_to_device<Sch, Encoding, storage::Default>(dat); if(eoval.is_error()){ auto& err = eoval.get_error(); @@ -68,14 +68,14 @@ public: */ template<typename Sch> error_or<data<Sch, Encoding, storage::Default>> receive(id<Sch> store_id){ - auto& vals = std::get<Sch>(values_); + auto& vals = std::get<std::unordered_map<uint64_t, data<Sch,Encoding,rmt::Sycl>>>(values_); auto find_res = vals.find(store_id.get_value()); if(find_res == vals.end()){ return make_error<err::not_found>(); } auto& dat = find_res->second; - auto eoval = device_->copy_to_host(dat); + auto eoval = device_->template copy_to_host<Sch, Encoding, storage::Default>(dat); return eoval; } @@ -153,7 +153,14 @@ public: */ template<typename Sch> conveyor<data<Sch, Encoding, storage::Default>> receive(id<Sch> dat_id){ - return srv_->receive(dat_id); + auto eov = srv_->receive(dat_id); + if(eov.is_error()){ + auto& err = eov.get_error(); + return std::move(err); + } + + auto& val = eov.get_value(); + return std::move(val); } /** |