summaryrefslogtreecommitdiff
path: root/modules/remote-sycl
diff options
context:
space:
mode:
authorClaudius "keldu" Holeksa <mail@keldu.de>2025-11-18 21:56:06 +0100
committerClaudius "keldu" Holeksa <mail@keldu.de>2025-11-18 21:56:06 +0100
commit2073aef795f74e5c24b7992d6c2f0fadde3fa271 (patch)
tree755f89cbe50260d375cbc255a46f5630f9670a1e /modules/remote-sycl
parent2116b46f784aa9bc0dcbb1b6bfc22183d979b919 (diff)
downloadforstio-forstio-2073aef795f74e5c24b7992d6c2f0fadde3fa271.tar.gz
Making ref work?
Diffstat (limited to 'modules/remote-sycl')
-rw-r--r--modules/remote-sycl/c++/data.hpp16
-rw-r--r--modules/remote-sycl/tests/data_ref.cpp13
2 files changed, 27 insertions, 2 deletions
diff --git a/modules/remote-sycl/c++/data.hpp b/modules/remote-sycl/c++/data.hpp
index 149e3f2..763e4c3 100644
--- a/modules/remote-sycl/c++/data.hpp
+++ b/modules/remote-sycl/c++/data.hpp
@@ -50,7 +50,7 @@ public:
data_{size_.get(),sycl_alloc_}
{}
- auto* get_internal_data() {
+ data<Sch,Encode>* get_internal_data() {
if(data_.empty()){
return nullptr;
}
@@ -135,11 +135,23 @@ private:
public:
data() = delete;
- data(ref<data<schema::FixedArray<schema::UInt64, Dim>, Encode>> ref_data__):
+ data(ref<data<schema::Array<Sch, Dim>, encode::Sycl<Encode>>> ref_data__):
internal_data_ptr_{ref_data__().get_internal_data()},
dims_{ref_data__().dims()},
size_{ref_data__().size()}
{}
+
+ auto* get_internal_data() {
+ return internal_data_ptr_;
+ }
+
+ constexpr data<Sch, Encode>& at(const data<schema::FixedArray<schema::UInt64, Dim>, Encode>& i){
+ return internal_data_ptr_[this->get_flat_index(i)];
+ }
+
+ constexpr const data<Sch, Encode>& at(const data<schema::FixedArray<schema::UInt64, Dim>, Encode>& i)const{
+ return internal_data_ptr_[this->get_flat_index(i)];
+ }
private:
template<typename U>
uint64_t get_flat_index(const U& i) const {
diff --git a/modules/remote-sycl/tests/data_ref.cpp b/modules/remote-sycl/tests/data_ref.cpp
index 61a2c8e..8fb5be4 100644
--- a/modules/remote-sycl/tests/data_ref.cpp
+++ b/modules/remote-sycl/tests/data_ref.cpp
@@ -13,6 +13,19 @@ SAW_TEST("Data Ref Basics"){
acpp::sycl::queue sycl_q;
data<sch::Array<sch::Float64>, encode::Sycl<encode::Native>> dat{{{100u}},sycl_q};
+
+ data<sch::Ref<sch::Array<sch::Float64>>, encode::Sycl<encode::Native>> dat_ref{dat};
+ auto dat_ptr = dat_ref.get_internal_data();
+
+ sycl_q.parallel_for(100u, [=](acpp::sycl::id<1> idx){
+ size_t i = idx[0];
+
+ dat_ptr[i] = {static_cast<double>(i)};
+ }).wait();
+
+ for(uint64_t i = 0u; i < 100u; ++i){
+ SAW_EXPECT(dat_ptr[i].get() == i, std::string{"Unexpected value: "} + std::to_string(i));
+ }
}
}