summaryrefslogtreecommitdiff
path: root/modules/remote-sycl/tests/sycl_basics.foo
diff options
context:
space:
mode:
Diffstat (limited to 'modules/remote-sycl/tests/sycl_basics.foo')
-rw-r--r--modules/remote-sycl/tests/sycl_basics.foo77
1 files changed, 77 insertions, 0 deletions
diff --git a/modules/remote-sycl/tests/sycl_basics.foo b/modules/remote-sycl/tests/sycl_basics.foo
new file mode 100644
index 0000000..970f4d6
--- /dev/null
+++ b/modules/remote-sycl/tests/sycl_basics.foo
@@ -0,0 +1,77 @@
+#include <forstio/test/suite.hpp>
+
+#include "../c++/device.hpp"
+#include "../c++/data.hpp"
+#include "../c++/remote.hpp"
+
+namespace {
+namespace schema {
+using namespace saw::schema;
+
+using TestStruct = Struct<
+ Member<UInt64, "foo">,
+ Member<Float64, "bar">,
+ Member<Array<Float64>, "doubles">
+>;
+
+using Foo = Interface<
+ Member<Function<TestStruct, Void>, "foo">
+>;
+}
+SAW_TEST("SYCL Basics"){
+ using namespace saw;
+
+ acpp::sycl::queue q;
+ data<schema::TestStruct,encode::Sycl<encode::Native>> host_data;
+}
+
+/*
+SAW_TEST("SYCL Test Setup"){
+ using namespace saw;
+
+ data<schema::TestStruct> host_data;
+ host_data.template get<"foo">() = 321;
+ host_data.template get<"bar">() = 50.0;
+ host_data.template get<"doubles">() = {1024u};
+
+ saw::event_loop loop;
+ saw::wait_scope wait{loop};
+
+ remote<rmt::Sycl> rmt;
+ saw::own<saw::remote_address<saw::rmt::Sycl>> rmt_addr{};
+
+ rmt.resolve_address().then([&](auto addr){
+ rmt_addr = std::move(addr);
+ }).detach();
+
+ wait.poll();
+ SAW_EXPECT(rmt_addr, "Remote Address class hasn't been filled");
+
+ data<schema::TestStruct, encode::Sycl<encode::Native>> device_data{host_data};
+
+ interface<schema::Foo, encode::Sycl<encode::Native>,acpp::sycl::queue*> cl_iface {
+[&](data<schema::TestStruct, encode::Sycl<encode::Native>>& in, acpp::sycl::queue* cmd) -> error_or<void> {
+
+ cmd->submit([&](acpp::sycl::handler& h){
+
+ auto acc_buff = in.template access<acpp::sycl::access::mode::write>(h);
+
+ auto si = host_data.template get<"doubles">().size();
+
+ h.parallel_for(acpp::sycl::range<1>(si.get()), [=] (acpp::sycl::id<1> it){
+ acc_buff[0u].template get<"foo">() = acc_buff[0u].template get<"doubles">().size();
+ auto& dbls = acc_buff[0u].template get<"doubles">();
+ dbls.at(it[0u]) = it[0u] * 2.0;
+ });
+ });
+ return saw::void_t{};
+ }
+ };
+ auto our_device = share<device<rmt::Sycl>>();
+ auto& device = *our_device;
+
+ cl_iface.template call <"foo">(device_data, &(device.get_handle()));
+ device.get_handle().wait();
+}
+*/
+}