diff options
Diffstat (limited to 'modules/remote-sycl/tests/data.foo')
| -rw-r--r-- | modules/remote-sycl/tests/data.foo | 110 |
1 files changed, 110 insertions, 0 deletions
diff --git a/modules/remote-sycl/tests/data.foo b/modules/remote-sycl/tests/data.foo new file mode 100644 index 0000000..798b7a5 --- /dev/null +++ b/modules/remote-sycl/tests/data.foo @@ -0,0 +1,110 @@ +#include <forstio/test/suite.hpp> + +#include "../c++/transfer.hpp" +#include "../c++/remote.hpp" + +namespace { +namespace schema { +using namespace saw::schema; + +using TestStruct = Struct< + Member<UInt64, "foo">, + Member<Int32, "bra">, + Member<Array<Float64>, "baz"> +>; +} + +SAW_TEST("SYCL Data Management"){ + using namespace saw; + + data<schema::TestStruct> host_data; + auto& foo = host_data.template get<"foo">(); + foo = 321u; + auto& bra = host_data.template get<"bra">(); + bra = 123; + auto& baz = host_data.template get<"baz">(); + baz = {1024}; + for(data<schema::UInt64> i = 0; i < baz.size(); ++i){ + baz.at(i) = static_cast<double>(i.get()*3); + } + + saw::event_loop loop; + saw::wait_scope wait{loop}; + + remote<rmt::Sycl> rmt; + + own<remote_address<rmt::Sycl>> rmt_addr{}; + + rmt.resolve_address().then([&](auto addr){ + rmt_addr = std::move(addr); + }).detach(); + + wait.poll(); + SAW_EXPECT(rmt_addr, "Remote address hasn't been filled"); + + auto our_device = share<device<rmt::Sycl>>(); + auto& device = *our_device; + + auto data_srv = data_server<tmpl_group<schema::TestStruct>, encode::Native, rmt::Sycl>{our_device}; + auto data_cl = data_client<tmpl_group<schema::TestStruct>, encode::Native, rmt::Sycl>{data_srv}; + + auto eov = data_cl.send(host_data); + SAW_EXPECT(eov.is_value(), "Couldn't send data to SYCL"); + + auto& val = eov.get_value(); + + bool ran = false; + bool error_ran = false; + bool expected_values = true; + std::string err_msg; + + auto conv = data_cl.receive(val).then([&](auto dat) { + ran = true; + + auto& foo_b = dat.template get<"foo">(); + + if(foo != foo_b){ + expected_values = false; + err_msg = "foo not equal. "; + err_msg += std::to_string(foo.get()); + err_msg += " - "; + err_msg += std::to_string(foo_b.get()); + return; + } + + auto& bra_b = dat.template get<"bra">(); + if(bra != bra_b){ + expected_values = false; + err_msg = "bra not equal. "; + err_msg += std::to_string(bra.get()); + err_msg += " - "; + err_msg += std::to_string(bra_b.get()); + return; + } + + auto& baz_b = dat.template get<"baz">(); + if(baz.size() != baz_b.size()){ + expected_values = false; + err_msg = "baz not equal. "; + err_msg += std::to_string(baz.size().get()); + err_msg += " - "; + err_msg += std::to_string(baz_b.size().get()); + return; + } + }, [&](auto err){ + error_ran = true; + return std::move(err); + }); + + auto eob = conv.take(); + if(eob.is_error()){ + auto& err = eob.get_error(); + SAW_EXPECT(false, (std::string{"Conv value doesn't exist: "} + std::string{err.get_category()} + std::string{" - "} + std::string{err.get_message()})); + } + auto& bval = eob.get_value(); + + SAW_EXPECT(!error_ran, "Conveyor ran, but we got an error."); + SAW_EXPECT(ran, "Conveyor didn't run."); + SAW_EXPECT(expected_values, std::string{"Values are not equal. "} + err_msg); +} +} |
