#pragma once #include "data_math.hpp" namespace saw { template class iterator final { public: template data norm_2(const data& d){ return {}; } */ template data, Encoding> dot(const data, Encoding>& left, const data, Encoding>& right){ data,Encoding> val; auto& inner = val({}); for(uint64_t i = 0u; i < D; ++i){ inner = inner + left({{i}}) * right({{i}}); } return val; } template data,Encoding> sqrt(const data,Encoding>& inp){ data,Encoding> out; out.at({}).set(std::sqrt(inp.at({}).get())); return out; } template data, Encoding> normalize(const data>& input ) { auto inp_dot = dot(input,input); if(inp_dot.at({}).get() <= 0){ return input; } auto sqrt_inp_dot = sqrt(inp_dot); saw::data, Encoding> out; for(uint64_t i = 0u; i < D; ++i){ out.at({{i}}) = (input.at({{i}}) / sqrt_inp_dot.at({})); } return out; } template data, Encoding> vectorize_data(const data>& dat){ data,Encoding> vec_data; for(uint64_t i{0u}; i < D; ++i){ vec_data.at({{i}}) = dat.at({{i}}); } return vec_data; } template data, Encoding> multiply(const data, Encoding>& left, const data, Encoding>& right){ data, Encoding> lr; for(uint64_t i = 0u; i < M; ++i){ for(uint64_t j = 0u; j < N; ++j){ for(uint64_t k = 0u; k < K; ++k){ lr.at({{i,j}}) = lr.at({{i,j}}) + left.at({{i,k}}) * right.at({{k,j}}); } } } return lr; } template data, Encoding> multiply(const data, Encoding>& left, const data, Encoding>& right){ data, Encoding> lr; for(uint64_t i = 0u; i < M; ++i){ for(uint64_t j = 0u; j < N; ++j){ lr.at({{i}}) = lr.at({{i}}) + left.at({{i,j}}) * right.at({{j}}); } } return lr; } template data, Encoding> cross( const data, Encoding> lh, const data, Encoding> rh ){ data, Encoding> cross_prod; cross_prod.at({{0u}}) = lh.at({{1u}}) * rh.at({{2u}}) - lh.at({{2u}}) * rh.at({{1u}}); cross_prod.at({{1u}}) = lh.at({{2u}}) * rh.at({{0u}}) - lh.at({{0u}}) * rh.at({{2u}}); cross_prod.at({{2u}}) = lh.at({{0u}}) * rh.at({{1u}}) - lh.at({{1u}}) * rh.at({{0u}}); return cross_prod; } template data,Encoding> cross( const data, Encoding> lh, const data, Encoding> rh ){ data, Encoding> cross_prod; cross_prod.at({}) = lh.at({{0u}}) * rh.at({{1u}}) - lh.at({{1u}}) * rh.at({{0u}}); return cross_prod; } template data,Encoding> cos( const data,Encoding>& val ){ data,Encoding> ret; ret.at({}).set(std::cos(val.at({}).get())); return ret; } template data,Encoding> sin( const data,Encoding>& val ){ data,Encoding> ret; ret.at({}).set(std::sin(val.at({}).get())); return ret; } template data,Encoding> rotate( const data, Encoding> vec, const data, Encoding> rot ){ data, Encoding> rot_vec; rot_vec.at({{0u}}) = vec.at({{0u}}) * cos(rot).at({}); rot_vec.at({{1u}}) = vec.at({{1u}}) * sin(rot).at({}); return rot_vec; } template data,Encoding> scale( const data, Encoding> vec, const data, Encoding> scale ){ data, Encoding> sc_vec; for(uint64_t i = 0u; i< D; ++i){ sc_vec.at({{i}}) = vec.at({{i}}) * scale; } return sc_vec; } template data, Encoding> fill( const data& filler){ data, Encoding> tbf; iterator::apply([&](const auto& index){ tbf.at(index) = filler; }, {}, {{Ds...}}); return tbf; } } }