diff options
author | Claudius "keldu" Holeksa <mail@keldu.de> | 2025-09-22 15:53:45 +0200 |
---|---|---|
committer | Claudius "keldu" Holeksa <mail@keldu.de> | 2025-09-22 15:53:45 +0200 |
commit | f398d046ed5ff6e3bf803c3f16d1fb470b464bdb (patch) | |
tree | 31ea6489cd77ab6169f39a5e25dfded993c3d4b2 /modules | |
parent | a1583da62ea0f7e9affe868cd509557b5e91fae3 (diff) |
Add sqrt and normalize
Diffstat (limited to 'modules')
-rw-r--r-- | modules/codec/c++/data_math.hpp | 11 | ||||
-rw-r--r-- | modules/codec/c++/math.hpp | 18 |
2 files changed, 29 insertions, 0 deletions
diff --git a/modules/codec/c++/data_math.hpp b/modules/codec/c++/data_math.hpp index efdca7d..51657c7 100644 --- a/modules/codec/c++/data_math.hpp +++ b/modules/codec/c++/data_math.hpp @@ -120,6 +120,17 @@ public: return c; } + + template<typename InnerChange> + data<schema::Tensor<InnerChange, Dims...>, encode::Native> cast_to(){ + data<schema::Tensor<InnerChange, Dims...>, encode::Native> native_change; + rank_iterator<Dims...>::in_fixed_bounds([&](const data<schema::FixedArray<schema::UInt64, sizeof...(Dims)>, encode::Native>& index) -> error_or<void>{ + native_change.at(index) = at(index).template cast_to<InnerChange>(); + return make_void(); + }); + + return inner_change; + } }; } diff --git a/modules/codec/c++/math.hpp b/modules/codec/c++/math.hpp index ddeea3f..67d8ed3 100644 --- a/modules/codec/c++/math.hpp +++ b/modules/codec/c++/math.hpp @@ -21,5 +21,23 @@ data<schema::Scalar<T>, Encoding> dot(const data<schema::Vector<T,D>, Encoding>& return val; } + +template<typename T,typename Encoding = encode::Native> +data<schema::Scalar<T>,Encoding> sqrt(const data<schema::Scalar<T>,Encoding>& inp){ + data<schema::Scalar<T>,Encoding> out; + out.at({}).set(std::sqrt(inp.at({}).get())); + return out; +} + +template<typename T, uint64_t D, typename Encoding = encode::Native> +data<schema::Vector<T,D>, Encoding> normalize(const data<schema::Vector<T,D>>& input ) { + auto inp_dot = dot<T,D,Encoding>(input,input); + auto sqrt_inp_dot = sqrt<T>(inp_dot); + + saw::data<schema::Vector<T,D>, Encoding> out; + out.at({{0u}}).set(out.at({{0u}}).get() / sqrt_inp_dot.at({}).get()); + out.at({{1u}}).set(out.at({{1u}}).get() / sqrt_inp_dot.at({}).get()); + return out; +} } } |