diff options
Diffstat (limited to 'modules/codec/c++/data_math.hpp')
-rw-r--r-- | modules/codec/c++/data_math.hpp | 19 |
1 files changed, 13 insertions, 6 deletions
diff --git a/modules/codec/c++/data_math.hpp b/modules/codec/c++/data_math.hpp index 7277168..b5491eb 100644 --- a/modules/codec/c++/data_math.hpp +++ b/modules/codec/c++/data_math.hpp @@ -2,6 +2,7 @@ #include "data.hpp" #include "schema_math.hpp" +#include "iterator.hpp" #include <cmath> @@ -71,31 +72,37 @@ public: using Schema = schema::Tensor<Inner, Dims...>; private: data<schema::Array<Inner,Schema::Rank>, encode::Native> values_; + public: data(): values_{data<schema::FixedArray<schema::UInt64,sizeof...(Dims)>>{{data<schema::UInt64, encode::Native>{Dims}...}}} {} - data<Inner, encode::Native>& at(const data<schema::FixedArray<schema::UInt64>, encode::Native>& index){ + data<Inner, encode::Native>& at(const data<schema::FixedArray<schema::UInt64, sizeof...(Dims)>, encode::Native>& index){ return values_.at(index); } - const data<Inner, encode::Native>& at(const data<schema::FixedArray<schema::UInt64>, encode::Native>& index) const { + const data<Inner, encode::Native>& at(const data<schema::FixedArray<schema::UInt64, sizeof...(Dims)>, encode::Native>& index) const { return values_.at(index); } - data<Inner, encode::Native>& operator()(const data<schema::FixedArray<schema::UInt64>, encode::Native>& index){ + data<Inner, encode::Native>& operator()(const data<schema::FixedArray<schema::UInt64, sizeof...(Dims)>, encode::Native>& index){ return values_.at(index); } - const data<Inner, encode::Native>& operator()(const data<schema::FixedArray<schema::UInt64>, encode::Native>& index) const { + const data<Inner, encode::Native>& operator()(const data<schema::FixedArray<schema::UInt64, sizeof...(Dims)>, encode::Native>& index) const { return values_.at(index); } - data<schema::Tensor<Inner, Dims...>, encode::Native> operator+(const data<schema::Tensor<Inner, Dims...>, encode::Native>& rhs) const { + data<schema::Tensor<Inner, Dims...>, encode::Native> operator+(const data<schema::Tensor<Inner, Dims...>, encode::Native>& rhs) { data<schema::Tensor<Inner, Dims...>, encode::Native> c; - return {}; + rank_iterator<Dims...>::in_fixed_bounds([&](const data<schema::FixedArray<schema::UInt64, sizeof...(Dims)>, encode::Native>& index) -> error_or<void>{ + c.at(index) = at(index) + rhs.at(index); + return make_void(); + }); + + return c; } }; |