diff options
-rw-r--r-- | modules/io_codec/c++/rpc.hpp | 10 | ||||
-rw-r--r-- | modules/remote-sycl/c++/remote.hpp | 46 | ||||
-rw-r--r-- | modules/remote-sycl/c++/transfer.hpp | 6 | ||||
-rw-r--r-- | modules/remote-sycl/tests/data.cpp | 2 | ||||
-rw-r--r-- | modules/remote-sycl/tests/sycl_basics.cpp | 5 |
5 files changed, 40 insertions, 29 deletions
diff --git a/modules/io_codec/c++/rpc.hpp b/modules/io_codec/c++/rpc.hpp index f01ebd5..2c97d6b 100644 --- a/modules/io_codec/c++/rpc.hpp +++ b/modules/io_codec/c++/rpc.hpp @@ -121,12 +121,12 @@ class rpc_client { /** * Implementation of a remote server on the backend */ -template<typename Iface, typename Encoding, typename Remote> +template<typename Iface, typename Encoding, typename Storage, typename Remote> class rpc_server { private: - interface<Iface, Encoding> iface_; + interface<Iface, Encoding, Storage> iface_; public: - rpc_server(interface<Iface, Encoding> iface): + rpc_server(interface<Iface, Encoding, Storage> iface): iface_{std::move(iface)} {} }; @@ -163,7 +163,7 @@ class remote { /** * Start listening */ - template<typename Iface, typename Encode> - rpc_server<Iface, Encode, Remote> listen(); + template<typename Iface, typename Encode, typename Storage> + rpc_server<Iface, Encode, Storage, Remote> listen(); }; } diff --git a/modules/remote-sycl/c++/remote.hpp b/modules/remote-sycl/c++/remote.hpp index 7e77ec9..8ec4667 100644 --- a/modules/remote-sycl/c++/remote.hpp +++ b/modules/remote-sycl/c++/remote.hpp @@ -124,7 +124,7 @@ private: /** * Server this client is tied to */ - rpc_server<Iface, Encoding, rmt::Sycl>* srv_; + rpc_server<Iface, Encoding, Storage, rmt::Sycl>* srv_; /** * TransferClient created from the internal RPC data server @@ -135,7 +135,7 @@ private: * Generated some sort of id for the request. */ public: - rpc_client(rpc_server<Iface, Encoding, rmt::Sycl>& srv): + rpc_client(rpc_server<Iface, Encoding, Storage, rmt::Sycl>& srv): srv_{&srv}, data_client_{srv_->data_server} {} @@ -159,7 +159,7 @@ public: * Rpc Server class for the Sycl backend. */ template<typename Iface, typename Encoding> -class rpc_server<Iface, Encoding, rmt::Sycl> { +class rpc_server<Iface, Encoding, storage::Default, rmt::Sycl> { public: using InterfaceCtxT = cl::sycl::queue*; using InterfaceT = interface<Iface, Encoding, storage::Default, InterfaceCtxT>; @@ -168,7 +168,7 @@ private: /** * Device instance enabling the use of the remote device. */ - device<rmt::Sycl>* device_; + our<device<rmt::Sycl>> device_; using DataServerT = data_server<typename impl::rpc_iface_type_helper<Iface>::type, Encoding, rmt::Sycl>; /** @@ -186,8 +186,8 @@ public: /** * Main constructor */ - rpc_server(device<rmt::Sycl>& dev__, DataServerT& data_server__, InterfaceT cl_iface): - device_{&dev__}, + rpc_server(our<device<rmt::Sycl>> dev__, DataServerT& data_server__, InterfaceT cl_iface): + device_{std::move(dev__)}, data_server_{&data_server__}, cl_interface_{std::move(cl_iface)} {} @@ -277,13 +277,23 @@ template<> struct remote_address<rmt::Sycl> { private: remote<rmt::Sycl>* ctx_; + our<device<rmt::Sycl>> device_; SAW_FORBID_COPY(remote_address); SAW_FORBID_MOVE(remote_address); public: - remote_address(remote<rmt::Sycl>& r_ctx): - ctx_{&r_ctx} + remote_address(remote<rmt::Sycl>& r_ctx, our<device<rmt::Sycl>> d_ctx): + ctx_{&r_ctx}, + device_{std::move(d_ctx)} {} + + our<device<rmt::Sycl>> copy_device_reference() const { + return device_; + } + + device<rmt::Sycl>& get_device(){ + return *device_; + } }; template<> @@ -302,24 +312,26 @@ public: * we just create a default. */ conveyor<own<remote_address<rmt::Sycl>>> resolve_address(){ - return heap<remote_address<rmt::Sycl>>(*this); + auto dev = std::make_shared<device<rmt::Sycl>>(); + return heap<remote_address<rmt::Sycl>>(*this, std::move(dev)); } /** - * Connect to a device + * Parse address, but don't resolve it. */ - device<rmt::Sycl> connect_device(const remote_address<rmt::Sycl>&){ - return {}; + error_or<own<remote_address<rmt::Sycl>>> parse_address(){ + auto dev = std::make_shared<device<rmt::Sycl>>(); + return heap<remote_address<rmt::Sycl>>(*this, std::move(dev)); } /** * Spin up a rpc server */ - template<typename Iface, typename Encoding> - rpc_server<Iface, Encoding, rmt::Sycl> listen(device<rmt::Sycl>& dev, typename rpc_server<Iface, Encoding, rmt::Sycl>::InterfaceT iface){ - using RpcServerT = rpc_server<Iface, Encoding, rmt::Sycl>; - using InterfaceT = typename RpcServerT::InterfaceT; - return {dev, std::move(iface)}; + template<typename Iface, typename Encoding, typename Storage> + rpc_server<Iface, Encoding, Storage, rmt::Sycl> listen(remote_address<rmt::Sycl>& dev, typename rpc_server<Iface, Encoding, Storage, rmt::Sycl>::InterfaceT iface){ + //using RpcServerT = rpc_server<Iface, Encoding, rmt::Sycl>; + //using InterfaceT = typename RpcServerT::InterfaceT; + return {dev.copy_device_reference(), std::move(iface)}; } }; diff --git a/modules/remote-sycl/c++/transfer.hpp b/modules/remote-sycl/c++/transfer.hpp index 8987de9..f535751 100644 --- a/modules/remote-sycl/c++/transfer.hpp +++ b/modules/remote-sycl/c++/transfer.hpp @@ -26,7 +26,7 @@ private: /** * Device context class */ - device<rmt::Sycl>* device_; + our<device<rmt::Sycl>> device_; /** * Store for the data the server manages. @@ -36,8 +36,8 @@ public: /** * Main constructor */ - data_server(device<rmt::Sycl>& device__): - device_{&device__} + data_server(our<device<rmt::Sycl>> device__): + device_{std::move(device__)} {} /** diff --git a/modules/remote-sycl/tests/data.cpp b/modules/remote-sycl/tests/data.cpp index dff19fb..de09c92 100644 --- a/modules/remote-sycl/tests/data.cpp +++ b/modules/remote-sycl/tests/data.cpp @@ -41,7 +41,7 @@ SAW_TEST("SYCL Data Management"){ wait.poll(); SAW_EXPECT(rmt_addr, "Remote address hasn't been filled"); - auto device = rmt.connect_device(*rmt_addr); + auto device = rmt_addr->copy_device_reference(); auto data_srv = data_server<tmpl_group<schema::TestStruct>, encode::Native, rmt::Sycl>{device}; diff --git a/modules/remote-sycl/tests/sycl_basics.cpp b/modules/remote-sycl/tests/sycl_basics.cpp index 90da299..61e0d87 100644 --- a/modules/remote-sycl/tests/sycl_basics.cpp +++ b/modules/remote-sycl/tests/sycl_basics.cpp @@ -46,8 +46,6 @@ SAW_TEST("SYCL Test Setup"){ wait.poll(); SAW_EXPECT(rmt_addr, "Remote Address class hasn't been filled"); - auto device = rmt.connect_device(*rmt_addr); - data<schema::TestStruct, encode::Native, rmt::Sycl> device_data{host_data}; interface<schema::Foo, encode::Native,rmt::Sycl, cl::sycl::queue*> cl_iface { @@ -80,8 +78,9 @@ SAW_TEST("SYCL Test Setup"){ return saw::void_t{}; } }; + auto& device = rmt_addr->get_device(); - cl_iface.template call <"foo">(device_data, &device.get_handle()); + cl_iface.template call <"foo">(device_data, &(device.get_handle())); device.get_handle().wait(); } } |