summaryrefslogtreecommitdiff
path: root/modules/remote-sycl/c++
diff options
context:
space:
mode:
Diffstat (limited to 'modules/remote-sycl/c++')
-rw-r--r--modules/remote-sycl/c++/remote.hpp46
-rw-r--r--modules/remote-sycl/c++/transfer.hpp6
2 files changed, 32 insertions, 20 deletions
diff --git a/modules/remote-sycl/c++/remote.hpp b/modules/remote-sycl/c++/remote.hpp
index 7e77ec9..8ec4667 100644
--- a/modules/remote-sycl/c++/remote.hpp
+++ b/modules/remote-sycl/c++/remote.hpp
@@ -124,7 +124,7 @@ private:
/**
* Server this client is tied to
*/
- rpc_server<Iface, Encoding, rmt::Sycl>* srv_;
+ rpc_server<Iface, Encoding, Storage, rmt::Sycl>* srv_;
/**
* TransferClient created from the internal RPC data server
@@ -135,7 +135,7 @@ private:
* Generated some sort of id for the request.
*/
public:
- rpc_client(rpc_server<Iface, Encoding, rmt::Sycl>& srv):
+ rpc_client(rpc_server<Iface, Encoding, Storage, rmt::Sycl>& srv):
srv_{&srv},
data_client_{srv_->data_server}
{}
@@ -159,7 +159,7 @@ public:
* Rpc Server class for the Sycl backend.
*/
template<typename Iface, typename Encoding>
-class rpc_server<Iface, Encoding, rmt::Sycl> {
+class rpc_server<Iface, Encoding, storage::Default, rmt::Sycl> {
public:
using InterfaceCtxT = cl::sycl::queue*;
using InterfaceT = interface<Iface, Encoding, storage::Default, InterfaceCtxT>;
@@ -168,7 +168,7 @@ private:
/**
* Device instance enabling the use of the remote device.
*/
- device<rmt::Sycl>* device_;
+ our<device<rmt::Sycl>> device_;
using DataServerT = data_server<typename impl::rpc_iface_type_helper<Iface>::type, Encoding, rmt::Sycl>;
/**
@@ -186,8 +186,8 @@ public:
/**
* Main constructor
*/
- rpc_server(device<rmt::Sycl>& dev__, DataServerT& data_server__, InterfaceT cl_iface):
- device_{&dev__},
+ rpc_server(our<device<rmt::Sycl>> dev__, DataServerT& data_server__, InterfaceT cl_iface):
+ device_{std::move(dev__)},
data_server_{&data_server__},
cl_interface_{std::move(cl_iface)}
{}
@@ -277,13 +277,23 @@ template<>
struct remote_address<rmt::Sycl> {
private:
remote<rmt::Sycl>* ctx_;
+ our<device<rmt::Sycl>> device_;
SAW_FORBID_COPY(remote_address);
SAW_FORBID_MOVE(remote_address);
public:
- remote_address(remote<rmt::Sycl>& r_ctx):
- ctx_{&r_ctx}
+ remote_address(remote<rmt::Sycl>& r_ctx, our<device<rmt::Sycl>> d_ctx):
+ ctx_{&r_ctx},
+ device_{std::move(d_ctx)}
{}
+
+ our<device<rmt::Sycl>> copy_device_reference() const {
+ return device_;
+ }
+
+ device<rmt::Sycl>& get_device(){
+ return *device_;
+ }
};
template<>
@@ -302,24 +312,26 @@ public:
* we just create a default.
*/
conveyor<own<remote_address<rmt::Sycl>>> resolve_address(){
- return heap<remote_address<rmt::Sycl>>(*this);
+ auto dev = std::make_shared<device<rmt::Sycl>>();
+ return heap<remote_address<rmt::Sycl>>(*this, std::move(dev));
}
/**
- * Connect to a device
+ * Parse address, but don't resolve it.
*/
- device<rmt::Sycl> connect_device(const remote_address<rmt::Sycl>&){
- return {};
+ error_or<own<remote_address<rmt::Sycl>>> parse_address(){
+ auto dev = std::make_shared<device<rmt::Sycl>>();
+ return heap<remote_address<rmt::Sycl>>(*this, std::move(dev));
}
/**
* Spin up a rpc server
*/
- template<typename Iface, typename Encoding>
- rpc_server<Iface, Encoding, rmt::Sycl> listen(device<rmt::Sycl>& dev, typename rpc_server<Iface, Encoding, rmt::Sycl>::InterfaceT iface){
- using RpcServerT = rpc_server<Iface, Encoding, rmt::Sycl>;
- using InterfaceT = typename RpcServerT::InterfaceT;
- return {dev, std::move(iface)};
+ template<typename Iface, typename Encoding, typename Storage>
+ rpc_server<Iface, Encoding, Storage, rmt::Sycl> listen(remote_address<rmt::Sycl>& dev, typename rpc_server<Iface, Encoding, Storage, rmt::Sycl>::InterfaceT iface){
+ //using RpcServerT = rpc_server<Iface, Encoding, rmt::Sycl>;
+ //using InterfaceT = typename RpcServerT::InterfaceT;
+ return {dev.copy_device_reference(), std::move(iface)};
}
};
diff --git a/modules/remote-sycl/c++/transfer.hpp b/modules/remote-sycl/c++/transfer.hpp
index 8987de9..f535751 100644
--- a/modules/remote-sycl/c++/transfer.hpp
+++ b/modules/remote-sycl/c++/transfer.hpp
@@ -26,7 +26,7 @@ private:
/**
* Device context class
*/
- device<rmt::Sycl>* device_;
+ our<device<rmt::Sycl>> device_;
/**
* Store for the data the server manages.
@@ -36,8 +36,8 @@ public:
/**
* Main constructor
*/
- data_server(device<rmt::Sycl>& device__):
- device_{&device__}
+ data_server(our<device<rmt::Sycl>> device__):
+ device_{std::move(device__)}
{}
/**