summaryrefslogtreecommitdiff
path: root/modules
diff options
context:
space:
mode:
authorClaudius "keldu" Holeksa <mail@keldu.de>2025-10-29 17:47:39 +0100
committerClaudius "keldu" Holeksa <mail@keldu.de>2025-10-29 17:47:39 +0100
commit53b1bc01474a0612e8039485cd4e33fc441f673a (patch)
tree385c05031ab94065cad79b620f365322723fda40 /modules
parent7ea2e439dded1baa11c4c12207eee8e1033ae104 (diff)
downloadforstio-forstio-53b1bc01474a0612e8039485cd4e33fc441f673a.tar.gz
Adding some proper USM features :)
Diffstat (limited to 'modules')
-rw-r--r--modules/remote-sycl/c++/data.hpp126
1 files changed, 79 insertions, 47 deletions
diff --git a/modules/remote-sycl/c++/data.hpp b/modules/remote-sycl/c++/data.hpp
index 91f74f8..2d5893f 100644
--- a/modules/remote-sycl/c++/data.hpp
+++ b/modules/remote-sycl/c++/data.hpp
@@ -8,73 +8,105 @@ namespace saw {
* Generic wrapper class which stores data on the sycl side.
* Most of the times this will be a root object.
*/
-template<typename Schema>
-class data<Schema, encode::Sycl<encode::Native>> {
+
+template<typename Sch, uint64_t Dim, typename Encode>
+class data<schema::Array<Sch, Dim>, encode::Sycl<Encode>> {
+public:
+ using Schema = schema::Array<Sch,Dim>;
private:
- cl::sycl::buffer<data<Schema, encode::Native>> data_;
- data<schema::UInt64, encode::Native> size_;
+ // cl::sycl::buffer<data<Sch, encode::Native>> data_;
+ using sycl_usm_allocator = acpp::sycl::usm_allocator<data<Sch,Encode>, sycl::usm::alloc::shared>;
+ data<schema::FixedArray<schema::UInt64, sizeof...(D)>, Encode> dims_;
+ data<schema::UInt64, Encode> size_;
+ std::vector<data<Sch,Encode>, sycl_usm_allocator> data_;
+
+ uint64_t get_full_size() const {
+ uint64_t s = 1;
+
+ for(uint64_t iter = 0; iter < Dim; ++iter){
+ auto& dim_iter = dims_.at(data<schema::UInt64>{iter});
+ s *= dim_iter.get();
+ }
+
+ return s;
+ }
public:
- data(const data<Schema, encode::Native>& data__):
- data_{&data__, 1u},
- size_{data__.size()}
+ data():
+ dims_{},
+ size_{0u},
+ data_{}
+ {
+ for(uint64_t iter = 0; iter < Dim; ++iter){
+ dims_.at({iter}) = 0u;
+ }
+ }
+
+ data(const data<schema::FixedArray<schema::UInt64, sizeof...(D)>, Encode>& dims__):
+ dims_{dims__},
+ size_{get_full_size()},
+ data_{size_}
{}
- auto& get_handle() {
- return data_;
+ auto* get_internal_data() {
+ if(data_.empty()){
+ return nullptr;
+ }
+ return &(data_[0u]);
}
- const auto& get_handle() const {
- return data_;
+ const auto& get_internal_size() const {
+ return size_;
}
- data<schema::UInt64, encode::Native> size() const {
+ data<schema::UInt64, Encode> size() const {
return size_;
}
- template<cl::sycl::access::mode AccessMode>
- auto access(cl::sycl::handler& h){
- return data_.template get_access<AccessMode>(h);
- }
-
- template<cl::sycl::access::mode AccessMode>
- auto access(cl::sycl::handler& h) const {
- return data_.template get_access<AccessMode>(h);
+ data<schema::FixedArray<schema::UInt64, sizeof...(D)>, Encode> dims() const {
+ return dims_;
}
-};
-template<typename Sch, uint64_t Dim>
-class data<schema::Array<Sch, Dim>, encode::Sycl<encode::Native>> {
-public:
- using Schema = schema::Array<Sch,Dim>;
-private:
- cl::sycl::buffer<data<Sch, encode::Native>> data_;
- data<schema::UInt64, encode::Native> size_;
-public:
- data(const data<Schema, encode::Native>& host_data__):
- data_{&host_data__.at({0u}),host_data__.size().get()},
- size_{host_data__.size()}
- {}
-
- auto& get_handle() {
- return data_;
+ constexpr data<T, Encode>& at(const data<schema::FixedArray<schema::UInt64, sizeof...(D)>, Encode>& i){
+ return value_.at(this->get_flat_index(i));
}
- const auto& get_handle() const {
- return data_;
+ constexpr const data<T, Encode>& at(const data<schema::FixedArray<schema::UInt64, sizeof...(D)>, Encode>& i)const{
+ return value_.at(this->get_flat_index(i));
}
- data<schema::UInt64, encode::Native> size() const {
- return size_;
+ data<schema::UInt64,Encode> internal_flat_index(const data<schema::FixedArray<schema::UInt64, sizeof...(D)>, Encode>& i) const {
+ return {this->get_flat_index(i)};
}
+private:
+ template<typename U>
+ uint64_t get_flat_index(const U& i) const {
+ static_assert(
+ std::is_same_v<U,data<schema::FixedArray<schema::UInt64,Dim>, Encode>> or
+ std::is_same_v<U,std::array<uint64_t,Dim>>,
+ "Unsupported type"
+ );
+ assert(value_.size() == get_full_size());
+ uint64_t s = 0;
- template<cl::sycl::access::mode AccessMode>
- auto access(cl::sycl::handler& h){
- return data_.template get_access<AccessMode>(h);
- }
+ uint64_t stride = 1;
+
+ for(uint64_t iter = 0; iter < Dim; ++iter){
+ uint64_t ind = [](auto val) -> uint64_t {
+ using V = std::decay_t<decltype(val)>;
+ if constexpr (std::is_same_v<V,data<schema::UInt64>>){
+ return val.get();
+ }else if constexpr (std::is_same_v<V, uint64_t>){
+ return val;
+ }else{
+ static_assert(always_false<V>, "Cases exhausted");
+ }
+ }(i.at(iter));
+ assert(ind < dims_.at({iter}).get() );
+ s += ind * stride;
+ stride *= dims_.at(iter).get();
+ }
- template<cl::sycl::access::mode AccessMode>
- auto access(cl::sycl::handler& h) const {
- return data_.template get_access<AccessMode>(h);
+ return s;
}
};
}