From 86b06a3fee2cd7635a9ab486e2a35bdf1e81ce38 Mon Sep 17 00:00:00 2001 From: "Claudius \"keldu\" Holeksa" Date: Fri, 21 Jun 2024 19:44:34 +0200 Subject: Moving forward with basic test for sycl --- modules/remote-sycl/tests/calculator.cpp | 69 ---------------------- modules/remote-sycl/tests/calculator.foo | 69 ++++++++++++++++++++++ modules/remote-sycl/tests/sycl_basics.cpp | 96 +++++++++++++++++++++++++++++++ 3 files changed, 165 insertions(+), 69 deletions(-) delete mode 100644 modules/remote-sycl/tests/calculator.cpp create mode 100644 modules/remote-sycl/tests/calculator.foo create mode 100644 modules/remote-sycl/tests/sycl_basics.cpp (limited to 'modules/remote-sycl/tests') diff --git a/modules/remote-sycl/tests/calculator.cpp b/modules/remote-sycl/tests/calculator.cpp deleted file mode 100644 index 6d061ad..0000000 --- a/modules/remote-sycl/tests/calculator.cpp +++ /dev/null @@ -1,69 +0,0 @@ -#include - -#include "../c++/remote.hpp" - -namespace { -namespace schema { -using namespace saw::schema; -using Calculator = Interface< - Member< - Function, Int64>, "add" - > -, Member< - Function, Int64>, "multiply" - > ->; -} - -SAW_TEST("Sycl Interface Calculator"){ - using namespace saw; - - cl::sycl::queue cmd_queue; - - interface, cl::sycl::queue*> cl_iface { -[](data>& in, cl::sycl::queue* cmd) -> data { - std::array h_xy{in.get<0>().get(), in.get<1>().get()}; - int64_t res{}; - cl::sycl::buffer d_xy { h_xy.data(), h_xy.size() }; - cl::sycl::buffer d_z { &res, 1u }; - cmd->submit([&](cl::sycl::handler& h){ - auto a_xy = d_xy.get_access(h); - auto a_z = d_z.get_access(h); - - h.parallel_for(cl::sycl::range<1>(1u), [=] (cl::sycl::id<1> it){ - a_z[0] = a_xy[0] + a_xy[1]; - }); - }); - cmd->wait(); - return data{res}; - }, - [](data,encode::Native,rmt::Sycl>& in, cl::sycl::queue* cmd) -> data { - return data{in.get<0>().get() * in.get<1>().get()}; - } - }; - - int64_t x = 1; - int64_t y = -2; - - int64_t sum = x + y; - int64_t mult = x * y; - - - data> input; - input.template get<0>().set(x); - input.template get<1>().set(y); - - { - auto eov = cl_iface.template call<"add">(input, &cmd_queue); - SAW_EXPECT(eov.is_value(), "Returned error on add"); - - SAW_EXPECT(eov.get_value().get() == sum, "Addition was incorrect"); - } - { - auto eov = cl_iface.template call<"multiply">(input, &cmd_queue); - SAW_EXPECT(eov.is_value(), "Returned error on add"); - - SAW_EXPECT(eov.get_value().get() == mult, "Addition was incorrect"); - } -} -} diff --git a/modules/remote-sycl/tests/calculator.foo b/modules/remote-sycl/tests/calculator.foo new file mode 100644 index 0000000..745bd3d --- /dev/null +++ b/modules/remote-sycl/tests/calculator.foo @@ -0,0 +1,69 @@ +#include + +#include "../c++/remote.hpp" + +namespace { +namespace schema { +using namespace saw::schema; +using Calculator = Interface< + Member< + Function, Int64>, "add" + > +, Member< + Function, Int64>, "multiply" + > +>; +} + +SAW_TEST("SYCL Interface Calculator"){ + using namespace saw; + + cl::sycl::queue cmd_queue; + + interface, cl::sycl::queue*> cl_iface { +[](data>& in, cl::sycl::queue* cmd) -> data { + std::array h_xy{in.get<0>().get(), in.get<1>().get()}; + int64_t res{}; + cl::sycl::buffer d_xy { h_xy.data(), h_xy.size() }; + cl::sycl::buffer d_z { &res, 1u }; + cmd->submit([&](cl::sycl::handler& h){ + auto a_xy = d_xy.get_access(h); + auto a_z = d_z.get_access(h); + + h.parallel_for(cl::sycl::range<1>(1u), [=] (cl::sycl::id<1> it){ + a_z[0] = a_xy[0] + a_xy[1]; + }); + }); + cmd->wait(); + return data{res}; + }, + [](data,encode::Native,rmt::Sycl>& in, cl::sycl::queue* cmd) -> data { + return data{in.get<0>().get() * in.get<1>().get()}; + } + }; + + int64_t x = 1; + int64_t y = -2; + + int64_t sum = x + y; + int64_t mult = x * y; + + + data> input; + input.template get<0>().set(x); + input.template get<1>().set(y); + + { + auto eov = cl_iface.template call<"add">(input, &cmd_queue); + SAW_EXPECT(eov.is_value(), "Returned error on add"); + + SAW_EXPECT(eov.get_value().get() == sum, "Addition was incorrect"); + } + { + auto eov = cl_iface.template call<"multiply">(input, &cmd_queue); + SAW_EXPECT(eov.is_value(), "Returned error on add"); + + SAW_EXPECT(eov.get_value().get() == mult, "Addition was incorrect"); + } +} +} diff --git a/modules/remote-sycl/tests/sycl_basics.cpp b/modules/remote-sycl/tests/sycl_basics.cpp new file mode 100644 index 0000000..bf41983 --- /dev/null +++ b/modules/remote-sycl/tests/sycl_basics.cpp @@ -0,0 +1,96 @@ +#include + +#include "../c++/remote.hpp" + +namespace { +namespace schema { +using namespace saw::schema; + +using TestStruct = Struct< + Member, + Member, + Member, "doubles"> +>; + +using Foo = Interface< + Member, "foo"> +>; + +using Calculator = Interface< + Member< + Function, Int64>, "add" + > +, Member< + Function, Int64>, "multiply" + > +>; +} +SAW_TEST("SYCL Test Setup"){ + using namespace saw; + + data host_data; + host_data.template get<"foo">() = 321; + host_data.template get<"bar">() = 50.0; + host_data.template get<"doubles">() = {1024u}; + + saw::event_loop loop; + saw::wait_scope wait{loop}; + + remote rmt; + saw::own> rmt_addr{}; + + rmt.resolve_address().then([&](auto addr){ + rmt_addr = std::move(addr); + }).detach(); + + wait.poll(); + SAW_EXPECT(rmt_addr, "Remote Address class hasn't been filled"); + + auto device = rmt.connect_device(*rmt_addr); + + data device_data{host_data}; + + interface cl_iface { +[&](data& in, cl::sycl::queue* cmd) -> error_or { + + cmd->submit([&](cl::sycl::handler& h){ + + auto acc_buff = in.template access(h); + + uint64_t si = host_data.template get<"doubles">().size(); + + h.parallel_for(cl::sycl::range<1>(si), [=] (cl::sycl::id<1> it){ + acc_buff[0u].template get<"foo">() = acc_buff[0u].template get<"doubles">().size(); + auto& dbls = acc_buff[0u].template get<"doubles">(); + dbls.at(it[0u]) = it[0u] * 2.0; + }); + }); + /* + cmd->submit([&](cl::sycl::handler& h){ + auto acc_buff = in.template access(h); + h.copy(acc_buff, &host_data); + }); + */ + + /** + cl::sycl::host_accessor result{in.get_handle()}; + std::cout<().get()<().get()<()<().get()<().get()<(device_data, &device.get_handle()); + device.get_handle().wait(); + std::cout<().get()<().get()<(); + std::cout<