summaryrefslogtreecommitdiff
path: root/modules
diff options
context:
space:
mode:
authorClaudius "keldu" Holeksa <mail@keldu.de>2025-11-18 17:46:04 +0100
committerClaudius "keldu" Holeksa <mail@keldu.de>2025-11-18 17:46:04 +0100
commit668e53e42e210d2cedf29281eb187e8d7f129651 (patch)
treeda3034eb02203cd9882c6de13ec556aaf6ac19ae /modules
parente2071b41bf8547057c485fea2a8d3aed7bb710ed (diff)
downloadforstio-forstio-668e53e42e210d2cedf29281eb187e8d7f129651.tar.gz
Working on tests in sycl
Diffstat (limited to 'modules')
-rw-r--r--modules/remote-sycl/.nix/derivation.nix2
-rw-r--r--modules/remote-sycl/c++/common.hpp1
-rw-r--r--modules/remote-sycl/c++/data.hpp63
-rw-r--r--modules/remote-sycl/tests/data.foo (renamed from modules/remote-sycl/tests/data.cpp)0
-rw-r--r--modules/remote-sycl/tests/data_ref.cpp18
-rw-r--r--modules/remote-sycl/tests/mixed_precision.foo (renamed from modules/remote-sycl/tests/mixed_precision.cpp)0
-rw-r--r--modules/remote-sycl/tests/sycl_basics.cpp31
7 files changed, 95 insertions, 20 deletions
diff --git a/modules/remote-sycl/.nix/derivation.nix b/modules/remote-sycl/.nix/derivation.nix
index 688af18..5839acb 100644
--- a/modules/remote-sycl/.nix/derivation.nix
+++ b/modules/remote-sycl/.nix/derivation.nix
@@ -48,7 +48,7 @@ in stdenv.mkDerivation {
scons prefix=$out build_benchmarks=${build_benchmarks} build_examples=${build_examples} install
'';
- doCheck = false;
+ doCheck = true;
checkPhase = ''
export ACPP_APPDB_DIR=.
scons test
diff --git a/modules/remote-sycl/c++/common.hpp b/modules/remote-sycl/c++/common.hpp
index 287075f..54a09d1 100644
--- a/modules/remote-sycl/c++/common.hpp
+++ b/modules/remote-sycl/c++/common.hpp
@@ -11,6 +11,7 @@ namespace saw {
namespace rmt {
struct Sycl {};
}
+
namespace encode {
template<typename Inner>
struct Sycl {};
diff --git a/modules/remote-sycl/c++/data.hpp b/modules/remote-sycl/c++/data.hpp
index 11dfbf2..e057766 100644
--- a/modules/remote-sycl/c++/data.hpp
+++ b/modules/remote-sycl/c++/data.hpp
@@ -57,7 +57,7 @@ public:
return &(data_[0u]);
}
- const auto& get_internal_size() const {
+ auto get_internal_size() const {
return size_;
}
@@ -112,4 +112,63 @@ private:
return s;
}
};
-}
+
+template<typename Sch, uint64_t Dim, typename Encode>
+class data<schema::Ref<schema::Array<Sch, Dim>>, encode::Sycl<Encode>> {
+public:
+ using Schema = schema::Ref<schema::Array<Sch,Dim>>;
+private:
+ data<schema::Sch, Encode>* internal_data_ptr_;
+ data<schema::FixedArray<schema::UInt64, Dim>, Encode> dims_;
+ data<schema::UInt64, Encode> size_;
+
+ uint64_t get_full_size() const {
+ uint64_t s = 1;
+
+ for(uint64_t iter = 0; iter < Dim; ++iter){
+ auto& dim_iter = dims_.at(data<schema::UInt64>{iter});
+ s *= dim_iter.get();
+ }
+
+ return s;
+ }
+public:
+ data() = delete;
+
+ data(ref<data<schema::FixedArray<schema::UInt64, Dim>, Encode>> ref_data__):
+ internal_data_ptr_{ref_data__().get_internal_data()},
+ dims_{ref_data__().dims()},
+ size_{ref_data__().size()}
+ {}
+private:
+ template<typename U>
+ uint64_t get_flat_index(const U& i) const {
+ static_assert(
+ std::is_same_v<U,data<schema::FixedArray<schema::UInt64,Dim>, Encode>> or
+ std::is_same_v<U,std::array<uint64_t,Dim>>,
+ "Unsupported type"
+ );
+ assert(data_.size() == get_full_size());
+ uint64_t s = 0;
+
+ uint64_t stride = 1;
+
+ for(uint64_t iter = 0; iter < Dim; ++iter){
+ uint64_t ind = [](auto val) -> uint64_t {
+ using V = std::decay_t<decltype(val)>;
+ if constexpr (std::is_same_v<V,data<schema::UInt64>>){
+ return val.get();
+ }else if constexpr (std::is_same_v<V, uint64_t>){
+ return val;
+ }else{
+ static_assert(always_false<V>, "Cases exhausted");
+ }
+ }(i.at(iter));
+ assert(ind < dims_.at({iter}).get() );
+ s += ind * stride;
+ stride *= dims_.at(iter).get();
+ }
+
+ return s;
+ }
+};
diff --git a/modules/remote-sycl/tests/data.cpp b/modules/remote-sycl/tests/data.foo
index 798b7a5..798b7a5 100644
--- a/modules/remote-sycl/tests/data.cpp
+++ b/modules/remote-sycl/tests/data.foo
diff --git a/modules/remote-sycl/tests/data_ref.cpp b/modules/remote-sycl/tests/data_ref.cpp
new file mode 100644
index 0000000..03afb8f
--- /dev/null
+++ b/modules/remote-sycl/tests/data_ref.cpp
@@ -0,0 +1,18 @@
+#include <forstio/test/suite.hpp>
+
+#include "../c++/data.hpp"
+
+namespace {
+namespace sch {
+using namespace saw::schema;
+}
+
+SAW_TEST("Data Ref Basics"){
+ using namespace saw;
+
+ acpp::sycl::queue sycl_q;
+
+ data<sch::Array<sch::Float64>, encode::Sycl<encode::Native>> dat{{{100u}},sycl_q};
+
+}
+}
diff --git a/modules/remote-sycl/tests/mixed_precision.cpp b/modules/remote-sycl/tests/mixed_precision.foo
index 4a5218d..4a5218d 100644
--- a/modules/remote-sycl/tests/mixed_precision.cpp
+++ b/modules/remote-sycl/tests/mixed_precision.foo
diff --git a/modules/remote-sycl/tests/sycl_basics.cpp b/modules/remote-sycl/tests/sycl_basics.cpp
index 4ad3cf7..970f4d6 100644
--- a/modules/remote-sycl/tests/sycl_basics.cpp
+++ b/modules/remote-sycl/tests/sycl_basics.cpp
@@ -18,6 +18,14 @@ using Foo = Interface<
Member<Function<TestStruct, Void>, "foo">
>;
}
+SAW_TEST("SYCL Basics"){
+ using namespace saw;
+
+ acpp::sycl::queue q;
+ data<schema::TestStruct,encode::Sycl<encode::Native>> host_data;
+}
+
+/*
SAW_TEST("SYCL Test Setup"){
using namespace saw;
@@ -41,33 +49,21 @@ SAW_TEST("SYCL Test Setup"){
data<schema::TestStruct, encode::Sycl<encode::Native>> device_data{host_data};
- interface<schema::Foo, encode::Sycl<encode::Native>,cl::sycl::queue*> cl_iface {
-[&](data<schema::TestStruct, encode::Sycl<encode::Native>>& in, cl::sycl::queue* cmd) -> error_or<void> {
+ interface<schema::Foo, encode::Sycl<encode::Native>,acpp::sycl::queue*> cl_iface {
+[&](data<schema::TestStruct, encode::Sycl<encode::Native>>& in, acpp::sycl::queue* cmd) -> error_or<void> {
- cmd->submit([&](cl::sycl::handler& h){
+ cmd->submit([&](acpp::sycl::handler& h){
- auto acc_buff = in.template access<cl::sycl::access::mode::write>(h);
+ auto acc_buff = in.template access<acpp::sycl::access::mode::write>(h);
auto si = host_data.template get<"doubles">().size();
- h.parallel_for(cl::sycl::range<1>(si.get()), [=] (cl::sycl::id<1> it){
+ h.parallel_for(acpp::sycl::range<1>(si.get()), [=] (acpp::sycl::id<1> it){
acc_buff[0u].template get<"foo">() = acc_buff[0u].template get<"doubles">().size();
auto& dbls = acc_buff[0u].template get<"doubles">();
dbls.at(it[0u]) = it[0u] * 2.0;
});
});
- /*
- cmd->submit([&](cl::sycl::handler& h){
- auto acc_buff = in.template access<cl::sycl::access::mode::read>(h);
- h.copy(acc_buff, &host_data);
- });
- */
-
- /**
- cl::sycl::host_accessor result{in.get_handle()};
- std::cout<<result[0u].template get<"foo">().get()<<std::endl;
- std::cout<<result[0u].template get<"bar">().get()<<std::endl;
- **/
return saw::void_t{};
}
};
@@ -77,4 +73,5 @@ SAW_TEST("SYCL Test Setup"){
cl_iface.template call <"foo">(device_data, &(device.get_handle()));
device.get_handle().wait();
}
+*/
}