From 668e53e42e210d2cedf29281eb187e8d7f129651 Mon Sep 17 00:00:00 2001 From: "Claudius \"keldu\" Holeksa" Date: Tue, 18 Nov 2025 17:46:04 +0100 Subject: Working on tests in sycl --- modules/remote-sycl/c++/common.hpp | 1 + modules/remote-sycl/c++/data.hpp | 63 ++++++++++++++++++++++++++++++++++++-- 2 files changed, 62 insertions(+), 2 deletions(-) (limited to 'modules/remote-sycl/c++') diff --git a/modules/remote-sycl/c++/common.hpp b/modules/remote-sycl/c++/common.hpp index 287075f..54a09d1 100644 --- a/modules/remote-sycl/c++/common.hpp +++ b/modules/remote-sycl/c++/common.hpp @@ -11,6 +11,7 @@ namespace saw { namespace rmt { struct Sycl {}; } + namespace encode { template struct Sycl {}; diff --git a/modules/remote-sycl/c++/data.hpp b/modules/remote-sycl/c++/data.hpp index 11dfbf2..e057766 100644 --- a/modules/remote-sycl/c++/data.hpp +++ b/modules/remote-sycl/c++/data.hpp @@ -57,7 +57,7 @@ public: return &(data_[0u]); } - const auto& get_internal_size() const { + auto get_internal_size() const { return size_; } @@ -112,4 +112,63 @@ private: return s; } }; -} + +template +class data>, encode::Sycl> { +public: + using Schema = schema::Ref>; +private: + data* internal_data_ptr_; + data, Encode> dims_; + data size_; + + uint64_t get_full_size() const { + uint64_t s = 1; + + for(uint64_t iter = 0; iter < Dim; ++iter){ + auto& dim_iter = dims_.at(data{iter}); + s *= dim_iter.get(); + } + + return s; + } +public: + data() = delete; + + data(ref, Encode>> ref_data__): + internal_data_ptr_{ref_data__().get_internal_data()}, + dims_{ref_data__().dims()}, + size_{ref_data__().size()} + {} +private: + template + uint64_t get_flat_index(const U& i) const { + static_assert( + std::is_same_v, Encode>> or + std::is_same_v>, + "Unsupported type" + ); + assert(data_.size() == get_full_size()); + uint64_t s = 0; + + uint64_t stride = 1; + + for(uint64_t iter = 0; iter < Dim; ++iter){ + uint64_t ind = [](auto val) -> uint64_t { + using V = std::decay_t; + if constexpr (std::is_same_v>){ + return val.get(); + }else if constexpr (std::is_same_v){ + return val; + }else{ + static_assert(always_false, "Cases exhausted"); + } + }(i.at(iter)); + assert(ind < dims_.at({iter}).get() ); + s += ind * stride; + stride *= dims_.at(iter).get(); + } + + return s; + } +}; -- cgit v1.2.3