summaryrefslogtreecommitdiff
path: root/modules/remote-sycl/tests/data.cpp
blob: 027f2774ee5fa63a2e887ccec14fcde93345d368 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
#include <forstio/test/suite.hpp>

#include "../c++/transfer.hpp"
#include "../c++/remote.hpp"

namespace {
namespace schema {
using namespace saw::schema;

using TestStruct = Struct<
	Member<UInt64, "foo">,
	Member<Int32, "bra">,
	Member<Array<Float64>, "baz">
>;
}

SAW_TEST("SYCL Data Management"){
	using namespace saw;

	data<schema::TestStruct> host_data;
	auto& foo = host_data.template get<"foo">();
	foo = 321u;
	auto& bra = host_data.template get<"bra">();
	bra = 123;
	auto& baz = host_data.template get<"baz">();
	baz = {1024};
	for(uint64_t i = 0; i < baz.size(); ++i){
		baz.at(i) = static_cast<double>(i*3);
	}

	saw::event_loop loop;
	saw::wait_scope wait{loop};

	remote<rmt::Sycl> rmt;

	own<remote_address<rmt::Sycl>> rmt_addr{};

	rmt.resolve_address().then([&](auto addr){
		rmt_addr = std::move(addr);
	}).detach();

	wait.poll();
	SAW_EXPECT(rmt_addr, "Remote address hasn't been filled");

	auto our_device = share<device<rmt::Sycl>>();
	auto& device = *our_device;

	auto data_srv = data_server<tmpl_group<schema::TestStruct>, encode::Native, rmt::Sycl>{our_device};
	auto data_cl = data_client<tmpl_group<schema::TestStruct>, encode::Native, rmt::Sycl>{data_srv};

	auto eov = data_cl.send(host_data);
	SAW_EXPECT(eov.is_value(), "Couldn't send data to SYCL");

	auto& val = eov.get_value();

	bool ran = false;
	bool error_ran = false;
	bool expected_values = true;
	std::string err_msg;

	auto conv = data_cl.receive(val).then([&](auto dat) {
		ran = true;

		auto& foo_b = dat.template get<"foo">();

		if(foo != foo_b){
			expected_values = false;
			err_msg = "foo not equal. ";
			err_msg += std::to_string(foo.get());
			err_msg += " - ";
			err_msg += std::to_string(foo_b.get());
			return;
		}
		
		auto& bra_b = dat.template get<"bra">();
		if(bra != bra_b){
			expected_values = false;
			err_msg = "bra not equal. ";
			err_msg += std::to_string(bra.get());
			err_msg += " - ";
			err_msg += std::to_string(bra_b.get());
			return;
		}
		
		auto& baz_b = dat.template get<"baz">();
		if(baz.size() != baz_b.size()){
			expected_values = false;
			err_msg = "baz not equal. ";
			err_msg += std::to_string(baz.size());
			err_msg += " - ";
			err_msg += std::to_string(baz_b.size());
			return;
		}
	}, [&](auto err){
		error_ran = true;
		return std::move(err);
	});

	auto eob = conv.take();
	if(eob.is_error()){
		auto& err = eob.get_error();
		SAW_EXPECT(false, (std::string{"Conv value doesn't exist: "} + std::string{err.get_category()} + std::string{" - "} + std::string{err.get_message()}));
	}
	auto& bval = eob.get_value();

	SAW_EXPECT(!error_ran, "Conveyor ran, but we got an error.");
	SAW_EXPECT(ran, "Conveyor didn't run.");
	SAW_EXPECT(expected_values, std::string{"Values are not equal. "} + err_msg);
}
}