summaryrefslogtreecommitdiff
path: root/modules
diff options
context:
space:
mode:
Diffstat (limited to 'modules')
-rw-r--r--modules/async/c++/async.hpp10
-rw-r--r--modules/async/c++/async.tmpl.hpp4
-rw-r--r--modules/remote-sycl/c++/data.hpp10
-rw-r--r--modules/remote-sycl/c++/device.hpp10
-rw-r--r--modules/remote-sycl/tests/data.cpp58
5 files changed, 69 insertions, 23 deletions
diff --git a/modules/async/c++/async.hpp b/modules/async/c++/async.hpp
index 0f58536..ba994fd 100644
--- a/modules/async/c++/async.hpp
+++ b/modules/async/c++/async.hpp
@@ -287,7 +287,7 @@ public:
* If no sink() or detach() is used you have to take elements out of the
* chain yourself.
*/
- error_or<fix_void<T>> take();
+ error_or<T> take();
/** @todo implement
* Specifically pump elements through this chain with the provided
@@ -758,10 +758,12 @@ public:
func_, std::move(dep_eov.get_value()));
} catch (const std::bad_alloc &) {
eov = make_error<err::out_of_memory>("Out of memory");
- } catch (const std::exception &) {
+ } catch (const std::exception & e) {
eov = make_error<err::invalid_state>(
- "Exception in chain occured. Return ErrorOr<T> if you "
- "want to handle errors which are recoverable");
+ "Exception in chain occured. Return error_or<T> if you "
+ "want to handle errors which are recoverable."
+ "You might have thrown an exception in your code"
+ "which you haven't caught. Don't do that.");
}
} else if (dep_eov.is_error()) {
eov = error_func_(std::move(dep_eov.get_error()));
diff --git a/modules/async/c++/async.tmpl.hpp b/modules/async/c++/async.tmpl.hpp
index 68489ad..7016283 100644
--- a/modules/async/c++/async.tmpl.hpp
+++ b/modules/async/c++/async.tmpl.hpp
@@ -160,14 +160,14 @@ own<conveyor_node> conveyor<T>::from_conveyor(conveyor<T> conveyor) {
return std::move(conveyor.node_);
}
-template <typename T> error_or<fix_void<T>> conveyor<T>::take() {
+template <typename T> error_or<T> conveyor<T>::take() {
SAW_ASSERT(node_) {
make_error<err::invalid_state>("conveyor in invalid state");
}
conveyor_storage *storage = node_->next_storage();
if (storage) {
if (storage->queued() > 0) {
- error_or<fix_void<T>> result;
+ error_or<T> result;
node_->get_result(result);
return result;
} else {
diff --git a/modules/remote-sycl/c++/data.hpp b/modules/remote-sycl/c++/data.hpp
index 42bc3b1..e436e72 100644
--- a/modules/remote-sycl/c++/data.hpp
+++ b/modules/remote-sycl/c++/data.hpp
@@ -21,10 +21,18 @@ public:
return data_;
}
+ const auto& get_handle() const {
+ return data_;
+ }
+
template<cl::sycl::access::mode AccessMode>
auto access(cl::sycl::handler& h){
return data_.template get_access<AccessMode>(h);
}
+
+ template<cl::sycl::access::mode AccessMode>
+ auto access(cl::sycl::handler& h) const {
+ return data_.template get_access<AccessMode>(h);
+ }
};
-
}
diff --git a/modules/remote-sycl/c++/device.hpp b/modules/remote-sycl/c++/device.hpp
index 46644b4..5e63a1d 100644
--- a/modules/remote-sycl/c++/device.hpp
+++ b/modules/remote-sycl/c++/device.hpp
@@ -28,8 +28,14 @@ public:
* Copy data to host
*/
template<typename Schema, typename Encoding, typename Storage>
- error_or<data<Schema, Encoding, Storage>> copy_to_host(const data<Schema, Encoding, rmt::Sycl>& dev_data){
- return {};
+ error_or<data<Schema, Encoding, Storage>> copy_to_host(data<Schema, Encoding, rmt::Sycl>& dev_data){
+ data<Schema,Encoding, Storage> host_data;
+ cmd_queue_.submit([&](cl::sycl::handler& h){
+ auto acc_buff = dev_data.template access<cl::sycl::access::mode::read>(h);
+ h.copy(acc_buff, &host_data);
+ });
+ cmd_queue_.wait();
+ return host_data;
}
/**
diff --git a/modules/remote-sycl/tests/data.cpp b/modules/remote-sycl/tests/data.cpp
index a2bb506..dff19fb 100644
--- a/modules/remote-sycl/tests/data.cpp
+++ b/modules/remote-sycl/tests/data.cpp
@@ -17,8 +17,10 @@ SAW_TEST("SYCL Data Management"){
using namespace saw;
data<schema::TestStruct> host_data;
- host_data.template get<"foo">() = 321u;
- host_data.template get<"bra">() = 123;
+ auto& foo = host_data.template get<"foo">();
+ foo = 321u;
+ auto& bra = host_data.template get<"bra">();
+ bra = 123;
auto& baz = host_data.template get<"baz">();
baz = {1024};
for(uint64_t i = 0; i < baz.size(); ++i){
@@ -52,28 +54,56 @@ SAW_TEST("SYCL Data Management"){
bool ran = false;
bool error_ran = false;
+ bool expected_values = true;
+ std::string err_msg;
- auto conv = data_cl.receive(val).then([&](auto dat){
- auto& foo = dat.template get<"foo">();
- auto& bra = dat.template get<"bra">();
- auto& baz = dat.template get<"baz">();
- SAW_EXPECT(foo == host_data.template get<"foo">(), "Data sent back wasn't equal");
- SAW_EXPECT(bra == host_data.template get<"bra">(), "Data sent back wasn't equal");
+ auto conv = data_cl.receive(val).then([&](auto dat) {
+ ran = true;
+
+ auto& foo_b = dat.template get<"foo">();
- for(uint64_t i = 0u; i < baz.size(); ++i){
- SAW_EXPECT(baz.at(i) == host_data.template get<"baz">().at(i), "Data sent back wasn't equal");
+ if(foo != foo_b){
+ expected_values = false;
+ err_msg = "foo not equal. ";
+ err_msg += std::to_string(foo.get());
+ err_msg += " - ";
+ err_msg += std::to_string(foo_b.get());
+ return;
+ }
+
+ auto& bra_b = dat.template get<"bra">();
+ if(bra != bra_b){
+ expected_values = false;
+ err_msg = "bra not equal. ";
+ err_msg += std::to_string(bra.get());
+ err_msg += " - ";
+ err_msg += std::to_string(bra_b.get());
+ return;
+ }
+
+ auto& baz_b = dat.template get<"baz">();
+ if(baz.size() != baz_b.size()){
+ expected_values = false;
+ err_msg = "baz not equal. ";
+ err_msg += std::to_string(baz.size());
+ err_msg += " - ";
+ err_msg += std::to_string(baz_b.size());
+ return;
}
- ran = true;
}, [&](auto err){
error_ran = true;
return std::move(err);
});
auto eob = conv.take();
- SAW_EXPECT(eob.is_value(), "conv value doesn't exist");
+ if(eob.is_error()){
+ auto& err = eob.get_error();
+ SAW_EXPECT(false, (std::string{"Conv value doesn't exist: "} + std::string{err.get_category()} + std::string{" - "} + std::string{err.get_message()}));
+ }
auto& bval = eob.get_value();
- SAW_EXPECT(!error_ran, "conveyor ran, but we got an error");
- SAW_EXPECT(ran, "conveyor didn't run");
+ SAW_EXPECT(!error_ran, "Conveyor ran, but we got an error.");
+ SAW_EXPECT(ran, "Conveyor didn't run.");
+ SAW_EXPECT(expected_values, std::string{"Values are not equal. "} + err_msg);
}
}