#pragma once #include #include #include #include namespace saw { namespace rmt { struct Sycl {}; } template<> class remote; /** * Remote data class for the Sycl backend. */ template class remote_data { private: id id_; id_map* map_; public: /** * Main constructor */ remote_data(const id& id, id_map& map): id_{id}, map_{&map} {} /** * Request data asynchronously */ conveyor> on_receive(); /// Stopped here }; /** * */ template class data, encode::Native> { public: using Schema = schema::Primitive; using NativeType = typename native_data_type::type; private: /** * */ NativeType val_; public: /** * */ data(NativeType val__): val_{val__} {} NativeType get(){ return val_; } }; template class data, encode::Native> { public: using Schema = schema::Array; private: uint64_t total_length_; typename native_data_type::type* device_data_; // data* device_data_; cl::sycl::queue* queue_; static_assert(is_primitive::value, "Only supports primitives for now"); static_assert(D==1u, "For now we only support 1D Arrays"); public: data(uint64_t size, cl::sycl::queue& q__): total_length_{size}, device_data_{cl::sycl::malloc_device::type>(size, q__)}, queue_{&q__} { if(!device_data_){ total_length_ = 0u; } } template data(const data& from, cl::sycl::queue& q__): total_length_{from.size()}, device_data_{cl::sycl::malloc_device::type>(from.size(), q__)}, queue_{&q__} { if(!device_data_){ total_length_ = 0u; } } ~data(){ // free data if(device_data_){ /// SYCL FREE cl::sycl::free(device_data_, *queue_); } } data>& at(uint64_t i){ return device_data_[i]; } }; namespace impl { template struct rpc_id_map_helper { static_assert(always_false, "Only support Interface schema types."); }; template struct rpc_id_map_helper, Encoding> { std::tuple...> maps; }; } /** * Rpc Client class for the Sycl backend. */ template class rpc_client { public: private: /** * Server this client is tied to */ rpc_server* srv_; public: rpc_client(rpc_server& srv): srv_{&srv} {} /** * Rpc call */ template error_or< id< typename schema_member_type::type::ResponseT > > call(data_or_id::type::RequestT, ClientEncoding> input){ return make_error("RpcClient side is not implemented"); } }; /** * Rpc Server class for the Sycl backend. */ template class rpc_server { public: using InterfaceCtxT = cl::sycl::queue*; using InterfaceT = interface; private: /** * Command queue for the sycl backend */ cl::sycl::queue cmd_queue_; /** * The interface including the relevant context class. */ interface cl_interface_; /** * Basic storage for response data. */ impl::rpc_id_map_helper storage_; public: rpc_server(interface cl_iface): cmd_queue_{}, cl_interface_{std::move(cl_iface)}, storage_{} {} template remote_data request_data(id dat){ return {dat, std::get>(storage_.maps)}; } /** * Rpc call */ template error_or< id< typename schema_member_type::type::ResponseT > > call(data_or_id::type::RequestT, ClientAllocation> input){ /** * First check if it's data or an id. * If it's an id, check if it's registered within the storage and retrieve it. */ auto eoinp = [&,this]() -> error_or::type::RequestT, Encoding>* > { if(input.is_id()){ // storage_.maps auto& inner_map = std::get::type::RequestT, Encoding >> (storage_.maps); auto eov = inner_map.find(input.get_id()); if(eov.is_error()){ return std::move(eov.get_error()); } return eov.get_value(); } else { return &input.get_data(); } }(); if(eoinp.is_error()){ return std::move(eoinp.get_error()); } auto& inp = *(eoinp.get_value()); auto eod = cl_interface_.template call(std::move(inp), &cmd_queue_); if(eod.is_error()){ return std::move(eod.get_error()); } /** * Store returned data in rpc storage */ auto& val = eod.get_value(); auto& inner_map = std::get::type::RequestT, Encoding >> (storage_.maps); auto eoid = inner_map.insert(std::move(val)); if(eoid.is_error()){ return std::move(eoid.get_error()); } return eoid.get_value(); } }; template<> struct remote_address { private: remote* ctx_; SAW_FORBID_COPY(remote_address); SAW_FORBID_MOVE(remote_address); public: remote_address(remote& r_ctx): ctx_{&r_ctx} {} }; template<> class remote { private: SAW_FORBID_COPY(remote); SAW_FORBID_MOVE(remote); public: /** * Default constructor */ remote(){} /** * For now we don't need to specify the location since * we just create a default. */ conveyor>> resolve_address(){ return heap>(*this); } /** * Spin up a rpc server */ template rpc_server listen(const remote_address&, typename rpc_server::InterfaceT iface){ using RpcServerT = rpc_server; using InterfaceT = typename RpcServerT::InterfaceT; return {std::move(iface)}; } }; }