summaryrefslogtreecommitdiff
path: root/modules/remote-sycl/tests/data.foo
diff options
context:
space:
mode:
Diffstat (limited to 'modules/remote-sycl/tests/data.foo')
-rw-r--r--modules/remote-sycl/tests/data.foo110
1 files changed, 110 insertions, 0 deletions
diff --git a/modules/remote-sycl/tests/data.foo b/modules/remote-sycl/tests/data.foo
new file mode 100644
index 0000000..798b7a5
--- /dev/null
+++ b/modules/remote-sycl/tests/data.foo
@@ -0,0 +1,110 @@
+#include <forstio/test/suite.hpp>
+
+#include "../c++/transfer.hpp"
+#include "../c++/remote.hpp"
+
+namespace {
+namespace schema {
+using namespace saw::schema;
+
+using TestStruct = Struct<
+ Member<UInt64, "foo">,
+ Member<Int32, "bra">,
+ Member<Array<Float64>, "baz">
+>;
+}
+
+SAW_TEST("SYCL Data Management"){
+ using namespace saw;
+
+ data<schema::TestStruct> host_data;
+ auto& foo = host_data.template get<"foo">();
+ foo = 321u;
+ auto& bra = host_data.template get<"bra">();
+ bra = 123;
+ auto& baz = host_data.template get<"baz">();
+ baz = {1024};
+ for(data<schema::UInt64> i = 0; i < baz.size(); ++i){
+ baz.at(i) = static_cast<double>(i.get()*3);
+ }
+
+ saw::event_loop loop;
+ saw::wait_scope wait{loop};
+
+ remote<rmt::Sycl> rmt;
+
+ own<remote_address<rmt::Sycl>> rmt_addr{};
+
+ rmt.resolve_address().then([&](auto addr){
+ rmt_addr = std::move(addr);
+ }).detach();
+
+ wait.poll();
+ SAW_EXPECT(rmt_addr, "Remote address hasn't been filled");
+
+ auto our_device = share<device<rmt::Sycl>>();
+ auto& device = *our_device;
+
+ auto data_srv = data_server<tmpl_group<schema::TestStruct>, encode::Native, rmt::Sycl>{our_device};
+ auto data_cl = data_client<tmpl_group<schema::TestStruct>, encode::Native, rmt::Sycl>{data_srv};
+
+ auto eov = data_cl.send(host_data);
+ SAW_EXPECT(eov.is_value(), "Couldn't send data to SYCL");
+
+ auto& val = eov.get_value();
+
+ bool ran = false;
+ bool error_ran = false;
+ bool expected_values = true;
+ std::string err_msg;
+
+ auto conv = data_cl.receive(val).then([&](auto dat) {
+ ran = true;
+
+ auto& foo_b = dat.template get<"foo">();
+
+ if(foo != foo_b){
+ expected_values = false;
+ err_msg = "foo not equal. ";
+ err_msg += std::to_string(foo.get());
+ err_msg += " - ";
+ err_msg += std::to_string(foo_b.get());
+ return;
+ }
+
+ auto& bra_b = dat.template get<"bra">();
+ if(bra != bra_b){
+ expected_values = false;
+ err_msg = "bra not equal. ";
+ err_msg += std::to_string(bra.get());
+ err_msg += " - ";
+ err_msg += std::to_string(bra_b.get());
+ return;
+ }
+
+ auto& baz_b = dat.template get<"baz">();
+ if(baz.size() != baz_b.size()){
+ expected_values = false;
+ err_msg = "baz not equal. ";
+ err_msg += std::to_string(baz.size().get());
+ err_msg += " - ";
+ err_msg += std::to_string(baz_b.size().get());
+ return;
+ }
+ }, [&](auto err){
+ error_ran = true;
+ return std::move(err);
+ });
+
+ auto eob = conv.take();
+ if(eob.is_error()){
+ auto& err = eob.get_error();
+ SAW_EXPECT(false, (std::string{"Conv value doesn't exist: "} + std::string{err.get_category()} + std::string{" - "} + std::string{err.get_message()}));
+ }
+ auto& bval = eob.get_value();
+
+ SAW_EXPECT(!error_ran, "Conveyor ran, but we got an error.");
+ SAW_EXPECT(ran, "Conveyor didn't run.");
+ SAW_EXPECT(expected_values, std::string{"Values are not equal. "} + err_msg);
+}
+}