summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--modules/codec/c++/data.hpp4
-rw-r--r--modules/remote-sycl/c++/device.hpp10
-rw-r--r--modules/remote-sycl/c++/remote.hpp4
-rw-r--r--modules/remote-sycl/tests/data_ref.cpp25
-rw-r--r--modules/remote-sycl/tests/remote.foo (renamed from modules/remote-sycl/tests/remote.cpp)2
5 files changed, 33 insertions, 12 deletions
diff --git a/modules/codec/c++/data.hpp b/modules/codec/c++/data.hpp
index 585501b..a06acdf 100644
--- a/modules/codec/c++/data.hpp
+++ b/modules/codec/c++/data.hpp
@@ -365,6 +365,10 @@ public:
constexpr data<schema::FixedArray<schema::UInt64, sizeof...(D)>> get_dims() const {
return {std::array<uint64_t, sizeof...(D)>{D...}};
}
+
+ constexpr data<schema::UInt64> flat_dims() const {
+ return {ct_multiply<uint64_t, D...>::value};
+ }
private:
constexpr uint64_t get_flat_index(const data<schema::FixedArray<schema::UInt64, sizeof...(D)>>& i) const {
uint64_t s = 0;
diff --git a/modules/remote-sycl/c++/device.hpp b/modules/remote-sycl/c++/device.hpp
index 05bb17a..6667711 100644
--- a/modules/remote-sycl/c++/device.hpp
+++ b/modules/remote-sycl/c++/device.hpp
@@ -22,13 +22,16 @@ public:
* Copy data to device
*/
template<typename Schema, typename Encoding>
- error_or<data<Schema, encode::Sycl<Encoding>>> copy_to_device(const data<Schema, Encoding>& host_data){
+ error_or<data<Schema, encode::Sycl<Encoding>>> copy_to_device(const data<Schema, Encoding>& host_data, data<Schema, encode::Sycl<Encoding>>& sycl_data){
+
return data<Schema, encode::Sycl<Encoding>>{host_data};
}
template<typename Schema, typename Encoding>
error_or<data<Schema, encode::Sycl<Encoding>>> allocate_on_device(const data<typename meta_schema<Schema>::MetaSchema, Encoding>& host_meta){
- return copy_to_device(data<Schema, Encoding>{host_meta});
+ data<Schema,Encoding> host_data{host_meta};
+ data<Schema,encode::Sycl<Encoding>> sycl_dat{host_data};
+ return sycl_dat;
}
/**
@@ -44,8 +47,7 @@ public:
});
cmd_queue_.wait();
*/
- acpp::sycl::host_accessor result{dev_data.get_handle()};
- return result[0];
+ return make_error<err::not_implemented>("device<rmt::Sycl>::copy_to_host");
}
/**
diff --git a/modules/remote-sycl/c++/remote.hpp b/modules/remote-sycl/c++/remote.hpp
index 65f645e..ef11d50 100644
--- a/modules/remote-sycl/c++/remote.hpp
+++ b/modules/remote-sycl/c++/remote.hpp
@@ -1,6 +1,8 @@
#pragma once
#include "common.hpp"
+#include "transfer.hpp"
+#include "rpc.hpp"
namespace saw {
@@ -72,7 +74,7 @@ public:
}
/**
- * Spin up data server
+ * Spin up a data server
*/
template<typename Schema, typename Encoding>
error_or<own<data_server<Schema, Encoding, rmt::Sycl>>> data_listen(remote_address<rmt::Sycl>& dev){
diff --git a/modules/remote-sycl/tests/data_ref.cpp b/modules/remote-sycl/tests/data_ref.cpp
index e92c693..7f5bb1b 100644
--- a/modules/remote-sycl/tests/data_ref.cpp
+++ b/modules/remote-sycl/tests/data_ref.cpp
@@ -10,14 +10,23 @@ using namespace saw::schema;
SAW_TEST("Data Ref Basics"){
using namespace saw;
- acpp::sycl::queue sycl_q;
+ device<rmt::Sycl> dev;
+ acpp::sycl::queue& sycl_q = dev.get_handle();
constexpr uint64_t dat_size = 1000u;
- data<sch::Array<sch::UInt64>, encode::Sycl<encode::Native>> dat{{{dat_size}},sycl_q};
+ data<sch::Array<sch::UInt64>, encode::Native> dat{{{dat_size}}};
+ auto eo_syc_dat = dev.template allocate_on_device<sch::Array<sch::UInt64>,encode::Native>({{dat_size}});
+ SAW_EXPECT(eo_syc_dat.is_value(), "Couldn't allocate on device");
+ auto& sycl_dat = eo_syc_dat.get_value();
- data<sch::Ref<sch::Array<sch::UInt64>>, encode::Sycl<encode::Native>> dat_ref{dat};
- auto dat_ptr = dat_ref.get_internal_data();
+ {
+ auto eov = dev.copy_to_device(dat,sycl_dat);
+ SAW_EXPECT_EOV(eov);
+ }
+
+ data<sch::Ref<sch::Array<sch::UInt64>>, encode::Sycl<encode::Native>> sdat_ref{sycl_dat};
+ auto dat_ptr = sdat_ref.get_internal_data();
sycl_q.parallel_for(dat_size, [=](acpp::sycl::id<1> idx){
size_t i = idx[0];
@@ -25,8 +34,12 @@ SAW_TEST("Data Ref Basics"){
dat_ptr[i] = {i};
}).wait();
- for(uint64_t i = 0u; i < dat_size; ++i){
- SAW_EXPECT(dat_ptr[i].get() == i, std::string{"Unexpected value: "} + std::to_string(i));
+ {
+ auto eov = dev.copy_to_host(sycl_dat,dat);
+ SAW_EXPECT_EOV(eov);
+ }
+ for(saw::data<sch::UInt64> i = 0u; i < saw::data<sch::UInt64>{dat_size}; ++i){
+ SAW_EXPECT(dat.at({i}) == i, std::string{"Unexpected value: "} + std::to_string(i.get()));
}
}
}
diff --git a/modules/remote-sycl/tests/remote.cpp b/modules/remote-sycl/tests/remote.foo
index e580f17..698c333 100644
--- a/modules/remote-sycl/tests/remote.cpp
+++ b/modules/remote-sycl/tests/remote.foo
@@ -17,7 +17,7 @@ SAW_TEST("Remote Basics"){
auto& sycl_addr = eo_sycl_addr.get_value();
- auto eo_dat_srv = sycl_rmt.data_listen(sycl_addr);
+ auto eo_dat_srv = sycl_rmt.data_listen<sch::Struct<>,saw::encode::Native>(*sycl_addr);
SAW_EOV_EXPECT(eo_dat_srv);
auto& dat_srv = eo_dat_srv.get_value();