From 57f6eacfcdbdba31185eb66b9a573a8923eecf16 Mon Sep 17 00:00:00 2001 From: "Claudius \"keldu\" Holeksa" Date: Thu, 13 Jun 2024 17:34:22 +0200 Subject: Possible fix for transferring primitives to device without dropping STL --- modules/remote-sycl/c++/remote.hpp | 83 ++++++++++++++++++++++++++------------ 1 file changed, 58 insertions(+), 25 deletions(-) (limited to 'modules/remote-sycl/c++/remote.hpp') diff --git a/modules/remote-sycl/c++/remote.hpp b/modules/remote-sycl/c++/remote.hpp index 677a427..bcc8a3c 100644 --- a/modules/remote-sycl/c++/remote.hpp +++ b/modules/remote-sycl/c++/remote.hpp @@ -22,15 +22,24 @@ class remote_data { private: id id_; id_map* map_; + cl::sycl::queue* queue_; public: /** * Main constructor */ - remote_data(const id& id, id_map& map): + remote_data(const id& id, id_map& map, cl::sycl::queue& queue__): id_{id}, - map_{&map} + map_{&map}, + queue_{&queue__} {} + /** + * Wait for the data + */ + error_or> wait(){ + + } + /** * Request data asynchronously */ @@ -39,21 +48,14 @@ public: /** * - */ template class data, encode::Native, rmt::Sycl> { public: using Schema = schema::Primitive; using NativeType = typename native_data_type::type; private: - /** - * - */ NativeType val_; public: - /** - * - */ data(NativeType val__): val_{val__} {} @@ -62,6 +64,7 @@ public: return val_; } }; + */ template class data, encode::Native, rmt::Sycl> { @@ -69,8 +72,8 @@ public: using Schema = schema::Array; private: uint64_t total_length_; - typename native_data_type::type* device_data_; - // data* device_data_; + // typename native_data_type::type* device_data_; + data* device_data_; cl::sycl::queue* queue_; static_assert(is_primitive::value, "Only supports primitives for now"); @@ -78,7 +81,8 @@ private: public: data(uint64_t size, cl::sycl::queue& q__): total_length_{size}, - device_data_{cl::sycl::malloc_device::type>(size, q__)}, + device_data_{cl::sycl::malloc_device>(size, q__)}, + //device_data_{cl::sycl::malloc_device::type>(size, q__)}, queue_{&q__} { if(!device_data_){ @@ -89,12 +93,15 @@ public: template data(const data& from, cl::sycl::queue& q__): total_length_{from.size()}, - device_data_{cl::sycl::malloc_device::type>(from.size(), q__)}, + device_data_{cl::sycl::malloc_device>(from.size(), q__)}, + //device_data_{cl::sycl::malloc_device::type>(from.size(), q__)}, queue_{&q__} { if(!device_data_){ total_length_ = 0u; + return; } + queue_->template copy>(&from.at(0), device_data_, total_length_); } data(const data& from): @@ -105,11 +112,23 @@ public: if(total_length_ == 0u || !queue_){ return; } - device_data_ = cl::sycl::malloc_device::type>(from.size(), *queue_); + device_data_ = cl::sycl::malloc_device>(from.size(), *queue_); + // device_data_ = cl::sycl::malloc_device::type>(from.size(), *queue_); if(!device_data_){ total_length_ = 0u; } } + + data& operator=(const data& rhs) { + total_length_ = rhs.total_length_; + device_data_ = cl::sycl::malloc_device>(rhs.size(), *rhs.queue_); + // device_data_ = cl::sycl::malloc_device::type>(rhs.size(), *rhs.queue_); + if(!device_data_){ + total_length_ = 0u; + } + queue_ = rhs.queue_; + return *this; + } data(data&& rhs): total_length_{rhs.total_length_}, @@ -139,8 +158,8 @@ public: } } - // data& at(uint64_t i){ - typename native_data_type::type& at(uint64_t i){ + data& at(uint64_t i){ + //typename native_data_type::type& at(uint64_t i){ return device_data_[i]; } @@ -160,6 +179,7 @@ struct rpc_id_map_helper, Encoding, Storage> { std::tuple...> maps; }; } + /** * Rpc Client class for the Sycl backend. */ @@ -171,6 +191,10 @@ private: * Server this client is tied to */ rpc_server* srv_; + + /** + * Generated some sort of id for the request. + */ public: rpc_client(rpc_server& srv): srv_{&srv} @@ -184,9 +208,9 @@ public: id< typename schema_member_type::type::ResponseT > - > call(data_or_id::type::RequestT, Encoding, Storage> input){ - (void) input; - return make_error("RpcClient side is not implemented"); + > call(const data_or_id::type::RequestT, Encoding, Storage>& input){ + auto next_free_id = srv_->template next_free_id::type::ResponseT>(); + return srv_->template call(input, next_free_id); } }; @@ -215,6 +239,14 @@ private: */ impl::rpc_id_map_helper storage_; public: + /** + * Ask which id the server prefers as the next one. Only available for fast requests on no roundtrip setups. + */ + template + id next_free_id() const { + return std::get>(storage_.maps).next_free_id(); + } + rpc_server(interface cl_iface): cmd_queue_{}, cl_interface_{std::move(cl_iface)}, @@ -222,9 +254,9 @@ public: {} template - remote_data request_data(id dat){ + remote_data request_data(id dat_id){ /// @TODO Fix so I can receive data - return {dat, std::get>(storage_.maps)}; + return {dat_id, std::get>(storage_.maps)}; } /** @@ -235,7 +267,7 @@ public: id< typename schema_member_type::type::ResponseT > - > call(data_or_id::type::RequestT, Encoding, ClientAllocation> input){ + > call(data_or_id::type::RequestT, Encoding, ClientAllocation> input, id::type::ResponseT> rpc_id){ using FuncT = typename schema_member_type::type; /** @@ -258,6 +290,7 @@ public: } else { auto& client_data = input.get_data(); dev_tmp_inp = heap>(client_data, cmd_queue_); + cmd_queue_.wait(); return dev_tmp_inp.get(); } }(); @@ -272,16 +305,16 @@ public: return std::move(eod.get_error()); } + auto& val = eod.get_value(); /** * Store returned data in rpc storage */ - auto& val = eod.get_value(); auto& inner_map = std::get::type::RequestT, Encoding,rmt::Sycl>> (storage_.maps); - auto eoid = inner_map.insert(std::move(val)); + auto eoid = inner_map.insert_as(std::move(val), rpc_id); if(eoid.is_error()){ return std::move(eoid.get_error()); } - return eoid.get_value(); + return rpc_id; } }; -- cgit v1.2.3