summaryrefslogtreecommitdiff
path: root/modules/remote-sycl/c++/data.hpp
diff options
context:
space:
mode:
Diffstat (limited to 'modules/remote-sycl/c++/data.hpp')
-rw-r--r--modules/remote-sycl/c++/data.hpp63
1 files changed, 61 insertions, 2 deletions
diff --git a/modules/remote-sycl/c++/data.hpp b/modules/remote-sycl/c++/data.hpp
index 11dfbf2..e057766 100644
--- a/modules/remote-sycl/c++/data.hpp
+++ b/modules/remote-sycl/c++/data.hpp
@@ -57,7 +57,7 @@ public:
return &(data_[0u]);
}
- const auto& get_internal_size() const {
+ auto get_internal_size() const {
return size_;
}
@@ -112,4 +112,63 @@ private:
return s;
}
};
-}
+
+template<typename Sch, uint64_t Dim, typename Encode>
+class data<schema::Ref<schema::Array<Sch, Dim>>, encode::Sycl<Encode>> {
+public:
+ using Schema = schema::Ref<schema::Array<Sch,Dim>>;
+private:
+ data<schema::Sch, Encode>* internal_data_ptr_;
+ data<schema::FixedArray<schema::UInt64, Dim>, Encode> dims_;
+ data<schema::UInt64, Encode> size_;
+
+ uint64_t get_full_size() const {
+ uint64_t s = 1;
+
+ for(uint64_t iter = 0; iter < Dim; ++iter){
+ auto& dim_iter = dims_.at(data<schema::UInt64>{iter});
+ s *= dim_iter.get();
+ }
+
+ return s;
+ }
+public:
+ data() = delete;
+
+ data(ref<data<schema::FixedArray<schema::UInt64, Dim>, Encode>> ref_data__):
+ internal_data_ptr_{ref_data__().get_internal_data()},
+ dims_{ref_data__().dims()},
+ size_{ref_data__().size()}
+ {}
+private:
+ template<typename U>
+ uint64_t get_flat_index(const U& i) const {
+ static_assert(
+ std::is_same_v<U,data<schema::FixedArray<schema::UInt64,Dim>, Encode>> or
+ std::is_same_v<U,std::array<uint64_t,Dim>>,
+ "Unsupported type"
+ );
+ assert(data_.size() == get_full_size());
+ uint64_t s = 0;
+
+ uint64_t stride = 1;
+
+ for(uint64_t iter = 0; iter < Dim; ++iter){
+ uint64_t ind = [](auto val) -> uint64_t {
+ using V = std::decay_t<decltype(val)>;
+ if constexpr (std::is_same_v<V,data<schema::UInt64>>){
+ return val.get();
+ }else if constexpr (std::is_same_v<V, uint64_t>){
+ return val;
+ }else{
+ static_assert(always_false<V>, "Cases exhausted");
+ }
+ }(i.at(iter));
+ assert(ind < dims_.at({iter}).get() );
+ s += ind * stride;
+ stride *= dims_.at(iter).get();
+ }
+
+ return s;
+ }
+};