summaryrefslogtreecommitdiff
path: root/modules/remote-hip
diff options
context:
space:
mode:
authorClaudius 'keldu' Holeksa <mail@keldu.de>2024-09-17 11:21:05 +0200
committerClaudius 'keldu' Holeksa <mail@keldu.de>2024-09-17 11:21:05 +0200
commit1d578450dc82843bd4b24f3a6aad2c1a82bbda5e (patch)
treefdb34a5629308fad6ef9c5e3f0a4290bb688c3c4 /modules/remote-hip
parentb23b2276b1ab7977e4cea721322f9d31f6ef85ca (diff)
Managed to get hip to compile
Diffstat (limited to 'modules/remote-hip')
-rw-r--r--modules/remote-hip/c++/data.hpp8
-rw-r--r--modules/remote-hip/c++/device.tmpl.hpp2
-rw-r--r--modules/remote-hip/examples/hip_transfer_data.cpp75
3 files changed, 59 insertions, 26 deletions
diff --git a/modules/remote-hip/c++/data.hpp b/modules/remote-hip/c++/data.hpp
index 3e7c3ed..5d3635f 100644
--- a/modules/remote-hip/c++/data.hpp
+++ b/modules/remote-hip/c++/data.hpp
@@ -8,16 +8,16 @@ namespace saw {
* Generic wrapper class which stores data on the sycl side.
* Most of the times this will be a root object.
*/
-template<typename Schema>
-class data<Schema, encode::Hip<encode::Native>> {
+template<typename Schema,typename T>
+class data<Schema, encode::Hip<T>> {
private:
- data<Schema, encode::Native>* data_;
+ data<Schema, T>* data_;
public:
data():
data_{nullptr}
{}
- data<Schema, encode::Native>** get_device_data() {
+ data<Schema, T>** get_device_data() {
return &data_;
}
};
diff --git a/modules/remote-hip/c++/device.tmpl.hpp b/modules/remote-hip/c++/device.tmpl.hpp
index 6edf431..0517f67 100644
--- a/modules/remote-hip/c++/device.tmpl.hpp
+++ b/modules/remote-hip/c++/device.tmpl.hpp
@@ -33,7 +33,7 @@ struct hip_copy_to_device<schema::Array<T,Dim>, Encoding> {
static error_or<void> apply(data<Schema, Encoding>& from, data<Schema,Encoding>** to){
typename native_data_type<T>::type* dat{};
hipError_t data_malloc_err = hipMalloc(&dat,sizeof(typename native_data_type<T>::type) * from.size());
- hipError_t data_copy_err = hipMemcpy(&dat, &(from.get_raw_data()),sizeof(typename native_data_type<T>::type) * from.size(), hipMemcpyHostToDevice);
+ hipError_t data_copy_err = hipMemcpy(&dat, (from.get_raw_data()),sizeof(typename native_data_type<T>::type) * from.size(), hipMemcpyHostToDevice);
if(from.size() == 0u){
// Everything is fine. We just don't want to allocate data which doesn't exist.
diff --git a/modules/remote-hip/examples/hip_transfer_data.cpp b/modules/remote-hip/examples/hip_transfer_data.cpp
index 18c82df..3f02d87 100644
--- a/modules/remote-hip/examples/hip_transfer_data.cpp
+++ b/modules/remote-hip/examples/hip_transfer_data.cpp
@@ -1,20 +1,22 @@
+#include <forstio/codec/data_raw.hpp>
+
#include "../c++/remote.hpp"
#include "../c++/transfer.hpp"
#include <iostream>
-__global__ void print_value(saw::data<saw::schema::Int16,saw::encode::Native>* val){
+__global__ void print_value(saw::data<saw::schema::Int16,saw::encode::NativeRaw>* val){
int v = val->get();
printf("Hello world: %d\n", v);
}
-__global__ void print_array_vals(saw::data<saw::schema::Array<saw::schema::Int16>* val){
+__global__ void print_array_vals(saw::data<saw::schema::Array<saw::schema::Int16>, saw::encode::NativeRaw>* val){
uint64_t orig_len = val->size();
long len = (long) orig_len;
printf("Array size: %ld\n", len);
- for(uint64_t i = 0; i < orig_len; +i){
- int v = val->at(i);
+ for(uint64_t i = 0; i < orig_len; ++i){
+ int v = val->at(i).get();
printf("%d ", v);
}
printf("\n");
@@ -35,28 +37,59 @@ saw::error_or<void> real_main(){
}
auto& addr = eo_addr.get_value();
- auto eo_dat_srv = rmt.data_listen<sch::Int16, encode::Native>(*addr);
- if(eo_dat_srv.is_error()){
- return std::move(eo_dat_srv.get_error());
- }
- auto& dat_srv = eo_dat_srv.get_value();
+ {
+ auto eo_dat_srv = rmt.data_listen<sch::Int16, encode::NativeRaw>(*addr);
+ if(eo_dat_srv.is_error()){
+ return std::move(eo_dat_srv.get_error());
+ }
+ auto& dat_srv = eo_dat_srv.get_value();
- data<sch::Int16> val{42};
- id<sch::Int16> id_val{0u};
- auto eo_send = dat_srv->send(val, id_val);
- if(eo_send.is_error()){
- return std::move(eo_send.get_error());
- }
+ data<sch::Int16,encode::NativeRaw> val{42};
+ id<sch::Int16> id_val{0u};
+ auto eo_send = dat_srv->send(val, id_val);
+ if(eo_send.is_error()){
+ return std::move(eo_send.get_error());
+ }
+
+ auto eo_dfind = dat_srv->find(id_val);
+ if(eo_dfind.is_error()){
+ return std::move(eo_dfind.get_error());
+ }
+ auto dfind = eo_dfind.get_value();
- auto eo_dfind = dat_srv->find(id_val);
- if(eo_dfind.is_error()){
- return std::move(eo_dfind.get_error());
+ auto& v = dfind();
+
+ print_value<<<dim3(2),dim3(2),0,hipStreamDefault>>>(*(v.get_device_data()));
}
- auto dfind = eo_dfind.get_value();
- auto& v = dfind();
+ {
+ auto eo_dat_srv = rmt.data_listen<sch::Array<sch::Int16>, encode::NativeRaw>(*addr);
+ if(eo_dat_srv.is_error()){
+ return std::move(eo_dat_srv.get_error());
+ }
+ auto& dat_srv = eo_dat_srv.get_value();
+
+ data<sch::Array<sch::Int16>,encode::NativeRaw> val{4};
+ val.at(0u).set(5);
+ val.at(1u).set(3);
+ val.at(2u).set(-6);
+ val.at(3u).set(1);
+ id<sch::Array<sch::Int16>> id_val{0u};
+ auto eo_send = dat_srv->send(val, id_val);
+ if(eo_send.is_error()){
+ return std::move(eo_send.get_error());
+ }
- print_value<<<dim3(2),dim3(2),0,hipStreamDefault>>>(*(v.get_device_data()));
+ auto eo_dfind = dat_srv->find(id_val);
+ if(eo_dfind.is_error()){
+ return std::move(eo_dfind.get_error());
+ }
+ auto dfind = eo_dfind.get_value();
+
+ auto& v = dfind();
+
+ print_array_vals<<<dim3(2),dim3(2),0,hipStreamDefault>>>(*(v.get_device_data()));
+ }
return make_void();
}