diff options
Diffstat (limited to 'lib/core/c++/math')
| -rw-r--r-- | lib/core/c++/math/n_closest.hpp | 18 |
1 files changed, 14 insertions, 4 deletions
diff --git a/lib/core/c++/math/n_closest.hpp b/lib/core/c++/math/n_closest.hpp index 13414e2..ac0fe2f 100644 --- a/lib/core/c++/math/n_closest.hpp +++ b/lib/core/c++/math/n_closest.hpp @@ -7,7 +7,7 @@ namespace kel { namespace lbm { template<typename FieldSchema, typename Encode, typename T, uint64_t D> -saw::data<typename FieldSchema::InnerValueSchema> n_closest_read(const saw::data<sch::Ptr<FieldSchema>,Encode>& f, const saw::data<sch::Vector<T,D>>& frac_ind){ +saw::data<typename FieldSchema::StoredValueSchema> n_closest_read(const saw::data<sch::Ptr<FieldSchema>,Encode>& f, const saw::data<sch::Vector<T,D>>& frac_ind){ auto shift_frac_ind = frac_ind; for(uint64_t i{0u}; i < D; ++i){ @@ -18,13 +18,16 @@ saw::data<typename FieldSchema::InnerValueSchema> n_closest_read(const saw::data } } - auto shift_ind = frac_ind.template cast_to<sch::UInt64>(); + saw::data<sch::FixedArray<sch::UInt64,D>> shift_ind; + for(uint64_t i{0u}; i < D; ++i){ + shift_ind.at({i}) = frac_ind.at({{i}}).template cast_to<sch::UInt64>(); + } return f.at(shift_ind); } template<typename FieldSchema, typename Encode, typename T, uint64_t D> -void n_closest_add(saw::data<sch::Ptr<FieldSchema>,Encode>& f, const saw::data<sch::Vector<T,D>>& frac_ind, const saw::data<typename FieldSchema::InnerValueSchema>& val){ +void n_closest_add(const saw::data<sch::Ptr<FieldSchema>,Encode>& f, const saw::data<sch::Vector<T,D>>& frac_ind, const saw::data<typename FieldSchema::StoredValueSchema>& val){ auto shift_frac_ind = frac_ind; for(uint64_t i{0u}; i < D; ++i){ @@ -34,7 +37,14 @@ void n_closest_add(saw::data<sch::Ptr<FieldSchema>,Encode>& f, const saw::data<s } } - auto shift_ind = frac_ind.template cast_to<sch::UInt64>(); + auto f_meta = f.meta(); + saw::data<sch::FixedArray<sch::UInt64,D>> shift_ind; + for(uint64_t i{0u}; i < D; ++i){ + shift_ind.at({i}) = frac_ind.at({{i}}).template cast_to<sch::UInt64>(); + if(shift_ind.at({i}) < f_meta.at({i})){ + shift_ind.at({i}) = f_meta.at({i}) - 1u; + } + } auto& f_i = f.at(shift_ind); f_i = f_i + val; |
