From e2a7609028346c3b776a424c9be848e49d3a0e2e Mon Sep 17 00:00:00 2001 From: "Claudius \"keldu\" Holeksa" Date: Wed, 12 Jun 2024 17:15:34 +0200 Subject: Working on reference types in std::function deduction --- modules/codec/c++/id_map.hpp | 8 +-- modules/codec/c++/interface.hpp | 11 ++- modules/io_codec/c++/rpc.hpp | 30 ++++---- modules/remote-sycl/c++/remote.hpp | 79 +++++++++++++--------- modules/remote-sycl/examples/sycl_basic.cpp | 4 +- modules/remote-sycl/examples/sycl_basic.hpp | 4 +- modules/remote-sycl/examples/sycl_basic_kernel.cpp | 20 +++--- 7 files changed, 88 insertions(+), 68 deletions(-) diff --git a/modules/codec/c++/id_map.hpp b/modules/codec/c++/id_map.hpp index bb31846..84c04e7 100644 --- a/modules/codec/c++/id_map.hpp +++ b/modules/codec/c++/id_map.hpp @@ -13,13 +13,13 @@ namespace saw { * Insert - O(1) * Erase - O(n) ? Dunno */ -template +template class id_map final { private: /** * Container which stores the primary data */ - std::vector> data_; + std::vector> data_; /** * Container which tracks free'd/fragmented elements within the * main container @@ -65,7 +65,7 @@ public: * Inserts an element into the container and returns either an id on success * or an error on failure. */ - error_or> insert(data val) noexcept { + error_or> insert(data val) noexcept { /// @todo Fix size_t and id base type if(free_ids_.empty()){ try { @@ -145,7 +145,7 @@ public: * Returns an error on failure and returns * a value pointer on success. */ - error_or*> find(const id& val){ + error_or*> find(const id& val){ if(val.get_value() >= data_.size()){ return make_error("ID is too large"); } diff --git a/modules/codec/c++/interface.hpp b/modules/codec/c++/interface.hpp index d55b415..0186f09 100644 --- a/modules/codec/c++/interface.hpp +++ b/modules/codec/c++/interface.hpp @@ -33,7 +33,16 @@ public: func_{std::move(func)} {} - error_or> call(data req, Context ctx = {}){ + error_or> call(data& req, Context ctx = {}){ + if constexpr (std::is_same_v){ + (void) ctx; + return func_(req); + } else { + return func_(req, ctx); + } + } + + error_or> call(data&& req, Context ctx = {}){ if constexpr (std::is_same_v){ (void) ctx; return func_(std::move(req)); diff --git a/modules/io_codec/c++/rpc.hpp b/modules/io_codec/c++/rpc.hpp index 04293cd..f01ebd5 100644 --- a/modules/io_codec/c++/rpc.hpp +++ b/modules/io_codec/c++/rpc.hpp @@ -13,13 +13,13 @@ namespace saw { /** * This class acts as a helper for rpc calls and representing data on the remote. */ -template +template class data_or_id { private: /** * Variant representing the either id or data class. */ - std::variant, data> doi_; + std::variant, data> doi_; public: /** * Constructor for instantiating. @@ -31,7 +31,7 @@ public: /** * Constructor for instantiating. */ - data_or_id(data val): + data_or_id(data val): doi_{std::move(val)} {} @@ -46,7 +46,7 @@ public: * Check if this class holds data. */ bool is_data() const { - return std::holds_alternative>(doi_); + return std::holds_alternative>(doi_); } /** @@ -59,22 +59,22 @@ public: /** * Return a data reference. */ - data& get_data(){ - return std::get>(doi_); + data& get_data(){ + return std::get>(doi_); } /** * Return a data reference. */ - const data& get_data() const { - return std::get>(doi_); + const data& get_data() const { + return std::get>(doi_); } }; /** * Representing data on the remote */ -template +template class remote_data { private: id id_; @@ -86,24 +86,24 @@ public: /** * Wait until data arrives */ - error_or> wait(wait_scope& wait); + error_or> wait(wait_scope& wait); /** * Asynchronously wait for a result */ - conveyor> on_receive(); + conveyor> on_receive(); }; /** * Client RPC reference structure */ -template +template class rpc_client { /** * request the data from the remote */ template - remote_data request_data(id data); + remote_data request_data(id data); /** @todo * Determine type based on Name @@ -157,8 +157,8 @@ class remote { /** * Connect to a remote */ - template - conveyor> connect(const remote_address& addr); + template + conveyor> connect(const remote_address& addr); /** * Start listening diff --git a/modules/remote-sycl/c++/remote.hpp b/modules/remote-sycl/c++/remote.hpp index d956314..a2b5a87 100644 --- a/modules/remote-sycl/c++/remote.hpp +++ b/modules/remote-sycl/c++/remote.hpp @@ -17,16 +17,16 @@ class remote; /** * Remote data class for the Sycl backend. */ -template -class remote_data { +template +class remote_data { private: id id_; - id_map* map_; + id_map* map_; public: /** * Main constructor */ - remote_data(const id& id, id_map& map): + remote_data(const id& id, id_map& map): id_{id}, map_{&map} {} @@ -34,14 +34,14 @@ public: /** * Request data asynchronously */ - conveyor> on_receive(); /// Stopped here + conveyor> on_receive(); /// Stopped here }; /** * */ template -class data, encode::Native> { +class data, encode::Native, rmt::Sycl> { public: using Schema = schema::Primitive; using NativeType = typename native_data_type::type; @@ -64,7 +64,7 @@ public: }; template -class data, encode::Native> { +class data, encode::Native, rmt::Sycl> { public: using Schema = schema::Array; private: @@ -86,8 +86,8 @@ public: } } - template - data(const data& from, cl::sycl::queue& q__): + template + data(const data& from, cl::sycl::queue& q__): total_length_{from.size()}, device_data_{cl::sycl::malloc_device::type>(from.size(), q__)}, queue_{&q__} @@ -105,28 +105,32 @@ public: } } - data>& at(uint64_t i){ + // data& at(uint64_t i){ + typename native_data_type::type& at(uint64_t i){ return device_data_[i]; } + + uint64_t size() const { + return total_length_; + } }; namespace impl { - -template +template struct rpc_id_map_helper { - static_assert(always_false, "Only support Interface schema types."); + static_assert(always_false, "Only supports Interface schema types."); }; -template -struct rpc_id_map_helper, Encoding> { - std::tuple...> maps; +template +struct rpc_id_map_helper, Encoding, Storage> { + std::tuple...> maps; }; } /** * Rpc Client class for the Sycl backend. */ -template -class rpc_client { +template +class rpc_client { public: private: /** @@ -141,12 +145,13 @@ public: /** * Rpc call */ - template + template error_or< id< typename schema_member_type::type::ResponseT > - > call(data_or_id::type::RequestT, ClientEncoding> input){ + > call(data_or_id::type::RequestT, Encoding, Storage> input){ + (void) input; return make_error("RpcClient side is not implemented"); } @@ -159,7 +164,7 @@ template class rpc_server { public: using InterfaceCtxT = cl::sycl::queue*; - using InterfaceT = interface; + using InterfaceT = interface; private: /** * Command queue for the sycl backend @@ -169,22 +174,23 @@ private: /** * The interface including the relevant context class. */ - interface cl_interface_; + interface cl_interface_; /** * Basic storage for response data. */ - impl::rpc_id_map_helper storage_; + impl::rpc_id_map_helper storage_; public: - rpc_server(interface cl_iface): + rpc_server(interface cl_iface): cmd_queue_{}, cl_interface_{std::move(cl_iface)}, storage_{} {} - template - remote_data request_data(id dat){ - return {dat, std::get>(storage_.maps)}; + template + remote_data request_data(id dat){ + /// @TODO Fix so I can receive data + return {dat, std::get>(storage_.maps)}; } /** @@ -195,23 +201,30 @@ public: id< typename schema_member_type::type::ResponseT > - > call(data_or_id::type::RequestT, ClientAllocation> input){ + > call(data_or_id::type::RequestT, Encoding, ClientAllocation> input){ + using FuncT = typename schema_member_type::type; + /** + * Object needed if and only if the provided data type is not an id + */ + own> dev_tmp_inp = nullptr; /** * First check if it's data or an id. * If it's an id, check if it's registered within the storage and retrieve it. */ - auto eoinp = [&,this]() -> error_or::type::RequestT, Encoding>* > { + auto eoinp = [&,this]() -> error_or* > { if(input.is_id()){ // storage_.maps - auto& inner_map = std::get::type::RequestT, Encoding >> (storage_.maps); + auto& inner_map = std::get> (storage_.maps); auto eov = inner_map.find(input.get_id()); if(eov.is_error()){ return std::move(eov.get_error()); } return eov.get_value(); } else { - return &input.get_data(); + auto& client_data = input.get_data(); + dev_tmp_inp = heap>(client_data, cmd_queue_); + return dev_tmp_inp.get(); } }(); if(eoinp.is_error()){ @@ -219,7 +232,7 @@ public: } auto& inp = *(eoinp.get_value()); - auto eod = cl_interface_.template call(std::move(inp), &cmd_queue_); + auto eod = cl_interface_.template call(inp, &cmd_queue_); if(eod.is_error()){ return std::move(eod.get_error()); @@ -229,7 +242,7 @@ public: * Store returned data in rpc storage */ auto& val = eod.get_value(); - auto& inner_map = std::get::type::RequestT, Encoding >> (storage_.maps); + auto& inner_map = std::get::type::RequestT, Encoding,rmt::Sycl>> (storage_.maps); auto eoid = inner_map.insert(std::move(val)); if(eoid.is_error()){ return std::move(eoid.get_error()); diff --git a/modules/remote-sycl/examples/sycl_basic.cpp b/modules/remote-sycl/examples/sycl_basic.cpp index 41aa2a1..677fd29 100644 --- a/modules/remote-sycl/examples/sycl_basic.cpp +++ b/modules/remote-sycl/examples/sycl_basic.cpp @@ -23,7 +23,7 @@ int main(){ saw::id> next_id{0u}; { - auto eov = rpc_server.template call<"increment">(saw::data, saw::encode::Native>{1u}); + auto eov = rpc_server.template call<"increment", saw::storage::Default>(saw::data, saw::encode::Native>{1u}); if(eov.is_error()){ auto& err = eov.get_error(); std::cerr<<"Error: "<(next_id); + auto eov = rpc_server.template call<"increment", saw::storage::Default>(next_id); if(eov.is_error()){ auto& err = eov.get_error(); std::cerr<<"Error: "<, UInt64>, "increment"> + Member, Array>, "increment"> >; } -saw::rpc_server, saw::rmt::Sycl> listen_basic_sycl(saw::remote& ctx, saw::remote_address& addr); +saw::rpc_server listen_basic_sycl(saw::remote& ctx, saw::remote_address& addr); diff --git a/modules/remote-sycl/examples/sycl_basic_kernel.cpp b/modules/remote-sycl/examples/sycl_basic_kernel.cpp index 86e73b5..94583b9 100644 --- a/modules/remote-sycl/examples/sycl_basic_kernel.cpp +++ b/modules/remote-sycl/examples/sycl_basic_kernel.cpp @@ -1,22 +1,20 @@ #include "sycl_basic.hpp" -saw::rpc_server, saw::rmt::Sycl> listen_basic_sycl(saw::remote& ctx, saw::remote_address& addr){ - saw::interface, cl::sycl::queue*> iface{ - [](saw::data, saw::encode::Native> in, cl::sycl::queue* q) -> saw::data> { - uint64_t inr = in.size(); - cl::sycl::buffer d_inc{ &inr, 1u }; - q->submit([&](cl::sycl::handler& h){ - auto a_inc = d_inc.get_access(h); +saw::rpc_server listen_basic_sycl(saw::remote& ctx, saw::remote_address& addr){ + saw::interface iface{ + + [](saw::data, saw::encode::Native, saw::rmt::Sycl> in, cl::sycl::queue* q) -> saw::data, saw::encode::Native, saw::rmt::Sycl> { - h.parallel_for(cl::sycl::range<1>(1u), [=] (cl::sycl::id<1> it){ - a_inc[0] += 1u; + q->submit([&](cl::sycl::handler& h){ + h.parallel_for(cl::sycl::range<1>(1u), [&] (cl::sycl::id<1> it){ + in.at(0u) += 1u; }); }); q->wait(); - return {inr}; + return in; } }; - auto rpc_server = ctx.template listen>(addr, std::move(iface)); + auto rpc_server = ctx.template listen(addr, std::move(iface)); return rpc_server; } -- cgit v1.2.3