summaryrefslogtreecommitdiff
path: root/modules/remote-sycl
diff options
context:
space:
mode:
authorClaudius "keldu" Holeksa <mail@keldu.de>2024-06-26 12:30:30 +0200
committerClaudius "keldu" Holeksa <mail@keldu.de>2024-06-26 12:30:30 +0200
commiteda37df9c399b23dc5bdb668730101a87f4770ce (patch)
tree1c8272cf2e724617f144aed8a9cd185408f02ef3 /modules/remote-sycl
parent729307460e77f62a532ee9841dcaed9c47f46419 (diff)
Attempting to fix async errors
Diffstat (limited to 'modules/remote-sycl')
-rw-r--r--modules/remote-sycl/c++/data.hpp2
-rw-r--r--modules/remote-sycl/c++/device.hpp4
-rw-r--r--modules/remote-sycl/c++/remote.hpp11
-rw-r--r--modules/remote-sycl/c++/transfer.hpp15
-rw-r--r--modules/remote-sycl/tests/data.cpp79
5 files changed, 99 insertions, 12 deletions
diff --git a/modules/remote-sycl/c++/data.hpp b/modules/remote-sycl/c++/data.hpp
index 3ad1d9c..42bc3b1 100644
--- a/modules/remote-sycl/c++/data.hpp
+++ b/modules/remote-sycl/c++/data.hpp
@@ -13,7 +13,7 @@ class data<Schema, encode::Native, rmt::Sycl> {
private:
cl::sycl::buffer<data<Schema, encode::Native, storage::Default>> data_;
public:
- data(data<Schema, encode::Native, storage::Default>& data__):
+ data(const data<Schema, encode::Native, storage::Default>& data__):
data_{&data__, 1u}
{}
diff --git a/modules/remote-sycl/c++/device.hpp b/modules/remote-sycl/c++/device.hpp
index 6d133ae..46644b4 100644
--- a/modules/remote-sycl/c++/device.hpp
+++ b/modules/remote-sycl/c++/device.hpp
@@ -21,7 +21,7 @@ public:
*/
template<typename Schema, typename Encoding, typename Storage>
error_or<data<Schema, Encoding, rmt::Sycl>> copy_to_device(const data<Schema, Encoding, Storage>& host_data){
- return data<Schema, Encoding, rmt::Sycl>::copy_to_device(host_data, *this);
+ return data<Schema, Encoding, rmt::Sycl>{host_data};
}
/**
@@ -29,7 +29,7 @@ public:
*/
template<typename Schema, typename Encoding, typename Storage>
error_or<data<Schema, Encoding, Storage>> copy_to_host(const data<Schema, Encoding, rmt::Sycl>& dev_data){
- return dev_data.copy_to_host();
+ return {};
}
/**
diff --git a/modules/remote-sycl/c++/remote.hpp b/modules/remote-sycl/c++/remote.hpp
index 1ae3103..7e77ec9 100644
--- a/modules/remote-sycl/c++/remote.hpp
+++ b/modules/remote-sycl/c++/remote.hpp
@@ -170,10 +170,11 @@ private:
*/
device<rmt::Sycl>* device_;
+ using DataServerT = data_server<typename impl::rpc_iface_type_helper<Iface>::type, Encoding, rmt::Sycl>;
/**
* Data server storing the relevant data
*/
- data_server<typename impl::rpc_iface_type_helper<Iface>::type, Encoding, rmt::Sycl> data_server_;
+ DataServerT* data_server_;
/**
* The interface including the relevant context class.
@@ -185,9 +186,9 @@ public:
/**
* Main constructor
*/
- rpc_server(device<rmt::Sycl>& dev__, InterfaceT cl_iface):
+ rpc_server(device<rmt::Sycl>& dev__, DataServerT& data_server__, InterfaceT cl_iface):
device_{&dev__},
- data_server_{},
+ data_server_{&data_server__},
cl_interface_{std::move(cl_iface)}
{}
@@ -230,7 +231,7 @@ public:
auto eoinp = [&,this]() -> error_or<data<typename FuncT::RequestT, Encoding, rmt::Sycl>* > {
if(input.is_id()){
// storage_.maps
- auto eov = data_server_.template find<typename FuncT::RequestT>(input.get_id());
+ auto eov = data_server_->template find<typename FuncT::RequestT>(input.get_id());
if(eov.is_error()){
return std::move(eov.get_error());
}
@@ -264,7 +265,7 @@ public:
/**
* Store returned data in rpc storage
*/
- auto eoid = data_server_.template insert<typename schema_member_type<Name, Iface>::type::RequestT>(std::move(val), rpc_id);
+ auto eoid = data_server_->template insert<typename schema_member_type<Name, Iface>::type::RequestT>(std::move(val), rpc_id);
if(eoid.is_error()){
return std::move(eoid.get_error());
}
diff --git a/modules/remote-sycl/c++/transfer.hpp b/modules/remote-sycl/c++/transfer.hpp
index 65a9b9e..8987de9 100644
--- a/modules/remote-sycl/c++/transfer.hpp
+++ b/modules/remote-sycl/c++/transfer.hpp
@@ -45,7 +45,7 @@ public:
*/
template<typename Sch>
error_or<void> send(const data<Sch, Encoding, storage::Default>& dat, id<Sch> store_id){
- auto& vals = std::get<Sch>(values_);
+ auto& vals = std::get<std::unordered_map<uint64_t, data<Sch,Encoding,rmt::Sycl>>>(values_);
auto eoval = device_->template copy_to_device<Sch, Encoding, storage::Default>(dat);
if(eoval.is_error()){
auto& err = eoval.get_error();
@@ -68,14 +68,14 @@ public:
*/
template<typename Sch>
error_or<data<Sch, Encoding, storage::Default>> receive(id<Sch> store_id){
- auto& vals = std::get<Sch>(values_);
+ auto& vals = std::get<std::unordered_map<uint64_t, data<Sch,Encoding,rmt::Sycl>>>(values_);
auto find_res = vals.find(store_id.get_value());
if(find_res == vals.end()){
return make_error<err::not_found>();
}
auto& dat = find_res->second;
- auto eoval = device_->copy_to_host(dat);
+ auto eoval = device_->template copy_to_host<Sch, Encoding, storage::Default>(dat);
return eoval;
}
@@ -153,7 +153,14 @@ public:
*/
template<typename Sch>
conveyor<data<Sch, Encoding, storage::Default>> receive(id<Sch> dat_id){
- return srv_->receive(dat_id);
+ auto eov = srv_->receive(dat_id);
+ if(eov.is_error()){
+ auto& err = eov.get_error();
+ return std::move(err);
+ }
+
+ auto& val = eov.get_value();
+ return std::move(val);
}
/**
diff --git a/modules/remote-sycl/tests/data.cpp b/modules/remote-sycl/tests/data.cpp
new file mode 100644
index 0000000..a2bb506
--- /dev/null
+++ b/modules/remote-sycl/tests/data.cpp
@@ -0,0 +1,79 @@
+#include <forstio/test/suite.hpp>
+
+#include "../c++/remote.hpp"
+
+namespace {
+namespace schema {
+using namespace saw::schema;
+
+using TestStruct = Struct<
+ Member<UInt64, "foo">,
+ Member<Int32, "bra">,
+ Member<Array<Float64>, "baz">
+>;
+}
+
+SAW_TEST("SYCL Data Management"){
+ using namespace saw;
+
+ data<schema::TestStruct> host_data;
+ host_data.template get<"foo">() = 321u;
+ host_data.template get<"bra">() = 123;
+ auto& baz = host_data.template get<"baz">();
+ baz = {1024};
+ for(uint64_t i = 0; i < baz.size(); ++i){
+ baz.at(i) = static_cast<double>(i*3);
+ }
+
+ saw::event_loop loop;
+ saw::wait_scope wait{loop};
+
+ remote<rmt::Sycl> rmt;
+
+ own<remote_address<rmt::Sycl>> rmt_addr{};
+
+ rmt.resolve_address().then([&](auto addr){
+ rmt_addr = std::move(addr);
+ }).detach();
+
+ wait.poll();
+ SAW_EXPECT(rmt_addr, "Remote address hasn't been filled");
+
+ auto device = rmt.connect_device(*rmt_addr);
+
+ auto data_srv = data_server<tmpl_group<schema::TestStruct>, encode::Native, rmt::Sycl>{device};
+
+ auto data_cl = data_client<tmpl_group<schema::TestStruct>, encode::Native, rmt::Sycl>{data_srv};
+
+ auto eov = data_cl.send(host_data);
+ SAW_EXPECT(eov.is_value(), "Couldn't send data to SYCL");
+
+ auto& val = eov.get_value();
+
+ bool ran = false;
+ bool error_ran = false;
+
+ auto conv = data_cl.receive(val).then([&](auto dat){
+ auto& foo = dat.template get<"foo">();
+ auto& bra = dat.template get<"bra">();
+ auto& baz = dat.template get<"baz">();
+ SAW_EXPECT(foo == host_data.template get<"foo">(), "Data sent back wasn't equal");
+ SAW_EXPECT(bra == host_data.template get<"bra">(), "Data sent back wasn't equal");
+
+ for(uint64_t i = 0u; i < baz.size(); ++i){
+ SAW_EXPECT(baz.at(i) == host_data.template get<"baz">().at(i), "Data sent back wasn't equal");
+ }
+ ran = true;
+ }, [&](auto err){
+ error_ran = true;
+ return std::move(err);
+ });
+
+ auto eob = conv.take();
+ SAW_EXPECT(eob.is_value(), "conv value doesn't exist");
+ auto& bval = eob.get_value();
+
+ SAW_EXPECT(!error_ran, "conveyor ran, but we got an error");
+ SAW_EXPECT(ran, "conveyor didn't run");
+}
+}