summaryrefslogtreecommitdiff
path: root/modules/remote-sycl/c++/remote.hpp
diff options
context:
space:
mode:
authorClaudius "keldu" Holeksa <mail@keldu.de>2024-06-18 16:35:39 +0200
committerClaudius "keldu" Holeksa <mail@keldu.de>2024-06-18 16:35:39 +0200
commit80b706332a48f54ae289093ee11b17f20ab2dc2e (patch)
treedc5d54e9aa583e9d2c7519958b3735a73d60f85b /modules/remote-sycl/c++/remote.hpp
parentdee0184f2bedfb3919309c9f372bd0cdec520e9d (diff)
Changed some setup designs
Diffstat (limited to 'modules/remote-sycl/c++/remote.hpp')
-rw-r--r--modules/remote-sycl/c++/remote.hpp93
1 files changed, 50 insertions, 43 deletions
diff --git a/modules/remote-sycl/c++/remote.hpp b/modules/remote-sycl/c++/remote.hpp
index 4a383d1..1873669 100644
--- a/modules/remote-sycl/c++/remote.hpp
+++ b/modules/remote-sycl/c++/remote.hpp
@@ -1,19 +1,16 @@
#pragma once
-#include <forstio/io_codec/rpc.hpp>
-#include <forstio/codec/data.hpp>
-#include <forstio/codec/id_map.hpp>
-
-#include <AdaptiveCpp/CL/sycl.hpp>
+#include "common.hpp"
+#include "data.hpp"
namespace saw {
-namespace rmt {
-struct Sycl {};
-}
template<>
class remote<rmt::Sycl>;
+template<typename T>
+class device;
+
/**
* Remote data class for the Sycl backend.
*/
@@ -65,32 +62,24 @@ public:
};
/**
- *
-template<typename T, uint64_t N>
-class data<schema::Primitive<T,N>, encode::Native, rmt::Sycl> {
-public:
- using Schema = schema::Primitive<T,N>;
- using NativeType = typename native_data_type<Schema>::type;
-private:
- NativeType val_;
-public:
- data(NativeType val__):
- val_{val__}
- {}
-
- NativeType get(){
- return val_;
- }
-};
+ * Sycl data class for handling the array Schema.
*/
-
template<typename T, uint64_t D>
class data<schema::Array<T,D>, encode::Native, rmt::Sycl> {
public:
using Schema = schema::Array<T,D>;
private:
+ /**
+ * Absolute size of the stored elementes.
+ */
uint64_t total_length_;
+ /**
+ * The data itself.
+ */
data<T,encode::Native,storage::Default>* device_data_;
+ /**
+ * Referenced sycl queue
+ */
cl::sycl::queue* queue_;
static_assert(is_primitive<T>::value, "Only supports primitives for now");
@@ -181,6 +170,10 @@ public:
return data_;
}
+ template<typename Storage>
+ static error_or<data<Schema, encode::Native, rmt::Sycl>> copy_to_device(const data<Schema, encode::Native, Storage>& host_data, device<rmt::Sycl>& dev);
+
+
data<T, encode::Native, storage::Default>& at(uint64_t i){
return device_data_[i];
}
@@ -202,34 +195,41 @@ struct rpc_id_map_helper<schema::Interface<Members...>, Encoding, Storage> {
};
}
-template<typename T>
-class device;
-
/**
* Represents a remote Sycl device.
*
*/
template<>
-class device<rmt::Sycl> {
+class device<rmt::Sycl> final {
private:
cl::sycl::queue cmd_queue_;
public:
+ device() = default;
+
+ SAW_FORBID_COPY(device);
+ SAW_FORBID_MOVE(device);
+
/**
* Copy data to device
*/
- template<typename Schema, typename Encoding, typename Storage, typename EncodingHost, typename StorageHost>
- error_or<data<Schema, Encoding, Storage>> copy_to_device(const data<Schema, EncodingHost, StorageHost>>& host_data){
- (void) host_data;
- return make_error<err::not_implemented>();
+ template<typename Schema, typename Encoding, typename Storage>
+ error_or<data<Schema, Encoding, rmt::Sycl>> copy_to_device(const data<Schema, Encoding, Storage>& host_data){
+ return data<Schema, Encoding, rmt::Sycl>::copy_to_device(host_data);
}
/**
* Copy data to host
*/
- template<typename Schema, typename Encoding, typename Storage, typename EncodingDevice, typename StorageDevice>
- error_or<data<Schema, Encoding, Storage>> copy_to_host(const data<Schema, EncodingDevice, StorageDevice>>& device_data){
- (void) device_data;
- return make_error<err::not_implemented>();
+ template<typename Schema, typename Encoding, typename Storage>
+ error_or<data<Schema, Encoding, Storage>> copy_to_host(const data<Schema, Encoding, rmt::Sycl>& device_data){
+ return device_data.copy_to_host();
+ }
+
+ /**
+ * Get a reference to the handle
+ */
+ cl::sycl::queue& get_handle(){
+ return cmd_queue_;
}
};
@@ -317,7 +317,7 @@ public:
template<typename IdT, typename Storage>
remote_data<IdT, Encoding, Storage, rmt::Sycl> request_data(id<IdT> dat_id){
- return {dat_id, std::get<id_map<IdT,Encoding,rmt::Sycl>>(storage_.maps), device_.cmd_queue_};
+ return {dat_id, std::get<id_map<IdT,Encoding,rmt::Sycl>>(storage_.maps), device_->get_handle()};
}
/**
@@ -350,8 +350,8 @@ public:
return eov.get_value();
} else {
auto& client_data = input.get_data();
- dev_tmp_inp = heap<data<typename FuncT::RequestT, Encoding, rmt::Sycl>>(client_data, cmd_queue_);
- cmd_queue_.wait();
+ dev_tmp_inp = heap<data<typename FuncT::RequestT, Encoding, rmt::Sycl>>(client_data, device_->get_handle());
+ device_->get_handle().wait();
return dev_tmp_inp.get();
}
}();
@@ -360,7 +360,7 @@ public:
}
auto& inp = *(eoinp.get_value());
- auto eod = cl_interface_.template call<Name>(inp, &cmd_queue_);
+ auto eod = cl_interface_.template call<Name>(inp, &(device_->get_handle()));
if(eod.is_error()){
return std::move(eod.get_error());
@@ -424,11 +424,18 @@ public:
* Spin up a rpc server
*/
template<typename Iface, typename Encoding>
- rpc_server<Iface, Encoding, rmt::Sycl> listen(const device<rmt::Sycl>& dev, typename rpc_server<Iface, Encoding, rmt::Sycl>::InterfaceT iface){
+ rpc_server<Iface, Encoding, rmt::Sycl> listen(device<rmt::Sycl>& dev, typename rpc_server<Iface, Encoding, rmt::Sycl>::InterfaceT iface){
using RpcServerT = rpc_server<Iface, Encoding, rmt::Sycl>;
using InterfaceT = typename RpcServerT::InterfaceT;
return {dev, std::move(iface)};
}
};
+template<typename T, uint64_t D>
+template<typename Storage>
+error_or<data<schema::Array<T,D>, encode::Native, rmt::Sycl>> data<schema::Array<T,D>, encode::Native, rmt::Sycl>::copy_to_device(const data<schema::Array<T,D>, encode::Native, Storage>& host_data, device<rmt::Sycl>& dev){
+ data<schema::Array<T,D>, encode::Native, rmt::Sycl> device_data{host_data.size(), dev.get_handle()};
+ return make_error<err::not_implemented>();
+}
+
}