summaryrefslogtreecommitdiff
path: root/modules
diff options
context:
space:
mode:
Diffstat (limited to 'modules')
-rw-r--r--modules/codec/c++/data_math.hpp11
-rw-r--r--modules/codec/tests/math.cpp8
2 files changed, 19 insertions, 0 deletions
diff --git a/modules/codec/c++/data_math.hpp b/modules/codec/c++/data_math.hpp
index b5491eb..5791f00 100644
--- a/modules/codec/c++/data_math.hpp
+++ b/modules/codec/c++/data_math.hpp
@@ -104,6 +104,17 @@ public:
return c;
}
+
+ data<schema::Tensor<Inner, Dims...>, encode::Native> operator-(const data<schema::Tensor<Inner, Dims...>, encode::Native>& rhs) {
+ data<schema::Tensor<Inner, Dims...>, encode::Native> c;
+
+ rank_iterator<Dims...>::in_fixed_bounds([&](const data<schema::FixedArray<schema::UInt64, sizeof...(Dims)>, encode::Native>& index) -> error_or<void>{
+ c.at(index) = at(index) - rhs.at(index);
+ return make_void();
+ });
+
+ return c;
+ }
};
}
diff --git a/modules/codec/tests/math.cpp b/modules/codec/tests/math.cpp
index 8b3b6f9..ad2d9a6 100644
--- a/modules/codec/tests/math.cpp
+++ b/modules/codec/tests/math.cpp
@@ -46,5 +46,13 @@ SAW_TEST("Math/Tensor"){
SAW_EXPECT(c.at({{1u,0u}}).get() == 7.0, std::string{"Unexpected value at (1,0): "} + std::to_string(c.at({{1u,0u}}).get()));
SAW_EXPECT(c.at({{1u,1u}}).get() == 9.0, std::string{"Unexpected value at (1,1): "} + std::to_string(c.at({{1u,1u}}).get()));
}
+
+ auto d = b - a;
+ {
+ SAW_EXPECT(d.at({{0u,0u}}).get() == 1.0, std::string{"Unexpected value at (0,0): "} + std::to_string(d.at({{0u,0u}}).get()) );
+ SAW_EXPECT(d.at({{0u,1u}}).get() == 1.0, std::string{"Unexpected value at (0,1): "} + std::to_string(d.at({{0u,1u}}).get()));
+ SAW_EXPECT(d.at({{1u,0u}}).get() == 1.0, std::string{"Unexpected value at (1,0): "} + std::to_string(d.at({{1u,0u}}).get()));
+ SAW_EXPECT(d.at({{1u,1u}}).get() == 1.0, std::string{"Unexpected value at (1,1): "} + std::to_string(d.at({{1u,1u}}).get()));
+ }
}
}