diff options
Diffstat (limited to 'modules/remote-sycl/c++')
-rw-r--r-- | modules/remote-sycl/c++/remote.hpp | 83 |
1 files changed, 58 insertions, 25 deletions
diff --git a/modules/remote-sycl/c++/remote.hpp b/modules/remote-sycl/c++/remote.hpp index 677a427..bcc8a3c 100644 --- a/modules/remote-sycl/c++/remote.hpp +++ b/modules/remote-sycl/c++/remote.hpp @@ -22,16 +22,25 @@ class remote_data<T, Encoding, Storage, rmt::Sycl> { private: id<T> id_; id_map<T,Encoding,rmt::Sycl>* map_; + cl::sycl::queue* queue_; public: /** * Main constructor */ - remote_data(const id<T>& id, id_map<T, Encoding, rmt::Sycl>& map): + remote_data(const id<T>& id, id_map<T, Encoding, rmt::Sycl>& map, cl::sycl::queue& queue__): id_{id}, - map_{&map} + map_{&map}, + queue_{&queue__} {} /** + * Wait for the data + */ + error_or<data<T,Encoding,Storage>> wait(){ + + } + + /** * Request data asynchronously */ conveyor<data<T,Encoding,Storage>> on_receive(); /// Stopped here @@ -39,21 +48,14 @@ public: /** * - */ template<typename T, uint64_t N> class data<schema::Primitive<T,N>, encode::Native, rmt::Sycl> { public: using Schema = schema::Primitive<T,N>; using NativeType = typename native_data_type<Schema>::type; private: - /** - * - */ NativeType val_; public: - /** - * - */ data(NativeType val__): val_{val__} {} @@ -62,6 +64,7 @@ public: return val_; } }; + */ template<typename T, uint64_t D> class data<schema::Array<T,D>, encode::Native, rmt::Sycl> { @@ -69,8 +72,8 @@ public: using Schema = schema::Array<T,D>; private: uint64_t total_length_; - typename native_data_type<T>::type* device_data_; - // data<T>* device_data_; + // typename native_data_type<T>::type* device_data_; + data<T,encode::Native,storage::Default>* device_data_; cl::sycl::queue* queue_; static_assert(is_primitive<T>::value, "Only supports primitives for now"); @@ -78,7 +81,8 @@ private: public: data(uint64_t size, cl::sycl::queue& q__): total_length_{size}, - device_data_{cl::sycl::malloc_device<typename native_data_type<T>::type>(size, q__)}, + device_data_{cl::sycl::malloc_device<data<T,encode::Native,storage::Default>>(size, q__)}, + //device_data_{cl::sycl::malloc_device<typename native_data_type<T>::type>(size, q__)}, queue_{&q__} { if(!device_data_){ @@ -89,12 +93,15 @@ public: template<typename Encoding, typename Storage> data(const data<Schema, Encoding, Storage>& from, cl::sycl::queue& q__): total_length_{from.size()}, - device_data_{cl::sycl::malloc_device<typename native_data_type<T>::type>(from.size(), q__)}, + device_data_{cl::sycl::malloc_device<data<T,encode::Native,storage::Default>>(from.size(), q__)}, + //device_data_{cl::sycl::malloc_device<typename native_data_type<T>::type>(from.size(), q__)}, queue_{&q__} { if(!device_data_){ total_length_ = 0u; + return; } + queue_->template copy<data<T,encode::Native,storage::Default>>(&from.at(0), device_data_, total_length_); } data(const data<Schema, encode::Native, rmt::Sycl>& from): @@ -105,11 +112,23 @@ public: if(total_length_ == 0u || !queue_){ return; } - device_data_ = cl::sycl::malloc_device<typename native_data_type<T>::type>(from.size(), *queue_); + device_data_ = cl::sycl::malloc_device<data<T,encode::Native,storage::Default>>(from.size(), *queue_); + // device_data_ = cl::sycl::malloc_device<typename native_data_type<T>::type>(from.size(), *queue_); if(!device_data_){ total_length_ = 0u; } } + + data<Schema, encode::Native, rmt::Sycl>& operator=(const data<Schema, encode::Native, rmt::Sycl>& rhs) { + total_length_ = rhs.total_length_; + device_data_ = cl::sycl::malloc_device<data<T,encode::Native,storage::Default>>(rhs.size(), *rhs.queue_); + // device_data_ = cl::sycl::malloc_device<typename native_data_type<T>::type>(rhs.size(), *rhs.queue_); + if(!device_data_){ + total_length_ = 0u; + } + queue_ = rhs.queue_; + return *this; + } data(data<Schema, encode::Native, rmt::Sycl>&& rhs): total_length_{rhs.total_length_}, @@ -139,8 +158,8 @@ public: } } - // data<T,encode::Native,rmt::Sycl>& at(uint64_t i){ - typename native_data_type<T>::type& at(uint64_t i){ + data<T, encode::Native, saw::storage::Default>& at(uint64_t i){ + //typename native_data_type<T>::type& at(uint64_t i){ return device_data_[i]; } @@ -160,6 +179,7 @@ struct rpc_id_map_helper<schema::Interface<Members...>, Encoding, Storage> { std::tuple<id_map<typename Members::ValueType::ResponseT, Encoding, Storage>...> maps; }; } + /** * Rpc Client class for the Sycl backend. */ @@ -171,6 +191,10 @@ private: * Server this client is tied to */ rpc_server<Iface, Encoding, rmt::Sycl>* srv_; + + /** + * Generated some sort of id for the request. + */ public: rpc_client(rpc_server<Iface, Encoding, rmt::Sycl>& srv): srv_{&srv} @@ -184,9 +208,9 @@ public: id< typename schema_member_type<Name, Iface>::type::ResponseT > - > call(data_or_id<typename schema_member_type<Name, Iface>::type::RequestT, Encoding, Storage> input){ - (void) input; - return make_error<err::not_implemented>("RpcClient side is not implemented"); + > call(const data_or_id<typename schema_member_type<Name, Iface>::type::RequestT, Encoding, Storage>& input){ + auto next_free_id = srv_->template next_free_id<typename schema_member_type<Name, Iface>::type::ResponseT>(); + return srv_->template call<Name, Storage>(input, next_free_id); } }; @@ -215,6 +239,14 @@ private: */ impl::rpc_id_map_helper<Iface, Encoding, rmt::Sycl> storage_; public: + /** + * Ask which id the server prefers as the next one. Only available for fast requests on no roundtrip setups. + */ + template<typename T> + id<T> next_free_id() const { + return std::get<id_map<T,Encoding,rmt::Sycl>>(storage_.maps).next_free_id(); + } + rpc_server(interface<Iface, Encoding, rmt::Sycl, InterfaceCtxT> cl_iface): cmd_queue_{}, cl_interface_{std::move(cl_iface)}, @@ -222,9 +254,9 @@ public: {} template<typename IdT, typename Storage> - remote_data<IdT, Encoding, Storage, rmt::Sycl> request_data(id<IdT> dat){ + remote_data<IdT, Encoding, Storage, rmt::Sycl> request_data(id<IdT> dat_id){ /// @TODO Fix so I can receive data - return {dat, std::get<id_map<IdT, Encoding,rmt::Sycl>>(storage_.maps)}; + return {dat_id, std::get<id_map<IdT, Encoding,rmt::Sycl>>(storage_.maps)}; } /** @@ -235,7 +267,7 @@ public: id< typename schema_member_type<Name, Iface>::type::ResponseT > - > call(data_or_id<typename schema_member_type<Name, Iface>::type::RequestT, Encoding, ClientAllocation> input){ + > call(data_or_id<typename schema_member_type<Name, Iface>::type::RequestT, Encoding, ClientAllocation> input, id<typename schema_member_type<Name,Iface>::type::ResponseT> rpc_id){ using FuncT = typename schema_member_type<Name, Iface>::type; /** @@ -258,6 +290,7 @@ public: } else { auto& client_data = input.get_data(); dev_tmp_inp = heap<data<typename FuncT::RequestT, Encoding, rmt::Sycl>>(client_data, cmd_queue_); + cmd_queue_.wait(); return dev_tmp_inp.get(); } }(); @@ -272,16 +305,16 @@ public: return std::move(eod.get_error()); } + auto& val = eod.get_value(); /** * Store returned data in rpc storage */ - auto& val = eod.get_value(); auto& inner_map = std::get<id_map<typename schema_member_type<Name, Iface>::type::RequestT, Encoding,rmt::Sycl>> (storage_.maps); - auto eoid = inner_map.insert(std::move(val)); + auto eoid = inner_map.insert_as(std::move(val), rpc_id); if(eoid.is_error()){ return std::move(eoid.get_error()); } - return eoid.get_value(); + return rpc_id; } }; |