summaryrefslogtreecommitdiff
path: root/modules/remote-sycl/tests
diff options
context:
space:
mode:
Diffstat (limited to 'modules/remote-sycl/tests')
-rw-r--r--modules/remote-sycl/tests/calculator.cpp64
1 files changed, 63 insertions, 1 deletions
diff --git a/modules/remote-sycl/tests/calculator.cpp b/modules/remote-sycl/tests/calculator.cpp
index 9ee1583..730838d 100644
--- a/modules/remote-sycl/tests/calculator.cpp
+++ b/modules/remote-sycl/tests/calculator.cpp
@@ -2,6 +2,68 @@
#include "../c++/remote.hpp"
-SAW_TEST("OpenCl Calculator"){
+namespace {
+namespace schema {
+using namespace saw::schema;
+using Calculator = Interface<
+ Member<
+ Function<Tuple<Int64, Int64>, Int64>, "add"
+ >
+, Member<
+ Function<Tuple<Int64, Int64>, Int64>, "multiply"
+ >
+>;
+}
+
+SAW_TEST("Sycl Interface Calculator"){
+ using namespace saw;
+
+ cl::sycl::queue cmd_queue;
+
+ interface<schema::Calculator, encode::Native<storage::Default>, cl::sycl::queue*> cl_iface {
+[](data<schema::Tuple<schema::Int64, schema::Int64>> in, cl::sycl::queue* cmd) -> data<schema::Int64> {
+ std::array<int64_t,2> h_xy{in.get<0>().get(), in.get<1>().get()};
+ int64_t res{};
+ cl::sycl::buffer<int64_t,1> d_xy { h_xy.data(), h_xy.size() };
+ cl::sycl::buffer<int64_t,1> d_z { &res, 1u };
+ cmd->submit([&](cl::sycl::handler& h){
+ auto a_xy = d_xy.get_access<cl::sycl::access::mode::read>(h);
+ auto a_z = d_z.get_access<cl::sycl::access::mode::write>(h);
+
+ h.parallel_for(cl::sycl::range<1>(1u), [=] (cl::sycl::id<1> it){
+ a_z[0] = a_xy[0] + a_xy[1];
+ });
+ });
+ cmd->wait();
+ return data<schema::Int64>{res};
+ },
+ [](data<schema::Tuple<schema::Int64, schema::Int64>> in, cl::sycl::queue* cmd) -> data<schema::Int64> {
+ return data<schema::Int64>{in.get<0>().get() * in.get<1>().get()};
+ }
+ };
+
+ int64_t x = 1;
+ int64_t y = -2;
+ int64_t sum = x + y;
+ int64_t mult = x * y;
+
+
+ data<schema::Tuple<schema::Int64, schema::Int64>> input;
+ input.template get<0>().set(x);
+ input.template get<1>().set(y);
+
+ {
+ auto eov = cl_iface.template call<"add">(input, &cmd_queue);
+ SAW_EXPECT(eov.is_value(), "Returned error on add");
+
+ SAW_EXPECT(eov.get_value().get() == sum, "Addition was incorrect");
+ }
+ {
+ auto eov = cl_iface.template call<"multiply">(input, &cmd_queue);
+ SAW_EXPECT(eov.is_value(), "Returned error on add");
+
+ SAW_EXPECT(eov.get_value().get() == mult, "Addition was incorrect");
+ }
+}
}