summaryrefslogtreecommitdiff
path: root/modules/remote-sycl/c++/remote.hpp
diff options
context:
space:
mode:
authorClaudius "keldu" Holeksa <mail@keldu.de>2024-06-26 09:39:34 +0200
committerClaudius "keldu" Holeksa <mail@keldu.de>2024-06-26 09:39:34 +0200
commit729307460e77f62a532ee9841dcaed9c47f46419 (patch)
tree0b52ddbfa47d9d148907de90e7a2987d72ed7d73 /modules/remote-sycl/c++/remote.hpp
parent51b50882d2906b83c5275c732a56ff333ae6696f (diff)
Added better structure for the data server
Diffstat (limited to 'modules/remote-sycl/c++/remote.hpp')
-rw-r--r--modules/remote-sycl/c++/remote.hpp138
1 files changed, 41 insertions, 97 deletions
diff --git a/modules/remote-sycl/c++/remote.hpp b/modules/remote-sycl/c++/remote.hpp
index 2d0eaea..1ae3103 100644
--- a/modules/remote-sycl/c++/remote.hpp
+++ b/modules/remote-sycl/c++/remote.hpp
@@ -2,28 +2,10 @@
#include "common.hpp"
#include "data.hpp"
+#include "device.hpp"
+#include "transfer.hpp"
namespace saw {
-
-template<typename Schema>
-class data<Schema, encode::Native, rmt::Sycl> {
-private:
- cl::sycl::buffer<data<Schema, encode::Native, storage::Default>> data_;
-public:
- data(data<Schema, encode::Native, storage::Default>& data__):
- data_{&data__, 1u}
- {}
-
- auto& get_handle() {
- return data_;
- }
-
- template<cl::sycl::access::mode AccessMode>
- auto access(cl::sycl::handler& h){
- return data_.template get_access<AccessMode>(h);
- }
-};
-
/**
* Remote data class for the Sycl backend.
*/
@@ -78,35 +60,22 @@ public:
/**
* Meant to be a helper object which holds the allocated data on the sycl side
*/
-template<typename Schema, typename Encoding, typename Backend>
-class device_data;
+//template<typename Schema, typename Encoding, typename Backend>
+//class device_data;
/**
* This class helps in regards to the ownership on the server side
- */
template<typename Schema, typename Encoding>
class device_data<Schema, Encoding, rmt::Sycl> {
private:
- /**
- * The actual data
- */
data<Schema,Encoding,storage::Default>* device_data_;
- /**
- * The sycl queue object
- */
cl::sycl::queue* queue_;
public:
- /**
- * Main constructor
- */
device_data(data<Schema,Encoding,storage::Default>& device_data__, cl::sycl::queue& queue__):
device_data_{&device_data__},
queue_{&queue__}
{}
- /**
- * Destructor specifically designed to deallocate on the device.
- */
~device_data(){
if(device_data_){
cl::sycl::free(device_data_,queue_);
@@ -117,64 +86,33 @@ public:
SAW_FORBID_COPY(device_data);
SAW_FORBID_MOVE(device_data);
};
-
-namespace impl {
-template<typename Schema, typename Encoding, typename Backend>
-struct device_id_map {
- std::vector<device_data<Schema, Encoding, Backend>> data;
-};
-
-template<typename Iface, typename Encoding, typename Storage>
-struct rpc_id_map_helper {
- static_assert(always_false<Iface, Encoding,Storage>, "Only supports Interface schema types.");
-};
-
-template<typename... Members, typename Encoding, typename Storage>
-struct rpc_id_map_helper<schema::Interface<Members...>, Encoding, Storage> {
- std::tuple<id_map<typename Members::ValueType::ResponseT, Encoding, Storage>...> maps;
-};
-}
+ */
}
// Maybe a helper impl tmpl file?
namespace saw {
-/**
- * Represents a remote Sycl device.
- */
-template<>
-class device<rmt::Sycl> final {
-private:
- cl::sycl::queue cmd_queue_;
-public:
- device() = default;
- SAW_FORBID_COPY(device);
- SAW_FORBID_MOVE(device);
+namespace impl {
+template<typename Func>
+struct rpc_func_type_helper;
- /**
- * Copy data to device
- */
- template<typename Schema, typename Encoding, typename Storage>
- error_or<data<Schema, Encoding, rmt::Sycl>> copy_to_device(const data<Schema, Encoding, Storage>& host_data){
- return data<Schema, Encoding, rmt::Sycl>::copy_to_device(host_data, *this);
- }
-
- /**
- * Copy data to host
- */
- template<typename Schema, typename Encoding, typename Storage>
- error_or<data<Schema, Encoding, Storage>> copy_to_host(const data<Schema, Encoding, rmt::Sycl>& device_data){
- return device_data.copy_to_host();
- }
-
- /**
- * Get a reference to the handle
- */
- cl::sycl::queue& get_handle(){
- return cmd_queue_;
- }
+template<typename Response, typename Request>
+struct rpc_func_type_helper<schema::Function<Request, Response>>{
+ using type = tmpl_group<Response, Request>;
+};
+
+template<typename Iface>
+struct rpc_iface_type_helper {
+ using type = tmpl_group<>;
+};
+
+template<typename Func, string_literal K, typename... Functions, string_literal... Keys>
+struct rpc_iface_type_helper<schema::Interface<schema::Member<Func,K>,schema::Member<Functions,Keys>...>> {
+ using inner_type = typename rpc_func_type_helper<Func>::type;
+ using type = typename tmpl_concat<inner_type, typename rpc_iface_type_helper<schema::Interface<schema::Member<Functions,Keys>...>>::type>::type;
};
+}
/**
* Rpc Client class for the Sycl backend.
@@ -189,11 +127,17 @@ private:
rpc_server<Iface, Encoding, rmt::Sycl>* srv_;
/**
+ * TransferClient created from the internal RPC data server
+ */
+ data_client<typename impl::rpc_iface_type_helper<Iface>::type, Encoding, rmt::Sycl> data_client_;
+
+ /**
* Generated some sort of id for the request.
*/
public:
rpc_client(rpc_server<Iface, Encoding, rmt::Sycl>& srv):
- srv_{&srv}
+ srv_{&srv},
+ data_client_{srv_->data_server}
{}
/**
@@ -219,6 +163,7 @@ class rpc_server<Iface, Encoding, rmt::Sycl> {
public:
using InterfaceCtxT = cl::sycl::queue*;
using InterfaceT = interface<Iface, Encoding, storage::Default, InterfaceCtxT>;
+
private:
/**
* Device instance enabling the use of the remote device.
@@ -226,14 +171,15 @@ private:
device<rmt::Sycl>* device_;
/**
- * The interface including the relevant context class.
+ * Data server storing the relevant data
*/
- interface<Iface, Encoding, storage::Default, InterfaceCtxT> cl_interface_;
+ data_server<typename impl::rpc_iface_type_helper<Iface>::type, Encoding, rmt::Sycl> data_server_;
/**
- * Basic storage for response data.
+ * The interface including the relevant context class.
*/
- impl::rpc_id_map_helper<Iface, Encoding, rmt::Sycl> storage_;
+ interface<Iface, Encoding, storage::Default, InterfaceCtxT> cl_interface_;
+
public:
/**
@@ -241,8 +187,8 @@ public:
*/
rpc_server(device<rmt::Sycl>& dev__, InterfaceT cl_iface):
device_{&dev__},
- cl_interface_{std::move(cl_iface)},
- storage_{}
+ data_server_{},
+ cl_interface_{std::move(cl_iface)}
{}
/**
@@ -276,7 +222,7 @@ public:
/**
* Object needed if and only if the provided data type is not an id
*/
- own<device_data<typename FuncT::RequestT, Encoding, rmt::Sycl>> dev_tmp_inp = nullptr;
+ own<data<typename FuncT::RequestT, Encoding, rmt::Sycl>> dev_tmp_inp = nullptr;
/**
* First check if it's data or an id.
* If it's an id, check if it's registered within the storage and retrieve it.
@@ -284,8 +230,7 @@ public:
auto eoinp = [&,this]() -> error_or<data<typename FuncT::RequestT, Encoding, rmt::Sycl>* > {
if(input.is_id()){
// storage_.maps
- auto& inner_map = std::get<id_map<typename FuncT::RequestT, Encoding, rmt::Sycl>> (storage_.maps);
- auto eov = inner_map.find(input.get_id());
+ auto eov = data_server_.template find<typename FuncT::RequestT>(input.get_id());
if(eov.is_error()){
return std::move(eov.get_error());
}
@@ -319,8 +264,7 @@ public:
/**
* Store returned data in rpc storage
*/
- auto& inner_map = std::get<id_map<typename schema_member_type<Name, Iface>::type::RequestT, Encoding,rmt::Sycl>> (storage_.maps);
- auto eoid = inner_map.insert_as(std::move(val), rpc_id);
+ auto eoid = data_server_.template insert<typename schema_member_type<Name, Iface>::type::RequestT>(std::move(val), rpc_id);
if(eoid.is_error()){
return std::move(eoid.get_error());
}