#pragma once #include "common.hpp" namespace saw { /** * Generic wrapper class which stores data on the sycl side. * Most of the times this will be a root object. */ template class data, encode::Sycl> { public: using Schema = schema::Array; private: // cl::sycl::buffer> data_; using sycl_usm_allocator = acpp::sycl::usm_allocator, acpp::sycl::usm::alloc::shared>; sycl_usm_allocator sycl_alloc_; data, Encode> dims_; data size_; std::vector, sycl_usm_allocator> data_; uint64_t get_full_size() const { uint64_t s = 1; for(uint64_t iter = 0; iter < Dim; ++iter){ auto& dim_iter = dims_.at(data{iter}); s *= dim_iter.get(); } return s; } public: data(acpp::sycl::queue& q__): sycl_alloc_{q__}, dims_{}, size_{0u}, data_{0u,sycl_alloc_} { for(uint64_t iter = 0; iter < Dim; ++iter){ dims_.at({iter}) = 0u; } } data(const data, Encode>& dims__, acpp::sycl::queue& q__): sycl_alloc_{q__}, dims_{dims__}, size_{get_full_size()}, data_{size_.get(),sycl_alloc_} {} data* get_internal_data() { if(data_.empty()){ return nullptr; } return &(data_[0u]); } auto get_internal_size() const { return size_; } data size() const { return size_; } data, Encode> dims() const { return dims_; } constexpr data& at(const data, Encode>& i){ return data_.at(this->get_flat_index(i)); } constexpr const data& at(const data, Encode>& i)const{ return data_.at(this->get_flat_index(i)); } data internal_flat_index(const data, Encode>& i) const { return {this->get_flat_index(i)}; } private: template uint64_t get_flat_index(const U& i) const { static_assert( std::is_same_v, Encode>> or std::is_same_v>, "Unsupported type" ); assert(data_.size() == get_full_size()); uint64_t s = 0; uint64_t stride = 1; for(uint64_t iter = 0; iter < Dim; ++iter){ uint64_t ind = [](auto val) -> uint64_t { using V = std::decay_t; if constexpr (std::is_same_v>){ return val.get(); }else if constexpr (std::is_same_v){ return val; }else{ static_assert(always_false, "Cases exhausted"); } }(i.at(iter)); assert(ind < dims_.at({iter}).get() ); s += ind * stride; stride *= dims_.at(iter).get(); } return s; } }; template class data>, encode::Sycl> { public: using Schema = schema::Ref>; private: data* internal_data_ptr_; data, Encode> dims_; data size_; uint64_t get_full_size() const { uint64_t s = 1; for(uint64_t iter = 0; iter < Dim; ++iter){ auto& dim_iter = dims_.at(data{iter}); s *= dim_iter.get(); } return s; } public: data() = delete; data(ref, encode::Sycl>> ref_data__): internal_data_ptr_{ref_data__().get_internal_data()}, dims_{ref_data__().dims()}, size_{ref_data__().size()} {} auto* get_internal_data() { return internal_data_ptr_; } constexpr data& at(const data, Encode>& i){ return internal_data_ptr_[this->get_flat_index(i)]; } constexpr const data& at(const data, Encode>& i)const{ return internal_data_ptr_[this->get_flat_index(i)]; } private: template uint64_t get_flat_index(const U& i) const { static_assert( std::is_same_v, Encode>> or std::is_same_v>, "Unsupported type" ); assert(size_ == get_full_size()); uint64_t s = 0; uint64_t stride = 1; for(uint64_t iter = 0; iter < Dim; ++iter){ uint64_t ind = [](auto val) -> uint64_t { using V = std::decay_t; if constexpr (std::is_same_v>){ return val.get(); }else if constexpr (std::is_same_v){ return val; }else{ static_assert(always_false, "Cases exhausted"); } }(i.at(iter)); assert(ind < dims_.at({iter}).get() ); s += ind * stride; stride *= dims_.at(iter).get(); } return s; } }; }