summaryrefslogtreecommitdiff
path: root/modules/remote-sycl
diff options
context:
space:
mode:
authorClaudius "keldu" Holeksa <mail@keldu.de>2024-06-20 16:35:25 +0200
committerClaudius "keldu" Holeksa <mail@keldu.de>2024-06-20 16:35:25 +0200
commit601113a445658d8b15273dd91c66cf20daf50d30 (patch)
treebcb6c2a77e85bb64d6beb9b3f93a5f7bc5a6e400 /modules/remote-sycl
parentc1d352270add2f205d038d7e4f69c1b4f35f014d (diff)
Changing towards a better allocated structure for sycl
Diffstat (limited to 'modules/remote-sycl')
-rw-r--r--modules/remote-sycl/c++/remote.hpp225
-rw-r--r--modules/remote-sycl/examples/sycl_basic_kernel.cpp6
-rw-r--r--modules/remote-sycl/tests/calculator.cpp4
3 files changed, 94 insertions, 141 deletions
diff --git a/modules/remote-sycl/c++/remote.hpp b/modules/remote-sycl/c++/remote.hpp
index 1873669..54b7a7b 100644
--- a/modules/remote-sycl/c++/remote.hpp
+++ b/modules/remote-sycl/c++/remote.hpp
@@ -18,169 +18,96 @@ template<typename T, typename Encoding, typename Storage>
class remote_data<T, Encoding, Storage, rmt::Sycl> {
private:
/**
- * Id representing the remote data
+ * An identifier to the data being held on the remote
*/
- id<T> id_;
+ id<T> data_id_;
+
/**
- * Storage for the
+ * The sycl queue object
*/
- id_map<T,Encoding,rmt::Sycl>* map_;
+ cl::sycl::queue* queue_;
public:
/**
* Main constructor
*/
+ remote_data(data<T,Encoding,Storage>& remote_data__, cl::sycl::queue& queue__):
+ remote_data_{&remote_data__},
+ queue_{&queue__}
+ {}
+
+ /**
+ * Destructor specifically designed to deallocate on the device.
+ */
+ ~remote_data(){
+ if(remote_data_){
+ cl::sycl::free(remote_data_,queue_);
+ remote_data_ = nullptr;
+ }
+ }
+
+ SAW_FORBID_COPY(remote_data);
+ SAW_FORBID_MOVE(remote_data);
+ /**
remote_data(const id<T>& id, id_map<T, Encoding, rmt::Sycl>& map, cl::sycl::queue& queue__):
id_{id},
map_{&map}
{}
+ */
/**
* Wait for the data
*/
error_or<data<T,Encoding,Storage>> wait(){
- auto eov = map_->find(id_);
- if(eov.is_error()){
- auto& err = eov.get_error();
- return std::move(err);
- }
- auto& val = eov.get_value();
- std::cout<<"Values Sycl in Map: "<<val->size()<<std::endl;
-
- {
- auto eocop = val->template copy_to_host<storage::Default>();
- if(eocop.is_error()){
- return eocop;
- }
- return eocop.get_value();
- }
+ return make_error<err::not_implemented>();
}
/**
* Request data asynchronously
*/
- conveyor<data<T,Encoding,Storage>> on_receive(); /// Stopped here
+ // conveyor<data<T,Encoding,Storage>> on_receive(); /// Stopped here
};
/**
- * Sycl data class for handling the array Schema.
+ * Meant to be a helper object which holds the allocated data on the sycl side
*/
-template<typename T, uint64_t D>
-class data<schema::Array<T,D>, encode::Native, rmt::Sycl> {
-public:
- using Schema = schema::Array<T,D>;
+template<typename Schema, typename Encoding, typename Backend>
+class device_data;
+
+/**
+ * This class helps in regards to the ownership on the server side
+ */
+template<typename Schema, typename Encoding>
+class device_data<Schema, Encoding, rmt::Sycl> {
private:
/**
- * Absolute size of the stored elementes.
- */
- uint64_t total_length_;
- /**
- * The data itself.
+ * The actual data
*/
- data<T,encode::Native,storage::Default>* device_data_;
+ data<Schema,Encoding,Storage>* device_data_;
/**
- * Referenced sycl queue
+ * The sycl queue object
*/
cl::sycl::queue* queue_;
-
- static_assert(is_primitive<T>::value, "Only supports primitives for now");
- static_assert(D==1u, "For now we only support 1D Arrays");
public:
- data(uint64_t size, cl::sycl::queue& q__):
- total_length_{size},
- device_data_{cl::sycl::malloc_device<data<T,encode::Native,storage::Default>>(size, q__)},
- queue_{&q__}
- {
- if(!device_data_){
- total_length_ = 0u;
- return;
- }
- queue_->wait();
- }
-
- template<typename Encoding, typename Storage>
- data(const data<Schema, Encoding, Storage>& from, cl::sycl::queue& q__):
- total_length_{from.size()},
- device_data_{cl::sycl::malloc_device<data<T,encode::Native,storage::Default>>(from.size(), q__)},
- queue_{&q__}
- {
- if(!device_data_){
- total_length_ = 0u;
- return;
- }
- queue_->template copy<data<T,encode::Native,storage::Default>>(&from.at(0), device_data_, total_length_);
- queue_->wait();
- }
-
- data(const data<Schema, encode::Native, rmt::Sycl>& from):
- total_length_{from.size()},
- device_data_{nullptr},
- queue_{from.queue_}
- {
- if(total_length_ == 0u || !queue_){
- return;
- }
- device_data_ = cl::sycl::malloc_device<data<T,encode::Native,storage::Default>>(from.size(), *queue_);
- // device_data_ = cl::sycl::malloc_device<typename native_data_type<T>::type>(from.size(), *queue_);
- if(!device_data_){
- total_length_ = 0u;
- return;
- }
-
- queue_->template copy<data<T,encode::Native,storage::Default>>(from.device_data_, device_data_, total_length_);
- }
+ /**
+ * Main constructor
+ */
+ device_data(data<Schema,Encoding,Storage>& device_data__, cl::sycl::queue& queue__):
+ device_data_{&device_data__},
+ queue_{&queue__}
+ {}
- data(data<Schema, encode::Native, rmt::Sycl>&& rhs):
- total_length_{rhs.total_length_},
- device_data_{rhs.device_data_},
- queue_{rhs.queue_}
- {
- rhs.total_length_ = 0u;
- rhs.device_data_ = nullptr;
- rhs.queue_ = nullptr;
- }
-
- data<Schema, encode::Native, rmt::Sycl>& operator=(data<Schema, encode::Native, rmt::Sycl>&& rhs){
- total_length_ = rhs.total_length_;
- device_data_ = rhs.device_data_;
- queue_ = rhs.queue_;
- rhs.total_length_ = 0u;
- rhs.device_data_ = nullptr;
- rhs.queue_ = nullptr;
- return *this;
- }
-
- ~data(){
- // free data
- if(device_data_){
- /// SYCL FREE
- cl::sycl::free(device_data_, *queue_);
- }
- }
-
/**
- * Allocate appropriate meta data and then copy to host
+ * Destructor specifically designed to deallocate on the device.
*/
- template<typename Storage>
- error_or<data<Schema, encode::Native, Storage>> copy_to_host() const {
- data<Schema,encode::Native, Storage> data_{total_length_};
-
- /// TODO Check success
- queue_->template copy<data<T,encode::Native,storage::Default>>(device_data_, &data_.at(0), total_length_);
- queue_->wait();
- return data_;
- }
-
- template<typename Storage>
- static error_or<data<Schema, encode::Native, rmt::Sycl>> copy_to_device(const data<Schema, encode::Native, Storage>& host_data, device<rmt::Sycl>& dev);
-
-
- data<T, encode::Native, storage::Default>& at(uint64_t i){
- return device_data_[i];
+ ~device_data(){
+ if(device_data_){
+ cl::sycl::free(device_data_,queue_);
+ device_data_ = nullptr;
+ }
}
- uint64_t size() const {
- return total_length_;
- }
+ SAW_FORBID_COPY(device_data);
+ SAW_FORBID_MOVE(device_data);
};
namespace impl {
@@ -214,7 +141,7 @@ public:
*/
template<typename Schema, typename Encoding, typename Storage>
error_or<data<Schema, Encoding, rmt::Sycl>> copy_to_device(const data<Schema, Encoding, Storage>& host_data){
- return data<Schema, Encoding, rmt::Sycl>::copy_to_device(host_data);
+ return data<Schema, Encoding, rmt::Sycl>::copy_to_device(host_data, *this);
}
/**
@@ -280,7 +207,7 @@ template<typename Iface, typename Encoding>
class rpc_server<Iface, Encoding, rmt::Sycl> {
public:
using InterfaceCtxT = cl::sycl::queue*;
- using InterfaceT = interface<Iface, Encoding, rmt::Sycl, InterfaceCtxT>;
+ using InterfaceT = interface<Iface, Encoding, storage::Default, InterfaceCtxT>;
private:
/**
* Device instance enabling the use of the remote device.
@@ -290,18 +217,18 @@ private:
/**
* The interface including the relevant context class.
*/
- interface<Iface, Encoding, rmt::Sycl, InterfaceCtxT> cl_interface_;
+ interface<Iface, Encoding, storage::Default, InterfaceCtxT> cl_interface_;
/**
* Basic storage for response data.
*/
- impl::rpc_id_map_helper<Iface, Encoding, rmt::Sycl> storage_;
+ // impl::rpc_id_map_helper<Iface, Encoding, rmt::Sycl> storage_;
public:
/**
* Main constructor
*/
- rpc_server(device<rmt::Sycl>& dev__, interface<Iface, Encoding, rmt::Sycl, InterfaceCtxT> cl_iface):
+ rpc_server(device<rmt::Sycl>& dev__, InterfaceT cl_iface):
device_{&dev__},
cl_interface_{std::move(cl_iface)},
storage_{}
@@ -310,31 +237,35 @@ public:
/**
* Ask which id the server prefers as the next one. Only available for fast requests on no roundtrip setups.
*/
+ /**
template<typename T>
id<T> next_free_id() const {
return std::get<id_map<T,Encoding,rmt::Sycl>>(storage_.maps).next_free_id();
}
+ */
+ /**
template<typename IdT, typename Storage>
remote_data<IdT, Encoding, Storage, rmt::Sycl> request_data(id<IdT> dat_id){
return {dat_id, std::get<id_map<IdT,Encoding,rmt::Sycl>>(storage_.maps), device_->get_handle()};
}
+ */
/**
* Rpc call based on the name
*/
- template<string_literal Name, typename ClientAllocation>
+ template<string_literal Name>
error_or<
id<
typename schema_member_type<Name, Iface>::type::ResponseT
>
- > call(data_or_id<typename schema_member_type<Name, Iface>::type::RequestT, Encoding, ClientAllocation> input, id<typename schema_member_type<Name,Iface>::type::ResponseT> rpc_id){
+ > call(data_or_id<typename schema_member_type<Name, Iface>::type::RequestT, Encoding, storage::Default> input, id<typename schema_member_type<Name,Iface>::type::ResponseT> rpc_id){
using FuncT = typename schema_member_type<Name, Iface>::type;
/**
* Object needed if and only if the provided data type is not an id
*/
- own<data<typename FuncT::RequestT, Encoding, rmt::Sycl>> dev_tmp_inp = nullptr;
+ own<device_data<typename FuncT::RequestT, Encoding, rmt::Sycl>> dev_tmp_inp = nullptr;
/**
* First check if it's data or an id.
* If it's an id, check if it's registered within the storage and retrieve it.
@@ -350,7 +281,14 @@ public:
return eov.get_value();
} else {
auto& client_data = input.get_data();
- dev_tmp_inp = heap<data<typename FuncT::RequestT, Encoding, rmt::Sycl>>(client_data, device_->get_handle());
+
+ auto eov = device_->template copy_to_device(client_data);
+ if(eov.is_error()){
+ return std::move(eov.get_error());
+ }
+ auto& val = eov.get_value();
+
+ dev_tmp_inp = heap<data<typename FuncT::RequestT, Encoding, rmt::Sycl>>(std::move(val));
device_->get_handle().wait();
return dev_tmp_inp.get();
}
@@ -434,8 +372,19 @@ public:
template<typename T, uint64_t D>
template<typename Storage>
error_or<data<schema::Array<T,D>, encode::Native, rmt::Sycl>> data<schema::Array<T,D>, encode::Native, rmt::Sycl>::copy_to_device(const data<schema::Array<T,D>, encode::Native, Storage>& host_data, device<rmt::Sycl>& dev){
- data<schema::Array<T,D>, encode::Native, rmt::Sycl> device_data{host_data.size(), dev.get_handle()};
- return make_error<err::not_implemented>();
-}
+ /**
+ * Retrieve handle
+ */
+ auto& cmd_handle = dev.get_handle();
+
+ uint64_t* dev_len = cl::sycl::malloc_device<uint64_t>(1u, cmd_handle);
+ uint64_t len = host_data.size();
+ cmd_handle.template copy<uint64_t>(&len,dev_len, 1u);
+ auto dev_dat = cl::sycl::malloc_device<data<T,encode::Native,storage::Default>>(host_data.size(), cmd_handle);
+ cmd_handle.copy(&host_data.at(0), dev_dat, host_data.size());
+ cmd_handle.wait();
+
+ return data<schema::Array<T,D>,encode::Native, rmt::Sycl>{dev_len, dev_dat, cmd_handle};
+}
}
diff --git a/modules/remote-sycl/examples/sycl_basic_kernel.cpp b/modules/remote-sycl/examples/sycl_basic_kernel.cpp
index 6481eb9..f9a838e 100644
--- a/modules/remote-sycl/examples/sycl_basic_kernel.cpp
+++ b/modules/remote-sycl/examples/sycl_basic_kernel.cpp
@@ -2,9 +2,13 @@
saw::rpc_server<schema::BasicInterface, saw::encode::Native, saw::rmt::Sycl> listen_basic_sycl(saw::remote<saw::rmt::Sycl>& ctx, saw::device<saw::rmt::Sycl>& dev, saw::remote_address<saw::rmt::Sycl>& addr){
saw::interface<schema::BasicInterface, saw::encode::Native, saw::rmt::Sycl, cl::sycl::queue*> iface{
+ /**
+ * This is the increment kernel
+ */
+ [](saw::data<saw::schema::Array<saw::schema::UInt64>, saw::encode::Native, saw::rmt::Sycl>& in, cl::sycl::queue* q) -> saw::error_or<void> {
- [](saw::data<saw::schema::Array<saw::schema::UInt64>, saw::encode::Native, saw::rmt::Sycl> in, cl::sycl::queue* q) -> saw::data<saw::schema::Array<saw::schema::UInt64>, saw::encode::Native, saw::rmt::Sycl> {
q->submit([&](cl::sycl::handler& h){
+
h.single_task([&] (){
in.at(0u).set(in.at(0u).get() + 1u);
});
diff --git a/modules/remote-sycl/tests/calculator.cpp b/modules/remote-sycl/tests/calculator.cpp
index 730838d..6d061ad 100644
--- a/modules/remote-sycl/tests/calculator.cpp
+++ b/modules/remote-sycl/tests/calculator.cpp
@@ -21,7 +21,7 @@ SAW_TEST("Sycl Interface Calculator"){
cl::sycl::queue cmd_queue;
interface<schema::Calculator, encode::Native<storage::Default>, cl::sycl::queue*> cl_iface {
-[](data<schema::Tuple<schema::Int64, schema::Int64>> in, cl::sycl::queue* cmd) -> data<schema::Int64> {
+[](data<schema::Tuple<schema::Int64, schema::Int64>>& in, cl::sycl::queue* cmd) -> data<schema::Int64> {
std::array<int64_t,2> h_xy{in.get<0>().get(), in.get<1>().get()};
int64_t res{};
cl::sycl::buffer<int64_t,1> d_xy { h_xy.data(), h_xy.size() };
@@ -37,7 +37,7 @@ SAW_TEST("Sycl Interface Calculator"){
cmd->wait();
return data<schema::Int64>{res};
},
- [](data<schema::Tuple<schema::Int64, schema::Int64>> in, cl::sycl::queue* cmd) -> data<schema::Int64> {
+ [](data<schema::Tuple<schema::Int64, schema::Int64>,encode::Native,rmt::Sycl>& in, cl::sycl::queue* cmd) -> data<schema::Int64> {
return data<schema::Int64>{in.get<0>().get() * in.get<1>().get()};
}
};