summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--modules/codec/c++/data_math.hpp11
-rw-r--r--modules/codec/c++/math.hpp18
2 files changed, 29 insertions, 0 deletions
diff --git a/modules/codec/c++/data_math.hpp b/modules/codec/c++/data_math.hpp
index efdca7d..51657c7 100644
--- a/modules/codec/c++/data_math.hpp
+++ b/modules/codec/c++/data_math.hpp
@@ -120,6 +120,17 @@ public:
return c;
}
+
+ template<typename InnerChange>
+ data<schema::Tensor<InnerChange, Dims...>, encode::Native> cast_to(){
+ data<schema::Tensor<InnerChange, Dims...>, encode::Native> native_change;
+ rank_iterator<Dims...>::in_fixed_bounds([&](const data<schema::FixedArray<schema::UInt64, sizeof...(Dims)>, encode::Native>& index) -> error_or<void>{
+ native_change.at(index) = at(index).template cast_to<InnerChange>();
+ return make_void();
+ });
+
+ return inner_change;
+ }
};
}
diff --git a/modules/codec/c++/math.hpp b/modules/codec/c++/math.hpp
index ddeea3f..67d8ed3 100644
--- a/modules/codec/c++/math.hpp
+++ b/modules/codec/c++/math.hpp
@@ -21,5 +21,23 @@ data<schema::Scalar<T>, Encoding> dot(const data<schema::Vector<T,D>, Encoding>&
return val;
}
+
+template<typename T,typename Encoding = encode::Native>
+data<schema::Scalar<T>,Encoding> sqrt(const data<schema::Scalar<T>,Encoding>& inp){
+ data<schema::Scalar<T>,Encoding> out;
+ out.at({}).set(std::sqrt(inp.at({}).get()));
+ return out;
+}
+
+template<typename T, uint64_t D, typename Encoding = encode::Native>
+data<schema::Vector<T,D>, Encoding> normalize(const data<schema::Vector<T,D>>& input ) {
+ auto inp_dot = dot<T,D,Encoding>(input,input);
+ auto sqrt_inp_dot = sqrt<T>(inp_dot);
+
+ saw::data<schema::Vector<T,D>, Encoding> out;
+ out.at({{0u}}).set(out.at({{0u}}).get() / sqrt_inp_dot.at({}).get());
+ out.at({{1u}}).set(out.at({{1u}}).get() / sqrt_inp_dot.at({}).get());
+ return out;
+}
}
}