summaryrefslogtreecommitdiff
path: root/modules/remote-sycl
diff options
context:
space:
mode:
authorClaudius "keldu" Holeksa <mail@keldu.de>2024-06-21 19:44:34 +0200
committerClaudius "keldu" Holeksa <mail@keldu.de>2024-06-21 19:44:34 +0200
commit86b06a3fee2cd7635a9ab486e2a35bdf1e81ce38 (patch)
tree5485b323cdce1c1347f1a20c7f33e8f772c73dbf /modules/remote-sycl
parent601113a445658d8b15273dd91c66cf20daf50d30 (diff)
Moving forward with basic test for sycl
Diffstat (limited to 'modules/remote-sycl')
-rw-r--r--modules/remote-sycl/.nix/derivation.nix2
-rw-r--r--modules/remote-sycl/c++/remote.hpp71
-rw-r--r--modules/remote-sycl/tests/calculator.foo (renamed from modules/remote-sycl/tests/calculator.cpp)2
-rw-r--r--modules/remote-sycl/tests/sycl_basics.cpp96
4 files changed, 133 insertions, 38 deletions
diff --git a/modules/remote-sycl/.nix/derivation.nix b/modules/remote-sycl/.nix/derivation.nix
index 488b8a8..2247ec0 100644
--- a/modules/remote-sycl/.nix/derivation.nix
+++ b/modules/remote-sycl/.nix/derivation.nix
@@ -49,7 +49,7 @@ in stdenv.mkDerivation {
scons prefix=$out build_examples=${build_examples} install
'';
- doCheck = false;
+ doCheck = true;
checkPhase = ''
scons test
./bin/tests
diff --git a/modules/remote-sycl/c++/remote.hpp b/modules/remote-sycl/c++/remote.hpp
index 54b7a7b..24756be 100644
--- a/modules/remote-sycl/c++/remote.hpp
+++ b/modules/remote-sycl/c++/remote.hpp
@@ -11,11 +11,30 @@ class remote<rmt::Sycl>;
template<typename T>
class device;
+template<typename Schema>
+class data<Schema, encode::Native, rmt::Sycl> {
+private:
+ cl::sycl::buffer<data<Schema, encode::Native, storage::Default>> data_;
+public:
+ data(data<Schema, encode::Native, storage::Default>& data__):
+ data_{&data__, 1u}
+ {}
+
+ auto& get_handle() {
+ return data_;
+ }
+
+ template<cl::sycl::access::mode AccessMode>
+ auto access(cl::sycl::handler& h){
+ return data_.template get_access<AccessMode>(h);
+ }
+};
+
/**
* Remote data class for the Sycl backend.
*/
template<typename T, typename Encoding, typename Storage>
-class remote_data<T, Encoding, Storage, rmt::Sycl> {
+class remote_data<T, Encoding, Storage, rmt::Sycl> final {
private:
/**
* An identifier to the data being held on the remote
@@ -30,20 +49,15 @@ public:
/**
* Main constructor
*/
- remote_data(data<T,Encoding,Storage>& remote_data__, cl::sycl::queue& queue__):
- remote_data_{&remote_data__},
+ remote_data(id<T> data_id__, cl::sycl::queue& queue__):
+ data_id_{data_id__},
queue_{&queue__}
{}
/**
* Destructor specifically designed to deallocate on the device.
*/
- ~remote_data(){
- if(remote_data_){
- cl::sycl::free(remote_data_,queue_);
- remote_data_ = nullptr;
- }
- }
+ ~remote_data(){}
SAW_FORBID_COPY(remote_data);
SAW_FORBID_MOVE(remote_data);
@@ -82,7 +96,7 @@ private:
/**
* The actual data
*/
- data<Schema,Encoding,Storage>* device_data_;
+ data<Schema,Encoding,storage::Default>* device_data_;
/**
* The sycl queue object
*/
@@ -91,7 +105,7 @@ public:
/**
* Main constructor
*/
- device_data(data<Schema,Encoding,Storage>& device_data__, cl::sycl::queue& queue__):
+ device_data(data<Schema,Encoding,storage::Default>& device_data__, cl::sycl::queue& queue__):
device_data_{&device_data__},
queue_{&queue__}
{}
@@ -111,6 +125,11 @@ public:
};
namespace impl {
+template<typename Schema, typename Encoding, typename Backend>
+struct device_id_map {
+ std::vector<device_data<Schema, Encoding, Backend>> data;
+};
+
template<typename Iface, typename Encoding, typename Storage>
struct rpc_id_map_helper {
static_assert(always_false<Iface, Encoding,Storage>, "Only supports Interface schema types.");
@@ -122,9 +141,12 @@ struct rpc_id_map_helper<schema::Interface<Members...>, Encoding, Storage> {
};
}
+}
+// Maybe a helper impl tmpl file?
+namespace saw {
+
/**
* Represents a remote Sycl device.
- *
*/
template<>
class device<rmt::Sycl> final {
@@ -161,11 +183,6 @@ public:
};
/**
- * Device data transport
- */
-
-
-/**
* Rpc Client class for the Sycl backend.
*/
template<typename Iface, typename Encoding, typename Storage>
@@ -222,7 +239,7 @@ private:
/**
* Basic storage for response data.
*/
- // impl::rpc_id_map_helper<Iface, Encoding, rmt::Sycl> storage_;
+ impl::rpc_id_map_helper<Iface, Encoding, rmt::Sycl> storage_;
public:
/**
@@ -369,22 +386,4 @@ public:
}
};
-template<typename T, uint64_t D>
-template<typename Storage>
-error_or<data<schema::Array<T,D>, encode::Native, rmt::Sycl>> data<schema::Array<T,D>, encode::Native, rmt::Sycl>::copy_to_device(const data<schema::Array<T,D>, encode::Native, Storage>& host_data, device<rmt::Sycl>& dev){
- /**
- * Retrieve handle
- */
- auto& cmd_handle = dev.get_handle();
-
- uint64_t* dev_len = cl::sycl::malloc_device<uint64_t>(1u, cmd_handle);
- uint64_t len = host_data.size();
- cmd_handle.template copy<uint64_t>(&len,dev_len, 1u);
-
- auto dev_dat = cl::sycl::malloc_device<data<T,encode::Native,storage::Default>>(host_data.size(), cmd_handle);
- cmd_handle.copy(&host_data.at(0), dev_dat, host_data.size());
- cmd_handle.wait();
-
- return data<schema::Array<T,D>,encode::Native, rmt::Sycl>{dev_len, dev_dat, cmd_handle};
-}
}
diff --git a/modules/remote-sycl/tests/calculator.cpp b/modules/remote-sycl/tests/calculator.foo
index 6d061ad..745bd3d 100644
--- a/modules/remote-sycl/tests/calculator.cpp
+++ b/modules/remote-sycl/tests/calculator.foo
@@ -15,7 +15,7 @@ using Calculator = Interface<
>;
}
-SAW_TEST("Sycl Interface Calculator"){
+SAW_TEST("SYCL Interface Calculator"){
using namespace saw;
cl::sycl::queue cmd_queue;
diff --git a/modules/remote-sycl/tests/sycl_basics.cpp b/modules/remote-sycl/tests/sycl_basics.cpp
new file mode 100644
index 0000000..bf41983
--- /dev/null
+++ b/modules/remote-sycl/tests/sycl_basics.cpp
@@ -0,0 +1,96 @@
+#include <forstio/test/suite.hpp>
+
+#include "../c++/remote.hpp"
+
+namespace {
+namespace schema {
+using namespace saw::schema;
+
+using TestStruct = Struct<
+ Member<UInt64, "foo">,
+ Member<Float64, "bar">,
+ Member<Array<Float64>, "doubles">
+>;
+
+using Foo = Interface<
+ Member<Function<TestStruct, Void>, "foo">
+>;
+
+using Calculator = Interface<
+ Member<
+ Function<Tuple<Int64, Int64>, Int64>, "add"
+ >
+, Member<
+ Function<Tuple<Int64, Int64>, Int64>, "multiply"
+ >
+>;
+}
+SAW_TEST("SYCL Test Setup"){
+ using namespace saw;
+
+ data<schema::TestStruct> host_data;
+ host_data.template get<"foo">() = 321;
+ host_data.template get<"bar">() = 50.0;
+ host_data.template get<"doubles">() = {1024u};
+
+ saw::event_loop loop;
+ saw::wait_scope wait{loop};
+
+ remote<rmt::Sycl> rmt;
+ saw::own<saw::remote_address<saw::rmt::Sycl>> rmt_addr{};
+
+ rmt.resolve_address().then([&](auto addr){
+ rmt_addr = std::move(addr);
+ }).detach();
+
+ wait.poll();
+ SAW_EXPECT(rmt_addr, "Remote Address class hasn't been filled");
+
+ auto device = rmt.connect_device(*rmt_addr);
+
+ data<schema::TestStruct, encode::Native, rmt::Sycl> device_data{host_data};
+
+ interface<schema::Foo, encode::Native,rmt::Sycl, cl::sycl::queue*> cl_iface {
+[&](data<schema::TestStruct, encode::Native, rmt::Sycl>& in, cl::sycl::queue* cmd) -> error_or<void> {
+
+ cmd->submit([&](cl::sycl::handler& h){
+
+ auto acc_buff = in.template access<cl::sycl::access::mode::write>(h);
+
+ uint64_t si = host_data.template get<"doubles">().size();
+
+ h.parallel_for(cl::sycl::range<1>(si), [=] (cl::sycl::id<1> it){
+ acc_buff[0u].template get<"foo">() = acc_buff[0u].template get<"doubles">().size();
+ auto& dbls = acc_buff[0u].template get<"doubles">();
+ dbls.at(it[0u]) = it[0u] * 2.0;
+ });
+ });
+ /*
+ cmd->submit([&](cl::sycl::handler& h){
+ auto acc_buff = in.template access<cl::sycl::access::mode::read>(h);
+ h.copy(acc_buff, &host_data);
+ });
+ */
+
+ /**
+ cl::sycl::host_accessor result{in.get_handle()};
+ std::cout<<result[0u].template get<"foo">().get()<<std::endl;
+ std::cout<<result[0u].template get<"bar">().get()<<std::endl;
+ **/
+ return saw::void_t{};
+ }
+ };
+
+ std::cout<<"Running on:\n"<<device.get_handle().get_device().get_info<cl::sycl::info::device::name>()<<std::endl;
+
+ std::cout<<host_data.template get<"foo">().get()<<std::endl;
+ std::cout<<host_data.template get<"bar">().get()<<std::endl;
+ std::cout<<std::endl;
+ cl_iface.template call <"foo">(device_data, &device.get_handle());
+ device.get_handle().wait();
+ std::cout<<host_data.template get<"foo">().get()<<std::endl;
+ std::cout<<host_data.template get<"bar">().get()<<std::endl;
+ auto& dbls = host_data.template get<"doubles">();
+ std::cout<<std::endl;
+}
+}