summaryrefslogtreecommitdiff
path: root/modules/remote-hip/c++/rpc.hpp
diff options
context:
space:
mode:
Diffstat (limited to 'modules/remote-hip/c++/rpc.hpp')
-rw-r--r--modules/remote-hip/c++/rpc.hpp276
1 files changed, 276 insertions, 0 deletions
diff --git a/modules/remote-hip/c++/rpc.hpp b/modules/remote-hip/c++/rpc.hpp
new file mode 100644
index 0000000..f6b519b
--- /dev/null
+++ b/modules/remote-hip/c++/rpc.hpp
@@ -0,0 +1,276 @@
+#pragma once
+
+#include "common.hpp"
+#include "remote.hpp"
+#include "data.hpp"
+#include "device.hpp"
+#include "transfer.hpp"
+
+namespace saw {
+/**
+ * Remote data class for the Sycl backend.
+ */
+template<typename T, typename Encoding>
+class remote_data<T, Encoding, rmt::Hip> final {
+private:
+ /**
+ * An identifier to the data being held on the remote
+ */
+ id<T> data_id_;
+
+ /**
+ * The sycl queue object
+ */
+ cl::sycl::queue* queue_;
+public:
+ /**
+ * Main constructor
+ */
+ remote_data(id<T> data_id__, cl::sycl::queue& queue__):
+ data_id_{data_id__},
+ queue_{&queue__}
+ {}
+
+ /**
+ * Destructor specifically designed to deallocate on the device.
+ */
+ ~remote_data(){}
+
+ SAW_FORBID_COPY(remote_data);
+ SAW_FORBID_MOVE(remote_data);
+ /**
+ remote_data(const id<T>& id, id_map<T, Encoding, rmt::Hip>& map, cl::sycl::queue& queue__):
+ id_{id},
+ map_{&map}
+ {}
+ */
+
+ /**
+ * Wait for the data
+ */
+ error_or<data<T,Encoding>> wait(){
+ return make_error<err::not_implemented>();
+ }
+
+ /**
+ * Request data asynchronously
+ */
+ // conveyor<data<T,Encoding>> on_receive(); /// Stopped here
+};
+
+/**
+ * Meant to be a helper object which holds the allocated data on the sycl side
+ */
+//template<typename Schema, typename Encoding, typename Backend>
+//class device_data;
+
+/**
+ * This class helps in regards to the ownership on the server side
+template<typename Schema, typename Encoding>
+class device_data<Schema, Encoding, rmt::Hip> {
+private:
+ data<Schema,Encoding,storage::Default>* device_data_;
+ cl::sycl::queue* queue_;
+public:
+ device_data(data<Schema,Encoding,storage::Default>& device_data__, cl::sycl::queue& queue__):
+ device_data_{&device_data__},
+ queue_{&queue__}
+ {}
+
+ ~device_data(){
+ if(device_data_){
+ cl::sycl::free(device_data_,queue_);
+ device_data_ = nullptr;
+ }
+ }
+
+ SAW_FORBID_COPY(device_data);
+ SAW_FORBID_MOVE(device_data);
+};
+ */
+
+}
+// Maybe a helper impl tmpl file?
+namespace saw {
+
+
+namespace impl {
+template<typename Func>
+struct rpc_func_type_helper;
+
+template<typename Response, typename Request>
+struct rpc_func_type_helper<schema::Function<Request, Response>>{
+ using type = tmpl_group<Response, Request>;
+};
+
+template<typename Iface>
+struct rpc_iface_type_helper {
+ using type = tmpl_group<>;
+};
+
+template<typename Func, string_literal K, typename... Functions, string_literal... Keys>
+struct rpc_iface_type_helper<schema::Interface<schema::Member<Func,K>,schema::Member<Functions,Keys>...>> {
+ using inner_type = typename rpc_func_type_helper<Func>::type;
+ using type = typename tmpl_concat<inner_type, typename rpc_iface_type_helper<schema::Interface<schema::Member<Functions,Keys>...>>::type>::type;
+};
+}
+
+/**
+ * Rpc Client class for the Sycl backend.
+ */
+template<typename Iface, typename Encoding>
+class rpc_client<Iface, Encoding, rmt::Hip> {
+public:
+private:
+ /**
+ * Server this client is tied to
+ */
+ rpc_server<Iface, Encoding, rmt::Hip>* srv_;
+
+ /**
+ * TransferClient created from the internal RPC data server
+ */
+ data_client<typename impl::rpc_iface_type_helper<Iface>::type, Encoding, rmt::Hip> data_client_;
+
+ /**
+ * Generated some sort of id for the request.
+ */
+public:
+ rpc_client(rpc_server<Iface, Encoding, rmt::Hip>& srv):
+ srv_{&srv},
+ data_client_{srv_->data_server}
+ {}
+
+ /**
+ * Rpc call
+ */
+ template<string_literal Name>
+ error_or<
+ id<
+ typename schema_member_type<Name, Iface>::type::ResponseT
+ >
+ > call(const data_or_id<typename schema_member_type<Name, Iface>::type::RequestT, Encoding>& input){
+ auto next_free_id = srv_->template next_free_id<typename schema_member_type<Name, Iface>::type::ResponseT>();
+ return srv_->template call<Name>(input, next_free_id);
+ }
+
+};
+
+/**
+ * Rpc Server class for the Sycl backend.
+ */
+template<typename Iface, typename Encoding>
+class rpc_server<Iface, Encoding, rmt::Hip> {
+public:
+ using InterfaceCtxT = cl::sycl::queue*;
+ using InterfaceT = interface<Iface, encode::Sycl<Encoding>, InterfaceCtxT>;
+
+private:
+ /**
+ * Device instance enabling the use of the remote device.
+ */
+ our<device<rmt::Hip>> device_;
+
+ using DataServerT = data_server<typename impl::rpc_iface_type_helper<Iface>::type, Encoding, rmt::Hip>;
+ /**
+ * Data server storing the relevant data
+ */
+ DataServerT* data_server_;
+
+ /**
+ * The interface including the relevant context class.
+ */
+ interface<Iface, Encoding, InterfaceCtxT> cl_interface_;
+
+public:
+
+ /**
+ * Main constructor
+ */
+ rpc_server(our<device<rmt::Hip>> dev__, DataServerT& data_server__, InterfaceT cl_iface):
+ device_{std::move(dev__)},
+ data_server_{&data_server__},
+ cl_interface_{std::move(cl_iface)}
+ {}
+
+ /**
+ * Ask which id the server prefers as the next one. Only available for fast requests on no roundtrip setups.
+ */
+ /**
+ template<typename T>
+ id<T> next_free_id() const {
+ return std::get<id_map<T,Encoding,rmt::Hip>>(storage_.maps).next_free_id();
+ }
+ */
+
+ /**
+ template<typename IdT, typename Storage>
+ remote_data<IdT, Encoding, rmt::Hip> request_data(id<IdT> dat_id){
+ return {dat_id, std::get<id_map<IdT,Encoding,rmt::Hip>>(storage_.maps), device_->get_handle()};
+ }
+ */
+
+ /**
+ * Rpc call based on the name
+ */
+ template<string_literal Name>
+ error_or<
+ id<
+ typename schema_member_type<Name, Iface>::type::ResponseT
+ >
+ > call(data_or_id<typename schema_member_type<Name, Iface>::type::RequestT, Encoding> input, id<typename schema_member_type<Name,Iface>::type::ResponseT> rpc_id){
+ using FuncT = typename schema_member_type<Name, Iface>::type;
+
+ /**
+ * Object needed if and only if the provided data type is not an id
+ */
+ own<data<typename FuncT::RequestT, encode::Sycl<Encoding>>> dev_tmp_inp = nullptr;
+ /**
+ * First check if it's data or an id.
+ * If it's an id, check if it's registered within the storage and retrieve it.
+ */
+ auto eoinp = [&,this]() -> error_or<data<typename FuncT::RequestT, encode::Sycl<Encoding>>* > {
+ if(input.is_id()){
+ // storage_.maps
+ auto eov = data_server_->template find<typename FuncT::RequestT>(input.get_id());
+ if(eov.is_error()){
+ return std::move(eov.get_error());
+ }
+ return eov.get_value();
+ } else {
+ auto& client_data = input.get_data();
+
+ auto eov = device_->template copy_to_device(client_data);
+ if(eov.is_error()){
+ return std::move(eov.get_error());
+ }
+ auto& val = eov.get_value();
+
+ dev_tmp_inp = heap<data<typename FuncT::RequestT, encode::Sycl<Encoding>>>(std::move(val));
+ device_->get_handle().wait();
+ return dev_tmp_inp.get();
+ }
+ }();
+ if(eoinp.is_error()){
+ return std::move(eoinp.get_error());
+ }
+ auto& inp = *(eoinp.get_value());
+
+ auto eod = cl_interface_.template call<Name>(inp, &(device_->get_handle()));
+
+ if(eod.is_error()){
+ return std::move(eod.get_error());
+ }
+
+ auto& val = eod.get_value();
+ /**
+ * Store returned data in rpc storage
+ */
+ auto eoid = data_server_->template insert<typename schema_member_type<Name, Iface>::type::RequestT>(std::move(val), rpc_id);
+ if(eoid.is_error()){
+ return std::move(eoid.get_error());
+ }
+ return rpc_id;
+ }
+};
+}