diff options
author | Claudius "keldu" Holeksa <mail@keldu.de> | 2024-06-20 16:35:25 +0200 |
---|---|---|
committer | Claudius "keldu" Holeksa <mail@keldu.de> | 2024-06-20 16:35:25 +0200 |
commit | 601113a445658d8b15273dd91c66cf20daf50d30 (patch) | |
tree | bcb6c2a77e85bb64d6beb9b3f93a5f7bc5a6e400 /modules/remote-sycl | |
parent | c1d352270add2f205d038d7e4f69c1b4f35f014d (diff) |
Changing towards a better allocated structure for sycl
Diffstat (limited to 'modules/remote-sycl')
-rw-r--r-- | modules/remote-sycl/c++/remote.hpp | 225 | ||||
-rw-r--r-- | modules/remote-sycl/examples/sycl_basic_kernel.cpp | 6 | ||||
-rw-r--r-- | modules/remote-sycl/tests/calculator.cpp | 4 |
3 files changed, 94 insertions, 141 deletions
diff --git a/modules/remote-sycl/c++/remote.hpp b/modules/remote-sycl/c++/remote.hpp index 1873669..54b7a7b 100644 --- a/modules/remote-sycl/c++/remote.hpp +++ b/modules/remote-sycl/c++/remote.hpp @@ -18,169 +18,96 @@ template<typename T, typename Encoding, typename Storage> class remote_data<T, Encoding, Storage, rmt::Sycl> { private: /** - * Id representing the remote data + * An identifier to the data being held on the remote */ - id<T> id_; + id<T> data_id_; + /** - * Storage for the + * The sycl queue object */ - id_map<T,Encoding,rmt::Sycl>* map_; + cl::sycl::queue* queue_; public: /** * Main constructor */ + remote_data(data<T,Encoding,Storage>& remote_data__, cl::sycl::queue& queue__): + remote_data_{&remote_data__}, + queue_{&queue__} + {} + + /** + * Destructor specifically designed to deallocate on the device. + */ + ~remote_data(){ + if(remote_data_){ + cl::sycl::free(remote_data_,queue_); + remote_data_ = nullptr; + } + } + + SAW_FORBID_COPY(remote_data); + SAW_FORBID_MOVE(remote_data); + /** remote_data(const id<T>& id, id_map<T, Encoding, rmt::Sycl>& map, cl::sycl::queue& queue__): id_{id}, map_{&map} {} + */ /** * Wait for the data */ error_or<data<T,Encoding,Storage>> wait(){ - auto eov = map_->find(id_); - if(eov.is_error()){ - auto& err = eov.get_error(); - return std::move(err); - } - auto& val = eov.get_value(); - std::cout<<"Values Sycl in Map: "<<val->size()<<std::endl; - - { - auto eocop = val->template copy_to_host<storage::Default>(); - if(eocop.is_error()){ - return eocop; - } - return eocop.get_value(); - } + return make_error<err::not_implemented>(); } /** * Request data asynchronously */ - conveyor<data<T,Encoding,Storage>> on_receive(); /// Stopped here + // conveyor<data<T,Encoding,Storage>> on_receive(); /// Stopped here }; /** - * Sycl data class for handling the array Schema. + * Meant to be a helper object which holds the allocated data on the sycl side */ -template<typename T, uint64_t D> -class data<schema::Array<T,D>, encode::Native, rmt::Sycl> { -public: - using Schema = schema::Array<T,D>; +template<typename Schema, typename Encoding, typename Backend> +class device_data; + +/** + * This class helps in regards to the ownership on the server side + */ +template<typename Schema, typename Encoding> +class device_data<Schema, Encoding, rmt::Sycl> { private: /** - * Absolute size of the stored elementes. - */ - uint64_t total_length_; - /** - * The data itself. + * The actual data */ - data<T,encode::Native,storage::Default>* device_data_; + data<Schema,Encoding,Storage>* device_data_; /** - * Referenced sycl queue + * The sycl queue object */ cl::sycl::queue* queue_; - - static_assert(is_primitive<T>::value, "Only supports primitives for now"); - static_assert(D==1u, "For now we only support 1D Arrays"); public: - data(uint64_t size, cl::sycl::queue& q__): - total_length_{size}, - device_data_{cl::sycl::malloc_device<data<T,encode::Native,storage::Default>>(size, q__)}, - queue_{&q__} - { - if(!device_data_){ - total_length_ = 0u; - return; - } - queue_->wait(); - } - - template<typename Encoding, typename Storage> - data(const data<Schema, Encoding, Storage>& from, cl::sycl::queue& q__): - total_length_{from.size()}, - device_data_{cl::sycl::malloc_device<data<T,encode::Native,storage::Default>>(from.size(), q__)}, - queue_{&q__} - { - if(!device_data_){ - total_length_ = 0u; - return; - } - queue_->template copy<data<T,encode::Native,storage::Default>>(&from.at(0), device_data_, total_length_); - queue_->wait(); - } - - data(const data<Schema, encode::Native, rmt::Sycl>& from): - total_length_{from.size()}, - device_data_{nullptr}, - queue_{from.queue_} - { - if(total_length_ == 0u || !queue_){ - return; - } - device_data_ = cl::sycl::malloc_device<data<T,encode::Native,storage::Default>>(from.size(), *queue_); - // device_data_ = cl::sycl::malloc_device<typename native_data_type<T>::type>(from.size(), *queue_); - if(!device_data_){ - total_length_ = 0u; - return; - } - - queue_->template copy<data<T,encode::Native,storage::Default>>(from.device_data_, device_data_, total_length_); - } + /** + * Main constructor + */ + device_data(data<Schema,Encoding,Storage>& device_data__, cl::sycl::queue& queue__): + device_data_{&device_data__}, + queue_{&queue__} + {} - data(data<Schema, encode::Native, rmt::Sycl>&& rhs): - total_length_{rhs.total_length_}, - device_data_{rhs.device_data_}, - queue_{rhs.queue_} - { - rhs.total_length_ = 0u; - rhs.device_data_ = nullptr; - rhs.queue_ = nullptr; - } - - data<Schema, encode::Native, rmt::Sycl>& operator=(data<Schema, encode::Native, rmt::Sycl>&& rhs){ - total_length_ = rhs.total_length_; - device_data_ = rhs.device_data_; - queue_ = rhs.queue_; - rhs.total_length_ = 0u; - rhs.device_data_ = nullptr; - rhs.queue_ = nullptr; - return *this; - } - - ~data(){ - // free data - if(device_data_){ - /// SYCL FREE - cl::sycl::free(device_data_, *queue_); - } - } - /** - * Allocate appropriate meta data and then copy to host + * Destructor specifically designed to deallocate on the device. */ - template<typename Storage> - error_or<data<Schema, encode::Native, Storage>> copy_to_host() const { - data<Schema,encode::Native, Storage> data_{total_length_}; - - /// TODO Check success - queue_->template copy<data<T,encode::Native,storage::Default>>(device_data_, &data_.at(0), total_length_); - queue_->wait(); - return data_; - } - - template<typename Storage> - static error_or<data<Schema, encode::Native, rmt::Sycl>> copy_to_device(const data<Schema, encode::Native, Storage>& host_data, device<rmt::Sycl>& dev); - - - data<T, encode::Native, storage::Default>& at(uint64_t i){ - return device_data_[i]; + ~device_data(){ + if(device_data_){ + cl::sycl::free(device_data_,queue_); + device_data_ = nullptr; + } } - uint64_t size() const { - return total_length_; - } + SAW_FORBID_COPY(device_data); + SAW_FORBID_MOVE(device_data); }; namespace impl { @@ -214,7 +141,7 @@ public: */ template<typename Schema, typename Encoding, typename Storage> error_or<data<Schema, Encoding, rmt::Sycl>> copy_to_device(const data<Schema, Encoding, Storage>& host_data){ - return data<Schema, Encoding, rmt::Sycl>::copy_to_device(host_data); + return data<Schema, Encoding, rmt::Sycl>::copy_to_device(host_data, *this); } /** @@ -280,7 +207,7 @@ template<typename Iface, typename Encoding> class rpc_server<Iface, Encoding, rmt::Sycl> { public: using InterfaceCtxT = cl::sycl::queue*; - using InterfaceT = interface<Iface, Encoding, rmt::Sycl, InterfaceCtxT>; + using InterfaceT = interface<Iface, Encoding, storage::Default, InterfaceCtxT>; private: /** * Device instance enabling the use of the remote device. @@ -290,18 +217,18 @@ private: /** * The interface including the relevant context class. */ - interface<Iface, Encoding, rmt::Sycl, InterfaceCtxT> cl_interface_; + interface<Iface, Encoding, storage::Default, InterfaceCtxT> cl_interface_; /** * Basic storage for response data. */ - impl::rpc_id_map_helper<Iface, Encoding, rmt::Sycl> storage_; + // impl::rpc_id_map_helper<Iface, Encoding, rmt::Sycl> storage_; public: /** * Main constructor */ - rpc_server(device<rmt::Sycl>& dev__, interface<Iface, Encoding, rmt::Sycl, InterfaceCtxT> cl_iface): + rpc_server(device<rmt::Sycl>& dev__, InterfaceT cl_iface): device_{&dev__}, cl_interface_{std::move(cl_iface)}, storage_{} @@ -310,31 +237,35 @@ public: /** * Ask which id the server prefers as the next one. Only available for fast requests on no roundtrip setups. */ + /** template<typename T> id<T> next_free_id() const { return std::get<id_map<T,Encoding,rmt::Sycl>>(storage_.maps).next_free_id(); } + */ + /** template<typename IdT, typename Storage> remote_data<IdT, Encoding, Storage, rmt::Sycl> request_data(id<IdT> dat_id){ return {dat_id, std::get<id_map<IdT,Encoding,rmt::Sycl>>(storage_.maps), device_->get_handle()}; } + */ /** * Rpc call based on the name */ - template<string_literal Name, typename ClientAllocation> + template<string_literal Name> error_or< id< typename schema_member_type<Name, Iface>::type::ResponseT > - > call(data_or_id<typename schema_member_type<Name, Iface>::type::RequestT, Encoding, ClientAllocation> input, id<typename schema_member_type<Name,Iface>::type::ResponseT> rpc_id){ + > call(data_or_id<typename schema_member_type<Name, Iface>::type::RequestT, Encoding, storage::Default> input, id<typename schema_member_type<Name,Iface>::type::ResponseT> rpc_id){ using FuncT = typename schema_member_type<Name, Iface>::type; /** * Object needed if and only if the provided data type is not an id */ - own<data<typename FuncT::RequestT, Encoding, rmt::Sycl>> dev_tmp_inp = nullptr; + own<device_data<typename FuncT::RequestT, Encoding, rmt::Sycl>> dev_tmp_inp = nullptr; /** * First check if it's data or an id. * If it's an id, check if it's registered within the storage and retrieve it. @@ -350,7 +281,14 @@ public: return eov.get_value(); } else { auto& client_data = input.get_data(); - dev_tmp_inp = heap<data<typename FuncT::RequestT, Encoding, rmt::Sycl>>(client_data, device_->get_handle()); + + auto eov = device_->template copy_to_device(client_data); + if(eov.is_error()){ + return std::move(eov.get_error()); + } + auto& val = eov.get_value(); + + dev_tmp_inp = heap<data<typename FuncT::RequestT, Encoding, rmt::Sycl>>(std::move(val)); device_->get_handle().wait(); return dev_tmp_inp.get(); } @@ -434,8 +372,19 @@ public: template<typename T, uint64_t D> template<typename Storage> error_or<data<schema::Array<T,D>, encode::Native, rmt::Sycl>> data<schema::Array<T,D>, encode::Native, rmt::Sycl>::copy_to_device(const data<schema::Array<T,D>, encode::Native, Storage>& host_data, device<rmt::Sycl>& dev){ - data<schema::Array<T,D>, encode::Native, rmt::Sycl> device_data{host_data.size(), dev.get_handle()}; - return make_error<err::not_implemented>(); -} + /** + * Retrieve handle + */ + auto& cmd_handle = dev.get_handle(); + + uint64_t* dev_len = cl::sycl::malloc_device<uint64_t>(1u, cmd_handle); + uint64_t len = host_data.size(); + cmd_handle.template copy<uint64_t>(&len,dev_len, 1u); + auto dev_dat = cl::sycl::malloc_device<data<T,encode::Native,storage::Default>>(host_data.size(), cmd_handle); + cmd_handle.copy(&host_data.at(0), dev_dat, host_data.size()); + cmd_handle.wait(); + + return data<schema::Array<T,D>,encode::Native, rmt::Sycl>{dev_len, dev_dat, cmd_handle}; +} } diff --git a/modules/remote-sycl/examples/sycl_basic_kernel.cpp b/modules/remote-sycl/examples/sycl_basic_kernel.cpp index 6481eb9..f9a838e 100644 --- a/modules/remote-sycl/examples/sycl_basic_kernel.cpp +++ b/modules/remote-sycl/examples/sycl_basic_kernel.cpp @@ -2,9 +2,13 @@ saw::rpc_server<schema::BasicInterface, saw::encode::Native, saw::rmt::Sycl> listen_basic_sycl(saw::remote<saw::rmt::Sycl>& ctx, saw::device<saw::rmt::Sycl>& dev, saw::remote_address<saw::rmt::Sycl>& addr){ saw::interface<schema::BasicInterface, saw::encode::Native, saw::rmt::Sycl, cl::sycl::queue*> iface{ + /** + * This is the increment kernel + */ + [](saw::data<saw::schema::Array<saw::schema::UInt64>, saw::encode::Native, saw::rmt::Sycl>& in, cl::sycl::queue* q) -> saw::error_or<void> { - [](saw::data<saw::schema::Array<saw::schema::UInt64>, saw::encode::Native, saw::rmt::Sycl> in, cl::sycl::queue* q) -> saw::data<saw::schema::Array<saw::schema::UInt64>, saw::encode::Native, saw::rmt::Sycl> { q->submit([&](cl::sycl::handler& h){ + h.single_task([&] (){ in.at(0u).set(in.at(0u).get() + 1u); }); diff --git a/modules/remote-sycl/tests/calculator.cpp b/modules/remote-sycl/tests/calculator.cpp index 730838d..6d061ad 100644 --- a/modules/remote-sycl/tests/calculator.cpp +++ b/modules/remote-sycl/tests/calculator.cpp @@ -21,7 +21,7 @@ SAW_TEST("Sycl Interface Calculator"){ cl::sycl::queue cmd_queue; interface<schema::Calculator, encode::Native<storage::Default>, cl::sycl::queue*> cl_iface { -[](data<schema::Tuple<schema::Int64, schema::Int64>> in, cl::sycl::queue* cmd) -> data<schema::Int64> { +[](data<schema::Tuple<schema::Int64, schema::Int64>>& in, cl::sycl::queue* cmd) -> data<schema::Int64> { std::array<int64_t,2> h_xy{in.get<0>().get(), in.get<1>().get()}; int64_t res{}; cl::sycl::buffer<int64_t,1> d_xy { h_xy.data(), h_xy.size() }; @@ -37,7 +37,7 @@ SAW_TEST("Sycl Interface Calculator"){ cmd->wait(); return data<schema::Int64>{res}; }, - [](data<schema::Tuple<schema::Int64, schema::Int64>> in, cl::sycl::queue* cmd) -> data<schema::Int64> { + [](data<schema::Tuple<schema::Int64, schema::Int64>,encode::Native,rmt::Sycl>& in, cl::sycl::queue* cmd) -> data<schema::Int64> { return data<schema::Int64>{in.get<0>().get() * in.get<1>().get()}; } }; |