diff options
author | Claudius "keldu" Holeksa <mail@keldu.de> | 2024-06-18 16:35:39 +0200 |
---|---|---|
committer | Claudius "keldu" Holeksa <mail@keldu.de> | 2024-06-18 16:35:39 +0200 |
commit | 80b706332a48f54ae289093ee11b17f20ab2dc2e (patch) | |
tree | dc5d54e9aa583e9d2c7519958b3735a73d60f85b /modules/remote-sycl/c++/remote.hpp | |
parent | dee0184f2bedfb3919309c9f372bd0cdec520e9d (diff) |
Changed some setup designs
Diffstat (limited to 'modules/remote-sycl/c++/remote.hpp')
-rw-r--r-- | modules/remote-sycl/c++/remote.hpp | 93 |
1 files changed, 50 insertions, 43 deletions
diff --git a/modules/remote-sycl/c++/remote.hpp b/modules/remote-sycl/c++/remote.hpp index 4a383d1..1873669 100644 --- a/modules/remote-sycl/c++/remote.hpp +++ b/modules/remote-sycl/c++/remote.hpp @@ -1,19 +1,16 @@ #pragma once -#include <forstio/io_codec/rpc.hpp> -#include <forstio/codec/data.hpp> -#include <forstio/codec/id_map.hpp> - -#include <AdaptiveCpp/CL/sycl.hpp> +#include "common.hpp" +#include "data.hpp" namespace saw { -namespace rmt { -struct Sycl {}; -} template<> class remote<rmt::Sycl>; +template<typename T> +class device; + /** * Remote data class for the Sycl backend. */ @@ -65,32 +62,24 @@ public: }; /** - * -template<typename T, uint64_t N> -class data<schema::Primitive<T,N>, encode::Native, rmt::Sycl> { -public: - using Schema = schema::Primitive<T,N>; - using NativeType = typename native_data_type<Schema>::type; -private: - NativeType val_; -public: - data(NativeType val__): - val_{val__} - {} - - NativeType get(){ - return val_; - } -}; + * Sycl data class for handling the array Schema. */ - template<typename T, uint64_t D> class data<schema::Array<T,D>, encode::Native, rmt::Sycl> { public: using Schema = schema::Array<T,D>; private: + /** + * Absolute size of the stored elementes. + */ uint64_t total_length_; + /** + * The data itself. + */ data<T,encode::Native,storage::Default>* device_data_; + /** + * Referenced sycl queue + */ cl::sycl::queue* queue_; static_assert(is_primitive<T>::value, "Only supports primitives for now"); @@ -181,6 +170,10 @@ public: return data_; } + template<typename Storage> + static error_or<data<Schema, encode::Native, rmt::Sycl>> copy_to_device(const data<Schema, encode::Native, Storage>& host_data, device<rmt::Sycl>& dev); + + data<T, encode::Native, storage::Default>& at(uint64_t i){ return device_data_[i]; } @@ -202,34 +195,41 @@ struct rpc_id_map_helper<schema::Interface<Members...>, Encoding, Storage> { }; } -template<typename T> -class device; - /** * Represents a remote Sycl device. * */ template<> -class device<rmt::Sycl> { +class device<rmt::Sycl> final { private: cl::sycl::queue cmd_queue_; public: + device() = default; + + SAW_FORBID_COPY(device); + SAW_FORBID_MOVE(device); + /** * Copy data to device */ - template<typename Schema, typename Encoding, typename Storage, typename EncodingHost, typename StorageHost> - error_or<data<Schema, Encoding, Storage>> copy_to_device(const data<Schema, EncodingHost, StorageHost>>& host_data){ - (void) host_data; - return make_error<err::not_implemented>(); + template<typename Schema, typename Encoding, typename Storage> + error_or<data<Schema, Encoding, rmt::Sycl>> copy_to_device(const data<Schema, Encoding, Storage>& host_data){ + return data<Schema, Encoding, rmt::Sycl>::copy_to_device(host_data); } /** * Copy data to host */ - template<typename Schema, typename Encoding, typename Storage, typename EncodingDevice, typename StorageDevice> - error_or<data<Schema, Encoding, Storage>> copy_to_host(const data<Schema, EncodingDevice, StorageDevice>>& device_data){ - (void) device_data; - return make_error<err::not_implemented>(); + template<typename Schema, typename Encoding, typename Storage> + error_or<data<Schema, Encoding, Storage>> copy_to_host(const data<Schema, Encoding, rmt::Sycl>& device_data){ + return device_data.copy_to_host(); + } + + /** + * Get a reference to the handle + */ + cl::sycl::queue& get_handle(){ + return cmd_queue_; } }; @@ -317,7 +317,7 @@ public: template<typename IdT, typename Storage> remote_data<IdT, Encoding, Storage, rmt::Sycl> request_data(id<IdT> dat_id){ - return {dat_id, std::get<id_map<IdT,Encoding,rmt::Sycl>>(storage_.maps), device_.cmd_queue_}; + return {dat_id, std::get<id_map<IdT,Encoding,rmt::Sycl>>(storage_.maps), device_->get_handle()}; } /** @@ -350,8 +350,8 @@ public: return eov.get_value(); } else { auto& client_data = input.get_data(); - dev_tmp_inp = heap<data<typename FuncT::RequestT, Encoding, rmt::Sycl>>(client_data, cmd_queue_); - cmd_queue_.wait(); + dev_tmp_inp = heap<data<typename FuncT::RequestT, Encoding, rmt::Sycl>>(client_data, device_->get_handle()); + device_->get_handle().wait(); return dev_tmp_inp.get(); } }(); @@ -360,7 +360,7 @@ public: } auto& inp = *(eoinp.get_value()); - auto eod = cl_interface_.template call<Name>(inp, &cmd_queue_); + auto eod = cl_interface_.template call<Name>(inp, &(device_->get_handle())); if(eod.is_error()){ return std::move(eod.get_error()); @@ -424,11 +424,18 @@ public: * Spin up a rpc server */ template<typename Iface, typename Encoding> - rpc_server<Iface, Encoding, rmt::Sycl> listen(const device<rmt::Sycl>& dev, typename rpc_server<Iface, Encoding, rmt::Sycl>::InterfaceT iface){ + rpc_server<Iface, Encoding, rmt::Sycl> listen(device<rmt::Sycl>& dev, typename rpc_server<Iface, Encoding, rmt::Sycl>::InterfaceT iface){ using RpcServerT = rpc_server<Iface, Encoding, rmt::Sycl>; using InterfaceT = typename RpcServerT::InterfaceT; return {dev, std::move(iface)}; } }; +template<typename T, uint64_t D> +template<typename Storage> +error_or<data<schema::Array<T,D>, encode::Native, rmt::Sycl>> data<schema::Array<T,D>, encode::Native, rmt::Sycl>::copy_to_device(const data<schema::Array<T,D>, encode::Native, Storage>& host_data, device<rmt::Sycl>& dev){ + data<schema::Array<T,D>, encode::Native, rmt::Sycl> device_data{host_data.size(), dev.get_handle()}; + return make_error<err::not_implemented>(); +} + } |