summaryrefslogtreecommitdiff
path: root/lib
diff options
context:
space:
mode:
Diffstat (limited to 'lib')
-rw-r--r--lib/core/c++/abstract/data.hpp26
-rw-r--r--lib/sycl/c++/data.hpp112
-rw-r--r--lib/sycl/tests/data.cpp27
3 files changed, 157 insertions, 8 deletions
diff --git a/lib/core/c++/abstract/data.hpp b/lib/core/c++/abstract/data.hpp
index e8f1757..0075718 100644
--- a/lib/core/c++/abstract/data.hpp
+++ b/lib/core/c++/abstract/data.hpp
@@ -4,28 +4,48 @@
namespace kel {
namespace sch {
+struct Void {};
+
struct UnsignedInteger {};
struct SignedInteger {};
struct FloatingPoint {};
template<typename StorageT, typename InterfaceT>
struct MixedPrecision {
+ using Meta = Void;
using StorageType = StorageT;
using InterfaceType = InterfaceT;
};
template<typename PrimType, uint64_t N>
struct Primitive {
- using PrimitiveType = PrimType;
+ using Meta = Void;
+ using Type = PrimType;
static constexpr uint64_t Bytes = N;
};
template<typename T, uint64_t... Dims>
-struct Array {
- using InnerType = T;
+struct FixedArray {
+ using Meta = Void;
+ using Inner = T;
static constexpr std::array<uint64_t,sizeof...(Dims)> Dimensions{Dims...};
};
+template<typename T, uint64_t Dims>
+struct Array {
+ using Meta = FixedArray<UInt64,Dims>;
+ using Inner = T;
+ static constexpr std::array<uint64_t,sizeof...(Dims)> Dimensions{Dims};
+};
+
+template<typename... T>
+struct Tuple {
+};
}
+
+template<typename Sch>
+struct schema {
+ using Type = Sch;
+};
}
diff --git a/lib/sycl/c++/data.hpp b/lib/sycl/c++/data.hpp
index 4e5129b..3ac51e0 100644
--- a/lib/sycl/c++/data.hpp
+++ b/lib/sycl/c++/data.hpp
@@ -29,8 +29,17 @@ private:
SAW_FORBID_COPY(data);
SAW_FORBID_MOVE(data);
-
public:
+ data(const data<typename meta_schema<Schema>::MetaSchema>& meta__, acpp::sycl::queue& q__):
+ q_{&q__},
+ values_{nullptr}
+ {
+ (void) meta__;
+ SAW_ASSERT(q_);
+ values_ = acpp::sycl::malloc_device<data<Sch,Encode>>(ct_multiply<uint64_t,Dims...>::value,*q_);
+ SAW_ASSERT(values_);
+ }
+
data(acpp::sycl::queue& q__):
q_{&q__},
values_{nullptr}
@@ -65,7 +74,6 @@ public:
constexpr data<Sch,Encode>* flat_data() const {
return values_;
}
-
};
template<typename Sch, uint64_t... Dims, typename Encode>
@@ -122,6 +130,20 @@ private:
SAW_FORBID_COPY(data);
SAW_FORBID_MOVE(data);
public:
+ data(const data<typename meta_schema<Schema>::MetaSchema>& meta__, acpp::sycl::queue& q__):
+ q_{&q__},
+ values_{nullptr}
+ {
+ SAW_ASSERT(q_);
+ /// TODO use meta
+ data<schema::UInt64> m{1u};
+ for(uint64_t i = 0u; i < Dims; ++i){
+ m = m * meta__.at({i});
+ }
+ values_ = acpp::sycl::malloc_device<data<Sch,Encode>>(m.get(),*q_);
+ SAW_ASSERT(values_);
+ }
+
data(acpp::sycl::queue& q__):
values_{nullptr},
meta_{},
@@ -198,8 +220,7 @@ public:
data(const data<schema::Array<Sch,Dims>, kel::lbm::encode::Sycl<Encode>>& values__):
values_{values__.flat_data()},
meta_{values__.meta()}
- {
- }
+ {}
constexpr data<schema::FixedArray<schema::UInt64, Dims>, Encode> meta() const {
return meta_;
@@ -238,6 +259,10 @@ private:
data<InnerSchema, kel::lbm::encode::Sycl<Encode>> values_;
public:
+ data(const data<typename meta_schema<Schema>::MetaSchema>& meta__, acpp::sycl::queue& q__):
+ values_{meta__,q__}
+ {}
+
data(acpp::sycl::queue& q__):
values_{q__}
{}
@@ -366,6 +391,11 @@ private:
members_{(static_cast<void>(Is), q)...}
{}
public:
+ data(data<typename meta_schema<Schema>::MetaSchema>& meta__, acpp::sycl::queue& q__):
+ data{q__, std::make_index_sequence<sizeof...(Members)>{}}
+ {
+ }
+
data(acpp::sycl::queue& q__):
data{q__, std::make_index_sequence<sizeof...(Members)>{}}
{
@@ -546,6 +576,29 @@ struct sycl_copy_helper<sch::Struct<Members...>, Encode> final {
static saw::error_or<void> copy_to_host(saw::data<Schema,encode::Sycl<Encode>>& sycl_data, saw::data<Schema,Encode>& host_data, sycl::queue& q){
return copy_to_host_member<0u>(sycl_data, host_data, q);
}
+
+ template<uint64_t i>
+ static saw::error_or<void> malloc_on_device_member(saw::data<Schema,Encode>& host_data, saw::data<Schema,encode::Sycl<Encode>>& sycl_data, sycl::queue& q){
+
+ if constexpr (i < sizeof...(Members)){
+ using M = typename saw::parameter_pack_type<i,Members...>::type;
+ auto& host_member_data = host_data.template get<i>();
+ auto& sycl_member_data = sycl_data.template get<i>();
+
+ auto eov = sycl_copy_helper<typename M::ValueType,Encode>::malloc_on_device(host_member_data,sycl_member_data,q);
+ if(eov.is_error()){
+ return eov;
+ }
+
+ return malloc_on_device_member<i+1u>(host_data,sycl_data,q);
+ }
+
+ return saw::make_void();
+ }
+
+ static saw::error_or<void> malloc_on_device(saw::data<Schema,Encode>& host_data, saw::data<Schema,encode::Sycl<Encode>>& sycl_data, sycl::queue& q){
+ return malloc_on_device_member<0u>(host_data,sycl_data,q);
+ }
};
template<typename... Members, typename Encode>
@@ -594,9 +647,36 @@ struct sycl_copy_helper<sch::Tuple<Members...>, Encode> final {
}
- static saw::error_or<void> copy_to_host(saw::data<Schema,encode::Sycl<Encode>>& sycl_data, saw::data<Schema,Encode>& host_data, sycl::queue& q){
+ static saw::error_or<void> copy_to_host(
+ saw::data<Schema,Encode>& host_data,
+ saw::data<Schema,encode::Sycl<Encode>>& sycl_data,
+ sycl::queue& q
+ ){
return copy_to_host_member<0u>(sycl_data, host_data, q);
}
+
+ template<uint64_t i>
+ static saw::error_or<void> malloc_on_device_member(saw::data<Schema,Encode>& host_data, saw::data<Schema,encode::Sycl<Encode>>& sycl_data, sycl::queue& q){
+
+ if constexpr (i < sizeof...(Members)){
+ using M = typename saw::parameter_pack_type<i,Members...>::type;
+ auto& host_member_data = host_data.template get<i>();
+ auto& sycl_member_data = sycl_data.template get<i>();
+
+ auto eov = sycl_copy_helper<M,Encode>::malloc_on_device(host_member_data,sycl_member_data,q);
+ if(eov.is_error()){
+ return eov;
+ }
+
+ return malloc_on_device_member<i+1u>(host_data,sycl_data,q);
+ }
+
+ return saw::make_void();
+ }
+
+ static saw::error_or<void> malloc_on_device(saw::data<Schema,Encode>& host_data, saw::data<Schema,encode::Sycl<Encode>>& sycl_data, sycl::queue& q){
+ return malloc_on_device_member<0u>(host_data,sycl_data,q);
+ }
};
template<typename Sch, uint64_t... Dims, typename Encode>
@@ -626,6 +706,13 @@ struct sycl_copy_helper<sch::FixedArray<Sch,Dims...>, Encode> final {
}).wait();
return saw::make_void();
}
+
+ static saw::error_or<void> malloc_on_device(saw::data<Schema,Encode>& host_data, saw::data<Schema,encode::Sycl<Encode>>& sycl_data, sycl::queue& q){
+ (void) host_data;
+ (void) sycl_data;
+ (void) q;
+ return saw::make_void();
+ }
};
template<typename Sch, uint64_t Ghost, uint64_t... Dims, typename Encode>
@@ -684,6 +771,11 @@ struct sycl_copy_helper<sch::Array<Sch,Dims>, Encode> final {
}).wait();
return saw::make_void();
}
+
+ static saw::error_or<void> malloc_on_device(saw::data<Schema,Encode>& host_data, saw::data<Schema,encode::Sycl<Encode>>& sycl_data, sycl::queue& q){
+ sycl_data = {host_data.meta(),q};
+ return saw::make_void();
+ }
};
@@ -817,6 +909,16 @@ public:
return impl::sycl_copy_helper<Sch,Encode>::copy_to_host(sycl_data, host_data, q_);
}
+ template<typename Sch, typename Encode>
+ saw::error_or<void> malloc_on_device(
+ saw::data<Sch,Encode>& host_data,
+ saw::data<Sch,encode::Sycl<Encode>>& sycl_data
+ ){
+ auto eov = impl::sycl_copy_helper<Sch,Encode>::malloc_on_device(host_data, sycl_data, q_);
+ q_.wait();
+ return eov;
+ }
+
auto& get_handle(){
return q_;
}
diff --git a/lib/sycl/tests/data.cpp b/lib/sycl/tests/data.cpp
index 3073a22..6b17622 100644
--- a/lib/sycl/tests/data.cpp
+++ b/lib/sycl/tests/data.cpp
@@ -2,6 +2,24 @@
#include "../c++/lbm.hpp"
+namespace {
+
+namespace sch {
+using namespace kel::lbm::sch;
+using TestObjSchema = Tuple<
+ Member<FixedArray<UInt64,2u,2u>, "foo">,
+ Member<Array<Float32>, "bar">,
+ Member<
+ Array<
+ Struct<
+ Member<FixedArray<Float32,2u>,"pos">
+ >
+ >,
+ "baz"
+ >
+>;
+}
+
SAW_TEST("Sycl Data Compilation"){
acpp::sycl::queue q;
saw::data<
@@ -19,3 +37,12 @@ SAW_TEST("Sycl Data Compilation"){
// test_f.at({}).set(1);
// SAW_EXPECT(test_f.at({}).get() == 1, "Value check failed");
}
+
+SAW_TEST("Sycl Data Compilation for Particle Similacrum"){
+ acpp::sycl::queue q;
+
+ saw::data<
+ sch::TestObjSchema
+ > a;
+}
+}