summaryrefslogtreecommitdiff
path: root/modules
diff options
context:
space:
mode:
Diffstat (limited to 'modules')
-rw-r--r--modules/remote-hip/c++/data.hpp10
-rw-r--r--modules/remote-hip/c++/device.hpp17
-rw-r--r--modules/remote-hip/c++/device.tmpl.hpp25
-rw-r--r--modules/remote-hip/c++/remote.hpp24
-rw-r--r--modules/remote-hip/c++/transfer.hpp188
-rw-r--r--modules/remote-hip/examples/hip_transfer_data.cpp13
6 files changed, 75 insertions, 202 deletions
diff --git a/modules/remote-hip/c++/data.hpp b/modules/remote-hip/c++/data.hpp
index 5e8e6f9..3e7c3ed 100644
--- a/modules/remote-hip/c++/data.hpp
+++ b/modules/remote-hip/c++/data.hpp
@@ -11,14 +11,14 @@ namespace saw {
template<typename Schema>
class data<Schema, encode::Hip<encode::Native>> {
private:
- data<Schema, encode::Native> data_;
+ data<Schema, encode::Native>* data_;
public:
- data(const data<Schema, encode::Native>& data__):
- data_{data__}
+ data():
+ data_{nullptr}
{}
- ref<data<Schema, encode::Native>> get_data() {
- return {data_};
+ data<Schema, encode::Native>** get_device_data() {
+ return &data_;
}
};
}
diff --git a/modules/remote-hip/c++/device.hpp b/modules/remote-hip/c++/device.hpp
index 227ed1b..f760024 100644
--- a/modules/remote-hip/c++/device.hpp
+++ b/modules/remote-hip/c++/device.hpp
@@ -2,7 +2,10 @@
#include "common.hpp"
+
+#include "device.tmpl.hpp"
namespace saw {
+
/**
* Represents a remote Sycl device.
*/
@@ -14,6 +17,20 @@ public:
SAW_FORBID_COPY(device);
SAW_FORBID_MOVE(device);
+
+ template<typename Schema, typename Encoding>
+ error_or<void> copy_to_device(data<Schema, Encoding>& from, data<Schema, encode::Hip<Encoding>>& to){
+
+ auto dev_data = to.get_device_data();
+
+ auto eov = impl::hip_copy_to_device<Schema,Encoding>::apply(from, dev_data);
+ return eov;
+ }
+
+ template<typename Schema, typename Encoding>
+ error_or<void> copy_to_host(data<Schema,encode::Hip<Encoding>>& from, data<Schema,Encoding>& to){
+ return make_error<err::not_implemented>();
+ }
};
}
diff --git a/modules/remote-hip/c++/device.tmpl.hpp b/modules/remote-hip/c++/device.tmpl.hpp
new file mode 100644
index 0000000..4777660
--- /dev/null
+++ b/modules/remote-hip/c++/device.tmpl.hpp
@@ -0,0 +1,25 @@
+namespace saw {
+namespace impl {
+template<typename Schema, typename Encoding>
+struct hip_copy_to_device {
+ static error_or<void> apply(data<Schema, Encoding>& from, data<Schema, Encoding>** to){
+ static_assert(always_false<Schema,Encoding>, "Unsupported case.");
+ return make_void();
+ }
+};
+
+template<typename T, uint64_t N, typename Encoding>
+struct hip_copy_to_device<schema::Primitive<T,N>, Encoding> {
+ using Schema = schema::Primitive<T,N>;
+ static error_or<void> apply(data<Schema, Encoding>& from, data<Schema,Encoding>** to){
+ hipError_t malloc_err = hipMalloc(to, sizeof(data<Schema,Encoding>));
+ // HIP_CHECK(malloc_err);
+
+ hipError_t copy_err = hipMemcpy(*to, &from, sizeof(data<Schema,Encoding>), hipMemcpyHostToDevice);
+ // HIP_CHECK(copy_err);
+
+ return make_void();
+ }
+};
+}
+}
diff --git a/modules/remote-hip/c++/remote.hpp b/modules/remote-hip/c++/remote.hpp
index 794d629..242c06d 100644
--- a/modules/remote-hip/c++/remote.hpp
+++ b/modules/remote-hip/c++/remote.hpp
@@ -62,6 +62,13 @@ public:
conveyor<own<remote_address<rmt::Hip>>> resolve_address(uint64_t dev_id = 0u){
return heap<remote_address<rmt::Hip>>(dev_id);
}
+
+ /**
+ * Parse address, but don't resolve it.
+ */
+ error_or<own<remote_address<rmt::Hip>>> parse_address(uint64_t dev_id = 0u){
+ return heap<remote_address<rmt::Hip>>(dev_id);
+ }
/**
* Info.
@@ -97,13 +104,6 @@ public:
}
/**
- * Parse address, but don't resolve it.
- */
- error_or<own<remote_address<rmt::Hip>>> parse_address(uint64_t dev_id = 0u){
- return heap<remote_address<rmt::Hip>>(dev_id);
- }
-
- /**
* Spin up data server
*/
template<typename Schema, typename Encoding>
@@ -115,16 +115,6 @@ public:
}
return heap<data_server<Schema, Encoding, rmt::Hip>>(ins.first->second);
}
-
- /**
- * Spin up a rpc server
- */
- template<typename Iface, typename Encoding>
- rpc_server<Iface, Encoding, rmt::Hip> listen(remote_address<rmt::Hip>& dev, typename rpc_server<Iface, Encoding, rmt::Hip>::InterfaceT iface){
- //using RpcServerT = rpc_server<Iface, Encoding, rmt::Hip>;
- //using InterfaceT = typename RpcServerT::InterfaceT;
- return {share<device<rmt::Hip>>(), std::move(iface)};
- }
};
}
diff --git a/modules/remote-hip/c++/transfer.hpp b/modules/remote-hip/c++/transfer.hpp
index cdde8ba..d0ece27 100644
--- a/modules/remote-hip/c++/transfer.hpp
+++ b/modules/remote-hip/c++/transfer.hpp
@@ -28,14 +28,20 @@ public:
device_{std::move(device__)}
{}
- error_or<void> send(const data<Schema,Encoding>& dat, id<Schema> store_id){
+ error_or<void> send(data<Schema,Encoding>& dat, id<Schema> store_id){
+ data<Schema, encode::Hip<Encoding>> hip_dat;
+ {
+ auto eov = device_->copy_to_device(dat, hip_dat);
+ if(eov.is_error()){
+ return eov;
+ }
+ }
- auto ins = values_.emplace(std::make_pair(store_id.get_value(), data<Schema, encode::Hip<Encoding>>{dat}));
+ auto ins = values_.emplace(std::make_pair(store_id.get_value(), hip_dat));
if(!ins.second){
return make_error<err::already_exists>();
}
- return make_error<err::not_implemented>("Allocate not implemented. Since we don't actually do any device copies.");
return make_void();
}
@@ -63,180 +69,4 @@ public:
}
};
-template<typename... Schema, typename Encoding>
-class data_server<tmpl_group<Schema...>, Encoding, rmt::Hip> {
-private:
- /**
- * Device context class
- */
- our<device<rmt::Hip>> device_;
-
- /**
- * Store for the data the server manages.
- */
- typename impl::data_server_redux<encode::Hip<Encoding>, typename tmpl_reduce<tmpl_group<Schema...>>::type >::type values_;
-public:
- /**
- * Main constructor
- */
- data_server(our<device<rmt::Hip>> device__):
- device_{std::move(device__)}
- {}
-
- /**
- * Get data which we will store.
- */
- template<typename Sch>
- error_or<void> send(const data<Sch, Encoding>& dat, id<Sch> store_id){
- return make_error<err::not_implemented>();
- /*
- auto& vals = std::get<std::unordered_map<uint64_t, data<Sch,encode::Hip<Encoding>>>>(values_);
- auto eoval = device_->template copy_to_device<Sch, Encoding>(dat);
- if(eoval.is_error()){
- auto& err = eoval.get_error();
- return std::move(err);
- }
- auto& val = eoval.get_value();
- try {
- auto insert_res = vals.insert(std::make_pair(store_id.get_value(), std::move(val)));
- if(!insert_res.second){
- return make_error<err::already_exists>();
- }
- }catch ( std::exception& ){
- return make_error<err::out_of_memory>();
- }
- return void_t{};
- */
- }
-
- template<typename Sch>
- error_or<void> allocate(const data<typename meta_schema<Sch>::MetaSchema, Encoding>& dat, id<Sch> store_id){
- return make_error<err::not_implemented>();
- /*
- auto& vals = std::get<std::unordered_map<uint64_t, data<Sch,encode::Hip<Encoding>>>>(values_);
- auto eoval = device_->template allocate_on_device<Sch, Encoding>(dat);
- if(eoval.is_error()){
- auto& err = eoval.get_error();
- return std::move(err);
- }
- auto& val = eoval.get_value();
- try {
- auto insert_res = vals.insert(std::make_pair(store_id.get_value(), std::move(val)));
- if(!insert_res.second){
- return make_error<err::already_exists>();
- }
- }catch ( std::exception& ){
- return make_error<err::out_of_memory>();
- }
- return void_t{};
- */
- }
-
- /**
- * Requests data from the server
- */
- template<typename Sch>
- error_or<data<Sch, Encoding>> receive(id<Sch> store_id){
- auto& vals = std::get<std::unordered_map<uint64_t, data<Sch,encode::Hip<Encoding>>>>(values_);
- auto find_res = vals.find(store_id.get_value());
- if(find_res == vals.end()){
- return make_error<err::not_found>();
- }
- auto& dat = find_res->second;
-
- return make_error<err::not_implemented>();
- }
-
- /**
- * Request an erase of the stored data
- */
- template<typename Sch>
- error_or<void> erase(id<Sch> store_id){
- auto& vals = std::get<std::unordered_map<uint64_t, data<Sch,Encoding>>>(values_);
- auto erase_op = vals.erase(store_id.get_value());
- if(erase_op == 0u){
- return make_error<err::not_found>();
- }
- return void_t{};
- }
-
- /**
- * Get the stored data on the server side for immediate use.
- * Insert operations may invalidate the pointer.
- */
- template<typename Sch>
- error_or<data<Sch, encode::Hip<Encoding>>*> find(id<Sch> store_id){
- auto& vals = std::get<std::unordered_map<uint64_t, data<Sch,Encoding>>>(values_);
- auto find_res = vals.find(store_id.get_value());
- if(find_res == vals.end()){
- return make_error<err::not_found>();
- }
-
- return &(find_res.second);
- }
-};
-
-/**
- * Client for transporting data to remote and receiving data back
- */
-template<typename... Schema, typename Encoding>
-class data_client<tmpl_group<Schema...>, Encoding, rmt::Hip> {
-private:
- /**
- * Corresponding server for this client
- */
- data_server<tmpl_group<Schema...>, Encoding, rmt::Hip>* srv_;
-
- /**
- * The next id for identifying issues on the remote side.
- */
- uint64_t next_id_;
-public:
- /**
- * Main constructor
- */
- data_client(data_server<tmpl_group<Schema...>, Encoding, rmt::Hip>& srv__):
- srv_{&srv__},
- next_id_{0u}
- {}
-
- /**
- * Send data to the remote.
- */
- template<typename Sch>
- error_or<id<Sch>> send(const data<Sch, Encoding>& dat){
- id<Sch> dat_id{next_id_};
- auto eov = srv_->send(dat, dat_id);
- if(eov.is_error()){
- auto& err = eov.get_error();
- return std::move(err);
- }
-
- ++next_id_;
- return dat_id;
- }
-
- /**
- * Receive data
- */
- template<typename Sch>
- conveyor<data<Sch, Encoding>> receive(id<Sch> dat_id){
- auto eov = srv_->receive(dat_id);
- if(eov.is_error()){
- auto& err = eov.get_error();
- return std::move(err);
- }
-
- auto& val = eov.get_value();
- return std::move(val);
- }
-
- /**
- * Erase data
- */
- template<typename Sch>
- error_or<void> erase(id<Sch> dat_id){
- return srv_->erase(dat_id);
- }
-};
}
diff --git a/modules/remote-hip/examples/hip_transfer_data.cpp b/modules/remote-hip/examples/hip_transfer_data.cpp
index 49ff856..ae530bd 100644
--- a/modules/remote-hip/examples/hip_transfer_data.cpp
+++ b/modules/remote-hip/examples/hip_transfer_data.cpp
@@ -3,6 +3,10 @@
#include <iostream>
+__global__ print_value(int16_t val){
+ printf("Hello world: %d", val);
+}
+
namespace sch {
using namespace saw::schema;
}
@@ -25,13 +29,20 @@ saw::error_or<void> real_main(){
auto& dat_srv = eo_dat_srv.get_value();
data<sch::Int16> val{42};
-
id<sch::Int16> id_val{0u};
auto eo_send = dat_srv->send(val, id_val);
if(eo_send.is_error()){
return std::move(eo_send.get_error());
}
+ auto eo_dfind = dat_srv->find(id_val);
+ if(eo_dfind.is_error()){
+ return std::move(eo_dfind.get_error());
+ }
+ auto dfind = eo_dfind.get_value();
+
+ print_value<<<dim3(2),dim3(2),0,hipStreamDefault>>>(dfind());
+
return make_void();
}