diff options
author | Claudius "keldu" Holeksa <mail@keldu.de> | 2024-06-21 19:44:34 +0200 |
---|---|---|
committer | Claudius "keldu" Holeksa <mail@keldu.de> | 2024-06-21 19:44:34 +0200 |
commit | 86b06a3fee2cd7635a9ab486e2a35bdf1e81ce38 (patch) | |
tree | 5485b323cdce1c1347f1a20c7f33e8f772c73dbf /modules/remote-sycl/tests/sycl_basics.cpp | |
parent | 601113a445658d8b15273dd91c66cf20daf50d30 (diff) |
Moving forward with basic test for sycl
Diffstat (limited to 'modules/remote-sycl/tests/sycl_basics.cpp')
-rw-r--r-- | modules/remote-sycl/tests/sycl_basics.cpp | 96 |
1 files changed, 96 insertions, 0 deletions
diff --git a/modules/remote-sycl/tests/sycl_basics.cpp b/modules/remote-sycl/tests/sycl_basics.cpp new file mode 100644 index 0000000..bf41983 --- /dev/null +++ b/modules/remote-sycl/tests/sycl_basics.cpp @@ -0,0 +1,96 @@ +#include <forstio/test/suite.hpp> + +#include "../c++/remote.hpp" + +namespace { +namespace schema { +using namespace saw::schema; + +using TestStruct = Struct< + Member<UInt64, "foo">, + Member<Float64, "bar">, + Member<Array<Float64>, "doubles"> +>; + +using Foo = Interface< + Member<Function<TestStruct, Void>, "foo"> +>; + +using Calculator = Interface< + Member< + Function<Tuple<Int64, Int64>, Int64>, "add" + > +, Member< + Function<Tuple<Int64, Int64>, Int64>, "multiply" + > +>; +} +SAW_TEST("SYCL Test Setup"){ + using namespace saw; + + data<schema::TestStruct> host_data; + host_data.template get<"foo">() = 321; + host_data.template get<"bar">() = 50.0; + host_data.template get<"doubles">() = {1024u}; + + saw::event_loop loop; + saw::wait_scope wait{loop}; + + remote<rmt::Sycl> rmt; + saw::own<saw::remote_address<saw::rmt::Sycl>> rmt_addr{}; + + rmt.resolve_address().then([&](auto addr){ + rmt_addr = std::move(addr); + }).detach(); + + wait.poll(); + SAW_EXPECT(rmt_addr, "Remote Address class hasn't been filled"); + + auto device = rmt.connect_device(*rmt_addr); + + data<schema::TestStruct, encode::Native, rmt::Sycl> device_data{host_data}; + + interface<schema::Foo, encode::Native,rmt::Sycl, cl::sycl::queue*> cl_iface { +[&](data<schema::TestStruct, encode::Native, rmt::Sycl>& in, cl::sycl::queue* cmd) -> error_or<void> { + + cmd->submit([&](cl::sycl::handler& h){ + + auto acc_buff = in.template access<cl::sycl::access::mode::write>(h); + + uint64_t si = host_data.template get<"doubles">().size(); + + h.parallel_for(cl::sycl::range<1>(si), [=] (cl::sycl::id<1> it){ + acc_buff[0u].template get<"foo">() = acc_buff[0u].template get<"doubles">().size(); + auto& dbls = acc_buff[0u].template get<"doubles">(); + dbls.at(it[0u]) = it[0u] * 2.0; + }); + }); + /* + cmd->submit([&](cl::sycl::handler& h){ + auto acc_buff = in.template access<cl::sycl::access::mode::read>(h); + h.copy(acc_buff, &host_data); + }); + */ + + /** + cl::sycl::host_accessor result{in.get_handle()}; + std::cout<<result[0u].template get<"foo">().get()<<std::endl; + std::cout<<result[0u].template get<"bar">().get()<<std::endl; + **/ + return saw::void_t{}; + } + }; + + std::cout<<"Running on:\n"<<device.get_handle().get_device().get_info<cl::sycl::info::device::name>()<<std::endl; + + std::cout<<host_data.template get<"foo">().get()<<std::endl; + std::cout<<host_data.template get<"bar">().get()<<std::endl; + std::cout<<std::endl; + cl_iface.template call <"foo">(device_data, &device.get_handle()); + device.get_handle().wait(); + std::cout<<host_data.template get<"foo">().get()<<std::endl; + std::cout<<host_data.template get<"bar">().get()<<std::endl; + auto& dbls = host_data.template get<"doubles">(); + std::cout<<std::endl; +} +} |