summaryrefslogtreecommitdiff
path: root/modules/remote-sycl/c++/data.hpp
blob: 11dfbf2705508ffd7c30ad512d2eea87453c10aa (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
#pragma once

#include "common.hpp"

namespace saw {

/**
 * Generic wrapper class which stores data on the sycl side.
 * Most of the times this will be a root object.
 */

template<typename Sch, uint64_t Dim, typename Encode>
class data<schema::Array<Sch, Dim>, encode::Sycl<Encode>> {
public:
	using Schema = schema::Array<Sch,Dim>;
private:
	// cl::sycl::buffer<data<Sch, encode::Native>> data_;
	using sycl_usm_allocator = acpp::sycl::usm_allocator<data<Sch,Encode>, acpp::sycl::usm::alloc::shared>;
	sycl_usm_allocator sycl_alloc_;
	data<schema::FixedArray<schema::UInt64, Dim>, Encode> dims_;
	data<schema::UInt64, Encode> size_;
	std::vector<data<Sch,Encode>, sycl_usm_allocator> data_;

	uint64_t get_full_size() const {
		uint64_t s = 1;

		for(uint64_t iter = 0; iter < Dim; ++iter){
			auto& dim_iter = dims_.at(data<schema::UInt64>{iter});
			s *= dim_iter.get();
		}

		return s;
	}
public:
	data(acpp::sycl::queue& q__):
		sycl_alloc_{q__},
		dims_{},
		size_{0u},
		data_{0u,sycl_alloc_}
	{
		for(uint64_t iter = 0; iter < Dim; ++iter){
			dims_.at({iter}) = 0u;
		}
	}

	data(const data<schema::FixedArray<schema::UInt64, Dim>, Encode>& dims__, acpp::sycl::queue& q__):
		sycl_alloc_{q__},
		dims_{dims__},
		size_{get_full_size()},
		data_{size_.get(),sycl_alloc_}
	{}

	auto* get_internal_data() {
		if(data_.empty()){
			return nullptr;
		}
		return &(data_[0u]);
	}

	const auto& get_internal_size() const {
		return size_;
	}

	data<schema::UInt64, Encode> size() const {
		return size_;
	}

	data<schema::FixedArray<schema::UInt64, Dim>, Encode> dims() const {
		return dims_;
	}

	constexpr data<Sch, Encode>& at(const data<schema::FixedArray<schema::UInt64, Dim>, Encode>& i){
		return data_.at(this->get_flat_index(i));
	}

	constexpr const data<Sch, Encode>& at(const data<schema::FixedArray<schema::UInt64, Dim>, Encode>& i)const{
		return data_.at(this->get_flat_index(i));
	}

	data<schema::UInt64,Encode> internal_flat_index(const data<schema::FixedArray<schema::UInt64, Dim>, Encode>& i) const {
		return {this->get_flat_index(i)};
	}
private:
	template<typename U>
	uint64_t get_flat_index(const U& i) const {
		static_assert(
			std::is_same_v<U,data<schema::FixedArray<schema::UInt64,Dim>, Encode>> or
			std::is_same_v<U,std::array<uint64_t,Dim>>,
			"Unsupported type"
		);
		assert(data_.size() == get_full_size());
		uint64_t s = 0;

		uint64_t stride = 1;

		for(uint64_t iter = 0; iter < Dim; ++iter){
			uint64_t ind = [](auto val) -> uint64_t {
				using V = std::decay_t<decltype(val)>;
				if constexpr (std::is_same_v<V,data<schema::UInt64>>){
					return val.get();
				}else if constexpr (std::is_same_v<V, uint64_t>){
					return val;
				}else{
					static_assert(always_false<V>, "Cases exhausted");
				}
			}(i.at(iter));
			assert(ind < dims_.at({iter}).get() );
			s += ind * stride;
			stride *= dims_.at(iter).get();
		}

		return s;
	}
};
}