summaryrefslogtreecommitdiff
path: root/modules/remote-sycl
diff options
context:
space:
mode:
authorClaudius "keldu" Holeksa <mail@keldu.de>2024-06-26 14:41:18 +0200
committerClaudius "keldu" Holeksa <mail@keldu.de>2024-06-26 14:41:18 +0200
commitdf7789cbef7ffa9658c61525edf75bebaa6398ff (patch)
tree8b004a208910a9c38f20ec312fbb0e74a5a2c279 /modules/remote-sycl
parenteda37df9c399b23dc5bdb668730101a87f4770ce (diff)
Got double free :/
Diffstat (limited to 'modules/remote-sycl')
-rw-r--r--modules/remote-sycl/c++/data.hpp10
-rw-r--r--modules/remote-sycl/c++/device.hpp10
-rw-r--r--modules/remote-sycl/tests/data.cpp58
3 files changed, 61 insertions, 17 deletions
diff --git a/modules/remote-sycl/c++/data.hpp b/modules/remote-sycl/c++/data.hpp
index 42bc3b1..e436e72 100644
--- a/modules/remote-sycl/c++/data.hpp
+++ b/modules/remote-sycl/c++/data.hpp
@@ -21,10 +21,18 @@ public:
return data_;
}
+ const auto& get_handle() const {
+ return data_;
+ }
+
template<cl::sycl::access::mode AccessMode>
auto access(cl::sycl::handler& h){
return data_.template get_access<AccessMode>(h);
}
+
+ template<cl::sycl::access::mode AccessMode>
+ auto access(cl::sycl::handler& h) const {
+ return data_.template get_access<AccessMode>(h);
+ }
};
-
}
diff --git a/modules/remote-sycl/c++/device.hpp b/modules/remote-sycl/c++/device.hpp
index 46644b4..5e63a1d 100644
--- a/modules/remote-sycl/c++/device.hpp
+++ b/modules/remote-sycl/c++/device.hpp
@@ -28,8 +28,14 @@ public:
* Copy data to host
*/
template<typename Schema, typename Encoding, typename Storage>
- error_or<data<Schema, Encoding, Storage>> copy_to_host(const data<Schema, Encoding, rmt::Sycl>& dev_data){
- return {};
+ error_or<data<Schema, Encoding, Storage>> copy_to_host(data<Schema, Encoding, rmt::Sycl>& dev_data){
+ data<Schema,Encoding, Storage> host_data;
+ cmd_queue_.submit([&](cl::sycl::handler& h){
+ auto acc_buff = dev_data.template access<cl::sycl::access::mode::read>(h);
+ h.copy(acc_buff, &host_data);
+ });
+ cmd_queue_.wait();
+ return host_data;
}
/**
diff --git a/modules/remote-sycl/tests/data.cpp b/modules/remote-sycl/tests/data.cpp
index a2bb506..dff19fb 100644
--- a/modules/remote-sycl/tests/data.cpp
+++ b/modules/remote-sycl/tests/data.cpp
@@ -17,8 +17,10 @@ SAW_TEST("SYCL Data Management"){
using namespace saw;
data<schema::TestStruct> host_data;
- host_data.template get<"foo">() = 321u;
- host_data.template get<"bra">() = 123;
+ auto& foo = host_data.template get<"foo">();
+ foo = 321u;
+ auto& bra = host_data.template get<"bra">();
+ bra = 123;
auto& baz = host_data.template get<"baz">();
baz = {1024};
for(uint64_t i = 0; i < baz.size(); ++i){
@@ -52,28 +54,56 @@ SAW_TEST("SYCL Data Management"){
bool ran = false;
bool error_ran = false;
+ bool expected_values = true;
+ std::string err_msg;
- auto conv = data_cl.receive(val).then([&](auto dat){
- auto& foo = dat.template get<"foo">();
- auto& bra = dat.template get<"bra">();
- auto& baz = dat.template get<"baz">();
- SAW_EXPECT(foo == host_data.template get<"foo">(), "Data sent back wasn't equal");
- SAW_EXPECT(bra == host_data.template get<"bra">(), "Data sent back wasn't equal");
+ auto conv = data_cl.receive(val).then([&](auto dat) {
+ ran = true;
+
+ auto& foo_b = dat.template get<"foo">();
- for(uint64_t i = 0u; i < baz.size(); ++i){
- SAW_EXPECT(baz.at(i) == host_data.template get<"baz">().at(i), "Data sent back wasn't equal");
+ if(foo != foo_b){
+ expected_values = false;
+ err_msg = "foo not equal. ";
+ err_msg += std::to_string(foo.get());
+ err_msg += " - ";
+ err_msg += std::to_string(foo_b.get());
+ return;
+ }
+
+ auto& bra_b = dat.template get<"bra">();
+ if(bra != bra_b){
+ expected_values = false;
+ err_msg = "bra not equal. ";
+ err_msg += std::to_string(bra.get());
+ err_msg += " - ";
+ err_msg += std::to_string(bra_b.get());
+ return;
+ }
+
+ auto& baz_b = dat.template get<"baz">();
+ if(baz.size() != baz_b.size()){
+ expected_values = false;
+ err_msg = "baz not equal. ";
+ err_msg += std::to_string(baz.size());
+ err_msg += " - ";
+ err_msg += std::to_string(baz_b.size());
+ return;
}
- ran = true;
}, [&](auto err){
error_ran = true;
return std::move(err);
});
auto eob = conv.take();
- SAW_EXPECT(eob.is_value(), "conv value doesn't exist");
+ if(eob.is_error()){
+ auto& err = eob.get_error();
+ SAW_EXPECT(false, (std::string{"Conv value doesn't exist: "} + std::string{err.get_category()} + std::string{" - "} + std::string{err.get_message()}));
+ }
auto& bval = eob.get_value();
- SAW_EXPECT(!error_ran, "conveyor ran, but we got an error");
- SAW_EXPECT(ran, "conveyor didn't run");
+ SAW_EXPECT(!error_ran, "Conveyor ran, but we got an error.");
+ SAW_EXPECT(ran, "Conveyor didn't run.");
+ SAW_EXPECT(expected_values, std::string{"Values are not equal. "} + err_msg);
}
}