summaryrefslogtreecommitdiff
path: root/modules/remote-sycl
diff options
context:
space:
mode:
authorClaudius "keldu" Holeksa <mail@keldu.de>2024-05-29 21:16:06 +0200
committerClaudius "keldu" Holeksa <mail@keldu.de>2024-05-29 21:16:06 +0200
commit60d0f8da2b754d1deb0dbb59f6e6783ba4c692c4 (patch)
treec0d49736f035220640ed01cfb210d37e7bb254cb /modules/remote-sycl
parent7b6e0ca99f8521e034452f0d0243a7f3e33843a9 (diff)
Reworked id_map and trying to fix sycl launch
Diffstat (limited to 'modules/remote-sycl')
-rw-r--r--modules/remote-sycl/c++/remote.hpp51
1 files changed, 40 insertions, 11 deletions
diff --git a/modules/remote-sycl/c++/remote.hpp b/modules/remote-sycl/c++/remote.hpp
index d4b114a..003dd0e 100644
--- a/modules/remote-sycl/c++/remote.hpp
+++ b/modules/remote-sycl/c++/remote.hpp
@@ -2,7 +2,7 @@
#include <forstio/io_codec/rpc.hpp>
#include <forstio/codec/data.hpp>
-#include <forstio/id_map.hpp>
+#include <forstio/codec/id_map.hpp>
#include <CL/sycl.hpp>
@@ -21,12 +21,12 @@ template<typename T, typename Encoding>
class remote_data<T, Encoding, rmt::Sycl> {
private:
id<T> id_;
- id_map<T>* map_;
+ id_map<T,Encoding>* map_;
public:
/**
* Main constructor
*/
- remote_data(const id<T>& id, id_map<data<T, Encoding>>& map):
+ remote_data(const id<T>& id, id_map<T, Encoding>& map):
id_{id},
map_{&map}
{}
@@ -55,7 +55,7 @@ struct rpc_id_map_helper {
template<typename... Members, typename Encoding>
struct rpc_id_map_helper<schema::Interface<Members...>, Encoding> {
- std::tuple<id_map<data<typename Members::ValueType::ResponseT, Encoding>>...> maps;
+ std::tuple<id_map<typename Members::ValueType::ResponseT, Encoding>...> maps;
};
}
/**
@@ -78,7 +78,7 @@ private:
interface<Iface, Encoding, InterfaceCtxT> cl_interface_;
/**
- *
+ * Basic storage for response data.
*/
impl::rpc_id_map_helper<Iface, Encoding> storage_;
public:
@@ -90,11 +90,11 @@ public:
template<typename IdT>
remote_data<IdT, Encoding, rmt::Sycl> request_data(id<IdT> dat){
- return {dat, std::get<id_map<data<IdT, Encoding>>>(storage_.maps)};
+ return {dat, std::get<id_map<IdT, Encoding>>(storage_.maps)};
}
/**
- * rpc call
+ * Rpc call
*/
template<string_literal Name>
error_or<
@@ -103,15 +103,44 @@ public:
>
> call(data_or_id<typename schema_member_type<Name, Iface>::type::RequestT, Encoding> input){
- auto eod = cl_interface_.template call<Name>(std::move(input), &cmd_queue_);
+ /**
+ * First check if it's data or an id.
+ * If it's an id, check if it's registered within the storage and retrieve it.
+ */
+ auto eoinp = [&,this]() -> error_or<data<typename schema_member_type<Name, Iface>::type::RequestT, Encoding>* > {
+ if(input.is_id()){
+ // storage_.maps
+ auto& inner_map = std::get<id_map<typename schema_member_type<Name, Iface>::type::RequestT, Encoding >> (storage_.maps);
+ auto eov = inner_map.find(input.get_id());
+ if(eov.is_error()){
+ return std::move(eov.get_error());
+ }
+ return eov.get_value();
+ }else {
+ return &input.get_data();
+ }
+ }();
+ if(eoinp.is_error()){
+ return std::move(eoinp.get_error());
+ }
+ auto& inp = *(eoinp.get_value());
+
+ auto eod = cl_interface_.template call<Name>(std::move(inp), &cmd_queue_);
if(eod.is_error()){
return std::move(eod.get_error());
}
- // using ResponseTMap = id_map<data<>>
-
- return id<typename schema_member_type<Name, Iface>::type::ResponseT>{};
+ /**
+ * Store returned data in rpc storage
+ */
+ auto& val = eod.get_value();
+ auto& inner_map = std::get<id_map<typename schema_member_type<Name, Iface>::type::RequestT, Encoding >> (storage_.maps);
+ auto eoid = inner_map.insert(std::move(val));
+ if(eoid.is_error()){
+ return std::move(eoid.get_error());
+ }
+ return eoid.get_value();
}
};