#pragma once #include "common.hpp" #include "device.hpp" namespace saw { template<> struct remote_address { private: uint64_t dev_id_; SAW_FORBID_COPY(remote_address); SAW_FORBID_MOVE(remote_address); public: remote_address(uint64_t id): dev_id_{id} {} uint64_t get_device_id() const { return dev_id_; } }; template<> class remote { private: SAW_FORBID_COPY(remote); SAW_FORBID_MOVE(remote); struct key_t { uint64_t device_id; uint32_t sch_id; uint32_t enc_id; bool operator<(const key_t& rhs) const { if(device_id != rhs.device_id){ return device_id < rhs.device_id; } if(sch_id != rhs.sch_id){ return sch_id < rhs.sch_id; } if(enc_id != rhs.enc_id){ return enc_id < rhs.enc_id; } return false; } }; std::map>> devs_; std::map>> reg_dat_srvs_; public: /** * Default constructor */ remote(){} /** * For now we don't need to specify the location since * we just create a default. */ conveyor>> resolve_address(uint64_t dev_id = 0u){ return heap>(dev_id); } /** * Info. * Will be removed. */ std::string get_info() const { std::stringstream sstr; int dev_count; hipGetDeviceCount(&dev_count); for(int i = 0; i < dev_count; ++i){ hipSetDevice(i); hipDeviceProp_t props{}; hipGetDeviceProperties(&props, i); sstr << "Name: " << props.name << '\n'; sstr << "totalGlobalMem: "<< static_cast(props.totalGlobalMem / (1024.0 * 1024.0 * 1024.0) ) << "GiB\n"; /* sstr << "totalGlobalMem: " << props.totalGlobalMem << " GiB\n"; sstr << "sharedMemPerBlock: " << props.sharedMemPerBlock << " KiB\n"; sstr << "regsPerBlock: " << props.regsPerBlock << '\n'; sstr << "warpSize: " << props.warpSize << '\n'; sstr << "maxThreadsPerBlock: " << props.maxThreadsPerBlock << '\n'; sstr << "maxThreadsDim: " << "(" << props.maxThreadsDim[0] << ", " << props.maxThreadsDim[1] << ", " << props.maxThreadsDim[2] << ")\n"; sstr << "maxGridSize: " << "(" << props.maxGridSize[0] << ", " << props.maxGridSize[1] << ", " << props.maxGridSize[2] << ")\n"; sstr << "clockRate: " << props.clockRate << " Mhz\n"; */ } return sstr.str(); } /** * Parse address, but don't resolve it. */ error_or>> parse_address(uint64_t dev_id = 0u){ return heap>(dev_id); } /** * Spin up data server */ template error_or>> data_listen(const remote_address& dev_addr){ our> dev = nullptr; auto ins = devs_.emplace(std::make_pair(dev_addr.get_device_id(), our>{nullptr})); if(ins.second){ ins.first->second = share>(); } return heap>(ins.first->second); } /** * Spin up a rpc server */ template rpc_server listen(remote_address& dev, typename rpc_server::InterfaceT iface){ //using RpcServerT = rpc_server; //using InterfaceT = typename RpcServerT::InterfaceT; return {share>(), std::move(iface)}; } }; }