summaryrefslogtreecommitdiff
path: root/modules/remote-sycl/tests/data.cpp
blob: a2bb50648c04fe514d4f3ea19edf70a7322a0fa6 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
#include <forstio/test/suite.hpp>

#include "../c++/remote.hpp"

namespace {
namespace schema {
using namespace saw::schema;

using TestStruct = Struct<
	Member<UInt64, "foo">,
	Member<Int32, "bra">,
	Member<Array<Float64>, "baz">
>;
}

SAW_TEST("SYCL Data Management"){
	using namespace saw;

	data<schema::TestStruct> host_data;
	host_data.template get<"foo">() = 321u;
	host_data.template get<"bra">() = 123;
	auto& baz = host_data.template get<"baz">();
	baz = {1024};
	for(uint64_t i = 0; i < baz.size(); ++i){
		baz.at(i) = static_cast<double>(i*3);
	}

	saw::event_loop loop;
	saw::wait_scope wait{loop};

	remote<rmt::Sycl> rmt;

	own<remote_address<rmt::Sycl>> rmt_addr{};

	rmt.resolve_address().then([&](auto addr){
		rmt_addr = std::move(addr);
	}).detach();

	wait.poll();
	SAW_EXPECT(rmt_addr, "Remote address hasn't been filled");

	auto device = rmt.connect_device(*rmt_addr);

	auto data_srv = data_server<tmpl_group<schema::TestStruct>, encode::Native, rmt::Sycl>{device};

	auto data_cl = data_client<tmpl_group<schema::TestStruct>, encode::Native, rmt::Sycl>{data_srv};

	auto eov = data_cl.send(host_data);
	SAW_EXPECT(eov.is_value(), "Couldn't send data to SYCL");

	auto& val = eov.get_value();

	bool ran = false;
	bool error_ran = false;

	auto conv = data_cl.receive(val).then([&](auto dat){
		auto& foo = dat.template get<"foo">();
		auto& bra = dat.template get<"bra">();
		auto& baz = dat.template get<"baz">();
		SAW_EXPECT(foo == host_data.template get<"foo">(), "Data sent back wasn't equal");
		SAW_EXPECT(bra == host_data.template get<"bra">(), "Data sent back wasn't equal");

		for(uint64_t i = 0u; i < baz.size(); ++i){
			SAW_EXPECT(baz.at(i) == host_data.template get<"baz">().at(i), "Data sent back wasn't equal");
		}
		ran = true;
	}, [&](auto err){
		error_ran = true;
		return std::move(err);
	});

	auto eob = conv.take();
	SAW_EXPECT(eob.is_value(), "conv value doesn't exist");
	auto& bval = eob.get_value();

	SAW_EXPECT(!error_ran, "conveyor ran, but we got an error");
	SAW_EXPECT(ran, "conveyor didn't run");
}
}