From df7574bf64b014e152d100a224d29ecdda32a7b2 Mon Sep 17 00:00:00 2001 From: Claudius 'keldu' Holeksa Date: Wed, 11 Sep 2024 13:43:18 +0200 Subject: Remote Hip example work --- modules/remote-hip/c++/data.hpp | 10 +- modules/remote-hip/c++/device.hpp | 17 ++ modules/remote-hip/c++/device.tmpl.hpp | 25 +++ modules/remote-hip/c++/remote.hpp | 24 +-- modules/remote-hip/c++/transfer.hpp | 188 ++-------------------- modules/remote-hip/examples/hip_transfer_data.cpp | 13 +- 6 files changed, 75 insertions(+), 202 deletions(-) create mode 100644 modules/remote-hip/c++/device.tmpl.hpp diff --git a/modules/remote-hip/c++/data.hpp b/modules/remote-hip/c++/data.hpp index 5e8e6f9..3e7c3ed 100644 --- a/modules/remote-hip/c++/data.hpp +++ b/modules/remote-hip/c++/data.hpp @@ -11,14 +11,14 @@ namespace saw { template class data> { private: - data data_; + data* data_; public: - data(const data& data__): - data_{data__} + data(): + data_{nullptr} {} - ref> get_data() { - return {data_}; + data** get_device_data() { + return &data_; } }; } diff --git a/modules/remote-hip/c++/device.hpp b/modules/remote-hip/c++/device.hpp index 227ed1b..f760024 100644 --- a/modules/remote-hip/c++/device.hpp +++ b/modules/remote-hip/c++/device.hpp @@ -2,7 +2,10 @@ #include "common.hpp" + +#include "device.tmpl.hpp" namespace saw { + /** * Represents a remote Sycl device. */ @@ -14,6 +17,20 @@ public: SAW_FORBID_COPY(device); SAW_FORBID_MOVE(device); + + template + error_or copy_to_device(data& from, data>& to){ + + auto dev_data = to.get_device_data(); + + auto eov = impl::hip_copy_to_device::apply(from, dev_data); + return eov; + } + + template + error_or copy_to_host(data>& from, data& to){ + return make_error(); + } }; } diff --git a/modules/remote-hip/c++/device.tmpl.hpp b/modules/remote-hip/c++/device.tmpl.hpp new file mode 100644 index 0000000..4777660 --- /dev/null +++ b/modules/remote-hip/c++/device.tmpl.hpp @@ -0,0 +1,25 @@ +namespace saw { +namespace impl { +template +struct hip_copy_to_device { + static error_or apply(data& from, data** to){ + static_assert(always_false, "Unsupported case."); + return make_void(); + } +}; + +template +struct hip_copy_to_device, Encoding> { + using Schema = schema::Primitive; + static error_or apply(data& from, data** to){ + hipError_t malloc_err = hipMalloc(to, sizeof(data)); + // HIP_CHECK(malloc_err); + + hipError_t copy_err = hipMemcpy(*to, &from, sizeof(data), hipMemcpyHostToDevice); + // HIP_CHECK(copy_err); + + return make_void(); + } +}; +} +} diff --git a/modules/remote-hip/c++/remote.hpp b/modules/remote-hip/c++/remote.hpp index 794d629..242c06d 100644 --- a/modules/remote-hip/c++/remote.hpp +++ b/modules/remote-hip/c++/remote.hpp @@ -62,6 +62,13 @@ public: conveyor>> resolve_address(uint64_t dev_id = 0u){ return heap>(dev_id); } + + /** + * Parse address, but don't resolve it. + */ + error_or>> parse_address(uint64_t dev_id = 0u){ + return heap>(dev_id); + } /** * Info. @@ -96,13 +103,6 @@ public: return sstr.str(); } - /** - * Parse address, but don't resolve it. - */ - error_or>> parse_address(uint64_t dev_id = 0u){ - return heap>(dev_id); - } - /** * Spin up data server */ @@ -115,16 +115,6 @@ public: } return heap>(ins.first->second); } - - /** - * Spin up a rpc server - */ - template - rpc_server listen(remote_address& dev, typename rpc_server::InterfaceT iface){ - //using RpcServerT = rpc_server; - //using InterfaceT = typename RpcServerT::InterfaceT; - return {share>(), std::move(iface)}; - } }; } diff --git a/modules/remote-hip/c++/transfer.hpp b/modules/remote-hip/c++/transfer.hpp index cdde8ba..d0ece27 100644 --- a/modules/remote-hip/c++/transfer.hpp +++ b/modules/remote-hip/c++/transfer.hpp @@ -28,14 +28,20 @@ public: device_{std::move(device__)} {} - error_or send(const data& dat, id store_id){ + error_or send(data& dat, id store_id){ + data> hip_dat; + { + auto eov = device_->copy_to_device(dat, hip_dat); + if(eov.is_error()){ + return eov; + } + } - auto ins = values_.emplace(std::make_pair(store_id.get_value(), data>{dat})); + auto ins = values_.emplace(std::make_pair(store_id.get_value(), hip_dat)); if(!ins.second){ return make_error(); } - return make_error("Allocate not implemented. Since we don't actually do any device copies."); return make_void(); } @@ -63,180 +69,4 @@ public: } }; -template -class data_server, Encoding, rmt::Hip> { -private: - /** - * Device context class - */ - our> device_; - - /** - * Store for the data the server manages. - */ - typename impl::data_server_redux, typename tmpl_reduce>::type >::type values_; -public: - /** - * Main constructor - */ - data_server(our> device__): - device_{std::move(device__)} - {} - - /** - * Get data which we will store. - */ - template - error_or send(const data& dat, id store_id){ - return make_error(); - /* - auto& vals = std::get>>>(values_); - auto eoval = device_->template copy_to_device(dat); - if(eoval.is_error()){ - auto& err = eoval.get_error(); - return std::move(err); - } - auto& val = eoval.get_value(); - try { - auto insert_res = vals.insert(std::make_pair(store_id.get_value(), std::move(val))); - if(!insert_res.second){ - return make_error(); - } - }catch ( std::exception& ){ - return make_error(); - } - return void_t{}; - */ - } - - template - error_or allocate(const data::MetaSchema, Encoding>& dat, id store_id){ - return make_error(); - /* - auto& vals = std::get>>>(values_); - auto eoval = device_->template allocate_on_device(dat); - if(eoval.is_error()){ - auto& err = eoval.get_error(); - return std::move(err); - } - auto& val = eoval.get_value(); - try { - auto insert_res = vals.insert(std::make_pair(store_id.get_value(), std::move(val))); - if(!insert_res.second){ - return make_error(); - } - }catch ( std::exception& ){ - return make_error(); - } - return void_t{}; - */ - } - - /** - * Requests data from the server - */ - template - error_or> receive(id store_id){ - auto& vals = std::get>>>(values_); - auto find_res = vals.find(store_id.get_value()); - if(find_res == vals.end()){ - return make_error(); - } - auto& dat = find_res->second; - - return make_error(); - } - - /** - * Request an erase of the stored data - */ - template - error_or erase(id store_id){ - auto& vals = std::get>>(values_); - auto erase_op = vals.erase(store_id.get_value()); - if(erase_op == 0u){ - return make_error(); - } - return void_t{}; - } - - /** - * Get the stored data on the server side for immediate use. - * Insert operations may invalidate the pointer. - */ - template - error_or>*> find(id store_id){ - auto& vals = std::get>>(values_); - auto find_res = vals.find(store_id.get_value()); - if(find_res == vals.end()){ - return make_error(); - } - - return &(find_res.second); - } -}; - -/** - * Client for transporting data to remote and receiving data back - */ -template -class data_client, Encoding, rmt::Hip> { -private: - /** - * Corresponding server for this client - */ - data_server, Encoding, rmt::Hip>* srv_; - - /** - * The next id for identifying issues on the remote side. - */ - uint64_t next_id_; -public: - /** - * Main constructor - */ - data_client(data_server, Encoding, rmt::Hip>& srv__): - srv_{&srv__}, - next_id_{0u} - {} - - /** - * Send data to the remote. - */ - template - error_or> send(const data& dat){ - id dat_id{next_id_}; - auto eov = srv_->send(dat, dat_id); - if(eov.is_error()){ - auto& err = eov.get_error(); - return std::move(err); - } - - ++next_id_; - return dat_id; - } - - /** - * Receive data - */ - template - conveyor> receive(id dat_id){ - auto eov = srv_->receive(dat_id); - if(eov.is_error()){ - auto& err = eov.get_error(); - return std::move(err); - } - - auto& val = eov.get_value(); - return std::move(val); - } - - /** - * Erase data - */ - template - error_or erase(id dat_id){ - return srv_->erase(dat_id); - } -}; } diff --git a/modules/remote-hip/examples/hip_transfer_data.cpp b/modules/remote-hip/examples/hip_transfer_data.cpp index 49ff856..ae530bd 100644 --- a/modules/remote-hip/examples/hip_transfer_data.cpp +++ b/modules/remote-hip/examples/hip_transfer_data.cpp @@ -3,6 +3,10 @@ #include +__global__ print_value(int16_t val){ + printf("Hello world: %d", val); +} + namespace sch { using namespace saw::schema; } @@ -25,13 +29,20 @@ saw::error_or real_main(){ auto& dat_srv = eo_dat_srv.get_value(); data val{42}; - id id_val{0u}; auto eo_send = dat_srv->send(val, id_val); if(eo_send.is_error()){ return std::move(eo_send.get_error()); } + auto eo_dfind = dat_srv->find(id_val); + if(eo_dfind.is_error()){ + return std::move(eo_dfind.get_error()); + } + auto dfind = eo_dfind.get_value(); + + print_value<<>>(dfind()); + return make_void(); } -- cgit v1.2.3