summaryrefslogtreecommitdiff
path: root/lib/sycl/c++
diff options
context:
space:
mode:
Diffstat (limited to 'lib/sycl/c++')
-rw-r--r--lib/sycl/c++/data.hpp205
1 files changed, 172 insertions, 33 deletions
diff --git a/lib/sycl/c++/data.hpp b/lib/sycl/c++/data.hpp
index bb8b4bf..d9976dd 100644
--- a/lib/sycl/c++/data.hpp
+++ b/lib/sycl/c++/data.hpp
@@ -11,19 +11,115 @@ struct Sycl {
};
}
+/*
namespace impl {
template<typename Schema>
struct struct_has_only_equal_dimension_array{};
}
+*/
}
}
namespace saw {
+template<typename Sch, uint64_t... Dims, typename Encode>
+class data<schema::FixedArray<Sch,Dims...>, kel::lbm::encode::Sycl<Encode>> final {
+public:
+ using Schema = schema::FixedArray<Sch,Dims...>;
+private:
+ acpp::sycl::queue* q_;
+ data<Sch>* values_;
+
+ SAW_FORBID_COPY(data);
+public:
+ data(acpp::sycl::queue& q__):
+ q_{&q__},
+ values_{nullptr}
+ {
+ values_ = acpp::sycl::malloc<data<Sch>>(ct_multiply<uint64_t,Dims...>::value,q_);
+ }
+
+ ~data(){
+ if(not values_){
+ return;
+ }
+
+ acpp::sycl::free(values_,q_);
+ }
+
+ static constexpr data<schema::FixedArray<schema::UInt64, sizeof...(Dims)>> get_dims() {
+ return {std::array<uint64_t, sizeof...(Dims)>{Dims...}};
+ }
+
+ data<Sch>& at(data<schema::FixedArray<schema::UInt64,sizeof...(Dims)>>& index){
+ return values_[kel::lbm::flatten_index<schema::UInt64,sizeof...(Dims)>::apply(index,get_dims()).get()];
+ }
+
+ constexpr data<Sch,encode::Native>* flat_data() {
+ return values_;
+ }
+};
+
+template<typename Sch, uint64_t Ghost, uint64_t... Sides, typename Encode>
+class data<kel::lbm::sch::Chunk<Sch,Ghost,Sides...>,kel::lbm::encode::Sycl<Encode>> final {
+public:
+ using Schema = kel::lbm::sch::Chunk<Sch,Ghost,Sides...>;
+private:
+ using InnerSchema = typename Schema::InnerSchema;
+ using ValueSchema = typename InnerSchema::ValueType;
+
+ data<InnerSchema, kel::lbm::encode::Sycl<Encode>> values_;
+public:
+ data(acpp::sycl::queue& q__):
+ values_{q__}
+ {}
+
+ data<ValueSchema, kel::lbm::encode::Sycl<Encode>>& ghost_at(const data<schema::FixedArray<schema::UInt64,sizeof...(Sides)>>& index){
+ return values_.at(index);
+ }
+
+ const data<ValueSchema, kel::lbm::encode::Sycl<Encode>>& ghost_at(const data<schema::FixedArray<schema::UInt64,sizeof...(Sides)>>& index) const {
+ return values_.at(index);
+ }
+
+ static constexpr auto get_ghost_dims() {
+ return data<InnerSchema,kel::lbm::encode::Sycl<Encode>>::get_dims();
+ }
+
+ data<ValueSchema, kel::lbm::encode::Sycl<Encode>>& at(const data<schema::FixedArray<schema::UInt64,sizeof...(Sides)>>& index){
+ std::decay_t<decltype(index)> ind;
+ for(uint64_t i = 0u; i < sizeof...(Sides); ++i){
+ ind.at({i}) = index.at({i}) + Ghost;
+ }
+ return values_.at(ind);
+ }
+
+ const data<ValueSchema, kel::lbm::encode::Sycl<Encode>>& at(const data<schema::FixedArray<schema::UInt64,sizeof...(Sides)>>& index) const {
+ std::decay_t<decltype(index)> ind;
+ for(uint64_t i = 0u; i < sizeof...(Sides); ++i){
+ ind.at({i}) = index.at({i}) + Ghost;
+ }
+ return values_.at(ind);
+ }
+
+ static constexpr auto get_dims(){
+ return data<schema::FixedArray<schema::UInt64, sizeof...(Sides)>,kel::lbm::encode::Sycl<Encode>>{{Sides...}};
+ }
+
+ auto flat_data(){
+ return values_.flat_data();
+ }
+
+ static constexpr auto flat_size() {
+ return data<InnerSchema,kel::lbm::encode::Sycl<Encode>>::flat_size();
+ }
+};
+
template<uint64_t Ghost, uint64_t... Meta, typename... Sch, string_literal... Keys, typename Encode>
class data<schema::Struct<schema::Member<kel::lbm::sch::Chunk<Sch,Ghost,Meta...>, Keys>...>, kel::lbm::encode::Sycl<Encode> > final {
public:
static constexpr data<schema::FixedArray<schema::UInt64,sizeof...(Meta)>> meta = {{Meta...}};
- using StorageT = std::tuple<data<Sch,Encode>*...>;
+ using StorageT = std::tuple<data<kel::lbm::sch::Chunk<Sch,Ghost,Meta...>,kel::lbm::encode::Sycl<Encode>>...>;
+ using Schema = schema::Struct<schema::Member<kel::lbm::sch::Chunk<Sch,Ghost,Meta...>, Keys>...>;
private:
/**
@@ -31,43 +127,18 @@ private:
* Do it here by specializing.
*/
StorageT members_;
- kel::lbm::sycl::queue* q_;
public:
- data():
- members_{},
- q_{nullptr}
- {}
-
- data(StorageT members__, kel::lbm::sycl::queue& q__):
- members_{members__},
- q_{&q__}
+ data(acpp::sycl::queue& q__):
+ members_{{data<kel::lbm::sch::Chunk<Sch,Ghost,Meta...>,kel::lbm::encode::Sycl<Encode>>{q__}...}}
{}
- ~data(){
- SAW_ASSERT(q_){
- return;
- }
- std::visit([this](auto arg){
- if(not arg){
- return;
- }
- acpp::sycl::free(arg,*q_);
- arg = nullptr;
- },members_);
- }
-
- template<saw::string_literal K>
- auto* get_ptr(){
- return std::get<parameter_key_pack_index<K, Keys...>::value>(members_);
- }
-
template<saw::string_literal K>
auto& get(){
- auto ptr = get_ptr<K>();
- SAW_ASSERT(ptr);
- return *ptr;
+ return std::get<parameter_key_pack_index<K, Keys...>::value>(members_);
}
};
+
+
}
namespace kel {
@@ -81,7 +152,7 @@ struct sycl_malloc_struct_helper<sch::Struct<Members...>, Encode> final {
using Schema = sch::Struct<Members...>;
template<uint64_t i>
- static saw::error_or<void> allocate_on_device_member(typename saw::data<typename saw::parameter_pack_type<i,Members...>::type::ValueType,encode::Sycl<Encode>>::StorageT& storage, sycl::queue& q){
+ static saw::error_or<void> allocate_on_device_member(typename saw::data<Schema,encode::Sycl<Encode>>::StorageT& storage, sycl::queue& q){
if constexpr (i < sizeof...(Members)){
using M = typename saw::parameter_pack_type<i,Members...>::type;
auto& ptr = std::get<i>(storage);
@@ -103,6 +174,60 @@ struct sycl_malloc_struct_helper<sch::Struct<Members...>, Encode> final {
return eov;
}
};
+
+template<typename Sch, typename Encode>
+struct sycl_copy_helper;
+
+template<typename... Members, typename Encode>
+struct sycl_copy_helper<sch::Struct<Members...>, Encode> final {
+ using Schema = sch::Struct<Members...>;
+
+ template<uint64_t i>
+ static saw::error_or<void> copy_to_device_member(saw::data<Schema,Encode>& host_data, saw::data<Schema,encode::Sycl<Encode>>& sycl_data, sycl::queue& q){
+ if constexpr (i < sizeof...(Members)){
+ using M = typename saw::parameter_pack_type<i,Members...>::type;
+ auto& host_member_data = host_data.template get<i>();
+ auto& sycl_member_data = sycl_data.template get<i>();
+
+ auto host_ptr = host_member_data.flat_data();
+ auto sycl_ptr = sycl_member_data.flat_data();
+
+ q.memcpy(host_ptr, sycl_ptr, sizeof(std::decay_t<decltype(host_ptr)>) * host_member_data.flat_size() );
+
+ return copy_to_device_member<i+1u>(host_data,sycl_data,q);
+ }
+
+ return saw::make_void();
+ }
+
+ static saw::error_or<void> copy_to_device(saw::data<Schema,Encode>& host_data, saw::data<Schema, encode::Sycl<Encode>>& sycl_data, sycl::queue& q){
+
+ return copy_to_device_member<0u>(host_data, sycl_data, q);
+ }
+
+ template<uint64_t i>
+ static saw::error_or<void> copy_to_host_member(saw::data<Schema,encode::Sycl<Encode>>& sycl_data, saw::data<Schema,Encode>& host_data, sycl::queue& q){
+ if constexpr (i < sizeof...(Members)){
+ using M = typename saw::parameter_pack_type<i,Members...>::type;
+ auto& host_member_data = host_data.template get<i>();
+ auto& sycl_member_data = sycl_data.template get<i>();
+
+ auto host_ptr = host_member_data.flat_data();
+ auto sycl_ptr = sycl_member_data.flat_data();
+
+ q.memcpy(sycl_ptr, host_ptr, sizeof(std::decay_t<decltype(host_ptr)>) * host_member_data.flat_size() );
+
+ return copy_to_host_member<i+1u>(sycl_data,host_data,q);
+ }
+
+ return saw::make_void();
+ }
+
+
+ static saw::error_or<void> copy_to_host(saw::data<Schema,Encode>& sycl_data, saw::data<Schema,encode::Sycl<Encode>>& host_data, sycl::queue& q){
+ return copy_to_host_member<0u>(sycl_data, host_data, q);
+ }
+};
}
class device final {
private:
@@ -120,7 +245,21 @@ public:
if(eov.is_error()){
return eov;
}
- sycl_data.set_queue(q_);
+ return saw::make_void();
+ }
+
+ template<typename Sch, typename Encode>
+ saw::error_or<void> copy_to_device(saw::data<Sch,Encode>& host_data, saw::data<Sch,encode::Sycl<Encode>>& sycl_data){
+ return impl::sycl_copy_helper<Sch,Encode>::copy_to_device(host_data, sycl_data, q_);
+ }
+
+ template<typename Sch, typename Encode>
+ saw::error_or<void> copy_to_host(saw::data<Sch,encode::Sycl<Encode>>& sycl_data, saw::data<Sch,Encode>& host_data){
+ return impl::sycl_copy_helper<Sch,Encode>::copy_to_host(sycl_data, host_data, q_);
+ }
+
+ auto& get_handle(){
+ return q_;
}
};
}