diff options
-rw-r--r-- | modules/async/c++/async.hpp | 10 | ||||
-rw-r--r-- | modules/async/c++/async.tmpl.hpp | 4 | ||||
-rw-r--r-- | modules/remote-sycl/c++/data.hpp | 10 | ||||
-rw-r--r-- | modules/remote-sycl/c++/device.hpp | 10 | ||||
-rw-r--r-- | modules/remote-sycl/tests/data.cpp | 58 |
5 files changed, 69 insertions, 23 deletions
diff --git a/modules/async/c++/async.hpp b/modules/async/c++/async.hpp index 0f58536..ba994fd 100644 --- a/modules/async/c++/async.hpp +++ b/modules/async/c++/async.hpp @@ -287,7 +287,7 @@ public: * If no sink() or detach() is used you have to take elements out of the * chain yourself. */ - error_or<fix_void<T>> take(); + error_or<T> take(); /** @todo implement * Specifically pump elements through this chain with the provided @@ -758,10 +758,12 @@ public: func_, std::move(dep_eov.get_value())); } catch (const std::bad_alloc &) { eov = make_error<err::out_of_memory>("Out of memory"); - } catch (const std::exception &) { + } catch (const std::exception & e) { eov = make_error<err::invalid_state>( - "Exception in chain occured. Return ErrorOr<T> if you " - "want to handle errors which are recoverable"); + "Exception in chain occured. Return error_or<T> if you " + "want to handle errors which are recoverable." + "You might have thrown an exception in your code" + "which you haven't caught. Don't do that."); } } else if (dep_eov.is_error()) { eov = error_func_(std::move(dep_eov.get_error())); diff --git a/modules/async/c++/async.tmpl.hpp b/modules/async/c++/async.tmpl.hpp index 68489ad..7016283 100644 --- a/modules/async/c++/async.tmpl.hpp +++ b/modules/async/c++/async.tmpl.hpp @@ -160,14 +160,14 @@ own<conveyor_node> conveyor<T>::from_conveyor(conveyor<T> conveyor) { return std::move(conveyor.node_); } -template <typename T> error_or<fix_void<T>> conveyor<T>::take() { +template <typename T> error_or<T> conveyor<T>::take() { SAW_ASSERT(node_) { make_error<err::invalid_state>("conveyor in invalid state"); } conveyor_storage *storage = node_->next_storage(); if (storage) { if (storage->queued() > 0) { - error_or<fix_void<T>> result; + error_or<T> result; node_->get_result(result); return result; } else { diff --git a/modules/remote-sycl/c++/data.hpp b/modules/remote-sycl/c++/data.hpp index 42bc3b1..e436e72 100644 --- a/modules/remote-sycl/c++/data.hpp +++ b/modules/remote-sycl/c++/data.hpp @@ -21,10 +21,18 @@ public: return data_; } + const auto& get_handle() const { + return data_; + } + template<cl::sycl::access::mode AccessMode> auto access(cl::sycl::handler& h){ return data_.template get_access<AccessMode>(h); } + + template<cl::sycl::access::mode AccessMode> + auto access(cl::sycl::handler& h) const { + return data_.template get_access<AccessMode>(h); + } }; - } diff --git a/modules/remote-sycl/c++/device.hpp b/modules/remote-sycl/c++/device.hpp index 46644b4..5e63a1d 100644 --- a/modules/remote-sycl/c++/device.hpp +++ b/modules/remote-sycl/c++/device.hpp @@ -28,8 +28,14 @@ public: * Copy data to host */ template<typename Schema, typename Encoding, typename Storage> - error_or<data<Schema, Encoding, Storage>> copy_to_host(const data<Schema, Encoding, rmt::Sycl>& dev_data){ - return {}; + error_or<data<Schema, Encoding, Storage>> copy_to_host(data<Schema, Encoding, rmt::Sycl>& dev_data){ + data<Schema,Encoding, Storage> host_data; + cmd_queue_.submit([&](cl::sycl::handler& h){ + auto acc_buff = dev_data.template access<cl::sycl::access::mode::read>(h); + h.copy(acc_buff, &host_data); + }); + cmd_queue_.wait(); + return host_data; } /** diff --git a/modules/remote-sycl/tests/data.cpp b/modules/remote-sycl/tests/data.cpp index a2bb506..dff19fb 100644 --- a/modules/remote-sycl/tests/data.cpp +++ b/modules/remote-sycl/tests/data.cpp @@ -17,8 +17,10 @@ SAW_TEST("SYCL Data Management"){ using namespace saw; data<schema::TestStruct> host_data; - host_data.template get<"foo">() = 321u; - host_data.template get<"bra">() = 123; + auto& foo = host_data.template get<"foo">(); + foo = 321u; + auto& bra = host_data.template get<"bra">(); + bra = 123; auto& baz = host_data.template get<"baz">(); baz = {1024}; for(uint64_t i = 0; i < baz.size(); ++i){ @@ -52,28 +54,56 @@ SAW_TEST("SYCL Data Management"){ bool ran = false; bool error_ran = false; + bool expected_values = true; + std::string err_msg; - auto conv = data_cl.receive(val).then([&](auto dat){ - auto& foo = dat.template get<"foo">(); - auto& bra = dat.template get<"bra">(); - auto& baz = dat.template get<"baz">(); - SAW_EXPECT(foo == host_data.template get<"foo">(), "Data sent back wasn't equal"); - SAW_EXPECT(bra == host_data.template get<"bra">(), "Data sent back wasn't equal"); + auto conv = data_cl.receive(val).then([&](auto dat) { + ran = true; + + auto& foo_b = dat.template get<"foo">(); - for(uint64_t i = 0u; i < baz.size(); ++i){ - SAW_EXPECT(baz.at(i) == host_data.template get<"baz">().at(i), "Data sent back wasn't equal"); + if(foo != foo_b){ + expected_values = false; + err_msg = "foo not equal. "; + err_msg += std::to_string(foo.get()); + err_msg += " - "; + err_msg += std::to_string(foo_b.get()); + return; + } + + auto& bra_b = dat.template get<"bra">(); + if(bra != bra_b){ + expected_values = false; + err_msg = "bra not equal. "; + err_msg += std::to_string(bra.get()); + err_msg += " - "; + err_msg += std::to_string(bra_b.get()); + return; + } + + auto& baz_b = dat.template get<"baz">(); + if(baz.size() != baz_b.size()){ + expected_values = false; + err_msg = "baz not equal. "; + err_msg += std::to_string(baz.size()); + err_msg += " - "; + err_msg += std::to_string(baz_b.size()); + return; } - ran = true; }, [&](auto err){ error_ran = true; return std::move(err); }); auto eob = conv.take(); - SAW_EXPECT(eob.is_value(), "conv value doesn't exist"); + if(eob.is_error()){ + auto& err = eob.get_error(); + SAW_EXPECT(false, (std::string{"Conv value doesn't exist: "} + std::string{err.get_category()} + std::string{" - "} + std::string{err.get_message()})); + } auto& bval = eob.get_value(); - SAW_EXPECT(!error_ran, "conveyor ran, but we got an error"); - SAW_EXPECT(ran, "conveyor didn't run"); + SAW_EXPECT(!error_ran, "Conveyor ran, but we got an error."); + SAW_EXPECT(ran, "Conveyor didn't run."); + SAW_EXPECT(expected_values, std::string{"Values are not equal. "} + err_msg); } } |