summaryrefslogtreecommitdiff
path: root/modules/remote-hip/c++/remote.hpp
blob: 794d6298a9ed1fa9a6379c7395351f2c7733e6eb (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
#pragma once

#include "common.hpp"
#include "device.hpp"

namespace saw {

template<>
struct remote_address<rmt::Hip> {
private:
	uint64_t dev_id_;

	SAW_FORBID_COPY(remote_address);
	SAW_FORBID_MOVE(remote_address);
public:
	remote_address(uint64_t id):
		dev_id_{id}
	{}

	uint64_t get_device_id() const {
		return dev_id_;
	}
};

template<>
class remote<rmt::Hip> {
private:
	SAW_FORBID_COPY(remote);
	SAW_FORBID_MOVE(remote);

	struct key_t {
		uint64_t device_id;
		uint32_t sch_id;
		uint32_t enc_id;
		
		bool operator<(const key_t& rhs) const {
			if(device_id != rhs.device_id){
				return device_id < rhs.device_id;
			}
			if(sch_id != rhs.sch_id){
				return sch_id < rhs.sch_id;
			}
			if(enc_id != rhs.enc_id){
				return enc_id < rhs.enc_id;
			}
			return false;
		}
	};

	std::map<uint64_t, our<device<rmt::Hip>>> devs_;
	std::map<key_t, ptr<i_data_server<rmt::Hip>>> reg_dat_srvs_;
public:
	/**
	 * Default constructor
	 */
	remote(){}

	/**
	 * For now we don't need to specify the location since
	 * we just create a default.
	 */
	conveyor<own<remote_address<rmt::Hip>>> resolve_address(uint64_t dev_id = 0u){
		return heap<remote_address<rmt::Hip>>(dev_id);
	}

	/**
	 * Info.
	 * Will be removed. 
	 */
	std::string get_info() const {
		std::stringstream sstr;
		int dev_count;
		hipGetDeviceCount(&dev_count);
		for(int i = 0; i < dev_count; ++i){
			hipSetDevice(i);

			hipDeviceProp_t props{};
			hipGetDeviceProperties(&props, i);

			sstr << "Name: " << props.name << '\n';
			sstr << "totalGlobalMem: "<< static_cast<double>(props.totalGlobalMem / (1024.0 * 1024.0 * 1024.0) ) << " GiB\n";
			sstr << "sharedMemPerBlock: "<< static_cast<double>(props.sharedMemPerBlock / (1024.0) ) << " KiB\n";
			sstr << "clockRate: " << static_cast<double>(props.clockRate / 1000.0) << " Mhz\n";
			/*
			sstr					<< "totalGlobalMem: " << props.totalGlobalMem			<< " GiB\n";
			sstr << "sharedMemPerBlock: " << props.sharedMemPerBlock								<< " KiB\n";
			sstr << "regsPerBlock: " << props.regsPerBlock << '\n';
			sstr << "warpSize: " << props.warpSize << '\n';
			sstr << "maxThreadsPerBlock: " << props.maxThreadsPerBlock << '\n';
			sstr << "maxThreadsDim: "								<< "(" << props.maxThreadsDim[0] << ", " << props.maxThreadsDim[1] << ", "								<< props.maxThreadsDim[2] << ")\n";
			sstr << "maxGridSize: "								<< "(" << props.maxGridSize[0] << ", " << props.maxGridSize[1] << ", "								<< props.maxGridSize[2] << ")\n";
			sstr << "clockRate: " << props.clockRate << " Mhz\n";
			*/
		}

		return sstr.str();
	}

	/**
	 * Parse address, but don't resolve it.
	 */
	error_or<own<remote_address<rmt::Hip>>> parse_address(uint64_t dev_id = 0u){
		return heap<remote_address<rmt::Hip>>(dev_id);
	}

	/**
	 * Spin up data server
	 */
	template<typename Schema, typename Encoding>
	error_or<own<data_server<Schema, Encoding, rmt::Hip>>> data_listen(const remote_address<rmt::Hip>& dev_addr){
		our<device<rmt::Hip>> dev = nullptr;
		auto ins = devs_.emplace(std::make_pair(dev_addr.get_device_id(), our<device<rmt::Hip>>{nullptr}));
		if(ins.second){
			ins.first->second = share<device<rmt::Hip>>();
		}
		return heap<data_server<Schema, Encoding, rmt::Hip>>(ins.first->second);
	}

	/**
	 * Spin up a rpc server
	 */
	template<typename Iface, typename Encoding>
	rpc_server<Iface, Encoding, rmt::Hip> listen(remote_address<rmt::Hip>& dev, typename rpc_server<Iface, Encoding, rmt::Hip>::InterfaceT iface){
		//using RpcServerT = rpc_server<Iface, Encoding, rmt::Hip>;
		//using InterfaceT = typename RpcServerT::InterfaceT;
		return {share<device<rmt::Hip>>(), std::move(iface)};
	}
};

}