1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
|
#pragma once
#include "common.hpp"
#include "device.hpp"
namespace saw {
template<>
struct remote_address<rmt::Hip> {
private:
uint64_t dev_id_;
SAW_FORBID_COPY(remote_address);
SAW_FORBID_MOVE(remote_address);
public:
remote_address(uint64_t id):
dev_id_{id}
{}
uint64_t get_device_id() const {
return dev_id_;
}
};
template<>
class remote<rmt::Hip> {
private:
SAW_FORBID_COPY(remote);
SAW_FORBID_MOVE(remote);
struct key_t {
uint64_t device_id;
uint32_t sch_id;
uint32_t enc_id;
bool operator<(const key_t& rhs) const {
if(device_id != rhs.device_id){
return device_id < rhs.device_id;
}
if(sch_id != rhs.sch_id){
return sch_id < rhs.sch_id;
}
if(enc_id != rhs.enc_id){
return enc_id < rhs.enc_id;
}
return false;
}
};
std::map<uint64_t, our<device<rmt::Hip>>> devs_;
std::map<key_t, ptr<i_data_server<rmt::Hip>>> reg_dat_srvs_;
public:
/**
* Default constructor
*/
remote(){}
/**
* For now we don't need to specify the location since
* we just create a default.
*/
conveyor<own<remote_address<rmt::Hip>>> resolve_address(uint64_t dev_id = 0u){
return heap<remote_address<rmt::Hip>>(dev_id);
}
/**
* Info.
* Will be removed.
*/
std::string get_info() const {
std::stringstream sstr;
int dev_count;
hipGetDeviceCount(&dev_count);
for(int i = 0; i < dev_count; ++i){
hipSetDevice(i);
hipDeviceProp_t props{};
hipGetDeviceProperties(&props, i);
sstr << "Name: " << props.name << '\n';
/*
sstr << "totalGlobalMem: " << props.totalGlobalMem << " GiB\n";
sstr << "sharedMemPerBlock: " << props.sharedMemPerBlock << " KiB\n";
sstr << "regsPerBlock: " << props.regsPerBlock << '\n';
sstr << "warpSize: " << props.warpSize << '\n';
sstr << "maxThreadsPerBlock: " << props.maxThreadsPerBlock << '\n';
sstr << "maxThreadsDim: " << "(" << props.maxThreadsDim[0] << ", " << props.maxThreadsDim[1] << ", " << props.maxThreadsDim[2] << ")\n";
sstr << "maxGridSize: " << "(" << props.maxGridSize[0] << ", " << props.maxGridSize[1] << ", " << props.maxGridSize[2] << ")\n";
sstr << "clockRate: " << props.clockRate << " Mhz\n";
*/
}
return sstr.str();
}
/**
* Parse address, but don't resolve it.
*/
error_or<own<remote_address<rmt::Hip>>> parse_address(uint64_t dev_id = 0u){
return heap<remote_address<rmt::Hip>>(dev_id);
}
/**
* Spin up data server
*/
template<typename Schema, typename Encoding>
error_or<own<data_server<Schema, Encoding, rmt::Hip>>> data_listen(const remote_address<rmt::Hip>& dev_addr){
our<device<rmt::Hip>> dev = nullptr;
auto ins = devs_.emplace(std::make_pair(dev_addr.get_device_id(), our<device<rmt::Hip>>{nullptr}));
if(ins.second){
ins.first->second = share<device<rmt::Hip>>();
}
return heap<data_server<Schema, Encoding, rmt::Hip>>(ins.first->second);
}
/**
* Spin up a rpc server
*/
template<typename Iface, typename Encoding>
rpc_server<Iface, Encoding, rmt::Hip> listen(remote_address<rmt::Hip>& dev, typename rpc_server<Iface, Encoding, rmt::Hip>::InterfaceT iface){
//using RpcServerT = rpc_server<Iface, Encoding, rmt::Hip>;
//using InterfaceT = typename RpcServerT::InterfaceT;
return {share<device<rmt::Hip>>(), std::move(iface)};
}
};
}
|