summaryrefslogtreecommitdiff
path: root/modules/remote-sycl/c++
diff options
context:
space:
mode:
authorClaudius "keldu" Holeksa <mail@keldu.de>2024-06-26 12:30:30 +0200
committerClaudius "keldu" Holeksa <mail@keldu.de>2024-06-26 12:30:30 +0200
commiteda37df9c399b23dc5bdb668730101a87f4770ce (patch)
tree1c8272cf2e724617f144aed8a9cd185408f02ef3 /modules/remote-sycl/c++
parent729307460e77f62a532ee9841dcaed9c47f46419 (diff)
Attempting to fix async errors
Diffstat (limited to 'modules/remote-sycl/c++')
-rw-r--r--modules/remote-sycl/c++/data.hpp2
-rw-r--r--modules/remote-sycl/c++/device.hpp4
-rw-r--r--modules/remote-sycl/c++/remote.hpp11
-rw-r--r--modules/remote-sycl/c++/transfer.hpp15
4 files changed, 20 insertions, 12 deletions
diff --git a/modules/remote-sycl/c++/data.hpp b/modules/remote-sycl/c++/data.hpp
index 3ad1d9c..42bc3b1 100644
--- a/modules/remote-sycl/c++/data.hpp
+++ b/modules/remote-sycl/c++/data.hpp
@@ -13,7 +13,7 @@ class data<Schema, encode::Native, rmt::Sycl> {
private:
cl::sycl::buffer<data<Schema, encode::Native, storage::Default>> data_;
public:
- data(data<Schema, encode::Native, storage::Default>& data__):
+ data(const data<Schema, encode::Native, storage::Default>& data__):
data_{&data__, 1u}
{}
diff --git a/modules/remote-sycl/c++/device.hpp b/modules/remote-sycl/c++/device.hpp
index 6d133ae..46644b4 100644
--- a/modules/remote-sycl/c++/device.hpp
+++ b/modules/remote-sycl/c++/device.hpp
@@ -21,7 +21,7 @@ public:
*/
template<typename Schema, typename Encoding, typename Storage>
error_or<data<Schema, Encoding, rmt::Sycl>> copy_to_device(const data<Schema, Encoding, Storage>& host_data){
- return data<Schema, Encoding, rmt::Sycl>::copy_to_device(host_data, *this);
+ return data<Schema, Encoding, rmt::Sycl>{host_data};
}
/**
@@ -29,7 +29,7 @@ public:
*/
template<typename Schema, typename Encoding, typename Storage>
error_or<data<Schema, Encoding, Storage>> copy_to_host(const data<Schema, Encoding, rmt::Sycl>& dev_data){
- return dev_data.copy_to_host();
+ return {};
}
/**
diff --git a/modules/remote-sycl/c++/remote.hpp b/modules/remote-sycl/c++/remote.hpp
index 1ae3103..7e77ec9 100644
--- a/modules/remote-sycl/c++/remote.hpp
+++ b/modules/remote-sycl/c++/remote.hpp
@@ -170,10 +170,11 @@ private:
*/
device<rmt::Sycl>* device_;
+ using DataServerT = data_server<typename impl::rpc_iface_type_helper<Iface>::type, Encoding, rmt::Sycl>;
/**
* Data server storing the relevant data
*/
- data_server<typename impl::rpc_iface_type_helper<Iface>::type, Encoding, rmt::Sycl> data_server_;
+ DataServerT* data_server_;
/**
* The interface including the relevant context class.
@@ -185,9 +186,9 @@ public:
/**
* Main constructor
*/
- rpc_server(device<rmt::Sycl>& dev__, InterfaceT cl_iface):
+ rpc_server(device<rmt::Sycl>& dev__, DataServerT& data_server__, InterfaceT cl_iface):
device_{&dev__},
- data_server_{},
+ data_server_{&data_server__},
cl_interface_{std::move(cl_iface)}
{}
@@ -230,7 +231,7 @@ public:
auto eoinp = [&,this]() -> error_or<data<typename FuncT::RequestT, Encoding, rmt::Sycl>* > {
if(input.is_id()){
// storage_.maps
- auto eov = data_server_.template find<typename FuncT::RequestT>(input.get_id());
+ auto eov = data_server_->template find<typename FuncT::RequestT>(input.get_id());
if(eov.is_error()){
return std::move(eov.get_error());
}
@@ -264,7 +265,7 @@ public:
/**
* Store returned data in rpc storage
*/
- auto eoid = data_server_.template insert<typename schema_member_type<Name, Iface>::type::RequestT>(std::move(val), rpc_id);
+ auto eoid = data_server_->template insert<typename schema_member_type<Name, Iface>::type::RequestT>(std::move(val), rpc_id);
if(eoid.is_error()){
return std::move(eoid.get_error());
}
diff --git a/modules/remote-sycl/c++/transfer.hpp b/modules/remote-sycl/c++/transfer.hpp
index 65a9b9e..8987de9 100644
--- a/modules/remote-sycl/c++/transfer.hpp
+++ b/modules/remote-sycl/c++/transfer.hpp
@@ -45,7 +45,7 @@ public:
*/
template<typename Sch>
error_or<void> send(const data<Sch, Encoding, storage::Default>& dat, id<Sch> store_id){
- auto& vals = std::get<Sch>(values_);
+ auto& vals = std::get<std::unordered_map<uint64_t, data<Sch,Encoding,rmt::Sycl>>>(values_);
auto eoval = device_->template copy_to_device<Sch, Encoding, storage::Default>(dat);
if(eoval.is_error()){
auto& err = eoval.get_error();
@@ -68,14 +68,14 @@ public:
*/
template<typename Sch>
error_or<data<Sch, Encoding, storage::Default>> receive(id<Sch> store_id){
- auto& vals = std::get<Sch>(values_);
+ auto& vals = std::get<std::unordered_map<uint64_t, data<Sch,Encoding,rmt::Sycl>>>(values_);
auto find_res = vals.find(store_id.get_value());
if(find_res == vals.end()){
return make_error<err::not_found>();
}
auto& dat = find_res->second;
- auto eoval = device_->copy_to_host(dat);
+ auto eoval = device_->template copy_to_host<Sch, Encoding, storage::Default>(dat);
return eoval;
}
@@ -153,7 +153,14 @@ public:
*/
template<typename Sch>
conveyor<data<Sch, Encoding, storage::Default>> receive(id<Sch> dat_id){
- return srv_->receive(dat_id);
+ auto eov = srv_->receive(dat_id);
+ if(eov.is_error()){
+ auto& err = eov.get_error();
+ return std::move(err);
+ }
+
+ auto& val = eov.get_value();
+ return std::move(val);
}
/**