summaryrefslogtreecommitdiff
path: root/modules/remote-sycl/tests/sycl_basics.cpp
blob: f6e62b493099b79a6b362551f682440fcf195724 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
#include <forstio/test/suite.hpp>

#include "../c++/remote.hpp"

namespace {
namespace schema {
using namespace saw::schema;

using TestStruct = Struct<
	Member<UInt64, "foo">,
	Member<Float64, "bar">,
	Member<Array<Float64>, "doubles">
>;

using Foo = Interface<
	Member<Function<TestStruct, Void>, "foo">
>;
}
SAW_TEST("SYCL Test Setup"){
	using namespace saw;

	data<schema::TestStruct> host_data;
	host_data.template get<"foo">() = 321;
	host_data.template get<"bar">() = 50.0;
	host_data.template get<"doubles">() = {1024u};

	saw::event_loop loop;
	saw::wait_scope wait{loop};
	
	remote<rmt::Sycl> rmt;
	saw::own<saw::remote_address<saw::rmt::Sycl>> rmt_addr{};

	rmt.resolve_address().then([&](auto addr){
		rmt_addr = std::move(addr);
	}).detach();

	wait.poll();
	SAW_EXPECT(rmt_addr, "Remote Address class hasn't been filled");

	data<schema::TestStruct, rmt::Sycl<encode::Native>> device_data{host_data};

	interface<schema::Foo, rmt::Sycl<encode::Native>,rmt::Sycl, cl::sycl::queue*> cl_iface {
[&](data<schema::TestStruct, rmt::Sycl<encode::Native>>& in, cl::sycl::queue* cmd) -> error_or<void> {
	
			cmd->submit([&](cl::sycl::handler& h){

				auto acc_buff = in.template access<cl::sycl::access::mode::write>(h);

				uint64_t si = host_data.template get<"doubles">().size();

				h.parallel_for(cl::sycl::range<1>(si), [=] (cl::sycl::id<1> it){
					acc_buff[0u].template get<"foo">() = acc_buff[0u].template get<"doubles">().size();
					auto& dbls = acc_buff[0u].template get<"doubles">();
					dbls.at(it[0u]) = it[0u] * 2.0;
				});
			});
			/*
			cmd->submit([&](cl::sycl::handler& h){
				auto acc_buff = in.template access<cl::sycl::access::mode::read>(h);
				h.copy(acc_buff, &host_data);
			});
			*/
	
			/**
			cl::sycl::host_accessor result{in.get_handle()};
			std::cout<<result[0u].template get<"foo">().get()<<std::endl;
			std::cout<<result[0u].template get<"bar">().get()<<std::endl;
			**/
			return saw::void_t{};
		}
	};
	auto& device = rmt_addr->get_device();
	
	cl_iface.template call <"foo">(device_data, &(device.get_handle()));
	device.get_handle().wait();
}
}