diff options
author | Claudius "keldu" Holeksa <mail@keldu.de> | 2024-06-13 17:34:22 +0200 |
---|---|---|
committer | Claudius "keldu" Holeksa <mail@keldu.de> | 2024-06-13 17:34:22 +0200 |
commit | 57f6eacfcdbdba31185eb66b9a573a8923eecf16 (patch) | |
tree | 1683da4209744fabbe87a949134701d617c0d5f9 /modules/remote-sycl | |
parent | 0f317186de9fb11d336e564f808e4732386c4074 (diff) |
Possible fix for transferring primitives to device without dropping STL
Diffstat (limited to 'modules/remote-sycl')
-rw-r--r-- | modules/remote-sycl/c++/remote.hpp | 83 | ||||
-rw-r--r-- | modules/remote-sycl/examples/sycl_basic.cpp | 13 | ||||
-rw-r--r-- | modules/remote-sycl/examples/sycl_basic_kernel.cpp | 2 |
3 files changed, 67 insertions, 31 deletions
diff --git a/modules/remote-sycl/c++/remote.hpp b/modules/remote-sycl/c++/remote.hpp index 677a427..bcc8a3c 100644 --- a/modules/remote-sycl/c++/remote.hpp +++ b/modules/remote-sycl/c++/remote.hpp @@ -22,16 +22,25 @@ class remote_data<T, Encoding, Storage, rmt::Sycl> { private: id<T> id_; id_map<T,Encoding,rmt::Sycl>* map_; + cl::sycl::queue* queue_; public: /** * Main constructor */ - remote_data(const id<T>& id, id_map<T, Encoding, rmt::Sycl>& map): + remote_data(const id<T>& id, id_map<T, Encoding, rmt::Sycl>& map, cl::sycl::queue& queue__): id_{id}, - map_{&map} + map_{&map}, + queue_{&queue__} {} /** + * Wait for the data + */ + error_or<data<T,Encoding,Storage>> wait(){ + + } + + /** * Request data asynchronously */ conveyor<data<T,Encoding,Storage>> on_receive(); /// Stopped here @@ -39,21 +48,14 @@ public: /** * - */ template<typename T, uint64_t N> class data<schema::Primitive<T,N>, encode::Native, rmt::Sycl> { public: using Schema = schema::Primitive<T,N>; using NativeType = typename native_data_type<Schema>::type; private: - /** - * - */ NativeType val_; public: - /** - * - */ data(NativeType val__): val_{val__} {} @@ -62,6 +64,7 @@ public: return val_; } }; + */ template<typename T, uint64_t D> class data<schema::Array<T,D>, encode::Native, rmt::Sycl> { @@ -69,8 +72,8 @@ public: using Schema = schema::Array<T,D>; private: uint64_t total_length_; - typename native_data_type<T>::type* device_data_; - // data<T>* device_data_; + // typename native_data_type<T>::type* device_data_; + data<T,encode::Native,storage::Default>* device_data_; cl::sycl::queue* queue_; static_assert(is_primitive<T>::value, "Only supports primitives for now"); @@ -78,7 +81,8 @@ private: public: data(uint64_t size, cl::sycl::queue& q__): total_length_{size}, - device_data_{cl::sycl::malloc_device<typename native_data_type<T>::type>(size, q__)}, + device_data_{cl::sycl::malloc_device<data<T,encode::Native,storage::Default>>(size, q__)}, + //device_data_{cl::sycl::malloc_device<typename native_data_type<T>::type>(size, q__)}, queue_{&q__} { if(!device_data_){ @@ -89,12 +93,15 @@ public: template<typename Encoding, typename Storage> data(const data<Schema, Encoding, Storage>& from, cl::sycl::queue& q__): total_length_{from.size()}, - device_data_{cl::sycl::malloc_device<typename native_data_type<T>::type>(from.size(), q__)}, + device_data_{cl::sycl::malloc_device<data<T,encode::Native,storage::Default>>(from.size(), q__)}, + //device_data_{cl::sycl::malloc_device<typename native_data_type<T>::type>(from.size(), q__)}, queue_{&q__} { if(!device_data_){ total_length_ = 0u; + return; } + queue_->template copy<data<T,encode::Native,storage::Default>>(&from.at(0), device_data_, total_length_); } data(const data<Schema, encode::Native, rmt::Sycl>& from): @@ -105,11 +112,23 @@ public: if(total_length_ == 0u || !queue_){ return; } - device_data_ = cl::sycl::malloc_device<typename native_data_type<T>::type>(from.size(), *queue_); + device_data_ = cl::sycl::malloc_device<data<T,encode::Native,storage::Default>>(from.size(), *queue_); + // device_data_ = cl::sycl::malloc_device<typename native_data_type<T>::type>(from.size(), *queue_); if(!device_data_){ total_length_ = 0u; } } + + data<Schema, encode::Native, rmt::Sycl>& operator=(const data<Schema, encode::Native, rmt::Sycl>& rhs) { + total_length_ = rhs.total_length_; + device_data_ = cl::sycl::malloc_device<data<T,encode::Native,storage::Default>>(rhs.size(), *rhs.queue_); + // device_data_ = cl::sycl::malloc_device<typename native_data_type<T>::type>(rhs.size(), *rhs.queue_); + if(!device_data_){ + total_length_ = 0u; + } + queue_ = rhs.queue_; + return *this; + } data(data<Schema, encode::Native, rmt::Sycl>&& rhs): total_length_{rhs.total_length_}, @@ -139,8 +158,8 @@ public: } } - // data<T,encode::Native,rmt::Sycl>& at(uint64_t i){ - typename native_data_type<T>::type& at(uint64_t i){ + data<T, encode::Native, saw::storage::Default>& at(uint64_t i){ + //typename native_data_type<T>::type& at(uint64_t i){ return device_data_[i]; } @@ -160,6 +179,7 @@ struct rpc_id_map_helper<schema::Interface<Members...>, Encoding, Storage> { std::tuple<id_map<typename Members::ValueType::ResponseT, Encoding, Storage>...> maps; }; } + /** * Rpc Client class for the Sycl backend. */ @@ -171,6 +191,10 @@ private: * Server this client is tied to */ rpc_server<Iface, Encoding, rmt::Sycl>* srv_; + + /** + * Generated some sort of id for the request. + */ public: rpc_client(rpc_server<Iface, Encoding, rmt::Sycl>& srv): srv_{&srv} @@ -184,9 +208,9 @@ public: id< typename schema_member_type<Name, Iface>::type::ResponseT > - > call(data_or_id<typename schema_member_type<Name, Iface>::type::RequestT, Encoding, Storage> input){ - (void) input; - return make_error<err::not_implemented>("RpcClient side is not implemented"); + > call(const data_or_id<typename schema_member_type<Name, Iface>::type::RequestT, Encoding, Storage>& input){ + auto next_free_id = srv_->template next_free_id<typename schema_member_type<Name, Iface>::type::ResponseT>(); + return srv_->template call<Name, Storage>(input, next_free_id); } }; @@ -215,6 +239,14 @@ private: */ impl::rpc_id_map_helper<Iface, Encoding, rmt::Sycl> storage_; public: + /** + * Ask which id the server prefers as the next one. Only available for fast requests on no roundtrip setups. + */ + template<typename T> + id<T> next_free_id() const { + return std::get<id_map<T,Encoding,rmt::Sycl>>(storage_.maps).next_free_id(); + } + rpc_server(interface<Iface, Encoding, rmt::Sycl, InterfaceCtxT> cl_iface): cmd_queue_{}, cl_interface_{std::move(cl_iface)}, @@ -222,9 +254,9 @@ public: {} template<typename IdT, typename Storage> - remote_data<IdT, Encoding, Storage, rmt::Sycl> request_data(id<IdT> dat){ + remote_data<IdT, Encoding, Storage, rmt::Sycl> request_data(id<IdT> dat_id){ /// @TODO Fix so I can receive data - return {dat, std::get<id_map<IdT, Encoding,rmt::Sycl>>(storage_.maps)}; + return {dat_id, std::get<id_map<IdT, Encoding,rmt::Sycl>>(storage_.maps)}; } /** @@ -235,7 +267,7 @@ public: id< typename schema_member_type<Name, Iface>::type::ResponseT > - > call(data_or_id<typename schema_member_type<Name, Iface>::type::RequestT, Encoding, ClientAllocation> input){ + > call(data_or_id<typename schema_member_type<Name, Iface>::type::RequestT, Encoding, ClientAllocation> input, id<typename schema_member_type<Name,Iface>::type::ResponseT> rpc_id){ using FuncT = typename schema_member_type<Name, Iface>::type; /** @@ -258,6 +290,7 @@ public: } else { auto& client_data = input.get_data(); dev_tmp_inp = heap<data<typename FuncT::RequestT, Encoding, rmt::Sycl>>(client_data, cmd_queue_); + cmd_queue_.wait(); return dev_tmp_inp.get(); } }(); @@ -272,16 +305,16 @@ public: return std::move(eod.get_error()); } + auto& val = eod.get_value(); /** * Store returned data in rpc storage */ - auto& val = eod.get_value(); auto& inner_map = std::get<id_map<typename schema_member_type<Name, Iface>::type::RequestT, Encoding,rmt::Sycl>> (storage_.maps); - auto eoid = inner_map.insert(std::move(val)); + auto eoid = inner_map.insert_as(std::move(val), rpc_id); if(eoid.is_error()){ return std::move(eoid.get_error()); } - return eoid.get_value(); + return rpc_id; } }; diff --git a/modules/remote-sycl/examples/sycl_basic.cpp b/modules/remote-sycl/examples/sycl_basic.cpp index 677fd29..2e9a4f8 100644 --- a/modules/remote-sycl/examples/sycl_basic.cpp +++ b/modules/remote-sycl/examples/sycl_basic.cpp @@ -14,25 +14,25 @@ int main(){ }).detach(); wait.poll(); - if(!rmt_addr){ return -1; } auto rpc_server = listen_basic_sycl(remote_ctx, *rmt_addr); + saw::rpc_client<schema::BasicInterface, saw::encode::Native, saw::storage::Default, saw::rmt::Sycl> client{rpc_server}; - saw::id<schema::Array<schema::UInt64>> next_id{0u}; + saw::id<schema::Array<schema::UInt64>> id_zero{0u}; { - auto eov = rpc_server.template call<"increment", saw::storage::Default>(saw::data<schema::Array<schema::UInt64>, saw::encode::Native>{1u}); + auto eov = client.template call<"increment">(saw::data<schema::Array<schema::UInt64>, saw::encode::Native>{1u}); if(eov.is_error()){ auto& err = eov.get_error(); std::cerr<<"Error: "<<err.get_category()<<" : "<<err.get_message()<<std::endl; return -2; } - next_id = eov.get_value(); + id_zero = eov.get_value(); } { - auto eov = rpc_server.template call<"increment", saw::storage::Default>(next_id); + auto eov = client.template call<"increment">(id_zero); if(eov.is_error()){ auto& err = eov.get_error(); std::cerr<<"Error: "<<err.get_category()<<" : "<<err.get_message()<<std::endl; @@ -41,6 +41,9 @@ int main(){ auto& val = eov.get_value(); std::cout<<"Value: "<<val.get_value()<<std::endl; } + { + // auto eo_rd = rpc_server.request_data(id_one); + } return 0; } diff --git a/modules/remote-sycl/examples/sycl_basic_kernel.cpp b/modules/remote-sycl/examples/sycl_basic_kernel.cpp index 94583b9..888f905 100644 --- a/modules/remote-sycl/examples/sycl_basic_kernel.cpp +++ b/modules/remote-sycl/examples/sycl_basic_kernel.cpp @@ -7,7 +7,7 @@ saw::rpc_server<schema::BasicInterface, saw::encode::Native, saw::rmt::Sycl> lis q->submit([&](cl::sycl::handler& h){ h.parallel_for(cl::sycl::range<1>(1u), [&] (cl::sycl::id<1> it){ - in.at(0u) += 1u; + in.at(0u).set(in.at(0u).get() + 1u); }); }); q->wait(); |