#pragma once #include "data_math.hpp" namespace saw { namespace math { /* template data norm_2(const data& d){ return {}; } */ template data, Encoding> dot(const data, Encoding>& left, const data, Encoding>& right){ data,Encoding> val; auto& inner = val({}); for(uint64_t i = 0u; i < D; ++i){ inner = inner + left({{i}}) * right({{i}}); } return val; } template data,Encoding> sqrt(const data,Encoding>& inp){ data,Encoding> out; out.at({}).set(std::sqrt(inp.at({}).get())); return out; } template data, Encoding> normalize(const data>& input ) { auto inp_dot = dot(input,input); if(inp_dot.at({}).get() <= 0){ return input; } auto sqrt_inp_dot = sqrt(inp_dot); saw::data, Encoding> out; for(uint64_t i = 0u; i < D; ++i){ out.at({{i}}) = (input.at({{i}}) / sqrt_inp_dot.at({})); } return out; } template data, Encoding> vectorize_data(const data>& dat){ data,Encoding> vec_data; for(uint64_t i{0u}; i < D; ++i){ vec_data.at({{i}}) = dat.at({{i}}); } return vec_data; } template data, Encoding> multiply(const data, Encoding>& left, const data, Encoding>& right){ data, Encoding> lr; for(uint64_t i = 0u; i < M; ++i){ for(uint64_t j = 0u; j < N; ++j){ for(uint64_t k = 0u; k < K; ++k){ lr.at({{i,j}}) = lr.at({{i,j}}) + left.at({{i,k}}) * right.at({{k,j}}); } } } return lr; } template data, Encoding> multiply(const data, Encoding>& left, const data, Encoding>& right){ data, Encoding> lr; for(uint64_t i = 0u; i < M; ++i){ for(uint64_t j = 0u; j < N; ++j){ lr.at({{i}}) = lr.at({{i}}) + left.at({{i,j}}) * right.at({{j}}); } } return lr; } } }