From eda37df9c399b23dc5bdb668730101a87f4770ce Mon Sep 17 00:00:00 2001 From: "Claudius \"keldu\" Holeksa" Date: Wed, 26 Jun 2024 12:30:30 +0200 Subject: Attempting to fix async errors --- modules/async/c++/async.tmpl.hpp | 9 ++-- modules/remote-sycl/c++/data.hpp | 2 +- modules/remote-sycl/c++/device.hpp | 4 +- modules/remote-sycl/c++/remote.hpp | 11 ++--- modules/remote-sycl/c++/transfer.hpp | 15 +++++-- modules/remote-sycl/tests/data.cpp | 79 ++++++++++++++++++++++++++++++++++++ 6 files changed, 102 insertions(+), 18 deletions(-) create mode 100644 modules/remote-sycl/tests/data.cpp diff --git a/modules/async/c++/async.tmpl.hpp b/modules/async/c++/async.tmpl.hpp index ec8d3fc..68489ad 100644 --- a/modules/async/c++/async.tmpl.hpp +++ b/modules/async/c++/async.tmpl.hpp @@ -162,8 +162,7 @@ own conveyor::from_conveyor(conveyor conveyor) { template error_or> conveyor::take() { SAW_ASSERT(node_) { - return error_or>{ - make_error("conveyor in invalid state")}; + make_error("conveyor in invalid state"); } conveyor_storage *storage = node_->next_storage(); if (storage) { @@ -172,12 +171,10 @@ template error_or> conveyor::take() { node_->get_result(result); return result; } else { - return error_or>{ - make_error("conveyor buffer has no elements")}; + return make_error("conveyor buffer has no elements"); } } else { - return error_or>{ - make_error("conveyor node has no child storage")}; + return make_error("conveyor node has no child storage"); } } diff --git a/modules/remote-sycl/c++/data.hpp b/modules/remote-sycl/c++/data.hpp index 3ad1d9c..42bc3b1 100644 --- a/modules/remote-sycl/c++/data.hpp +++ b/modules/remote-sycl/c++/data.hpp @@ -13,7 +13,7 @@ class data { private: cl::sycl::buffer> data_; public: - data(data& data__): + data(const data& data__): data_{&data__, 1u} {} diff --git a/modules/remote-sycl/c++/device.hpp b/modules/remote-sycl/c++/device.hpp index 6d133ae..46644b4 100644 --- a/modules/remote-sycl/c++/device.hpp +++ b/modules/remote-sycl/c++/device.hpp @@ -21,7 +21,7 @@ public: */ template error_or> copy_to_device(const data& host_data){ - return data::copy_to_device(host_data, *this); + return data{host_data}; } /** @@ -29,7 +29,7 @@ public: */ template error_or> copy_to_host(const data& dev_data){ - return dev_data.copy_to_host(); + return {}; } /** diff --git a/modules/remote-sycl/c++/remote.hpp b/modules/remote-sycl/c++/remote.hpp index 1ae3103..7e77ec9 100644 --- a/modules/remote-sycl/c++/remote.hpp +++ b/modules/remote-sycl/c++/remote.hpp @@ -170,10 +170,11 @@ private: */ device* device_; + using DataServerT = data_server::type, Encoding, rmt::Sycl>; /** * Data server storing the relevant data */ - data_server::type, Encoding, rmt::Sycl> data_server_; + DataServerT* data_server_; /** * The interface including the relevant context class. @@ -185,9 +186,9 @@ public: /** * Main constructor */ - rpc_server(device& dev__, InterfaceT cl_iface): + rpc_server(device& dev__, DataServerT& data_server__, InterfaceT cl_iface): device_{&dev__}, - data_server_{}, + data_server_{&data_server__}, cl_interface_{std::move(cl_iface)} {} @@ -230,7 +231,7 @@ public: auto eoinp = [&,this]() -> error_or* > { if(input.is_id()){ // storage_.maps - auto eov = data_server_.template find(input.get_id()); + auto eov = data_server_->template find(input.get_id()); if(eov.is_error()){ return std::move(eov.get_error()); } @@ -264,7 +265,7 @@ public: /** * Store returned data in rpc storage */ - auto eoid = data_server_.template insert::type::RequestT>(std::move(val), rpc_id); + auto eoid = data_server_->template insert::type::RequestT>(std::move(val), rpc_id); if(eoid.is_error()){ return std::move(eoid.get_error()); } diff --git a/modules/remote-sycl/c++/transfer.hpp b/modules/remote-sycl/c++/transfer.hpp index 65a9b9e..8987de9 100644 --- a/modules/remote-sycl/c++/transfer.hpp +++ b/modules/remote-sycl/c++/transfer.hpp @@ -45,7 +45,7 @@ public: */ template error_or send(const data& dat, id store_id){ - auto& vals = std::get(values_); + auto& vals = std::get>>(values_); auto eoval = device_->template copy_to_device(dat); if(eoval.is_error()){ auto& err = eoval.get_error(); @@ -68,14 +68,14 @@ public: */ template error_or> receive(id store_id){ - auto& vals = std::get(values_); + auto& vals = std::get>>(values_); auto find_res = vals.find(store_id.get_value()); if(find_res == vals.end()){ return make_error(); } auto& dat = find_res->second; - auto eoval = device_->copy_to_host(dat); + auto eoval = device_->template copy_to_host(dat); return eoval; } @@ -153,7 +153,14 @@ public: */ template conveyor> receive(id dat_id){ - return srv_->receive(dat_id); + auto eov = srv_->receive(dat_id); + if(eov.is_error()){ + auto& err = eov.get_error(); + return std::move(err); + } + + auto& val = eov.get_value(); + return std::move(val); } /** diff --git a/modules/remote-sycl/tests/data.cpp b/modules/remote-sycl/tests/data.cpp new file mode 100644 index 0000000..a2bb506 --- /dev/null +++ b/modules/remote-sycl/tests/data.cpp @@ -0,0 +1,79 @@ +#include + +#include "../c++/remote.hpp" + +namespace { +namespace schema { +using namespace saw::schema; + +using TestStruct = Struct< + Member, + Member, + Member, "baz"> +>; +} + +SAW_TEST("SYCL Data Management"){ + using namespace saw; + + data host_data; + host_data.template get<"foo">() = 321u; + host_data.template get<"bra">() = 123; + auto& baz = host_data.template get<"baz">(); + baz = {1024}; + for(uint64_t i = 0; i < baz.size(); ++i){ + baz.at(i) = static_cast(i*3); + } + + saw::event_loop loop; + saw::wait_scope wait{loop}; + + remote rmt; + + own> rmt_addr{}; + + rmt.resolve_address().then([&](auto addr){ + rmt_addr = std::move(addr); + }).detach(); + + wait.poll(); + SAW_EXPECT(rmt_addr, "Remote address hasn't been filled"); + + auto device = rmt.connect_device(*rmt_addr); + + auto data_srv = data_server, encode::Native, rmt::Sycl>{device}; + + auto data_cl = data_client, encode::Native, rmt::Sycl>{data_srv}; + + auto eov = data_cl.send(host_data); + SAW_EXPECT(eov.is_value(), "Couldn't send data to SYCL"); + + auto& val = eov.get_value(); + + bool ran = false; + bool error_ran = false; + + auto conv = data_cl.receive(val).then([&](auto dat){ + auto& foo = dat.template get<"foo">(); + auto& bra = dat.template get<"bra">(); + auto& baz = dat.template get<"baz">(); + SAW_EXPECT(foo == host_data.template get<"foo">(), "Data sent back wasn't equal"); + SAW_EXPECT(bra == host_data.template get<"bra">(), "Data sent back wasn't equal"); + + for(uint64_t i = 0u; i < baz.size(); ++i){ + SAW_EXPECT(baz.at(i) == host_data.template get<"baz">().at(i), "Data sent back wasn't equal"); + } + ran = true; + }, [&](auto err){ + error_ran = true; + return std::move(err); + }); + + auto eob = conv.take(); + SAW_EXPECT(eob.is_value(), "conv value doesn't exist"); + auto& bval = eob.get_value(); + + SAW_EXPECT(!error_ran, "conveyor ran, but we got an error"); + SAW_EXPECT(ran, "conveyor didn't run"); +} +} -- cgit v1.2.3