From e2a7609028346c3b776a424c9be848e49d3a0e2e Mon Sep 17 00:00:00 2001 From: "Claudius \"keldu\" Holeksa" Date: Wed, 12 Jun 2024 17:15:34 +0200 Subject: Working on reference types in std::function deduction --- modules/remote-sycl/c++/remote.hpp | 79 ++++++++++++++++++++++---------------- 1 file changed, 46 insertions(+), 33 deletions(-) (limited to 'modules/remote-sycl/c++') diff --git a/modules/remote-sycl/c++/remote.hpp b/modules/remote-sycl/c++/remote.hpp index d956314..a2b5a87 100644 --- a/modules/remote-sycl/c++/remote.hpp +++ b/modules/remote-sycl/c++/remote.hpp @@ -17,16 +17,16 @@ class remote; /** * Remote data class for the Sycl backend. */ -template -class remote_data { +template +class remote_data { private: id id_; - id_map* map_; + id_map* map_; public: /** * Main constructor */ - remote_data(const id& id, id_map& map): + remote_data(const id& id, id_map& map): id_{id}, map_{&map} {} @@ -34,14 +34,14 @@ public: /** * Request data asynchronously */ - conveyor> on_receive(); /// Stopped here + conveyor> on_receive(); /// Stopped here }; /** * */ template -class data, encode::Native> { +class data, encode::Native, rmt::Sycl> { public: using Schema = schema::Primitive; using NativeType = typename native_data_type::type; @@ -64,7 +64,7 @@ public: }; template -class data, encode::Native> { +class data, encode::Native, rmt::Sycl> { public: using Schema = schema::Array; private: @@ -86,8 +86,8 @@ public: } } - template - data(const data& from, cl::sycl::queue& q__): + template + data(const data& from, cl::sycl::queue& q__): total_length_{from.size()}, device_data_{cl::sycl::malloc_device::type>(from.size(), q__)}, queue_{&q__} @@ -105,28 +105,32 @@ public: } } - data>& at(uint64_t i){ + // data& at(uint64_t i){ + typename native_data_type::type& at(uint64_t i){ return device_data_[i]; } + + uint64_t size() const { + return total_length_; + } }; namespace impl { - -template +template struct rpc_id_map_helper { - static_assert(always_false, "Only support Interface schema types."); + static_assert(always_false, "Only supports Interface schema types."); }; -template -struct rpc_id_map_helper, Encoding> { - std::tuple...> maps; +template +struct rpc_id_map_helper, Encoding, Storage> { + std::tuple...> maps; }; } /** * Rpc Client class for the Sycl backend. */ -template -class rpc_client { +template +class rpc_client { public: private: /** @@ -141,12 +145,13 @@ public: /** * Rpc call */ - template + template error_or< id< typename schema_member_type::type::ResponseT > - > call(data_or_id::type::RequestT, ClientEncoding> input){ + > call(data_or_id::type::RequestT, Encoding, Storage> input){ + (void) input; return make_error("RpcClient side is not implemented"); } @@ -159,7 +164,7 @@ template class rpc_server { public: using InterfaceCtxT = cl::sycl::queue*; - using InterfaceT = interface; + using InterfaceT = interface; private: /** * Command queue for the sycl backend @@ -169,22 +174,23 @@ private: /** * The interface including the relevant context class. */ - interface cl_interface_; + interface cl_interface_; /** * Basic storage for response data. */ - impl::rpc_id_map_helper storage_; + impl::rpc_id_map_helper storage_; public: - rpc_server(interface cl_iface): + rpc_server(interface cl_iface): cmd_queue_{}, cl_interface_{std::move(cl_iface)}, storage_{} {} - template - remote_data request_data(id dat){ - return {dat, std::get>(storage_.maps)}; + template + remote_data request_data(id dat){ + /// @TODO Fix so I can receive data + return {dat, std::get>(storage_.maps)}; } /** @@ -195,23 +201,30 @@ public: id< typename schema_member_type::type::ResponseT > - > call(data_or_id::type::RequestT, ClientAllocation> input){ + > call(data_or_id::type::RequestT, Encoding, ClientAllocation> input){ + using FuncT = typename schema_member_type::type; + /** + * Object needed if and only if the provided data type is not an id + */ + own> dev_tmp_inp = nullptr; /** * First check if it's data or an id. * If it's an id, check if it's registered within the storage and retrieve it. */ - auto eoinp = [&,this]() -> error_or::type::RequestT, Encoding>* > { + auto eoinp = [&,this]() -> error_or* > { if(input.is_id()){ // storage_.maps - auto& inner_map = std::get::type::RequestT, Encoding >> (storage_.maps); + auto& inner_map = std::get> (storage_.maps); auto eov = inner_map.find(input.get_id()); if(eov.is_error()){ return std::move(eov.get_error()); } return eov.get_value(); } else { - return &input.get_data(); + auto& client_data = input.get_data(); + dev_tmp_inp = heap>(client_data, cmd_queue_); + return dev_tmp_inp.get(); } }(); if(eoinp.is_error()){ @@ -219,7 +232,7 @@ public: } auto& inp = *(eoinp.get_value()); - auto eod = cl_interface_.template call(std::move(inp), &cmd_queue_); + auto eod = cl_interface_.template call(inp, &cmd_queue_); if(eod.is_error()){ return std::move(eod.get_error()); @@ -229,7 +242,7 @@ public: * Store returned data in rpc storage */ auto& val = eod.get_value(); - auto& inner_map = std::get::type::RequestT, Encoding >> (storage_.maps); + auto& inner_map = std::get::type::RequestT, Encoding,rmt::Sycl>> (storage_.maps); auto eoid = inner_map.insert(std::move(val)); if(eoid.is_error()){ return std::move(eoid.get_error()); -- cgit v1.2.3