From 57f6eacfcdbdba31185eb66b9a573a8923eecf16 Mon Sep 17 00:00:00 2001 From: "Claudius \"keldu\" Holeksa" Date: Thu, 13 Jun 2024 17:34:22 +0200 Subject: Possible fix for transferring primitives to device without dropping STL --- modules/codec/c++/id_map.hpp | 33 +++++++++ modules/remote-sycl/c++/remote.hpp | 83 +++++++++++++++------- modules/remote-sycl/examples/sycl_basic.cpp | 13 ++-- modules/remote-sycl/examples/sycl_basic_kernel.cpp | 2 +- 4 files changed, 100 insertions(+), 31 deletions(-) diff --git a/modules/codec/c++/id_map.hpp b/modules/codec/c++/id_map.hpp index 84c04e7..18a331f 100644 --- a/modules/codec/c++/id_map.hpp +++ b/modules/codec/c++/id_map.hpp @@ -87,6 +87,29 @@ public: return make_error(); } + /** + * Insert as data with associated id. This can fail when it doesn't adhere to the standard approach. + */ + error_or insert_as(data val, id id) noexcept { + if(free_ids_.empty()){ + if( id.get_value() != data_.size() ){ + return make_error("Can't insert_as with provided ID. Doesn't match."); + } + try { + data_.emplace_back(std::move(val)); + }catch(std::exception& e){ + return make_error(); + } + return void_t{}; + } + + if(free_ids_.back() != id){ + return make_error("Can't insert_as with provided ID. Doesn't match next id."); + } + data_.at(id.get_value()) = std::move(val); + return void_t{}; + } + /** * Erase a value at this id. If this id isn't in the map, then it returns an error. */ @@ -140,6 +163,16 @@ public: return void_t{}; } + /** + * Tries to return the next free id + */ + id next_free_id() const { + if(free_ids_.empty()){ + return {data_.size()}; + } + return free_ids_.back(); + } + /** * Tries to find a value based on an id. * Returns an error on failure and returns diff --git a/modules/remote-sycl/c++/remote.hpp b/modules/remote-sycl/c++/remote.hpp index 677a427..bcc8a3c 100644 --- a/modules/remote-sycl/c++/remote.hpp +++ b/modules/remote-sycl/c++/remote.hpp @@ -22,15 +22,24 @@ class remote_data { private: id id_; id_map* map_; + cl::sycl::queue* queue_; public: /** * Main constructor */ - remote_data(const id& id, id_map& map): + remote_data(const id& id, id_map& map, cl::sycl::queue& queue__): id_{id}, - map_{&map} + map_{&map}, + queue_{&queue__} {} + /** + * Wait for the data + */ + error_or> wait(){ + + } + /** * Request data asynchronously */ @@ -39,21 +48,14 @@ public: /** * - */ template class data, encode::Native, rmt::Sycl> { public: using Schema = schema::Primitive; using NativeType = typename native_data_type::type; private: - /** - * - */ NativeType val_; public: - /** - * - */ data(NativeType val__): val_{val__} {} @@ -62,6 +64,7 @@ public: return val_; } }; + */ template class data, encode::Native, rmt::Sycl> { @@ -69,8 +72,8 @@ public: using Schema = schema::Array; private: uint64_t total_length_; - typename native_data_type::type* device_data_; - // data* device_data_; + // typename native_data_type::type* device_data_; + data* device_data_; cl::sycl::queue* queue_; static_assert(is_primitive::value, "Only supports primitives for now"); @@ -78,7 +81,8 @@ private: public: data(uint64_t size, cl::sycl::queue& q__): total_length_{size}, - device_data_{cl::sycl::malloc_device::type>(size, q__)}, + device_data_{cl::sycl::malloc_device>(size, q__)}, + //device_data_{cl::sycl::malloc_device::type>(size, q__)}, queue_{&q__} { if(!device_data_){ @@ -89,12 +93,15 @@ public: template data(const data& from, cl::sycl::queue& q__): total_length_{from.size()}, - device_data_{cl::sycl::malloc_device::type>(from.size(), q__)}, + device_data_{cl::sycl::malloc_device>(from.size(), q__)}, + //device_data_{cl::sycl::malloc_device::type>(from.size(), q__)}, queue_{&q__} { if(!device_data_){ total_length_ = 0u; + return; } + queue_->template copy>(&from.at(0), device_data_, total_length_); } data(const data& from): @@ -105,11 +112,23 @@ public: if(total_length_ == 0u || !queue_){ return; } - device_data_ = cl::sycl::malloc_device::type>(from.size(), *queue_); + device_data_ = cl::sycl::malloc_device>(from.size(), *queue_); + // device_data_ = cl::sycl::malloc_device::type>(from.size(), *queue_); if(!device_data_){ total_length_ = 0u; } } + + data& operator=(const data& rhs) { + total_length_ = rhs.total_length_; + device_data_ = cl::sycl::malloc_device>(rhs.size(), *rhs.queue_); + // device_data_ = cl::sycl::malloc_device::type>(rhs.size(), *rhs.queue_); + if(!device_data_){ + total_length_ = 0u; + } + queue_ = rhs.queue_; + return *this; + } data(data&& rhs): total_length_{rhs.total_length_}, @@ -139,8 +158,8 @@ public: } } - // data& at(uint64_t i){ - typename native_data_type::type& at(uint64_t i){ + data& at(uint64_t i){ + //typename native_data_type::type& at(uint64_t i){ return device_data_[i]; } @@ -160,6 +179,7 @@ struct rpc_id_map_helper, Encoding, Storage> { std::tuple...> maps; }; } + /** * Rpc Client class for the Sycl backend. */ @@ -171,6 +191,10 @@ private: * Server this client is tied to */ rpc_server* srv_; + + /** + * Generated some sort of id for the request. + */ public: rpc_client(rpc_server& srv): srv_{&srv} @@ -184,9 +208,9 @@ public: id< typename schema_member_type::type::ResponseT > - > call(data_or_id::type::RequestT, Encoding, Storage> input){ - (void) input; - return make_error("RpcClient side is not implemented"); + > call(const data_or_id::type::RequestT, Encoding, Storage>& input){ + auto next_free_id = srv_->template next_free_id::type::ResponseT>(); + return srv_->template call(input, next_free_id); } }; @@ -215,6 +239,14 @@ private: */ impl::rpc_id_map_helper storage_; public: + /** + * Ask which id the server prefers as the next one. Only available for fast requests on no roundtrip setups. + */ + template + id next_free_id() const { + return std::get>(storage_.maps).next_free_id(); + } + rpc_server(interface cl_iface): cmd_queue_{}, cl_interface_{std::move(cl_iface)}, @@ -222,9 +254,9 @@ public: {} template - remote_data request_data(id dat){ + remote_data request_data(id dat_id){ /// @TODO Fix so I can receive data - return {dat, std::get>(storage_.maps)}; + return {dat_id, std::get>(storage_.maps)}; } /** @@ -235,7 +267,7 @@ public: id< typename schema_member_type::type::ResponseT > - > call(data_or_id::type::RequestT, Encoding, ClientAllocation> input){ + > call(data_or_id::type::RequestT, Encoding, ClientAllocation> input, id::type::ResponseT> rpc_id){ using FuncT = typename schema_member_type::type; /** @@ -258,6 +290,7 @@ public: } else { auto& client_data = input.get_data(); dev_tmp_inp = heap>(client_data, cmd_queue_); + cmd_queue_.wait(); return dev_tmp_inp.get(); } }(); @@ -272,16 +305,16 @@ public: return std::move(eod.get_error()); } + auto& val = eod.get_value(); /** * Store returned data in rpc storage */ - auto& val = eod.get_value(); auto& inner_map = std::get::type::RequestT, Encoding,rmt::Sycl>> (storage_.maps); - auto eoid = inner_map.insert(std::move(val)); + auto eoid = inner_map.insert_as(std::move(val), rpc_id); if(eoid.is_error()){ return std::move(eoid.get_error()); } - return eoid.get_value(); + return rpc_id; } }; diff --git a/modules/remote-sycl/examples/sycl_basic.cpp b/modules/remote-sycl/examples/sycl_basic.cpp index 677fd29..2e9a4f8 100644 --- a/modules/remote-sycl/examples/sycl_basic.cpp +++ b/modules/remote-sycl/examples/sycl_basic.cpp @@ -14,25 +14,25 @@ int main(){ }).detach(); wait.poll(); - if(!rmt_addr){ return -1; } auto rpc_server = listen_basic_sycl(remote_ctx, *rmt_addr); + saw::rpc_client client{rpc_server}; - saw::id> next_id{0u}; + saw::id> id_zero{0u}; { - auto eov = rpc_server.template call<"increment", saw::storage::Default>(saw::data, saw::encode::Native>{1u}); + auto eov = client.template call<"increment">(saw::data, saw::encode::Native>{1u}); if(eov.is_error()){ auto& err = eov.get_error(); std::cerr<<"Error: "<(next_id); + auto eov = client.template call<"increment">(id_zero); if(eov.is_error()){ auto& err = eov.get_error(); std::cerr<<"Error: "< lis q->submit([&](cl::sycl::handler& h){ h.parallel_for(cl::sycl::range<1>(1u), [&] (cl::sycl::id<1> it){ - in.at(0u) += 1u; + in.at(0u).set(in.at(0u).get() + 1u); }); }); q->wait(); -- cgit v1.2.3