From 53b1bc01474a0612e8039485cd4e33fc441f673a Mon Sep 17 00:00:00 2001 From: "Claudius \"keldu\" Holeksa" Date: Wed, 29 Oct 2025 17:47:39 +0100 Subject: Adding some proper USM features :) --- modules/remote-sycl/c++/data.hpp | 126 ++++++++++++++++++++++++--------------- 1 file changed, 79 insertions(+), 47 deletions(-) diff --git a/modules/remote-sycl/c++/data.hpp b/modules/remote-sycl/c++/data.hpp index 91f74f8..2d5893f 100644 --- a/modules/remote-sycl/c++/data.hpp +++ b/modules/remote-sycl/c++/data.hpp @@ -8,73 +8,105 @@ namespace saw { * Generic wrapper class which stores data on the sycl side. * Most of the times this will be a root object. */ -template -class data> { + +template +class data, encode::Sycl> { +public: + using Schema = schema::Array; private: - cl::sycl::buffer> data_; - data size_; + // cl::sycl::buffer> data_; + using sycl_usm_allocator = acpp::sycl::usm_allocator, sycl::usm::alloc::shared>; + data, Encode> dims_; + data size_; + std::vector, sycl_usm_allocator> data_; + + uint64_t get_full_size() const { + uint64_t s = 1; + + for(uint64_t iter = 0; iter < Dim; ++iter){ + auto& dim_iter = dims_.at(data{iter}); + s *= dim_iter.get(); + } + + return s; + } public: - data(const data& data__): - data_{&data__, 1u}, - size_{data__.size()} + data(): + dims_{}, + size_{0u}, + data_{} + { + for(uint64_t iter = 0; iter < Dim; ++iter){ + dims_.at({iter}) = 0u; + } + } + + data(const data, Encode>& dims__): + dims_{dims__}, + size_{get_full_size()}, + data_{size_} {} - auto& get_handle() { - return data_; + auto* get_internal_data() { + if(data_.empty()){ + return nullptr; + } + return &(data_[0u]); } - const auto& get_handle() const { - return data_; + const auto& get_internal_size() const { + return size_; } - data size() const { + data size() const { return size_; } - template - auto access(cl::sycl::handler& h){ - return data_.template get_access(h); - } - - template - auto access(cl::sycl::handler& h) const { - return data_.template get_access(h); + data, Encode> dims() const { + return dims_; } -}; -template -class data, encode::Sycl> { -public: - using Schema = schema::Array; -private: - cl::sycl::buffer> data_; - data size_; -public: - data(const data& host_data__): - data_{&host_data__.at({0u}),host_data__.size().get()}, - size_{host_data__.size()} - {} - - auto& get_handle() { - return data_; + constexpr data& at(const data, Encode>& i){ + return value_.at(this->get_flat_index(i)); } - const auto& get_handle() const { - return data_; + constexpr const data& at(const data, Encode>& i)const{ + return value_.at(this->get_flat_index(i)); } - data size() const { - return size_; + data internal_flat_index(const data, Encode>& i) const { + return {this->get_flat_index(i)}; } +private: + template + uint64_t get_flat_index(const U& i) const { + static_assert( + std::is_same_v, Encode>> or + std::is_same_v>, + "Unsupported type" + ); + assert(value_.size() == get_full_size()); + uint64_t s = 0; - template - auto access(cl::sycl::handler& h){ - return data_.template get_access(h); - } + uint64_t stride = 1; + + for(uint64_t iter = 0; iter < Dim; ++iter){ + uint64_t ind = [](auto val) -> uint64_t { + using V = std::decay_t; + if constexpr (std::is_same_v>){ + return val.get(); + }else if constexpr (std::is_same_v){ + return val; + }else{ + static_assert(always_false, "Cases exhausted"); + } + }(i.at(iter)); + assert(ind < dims_.at({iter}).get() ); + s += ind * stride; + stride *= dims_.at(iter).get(); + } - template - auto access(cl::sycl::handler& h) const { - return data_.template get_access(h); + return s; } }; } -- cgit v1.2.3