diff options
Diffstat (limited to 'modules/remote-sycl/tests/calculator.cpp')
-rw-r--r-- | modules/remote-sycl/tests/calculator.cpp | 64 |
1 files changed, 63 insertions, 1 deletions
diff --git a/modules/remote-sycl/tests/calculator.cpp b/modules/remote-sycl/tests/calculator.cpp index 9ee1583..730838d 100644 --- a/modules/remote-sycl/tests/calculator.cpp +++ b/modules/remote-sycl/tests/calculator.cpp @@ -2,6 +2,68 @@ #include "../c++/remote.hpp" -SAW_TEST("OpenCl Calculator"){ +namespace { +namespace schema { +using namespace saw::schema; +using Calculator = Interface< + Member< + Function<Tuple<Int64, Int64>, Int64>, "add" + > +, Member< + Function<Tuple<Int64, Int64>, Int64>, "multiply" + > +>; +} + +SAW_TEST("Sycl Interface Calculator"){ + using namespace saw; + + cl::sycl::queue cmd_queue; + + interface<schema::Calculator, encode::Native<storage::Default>, cl::sycl::queue*> cl_iface { +[](data<schema::Tuple<schema::Int64, schema::Int64>> in, cl::sycl::queue* cmd) -> data<schema::Int64> { + std::array<int64_t,2> h_xy{in.get<0>().get(), in.get<1>().get()}; + int64_t res{}; + cl::sycl::buffer<int64_t,1> d_xy { h_xy.data(), h_xy.size() }; + cl::sycl::buffer<int64_t,1> d_z { &res, 1u }; + cmd->submit([&](cl::sycl::handler& h){ + auto a_xy = d_xy.get_access<cl::sycl::access::mode::read>(h); + auto a_z = d_z.get_access<cl::sycl::access::mode::write>(h); + + h.parallel_for(cl::sycl::range<1>(1u), [=] (cl::sycl::id<1> it){ + a_z[0] = a_xy[0] + a_xy[1]; + }); + }); + cmd->wait(); + return data<schema::Int64>{res}; + }, + [](data<schema::Tuple<schema::Int64, schema::Int64>> in, cl::sycl::queue* cmd) -> data<schema::Int64> { + return data<schema::Int64>{in.get<0>().get() * in.get<1>().get()}; + } + }; + + int64_t x = 1; + int64_t y = -2; + int64_t sum = x + y; + int64_t mult = x * y; + + + data<schema::Tuple<schema::Int64, schema::Int64>> input; + input.template get<0>().set(x); + input.template get<1>().set(y); + + { + auto eov = cl_iface.template call<"add">(input, &cmd_queue); + SAW_EXPECT(eov.is_value(), "Returned error on add"); + + SAW_EXPECT(eov.get_value().get() == sum, "Addition was incorrect"); + } + { + auto eov = cl_iface.template call<"multiply">(input, &cmd_queue); + SAW_EXPECT(eov.is_value(), "Returned error on add"); + + SAW_EXPECT(eov.get_value().get() == mult, "Addition was incorrect"); + } +} } |