From 601113a445658d8b15273dd91c66cf20daf50d30 Mon Sep 17 00:00:00 2001 From: "Claudius \"keldu\" Holeksa" Date: Thu, 20 Jun 2024 16:35:25 +0200 Subject: Changing towards a better allocated structure for sycl --- modules/remote-sycl/c++/remote.hpp | 225 ++++++++++++++----------------------- 1 file changed, 87 insertions(+), 138 deletions(-) (limited to 'modules/remote-sycl/c++/remote.hpp') diff --git a/modules/remote-sycl/c++/remote.hpp b/modules/remote-sycl/c++/remote.hpp index 1873669..54b7a7b 100644 --- a/modules/remote-sycl/c++/remote.hpp +++ b/modules/remote-sycl/c++/remote.hpp @@ -18,169 +18,96 @@ template class remote_data { private: /** - * Id representing the remote data + * An identifier to the data being held on the remote */ - id id_; + id data_id_; + /** - * Storage for the + * The sycl queue object */ - id_map* map_; + cl::sycl::queue* queue_; public: /** * Main constructor */ + remote_data(data& remote_data__, cl::sycl::queue& queue__): + remote_data_{&remote_data__}, + queue_{&queue__} + {} + + /** + * Destructor specifically designed to deallocate on the device. + */ + ~remote_data(){ + if(remote_data_){ + cl::sycl::free(remote_data_,queue_); + remote_data_ = nullptr; + } + } + + SAW_FORBID_COPY(remote_data); + SAW_FORBID_MOVE(remote_data); + /** remote_data(const id& id, id_map& map, cl::sycl::queue& queue__): id_{id}, map_{&map} {} + */ /** * Wait for the data */ error_or> wait(){ - auto eov = map_->find(id_); - if(eov.is_error()){ - auto& err = eov.get_error(); - return std::move(err); - } - auto& val = eov.get_value(); - std::cout<<"Values Sycl in Map: "<size()<template copy_to_host(); - if(eocop.is_error()){ - return eocop; - } - return eocop.get_value(); - } + return make_error(); } /** * Request data asynchronously */ - conveyor> on_receive(); /// Stopped here + // conveyor> on_receive(); /// Stopped here }; /** - * Sycl data class for handling the array Schema. + * Meant to be a helper object which holds the allocated data on the sycl side */ -template -class data, encode::Native, rmt::Sycl> { -public: - using Schema = schema::Array; +template +class device_data; + +/** + * This class helps in regards to the ownership on the server side + */ +template +class device_data { private: /** - * Absolute size of the stored elementes. - */ - uint64_t total_length_; - /** - * The data itself. + * The actual data */ - data* device_data_; + data* device_data_; /** - * Referenced sycl queue + * The sycl queue object */ cl::sycl::queue* queue_; - - static_assert(is_primitive::value, "Only supports primitives for now"); - static_assert(D==1u, "For now we only support 1D Arrays"); public: - data(uint64_t size, cl::sycl::queue& q__): - total_length_{size}, - device_data_{cl::sycl::malloc_device>(size, q__)}, - queue_{&q__} - { - if(!device_data_){ - total_length_ = 0u; - return; - } - queue_->wait(); - } - - template - data(const data& from, cl::sycl::queue& q__): - total_length_{from.size()}, - device_data_{cl::sycl::malloc_device>(from.size(), q__)}, - queue_{&q__} - { - if(!device_data_){ - total_length_ = 0u; - return; - } - queue_->template copy>(&from.at(0), device_data_, total_length_); - queue_->wait(); - } - - data(const data& from): - total_length_{from.size()}, - device_data_{nullptr}, - queue_{from.queue_} - { - if(total_length_ == 0u || !queue_){ - return; - } - device_data_ = cl::sycl::malloc_device>(from.size(), *queue_); - // device_data_ = cl::sycl::malloc_device::type>(from.size(), *queue_); - if(!device_data_){ - total_length_ = 0u; - return; - } - - queue_->template copy>(from.device_data_, device_data_, total_length_); - } + /** + * Main constructor + */ + device_data(data& device_data__, cl::sycl::queue& queue__): + device_data_{&device_data__}, + queue_{&queue__} + {} - data(data&& rhs): - total_length_{rhs.total_length_}, - device_data_{rhs.device_data_}, - queue_{rhs.queue_} - { - rhs.total_length_ = 0u; - rhs.device_data_ = nullptr; - rhs.queue_ = nullptr; - } - - data& operator=(data&& rhs){ - total_length_ = rhs.total_length_; - device_data_ = rhs.device_data_; - queue_ = rhs.queue_; - rhs.total_length_ = 0u; - rhs.device_data_ = nullptr; - rhs.queue_ = nullptr; - return *this; - } - - ~data(){ - // free data - if(device_data_){ - /// SYCL FREE - cl::sycl::free(device_data_, *queue_); - } - } - /** - * Allocate appropriate meta data and then copy to host + * Destructor specifically designed to deallocate on the device. */ - template - error_or> copy_to_host() const { - data data_{total_length_}; - - /// TODO Check success - queue_->template copy>(device_data_, &data_.at(0), total_length_); - queue_->wait(); - return data_; - } - - template - static error_or> copy_to_device(const data& host_data, device& dev); - - - data& at(uint64_t i){ - return device_data_[i]; + ~device_data(){ + if(device_data_){ + cl::sycl::free(device_data_,queue_); + device_data_ = nullptr; + } } - uint64_t size() const { - return total_length_; - } + SAW_FORBID_COPY(device_data); + SAW_FORBID_MOVE(device_data); }; namespace impl { @@ -214,7 +141,7 @@ public: */ template error_or> copy_to_device(const data& host_data){ - return data::copy_to_device(host_data); + return data::copy_to_device(host_data, *this); } /** @@ -280,7 +207,7 @@ template class rpc_server { public: using InterfaceCtxT = cl::sycl::queue*; - using InterfaceT = interface; + using InterfaceT = interface; private: /** * Device instance enabling the use of the remote device. @@ -290,18 +217,18 @@ private: /** * The interface including the relevant context class. */ - interface cl_interface_; + interface cl_interface_; /** * Basic storage for response data. */ - impl::rpc_id_map_helper storage_; + // impl::rpc_id_map_helper storage_; public: /** * Main constructor */ - rpc_server(device& dev__, interface cl_iface): + rpc_server(device& dev__, InterfaceT cl_iface): device_{&dev__}, cl_interface_{std::move(cl_iface)}, storage_{} @@ -310,31 +237,35 @@ public: /** * Ask which id the server prefers as the next one. Only available for fast requests on no roundtrip setups. */ + /** template id next_free_id() const { return std::get>(storage_.maps).next_free_id(); } + */ + /** template remote_data request_data(id dat_id){ return {dat_id, std::get>(storage_.maps), device_->get_handle()}; } + */ /** * Rpc call based on the name */ - template + template error_or< id< typename schema_member_type::type::ResponseT > - > call(data_or_id::type::RequestT, Encoding, ClientAllocation> input, id::type::ResponseT> rpc_id){ + > call(data_or_id::type::RequestT, Encoding, storage::Default> input, id::type::ResponseT> rpc_id){ using FuncT = typename schema_member_type::type; /** * Object needed if and only if the provided data type is not an id */ - own> dev_tmp_inp = nullptr; + own> dev_tmp_inp = nullptr; /** * First check if it's data or an id. * If it's an id, check if it's registered within the storage and retrieve it. @@ -350,7 +281,14 @@ public: return eov.get_value(); } else { auto& client_data = input.get_data(); - dev_tmp_inp = heap>(client_data, device_->get_handle()); + + auto eov = device_->template copy_to_device(client_data); + if(eov.is_error()){ + return std::move(eov.get_error()); + } + auto& val = eov.get_value(); + + dev_tmp_inp = heap>(std::move(val)); device_->get_handle().wait(); return dev_tmp_inp.get(); } @@ -434,8 +372,19 @@ public: template template error_or, encode::Native, rmt::Sycl>> data, encode::Native, rmt::Sycl>::copy_to_device(const data, encode::Native, Storage>& host_data, device& dev){ - data, encode::Native, rmt::Sycl> device_data{host_data.size(), dev.get_handle()}; - return make_error(); -} + /** + * Retrieve handle + */ + auto& cmd_handle = dev.get_handle(); + + uint64_t* dev_len = cl::sycl::malloc_device(1u, cmd_handle); + uint64_t len = host_data.size(); + cmd_handle.template copy(&len,dev_len, 1u); + auto dev_dat = cl::sycl::malloc_device>(host_data.size(), cmd_handle); + cmd_handle.copy(&host_data.at(0), dev_dat, host_data.size()); + cmd_handle.wait(); + + return data,encode::Native, rmt::Sycl>{dev_len, dev_dat, cmd_handle}; +} } -- cgit v1.2.3