summaryrefslogtreecommitdiff
path: root/modules/remote-hip/c++/remote.hpp
blob: b814ff27de107857f003031e745a245182f4f201 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
#pragma once

#include "common.hpp"
#include "device.hpp"

namespace saw {

template<>
struct remote_address<rmt::Hip> {
private:
	uint64_t dev_id_;

	SAW_FORBID_COPY(remote_address);
	SAW_FORBID_MOVE(remote_address);
public:
	remote_address(uint64_t id):
		dev_id_{id}
	{}

	uint64_t get_device_id() const {
		return dev_id_;
	}
};

template<>
class remote<rmt::Hip> {
private:
	SAW_FORBID_COPY(remote);
	SAW_FORBID_MOVE(remote);

	struct key_t {
		uint64_t device_id;
		uint32_t sch_id;
		uint32_t enc_id;
		
		bool operator<(const key_t& rhs) const {
			if(device_id != rhs.device_id){
				return device_id < rhs.device_id;
			}
			if(sch_id != rhs.sch_id){
				return sch_id < rhs.sch_id;
			}
			if(enc_id != rhs.enc_id){
				return enc_id < rhs.enc_id;
			}
			return false;
		}
	};

	std::map<uint64_t, our<device<rmt::Hip>>> devs_;
	std::map<key_t, ptr<i_data_server<rmt::Hip>>> reg_dat_srvs_;
public:
	/**
	 * Default constructor
	 */
	remote(){}

	/**
	 * For now we don't need to specify the location since
	 * we just create a default.
	 */
	conveyor<own<remote_address<rmt::Hip>>> resolve_address(uint64_t dev_id = 0u){
		return heap<remote_address<rmt::Hip>>(dev_id);
	}

	/**
	 * Info.
	 * Will be removed. 
	 */
	std::string get_info() const {
		std::stringstream sstr;
		int dev_count;
		hipGetDeviceCount(&dev_count);
		for(int i = 0; i < dev_count; ++i){
			hipSetDevice(i);

			hipDeviceProp_t props{};
			hipGetDeviceProperties(&props, i);

			sstr << "Name: " << props.name << '\n';
			/*
			sstr					<< "totalGlobalMem: " << props.totalGlobalMem			<< " GiB\n";
			sstr << "sharedMemPerBlock: " << props.sharedMemPerBlock								<< " KiB\n";
			sstr << "regsPerBlock: " << props.regsPerBlock << '\n';
			sstr << "warpSize: " << props.warpSize << '\n';
			sstr << "maxThreadsPerBlock: " << props.maxThreadsPerBlock << '\n';
			sstr << "maxThreadsDim: "								<< "(" << props.maxThreadsDim[0] << ", " << props.maxThreadsDim[1] << ", "								<< props.maxThreadsDim[2] << ")\n";
			sstr << "maxGridSize: "								<< "(" << props.maxGridSize[0] << ", " << props.maxGridSize[1] << ", "								<< props.maxGridSize[2] << ")\n";
			sstr << "clockRate: " << props.clockRate << " Mhz\n";
			*/
		}

		return sstr.str();
	}

	/**
	 * Parse address, but don't resolve it.
	 */
	error_or<own<remote_address<rmt::Hip>>> parse_address(uint64_t dev_id = 0u){
		return heap<remote_address<rmt::Hip>>(dev_id);
	}

	/**
	 * Spin up data server
	 */
	template<typename Schema, typename Encoding>
	error_or<own<data_server<Schema, Encoding, rmt::Hip>>> data_listen(remote_address<rmt::Hip>& dev){
		our<device<rmt::Hip>> dev = nullptr;
		auto ins = devs_.emplace(std::make_pair(dev.get_device_id(), our<device<rmt::Hip>>{nullptr}));
		if(ins.second){
			ins.first->second = share<device<rmt::Hip>>();
		}
		return heap<data_server<Schema, Encoding, rmt::Hip>>(ins.first->second);
	}

	/**
	 * Spin up a rpc server
	 */
	template<typename Iface, typename Encoding>
	rpc_server<Iface, Encoding, rmt::Hip> listen(remote_address<rmt::Hip>& dev, typename rpc_server<Iface, Encoding, rmt::Hip>::InterfaceT iface){
		//using RpcServerT = rpc_server<Iface, Encoding, rmt::Hip>;
		//using InterfaceT = typename RpcServerT::InterfaceT;
		return {share<device<rmt::Hip>>(), std::move(iface)};
	}
};

}