diff options
author | Claudius "keldu" Holeksa <mail@keldu.de> | 2025-08-06 18:37:49 +0200 |
---|---|---|
committer | Claudius "keldu" Holeksa <mail@keldu.de> | 2025-08-06 18:37:49 +0200 |
commit | 16ea7cfc4d834a3c0cedb4a833ff0f212eeee2dc (patch) | |
tree | 1e20f18e1c984d96ef351cb4f6bd5dff81641f89 | |
parent | de03ab009d6f7a5e10d3bfd5dd3d4ad1672b1ace (diff) |
Fixing math operators
-rw-r--r-- | modules/codec/c++/data_math.hpp | 19 | ||||
-rw-r--r-- | modules/codec/c++/iterator.hpp | 55 |
2 files changed, 68 insertions, 6 deletions
diff --git a/modules/codec/c++/data_math.hpp b/modules/codec/c++/data_math.hpp index 7277168..b5491eb 100644 --- a/modules/codec/c++/data_math.hpp +++ b/modules/codec/c++/data_math.hpp @@ -2,6 +2,7 @@ #include "data.hpp" #include "schema_math.hpp" +#include "iterator.hpp" #include <cmath> @@ -71,31 +72,37 @@ public: using Schema = schema::Tensor<Inner, Dims...>; private: data<schema::Array<Inner,Schema::Rank>, encode::Native> values_; + public: data(): values_{data<schema::FixedArray<schema::UInt64,sizeof...(Dims)>>{{data<schema::UInt64, encode::Native>{Dims}...}}} {} - data<Inner, encode::Native>& at(const data<schema::FixedArray<schema::UInt64>, encode::Native>& index){ + data<Inner, encode::Native>& at(const data<schema::FixedArray<schema::UInt64, sizeof...(Dims)>, encode::Native>& index){ return values_.at(index); } - const data<Inner, encode::Native>& at(const data<schema::FixedArray<schema::UInt64>, encode::Native>& index) const { + const data<Inner, encode::Native>& at(const data<schema::FixedArray<schema::UInt64, sizeof...(Dims)>, encode::Native>& index) const { return values_.at(index); } - data<Inner, encode::Native>& operator()(const data<schema::FixedArray<schema::UInt64>, encode::Native>& index){ + data<Inner, encode::Native>& operator()(const data<schema::FixedArray<schema::UInt64, sizeof...(Dims)>, encode::Native>& index){ return values_.at(index); } - const data<Inner, encode::Native>& operator()(const data<schema::FixedArray<schema::UInt64>, encode::Native>& index) const { + const data<Inner, encode::Native>& operator()(const data<schema::FixedArray<schema::UInt64, sizeof...(Dims)>, encode::Native>& index) const { return values_.at(index); } - data<schema::Tensor<Inner, Dims...>, encode::Native> operator+(const data<schema::Tensor<Inner, Dims...>, encode::Native>& rhs) const { + data<schema::Tensor<Inner, Dims...>, encode::Native> operator+(const data<schema::Tensor<Inner, Dims...>, encode::Native>& rhs) { data<schema::Tensor<Inner, Dims...>, encode::Native> c; - return {}; + rank_iterator<Dims...>::in_fixed_bounds([&](const data<schema::FixedArray<schema::UInt64, sizeof...(Dims)>, encode::Native>& index) -> error_or<void>{ + c.at(index) = at(index) + rhs.at(index); + return make_void(); + }); + + return c; } }; diff --git a/modules/codec/c++/iterator.hpp b/modules/codec/c++/iterator.hpp new file mode 100644 index 0000000..96a4441 --- /dev/null +++ b/modules/codec/c++/iterator.hpp @@ -0,0 +1,55 @@ +#pragma once + +#include "data.hpp" + +namespace saw { +namespace impl { +template<uint64_t i, uint64_t val_0, uint64_t... vals> +struct iterator_index_parameter_value { + static_assert(i > 0, "Shouldn't happen"); + static constexpr uint64_t value = iterator_index_parameter_value<i-1u, vals...>::value; +}; + +template<uint64_t val_0, uint64_t... vals> +struct iterator_index_parameter_value<0u, val_0, vals...> { + static constexpr uint64_t value = val_0; +}; + +template<typename Func, uint64_t... Dims> +struct iterator_in_bounds { +private: + static constexpr std::array<uint64_t, sizeof...(Dims)> dims_{Dims...}; + + template<uint64_t level> + static error_or<void> apply_i(Func& func, data<schema::FixedArray<schema::UInt64, sizeof...(Dims)>, encode::Native>& index){ + if constexpr (level >= sizeof...(Dims)){ + return func(index); + }else{ + index.at({level}).set(0); + + for(;index.at({level}) < impl::iterator_index_parameter_value<level,Dims...>::value; ++index.at({level})){ + auto eov = apply_i<level+1u>(func, index); + if(eov.is_error()){ + return eov; + } + } + } + return saw::make_void(); + } +public: + static error_or<void> apply(Func& func){ + data<schema::FixedArray<schema::UInt64,sizeof...(Dims)>, encode::Native> index; + return apply_i<0u>(func, index); + } +}; +} + +template<uint64_t... Dims> +struct rank_iterator { + template<typename Func> + static error_or<void> in_fixed_bounds(Func&& func){ + // static_assert(D == 2u, "Currently a lazy implementation for AND combinations of intervalls."); + return impl::iterator_in_bounds<Func,Dims...>::apply(func); + } +}; +} |