diff options
author | Claudius "keldu" Holeksa <mail@keldu.de> | 2025-08-07 13:39:54 +0200 |
---|---|---|
committer | Claudius "keldu" Holeksa <mail@keldu.de> | 2025-08-07 13:39:54 +0200 |
commit | 03c56bcc83c8984165208a3826fb06f68109d1ac (patch) | |
tree | 55ea7e07130af5e1e9dce903b9cc32308632012f | |
parent | 45ca44d5f0387a0551cef87168a59d6df97f66fe (diff) |
Added subtraction
-rw-r--r-- | modules/codec/c++/data_math.hpp | 11 | ||||
-rw-r--r-- | modules/codec/tests/math.cpp | 8 |
2 files changed, 19 insertions, 0 deletions
diff --git a/modules/codec/c++/data_math.hpp b/modules/codec/c++/data_math.hpp index b5491eb..5791f00 100644 --- a/modules/codec/c++/data_math.hpp +++ b/modules/codec/c++/data_math.hpp @@ -104,6 +104,17 @@ public: return c; } + + data<schema::Tensor<Inner, Dims...>, encode::Native> operator-(const data<schema::Tensor<Inner, Dims...>, encode::Native>& rhs) { + data<schema::Tensor<Inner, Dims...>, encode::Native> c; + + rank_iterator<Dims...>::in_fixed_bounds([&](const data<schema::FixedArray<schema::UInt64, sizeof...(Dims)>, encode::Native>& index) -> error_or<void>{ + c.at(index) = at(index) - rhs.at(index); + return make_void(); + }); + + return c; + } }; } diff --git a/modules/codec/tests/math.cpp b/modules/codec/tests/math.cpp index 8b3b6f9..ad2d9a6 100644 --- a/modules/codec/tests/math.cpp +++ b/modules/codec/tests/math.cpp @@ -46,5 +46,13 @@ SAW_TEST("Math/Tensor"){ SAW_EXPECT(c.at({{1u,0u}}).get() == 7.0, std::string{"Unexpected value at (1,0): "} + std::to_string(c.at({{1u,0u}}).get())); SAW_EXPECT(c.at({{1u,1u}}).get() == 9.0, std::string{"Unexpected value at (1,1): "} + std::to_string(c.at({{1u,1u}}).get())); } + + auto d = b - a; + { + SAW_EXPECT(d.at({{0u,0u}}).get() == 1.0, std::string{"Unexpected value at (0,0): "} + std::to_string(d.at({{0u,0u}}).get()) ); + SAW_EXPECT(d.at({{0u,1u}}).get() == 1.0, std::string{"Unexpected value at (0,1): "} + std::to_string(d.at({{0u,1u}}).get())); + SAW_EXPECT(d.at({{1u,0u}}).get() == 1.0, std::string{"Unexpected value at (1,0): "} + std::to_string(d.at({{1u,0u}}).get())); + SAW_EXPECT(d.at({{1u,1u}}).get() == 1.0, std::string{"Unexpected value at (1,1): "} + std::to_string(d.at({{1u,1u}}).get())); + } } } |