summaryrefslogtreecommitdiff
path: root/modules/remote-sycl
diff options
context:
space:
mode:
authorClaudius "keldu" Holeksa <mail@keldu.de>2024-06-26 09:39:34 +0200
committerClaudius "keldu" Holeksa <mail@keldu.de>2024-06-26 09:39:34 +0200
commit729307460e77f62a532ee9841dcaed9c47f46419 (patch)
tree0b52ddbfa47d9d148907de90e7a2987d72ed7d73 /modules/remote-sycl
parent51b50882d2906b83c5275c732a56ff333ae6696f (diff)
Added better structure for the data server
Diffstat (limited to 'modules/remote-sycl')
-rw-r--r--modules/remote-sycl/c++/data.hpp23
-rw-r--r--modules/remote-sycl/c++/device.hpp38
-rw-r--r--modules/remote-sycl/c++/remote.hpp138
-rw-r--r--modules/remote-sycl/c++/transfer.hpp118
4 files changed, 204 insertions, 113 deletions
diff --git a/modules/remote-sycl/c++/data.hpp b/modules/remote-sycl/c++/data.hpp
index 7ecf1ae..3ad1d9c 100644
--- a/modules/remote-sycl/c++/data.hpp
+++ b/modules/remote-sycl/c++/data.hpp
@@ -4,4 +4,27 @@
namespace saw {
+/**
+ * Generic wrapper class which stores data on the sycl side.
+ * Most of the times this will be a root object.
+ */
+template<typename Schema>
+class data<Schema, encode::Native, rmt::Sycl> {
+private:
+ cl::sycl::buffer<data<Schema, encode::Native, storage::Default>> data_;
+public:
+ data(data<Schema, encode::Native, storage::Default>& data__):
+ data_{&data__, 1u}
+ {}
+
+ auto& get_handle() {
+ return data_;
+ }
+
+ template<cl::sycl::access::mode AccessMode>
+ auto access(cl::sycl::handler& h){
+ return data_.template get_access<AccessMode>(h);
+ }
+};
+
}
diff --git a/modules/remote-sycl/c++/device.hpp b/modules/remote-sycl/c++/device.hpp
index 30eed2f..6d133ae 100644
--- a/modules/remote-sycl/c++/device.hpp
+++ b/modules/remote-sycl/c++/device.hpp
@@ -1,5 +1,43 @@
#pragma once
+#include "common.hpp"
+
namespace saw {
+/**
+ * Represents a remote Sycl device.
+ */
+template<>
+class device<rmt::Sycl> final {
+private:
+ cl::sycl::queue cmd_queue_;
+public:
+ device() = default;
+
+ SAW_FORBID_COPY(device);
+ SAW_FORBID_MOVE(device);
+
+ /**
+ * Copy data to device
+ */
+ template<typename Schema, typename Encoding, typename Storage>
+ error_or<data<Schema, Encoding, rmt::Sycl>> copy_to_device(const data<Schema, Encoding, Storage>& host_data){
+ return data<Schema, Encoding, rmt::Sycl>::copy_to_device(host_data, *this);
+ }
+
+ /**
+ * Copy data to host
+ */
+ template<typename Schema, typename Encoding, typename Storage>
+ error_or<data<Schema, Encoding, Storage>> copy_to_host(const data<Schema, Encoding, rmt::Sycl>& dev_data){
+ return dev_data.copy_to_host();
+ }
+
+ /**
+ * Get a reference to the handle
+ */
+ cl::sycl::queue& get_handle(){
+ return cmd_queue_;
+ }
+};
}
diff --git a/modules/remote-sycl/c++/remote.hpp b/modules/remote-sycl/c++/remote.hpp
index 2d0eaea..1ae3103 100644
--- a/modules/remote-sycl/c++/remote.hpp
+++ b/modules/remote-sycl/c++/remote.hpp
@@ -2,28 +2,10 @@
#include "common.hpp"
#include "data.hpp"
+#include "device.hpp"
+#include "transfer.hpp"
namespace saw {
-
-template<typename Schema>
-class data<Schema, encode::Native, rmt::Sycl> {
-private:
- cl::sycl::buffer<data<Schema, encode::Native, storage::Default>> data_;
-public:
- data(data<Schema, encode::Native, storage::Default>& data__):
- data_{&data__, 1u}
- {}
-
- auto& get_handle() {
- return data_;
- }
-
- template<cl::sycl::access::mode AccessMode>
- auto access(cl::sycl::handler& h){
- return data_.template get_access<AccessMode>(h);
- }
-};
-
/**
* Remote data class for the Sycl backend.
*/
@@ -78,35 +60,22 @@ public:
/**
* Meant to be a helper object which holds the allocated data on the sycl side
*/
-template<typename Schema, typename Encoding, typename Backend>
-class device_data;
+//template<typename Schema, typename Encoding, typename Backend>
+//class device_data;
/**
* This class helps in regards to the ownership on the server side
- */
template<typename Schema, typename Encoding>
class device_data<Schema, Encoding, rmt::Sycl> {
private:
- /**
- * The actual data
- */
data<Schema,Encoding,storage::Default>* device_data_;
- /**
- * The sycl queue object
- */
cl::sycl::queue* queue_;
public:
- /**
- * Main constructor
- */
device_data(data<Schema,Encoding,storage::Default>& device_data__, cl::sycl::queue& queue__):
device_data_{&device_data__},
queue_{&queue__}
{}
- /**
- * Destructor specifically designed to deallocate on the device.
- */
~device_data(){
if(device_data_){
cl::sycl::free(device_data_,queue_);
@@ -117,64 +86,33 @@ public:
SAW_FORBID_COPY(device_data);
SAW_FORBID_MOVE(device_data);
};
-
-namespace impl {
-template<typename Schema, typename Encoding, typename Backend>
-struct device_id_map {
- std::vector<device_data<Schema, Encoding, Backend>> data;
-};
-
-template<typename Iface, typename Encoding, typename Storage>
-struct rpc_id_map_helper {
- static_assert(always_false<Iface, Encoding,Storage>, "Only supports Interface schema types.");
-};
-
-template<typename... Members, typename Encoding, typename Storage>
-struct rpc_id_map_helper<schema::Interface<Members...>, Encoding, Storage> {
- std::tuple<id_map<typename Members::ValueType::ResponseT, Encoding, Storage>...> maps;
-};
-}
+ */
}
// Maybe a helper impl tmpl file?
namespace saw {
-/**
- * Represents a remote Sycl device.
- */
-template<>
-class device<rmt::Sycl> final {
-private:
- cl::sycl::queue cmd_queue_;
-public:
- device() = default;
- SAW_FORBID_COPY(device);
- SAW_FORBID_MOVE(device);
+namespace impl {
+template<typename Func>
+struct rpc_func_type_helper;
- /**
- * Copy data to device
- */
- template<typename Schema, typename Encoding, typename Storage>
- error_or<data<Schema, Encoding, rmt::Sycl>> copy_to_device(const data<Schema, Encoding, Storage>& host_data){
- return data<Schema, Encoding, rmt::Sycl>::copy_to_device(host_data, *this);
- }
-
- /**
- * Copy data to host
- */
- template<typename Schema, typename Encoding, typename Storage>
- error_or<data<Schema, Encoding, Storage>> copy_to_host(const data<Schema, Encoding, rmt::Sycl>& device_data){
- return device_data.copy_to_host();
- }
-
- /**
- * Get a reference to the handle
- */
- cl::sycl::queue& get_handle(){
- return cmd_queue_;
- }
+template<typename Response, typename Request>
+struct rpc_func_type_helper<schema::Function<Request, Response>>{
+ using type = tmpl_group<Response, Request>;
+};
+
+template<typename Iface>
+struct rpc_iface_type_helper {
+ using type = tmpl_group<>;
+};
+
+template<typename Func, string_literal K, typename... Functions, string_literal... Keys>
+struct rpc_iface_type_helper<schema::Interface<schema::Member<Func,K>,schema::Member<Functions,Keys>...>> {
+ using inner_type = typename rpc_func_type_helper<Func>::type;
+ using type = typename tmpl_concat<inner_type, typename rpc_iface_type_helper<schema::Interface<schema::Member<Functions,Keys>...>>::type>::type;
};
+}
/**
* Rpc Client class for the Sycl backend.
@@ -189,11 +127,17 @@ private:
rpc_server<Iface, Encoding, rmt::Sycl>* srv_;
/**
+ * TransferClient created from the internal RPC data server
+ */
+ data_client<typename impl::rpc_iface_type_helper<Iface>::type, Encoding, rmt::Sycl> data_client_;
+
+ /**
* Generated some sort of id for the request.
*/
public:
rpc_client(rpc_server<Iface, Encoding, rmt::Sycl>& srv):
- srv_{&srv}
+ srv_{&srv},
+ data_client_{srv_->data_server}
{}
/**
@@ -219,6 +163,7 @@ class rpc_server<Iface, Encoding, rmt::Sycl> {
public:
using InterfaceCtxT = cl::sycl::queue*;
using InterfaceT = interface<Iface, Encoding, storage::Default, InterfaceCtxT>;
+
private:
/**
* Device instance enabling the use of the remote device.
@@ -226,14 +171,15 @@ private:
device<rmt::Sycl>* device_;
/**
- * The interface including the relevant context class.
+ * Data server storing the relevant data
*/
- interface<Iface, Encoding, storage::Default, InterfaceCtxT> cl_interface_;
+ data_server<typename impl::rpc_iface_type_helper<Iface>::type, Encoding, rmt::Sycl> data_server_;
/**
- * Basic storage for response data.
+ * The interface including the relevant context class.
*/
- impl::rpc_id_map_helper<Iface, Encoding, rmt::Sycl> storage_;
+ interface<Iface, Encoding, storage::Default, InterfaceCtxT> cl_interface_;
+
public:
/**
@@ -241,8 +187,8 @@ public:
*/
rpc_server(device<rmt::Sycl>& dev__, InterfaceT cl_iface):
device_{&dev__},
- cl_interface_{std::move(cl_iface)},
- storage_{}
+ data_server_{},
+ cl_interface_{std::move(cl_iface)}
{}
/**
@@ -276,7 +222,7 @@ public:
/**
* Object needed if and only if the provided data type is not an id
*/
- own<device_data<typename FuncT::RequestT, Encoding, rmt::Sycl>> dev_tmp_inp = nullptr;
+ own<data<typename FuncT::RequestT, Encoding, rmt::Sycl>> dev_tmp_inp = nullptr;
/**
* First check if it's data or an id.
* If it's an id, check if it's registered within the storage and retrieve it.
@@ -284,8 +230,7 @@ public:
auto eoinp = [&,this]() -> error_or<data<typename FuncT::RequestT, Encoding, rmt::Sycl>* > {
if(input.is_id()){
// storage_.maps
- auto& inner_map = std::get<id_map<typename FuncT::RequestT, Encoding, rmt::Sycl>> (storage_.maps);
- auto eov = inner_map.find(input.get_id());
+ auto eov = data_server_.template find<typename FuncT::RequestT>(input.get_id());
if(eov.is_error()){
return std::move(eov.get_error());
}
@@ -319,8 +264,7 @@ public:
/**
* Store returned data in rpc storage
*/
- auto& inner_map = std::get<id_map<typename schema_member_type<Name, Iface>::type::RequestT, Encoding,rmt::Sycl>> (storage_.maps);
- auto eoid = inner_map.insert_as(std::move(val), rpc_id);
+ auto eoid = data_server_.template insert<typename schema_member_type<Name, Iface>::type::RequestT>(std::move(val), rpc_id);
if(eoid.is_error()){
return std::move(eoid.get_error());
}
diff --git a/modules/remote-sycl/c++/transfer.hpp b/modules/remote-sycl/c++/transfer.hpp
index 6849caa..65a9b9e 100644
--- a/modules/remote-sycl/c++/transfer.hpp
+++ b/modules/remote-sycl/c++/transfer.hpp
@@ -2,11 +2,26 @@
#include "common.hpp"
#include "data.hpp"
+#include "device.hpp"
+
#include <forstio/error.hpp>
+#include <forstio/reduce_templates.hpp>
namespace saw {
-template<typename Schema, typename Encoding>
-class data_server<Schema, Encoding, rmt::Sycl> {
+namespace impl {
+template<typename Encoding, typename T>
+struct data_server_redux {
+ using type = std::tuple<>;
+};
+
+template<typename Encoding, typename... Schema>
+struct data_server_redux<Encoding, tmpl_group<Schema...>> {
+ using type = std::tuple<std::unordered_map<uint64_t, data<Schema, Encoding, rmt::Sycl>>...>;
+};
+}
+
+template<typename... Schema, typename Encoding>
+class data_server<tmpl_group<Schema...>, Encoding, rmt::Sycl> {
private:
/**
* Device context class
@@ -16,7 +31,7 @@ private:
/**
* Store for the data the server manages.
*/
- std::unordered_map<uint64_t, data<Schema, Encoding, rmt::Sycl>> values_;
+ impl::data_server_redux<Encoding, typename tmpl_reduce<tmpl_group<Schema...>>::type >::type values_;
public:
/**
* Main constructor
@@ -26,29 +41,83 @@ public:
{}
/**
- * Receive data which we will store.
+ * Get data which we will store.
*/
- error_or<void> send(const data<Schema, Encoding, storage::Default>& dat, id<Schema> store_id){
- auto eoval = device_->copy_to_device(dat);
+ template<typename Sch>
+ error_or<void> send(const data<Sch, Encoding, storage::Default>& dat, id<Sch> store_id){
+ auto& vals = std::get<Sch>(values_);
+ auto eoval = device_->template copy_to_device<Sch, Encoding, storage::Default>(dat);
if(eoval.is_error()){
auto& err = eoval.get_error();
return std::move(err);
}
- return make_error<err::not_implemented>();
+ auto& val = eoval.get_value();
+ try {
+ auto insert_res = vals.insert(std::make_pair(store_id.get_value(), std::move(val)));
+ if(!insert_res.second){
+ return make_error<err::already_exists>();
+ }
+ }catch ( std::exception& ){
+ return make_error<err::out_of_memory>();
+ }
+ return void_t{};
+ }
+
+ /**
+ * Requests data from the server
+ */
+ template<typename Sch>
+ error_or<data<Sch, Encoding, storage::Default>> receive(id<Sch> store_id){
+ auto& vals = std::get<Sch>(values_);
+ auto find_res = vals.find(store_id.get_value());
+ if(find_res == vals.end()){
+ return make_error<err::not_found>();
+ }
+ auto& dat = find_res->second;
+
+ auto eoval = device_->copy_to_host(dat);
+ return eoval;
+ }
+
+ /**
+ * Request an erase of the stored data
+ */
+ template<typename Sch>
+ error_or<void> erase(id<Sch> store_id){
+ auto& vals = std::get<Sch>(values_);
+ auto erase_op = vals.erase(store_id.get_value());
+ if(erase_op == 0u){
+ return make_error<err::not_found>();
+ }
+ return void_t{};
}
- error_or<data<Schema, Encoding, storage::Default>> receive(id<Schema> store_id){
- return make_error<err::not_implemented>();
+ /**
+ * Get the stored data on the server side for immediate use.
+ * Insert operations may invalidate the pointer.
+ */
+ template<typename Sch>
+ error_or<data<Sch, Encoding, rmt::Sycl>*> find(id<Sch> store_id){
+ auto& vals = std::get<Sch>(values_);
+ auto find_res = vals.find(store_id.get_value());
+ if(find_res == vals.end()){
+ return make_error<err::not_found>();
+ }
+
+ return &(find_res.second);
}
};
-template<typename Schema, typename Encoding>
-class data_client<Schema, Encoding, rmt::Sycl> {
+/**
+ * Client for transporting data to remote and receiving data back
+ */
+template<typename... Schema, typename Encoding>
+class data_client<tmpl_group<Schema...>, Encoding, rmt::Sycl> {
private:
/**
* Corresponding server for this client
*/
- data_server<Schema, Encoding, rmt::Sycl>* srv_;
+ data_server<tmpl_group<Schema...>, Encoding, rmt::Sycl>* srv_;
/**
* The next id for identifying issues on the remote side.
@@ -58,16 +127,17 @@ public:
/**
* Main constructor
*/
- data_client(data_server<Schema, Encoding, rmt::Sycl>& srv__):
+ data_client(data_server<tmpl_group<Schema...>, Encoding, rmt::Sycl>& srv__):
srv_{&srv__},
next_id_{0u}
{}
/**
- * Send data to.
+ * Send data to the remote.
*/
- error_or<id<Schema>> send(const data<Schema, Encoding, storage::Default>& dat){
- id<Schema> dat_id{next_id_};
+ template<typename Sch>
+ error_or<id<Sch>> send(const data<Sch, Encoding, storage::Default>& dat){
+ id<Sch> dat_id{next_id_};
auto eov = srv_->send(dat, dat_id);
if(eov.is_error()){
auto& err = eov.get_error();
@@ -77,5 +147,21 @@ public:
++next_id_;
return dat_id;
}
+
+ /**
+ * Receive data
+ */
+ template<typename Sch>
+ conveyor<data<Sch, Encoding, storage::Default>> receive(id<Sch> dat_id){
+ return srv_->receive(dat_id);
+ }
+
+ /**
+ * Erase data
+ */
+ template<typename Sch>
+ error_or<void> erase(id<Sch> dat_id){
+ return srv_->erase(dat_id);
+ }
};
}