#pragma once #include "common.hpp" #include "data.hpp" namespace saw { template<> class remote; template class device; template class data { private: cl::sycl::buffer> data_; public: data(data& data__): data_{&data__, 1u} {} auto& get_handle() { return data_; } template auto access(cl::sycl::handler& h){ return data_.template get_access(h); } }; /** * Remote data class for the Sycl backend. */ template class remote_data final { private: /** * An identifier to the data being held on the remote */ id data_id_; /** * The sycl queue object */ cl::sycl::queue* queue_; public: /** * Main constructor */ remote_data(id data_id__, cl::sycl::queue& queue__): data_id_{data_id__}, queue_{&queue__} {} /** * Destructor specifically designed to deallocate on the device. */ ~remote_data(){} SAW_FORBID_COPY(remote_data); SAW_FORBID_MOVE(remote_data); /** remote_data(const id& id, id_map& map, cl::sycl::queue& queue__): id_{id}, map_{&map} {} */ /** * Wait for the data */ error_or> wait(){ return make_error(); } /** * Request data asynchronously */ // conveyor> on_receive(); /// Stopped here }; /** * Meant to be a helper object which holds the allocated data on the sycl side */ template class device_data; /** * This class helps in regards to the ownership on the server side */ template class device_data { private: /** * The actual data */ data* device_data_; /** * The sycl queue object */ cl::sycl::queue* queue_; public: /** * Main constructor */ device_data(data& device_data__, cl::sycl::queue& queue__): device_data_{&device_data__}, queue_{&queue__} {} /** * Destructor specifically designed to deallocate on the device. */ ~device_data(){ if(device_data_){ cl::sycl::free(device_data_,queue_); device_data_ = nullptr; } } SAW_FORBID_COPY(device_data); SAW_FORBID_MOVE(device_data); }; namespace impl { template struct device_id_map { std::vector> data; }; template struct rpc_id_map_helper { static_assert(always_false, "Only supports Interface schema types."); }; template struct rpc_id_map_helper, Encoding, Storage> { std::tuple...> maps; }; } } // Maybe a helper impl tmpl file? namespace saw { /** * Represents a remote Sycl device. */ template<> class device final { private: cl::sycl::queue cmd_queue_; public: device() = default; SAW_FORBID_COPY(device); SAW_FORBID_MOVE(device); /** * Copy data to device */ template error_or> copy_to_device(const data& host_data){ return data::copy_to_device(host_data, *this); } /** * Copy data to host */ template error_or> copy_to_host(const data& device_data){ return device_data.copy_to_host(); } /** * Get a reference to the handle */ cl::sycl::queue& get_handle(){ return cmd_queue_; } }; /** * Rpc Client class for the Sycl backend. */ template class rpc_client { public: private: /** * Server this client is tied to */ rpc_server* srv_; /** * Generated some sort of id for the request. */ public: rpc_client(rpc_server& srv): srv_{&srv} {} /** * Rpc call */ template error_or< id< typename schema_member_type::type::ResponseT > > call(const data_or_id::type::RequestT, Encoding, Storage>& input){ auto next_free_id = srv_->template next_free_id::type::ResponseT>(); return srv_->template call(input, next_free_id); } }; /** * Rpc Server class for the Sycl backend. */ template class rpc_server { public: using InterfaceCtxT = cl::sycl::queue*; using InterfaceT = interface; private: /** * Device instance enabling the use of the remote device. */ device* device_; /** * The interface including the relevant context class. */ interface cl_interface_; /** * Basic storage for response data. */ impl::rpc_id_map_helper storage_; public: /** * Main constructor */ rpc_server(device& dev__, InterfaceT cl_iface): device_{&dev__}, cl_interface_{std::move(cl_iface)}, storage_{} {} /** * Ask which id the server prefers as the next one. Only available for fast requests on no roundtrip setups. */ /** template id next_free_id() const { return std::get>(storage_.maps).next_free_id(); } */ /** template remote_data request_data(id dat_id){ return {dat_id, std::get>(storage_.maps), device_->get_handle()}; } */ /** * Rpc call based on the name */ template error_or< id< typename schema_member_type::type::ResponseT > > call(data_or_id::type::RequestT, Encoding, storage::Default> input, id::type::ResponseT> rpc_id){ using FuncT = typename schema_member_type::type; /** * Object needed if and only if the provided data type is not an id */ own> dev_tmp_inp = nullptr; /** * First check if it's data or an id. * If it's an id, check if it's registered within the storage and retrieve it. */ auto eoinp = [&,this]() -> error_or* > { if(input.is_id()){ // storage_.maps auto& inner_map = std::get> (storage_.maps); auto eov = inner_map.find(input.get_id()); if(eov.is_error()){ return std::move(eov.get_error()); } return eov.get_value(); } else { auto& client_data = input.get_data(); auto eov = device_->template copy_to_device(client_data); if(eov.is_error()){ return std::move(eov.get_error()); } auto& val = eov.get_value(); dev_tmp_inp = heap>(std::move(val)); device_->get_handle().wait(); return dev_tmp_inp.get(); } }(); if(eoinp.is_error()){ return std::move(eoinp.get_error()); } auto& inp = *(eoinp.get_value()); auto eod = cl_interface_.template call(inp, &(device_->get_handle())); if(eod.is_error()){ return std::move(eod.get_error()); } auto& val = eod.get_value(); /** * Store returned data in rpc storage */ auto& inner_map = std::get::type::RequestT, Encoding,rmt::Sycl>> (storage_.maps); auto eoid = inner_map.insert_as(std::move(val), rpc_id); if(eoid.is_error()){ return std::move(eoid.get_error()); } return rpc_id; } }; template<> struct remote_address { private: remote* ctx_; SAW_FORBID_COPY(remote_address); SAW_FORBID_MOVE(remote_address); public: remote_address(remote& r_ctx): ctx_{&r_ctx} {} }; template<> class remote { private: SAW_FORBID_COPY(remote); SAW_FORBID_MOVE(remote); public: /** * Default constructor */ remote(){} /** * For now we don't need to specify the location since * we just create a default. */ conveyor>> resolve_address(){ return heap>(*this); } /** * Connect to a device */ device connect_device(const remote_address&){ return {}; } /** * Spin up a rpc server */ template rpc_server listen(device& dev, typename rpc_server::InterfaceT iface){ using RpcServerT = rpc_server; using InterfaceT = typename RpcServerT::InterfaceT; return {dev, std::move(iface)}; } }; }