summaryrefslogtreecommitdiff
path: root/modules/remote-sycl
diff options
context:
space:
mode:
authorClaudius "keldu" Holeksa <mail@keldu.de>2024-06-18 16:35:39 +0200
committerClaudius "keldu" Holeksa <mail@keldu.de>2024-06-18 16:35:39 +0200
commit80b706332a48f54ae289093ee11b17f20ab2dc2e (patch)
treedc5d54e9aa583e9d2c7519958b3735a73d60f85b /modules/remote-sycl
parentdee0184f2bedfb3919309c9f372bd0cdec520e9d (diff)
Changed some setup designs
Diffstat (limited to 'modules/remote-sycl')
-rw-r--r--modules/remote-sycl/c++/common.hpp14
-rw-r--r--modules/remote-sycl/c++/data.hpp7
-rw-r--r--modules/remote-sycl/c++/device.hpp5
-rw-r--r--modules/remote-sycl/c++/remote.hpp93
-rw-r--r--modules/remote-sycl/examples/sycl_basic.cpp4
-rw-r--r--modules/remote-sycl/examples/sycl_basic.hpp2
-rw-r--r--modules/remote-sycl/examples/sycl_basic_kernel.cpp4
7 files changed, 82 insertions, 47 deletions
diff --git a/modules/remote-sycl/c++/common.hpp b/modules/remote-sycl/c++/common.hpp
new file mode 100644
index 0000000..c1f9ddf
--- /dev/null
+++ b/modules/remote-sycl/c++/common.hpp
@@ -0,0 +1,14 @@
+#pragma once
+
+#include <forstio/io_codec/rpc.hpp>
+#include <forstio/codec/data.hpp>
+#include <forstio/codec/id_map.hpp>
+
+#include <AdaptiveCpp/CL/sycl.hpp>
+
+namespace saw {
+namespace rmt {
+struct Sycl {};
+}
+
+}
diff --git a/modules/remote-sycl/c++/data.hpp b/modules/remote-sycl/c++/data.hpp
new file mode 100644
index 0000000..7ecf1ae
--- /dev/null
+++ b/modules/remote-sycl/c++/data.hpp
@@ -0,0 +1,7 @@
+#pragma once
+
+#include "common.hpp"
+
+namespace saw {
+
+}
diff --git a/modules/remote-sycl/c++/device.hpp b/modules/remote-sycl/c++/device.hpp
new file mode 100644
index 0000000..30eed2f
--- /dev/null
+++ b/modules/remote-sycl/c++/device.hpp
@@ -0,0 +1,5 @@
+#pragma once
+
+namespace saw {
+
+}
diff --git a/modules/remote-sycl/c++/remote.hpp b/modules/remote-sycl/c++/remote.hpp
index 4a383d1..1873669 100644
--- a/modules/remote-sycl/c++/remote.hpp
+++ b/modules/remote-sycl/c++/remote.hpp
@@ -1,19 +1,16 @@
#pragma once
-#include <forstio/io_codec/rpc.hpp>
-#include <forstio/codec/data.hpp>
-#include <forstio/codec/id_map.hpp>
-
-#include <AdaptiveCpp/CL/sycl.hpp>
+#include "common.hpp"
+#include "data.hpp"
namespace saw {
-namespace rmt {
-struct Sycl {};
-}
template<>
class remote<rmt::Sycl>;
+template<typename T>
+class device;
+
/**
* Remote data class for the Sycl backend.
*/
@@ -65,32 +62,24 @@ public:
};
/**
- *
-template<typename T, uint64_t N>
-class data<schema::Primitive<T,N>, encode::Native, rmt::Sycl> {
-public:
- using Schema = schema::Primitive<T,N>;
- using NativeType = typename native_data_type<Schema>::type;
-private:
- NativeType val_;
-public:
- data(NativeType val__):
- val_{val__}
- {}
-
- NativeType get(){
- return val_;
- }
-};
+ * Sycl data class for handling the array Schema.
*/
-
template<typename T, uint64_t D>
class data<schema::Array<T,D>, encode::Native, rmt::Sycl> {
public:
using Schema = schema::Array<T,D>;
private:
+ /**
+ * Absolute size of the stored elementes.
+ */
uint64_t total_length_;
+ /**
+ * The data itself.
+ */
data<T,encode::Native,storage::Default>* device_data_;
+ /**
+ * Referenced sycl queue
+ */
cl::sycl::queue* queue_;
static_assert(is_primitive<T>::value, "Only supports primitives for now");
@@ -181,6 +170,10 @@ public:
return data_;
}
+ template<typename Storage>
+ static error_or<data<Schema, encode::Native, rmt::Sycl>> copy_to_device(const data<Schema, encode::Native, Storage>& host_data, device<rmt::Sycl>& dev);
+
+
data<T, encode::Native, storage::Default>& at(uint64_t i){
return device_data_[i];
}
@@ -202,34 +195,41 @@ struct rpc_id_map_helper<schema::Interface<Members...>, Encoding, Storage> {
};
}
-template<typename T>
-class device;
-
/**
* Represents a remote Sycl device.
*
*/
template<>
-class device<rmt::Sycl> {
+class device<rmt::Sycl> final {
private:
cl::sycl::queue cmd_queue_;
public:
+ device() = default;
+
+ SAW_FORBID_COPY(device);
+ SAW_FORBID_MOVE(device);
+
/**
* Copy data to device
*/
- template<typename Schema, typename Encoding, typename Storage, typename EncodingHost, typename StorageHost>
- error_or<data<Schema, Encoding, Storage>> copy_to_device(const data<Schema, EncodingHost, StorageHost>>& host_data){
- (void) host_data;
- return make_error<err::not_implemented>();
+ template<typename Schema, typename Encoding, typename Storage>
+ error_or<data<Schema, Encoding, rmt::Sycl>> copy_to_device(const data<Schema, Encoding, Storage>& host_data){
+ return data<Schema, Encoding, rmt::Sycl>::copy_to_device(host_data);
}
/**
* Copy data to host
*/
- template<typename Schema, typename Encoding, typename Storage, typename EncodingDevice, typename StorageDevice>
- error_or<data<Schema, Encoding, Storage>> copy_to_host(const data<Schema, EncodingDevice, StorageDevice>>& device_data){
- (void) device_data;
- return make_error<err::not_implemented>();
+ template<typename Schema, typename Encoding, typename Storage>
+ error_or<data<Schema, Encoding, Storage>> copy_to_host(const data<Schema, Encoding, rmt::Sycl>& device_data){
+ return device_data.copy_to_host();
+ }
+
+ /**
+ * Get a reference to the handle
+ */
+ cl::sycl::queue& get_handle(){
+ return cmd_queue_;
}
};
@@ -317,7 +317,7 @@ public:
template<typename IdT, typename Storage>
remote_data<IdT, Encoding, Storage, rmt::Sycl> request_data(id<IdT> dat_id){
- return {dat_id, std::get<id_map<IdT,Encoding,rmt::Sycl>>(storage_.maps), device_.cmd_queue_};
+ return {dat_id, std::get<id_map<IdT,Encoding,rmt::Sycl>>(storage_.maps), device_->get_handle()};
}
/**
@@ -350,8 +350,8 @@ public:
return eov.get_value();
} else {
auto& client_data = input.get_data();
- dev_tmp_inp = heap<data<typename FuncT::RequestT, Encoding, rmt::Sycl>>(client_data, cmd_queue_);
- cmd_queue_.wait();
+ dev_tmp_inp = heap<data<typename FuncT::RequestT, Encoding, rmt::Sycl>>(client_data, device_->get_handle());
+ device_->get_handle().wait();
return dev_tmp_inp.get();
}
}();
@@ -360,7 +360,7 @@ public:
}
auto& inp = *(eoinp.get_value());
- auto eod = cl_interface_.template call<Name>(inp, &cmd_queue_);
+ auto eod = cl_interface_.template call<Name>(inp, &(device_->get_handle()));
if(eod.is_error()){
return std::move(eod.get_error());
@@ -424,11 +424,18 @@ public:
* Spin up a rpc server
*/
template<typename Iface, typename Encoding>
- rpc_server<Iface, Encoding, rmt::Sycl> listen(const device<rmt::Sycl>& dev, typename rpc_server<Iface, Encoding, rmt::Sycl>::InterfaceT iface){
+ rpc_server<Iface, Encoding, rmt::Sycl> listen(device<rmt::Sycl>& dev, typename rpc_server<Iface, Encoding, rmt::Sycl>::InterfaceT iface){
using RpcServerT = rpc_server<Iface, Encoding, rmt::Sycl>;
using InterfaceT = typename RpcServerT::InterfaceT;
return {dev, std::move(iface)};
}
};
+template<typename T, uint64_t D>
+template<typename Storage>
+error_or<data<schema::Array<T,D>, encode::Native, rmt::Sycl>> data<schema::Array<T,D>, encode::Native, rmt::Sycl>::copy_to_device(const data<schema::Array<T,D>, encode::Native, Storage>& host_data, device<rmt::Sycl>& dev){
+ data<schema::Array<T,D>, encode::Native, rmt::Sycl> device_data{host_data.size(), dev.get_handle()};
+ return make_error<err::not_implemented>();
+}
+
}
diff --git a/modules/remote-sycl/examples/sycl_basic.cpp b/modules/remote-sycl/examples/sycl_basic.cpp
index 486aca1..bc4d997 100644
--- a/modules/remote-sycl/examples/sycl_basic.cpp
+++ b/modules/remote-sycl/examples/sycl_basic.cpp
@@ -18,7 +18,9 @@ int main(){
return -1;
}
- auto rpc_server = listen_basic_sycl(remote_ctx, *rmt_addr);
+ auto device = remote_ctx.connect_device(*rmt_addr);
+
+ auto rpc_server = listen_basic_sycl(remote_ctx, device, *rmt_addr);
saw::rpc_client<schema::BasicInterface, saw::encode::Native, saw::storage::Default, saw::rmt::Sycl> client{rpc_server};
saw::id<schema::Array<schema::UInt64>> id_zero{0u};
diff --git a/modules/remote-sycl/examples/sycl_basic.hpp b/modules/remote-sycl/examples/sycl_basic.hpp
index 6932184..b250d8c 100644
--- a/modules/remote-sycl/examples/sycl_basic.hpp
+++ b/modules/remote-sycl/examples/sycl_basic.hpp
@@ -10,4 +10,4 @@ using BasicInterface = Interface<
>;
}
-saw::rpc_server<schema::BasicInterface, saw::encode::Native, saw::rmt::Sycl> listen_basic_sycl(saw::remote<saw::rmt::Sycl>& ctx, saw::remote_address<saw::rmt::Sycl>& addr);
+saw::rpc_server<schema::BasicInterface, saw::encode::Native, saw::rmt::Sycl> listen_basic_sycl(saw::remote<saw::rmt::Sycl>& ctx, saw::device<saw::rmt::Sycl>& dev, saw::remote_address<saw::rmt::Sycl>& addr);
diff --git a/modules/remote-sycl/examples/sycl_basic_kernel.cpp b/modules/remote-sycl/examples/sycl_basic_kernel.cpp
index 03f0bac..6481eb9 100644
--- a/modules/remote-sycl/examples/sycl_basic_kernel.cpp
+++ b/modules/remote-sycl/examples/sycl_basic_kernel.cpp
@@ -1,6 +1,6 @@
#include "sycl_basic.hpp"
-saw::rpc_server<schema::BasicInterface, saw::encode::Native, saw::rmt::Sycl> listen_basic_sycl(saw::remote<saw::rmt::Sycl>& ctx, saw::remote_address<saw::rmt::Sycl>& addr){
+saw::rpc_server<schema::BasicInterface, saw::encode::Native, saw::rmt::Sycl> listen_basic_sycl(saw::remote<saw::rmt::Sycl>& ctx, saw::device<saw::rmt::Sycl>& dev, saw::remote_address<saw::rmt::Sycl>& addr){
saw::interface<schema::BasicInterface, saw::encode::Native, saw::rmt::Sycl, cl::sycl::queue*> iface{
[](saw::data<saw::schema::Array<saw::schema::UInt64>, saw::encode::Native, saw::rmt::Sycl> in, cl::sycl::queue* q) -> saw::data<saw::schema::Array<saw::schema::UInt64>, saw::encode::Native, saw::rmt::Sycl> {
@@ -13,7 +13,7 @@ saw::rpc_server<schema::BasicInterface, saw::encode::Native, saw::rmt::Sycl> lis
return in;
}
};
- auto rpc_server = ctx.template listen<schema::BasicInterface, saw::encode::Native>(addr, std::move(iface));
+ auto rpc_server = ctx.template listen<schema::BasicInterface, saw::encode::Native>(dev, std::move(iface));
return rpc_server;
}