diff options
-rw-r--r-- | modules/remote-sycl/benchmarks/mixed_precision.cpp | 12 | ||||
-rw-r--r-- | modules/remote-sycl/benchmarks/mixed_precision.hpp | 1 | ||||
-rw-r--r-- | modules/remote-sycl/c++/data.hpp | 6 | ||||
-rw-r--r-- | modules/remote-sycl/c++/device.hpp | 16 | ||||
-rw-r--r-- | modules/remote-sycl/c++/rpc.hpp | 22 |
5 files changed, 29 insertions, 28 deletions
diff --git a/modules/remote-sycl/benchmarks/mixed_precision.cpp b/modules/remote-sycl/benchmarks/mixed_precision.cpp index e804f4e..8c6d4c9 100644 --- a/modules/remote-sycl/benchmarks/mixed_precision.cpp +++ b/modules/remote-sycl/benchmarks/mixed_precision.cpp @@ -119,9 +119,9 @@ int main(int argc, char** argv){ float64_host_data.at(i) = static_cast<double>(gen_num); float32_host_data.at(i) = static_cast<float>(gen_num); } - data<sch::MixedArray, encode::Native, rmt::Sycl> mixed_device_data{mixed_host_data}; - data<sch::Float64Array, encode::Native, rmt::Sycl> float64_device_data{float64_host_data}; - data<sch::Float32Array, encode::Native, rmt::Sycl> float32_device_data{float32_host_data}; + data<sch::MixedArray, encode::Sycl<encode::Native>> mixed_device_data{mixed_host_data}; + data<sch::Float64Array, encode::Sycl<encode::Native>> float64_device_data{float64_host_data}; + data<sch::Float32Array, encode::Sycl<encode::Native>> float32_device_data{float32_host_data}; sycl_iface.template call<"float64_32">(mixed_device_data, &(device.get_handle())); sycl_iface.template call<"float64">(float64_device_data, &(device.get_handle())); @@ -157,9 +157,9 @@ int main(int argc, char** argv){ float32_host_data.at(i) = static_cast<float>(gen_num); } - data<sch::MixedArray, encode::Native, rmt::Sycl> mixed_device_data{mixed_host_data}; - data<sch::Float64Array, encode::Native, rmt::Sycl> float64_device_data{float64_host_data}; - data<sch::Float32Array, encode::Native, rmt::Sycl> float32_device_data{float32_host_data}; + data<sch::MixedArray, encode::Sycl<encode::Native>> mixed_device_data{mixed_host_data}; + data<sch::Float64Array, encode::Sycl<encode::Native>> float64_device_data{float64_host_data}; + data<sch::Float32Array, encode::Sycl<encode::Native>> float32_device_data{float32_host_data}; sycl_iface.template call<"float64_32">(mixed_device_data, &(device.get_handle())); device.get_handle().wait(); diff --git a/modules/remote-sycl/benchmarks/mixed_precision.hpp b/modules/remote-sycl/benchmarks/mixed_precision.hpp index 784b9b5..cd8f9ec 100644 --- a/modules/remote-sycl/benchmarks/mixed_precision.hpp +++ b/modules/remote-sycl/benchmarks/mixed_precision.hpp @@ -1,5 +1,6 @@ #pragma once +#include "../c++/device.hpp" #include "../c++/remote.hpp" namespace sch { diff --git a/modules/remote-sycl/c++/data.hpp b/modules/remote-sycl/c++/data.hpp index d939a53..7481d53 100644 --- a/modules/remote-sycl/c++/data.hpp +++ b/modules/remote-sycl/c++/data.hpp @@ -9,12 +9,12 @@ namespace saw { * Most of the times this will be a root object. */ template<typename Schema> -class data<Schema, encode::Native, rmt::Sycl> { +class data<Schema, encode::Sycl<encode::Native>> { private: - cl::sycl::buffer<data<Schema, encode::Native, storage::Default>> data_; + cl::sycl::buffer<data<Schema, encode::Native>> data_; uint64_t size_; public: - data(const data<Schema, encode::Native, storage::Default>& data__): + data(const data<Schema, encode::Native>& data__): data_{&data__, 1u}, size_{data__.size()} {} diff --git a/modules/remote-sycl/c++/device.hpp b/modules/remote-sycl/c++/device.hpp index 3561da7..6d4dbbf 100644 --- a/modules/remote-sycl/c++/device.hpp +++ b/modules/remote-sycl/c++/device.hpp @@ -21,21 +21,21 @@ public: /** * Copy data to device */ - template<typename Schema, typename Encoding, typename Storage> - error_or<data<Schema, Encoding, rmt::Sycl>> copy_to_device(const data<Schema, Encoding, Storage>& host_data){ - return data<Schema, Encoding, rmt::Sycl>{host_data}; + template<typename Schema, typename Encoding> + error_or<data<Schema, encode::Sycl<Encoding>>> copy_to_device(const data<Schema, Encoding>& host_data){ + return data<Schema, encode::Sycl<Encoding>>{host_data}; } - template<typename Schema, typename Encoding, typename Storage> - error_or<data<Schema, Encoding, rmt::Sycl>> allocate_on_device(const data<typename meta_schema<Schema>::MetaSchema, Encoding, Storage>& host_meta){ - return copy_to_device(data<Schema, Encoding, Storage>{host_meta}); + template<typename Schema, typename Encoding> + error_or<data<Schema, encode::Sycl<Encoding>>> allocate_on_device(const data<typename meta_schema<Schema>::MetaSchema, Encoding>& host_meta){ + return copy_to_device(data<Schema, Encoding>{host_meta}); } /** * Copy data to host */ - template<typename Schema, typename Encoding, typename Storage> - error_or<data<Schema, Encoding, Storage>> copy_to_host(data<Schema, Encoding, rmt::Sycl>& dev_data){ + template<typename Schema, typename Encoding> + error_or<data<Schema, Encoding>> copy_to_host(data<Schema, encode::Sycl<Encoding>>& dev_data){ /** data<Schema,Encoding, Storage> host_data; cmd_queue_.submit([&](cl::sycl::handler& h){ diff --git a/modules/remote-sycl/c++/rpc.hpp b/modules/remote-sycl/c++/rpc.hpp index 780f7a0..65e2df5 100644 --- a/modules/remote-sycl/c++/rpc.hpp +++ b/modules/remote-sycl/c++/rpc.hpp @@ -118,14 +118,14 @@ struct rpc_iface_type_helper<schema::Interface<schema::Member<Func,K>,schema::Me /** * Rpc Client class for the Sycl backend. */ -template<typename Iface, typename Encoding, typename Storage> -class rpc_client<Iface, Encoding, Storage, rmt::Sycl> { +template<typename Iface, typename Encoding> +class rpc_client<Iface, Encoding, rmt::Sycl> { public: private: /** * Server this client is tied to */ - rpc_server<Iface, Encoding, Storage, rmt::Sycl>* srv_; + rpc_server<Iface, Encoding, rmt::Sycl>* srv_; /** * TransferClient created from the internal RPC data server @@ -136,7 +136,7 @@ private: * Generated some sort of id for the request. */ public: - rpc_client(rpc_server<Iface, Encoding, Storage, rmt::Sycl>& srv): + rpc_client(rpc_server<Iface, Encoding, rmt::Sycl>& srv): srv_{&srv}, data_client_{srv_->data_server} {} @@ -149,7 +149,7 @@ public: id< typename schema_member_type<Name, Iface>::type::ResponseT > - > call(const data_or_id<typename schema_member_type<Name, Iface>::type::RequestT, Encoding, Storage>& input){ + > call(const data_or_id<typename schema_member_type<Name, Iface>::type::RequestT, Encoding>& input){ auto next_free_id = srv_->template next_free_id<typename schema_member_type<Name, Iface>::type::ResponseT>(); return srv_->template call<Name, Storage>(input, next_free_id); } @@ -160,10 +160,10 @@ public: * Rpc Server class for the Sycl backend. */ template<typename Iface, typename Encoding> -class rpc_server<Iface, Encoding, storage::Default, rmt::Sycl> { +class rpc_server<Iface, Encoding, rmt::Sycl> { public: using InterfaceCtxT = cl::sycl::queue*; - using InterfaceT = interface<Iface, Encoding, storage::Default, InterfaceCtxT>; + using InterfaceT = interface<Iface, encode::Sycl<Encoding>, InterfaceCtxT>; private: /** @@ -218,18 +218,18 @@ public: id< typename schema_member_type<Name, Iface>::type::ResponseT > - > call(data_or_id<typename schema_member_type<Name, Iface>::type::RequestT, Encoding, storage::Default> input, id<typename schema_member_type<Name,Iface>::type::ResponseT> rpc_id){ + > call(data_or_id<typename schema_member_type<Name, Iface>::type::RequestT, Encoding> input, id<typename schema_member_type<Name,Iface>::type::ResponseT> rpc_id){ using FuncT = typename schema_member_type<Name, Iface>::type; /** * Object needed if and only if the provided data type is not an id */ - own<data<typename FuncT::RequestT, Encoding, rmt::Sycl>> dev_tmp_inp = nullptr; + own<data<typename FuncT::RequestT, encode::Sycl<Encoding>>> dev_tmp_inp = nullptr; /** * First check if it's data or an id. * If it's an id, check if it's registered within the storage and retrieve it. */ - auto eoinp = [&,this]() -> error_or<data<typename FuncT::RequestT, Encoding, rmt::Sycl>* > { + auto eoinp = [&,this]() -> error_or<data<typename FuncT::RequestT, encode::Sycl<Encoding>>* > { if(input.is_id()){ // storage_.maps auto eov = data_server_->template find<typename FuncT::RequestT>(input.get_id()); @@ -246,7 +246,7 @@ public: } auto& val = eov.get_value(); - dev_tmp_inp = heap<data<typename FuncT::RequestT, Encoding, rmt::Sycl>>(std::move(val)); + dev_tmp_inp = heap<data<typename FuncT::RequestT, encode::Sycl<Encoding>>>(std::move(val)); device_->get_handle().wait(); return dev_tmp_inp.get(); } |