diff options
| author | Claudius "keldu" Holeksa <mail@keldu.de> | 2025-10-29 17:47:39 +0100 |
|---|---|---|
| committer | Claudius "keldu" Holeksa <mail@keldu.de> | 2025-10-29 17:47:39 +0100 |
| commit | 53b1bc01474a0612e8039485cd4e33fc441f673a (patch) | |
| tree | 385c05031ab94065cad79b620f365322723fda40 | |
| parent | 7ea2e439dded1baa11c4c12207eee8e1033ae104 (diff) | |
| download | forstio-forstio-53b1bc01474a0612e8039485cd4e33fc441f673a.tar.gz | |
Adding some proper USM features :)
| -rw-r--r-- | modules/remote-sycl/c++/data.hpp | 126 |
1 files changed, 79 insertions, 47 deletions
diff --git a/modules/remote-sycl/c++/data.hpp b/modules/remote-sycl/c++/data.hpp index 91f74f8..2d5893f 100644 --- a/modules/remote-sycl/c++/data.hpp +++ b/modules/remote-sycl/c++/data.hpp @@ -8,73 +8,105 @@ namespace saw { * Generic wrapper class which stores data on the sycl side. * Most of the times this will be a root object. */ -template<typename Schema> -class data<Schema, encode::Sycl<encode::Native>> { + +template<typename Sch, uint64_t Dim, typename Encode> +class data<schema::Array<Sch, Dim>, encode::Sycl<Encode>> { +public: + using Schema = schema::Array<Sch,Dim>; private: - cl::sycl::buffer<data<Schema, encode::Native>> data_; - data<schema::UInt64, encode::Native> size_; + // cl::sycl::buffer<data<Sch, encode::Native>> data_; + using sycl_usm_allocator = acpp::sycl::usm_allocator<data<Sch,Encode>, sycl::usm::alloc::shared>; + data<schema::FixedArray<schema::UInt64, sizeof...(D)>, Encode> dims_; + data<schema::UInt64, Encode> size_; + std::vector<data<Sch,Encode>, sycl_usm_allocator> data_; + + uint64_t get_full_size() const { + uint64_t s = 1; + + for(uint64_t iter = 0; iter < Dim; ++iter){ + auto& dim_iter = dims_.at(data<schema::UInt64>{iter}); + s *= dim_iter.get(); + } + + return s; + } public: - data(const data<Schema, encode::Native>& data__): - data_{&data__, 1u}, - size_{data__.size()} + data(): + dims_{}, + size_{0u}, + data_{} + { + for(uint64_t iter = 0; iter < Dim; ++iter){ + dims_.at({iter}) = 0u; + } + } + + data(const data<schema::FixedArray<schema::UInt64, sizeof...(D)>, Encode>& dims__): + dims_{dims__}, + size_{get_full_size()}, + data_{size_} {} - auto& get_handle() { - return data_; + auto* get_internal_data() { + if(data_.empty()){ + return nullptr; + } + return &(data_[0u]); } - const auto& get_handle() const { - return data_; + const auto& get_internal_size() const { + return size_; } - data<schema::UInt64, encode::Native> size() const { + data<schema::UInt64, Encode> size() const { return size_; } - template<cl::sycl::access::mode AccessMode> - auto access(cl::sycl::handler& h){ - return data_.template get_access<AccessMode>(h); - } - - template<cl::sycl::access::mode AccessMode> - auto access(cl::sycl::handler& h) const { - return data_.template get_access<AccessMode>(h); + data<schema::FixedArray<schema::UInt64, sizeof...(D)>, Encode> dims() const { + return dims_; } -}; -template<typename Sch, uint64_t Dim> -class data<schema::Array<Sch, Dim>, encode::Sycl<encode::Native>> { -public: - using Schema = schema::Array<Sch,Dim>; -private: - cl::sycl::buffer<data<Sch, encode::Native>> data_; - data<schema::UInt64, encode::Native> size_; -public: - data(const data<Schema, encode::Native>& host_data__): - data_{&host_data__.at({0u}),host_data__.size().get()}, - size_{host_data__.size()} - {} - - auto& get_handle() { - return data_; + constexpr data<T, Encode>& at(const data<schema::FixedArray<schema::UInt64, sizeof...(D)>, Encode>& i){ + return value_.at(this->get_flat_index(i)); } - const auto& get_handle() const { - return data_; + constexpr const data<T, Encode>& at(const data<schema::FixedArray<schema::UInt64, sizeof...(D)>, Encode>& i)const{ + return value_.at(this->get_flat_index(i)); } - data<schema::UInt64, encode::Native> size() const { - return size_; + data<schema::UInt64,Encode> internal_flat_index(const data<schema::FixedArray<schema::UInt64, sizeof...(D)>, Encode>& i) const { + return {this->get_flat_index(i)}; } +private: + template<typename U> + uint64_t get_flat_index(const U& i) const { + static_assert( + std::is_same_v<U,data<schema::FixedArray<schema::UInt64,Dim>, Encode>> or + std::is_same_v<U,std::array<uint64_t,Dim>>, + "Unsupported type" + ); + assert(value_.size() == get_full_size()); + uint64_t s = 0; - template<cl::sycl::access::mode AccessMode> - auto access(cl::sycl::handler& h){ - return data_.template get_access<AccessMode>(h); - } + uint64_t stride = 1; + + for(uint64_t iter = 0; iter < Dim; ++iter){ + uint64_t ind = [](auto val) -> uint64_t { + using V = std::decay_t<decltype(val)>; + if constexpr (std::is_same_v<V,data<schema::UInt64>>){ + return val.get(); + }else if constexpr (std::is_same_v<V, uint64_t>){ + return val; + }else{ + static_assert(always_false<V>, "Cases exhausted"); + } + }(i.at(iter)); + assert(ind < dims_.at({iter}).get() ); + s += ind * stride; + stride *= dims_.at(iter).get(); + } - template<cl::sycl::access::mode AccessMode> - auto access(cl::sycl::handler& h) const { - return data_.template get_access<AccessMode>(h); + return s; } }; } |
