diff options
author | Claudius "keldu" Holeksa <mail@keldu.de> | 2024-06-14 14:33:22 +0200 |
---|---|---|
committer | Claudius "keldu" Holeksa <mail@keldu.de> | 2024-06-14 14:33:22 +0200 |
commit | 5329652f839b99b95d63cd471ff73d251f74d911 (patch) | |
tree | 20797eb65f3e48686979362f828a6e07b21e9b5a /modules/remote-sycl | |
parent | 57f6eacfcdbdba31185eb66b9a573a8923eecf16 (diff) |
Fixed calc of sycl vals
Diffstat (limited to 'modules/remote-sycl')
-rw-r--r-- | modules/remote-sycl/c++/remote.hpp | 47 | ||||
-rw-r--r-- | modules/remote-sycl/examples/SConscript | 1 | ||||
-rw-r--r-- | modules/remote-sycl/examples/sycl_basic.cpp | 41 | ||||
-rw-r--r-- | modules/remote-sycl/examples/sycl_basic_kernel.cpp | 3 |
4 files changed, 79 insertions, 13 deletions
diff --git a/modules/remote-sycl/c++/remote.hpp b/modules/remote-sycl/c++/remote.hpp index bcc8a3c..4510237 100644 --- a/modules/remote-sycl/c++/remote.hpp +++ b/modules/remote-sycl/c++/remote.hpp @@ -22,22 +22,34 @@ class remote_data<T, Encoding, Storage, rmt::Sycl> { private: id<T> id_; id_map<T,Encoding,rmt::Sycl>* map_; - cl::sycl::queue* queue_; public: /** * Main constructor */ remote_data(const id<T>& id, id_map<T, Encoding, rmt::Sycl>& map, cl::sycl::queue& queue__): id_{id}, - map_{&map}, - queue_{&queue__} + map_{&map} {} /** * Wait for the data */ error_or<data<T,Encoding,Storage>> wait(){ + auto eov = map_->find(id_); + if(eov.is_error()){ + auto& err = eov.get_error(); + return std::move(err); + } + auto& val = eov.get_value(); + std::cout<<"Values Sycl in Map: "<<val->size()<<std::endl; + { + auto eocop = val->template copy_to_host<storage::Default>(); + if(eocop.is_error()){ + return eocop; + } + return eocop.get_value(); + } } /** @@ -87,7 +99,9 @@ public: { if(!device_data_){ total_length_ = 0u; + return; } + queue_->wait(); } template<typename Encoding, typename Storage> @@ -102,6 +116,7 @@ public: return; } queue_->template copy<data<T,encode::Native,storage::Default>>(&from.at(0), device_data_, total_length_); + queue_->wait(); } data(const data<Schema, encode::Native, rmt::Sycl>& from): @@ -116,7 +131,10 @@ public: // device_data_ = cl::sycl::malloc_device<typename native_data_type<T>::type>(from.size(), *queue_); if(!device_data_){ total_length_ = 0u; + return; } + + queue_->template copy<data<T,encode::Native,storage::Default>>(from.device_data_, device_data_, total_length_); } data<Schema, encode::Native, rmt::Sycl>& operator=(const data<Schema, encode::Native, rmt::Sycl>& rhs) { @@ -158,7 +176,20 @@ public: } } - data<T, encode::Native, saw::storage::Default>& at(uint64_t i){ + /** + * Allocate appropriate meta data and then copy to host + */ + template<typename Storage> + error_or<data<Schema, encode::Native, Storage>> copy_to_host() const { + data<Schema,encode::Native, Storage> data_{total_length_}; + + /// TODO Check success + queue_->template copy<data<T,encode::Native,storage::Default>>(device_data_, &data_.at(0), total_length_); + queue_->wait(); + return data_; + } + + data<T, encode::Native, storage::Default>& at(uint64_t i){ //typename native_data_type<T>::type& at(uint64_t i){ return device_data_[i]; } @@ -247,6 +278,9 @@ public: return std::get<id_map<T,Encoding,rmt::Sycl>>(storage_.maps).next_free_id(); } + /** + * Main constructor + */ rpc_server(interface<Iface, Encoding, rmt::Sycl, InterfaceCtxT> cl_iface): cmd_queue_{}, cl_interface_{std::move(cl_iface)}, @@ -255,12 +289,11 @@ public: template<typename IdT, typename Storage> remote_data<IdT, Encoding, Storage, rmt::Sycl> request_data(id<IdT> dat_id){ - /// @TODO Fix so I can receive data - return {dat_id, std::get<id_map<IdT, Encoding,rmt::Sycl>>(storage_.maps)}; + return {dat_id, std::get<id_map<IdT,Encoding,rmt::Sycl>>(storage_.maps), cmd_queue_}; } /** - * Rpc call + * Rpc call based on the name */ template<string_literal Name, typename ClientAllocation> error_or< diff --git a/modules/remote-sycl/examples/SConscript b/modules/remote-sycl/examples/SConscript index 02e528b..015b492 100644 --- a/modules/remote-sycl/examples/SConscript +++ b/modules/remote-sycl/examples/SConscript @@ -15,6 +15,7 @@ examples_env = env.Clone(); examples_sycl_env = examples_env.Clone(); examples_sycl_env['CXX'] = 'acpp'; +examples_sycl_env['CXXFLAGS'] += ['-O2']; examples_env.sources = sorted(glob.glob(dir_path + "/*.cpp")) examples_env.headers = sorted(glob.glob(dir_path + "/*.hpp")) diff --git a/modules/remote-sycl/examples/sycl_basic.cpp b/modules/remote-sycl/examples/sycl_basic.cpp index 2e9a4f8..486aca1 100644 --- a/modules/remote-sycl/examples/sycl_basic.cpp +++ b/modules/remote-sycl/examples/sycl_basic.cpp @@ -22,8 +22,12 @@ int main(){ saw::rpc_client<schema::BasicInterface, saw::encode::Native, saw::storage::Default, saw::rmt::Sycl> client{rpc_server}; saw::id<schema::Array<schema::UInt64>> id_zero{0u}; + saw::data<schema::Array<schema::UInt64>, saw::encode::Native> ex_data{1u}; + ex_data.at(0u).set(50u); { - auto eov = client.template call<"increment">(saw::data<schema::Array<schema::UInt64>, saw::encode::Native>{1u}); + auto eov = client.template call<"increment">( + ex_data + ); if(eov.is_error()){ auto& err = eov.get_error(); std::cerr<<"Error: "<<err.get_category()<<" : "<<err.get_message()<<std::endl; @@ -32,17 +36,46 @@ int main(){ id_zero = eov.get_value(); } { + auto rmt_data = rpc_server.request_data<schema::Array<schema::UInt64>, saw::storage::Default>(id_zero); + auto eo_rd = rmt_data.wait(); + if(eo_rd.is_error()){ + auto& err = eo_rd.get_error(); + std::cerr<<"Error: "<<err.get_category()<<" : "<<err.get_message()<<std::endl; + return -2; + } + + auto& val = eo_rd.get_value(); + std::cout<<"Values: "<<val.size()<<"\n"; + for(uint64_t i = 0; i < val.size(); ++i){ + std::cout<<val.at(i).get()<<'\t'; + } + std::cout<<std::endl; + } + saw::id<schema::Array<schema::UInt64>> id_one{1u}; + { auto eov = client.template call<"increment">(id_zero); if(eov.is_error()){ auto& err = eov.get_error(); std::cerr<<"Error: "<<err.get_category()<<" : "<<err.get_message()<<std::endl; return -2; } - auto& val = eov.get_value(); - std::cout<<"Value: "<<val.get_value()<<std::endl; + id_one = eov.get_value(); } { - // auto eo_rd = rpc_server.request_data(id_one); + auto rmt_data = rpc_server.request_data<schema::Array<schema::UInt64>, saw::storage::Default>(id_one); + auto eo_rd = rmt_data.wait(); + if(eo_rd.is_error()){ + auto& err = eo_rd.get_error(); + std::cerr<<"Error: "<<err.get_category()<<" : "<<err.get_message()<<std::endl; + return -2; + } + + auto& val = eo_rd.get_value(); + std::cout<<"Values: "<<val.size()<<"\n"; + for(uint64_t i = 0; i < val.size(); ++i){ + std::cout<<val.at(i).get()<<'\t'; + } + std::cout<<std::endl; } return 0; diff --git a/modules/remote-sycl/examples/sycl_basic_kernel.cpp b/modules/remote-sycl/examples/sycl_basic_kernel.cpp index 888f905..03f0bac 100644 --- a/modules/remote-sycl/examples/sycl_basic_kernel.cpp +++ b/modules/remote-sycl/examples/sycl_basic_kernel.cpp @@ -4,9 +4,8 @@ saw::rpc_server<schema::BasicInterface, saw::encode::Native, saw::rmt::Sycl> lis saw::interface<schema::BasicInterface, saw::encode::Native, saw::rmt::Sycl, cl::sycl::queue*> iface{ [](saw::data<saw::schema::Array<saw::schema::UInt64>, saw::encode::Native, saw::rmt::Sycl> in, cl::sycl::queue* q) -> saw::data<saw::schema::Array<saw::schema::UInt64>, saw::encode::Native, saw::rmt::Sycl> { - q->submit([&](cl::sycl::handler& h){ - h.parallel_for(cl::sycl::range<1>(1u), [&] (cl::sycl::id<1> it){ + h.single_task([&] (){ in.at(0u).set(in.at(0u).get() + 1u); }); }); |