summaryrefslogtreecommitdiff
path: root/modules/remote-sycl/tests
diff options
context:
space:
mode:
Diffstat (limited to 'modules/remote-sycl/tests')
-rw-r--r--modules/remote-sycl/tests/calculator.foo (renamed from modules/remote-sycl/tests/calculator.cpp)2
-rw-r--r--modules/remote-sycl/tests/sycl_basics.cpp96
2 files changed, 97 insertions, 1 deletions
diff --git a/modules/remote-sycl/tests/calculator.cpp b/modules/remote-sycl/tests/calculator.foo
index 6d061ad..745bd3d 100644
--- a/modules/remote-sycl/tests/calculator.cpp
+++ b/modules/remote-sycl/tests/calculator.foo
@@ -15,7 +15,7 @@ using Calculator = Interface<
>;
}
-SAW_TEST("Sycl Interface Calculator"){
+SAW_TEST("SYCL Interface Calculator"){
using namespace saw;
cl::sycl::queue cmd_queue;
diff --git a/modules/remote-sycl/tests/sycl_basics.cpp b/modules/remote-sycl/tests/sycl_basics.cpp
new file mode 100644
index 0000000..bf41983
--- /dev/null
+++ b/modules/remote-sycl/tests/sycl_basics.cpp
@@ -0,0 +1,96 @@
+#include <forstio/test/suite.hpp>
+
+#include "../c++/remote.hpp"
+
+namespace {
+namespace schema {
+using namespace saw::schema;
+
+using TestStruct = Struct<
+ Member<UInt64, "foo">,
+ Member<Float64, "bar">,
+ Member<Array<Float64>, "doubles">
+>;
+
+using Foo = Interface<
+ Member<Function<TestStruct, Void>, "foo">
+>;
+
+using Calculator = Interface<
+ Member<
+ Function<Tuple<Int64, Int64>, Int64>, "add"
+ >
+, Member<
+ Function<Tuple<Int64, Int64>, Int64>, "multiply"
+ >
+>;
+}
+SAW_TEST("SYCL Test Setup"){
+ using namespace saw;
+
+ data<schema::TestStruct> host_data;
+ host_data.template get<"foo">() = 321;
+ host_data.template get<"bar">() = 50.0;
+ host_data.template get<"doubles">() = {1024u};
+
+ saw::event_loop loop;
+ saw::wait_scope wait{loop};
+
+ remote<rmt::Sycl> rmt;
+ saw::own<saw::remote_address<saw::rmt::Sycl>> rmt_addr{};
+
+ rmt.resolve_address().then([&](auto addr){
+ rmt_addr = std::move(addr);
+ }).detach();
+
+ wait.poll();
+ SAW_EXPECT(rmt_addr, "Remote Address class hasn't been filled");
+
+ auto device = rmt.connect_device(*rmt_addr);
+
+ data<schema::TestStruct, encode::Native, rmt::Sycl> device_data{host_data};
+
+ interface<schema::Foo, encode::Native,rmt::Sycl, cl::sycl::queue*> cl_iface {
+[&](data<schema::TestStruct, encode::Native, rmt::Sycl>& in, cl::sycl::queue* cmd) -> error_or<void> {
+
+ cmd->submit([&](cl::sycl::handler& h){
+
+ auto acc_buff = in.template access<cl::sycl::access::mode::write>(h);
+
+ uint64_t si = host_data.template get<"doubles">().size();
+
+ h.parallel_for(cl::sycl::range<1>(si), [=] (cl::sycl::id<1> it){
+ acc_buff[0u].template get<"foo">() = acc_buff[0u].template get<"doubles">().size();
+ auto& dbls = acc_buff[0u].template get<"doubles">();
+ dbls.at(it[0u]) = it[0u] * 2.0;
+ });
+ });
+ /*
+ cmd->submit([&](cl::sycl::handler& h){
+ auto acc_buff = in.template access<cl::sycl::access::mode::read>(h);
+ h.copy(acc_buff, &host_data);
+ });
+ */
+
+ /**
+ cl::sycl::host_accessor result{in.get_handle()};
+ std::cout<<result[0u].template get<"foo">().get()<<std::endl;
+ std::cout<<result[0u].template get<"bar">().get()<<std::endl;
+ **/
+ return saw::void_t{};
+ }
+ };
+
+ std::cout<<"Running on:\n"<<device.get_handle().get_device().get_info<cl::sycl::info::device::name>()<<std::endl;
+
+ std::cout<<host_data.template get<"foo">().get()<<std::endl;
+ std::cout<<host_data.template get<"bar">().get()<<std::endl;
+ std::cout<<std::endl;
+ cl_iface.template call <"foo">(device_data, &device.get_handle());
+ device.get_handle().wait();
+ std::cout<<host_data.template get<"foo">().get()<<std::endl;
+ std::cout<<host_data.template get<"bar">().get()<<std::endl;
+ auto& dbls = host_data.template get<"doubles">();
+ std::cout<<std::endl;
+}
+}