diff options
Diffstat (limited to 'modules/remote-sycl/benchmarks/mixed_precision.cpp')
-rw-r--r-- | modules/remote-sycl/benchmarks/mixed_precision.cpp | 139 |
1 files changed, 103 insertions, 36 deletions
diff --git a/modules/remote-sycl/benchmarks/mixed_precision.cpp b/modules/remote-sycl/benchmarks/mixed_precision.cpp index b554a1c..e804f4e 100644 --- a/modules/remote-sycl/benchmarks/mixed_precision.cpp +++ b/modules/remote-sycl/benchmarks/mixed_precision.cpp @@ -3,15 +3,69 @@ #include <sstream> -int main(){ +int main(int argc, char** argv){ using namespace saw; uint64_t start_test_size = 1024ul * 1024ul; + + if(argc <= 0 || argc >= 256){ + std::cerr<<"Argument size being weird. Got "<<argc<<" args"<<std::endl; + return -1; + } + + std::vector<std::string_view> args; + args.resize(static_cast<uint64_t>(argc)); + for(uint64_t i = 0; i < args.size(); ++i){ + args.at(i) = {argv[i]}; + } + if(args.size() > 1){ + auto& str = args.at(1); + auto ec = std::from_chars(str.data(), str.data() + str.size(), start_test_size); + if(ec.ec != std::errc{}){ + std::cerr<<"Start size is not an int."<<std::endl; + return -1; + } + } uint64_t max_test_size = start_test_size * 1024ul; + + if(args.size() > 2){ + auto& str = args.at(2); + auto ec = std::from_chars(str.data(), str.data() + str.size(), max_test_size); + if(ec.ec != std::errc{}){ + std::cerr<<"Stop size is not an int."<<std::endl; + return -1; + } + } + + if(start_test_size > max_test_size){ + std::cerr<<"Invalid arguments. Stop size is smaller than Start size."<<std::endl; + return -1; + } + + uint64_t runs = 128ul; + + if(args.size() > 3){ + auto& str = args.at(3); + auto ec = std::from_chars(str.data(), str.data() + str.size(), runs); + if(ec.ec != std::errc{}){ + std::cerr<<"Run size is not an int."<<std::endl; + return -1; + } + } + uint64_t arithmetic_intensity = 1u; + if(args.size() > 4){ + auto& str = args.at(4); + auto ec = std::from_chars(str.data(), str.data() + str.size(), arithmetic_intensity); + if(ec.ec != std::errc{}){ + std::cerr<<"Arithmetic intensity is not an int."<<std::endl; + return -1; + } + } + std::random_device r; std::default_random_engine e1{r()}; - std::uniform_real_distribution<> dis{-1.0,1.0}; + std::uniform_real_distribution<> dis{-3.0,-1.0}; saw::event_loop loop; @@ -34,17 +88,18 @@ int main(){ cl::sycl::event float32_ev; cl::sycl::event float64_ev; - auto sycl_iface = listen_mixed_precision(mixed_ev, float64_ev, float32_ev); + auto sycl_iface = listen_mixed_precision(mixed_ev, float64_ev, float32_ev, arithmetic_intensity); data<sch::MixedArray> mixed_host_data; data<sch::Float64Array> float64_host_data; data<sch::Float32Array> float32_host_data; - auto time_eval = [](std::stringstream& sstr, cl::sycl::event& ev){ + auto time_eval = [](uint64_t & current_min_time, cl::sycl::event& ev){ auto end = ev.get_profiling_info<cl::sycl::info::event_profiling::command_end>(); auto start = ev.get_profiling_info<cl::sycl::info::event_profiling::command_start>(); - sstr<<(end-start) / 1.0e9; + uint64_t curr_time = (end-start); + current_min_time = std::min(curr_time, current_min_time); }; auto& device = rmt_addr->get_device(); @@ -79,41 +134,53 @@ int main(){ * Benchmark */ std::stringstream sstr; - for(uint64_t test_size = start_test_size; test_size < max_test_size; test_size *= 2ul){ - - (std::cout<<'.').flush(); - - data<sch::MixedArray> mixed_host_data; - data<sch::Float64Array> float64_host_data; - data<sch::Float32Array> float32_host_data; - mixed_host_data = {test_size}; - float64_host_data = {test_size}; - float32_host_data = {test_size}; - for(uint64_t i = 0; i < test_size; ++i){ - double gen_num = dis(e1); - mixed_host_data.at(i) = static_cast<double>(gen_num); - float64_host_data.at(i) = static_cast<double>(gen_num); - float32_host_data.at(i) = static_cast<float>(gen_num); + for(uint64_t test_size = start_test_size; test_size <= max_test_size; test_size *= 2ul){ + uint64_t time_mixed = std::numeric_limits<uint64_t>::max(); + uint64_t time_float64 = std::numeric_limits<uint64_t>::max(); + uint64_t time_float32 = std::numeric_limits<uint64_t>::max(); + for(uint64_t runs_i = 0u; runs_i < runs; ++runs_i){ + + (std::cout<<'.').flush(); + + data<sch::MixedArray> mixed_host_data; + data<sch::Float64Array> float64_host_data; + data<sch::Float32Array> float32_host_data; + + mixed_host_data = {test_size}; + float64_host_data = {test_size}; + float32_host_data = {test_size}; + + for(uint64_t i = 0; i < test_size; ++i){ + double gen_num = dis(e1); + mixed_host_data.at(i) = static_cast<double>(gen_num); + float64_host_data.at(i) = static_cast<double>(gen_num); + float32_host_data.at(i) = static_cast<float>(gen_num); + } + + data<sch::MixedArray, encode::Native, rmt::Sycl> mixed_device_data{mixed_host_data}; + data<sch::Float64Array, encode::Native, rmt::Sycl> float64_device_data{float64_host_data}; + data<sch::Float32Array, encode::Native, rmt::Sycl> float32_device_data{float32_host_data}; + + sycl_iface.template call<"float64_32">(mixed_device_data, &(device.get_handle())); + device.get_handle().wait(); + time_eval(time_mixed, mixed_ev); + sycl_iface.template call<"float64">(float64_device_data, &(device.get_handle())); + device.get_handle().wait(); + time_eval(time_float64, float64_ev); + sycl_iface.template call<"float32">(float32_device_data, &(device.get_handle())); + device.get_handle().wait(); + time_eval(time_float32, float32_ev); } - data<sch::MixedArray, encode::Native, rmt::Sycl> mixed_device_data{mixed_host_data}; - data<sch::Float64Array, encode::Native, rmt::Sycl> float64_device_data{float64_host_data}; - data<sch::Float32Array, encode::Native, rmt::Sycl> float32_device_data{float32_host_data}; - - sstr<<test_size<<",\t"; - sycl_iface.template call<"float64_32">(mixed_device_data, &(device.get_handle())); - device.get_handle().wait(); - time_eval(sstr, mixed_ev); + sstr<<test_size; sstr<<",\t"; - sycl_iface.template call<"float64">(float64_device_data, &(device.get_handle())); - device.get_handle().wait(); - time_eval(sstr, float64_ev); + sstr<<time_mixed / 1.0e9; sstr<<",\t"; - sycl_iface.template call<"float32">(float32_device_data, &(device.get_handle())); - device.get_handle().wait(); - time_eval(sstr, float32_ev); - sstr<<'\n'; + sstr<<time_float64 / 1.0e9; + sstr<<",\t"; + sstr<<time_float32 / 1.0e9; + sstr<<"\n"; } - std::cout<<sstr.str()<<std::endl; + std::cout<<'\n'<<'\n'<<sstr.str()<<std::endl; return 0; } |