summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorClaudius 'keldu' Holeksa <mail@keldu.de>2024-09-17 11:58:24 +0200
committerClaudius 'keldu' Holeksa <mail@keldu.de>2024-09-17 11:58:24 +0200
commitd2c1c0d73a602b77ae2eac1570e9f95141c0c666 (patch)
treec86642ac1f13b1566c4558eafd5b1ef44afd86c5
parent81a5bc78c326181dd7f8d5181e146979d12ba753 (diff)
wip
-rw-r--r--modules/remote-hip/c++/device.tmpl.hpp9
-rw-r--r--modules/remote-hip/examples/hip_transfer_data.cpp6
2 files changed, 9 insertions, 6 deletions
diff --git a/modules/remote-hip/c++/device.tmpl.hpp b/modules/remote-hip/c++/device.tmpl.hpp
index ce8b4ed..2ee991d 100644
--- a/modules/remote-hip/c++/device.tmpl.hpp
+++ b/modules/remote-hip/c++/device.tmpl.hpp
@@ -32,12 +32,9 @@ struct hip_copy_to_device<schema::Array<T,Dim>, Encoding> {
static error_or<void> apply(data<Schema, Encoding>& from, data<Schema,Encoding>** to){
typename native_data_type<T>::type* dat{};
- hipError_t data_malloc_err = hipMalloc(&dat,sizeof(typename native_data_type<T>::type) * from.size());
- hipError_t data_copy_err = hipMemcpy(dat, (from.get_raw_data()),sizeof(typename native_data_type<T>::type) * from.size(), hipMemcpyHostToDevice);
-
- if(from.size() == 0u){
- // Everything is fine. We just don't want to allocate data which doesn't exist.
- return make_void();
+ if(from.size() > 0u){
+ hipError_t data_malloc_err = hipMalloc(&dat,sizeof(typename native_data_type<T>::type) * from.size());
+ hipError_t data_copy_err = hipMemcpy(dat, (from.get_raw_data()),sizeof(typename native_data_type<T>::type) * from.size(), hipMemcpyHostToDevice);
}
// auto from_dat = &from.at(0);
diff --git a/modules/remote-hip/examples/hip_transfer_data.cpp b/modules/remote-hip/examples/hip_transfer_data.cpp
index f112c68..2742715 100644
--- a/modules/remote-hip/examples/hip_transfer_data.cpp
+++ b/modules/remote-hip/examples/hip_transfer_data.cpp
@@ -19,6 +19,12 @@ __global__ void print_array_vals(saw::data<saw::schema::Array<saw::schema::Int16
int v = val->at(i).get();
printf("%d ", v);
}
+
+ auto raw_d = val->get_raw_data();
+
+ for(uint64_t i = 0; i < orig_len; ++i){
+ printf("%d ", raw_d[i]);
+ }
printf("\n");
}