summaryrefslogtreecommitdiff
path: root/modules/remote-hip/c++/device.tmpl.hpp
blob: 2ee991d507f8eea660eebc4e826677ab51e99bdb (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
namespace saw {
namespace impl {
template<typename Schema, typename Encoding>
struct hip_copy_to_device {
	static error_or<void> apply(data<Schema, Encoding>& from, data<Schema, Encoding>** to){
		static_assert(always_false<Schema,Encoding>, "Unsupported case.");
		return make_void();
	}
};

template<typename T, uint64_t N, typename Encoding>
struct hip_copy_to_device<schema::Primitive<T,N>, Encoding> {
	using Schema = schema::Primitive<T,N>;

	static error_or<void> apply(data<Schema, Encoding>& from, data<Schema,Encoding>** to){
		hipError_t malloc_err = hipMalloc(to, sizeof(data<Schema,Encoding>));
		// HIP_CHECK(malloc_err);

		hipError_t copy_err = hipMemcpy(*to, &from, sizeof(data<Schema,Encoding>), hipMemcpyHostToDevice);
		// HIP_CHECK(copy_err);

		return make_void();
	}
};

template<typename T, uint64_t Dim, typename Encoding>
struct hip_copy_to_device<schema::Array<T,Dim>, Encoding> {
	static_assert(Dim == 1u, "Only 1D arrays are supported for now.");
	static_assert(is_primitive<T>::value, "Arrays can only handle primitives for now.");

	using Schema = schema::Array<T,Dim>;

	static error_or<void> apply(data<Schema, Encoding>& from, data<Schema,Encoding>** to){
		typename native_data_type<T>::type* dat{};
		if(from.size() > 0u){
			hipError_t data_malloc_err = hipMalloc(&dat,sizeof(typename native_data_type<T>::type) * from.size());
			hipError_t data_copy_err = hipMemcpy(dat, (from.get_raw_data()),sizeof(typename native_data_type<T>::type) * from.size(), hipMemcpyHostToDevice);
		}
		
		// auto from_dat = &from.at(0);
		data<Schema,Encoding> tmp_fake_dat;
		{
			auto eov = tmp_fake_dat.adopt(dat, from.size());
			if(eov.is_error()){
				return eov;
			}
		}

		hipError_t malloc_err = hipMalloc(to, sizeof(data<Schema,Encoding>));
		hipError_t copy_err = hipMemcpy(*to, &tmp_fake_dat, sizeof(data<Schema,Encoding>), hipMemcpyHostToDevice);

		tmp_fake_dat.extract();

		return make_void();
	}
};
}
}