summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--modules/remote-sycl/benchmarks/mixed_precision.cpp12
-rw-r--r--modules/remote-sycl/benchmarks/mixed_precision.hpp1
-rw-r--r--modules/remote-sycl/c++/data.hpp6
-rw-r--r--modules/remote-sycl/c++/device.hpp16
-rw-r--r--modules/remote-sycl/c++/rpc.hpp22
5 files changed, 29 insertions, 28 deletions
diff --git a/modules/remote-sycl/benchmarks/mixed_precision.cpp b/modules/remote-sycl/benchmarks/mixed_precision.cpp
index e804f4e..8c6d4c9 100644
--- a/modules/remote-sycl/benchmarks/mixed_precision.cpp
+++ b/modules/remote-sycl/benchmarks/mixed_precision.cpp
@@ -119,9 +119,9 @@ int main(int argc, char** argv){
float64_host_data.at(i) = static_cast<double>(gen_num);
float32_host_data.at(i) = static_cast<float>(gen_num);
}
- data<sch::MixedArray, encode::Native, rmt::Sycl> mixed_device_data{mixed_host_data};
- data<sch::Float64Array, encode::Native, rmt::Sycl> float64_device_data{float64_host_data};
- data<sch::Float32Array, encode::Native, rmt::Sycl> float32_device_data{float32_host_data};
+ data<sch::MixedArray, encode::Sycl<encode::Native>> mixed_device_data{mixed_host_data};
+ data<sch::Float64Array, encode::Sycl<encode::Native>> float64_device_data{float64_host_data};
+ data<sch::Float32Array, encode::Sycl<encode::Native>> float32_device_data{float32_host_data};
sycl_iface.template call<"float64_32">(mixed_device_data, &(device.get_handle()));
sycl_iface.template call<"float64">(float64_device_data, &(device.get_handle()));
@@ -157,9 +157,9 @@ int main(int argc, char** argv){
float32_host_data.at(i) = static_cast<float>(gen_num);
}
- data<sch::MixedArray, encode::Native, rmt::Sycl> mixed_device_data{mixed_host_data};
- data<sch::Float64Array, encode::Native, rmt::Sycl> float64_device_data{float64_host_data};
- data<sch::Float32Array, encode::Native, rmt::Sycl> float32_device_data{float32_host_data};
+ data<sch::MixedArray, encode::Sycl<encode::Native>> mixed_device_data{mixed_host_data};
+ data<sch::Float64Array, encode::Sycl<encode::Native>> float64_device_data{float64_host_data};
+ data<sch::Float32Array, encode::Sycl<encode::Native>> float32_device_data{float32_host_data};
sycl_iface.template call<"float64_32">(mixed_device_data, &(device.get_handle()));
device.get_handle().wait();
diff --git a/modules/remote-sycl/benchmarks/mixed_precision.hpp b/modules/remote-sycl/benchmarks/mixed_precision.hpp
index 784b9b5..cd8f9ec 100644
--- a/modules/remote-sycl/benchmarks/mixed_precision.hpp
+++ b/modules/remote-sycl/benchmarks/mixed_precision.hpp
@@ -1,5 +1,6 @@
#pragma once
+#include "../c++/device.hpp"
#include "../c++/remote.hpp"
namespace sch {
diff --git a/modules/remote-sycl/c++/data.hpp b/modules/remote-sycl/c++/data.hpp
index d939a53..7481d53 100644
--- a/modules/remote-sycl/c++/data.hpp
+++ b/modules/remote-sycl/c++/data.hpp
@@ -9,12 +9,12 @@ namespace saw {
* Most of the times this will be a root object.
*/
template<typename Schema>
-class data<Schema, encode::Native, rmt::Sycl> {
+class data<Schema, encode::Sycl<encode::Native>> {
private:
- cl::sycl::buffer<data<Schema, encode::Native, storage::Default>> data_;
+ cl::sycl::buffer<data<Schema, encode::Native>> data_;
uint64_t size_;
public:
- data(const data<Schema, encode::Native, storage::Default>& data__):
+ data(const data<Schema, encode::Native>& data__):
data_{&data__, 1u},
size_{data__.size()}
{}
diff --git a/modules/remote-sycl/c++/device.hpp b/modules/remote-sycl/c++/device.hpp
index 3561da7..6d4dbbf 100644
--- a/modules/remote-sycl/c++/device.hpp
+++ b/modules/remote-sycl/c++/device.hpp
@@ -21,21 +21,21 @@ public:
/**
* Copy data to device
*/
- template<typename Schema, typename Encoding, typename Storage>
- error_or<data<Schema, Encoding, rmt::Sycl>> copy_to_device(const data<Schema, Encoding, Storage>& host_data){
- return data<Schema, Encoding, rmt::Sycl>{host_data};
+ template<typename Schema, typename Encoding>
+ error_or<data<Schema, encode::Sycl<Encoding>>> copy_to_device(const data<Schema, Encoding>& host_data){
+ return data<Schema, encode::Sycl<Encoding>>{host_data};
}
- template<typename Schema, typename Encoding, typename Storage>
- error_or<data<Schema, Encoding, rmt::Sycl>> allocate_on_device(const data<typename meta_schema<Schema>::MetaSchema, Encoding, Storage>& host_meta){
- return copy_to_device(data<Schema, Encoding, Storage>{host_meta});
+ template<typename Schema, typename Encoding>
+ error_or<data<Schema, encode::Sycl<Encoding>>> allocate_on_device(const data<typename meta_schema<Schema>::MetaSchema, Encoding>& host_meta){
+ return copy_to_device(data<Schema, Encoding>{host_meta});
}
/**
* Copy data to host
*/
- template<typename Schema, typename Encoding, typename Storage>
- error_or<data<Schema, Encoding, Storage>> copy_to_host(data<Schema, Encoding, rmt::Sycl>& dev_data){
+ template<typename Schema, typename Encoding>
+ error_or<data<Schema, Encoding>> copy_to_host(data<Schema, encode::Sycl<Encoding>>& dev_data){
/**
data<Schema,Encoding, Storage> host_data;
cmd_queue_.submit([&](cl::sycl::handler& h){
diff --git a/modules/remote-sycl/c++/rpc.hpp b/modules/remote-sycl/c++/rpc.hpp
index 780f7a0..65e2df5 100644
--- a/modules/remote-sycl/c++/rpc.hpp
+++ b/modules/remote-sycl/c++/rpc.hpp
@@ -118,14 +118,14 @@ struct rpc_iface_type_helper<schema::Interface<schema::Member<Func,K>,schema::Me
/**
* Rpc Client class for the Sycl backend.
*/
-template<typename Iface, typename Encoding, typename Storage>
-class rpc_client<Iface, Encoding, Storage, rmt::Sycl> {
+template<typename Iface, typename Encoding>
+class rpc_client<Iface, Encoding, rmt::Sycl> {
public:
private:
/**
* Server this client is tied to
*/
- rpc_server<Iface, Encoding, Storage, rmt::Sycl>* srv_;
+ rpc_server<Iface, Encoding, rmt::Sycl>* srv_;
/**
* TransferClient created from the internal RPC data server
@@ -136,7 +136,7 @@ private:
* Generated some sort of id for the request.
*/
public:
- rpc_client(rpc_server<Iface, Encoding, Storage, rmt::Sycl>& srv):
+ rpc_client(rpc_server<Iface, Encoding, rmt::Sycl>& srv):
srv_{&srv},
data_client_{srv_->data_server}
{}
@@ -149,7 +149,7 @@ public:
id<
typename schema_member_type<Name, Iface>::type::ResponseT
>
- > call(const data_or_id<typename schema_member_type<Name, Iface>::type::RequestT, Encoding, Storage>& input){
+ > call(const data_or_id<typename schema_member_type<Name, Iface>::type::RequestT, Encoding>& input){
auto next_free_id = srv_->template next_free_id<typename schema_member_type<Name, Iface>::type::ResponseT>();
return srv_->template call<Name, Storage>(input, next_free_id);
}
@@ -160,10 +160,10 @@ public:
* Rpc Server class for the Sycl backend.
*/
template<typename Iface, typename Encoding>
-class rpc_server<Iface, Encoding, storage::Default, rmt::Sycl> {
+class rpc_server<Iface, Encoding, rmt::Sycl> {
public:
using InterfaceCtxT = cl::sycl::queue*;
- using InterfaceT = interface<Iface, Encoding, storage::Default, InterfaceCtxT>;
+ using InterfaceT = interface<Iface, encode::Sycl<Encoding>, InterfaceCtxT>;
private:
/**
@@ -218,18 +218,18 @@ public:
id<
typename schema_member_type<Name, Iface>::type::ResponseT
>
- > call(data_or_id<typename schema_member_type<Name, Iface>::type::RequestT, Encoding, storage::Default> input, id<typename schema_member_type<Name,Iface>::type::ResponseT> rpc_id){
+ > call(data_or_id<typename schema_member_type<Name, Iface>::type::RequestT, Encoding> input, id<typename schema_member_type<Name,Iface>::type::ResponseT> rpc_id){
using FuncT = typename schema_member_type<Name, Iface>::type;
/**
* Object needed if and only if the provided data type is not an id
*/
- own<data<typename FuncT::RequestT, Encoding, rmt::Sycl>> dev_tmp_inp = nullptr;
+ own<data<typename FuncT::RequestT, encode::Sycl<Encoding>>> dev_tmp_inp = nullptr;
/**
* First check if it's data or an id.
* If it's an id, check if it's registered within the storage and retrieve it.
*/
- auto eoinp = [&,this]() -> error_or<data<typename FuncT::RequestT, Encoding, rmt::Sycl>* > {
+ auto eoinp = [&,this]() -> error_or<data<typename FuncT::RequestT, encode::Sycl<Encoding>>* > {
if(input.is_id()){
// storage_.maps
auto eov = data_server_->template find<typename FuncT::RequestT>(input.get_id());
@@ -246,7 +246,7 @@ public:
}
auto& val = eov.get_value();
- dev_tmp_inp = heap<data<typename FuncT::RequestT, Encoding, rmt::Sycl>>(std::move(val));
+ dev_tmp_inp = heap<data<typename FuncT::RequestT, encode::Sycl<Encoding>>>(std::move(val));
device_->get_handle().wait();
return dev_tmp_inp.get();
}