diff options
Diffstat (limited to 'lib/sycl/c++')
| -rw-r--r-- | lib/sycl/c++/data.hpp | 205 |
1 files changed, 172 insertions, 33 deletions
diff --git a/lib/sycl/c++/data.hpp b/lib/sycl/c++/data.hpp index bb8b4bf..d9976dd 100644 --- a/lib/sycl/c++/data.hpp +++ b/lib/sycl/c++/data.hpp @@ -11,19 +11,115 @@ struct Sycl { }; } +/* namespace impl { template<typename Schema> struct struct_has_only_equal_dimension_array{}; } +*/ } } namespace saw { +template<typename Sch, uint64_t... Dims, typename Encode> +class data<schema::FixedArray<Sch,Dims...>, kel::lbm::encode::Sycl<Encode>> final { +public: + using Schema = schema::FixedArray<Sch,Dims...>; +private: + acpp::sycl::queue* q_; + data<Sch>* values_; + + SAW_FORBID_COPY(data); +public: + data(acpp::sycl::queue& q__): + q_{&q__}, + values_{nullptr} + { + values_ = acpp::sycl::malloc<data<Sch>>(ct_multiply<uint64_t,Dims...>::value,q_); + } + + ~data(){ + if(not values_){ + return; + } + + acpp::sycl::free(values_,q_); + } + + static constexpr data<schema::FixedArray<schema::UInt64, sizeof...(Dims)>> get_dims() { + return {std::array<uint64_t, sizeof...(Dims)>{Dims...}}; + } + + data<Sch>& at(data<schema::FixedArray<schema::UInt64,sizeof...(Dims)>>& index){ + return values_[kel::lbm::flatten_index<schema::UInt64,sizeof...(Dims)>::apply(index,get_dims()).get()]; + } + + constexpr data<Sch,encode::Native>* flat_data() { + return values_; + } +}; + +template<typename Sch, uint64_t Ghost, uint64_t... Sides, typename Encode> +class data<kel::lbm::sch::Chunk<Sch,Ghost,Sides...>,kel::lbm::encode::Sycl<Encode>> final { +public: + using Schema = kel::lbm::sch::Chunk<Sch,Ghost,Sides...>; +private: + using InnerSchema = typename Schema::InnerSchema; + using ValueSchema = typename InnerSchema::ValueType; + + data<InnerSchema, kel::lbm::encode::Sycl<Encode>> values_; +public: + data(acpp::sycl::queue& q__): + values_{q__} + {} + + data<ValueSchema, kel::lbm::encode::Sycl<Encode>>& ghost_at(const data<schema::FixedArray<schema::UInt64,sizeof...(Sides)>>& index){ + return values_.at(index); + } + + const data<ValueSchema, kel::lbm::encode::Sycl<Encode>>& ghost_at(const data<schema::FixedArray<schema::UInt64,sizeof...(Sides)>>& index) const { + return values_.at(index); + } + + static constexpr auto get_ghost_dims() { + return data<InnerSchema,kel::lbm::encode::Sycl<Encode>>::get_dims(); + } + + data<ValueSchema, kel::lbm::encode::Sycl<Encode>>& at(const data<schema::FixedArray<schema::UInt64,sizeof...(Sides)>>& index){ + std::decay_t<decltype(index)> ind; + for(uint64_t i = 0u; i < sizeof...(Sides); ++i){ + ind.at({i}) = index.at({i}) + Ghost; + } + return values_.at(ind); + } + + const data<ValueSchema, kel::lbm::encode::Sycl<Encode>>& at(const data<schema::FixedArray<schema::UInt64,sizeof...(Sides)>>& index) const { + std::decay_t<decltype(index)> ind; + for(uint64_t i = 0u; i < sizeof...(Sides); ++i){ + ind.at({i}) = index.at({i}) + Ghost; + } + return values_.at(ind); + } + + static constexpr auto get_dims(){ + return data<schema::FixedArray<schema::UInt64, sizeof...(Sides)>,kel::lbm::encode::Sycl<Encode>>{{Sides...}}; + } + + auto flat_data(){ + return values_.flat_data(); + } + + static constexpr auto flat_size() { + return data<InnerSchema,kel::lbm::encode::Sycl<Encode>>::flat_size(); + } +}; + template<uint64_t Ghost, uint64_t... Meta, typename... Sch, string_literal... Keys, typename Encode> class data<schema::Struct<schema::Member<kel::lbm::sch::Chunk<Sch,Ghost,Meta...>, Keys>...>, kel::lbm::encode::Sycl<Encode> > final { public: static constexpr data<schema::FixedArray<schema::UInt64,sizeof...(Meta)>> meta = {{Meta...}}; - using StorageT = std::tuple<data<Sch,Encode>*...>; + using StorageT = std::tuple<data<kel::lbm::sch::Chunk<Sch,Ghost,Meta...>,kel::lbm::encode::Sycl<Encode>>...>; + using Schema = schema::Struct<schema::Member<kel::lbm::sch::Chunk<Sch,Ghost,Meta...>, Keys>...>; private: /** @@ -31,43 +127,18 @@ private: * Do it here by specializing. */ StorageT members_; - kel::lbm::sycl::queue* q_; public: - data(): - members_{}, - q_{nullptr} - {} - - data(StorageT members__, kel::lbm::sycl::queue& q__): - members_{members__}, - q_{&q__} + data(acpp::sycl::queue& q__): + members_{{data<kel::lbm::sch::Chunk<Sch,Ghost,Meta...>,kel::lbm::encode::Sycl<Encode>>{q__}...}} {} - ~data(){ - SAW_ASSERT(q_){ - return; - } - std::visit([this](auto arg){ - if(not arg){ - return; - } - acpp::sycl::free(arg,*q_); - arg = nullptr; - },members_); - } - - template<saw::string_literal K> - auto* get_ptr(){ - return std::get<parameter_key_pack_index<K, Keys...>::value>(members_); - } - template<saw::string_literal K> auto& get(){ - auto ptr = get_ptr<K>(); - SAW_ASSERT(ptr); - return *ptr; + return std::get<parameter_key_pack_index<K, Keys...>::value>(members_); } }; + + } namespace kel { @@ -81,7 +152,7 @@ struct sycl_malloc_struct_helper<sch::Struct<Members...>, Encode> final { using Schema = sch::Struct<Members...>; template<uint64_t i> - static saw::error_or<void> allocate_on_device_member(typename saw::data<typename saw::parameter_pack_type<i,Members...>::type::ValueType,encode::Sycl<Encode>>::StorageT& storage, sycl::queue& q){ + static saw::error_or<void> allocate_on_device_member(typename saw::data<Schema,encode::Sycl<Encode>>::StorageT& storage, sycl::queue& q){ if constexpr (i < sizeof...(Members)){ using M = typename saw::parameter_pack_type<i,Members...>::type; auto& ptr = std::get<i>(storage); @@ -103,6 +174,60 @@ struct sycl_malloc_struct_helper<sch::Struct<Members...>, Encode> final { return eov; } }; + +template<typename Sch, typename Encode> +struct sycl_copy_helper; + +template<typename... Members, typename Encode> +struct sycl_copy_helper<sch::Struct<Members...>, Encode> final { + using Schema = sch::Struct<Members...>; + + template<uint64_t i> + static saw::error_or<void> copy_to_device_member(saw::data<Schema,Encode>& host_data, saw::data<Schema,encode::Sycl<Encode>>& sycl_data, sycl::queue& q){ + if constexpr (i < sizeof...(Members)){ + using M = typename saw::parameter_pack_type<i,Members...>::type; + auto& host_member_data = host_data.template get<i>(); + auto& sycl_member_data = sycl_data.template get<i>(); + + auto host_ptr = host_member_data.flat_data(); + auto sycl_ptr = sycl_member_data.flat_data(); + + q.memcpy(host_ptr, sycl_ptr, sizeof(std::decay_t<decltype(host_ptr)>) * host_member_data.flat_size() ); + + return copy_to_device_member<i+1u>(host_data,sycl_data,q); + } + + return saw::make_void(); + } + + static saw::error_or<void> copy_to_device(saw::data<Schema,Encode>& host_data, saw::data<Schema, encode::Sycl<Encode>>& sycl_data, sycl::queue& q){ + + return copy_to_device_member<0u>(host_data, sycl_data, q); + } + + template<uint64_t i> + static saw::error_or<void> copy_to_host_member(saw::data<Schema,encode::Sycl<Encode>>& sycl_data, saw::data<Schema,Encode>& host_data, sycl::queue& q){ + if constexpr (i < sizeof...(Members)){ + using M = typename saw::parameter_pack_type<i,Members...>::type; + auto& host_member_data = host_data.template get<i>(); + auto& sycl_member_data = sycl_data.template get<i>(); + + auto host_ptr = host_member_data.flat_data(); + auto sycl_ptr = sycl_member_data.flat_data(); + + q.memcpy(sycl_ptr, host_ptr, sizeof(std::decay_t<decltype(host_ptr)>) * host_member_data.flat_size() ); + + return copy_to_host_member<i+1u>(sycl_data,host_data,q); + } + + return saw::make_void(); + } + + + static saw::error_or<void> copy_to_host(saw::data<Schema,Encode>& sycl_data, saw::data<Schema,encode::Sycl<Encode>>& host_data, sycl::queue& q){ + return copy_to_host_member<0u>(sycl_data, host_data, q); + } +}; } class device final { private: @@ -120,7 +245,21 @@ public: if(eov.is_error()){ return eov; } - sycl_data.set_queue(q_); + return saw::make_void(); + } + + template<typename Sch, typename Encode> + saw::error_or<void> copy_to_device(saw::data<Sch,Encode>& host_data, saw::data<Sch,encode::Sycl<Encode>>& sycl_data){ + return impl::sycl_copy_helper<Sch,Encode>::copy_to_device(host_data, sycl_data, q_); + } + + template<typename Sch, typename Encode> + saw::error_or<void> copy_to_host(saw::data<Sch,encode::Sycl<Encode>>& sycl_data, saw::data<Sch,Encode>& host_data){ + return impl::sycl_copy_helper<Sch,Encode>::copy_to_host(sycl_data, host_data, q_); + } + + auto& get_handle(){ + return q_; } }; } |
