diff options
Diffstat (limited to 'lib/sycl')
| -rw-r--r-- | lib/sycl/c++/data.hpp | 112 | ||||
| -rw-r--r-- | lib/sycl/tests/data.cpp | 27 |
2 files changed, 134 insertions, 5 deletions
diff --git a/lib/sycl/c++/data.hpp b/lib/sycl/c++/data.hpp index 4e5129b..3ac51e0 100644 --- a/lib/sycl/c++/data.hpp +++ b/lib/sycl/c++/data.hpp @@ -29,8 +29,17 @@ private: SAW_FORBID_COPY(data); SAW_FORBID_MOVE(data); - public: + data(const data<typename meta_schema<Schema>::MetaSchema>& meta__, acpp::sycl::queue& q__): + q_{&q__}, + values_{nullptr} + { + (void) meta__; + SAW_ASSERT(q_); + values_ = acpp::sycl::malloc_device<data<Sch,Encode>>(ct_multiply<uint64_t,Dims...>::value,*q_); + SAW_ASSERT(values_); + } + data(acpp::sycl::queue& q__): q_{&q__}, values_{nullptr} @@ -65,7 +74,6 @@ public: constexpr data<Sch,Encode>* flat_data() const { return values_; } - }; template<typename Sch, uint64_t... Dims, typename Encode> @@ -122,6 +130,20 @@ private: SAW_FORBID_COPY(data); SAW_FORBID_MOVE(data); public: + data(const data<typename meta_schema<Schema>::MetaSchema>& meta__, acpp::sycl::queue& q__): + q_{&q__}, + values_{nullptr} + { + SAW_ASSERT(q_); + /// TODO use meta + data<schema::UInt64> m{1u}; + for(uint64_t i = 0u; i < Dims; ++i){ + m = m * meta__.at({i}); + } + values_ = acpp::sycl::malloc_device<data<Sch,Encode>>(m.get(),*q_); + SAW_ASSERT(values_); + } + data(acpp::sycl::queue& q__): values_{nullptr}, meta_{}, @@ -198,8 +220,7 @@ public: data(const data<schema::Array<Sch,Dims>, kel::lbm::encode::Sycl<Encode>>& values__): values_{values__.flat_data()}, meta_{values__.meta()} - { - } + {} constexpr data<schema::FixedArray<schema::UInt64, Dims>, Encode> meta() const { return meta_; @@ -238,6 +259,10 @@ private: data<InnerSchema, kel::lbm::encode::Sycl<Encode>> values_; public: + data(const data<typename meta_schema<Schema>::MetaSchema>& meta__, acpp::sycl::queue& q__): + values_{meta__,q__} + {} + data(acpp::sycl::queue& q__): values_{q__} {} @@ -366,6 +391,11 @@ private: members_{(static_cast<void>(Is), q)...} {} public: + data(data<typename meta_schema<Schema>::MetaSchema>& meta__, acpp::sycl::queue& q__): + data{q__, std::make_index_sequence<sizeof...(Members)>{}} + { + } + data(acpp::sycl::queue& q__): data{q__, std::make_index_sequence<sizeof...(Members)>{}} { @@ -546,6 +576,29 @@ struct sycl_copy_helper<sch::Struct<Members...>, Encode> final { static saw::error_or<void> copy_to_host(saw::data<Schema,encode::Sycl<Encode>>& sycl_data, saw::data<Schema,Encode>& host_data, sycl::queue& q){ return copy_to_host_member<0u>(sycl_data, host_data, q); } + + template<uint64_t i> + static saw::error_or<void> malloc_on_device_member(saw::data<Schema,Encode>& host_data, saw::data<Schema,encode::Sycl<Encode>>& sycl_data, sycl::queue& q){ + + if constexpr (i < sizeof...(Members)){ + using M = typename saw::parameter_pack_type<i,Members...>::type; + auto& host_member_data = host_data.template get<i>(); + auto& sycl_member_data = sycl_data.template get<i>(); + + auto eov = sycl_copy_helper<typename M::ValueType,Encode>::malloc_on_device(host_member_data,sycl_member_data,q); + if(eov.is_error()){ + return eov; + } + + return malloc_on_device_member<i+1u>(host_data,sycl_data,q); + } + + return saw::make_void(); + } + + static saw::error_or<void> malloc_on_device(saw::data<Schema,Encode>& host_data, saw::data<Schema,encode::Sycl<Encode>>& sycl_data, sycl::queue& q){ + return malloc_on_device_member<0u>(host_data,sycl_data,q); + } }; template<typename... Members, typename Encode> @@ -594,9 +647,36 @@ struct sycl_copy_helper<sch::Tuple<Members...>, Encode> final { } - static saw::error_or<void> copy_to_host(saw::data<Schema,encode::Sycl<Encode>>& sycl_data, saw::data<Schema,Encode>& host_data, sycl::queue& q){ + static saw::error_or<void> copy_to_host( + saw::data<Schema,Encode>& host_data, + saw::data<Schema,encode::Sycl<Encode>>& sycl_data, + sycl::queue& q + ){ return copy_to_host_member<0u>(sycl_data, host_data, q); } + + template<uint64_t i> + static saw::error_or<void> malloc_on_device_member(saw::data<Schema,Encode>& host_data, saw::data<Schema,encode::Sycl<Encode>>& sycl_data, sycl::queue& q){ + + if constexpr (i < sizeof...(Members)){ + using M = typename saw::parameter_pack_type<i,Members...>::type; + auto& host_member_data = host_data.template get<i>(); + auto& sycl_member_data = sycl_data.template get<i>(); + + auto eov = sycl_copy_helper<M,Encode>::malloc_on_device(host_member_data,sycl_member_data,q); + if(eov.is_error()){ + return eov; + } + + return malloc_on_device_member<i+1u>(host_data,sycl_data,q); + } + + return saw::make_void(); + } + + static saw::error_or<void> malloc_on_device(saw::data<Schema,Encode>& host_data, saw::data<Schema,encode::Sycl<Encode>>& sycl_data, sycl::queue& q){ + return malloc_on_device_member<0u>(host_data,sycl_data,q); + } }; template<typename Sch, uint64_t... Dims, typename Encode> @@ -626,6 +706,13 @@ struct sycl_copy_helper<sch::FixedArray<Sch,Dims...>, Encode> final { }).wait(); return saw::make_void(); } + + static saw::error_or<void> malloc_on_device(saw::data<Schema,Encode>& host_data, saw::data<Schema,encode::Sycl<Encode>>& sycl_data, sycl::queue& q){ + (void) host_data; + (void) sycl_data; + (void) q; + return saw::make_void(); + } }; template<typename Sch, uint64_t Ghost, uint64_t... Dims, typename Encode> @@ -684,6 +771,11 @@ struct sycl_copy_helper<sch::Array<Sch,Dims>, Encode> final { }).wait(); return saw::make_void(); } + + static saw::error_or<void> malloc_on_device(saw::data<Schema,Encode>& host_data, saw::data<Schema,encode::Sycl<Encode>>& sycl_data, sycl::queue& q){ + sycl_data = {host_data.meta(),q}; + return saw::make_void(); + } }; @@ -817,6 +909,16 @@ public: return impl::sycl_copy_helper<Sch,Encode>::copy_to_host(sycl_data, host_data, q_); } + template<typename Sch, typename Encode> + saw::error_or<void> malloc_on_device( + saw::data<Sch,Encode>& host_data, + saw::data<Sch,encode::Sycl<Encode>>& sycl_data + ){ + auto eov = impl::sycl_copy_helper<Sch,Encode>::malloc_on_device(host_data, sycl_data, q_); + q_.wait(); + return eov; + } + auto& get_handle(){ return q_; } diff --git a/lib/sycl/tests/data.cpp b/lib/sycl/tests/data.cpp index 3073a22..6b17622 100644 --- a/lib/sycl/tests/data.cpp +++ b/lib/sycl/tests/data.cpp @@ -2,6 +2,24 @@ #include "../c++/lbm.hpp" +namespace { + +namespace sch { +using namespace kel::lbm::sch; +using TestObjSchema = Tuple< + Member<FixedArray<UInt64,2u,2u>, "foo">, + Member<Array<Float32>, "bar">, + Member< + Array< + Struct< + Member<FixedArray<Float32,2u>,"pos"> + > + >, + "baz" + > +>; +} + SAW_TEST("Sycl Data Compilation"){ acpp::sycl::queue q; saw::data< @@ -19,3 +37,12 @@ SAW_TEST("Sycl Data Compilation"){ // test_f.at({}).set(1); // SAW_EXPECT(test_f.at({}).get() == 1, "Value check failed"); } + +SAW_TEST("Sycl Data Compilation for Particle Similacrum"){ + acpp::sycl::queue q; + + saw::data< + sch::TestObjSchema + > a; +} +} |
