From 80b706332a48f54ae289093ee11b17f20ab2dc2e Mon Sep 17 00:00:00 2001 From: "Claudius \"keldu\" Holeksa" Date: Tue, 18 Jun 2024 16:35:39 +0200 Subject: Changed some setup designs --- modules/remote-sycl/c++/remote.hpp | 93 ++++++++++++++++++++------------------ 1 file changed, 50 insertions(+), 43 deletions(-) (limited to 'modules/remote-sycl/c++/remote.hpp') diff --git a/modules/remote-sycl/c++/remote.hpp b/modules/remote-sycl/c++/remote.hpp index 4a383d1..1873669 100644 --- a/modules/remote-sycl/c++/remote.hpp +++ b/modules/remote-sycl/c++/remote.hpp @@ -1,19 +1,16 @@ #pragma once -#include -#include -#include - -#include +#include "common.hpp" +#include "data.hpp" namespace saw { -namespace rmt { -struct Sycl {}; -} template<> class remote; +template +class device; + /** * Remote data class for the Sycl backend. */ @@ -65,32 +62,24 @@ public: }; /** - * -template -class data, encode::Native, rmt::Sycl> { -public: - using Schema = schema::Primitive; - using NativeType = typename native_data_type::type; -private: - NativeType val_; -public: - data(NativeType val__): - val_{val__} - {} - - NativeType get(){ - return val_; - } -}; + * Sycl data class for handling the array Schema. */ - template class data, encode::Native, rmt::Sycl> { public: using Schema = schema::Array; private: + /** + * Absolute size of the stored elementes. + */ uint64_t total_length_; + /** + * The data itself. + */ data* device_data_; + /** + * Referenced sycl queue + */ cl::sycl::queue* queue_; static_assert(is_primitive::value, "Only supports primitives for now"); @@ -181,6 +170,10 @@ public: return data_; } + template + static error_or> copy_to_device(const data& host_data, device& dev); + + data& at(uint64_t i){ return device_data_[i]; } @@ -202,34 +195,41 @@ struct rpc_id_map_helper, Encoding, Storage> { }; } -template -class device; - /** * Represents a remote Sycl device. * */ template<> -class device { +class device final { private: cl::sycl::queue cmd_queue_; public: + device() = default; + + SAW_FORBID_COPY(device); + SAW_FORBID_MOVE(device); + /** * Copy data to device */ - template - error_or> copy_to_device(const data>& host_data){ - (void) host_data; - return make_error(); + template + error_or> copy_to_device(const data& host_data){ + return data::copy_to_device(host_data); } /** * Copy data to host */ - template - error_or> copy_to_host(const data>& device_data){ - (void) device_data; - return make_error(); + template + error_or> copy_to_host(const data& device_data){ + return device_data.copy_to_host(); + } + + /** + * Get a reference to the handle + */ + cl::sycl::queue& get_handle(){ + return cmd_queue_; } }; @@ -317,7 +317,7 @@ public: template remote_data request_data(id dat_id){ - return {dat_id, std::get>(storage_.maps), device_.cmd_queue_}; + return {dat_id, std::get>(storage_.maps), device_->get_handle()}; } /** @@ -350,8 +350,8 @@ public: return eov.get_value(); } else { auto& client_data = input.get_data(); - dev_tmp_inp = heap>(client_data, cmd_queue_); - cmd_queue_.wait(); + dev_tmp_inp = heap>(client_data, device_->get_handle()); + device_->get_handle().wait(); return dev_tmp_inp.get(); } }(); @@ -360,7 +360,7 @@ public: } auto& inp = *(eoinp.get_value()); - auto eod = cl_interface_.template call(inp, &cmd_queue_); + auto eod = cl_interface_.template call(inp, &(device_->get_handle())); if(eod.is_error()){ return std::move(eod.get_error()); @@ -424,11 +424,18 @@ public: * Spin up a rpc server */ template - rpc_server listen(const device& dev, typename rpc_server::InterfaceT iface){ + rpc_server listen(device& dev, typename rpc_server::InterfaceT iface){ using RpcServerT = rpc_server; using InterfaceT = typename RpcServerT::InterfaceT; return {dev, std::move(iface)}; } }; +template +template +error_or, encode::Native, rmt::Sycl>> data, encode::Native, rmt::Sycl>::copy_to_device(const data, encode::Native, Storage>& host_data, device& dev){ + data, encode::Native, rmt::Sycl> device_data{host_data.size(), dev.get_handle()}; + return make_error(); +} + } -- cgit v1.2.3