diff options
author | Claudius "keldu" Holeksa <mail@keldu.de> | 2024-06-21 19:44:34 +0200 |
---|---|---|
committer | Claudius "keldu" Holeksa <mail@keldu.de> | 2024-06-21 19:44:34 +0200 |
commit | 86b06a3fee2cd7635a9ab486e2a35bdf1e81ce38 (patch) | |
tree | 5485b323cdce1c1347f1a20c7f33e8f772c73dbf | |
parent | 601113a445658d8b15273dd91c66cf20daf50d30 (diff) |
Moving forward with basic test for sycl
-rw-r--r-- | default.nix | 2 | ||||
-rw-r--r-- | modules/codec/c++/interface.hpp | 22 | ||||
-rw-r--r-- | modules/codec/c++/schema.hpp | 6 | ||||
-rw-r--r-- | modules/codec/tests/codec.cpp | 18 | ||||
-rw-r--r-- | modules/io_codec/c++/transfer.hpp | 9 | ||||
-rw-r--r-- | modules/remote-sycl/.nix/derivation.nix | 2 | ||||
-rw-r--r-- | modules/remote-sycl/c++/remote.hpp | 71 | ||||
-rw-r--r-- | modules/remote-sycl/tests/calculator.foo (renamed from modules/remote-sycl/tests/calculator.cpp) | 2 | ||||
-rw-r--r-- | modules/remote-sycl/tests/sycl_basics.cpp | 96 |
9 files changed, 179 insertions, 49 deletions
diff --git a/default.nix b/default.nix index 79f7777..681daf3 100644 --- a/default.nix +++ b/default.nix @@ -79,7 +79,7 @@ in rec { inherit clang-tools; openmp = pkgs.llvmPackages_15.openmp; - build_examples = "true"; + build_examples = "false"; }; tools = pkgs.callPackage modules/tools/.nix/derivation.nix { diff --git a/modules/codec/c++/interface.hpp b/modules/codec/c++/interface.hpp index 0f41f55..e1c9a12 100644 --- a/modules/codec/c++/interface.hpp +++ b/modules/codec/c++/interface.hpp @@ -12,28 +12,40 @@ template<typename SchemaFunc, typename Encode, typename Storage, typename Contex class function; namespace impl { +template<typename DataSchema, typename Encode, typename Storage> +struct FuncReturnTypeHelper { + using Type = data<DataSchema,Encode,Storage>; +}; + +template<typename Encode, typename Storage> +struct FuncReturnTypeHelper<schema::Void, Encode, Storage> { + using Type = void; +}; + template<typename Request, typename Response, typename Encode, typename Storage, typename Ctx> struct FuncTypeHelper { - using Type = std::function<data<Response, Encode, Storage>(data<Request, Encode, Storage>&, Ctx)>; + using Type = std::function<error_or<typename FuncReturnTypeHelper<Response,Encode,Storage>::Type>(data<Request, Encode, Storage>&, Ctx)>; }; template<typename Request, typename Response, typename Encode, typename Storage> struct FuncTypeHelper<Request, Response, Encode, Storage, void_t> { - using Type = std::function<data<Response, Encode, Storage>(data<Request, Encode, Storage>&)>; + using Type = std::function<error_or<typename FuncReturnTypeHelper<Response, Encode, Storage>::Type>(data<Request, Encode, Storage>&)>; }; + } template<typename Request, typename Response, typename Encode, typename Storage, typename Context> class function<schema::Function<Request, Response>, Encode, Storage, Context> { private: typename impl::FuncTypeHelper<Request, Response, Encode, Storage, Context>::Type func_; + using ResponseDataType = typename impl::FuncReturnTypeHelper<Response, Encode, Storage>::Type; public: template<typename Func> function(Func func): func_{std::move(func)} {} - error_or<data<Response, Encode, Storage>> call(data<Request, Encode, Storage>& req, Context ctx = {}){ + error_or<ResponseDataType> call(data<Request, Encode, Storage>& req, Context ctx = {}){ if constexpr (std::is_same_v<Context, void_t>){ (void) ctx; return func_(req); @@ -84,13 +96,13 @@ public: template<string_literal Lit> error_or< - data< + typename impl::FuncReturnTypeHelper< typename parameter_pack_type< parameter_key_pack_index< Lit, Names... >::value , Responses...>::type - , Encode, Storage>> call( + , Encode, Storage>::Type > call( data< typename parameter_pack_type< parameter_key_pack_index< diff --git a/modules/codec/c++/schema.hpp b/modules/codec/c++/schema.hpp index 2ef7c77..6a69425 100644 --- a/modules/codec/c++/schema.hpp +++ b/modules/codec/c++/schema.hpp @@ -11,6 +11,12 @@ template <class T> struct is_primitive { namespace schema { // NOLINTBEGIN + +/** + * Void Type used for function schemas + */ +struct Void {}; + template <typename T, string_literal Literal> struct Member { static constexpr string_literal name = "Member"; diff --git a/modules/codec/tests/codec.cpp b/modules/codec/tests/codec.cpp index 720b734..1bec214 100644 --- a/modules/codec/tests/codec.cpp +++ b/modules/codec/tests/codec.cpp @@ -40,12 +40,15 @@ using TestInt32Pair = Tuple< Int32 >; +using TestVoidReturnFunction = Function<Int32, Void>; + using TestCalcFunction = Function<TestInt32Pair, Int32>; using TestInterface = Interface< Member<TestCalcFunction, "add">, Member<TestCalcFunction, "sub">, - Member<TestCalcFunction, "multiply"> + Member<TestCalcFunction, "multiply">, + Member<TestVoidReturnFunction, "void"> >; } SAW_TEST("One Dimensional Array") { @@ -378,7 +381,7 @@ SAW_TEST("Interface basics"){ data<schema::TestInt32Pair, encode::Native, storage::Default> native; auto func_add = - [](data<schema::TestInt32Pair, encode::Native, storage::Default> req){ + [](data<schema::TestInt32Pair, encode::Native, storage::Default>& req){ data<schema::Int32, encode::Native, storage::Default> resp; resp.set(req.get<0>().get() + req.get<1>().get()); @@ -386,14 +389,14 @@ SAW_TEST("Interface basics"){ return resp; }; auto func_sub = - [](data<schema::TestInt32Pair, encode::Native, storage::Default> req){ + [](data<schema::TestInt32Pair, encode::Native, storage::Default>& req){ data<schema::Int32, encode::Native, storage::Default> resp; resp.set(req.get<0>().get() - req.get<1>().get()); return resp; }; - auto func_multiply = [](data<schema::TestInt32Pair, encode::Native, storage::Default> req){ + auto func_multiply = [](data<schema::TestInt32Pair, encode::Native, storage::Default>& req){ data<schema::Int32, encode::Native, storage::Default> resp; resp.set(req.get<0>().get() * req.get<1>().get()); @@ -401,7 +404,12 @@ SAW_TEST("Interface basics"){ return resp; }; - auto iface = interface_factory<schema::TestInterface, encode::Native, storage::Default>::create(std::move(func_add), std::move(func_sub), std::move(func_multiply)); + auto func_void = [](data<schema::Int32>& req) -> error_or<void> { + (void) req; + return void_t{}; + }; + + auto iface = interface_factory<schema::TestInterface, encode::Native, storage::Default>::create(std::move(func_add), std::move(func_sub), std::move(func_multiply), std::move(func_void)); { data<schema::TestInt32Pair, encode::Native, storage::Default> native; diff --git a/modules/io_codec/c++/transfer.hpp b/modules/io_codec/c++/transfer.hpp new file mode 100644 index 0000000..b6aa977 --- /dev/null +++ b/modules/io_codec/c++/transfer.hpp @@ -0,0 +1,9 @@ +#pragma once + +namespace saw { +template<typename T> +class data_client; + +template<typename T> +class data_server; +} diff --git a/modules/remote-sycl/.nix/derivation.nix b/modules/remote-sycl/.nix/derivation.nix index 488b8a8..2247ec0 100644 --- a/modules/remote-sycl/.nix/derivation.nix +++ b/modules/remote-sycl/.nix/derivation.nix @@ -49,7 +49,7 @@ in stdenv.mkDerivation { scons prefix=$out build_examples=${build_examples} install ''; - doCheck = false; + doCheck = true; checkPhase = '' scons test ./bin/tests diff --git a/modules/remote-sycl/c++/remote.hpp b/modules/remote-sycl/c++/remote.hpp index 54b7a7b..24756be 100644 --- a/modules/remote-sycl/c++/remote.hpp +++ b/modules/remote-sycl/c++/remote.hpp @@ -11,11 +11,30 @@ class remote<rmt::Sycl>; template<typename T> class device; +template<typename Schema> +class data<Schema, encode::Native, rmt::Sycl> { +private: + cl::sycl::buffer<data<Schema, encode::Native, storage::Default>> data_; +public: + data(data<Schema, encode::Native, storage::Default>& data__): + data_{&data__, 1u} + {} + + auto& get_handle() { + return data_; + } + + template<cl::sycl::access::mode AccessMode> + auto access(cl::sycl::handler& h){ + return data_.template get_access<AccessMode>(h); + } +}; + /** * Remote data class for the Sycl backend. */ template<typename T, typename Encoding, typename Storage> -class remote_data<T, Encoding, Storage, rmt::Sycl> { +class remote_data<T, Encoding, Storage, rmt::Sycl> final { private: /** * An identifier to the data being held on the remote @@ -30,20 +49,15 @@ public: /** * Main constructor */ - remote_data(data<T,Encoding,Storage>& remote_data__, cl::sycl::queue& queue__): - remote_data_{&remote_data__}, + remote_data(id<T> data_id__, cl::sycl::queue& queue__): + data_id_{data_id__}, queue_{&queue__} {} /** * Destructor specifically designed to deallocate on the device. */ - ~remote_data(){ - if(remote_data_){ - cl::sycl::free(remote_data_,queue_); - remote_data_ = nullptr; - } - } + ~remote_data(){} SAW_FORBID_COPY(remote_data); SAW_FORBID_MOVE(remote_data); @@ -82,7 +96,7 @@ private: /** * The actual data */ - data<Schema,Encoding,Storage>* device_data_; + data<Schema,Encoding,storage::Default>* device_data_; /** * The sycl queue object */ @@ -91,7 +105,7 @@ public: /** * Main constructor */ - device_data(data<Schema,Encoding,Storage>& device_data__, cl::sycl::queue& queue__): + device_data(data<Schema,Encoding,storage::Default>& device_data__, cl::sycl::queue& queue__): device_data_{&device_data__}, queue_{&queue__} {} @@ -111,6 +125,11 @@ public: }; namespace impl { +template<typename Schema, typename Encoding, typename Backend> +struct device_id_map { + std::vector<device_data<Schema, Encoding, Backend>> data; +}; + template<typename Iface, typename Encoding, typename Storage> struct rpc_id_map_helper { static_assert(always_false<Iface, Encoding,Storage>, "Only supports Interface schema types."); @@ -122,9 +141,12 @@ struct rpc_id_map_helper<schema::Interface<Members...>, Encoding, Storage> { }; } +} +// Maybe a helper impl tmpl file? +namespace saw { + /** * Represents a remote Sycl device. - * */ template<> class device<rmt::Sycl> final { @@ -161,11 +183,6 @@ public: }; /** - * Device data transport - */ - - -/** * Rpc Client class for the Sycl backend. */ template<typename Iface, typename Encoding, typename Storage> @@ -222,7 +239,7 @@ private: /** * Basic storage for response data. */ - // impl::rpc_id_map_helper<Iface, Encoding, rmt::Sycl> storage_; + impl::rpc_id_map_helper<Iface, Encoding, rmt::Sycl> storage_; public: /** @@ -369,22 +386,4 @@ public: } }; -template<typename T, uint64_t D> -template<typename Storage> -error_or<data<schema::Array<T,D>, encode::Native, rmt::Sycl>> data<schema::Array<T,D>, encode::Native, rmt::Sycl>::copy_to_device(const data<schema::Array<T,D>, encode::Native, Storage>& host_data, device<rmt::Sycl>& dev){ - /** - * Retrieve handle - */ - auto& cmd_handle = dev.get_handle(); - - uint64_t* dev_len = cl::sycl::malloc_device<uint64_t>(1u, cmd_handle); - uint64_t len = host_data.size(); - cmd_handle.template copy<uint64_t>(&len,dev_len, 1u); - - auto dev_dat = cl::sycl::malloc_device<data<T,encode::Native,storage::Default>>(host_data.size(), cmd_handle); - cmd_handle.copy(&host_data.at(0), dev_dat, host_data.size()); - cmd_handle.wait(); - - return data<schema::Array<T,D>,encode::Native, rmt::Sycl>{dev_len, dev_dat, cmd_handle}; -} } diff --git a/modules/remote-sycl/tests/calculator.cpp b/modules/remote-sycl/tests/calculator.foo index 6d061ad..745bd3d 100644 --- a/modules/remote-sycl/tests/calculator.cpp +++ b/modules/remote-sycl/tests/calculator.foo @@ -15,7 +15,7 @@ using Calculator = Interface< >; } -SAW_TEST("Sycl Interface Calculator"){ +SAW_TEST("SYCL Interface Calculator"){ using namespace saw; cl::sycl::queue cmd_queue; diff --git a/modules/remote-sycl/tests/sycl_basics.cpp b/modules/remote-sycl/tests/sycl_basics.cpp new file mode 100644 index 0000000..bf41983 --- /dev/null +++ b/modules/remote-sycl/tests/sycl_basics.cpp @@ -0,0 +1,96 @@ +#include <forstio/test/suite.hpp> + +#include "../c++/remote.hpp" + +namespace { +namespace schema { +using namespace saw::schema; + +using TestStruct = Struct< + Member<UInt64, "foo">, + Member<Float64, "bar">, + Member<Array<Float64>, "doubles"> +>; + +using Foo = Interface< + Member<Function<TestStruct, Void>, "foo"> +>; + +using Calculator = Interface< + Member< + Function<Tuple<Int64, Int64>, Int64>, "add" + > +, Member< + Function<Tuple<Int64, Int64>, Int64>, "multiply" + > +>; +} +SAW_TEST("SYCL Test Setup"){ + using namespace saw; + + data<schema::TestStruct> host_data; + host_data.template get<"foo">() = 321; + host_data.template get<"bar">() = 50.0; + host_data.template get<"doubles">() = {1024u}; + + saw::event_loop loop; + saw::wait_scope wait{loop}; + + remote<rmt::Sycl> rmt; + saw::own<saw::remote_address<saw::rmt::Sycl>> rmt_addr{}; + + rmt.resolve_address().then([&](auto addr){ + rmt_addr = std::move(addr); + }).detach(); + + wait.poll(); + SAW_EXPECT(rmt_addr, "Remote Address class hasn't been filled"); + + auto device = rmt.connect_device(*rmt_addr); + + data<schema::TestStruct, encode::Native, rmt::Sycl> device_data{host_data}; + + interface<schema::Foo, encode::Native,rmt::Sycl, cl::sycl::queue*> cl_iface { +[&](data<schema::TestStruct, encode::Native, rmt::Sycl>& in, cl::sycl::queue* cmd) -> error_or<void> { + + cmd->submit([&](cl::sycl::handler& h){ + + auto acc_buff = in.template access<cl::sycl::access::mode::write>(h); + + uint64_t si = host_data.template get<"doubles">().size(); + + h.parallel_for(cl::sycl::range<1>(si), [=] (cl::sycl::id<1> it){ + acc_buff[0u].template get<"foo">() = acc_buff[0u].template get<"doubles">().size(); + auto& dbls = acc_buff[0u].template get<"doubles">(); + dbls.at(it[0u]) = it[0u] * 2.0; + }); + }); + /* + cmd->submit([&](cl::sycl::handler& h){ + auto acc_buff = in.template access<cl::sycl::access::mode::read>(h); + h.copy(acc_buff, &host_data); + }); + */ + + /** + cl::sycl::host_accessor result{in.get_handle()}; + std::cout<<result[0u].template get<"foo">().get()<<std::endl; + std::cout<<result[0u].template get<"bar">().get()<<std::endl; + **/ + return saw::void_t{}; + } + }; + + std::cout<<"Running on:\n"<<device.get_handle().get_device().get_info<cl::sycl::info::device::name>()<<std::endl; + + std::cout<<host_data.template get<"foo">().get()<<std::endl; + std::cout<<host_data.template get<"bar">().get()<<std::endl; + std::cout<<std::endl; + cl_iface.template call <"foo">(device_data, &device.get_handle()); + device.get_handle().wait(); + std::cout<<host_data.template get<"foo">().get()<<std::endl; + std::cout<<host_data.template get<"bar">().get()<<std::endl; + auto& dbls = host_data.template get<"doubles">(); + std::cout<<std::endl; +} +} |