diff options
Diffstat (limited to 'modules/remote-sycl/c++/data.hpp')
| -rw-r--r-- | modules/remote-sycl/c++/data.hpp | 63 |
1 files changed, 61 insertions, 2 deletions
diff --git a/modules/remote-sycl/c++/data.hpp b/modules/remote-sycl/c++/data.hpp index 11dfbf2..e057766 100644 --- a/modules/remote-sycl/c++/data.hpp +++ b/modules/remote-sycl/c++/data.hpp @@ -57,7 +57,7 @@ public: return &(data_[0u]); } - const auto& get_internal_size() const { + auto get_internal_size() const { return size_; } @@ -112,4 +112,63 @@ private: return s; } }; -} + +template<typename Sch, uint64_t Dim, typename Encode> +class data<schema::Ref<schema::Array<Sch, Dim>>, encode::Sycl<Encode>> { +public: + using Schema = schema::Ref<schema::Array<Sch,Dim>>; +private: + data<schema::Sch, Encode>* internal_data_ptr_; + data<schema::FixedArray<schema::UInt64, Dim>, Encode> dims_; + data<schema::UInt64, Encode> size_; + + uint64_t get_full_size() const { + uint64_t s = 1; + + for(uint64_t iter = 0; iter < Dim; ++iter){ + auto& dim_iter = dims_.at(data<schema::UInt64>{iter}); + s *= dim_iter.get(); + } + + return s; + } +public: + data() = delete; + + data(ref<data<schema::FixedArray<schema::UInt64, Dim>, Encode>> ref_data__): + internal_data_ptr_{ref_data__().get_internal_data()}, + dims_{ref_data__().dims()}, + size_{ref_data__().size()} + {} +private: + template<typename U> + uint64_t get_flat_index(const U& i) const { + static_assert( + std::is_same_v<U,data<schema::FixedArray<schema::UInt64,Dim>, Encode>> or + std::is_same_v<U,std::array<uint64_t,Dim>>, + "Unsupported type" + ); + assert(data_.size() == get_full_size()); + uint64_t s = 0; + + uint64_t stride = 1; + + for(uint64_t iter = 0; iter < Dim; ++iter){ + uint64_t ind = [](auto val) -> uint64_t { + using V = std::decay_t<decltype(val)>; + if constexpr (std::is_same_v<V,data<schema::UInt64>>){ + return val.get(); + }else if constexpr (std::is_same_v<V, uint64_t>){ + return val; + }else{ + static_assert(always_false<V>, "Cases exhausted"); + } + }(i.at(iter)); + assert(ind < dims_.at({iter}).get() ); + s += ind * stride; + stride *= dims_.at(iter).get(); + } + + return s; + } +}; |
