diff options
Diffstat (limited to 'modules/remote-sycl/benchmarks/mixed_precision_alternative.cpp')
| -rw-r--r-- | modules/remote-sycl/benchmarks/mixed_precision_alternative.cpp | 58 |
1 files changed, 42 insertions, 16 deletions
diff --git a/modules/remote-sycl/benchmarks/mixed_precision_alternative.cpp b/modules/remote-sycl/benchmarks/mixed_precision_alternative.cpp index e1a1e90..4afb29e 100644 --- a/modules/remote-sycl/benchmarks/mixed_precision_alternative.cpp +++ b/modules/remote-sycl/benchmarks/mixed_precision_alternative.cpp @@ -1,29 +1,55 @@ #include "../c++/data.hpp" +#include <random> + namespace sch { using namespace saw::schema; } template<typename T> void inner_work(){ - acpp::sycl::queue sycl_q; - - constexpr uint64_t dat_size = 10000u; - - data<sch::Array<T>, encode::Sycl<encode::Native>> dat{{{dat_size}},sycl_q}; - data<sch::Ref<sch::Array<T>>, encode::Sycl<encode::Native>> dat_ref{dat}; - auto dat_ptr = dat_ref.get_internal_data(); + std::random_device r; + std::default_random_engine e1{r()}; + std::uniform_real_distribution<> dis{-3.0,-1.0}; - sycl_q.parallel_for(dat_size, [=](acpp::sycl::id<1> idx){ - size_t i = idx[0]; - - dat_ptr[i] = {i}; - }).wait(); - - for(uint64_t i = 0u; i < dat_size; ++i){ - SAW_EXPECT(dat_ptr[i].get() == i, std::string{"Unexpected value: "} + std::to_string(i)); + acpp::sycl::queue sycl_q; + acpp::sycl::event ev; + + auto time_eval = [](uint64_t & current_min_time, acpp::sycl::event& evt){ + auto end = evt.get_profiling_info<acpp::sycl::info::event_profiling::command_end>(); + auto start = evt.get_profiling_info<acpp::sycl::info::event_profiling::command_start>(); + + uint64_t curr_time = (end-start); + current_min_time = std::min(curr_time, current_min_time); + }; + + constexpr uint64_t arithmetic_intensity = 1024ul; + + /** + * Warmup + */ + std::cout<<"Warming up ..."<<std::endl; + for(uint64_t test_size = 1ul; test_size < max_test_size; test_size *= 2ul){ + data<sch::Array<T>, encode::Sycl<encode::Native>> dat{{{test_size}},sycl_q}; + data<sch::Ref<sch::Array<T>>, encode::Sycl<encode::Native>> dat_ref{dat}; + auto dat_ptr = dat_ref.get_internal_data(); + + for(uint64_t i = 0; i < test_size; ++i){ + double gen_num = dis(e1); + dat.at({{i}}) = {static_cast<double>(gen_num)}; + } + + sycl_q.parallel_for([=](acpp::sycl::id<1> idx){ + data<T::InterfaceSchema> foo = {dat_ptr[idx[0u]].get()}; + for(uint64_t i = 0; i < arithmetic_intensity; ++i){ + if( foo.get() == 1.1e12 ){ + dat_ptr[idx[0u]] = {}; + } + foo = foo + foo * saw::data<T::InterfaceSchema>{1.7342345}; + } + dat_ptr[idx[0u]] = foo; + }).wait(); } - } int main(){ |
