diff options
Diffstat (limited to 'modules/remote-hip')
-rw-r--r-- | modules/remote-hip/c++/data.hpp | 10 | ||||
-rw-r--r-- | modules/remote-hip/c++/device.hpp | 17 | ||||
-rw-r--r-- | modules/remote-hip/c++/device.tmpl.hpp | 25 | ||||
-rw-r--r-- | modules/remote-hip/c++/remote.hpp | 24 | ||||
-rw-r--r-- | modules/remote-hip/c++/transfer.hpp | 188 | ||||
-rw-r--r-- | modules/remote-hip/examples/hip_transfer_data.cpp | 13 |
6 files changed, 75 insertions, 202 deletions
diff --git a/modules/remote-hip/c++/data.hpp b/modules/remote-hip/c++/data.hpp index 5e8e6f9..3e7c3ed 100644 --- a/modules/remote-hip/c++/data.hpp +++ b/modules/remote-hip/c++/data.hpp @@ -11,14 +11,14 @@ namespace saw { template<typename Schema> class data<Schema, encode::Hip<encode::Native>> { private: - data<Schema, encode::Native> data_; + data<Schema, encode::Native>* data_; public: - data(const data<Schema, encode::Native>& data__): - data_{data__} + data(): + data_{nullptr} {} - ref<data<Schema, encode::Native>> get_data() { - return {data_}; + data<Schema, encode::Native>** get_device_data() { + return &data_; } }; } diff --git a/modules/remote-hip/c++/device.hpp b/modules/remote-hip/c++/device.hpp index 227ed1b..f760024 100644 --- a/modules/remote-hip/c++/device.hpp +++ b/modules/remote-hip/c++/device.hpp @@ -2,7 +2,10 @@ #include "common.hpp" + +#include "device.tmpl.hpp" namespace saw { + /** * Represents a remote Sycl device. */ @@ -14,6 +17,20 @@ public: SAW_FORBID_COPY(device); SAW_FORBID_MOVE(device); + + template<typename Schema, typename Encoding> + error_or<void> copy_to_device(data<Schema, Encoding>& from, data<Schema, encode::Hip<Encoding>>& to){ + + auto dev_data = to.get_device_data(); + + auto eov = impl::hip_copy_to_device<Schema,Encoding>::apply(from, dev_data); + return eov; + } + + template<typename Schema, typename Encoding> + error_or<void> copy_to_host(data<Schema,encode::Hip<Encoding>>& from, data<Schema,Encoding>& to){ + return make_error<err::not_implemented>(); + } }; } diff --git a/modules/remote-hip/c++/device.tmpl.hpp b/modules/remote-hip/c++/device.tmpl.hpp new file mode 100644 index 0000000..4777660 --- /dev/null +++ b/modules/remote-hip/c++/device.tmpl.hpp @@ -0,0 +1,25 @@ +namespace saw { +namespace impl { +template<typename Schema, typename Encoding> +struct hip_copy_to_device { + static error_or<void> apply(data<Schema, Encoding>& from, data<Schema, Encoding>** to){ + static_assert(always_false<Schema,Encoding>, "Unsupported case."); + return make_void(); + } +}; + +template<typename T, uint64_t N, typename Encoding> +struct hip_copy_to_device<schema::Primitive<T,N>, Encoding> { + using Schema = schema::Primitive<T,N>; + static error_or<void> apply(data<Schema, Encoding>& from, data<Schema,Encoding>** to){ + hipError_t malloc_err = hipMalloc(to, sizeof(data<Schema,Encoding>)); + // HIP_CHECK(malloc_err); + + hipError_t copy_err = hipMemcpy(*to, &from, sizeof(data<Schema,Encoding>), hipMemcpyHostToDevice); + // HIP_CHECK(copy_err); + + return make_void(); + } +}; +} +} diff --git a/modules/remote-hip/c++/remote.hpp b/modules/remote-hip/c++/remote.hpp index 794d629..242c06d 100644 --- a/modules/remote-hip/c++/remote.hpp +++ b/modules/remote-hip/c++/remote.hpp @@ -62,6 +62,13 @@ public: conveyor<own<remote_address<rmt::Hip>>> resolve_address(uint64_t dev_id = 0u){ return heap<remote_address<rmt::Hip>>(dev_id); } + + /** + * Parse address, but don't resolve it. + */ + error_or<own<remote_address<rmt::Hip>>> parse_address(uint64_t dev_id = 0u){ + return heap<remote_address<rmt::Hip>>(dev_id); + } /** * Info. @@ -97,13 +104,6 @@ public: } /** - * Parse address, but don't resolve it. - */ - error_or<own<remote_address<rmt::Hip>>> parse_address(uint64_t dev_id = 0u){ - return heap<remote_address<rmt::Hip>>(dev_id); - } - - /** * Spin up data server */ template<typename Schema, typename Encoding> @@ -115,16 +115,6 @@ public: } return heap<data_server<Schema, Encoding, rmt::Hip>>(ins.first->second); } - - /** - * Spin up a rpc server - */ - template<typename Iface, typename Encoding> - rpc_server<Iface, Encoding, rmt::Hip> listen(remote_address<rmt::Hip>& dev, typename rpc_server<Iface, Encoding, rmt::Hip>::InterfaceT iface){ - //using RpcServerT = rpc_server<Iface, Encoding, rmt::Hip>; - //using InterfaceT = typename RpcServerT::InterfaceT; - return {share<device<rmt::Hip>>(), std::move(iface)}; - } }; } diff --git a/modules/remote-hip/c++/transfer.hpp b/modules/remote-hip/c++/transfer.hpp index cdde8ba..d0ece27 100644 --- a/modules/remote-hip/c++/transfer.hpp +++ b/modules/remote-hip/c++/transfer.hpp @@ -28,14 +28,20 @@ public: device_{std::move(device__)} {} - error_or<void> send(const data<Schema,Encoding>& dat, id<Schema> store_id){ + error_or<void> send(data<Schema,Encoding>& dat, id<Schema> store_id){ + data<Schema, encode::Hip<Encoding>> hip_dat; + { + auto eov = device_->copy_to_device(dat, hip_dat); + if(eov.is_error()){ + return eov; + } + } - auto ins = values_.emplace(std::make_pair(store_id.get_value(), data<Schema, encode::Hip<Encoding>>{dat})); + auto ins = values_.emplace(std::make_pair(store_id.get_value(), hip_dat)); if(!ins.second){ return make_error<err::already_exists>(); } - return make_error<err::not_implemented>("Allocate not implemented. Since we don't actually do any device copies."); return make_void(); } @@ -63,180 +69,4 @@ public: } }; -template<typename... Schema, typename Encoding> -class data_server<tmpl_group<Schema...>, Encoding, rmt::Hip> { -private: - /** - * Device context class - */ - our<device<rmt::Hip>> device_; - - /** - * Store for the data the server manages. - */ - typename impl::data_server_redux<encode::Hip<Encoding>, typename tmpl_reduce<tmpl_group<Schema...>>::type >::type values_; -public: - /** - * Main constructor - */ - data_server(our<device<rmt::Hip>> device__): - device_{std::move(device__)} - {} - - /** - * Get data which we will store. - */ - template<typename Sch> - error_or<void> send(const data<Sch, Encoding>& dat, id<Sch> store_id){ - return make_error<err::not_implemented>(); - /* - auto& vals = std::get<std::unordered_map<uint64_t, data<Sch,encode::Hip<Encoding>>>>(values_); - auto eoval = device_->template copy_to_device<Sch, Encoding>(dat); - if(eoval.is_error()){ - auto& err = eoval.get_error(); - return std::move(err); - } - auto& val = eoval.get_value(); - try { - auto insert_res = vals.insert(std::make_pair(store_id.get_value(), std::move(val))); - if(!insert_res.second){ - return make_error<err::already_exists>(); - } - }catch ( std::exception& ){ - return make_error<err::out_of_memory>(); - } - return void_t{}; - */ - } - - template<typename Sch> - error_or<void> allocate(const data<typename meta_schema<Sch>::MetaSchema, Encoding>& dat, id<Sch> store_id){ - return make_error<err::not_implemented>(); - /* - auto& vals = std::get<std::unordered_map<uint64_t, data<Sch,encode::Hip<Encoding>>>>(values_); - auto eoval = device_->template allocate_on_device<Sch, Encoding>(dat); - if(eoval.is_error()){ - auto& err = eoval.get_error(); - return std::move(err); - } - auto& val = eoval.get_value(); - try { - auto insert_res = vals.insert(std::make_pair(store_id.get_value(), std::move(val))); - if(!insert_res.second){ - return make_error<err::already_exists>(); - } - }catch ( std::exception& ){ - return make_error<err::out_of_memory>(); - } - return void_t{}; - */ - } - - /** - * Requests data from the server - */ - template<typename Sch> - error_or<data<Sch, Encoding>> receive(id<Sch> store_id){ - auto& vals = std::get<std::unordered_map<uint64_t, data<Sch,encode::Hip<Encoding>>>>(values_); - auto find_res = vals.find(store_id.get_value()); - if(find_res == vals.end()){ - return make_error<err::not_found>(); - } - auto& dat = find_res->second; - - return make_error<err::not_implemented>(); - } - - /** - * Request an erase of the stored data - */ - template<typename Sch> - error_or<void> erase(id<Sch> store_id){ - auto& vals = std::get<std::unordered_map<uint64_t, data<Sch,Encoding>>>(values_); - auto erase_op = vals.erase(store_id.get_value()); - if(erase_op == 0u){ - return make_error<err::not_found>(); - } - return void_t{}; - } - - /** - * Get the stored data on the server side for immediate use. - * Insert operations may invalidate the pointer. - */ - template<typename Sch> - error_or<data<Sch, encode::Hip<Encoding>>*> find(id<Sch> store_id){ - auto& vals = std::get<std::unordered_map<uint64_t, data<Sch,Encoding>>>(values_); - auto find_res = vals.find(store_id.get_value()); - if(find_res == vals.end()){ - return make_error<err::not_found>(); - } - - return &(find_res.second); - } -}; - -/** - * Client for transporting data to remote and receiving data back - */ -template<typename... Schema, typename Encoding> -class data_client<tmpl_group<Schema...>, Encoding, rmt::Hip> { -private: - /** - * Corresponding server for this client - */ - data_server<tmpl_group<Schema...>, Encoding, rmt::Hip>* srv_; - - /** - * The next id for identifying issues on the remote side. - */ - uint64_t next_id_; -public: - /** - * Main constructor - */ - data_client(data_server<tmpl_group<Schema...>, Encoding, rmt::Hip>& srv__): - srv_{&srv__}, - next_id_{0u} - {} - - /** - * Send data to the remote. - */ - template<typename Sch> - error_or<id<Sch>> send(const data<Sch, Encoding>& dat){ - id<Sch> dat_id{next_id_}; - auto eov = srv_->send(dat, dat_id); - if(eov.is_error()){ - auto& err = eov.get_error(); - return std::move(err); - } - - ++next_id_; - return dat_id; - } - - /** - * Receive data - */ - template<typename Sch> - conveyor<data<Sch, Encoding>> receive(id<Sch> dat_id){ - auto eov = srv_->receive(dat_id); - if(eov.is_error()){ - auto& err = eov.get_error(); - return std::move(err); - } - - auto& val = eov.get_value(); - return std::move(val); - } - - /** - * Erase data - */ - template<typename Sch> - error_or<void> erase(id<Sch> dat_id){ - return srv_->erase(dat_id); - } -}; } diff --git a/modules/remote-hip/examples/hip_transfer_data.cpp b/modules/remote-hip/examples/hip_transfer_data.cpp index 49ff856..ae530bd 100644 --- a/modules/remote-hip/examples/hip_transfer_data.cpp +++ b/modules/remote-hip/examples/hip_transfer_data.cpp @@ -3,6 +3,10 @@ #include <iostream> +__global__ print_value(int16_t val){ + printf("Hello world: %d", val); +} + namespace sch { using namespace saw::schema; } @@ -25,13 +29,20 @@ saw::error_or<void> real_main(){ auto& dat_srv = eo_dat_srv.get_value(); data<sch::Int16> val{42}; - id<sch::Int16> id_val{0u}; auto eo_send = dat_srv->send(val, id_val); if(eo_send.is_error()){ return std::move(eo_send.get_error()); } + auto eo_dfind = dat_srv->find(id_val); + if(eo_dfind.is_error()){ + return std::move(eo_dfind.get_error()); + } + auto dfind = eo_dfind.get_value(); + + print_value<<<dim3(2),dim3(2),0,hipStreamDefault>>>(dfind()); + return make_void(); } |