diff options
Diffstat (limited to 'modules/remote-sycl/c++/remote.hpp')
-rw-r--r-- | modules/remote-sycl/c++/remote.hpp | 100 |
1 files changed, 98 insertions, 2 deletions
diff --git a/modules/remote-sycl/c++/remote.hpp b/modules/remote-sycl/c++/remote.hpp index dbbefcb..d311ca5 100644 --- a/modules/remote-sycl/c++/remote.hpp +++ b/modules/remote-sycl/c++/remote.hpp @@ -4,7 +4,7 @@ #include <forstio/codec/data.hpp> #include <forstio/codec/id_map.hpp> -#include <CL/sycl.hpp> +#include <AdaptiveCpp/CL/sycl.hpp> namespace saw { namespace rmt { @@ -37,13 +37,77 @@ public: conveyor<data<T,Encoding>> on_receive(); /// Stopped here }; +/** + * + */ +template<typename T, uint64_t N> +class data<schema::Primitive<T,N>, encode::Native<rmt::Sycl>> { +public: + using Schema = schema::Primitive<T,N>; + using NativeType = typename native_data_type<Schema>::type; +private: + /** + * + */ + NativeType val_; +public: + /** + * + */ + data(NativeType val__): + val_{val__} + {} + + NativeType get(){ + return val_; + } +}; + template<typename T, uint64_t D> class data<schema::Array<T,D>, encode::Native<rmt::Sycl>> { +public: + using Schema = schema::Array<T,D>; private: - cl::sycl::buffer<typename native_data_type<T>::type, D> device_data_; + uint64_t total_length_; + typename native_data_type<T>::type* device_data_; + // data<T>* device_data_; + cl::sycl::queue* queue_; + static_assert(is_primitive<T>::value, "Only supports primitives for now"); static_assert(D==1u, "For now we only support 1D Arrays"); public: + data(uint64_t size, cl::sycl::queue& q__): + total_length_{size}, + device_data_{cl::sycl::malloc_device<typename native_data_type<T>::type>(size, q__)}, + queue_{&q__} + { + if(!device_data_){ + total_length_ = 0u; + } + } + + template<typename Encoding> + data(const data<Schema, Encoding>& from, cl::sycl::queue& q__): + total_length_{from.size()}, + device_data_{cl::sycl::malloc_device<typename native_data_type<T>::type>(from.size(), q__)}, + queue_{&q__} + { + if(!device_data_){ + total_length_ = 0u; + } + } + + ~data(){ + // free data + if(device_data_){ + /// SYCL FREE + cl::sycl::free(device_data_, *queue_); + } + } + + data<T,encode::Native<rmt::Sycl>>& at(uint64_t i){ + return device_data_[i]; + } }; namespace impl { @@ -59,6 +123,36 @@ struct rpc_id_map_helper<schema::Interface<Members...>, Encoding> { }; } /** + * Rpc Client class for the Sycl backend. + */ +template<typename Iface, typename Encoding> +class rpc_client<Iface, Encoding, rmt::Sycl> { +public: +private: + /** + * Server this client is tied to + */ + rpc_server<Iface, Encoding, rmt::Sycl>* srv_; +public: + rpc_client(rpc_server<Iface, Encoding, rmt::Sycl>& srv): + srv_{&srv} + {} + + /** + * Rpc call + */ + template<string_literal Name, typename ClientEncoding> + error_or< + id< + typename schema_member_type<Name, Iface>::type::ResponseT + > + > call(data_or_id<typename schema_member_type<Name, Iface>::type::RequestT, ClientEncoding> input){ + return make_error<err::not_implemented>("RpcClient side is not implemented"); + } + +}; + +/** * Rpc Server class for the Sycl backend. */ template<typename Iface, typename Encoding> @@ -103,6 +197,8 @@ public: > > call(data_or_id<typename schema_member_type<Name, Iface>::type::RequestT, Encoding> input){ + + /** * First check if it's data or an id. * If it's an id, check if it's registered within the storage and retrieve it. |