diff options
author | Claudius 'keldu' Holeksa <mail@keldu.de> | 2024-09-17 11:58:24 +0200 |
---|---|---|
committer | Claudius 'keldu' Holeksa <mail@keldu.de> | 2024-09-17 11:58:24 +0200 |
commit | d2c1c0d73a602b77ae2eac1570e9f95141c0c666 (patch) | |
tree | c86642ac1f13b1566c4558eafd5b1ef44afd86c5 /modules | |
parent | 81a5bc78c326181dd7f8d5181e146979d12ba753 (diff) |
wip
Diffstat (limited to 'modules')
-rw-r--r-- | modules/remote-hip/c++/device.tmpl.hpp | 9 | ||||
-rw-r--r-- | modules/remote-hip/examples/hip_transfer_data.cpp | 6 |
2 files changed, 9 insertions, 6 deletions
diff --git a/modules/remote-hip/c++/device.tmpl.hpp b/modules/remote-hip/c++/device.tmpl.hpp index ce8b4ed..2ee991d 100644 --- a/modules/remote-hip/c++/device.tmpl.hpp +++ b/modules/remote-hip/c++/device.tmpl.hpp @@ -32,12 +32,9 @@ struct hip_copy_to_device<schema::Array<T,Dim>, Encoding> { static error_or<void> apply(data<Schema, Encoding>& from, data<Schema,Encoding>** to){ typename native_data_type<T>::type* dat{}; - hipError_t data_malloc_err = hipMalloc(&dat,sizeof(typename native_data_type<T>::type) * from.size()); - hipError_t data_copy_err = hipMemcpy(dat, (from.get_raw_data()),sizeof(typename native_data_type<T>::type) * from.size(), hipMemcpyHostToDevice); - - if(from.size() == 0u){ - // Everything is fine. We just don't want to allocate data which doesn't exist. - return make_void(); + if(from.size() > 0u){ + hipError_t data_malloc_err = hipMalloc(&dat,sizeof(typename native_data_type<T>::type) * from.size()); + hipError_t data_copy_err = hipMemcpy(dat, (from.get_raw_data()),sizeof(typename native_data_type<T>::type) * from.size(), hipMemcpyHostToDevice); } // auto from_dat = &from.at(0); diff --git a/modules/remote-hip/examples/hip_transfer_data.cpp b/modules/remote-hip/examples/hip_transfer_data.cpp index f112c68..2742715 100644 --- a/modules/remote-hip/examples/hip_transfer_data.cpp +++ b/modules/remote-hip/examples/hip_transfer_data.cpp @@ -19,6 +19,12 @@ __global__ void print_array_vals(saw::data<saw::schema::Array<saw::schema::Int16 int v = val->at(i).get(); printf("%d ", v); } + + auto raw_d = val->get_raw_data(); + + for(uint64_t i = 0; i < orig_len; ++i){ + printf("%d ", raw_d[i]); + } printf("\n"); } |