summaryrefslogtreecommitdiff
path: root/modules/remote-sycl
diff options
context:
space:
mode:
authorClaudius "keldu" Holeksa <mail@keldu.de>2024-06-13 17:34:22 +0200
committerClaudius "keldu" Holeksa <mail@keldu.de>2024-06-13 17:34:22 +0200
commit57f6eacfcdbdba31185eb66b9a573a8923eecf16 (patch)
tree1683da4209744fabbe87a949134701d617c0d5f9 /modules/remote-sycl
parent0f317186de9fb11d336e564f808e4732386c4074 (diff)
Possible fix for transferring primitives to device without dropping STL
Diffstat (limited to 'modules/remote-sycl')
-rw-r--r--modules/remote-sycl/c++/remote.hpp83
-rw-r--r--modules/remote-sycl/examples/sycl_basic.cpp13
-rw-r--r--modules/remote-sycl/examples/sycl_basic_kernel.cpp2
3 files changed, 67 insertions, 31 deletions
diff --git a/modules/remote-sycl/c++/remote.hpp b/modules/remote-sycl/c++/remote.hpp
index 677a427..bcc8a3c 100644
--- a/modules/remote-sycl/c++/remote.hpp
+++ b/modules/remote-sycl/c++/remote.hpp
@@ -22,16 +22,25 @@ class remote_data<T, Encoding, Storage, rmt::Sycl> {
private:
id<T> id_;
id_map<T,Encoding,rmt::Sycl>* map_;
+ cl::sycl::queue* queue_;
public:
/**
* Main constructor
*/
- remote_data(const id<T>& id, id_map<T, Encoding, rmt::Sycl>& map):
+ remote_data(const id<T>& id, id_map<T, Encoding, rmt::Sycl>& map, cl::sycl::queue& queue__):
id_{id},
- map_{&map}
+ map_{&map},
+ queue_{&queue__}
{}
/**
+ * Wait for the data
+ */
+ error_or<data<T,Encoding,Storage>> wait(){
+
+ }
+
+ /**
* Request data asynchronously
*/
conveyor<data<T,Encoding,Storage>> on_receive(); /// Stopped here
@@ -39,21 +48,14 @@ public:
/**
*
- */
template<typename T, uint64_t N>
class data<schema::Primitive<T,N>, encode::Native, rmt::Sycl> {
public:
using Schema = schema::Primitive<T,N>;
using NativeType = typename native_data_type<Schema>::type;
private:
- /**
- *
- */
NativeType val_;
public:
- /**
- *
- */
data(NativeType val__):
val_{val__}
{}
@@ -62,6 +64,7 @@ public:
return val_;
}
};
+ */
template<typename T, uint64_t D>
class data<schema::Array<T,D>, encode::Native, rmt::Sycl> {
@@ -69,8 +72,8 @@ public:
using Schema = schema::Array<T,D>;
private:
uint64_t total_length_;
- typename native_data_type<T>::type* device_data_;
- // data<T>* device_data_;
+ // typename native_data_type<T>::type* device_data_;
+ data<T,encode::Native,storage::Default>* device_data_;
cl::sycl::queue* queue_;
static_assert(is_primitive<T>::value, "Only supports primitives for now");
@@ -78,7 +81,8 @@ private:
public:
data(uint64_t size, cl::sycl::queue& q__):
total_length_{size},
- device_data_{cl::sycl::malloc_device<typename native_data_type<T>::type>(size, q__)},
+ device_data_{cl::sycl::malloc_device<data<T,encode::Native,storage::Default>>(size, q__)},
+ //device_data_{cl::sycl::malloc_device<typename native_data_type<T>::type>(size, q__)},
queue_{&q__}
{
if(!device_data_){
@@ -89,12 +93,15 @@ public:
template<typename Encoding, typename Storage>
data(const data<Schema, Encoding, Storage>& from, cl::sycl::queue& q__):
total_length_{from.size()},
- device_data_{cl::sycl::malloc_device<typename native_data_type<T>::type>(from.size(), q__)},
+ device_data_{cl::sycl::malloc_device<data<T,encode::Native,storage::Default>>(from.size(), q__)},
+ //device_data_{cl::sycl::malloc_device<typename native_data_type<T>::type>(from.size(), q__)},
queue_{&q__}
{
if(!device_data_){
total_length_ = 0u;
+ return;
}
+ queue_->template copy<data<T,encode::Native,storage::Default>>(&from.at(0), device_data_, total_length_);
}
data(const data<Schema, encode::Native, rmt::Sycl>& from):
@@ -105,11 +112,23 @@ public:
if(total_length_ == 0u || !queue_){
return;
}
- device_data_ = cl::sycl::malloc_device<typename native_data_type<T>::type>(from.size(), *queue_);
+ device_data_ = cl::sycl::malloc_device<data<T,encode::Native,storage::Default>>(from.size(), *queue_);
+ // device_data_ = cl::sycl::malloc_device<typename native_data_type<T>::type>(from.size(), *queue_);
if(!device_data_){
total_length_ = 0u;
}
}
+
+ data<Schema, encode::Native, rmt::Sycl>& operator=(const data<Schema, encode::Native, rmt::Sycl>& rhs) {
+ total_length_ = rhs.total_length_;
+ device_data_ = cl::sycl::malloc_device<data<T,encode::Native,storage::Default>>(rhs.size(), *rhs.queue_);
+ // device_data_ = cl::sycl::malloc_device<typename native_data_type<T>::type>(rhs.size(), *rhs.queue_);
+ if(!device_data_){
+ total_length_ = 0u;
+ }
+ queue_ = rhs.queue_;
+ return *this;
+ }
data(data<Schema, encode::Native, rmt::Sycl>&& rhs):
total_length_{rhs.total_length_},
@@ -139,8 +158,8 @@ public:
}
}
- // data<T,encode::Native,rmt::Sycl>& at(uint64_t i){
- typename native_data_type<T>::type& at(uint64_t i){
+ data<T, encode::Native, saw::storage::Default>& at(uint64_t i){
+ //typename native_data_type<T>::type& at(uint64_t i){
return device_data_[i];
}
@@ -160,6 +179,7 @@ struct rpc_id_map_helper<schema::Interface<Members...>, Encoding, Storage> {
std::tuple<id_map<typename Members::ValueType::ResponseT, Encoding, Storage>...> maps;
};
}
+
/**
* Rpc Client class for the Sycl backend.
*/
@@ -171,6 +191,10 @@ private:
* Server this client is tied to
*/
rpc_server<Iface, Encoding, rmt::Sycl>* srv_;
+
+ /**
+ * Generated some sort of id for the request.
+ */
public:
rpc_client(rpc_server<Iface, Encoding, rmt::Sycl>& srv):
srv_{&srv}
@@ -184,9 +208,9 @@ public:
id<
typename schema_member_type<Name, Iface>::type::ResponseT
>
- > call(data_or_id<typename schema_member_type<Name, Iface>::type::RequestT, Encoding, Storage> input){
- (void) input;
- return make_error<err::not_implemented>("RpcClient side is not implemented");
+ > call(const data_or_id<typename schema_member_type<Name, Iface>::type::RequestT, Encoding, Storage>& input){
+ auto next_free_id = srv_->template next_free_id<typename schema_member_type<Name, Iface>::type::ResponseT>();
+ return srv_->template call<Name, Storage>(input, next_free_id);
}
};
@@ -215,6 +239,14 @@ private:
*/
impl::rpc_id_map_helper<Iface, Encoding, rmt::Sycl> storage_;
public:
+ /**
+ * Ask which id the server prefers as the next one. Only available for fast requests on no roundtrip setups.
+ */
+ template<typename T>
+ id<T> next_free_id() const {
+ return std::get<id_map<T,Encoding,rmt::Sycl>>(storage_.maps).next_free_id();
+ }
+
rpc_server(interface<Iface, Encoding, rmt::Sycl, InterfaceCtxT> cl_iface):
cmd_queue_{},
cl_interface_{std::move(cl_iface)},
@@ -222,9 +254,9 @@ public:
{}
template<typename IdT, typename Storage>
- remote_data<IdT, Encoding, Storage, rmt::Sycl> request_data(id<IdT> dat){
+ remote_data<IdT, Encoding, Storage, rmt::Sycl> request_data(id<IdT> dat_id){
/// @TODO Fix so I can receive data
- return {dat, std::get<id_map<IdT, Encoding,rmt::Sycl>>(storage_.maps)};
+ return {dat_id, std::get<id_map<IdT, Encoding,rmt::Sycl>>(storage_.maps)};
}
/**
@@ -235,7 +267,7 @@ public:
id<
typename schema_member_type<Name, Iface>::type::ResponseT
>
- > call(data_or_id<typename schema_member_type<Name, Iface>::type::RequestT, Encoding, ClientAllocation> input){
+ > call(data_or_id<typename schema_member_type<Name, Iface>::type::RequestT, Encoding, ClientAllocation> input, id<typename schema_member_type<Name,Iface>::type::ResponseT> rpc_id){
using FuncT = typename schema_member_type<Name, Iface>::type;
/**
@@ -258,6 +290,7 @@ public:
} else {
auto& client_data = input.get_data();
dev_tmp_inp = heap<data<typename FuncT::RequestT, Encoding, rmt::Sycl>>(client_data, cmd_queue_);
+ cmd_queue_.wait();
return dev_tmp_inp.get();
}
}();
@@ -272,16 +305,16 @@ public:
return std::move(eod.get_error());
}
+ auto& val = eod.get_value();
/**
* Store returned data in rpc storage
*/
- auto& val = eod.get_value();
auto& inner_map = std::get<id_map<typename schema_member_type<Name, Iface>::type::RequestT, Encoding,rmt::Sycl>> (storage_.maps);
- auto eoid = inner_map.insert(std::move(val));
+ auto eoid = inner_map.insert_as(std::move(val), rpc_id);
if(eoid.is_error()){
return std::move(eoid.get_error());
}
- return eoid.get_value();
+ return rpc_id;
}
};
diff --git a/modules/remote-sycl/examples/sycl_basic.cpp b/modules/remote-sycl/examples/sycl_basic.cpp
index 677fd29..2e9a4f8 100644
--- a/modules/remote-sycl/examples/sycl_basic.cpp
+++ b/modules/remote-sycl/examples/sycl_basic.cpp
@@ -14,25 +14,25 @@ int main(){
}).detach();
wait.poll();
-
if(!rmt_addr){
return -1;
}
auto rpc_server = listen_basic_sycl(remote_ctx, *rmt_addr);
+ saw::rpc_client<schema::BasicInterface, saw::encode::Native, saw::storage::Default, saw::rmt::Sycl> client{rpc_server};
- saw::id<schema::Array<schema::UInt64>> next_id{0u};
+ saw::id<schema::Array<schema::UInt64>> id_zero{0u};
{
- auto eov = rpc_server.template call<"increment", saw::storage::Default>(saw::data<schema::Array<schema::UInt64>, saw::encode::Native>{1u});
+ auto eov = client.template call<"increment">(saw::data<schema::Array<schema::UInt64>, saw::encode::Native>{1u});
if(eov.is_error()){
auto& err = eov.get_error();
std::cerr<<"Error: "<<err.get_category()<<" : "<<err.get_message()<<std::endl;
return -2;
}
- next_id = eov.get_value();
+ id_zero = eov.get_value();
}
{
- auto eov = rpc_server.template call<"increment", saw::storage::Default>(next_id);
+ auto eov = client.template call<"increment">(id_zero);
if(eov.is_error()){
auto& err = eov.get_error();
std::cerr<<"Error: "<<err.get_category()<<" : "<<err.get_message()<<std::endl;
@@ -41,6 +41,9 @@ int main(){
auto& val = eov.get_value();
std::cout<<"Value: "<<val.get_value()<<std::endl;
}
+ {
+ // auto eo_rd = rpc_server.request_data(id_one);
+ }
return 0;
}
diff --git a/modules/remote-sycl/examples/sycl_basic_kernel.cpp b/modules/remote-sycl/examples/sycl_basic_kernel.cpp
index 94583b9..888f905 100644
--- a/modules/remote-sycl/examples/sycl_basic_kernel.cpp
+++ b/modules/remote-sycl/examples/sycl_basic_kernel.cpp
@@ -7,7 +7,7 @@ saw::rpc_server<schema::BasicInterface, saw::encode::Native, saw::rmt::Sycl> lis
q->submit([&](cl::sycl::handler& h){
h.parallel_for(cl::sycl::range<1>(1u), [&] (cl::sycl::id<1> it){
- in.at(0u) += 1u;
+ in.at(0u).set(in.at(0u).get() + 1u);
});
});
q->wait();