From 0a8dd2541e20f59812db21e8bad069b50cf8ebaf Mon Sep 17 00:00:00 2001 From: "Claudius \"keldu\" Holeksa" Date: Fri, 23 Jan 2026 12:04:35 +0100 Subject: Preparing for SYCL accesors --- lib/sycl/c++/data.hpp | 205 ++++++++++++++++++++++++++++++++++++++++++-------- 1 file changed, 172 insertions(+), 33 deletions(-) (limited to 'lib/sycl/c++/data.hpp') diff --git a/lib/sycl/c++/data.hpp b/lib/sycl/c++/data.hpp index bb8b4bf..d9976dd 100644 --- a/lib/sycl/c++/data.hpp +++ b/lib/sycl/c++/data.hpp @@ -11,19 +11,115 @@ struct Sycl { }; } +/* namespace impl { template struct struct_has_only_equal_dimension_array{}; } +*/ } } namespace saw { +template +class data, kel::lbm::encode::Sycl> final { +public: + using Schema = schema::FixedArray; +private: + acpp::sycl::queue* q_; + data* values_; + + SAW_FORBID_COPY(data); +public: + data(acpp::sycl::queue& q__): + q_{&q__}, + values_{nullptr} + { + values_ = acpp::sycl::malloc>(ct_multiply::value,q_); + } + + ~data(){ + if(not values_){ + return; + } + + acpp::sycl::free(values_,q_); + } + + static constexpr data> get_dims() { + return {std::array{Dims...}}; + } + + data& at(data>& index){ + return values_[kel::lbm::flatten_index::apply(index,get_dims()).get()]; + } + + constexpr data* flat_data() { + return values_; + } +}; + +template +class data,kel::lbm::encode::Sycl> final { +public: + using Schema = kel::lbm::sch::Chunk; +private: + using InnerSchema = typename Schema::InnerSchema; + using ValueSchema = typename InnerSchema::ValueType; + + data> values_; +public: + data(acpp::sycl::queue& q__): + values_{q__} + {} + + data>& ghost_at(const data>& index){ + return values_.at(index); + } + + const data>& ghost_at(const data>& index) const { + return values_.at(index); + } + + static constexpr auto get_ghost_dims() { + return data>::get_dims(); + } + + data>& at(const data>& index){ + std::decay_t ind; + for(uint64_t i = 0u; i < sizeof...(Sides); ++i){ + ind.at({i}) = index.at({i}) + Ghost; + } + return values_.at(ind); + } + + const data>& at(const data>& index) const { + std::decay_t ind; + for(uint64_t i = 0u; i < sizeof...(Sides); ++i){ + ind.at({i}) = index.at({i}) + Ghost; + } + return values_.at(ind); + } + + static constexpr auto get_dims(){ + return data,kel::lbm::encode::Sycl>{{Sides...}}; + } + + auto flat_data(){ + return values_.flat_data(); + } + + static constexpr auto flat_size() { + return data>::flat_size(); + } +}; + template class data, Keys>...>, kel::lbm::encode::Sycl > final { public: static constexpr data> meta = {{Meta...}}; - using StorageT = std::tuple*...>; + using StorageT = std::tuple,kel::lbm::encode::Sycl>...>; + using Schema = schema::Struct, Keys>...>; private: /** @@ -31,43 +127,18 @@ private: * Do it here by specializing. */ StorageT members_; - kel::lbm::sycl::queue* q_; public: - data(): - members_{}, - q_{nullptr} - {} - - data(StorageT members__, kel::lbm::sycl::queue& q__): - members_{members__}, - q_{&q__} + data(acpp::sycl::queue& q__): + members_{{data,kel::lbm::encode::Sycl>{q__}...}} {} - ~data(){ - SAW_ASSERT(q_){ - return; - } - std::visit([this](auto arg){ - if(not arg){ - return; - } - acpp::sycl::free(arg,*q_); - arg = nullptr; - },members_); - } - - template - auto* get_ptr(){ - return std::get::value>(members_); - } - template auto& get(){ - auto ptr = get_ptr(); - SAW_ASSERT(ptr); - return *ptr; + return std::get::value>(members_); } }; + + } namespace kel { @@ -81,7 +152,7 @@ struct sycl_malloc_struct_helper, Encode> final { using Schema = sch::Struct; template - static saw::error_or allocate_on_device_member(typename saw::data::type::ValueType,encode::Sycl>::StorageT& storage, sycl::queue& q){ + static saw::error_or allocate_on_device_member(typename saw::data>::StorageT& storage, sycl::queue& q){ if constexpr (i < sizeof...(Members)){ using M = typename saw::parameter_pack_type::type; auto& ptr = std::get(storage); @@ -103,6 +174,60 @@ struct sycl_malloc_struct_helper, Encode> final { return eov; } }; + +template +struct sycl_copy_helper; + +template +struct sycl_copy_helper, Encode> final { + using Schema = sch::Struct; + + template + static saw::error_or copy_to_device_member(saw::data& host_data, saw::data>& sycl_data, sycl::queue& q){ + if constexpr (i < sizeof...(Members)){ + using M = typename saw::parameter_pack_type::type; + auto& host_member_data = host_data.template get(); + auto& sycl_member_data = sycl_data.template get(); + + auto host_ptr = host_member_data.flat_data(); + auto sycl_ptr = sycl_member_data.flat_data(); + + q.memcpy(host_ptr, sycl_ptr, sizeof(std::decay_t) * host_member_data.flat_size() ); + + return copy_to_device_member(host_data,sycl_data,q); + } + + return saw::make_void(); + } + + static saw::error_or copy_to_device(saw::data& host_data, saw::data>& sycl_data, sycl::queue& q){ + + return copy_to_device_member<0u>(host_data, sycl_data, q); + } + + template + static saw::error_or copy_to_host_member(saw::data>& sycl_data, saw::data& host_data, sycl::queue& q){ + if constexpr (i < sizeof...(Members)){ + using M = typename saw::parameter_pack_type::type; + auto& host_member_data = host_data.template get(); + auto& sycl_member_data = sycl_data.template get(); + + auto host_ptr = host_member_data.flat_data(); + auto sycl_ptr = sycl_member_data.flat_data(); + + q.memcpy(sycl_ptr, host_ptr, sizeof(std::decay_t) * host_member_data.flat_size() ); + + return copy_to_host_member(sycl_data,host_data,q); + } + + return saw::make_void(); + } + + + static saw::error_or copy_to_host(saw::data& sycl_data, saw::data>& host_data, sycl::queue& q){ + return copy_to_host_member<0u>(sycl_data, host_data, q); + } +}; } class device final { private: @@ -120,7 +245,21 @@ public: if(eov.is_error()){ return eov; } - sycl_data.set_queue(q_); + return saw::make_void(); + } + + template + saw::error_or copy_to_device(saw::data& host_data, saw::data>& sycl_data){ + return impl::sycl_copy_helper::copy_to_device(host_data, sycl_data, q_); + } + + template + saw::error_or copy_to_host(saw::data>& sycl_data, saw::data& host_data){ + return impl::sycl_copy_helper::copy_to_host(sycl_data, host_data, q_); + } + + auto& get_handle(){ + return q_; } }; } -- cgit v1.2.3