diff options
| author | Claudius "keldu" Holeksa <mail@keldu.de> | 2025-11-18 21:56:06 +0100 |
|---|---|---|
| committer | Claudius "keldu" Holeksa <mail@keldu.de> | 2025-11-18 21:56:06 +0100 |
| commit | 2073aef795f74e5c24b7992d6c2f0fadde3fa271 (patch) | |
| tree | 755f89cbe50260d375cbc255a46f5630f9670a1e /modules | |
| parent | 2116b46f784aa9bc0dcbb1b6bfc22183d979b919 (diff) | |
| download | forstio-forstio-2073aef795f74e5c24b7992d6c2f0fadde3fa271.tar.gz | |
Making ref work?
Diffstat (limited to 'modules')
| -rw-r--r-- | modules/remote-sycl/c++/data.hpp | 16 | ||||
| -rw-r--r-- | modules/remote-sycl/tests/data_ref.cpp | 13 |
2 files changed, 27 insertions, 2 deletions
diff --git a/modules/remote-sycl/c++/data.hpp b/modules/remote-sycl/c++/data.hpp index 149e3f2..763e4c3 100644 --- a/modules/remote-sycl/c++/data.hpp +++ b/modules/remote-sycl/c++/data.hpp @@ -50,7 +50,7 @@ public: data_{size_.get(),sycl_alloc_} {} - auto* get_internal_data() { + data<Sch,Encode>* get_internal_data() { if(data_.empty()){ return nullptr; } @@ -135,11 +135,23 @@ private: public: data() = delete; - data(ref<data<schema::FixedArray<schema::UInt64, Dim>, Encode>> ref_data__): + data(ref<data<schema::Array<Sch, Dim>, encode::Sycl<Encode>>> ref_data__): internal_data_ptr_{ref_data__().get_internal_data()}, dims_{ref_data__().dims()}, size_{ref_data__().size()} {} + + auto* get_internal_data() { + return internal_data_ptr_; + } + + constexpr data<Sch, Encode>& at(const data<schema::FixedArray<schema::UInt64, Dim>, Encode>& i){ + return internal_data_ptr_[this->get_flat_index(i)]; + } + + constexpr const data<Sch, Encode>& at(const data<schema::FixedArray<schema::UInt64, Dim>, Encode>& i)const{ + return internal_data_ptr_[this->get_flat_index(i)]; + } private: template<typename U> uint64_t get_flat_index(const U& i) const { diff --git a/modules/remote-sycl/tests/data_ref.cpp b/modules/remote-sycl/tests/data_ref.cpp index 61a2c8e..8fb5be4 100644 --- a/modules/remote-sycl/tests/data_ref.cpp +++ b/modules/remote-sycl/tests/data_ref.cpp @@ -13,6 +13,19 @@ SAW_TEST("Data Ref Basics"){ acpp::sycl::queue sycl_q; data<sch::Array<sch::Float64>, encode::Sycl<encode::Native>> dat{{{100u}},sycl_q}; + + data<sch::Ref<sch::Array<sch::Float64>>, encode::Sycl<encode::Native>> dat_ref{dat}; + auto dat_ptr = dat_ref.get_internal_data(); + + sycl_q.parallel_for(100u, [=](acpp::sycl::id<1> idx){ + size_t i = idx[0]; + + dat_ptr[i] = {static_cast<double>(i)}; + }).wait(); + + for(uint64_t i = 0u; i < 100u; ++i){ + SAW_EXPECT(dat_ptr[i].get() == i, std::string{"Unexpected value: "} + std::to_string(i)); + } } } |
