summaryrefslogtreecommitdiff
path: root/modules/remote-sycl/benchmarks/mixed_precision.cpp
blob: b979b0c47f605c0d3bca50eb2cfe2bd6744905d2 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
#include "./mixed_precision.hpp"
#include <forstio/codec/schema.hpp>

#include <sstream>

int main(){
	using namespace saw;

	constexpr uint64_t max_test_size = 1024ul * 1024ul * 256ul;

	std::random_device r;
	std::default_random_engine e1{r()};
	std::uniform_real_distribution<> dis{-1.0,1.0};


	saw::event_loop loop;
	saw::wait_scope wait{loop};

	remote<rmt::Sycl> rmt;

	own<remote_address<rmt::Sycl>> rmt_addr{};

	rmt.resolve_address().then([&](auto addr){
		rmt_addr = std::move(addr);
	}).detach();

	wait.poll();
	if(!rmt_addr){
		return -1;
	}
	
	cl::sycl::event mixed_ev;
	cl::sycl::event float32_ev;
	cl::sycl::event float64_ev;
	
	auto sycl_iface = listen_mixed_precision(mixed_ev, float64_ev, float32_ev);
		
	data<sch::MixedArray> mixed_host_data;
	data<sch::Float64Array> float64_host_data;
	data<sch::Float32Array> float32_host_data;

	auto time_eval = [](std::stringstream& sstr, cl::sycl::event& ev){
		auto end = ev.get_profiling_info<cl::sycl::info::event_profiling::command_end>();
		auto start = ev.get_profiling_info<cl::sycl::info::event_profiling::command_start>();

		sstr<<(end-start) / 1.0e9;
	};
	
	auto& device = rmt_addr->get_device();
	
	/**
	 * Warmup
	 */
	std::cout<<"Warming up ..."<<std::endl;
	for(uint64_t test_size = 1024ul; test_size < max_test_size; test_size *= 2ul){
	
		mixed_host_data = {test_size};
		float64_host_data = {test_size};
		float32_host_data = {test_size};
		for(uint64_t i = 0; i < test_size; ++i){
			double gen_num = dis(e1);
			mixed_host_data.at(i) = static_cast<double>(gen_num);
			float64_host_data.at(i) = static_cast<double>(gen_num);
			float32_host_data.at(i) = static_cast<float>(gen_num);
		}
		data<sch::MixedArray, encode::Native, rmt::Sycl> mixed_device_data{mixed_host_data};
		data<sch::Float64Array, encode::Native, rmt::Sycl> float64_device_data{float64_host_data};
		data<sch::Float32Array, encode::Native, rmt::Sycl> float32_device_data{float32_host_data};
		
		sycl_iface.template call<"float64_32">(mixed_device_data, &(device.get_handle()));
		sycl_iface.template call<"float64">(float64_device_data, &(device.get_handle()));
		sycl_iface.template call<"float32">(float32_device_data, &(device.get_handle()));
		device.get_handle().wait();
	}

	std::cout<<"Benchmark starting ..."<<std::endl;
	/**
	 * Benchmark
	 */
	std::stringstream sstr;
	for(uint64_t test_size = 1ul; test_size < max_test_size; test_size *= 2ul){
		data<sch::MixedArray> mixed_host_data;
		data<sch::Float64Array> float64_host_data;
		data<sch::Float32Array> float32_host_data;
		mixed_host_data = {test_size};
		float64_host_data = {test_size};
		float32_host_data = {test_size};
		for(uint64_t i = 0; i < test_size; ++i){
			double gen_num = dis(e1);
			mixed_host_data.at(i) = static_cast<double>(gen_num);
			float64_host_data.at(i) = static_cast<double>(gen_num);
			float32_host_data.at(i) = static_cast<float>(gen_num);
		}
		data<sch::MixedArray, encode::Native, rmt::Sycl> mixed_device_data{mixed_host_data};
		data<sch::Float64Array, encode::Native, rmt::Sycl> float64_device_data{float64_host_data};
		data<sch::Float32Array, encode::Native, rmt::Sycl> float32_device_data{float32_host_data};
		
		sstr<<test_size<<",\t";
		sycl_iface.template call<"float64_32">(mixed_device_data, &(device.get_handle()));
		device.get_handle().wait();
		time_eval(sstr, mixed_ev);
		sstr<<",\t";
		sycl_iface.template call<"float64">(float64_device_data, &(device.get_handle()));
		device.get_handle().wait();
		time_eval(sstr, float64_ev);
		sstr<<",\t";
		sycl_iface.template call<"float32">(float32_device_data, &(device.get_handle()));
		device.get_handle().wait();
		time_eval(sstr, float32_ev);
		sstr<<'\n';
	}
	std::cout<<sstr.str()<<std::endl;

	return 0;
}