summaryrefslogtreecommitdiff
path: root/modules/remote-hip/examples/hip_transfer_data.cpp
blob: b44a291e2e93d92b90a6116b30be57d1d5cecbf6 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
#include <forstio/codec/data_raw.hpp>

#include "../c++/remote.hpp"
#include "../c++/transfer.hpp"

#include <iostream>

__global__ void print_value(saw::data<saw::schema::Int16,saw::encode::NativeRaw>* val){
	int v = val->get();
	printf("Hello world: %d\n", v);
}

__global__ void print_array_vals(saw::data<saw::schema::Array<saw::schema::Int16>, saw::encode::NativeRaw>* val){
	uint64_t orig_len = val->size();
	long len = (long) orig_len;

	printf("Array size: %ld\n", len);
	for(uint64_t i = 0; i < orig_len; ++i){
		int v = val->at(i).get();
		printf("%d ", v);
	}
	printf("\n");

	auto raw_d = val->get_raw_data();

	for(uint64_t i = 0; i < orig_len; ++i){
		printf("%d ", raw_d[i]);
	}
	printf("\n");
}

namespace sch {
using namespace saw::schema;
}

saw::error_or<void> real_main(){
	using namespace saw;

	remote<rmt::Hip> rmt;

	auto eo_addr = rmt.parse_address(0);
	if(eo_addr.is_error()){
		return std::move(eo_addr.get_error());
	}
	auto& addr = eo_addr.get_value();

	{
		auto eo_dat_srv = rmt.data_listen<sch::Int16, encode::NativeRaw>(*addr);
		if(eo_dat_srv.is_error()){
			return std::move(eo_dat_srv.get_error());
		}
		auto& dat_srv = eo_dat_srv.get_value();

		data<sch::Int16,encode::NativeRaw> val{42};
		id<sch::Int16> id_val{0u};
		auto eo_send = dat_srv->send(val, id_val);
		if(eo_send.is_error()){
			return std::move(eo_send.get_error());
		}

		auto eo_dfind = dat_srv->find(id_val);
		if(eo_dfind.is_error()){
			return std::move(eo_dfind.get_error());
		}
		auto dfind = eo_dfind.get_value();

		auto& v = dfind();

		print_value<<<dim3(2),dim3(2),0,hipStreamDefault>>>(*(v.get_device_data()));
	}

	{
		auto eo_dat_srv = rmt.data_listen<sch::Array<sch::Int16>, encode::NativeRaw>(*addr);
		if(eo_dat_srv.is_error()){
			return std::move(eo_dat_srv.get_error());
		}
		auto& dat_srv = eo_dat_srv.get_value();

		data<sch::Array<sch::Int16>,encode::NativeRaw> val{4};
		val.at(0u).set(5);
		val.at(1u).set(3);
		val.at(2u).set(-6);
		val.at(3u).set(1);
		id<sch::Array<sch::Int16>> id_val{0u};
		auto eo_send = dat_srv->send(val, id_val);
		if(eo_send.is_error()){
			return std::move(eo_send.get_error());
		}

		auto eo_dfind = dat_srv->find(id_val);
		if(eo_dfind.is_error()){
			return std::move(eo_dfind.get_error());
		}
		auto dfind = eo_dfind.get_value();

		auto& v = dfind();
		
		print_array_vals<<<dim3(1),dim3(1),0,hipStreamDefault>>>(*(v.get_device_data()));
	}

	return make_void();
}

int main(){
	auto eov = real_main();
	if(eov.is_error()){
		auto& err = eov.get_error();
		std::cerr<<"Error: "<<err.get_category();
		auto err_msg = err.get_message();
		if(err_msg.size() > 0u){
			std::cerr<<" - "<<err_msg;
		}
		std::cerr<<"\n"<<std::endl;
		return err.get_id();
	}
	return 0;
}