diff --git a/vowpalwabbit/core/src/reductions/eigen_memory_tree.cc b/vowpalwabbit/core/src/reductions/eigen_memory_tree.cc index a496d8fc905..11da7bbb24b 100644 --- a/vowpalwabbit/core/src/reductions/eigen_memory_tree.cc +++ b/vowpalwabbit/core/src/reductions/eigen_memory_tree.cc @@ -489,15 +489,7 @@ void tree_bound(emt_tree& b, emt_example* ec) } } -void scorer_features(const emt_feats& f1, VW::features& out) -{ - for (auto p : f1) - { - if (p.second != 0) { out.push_back(p.second, p.first); } - } -} - -void scorer_features(const emt_feats& f1, const emt_feats& f2, VW::features& out) +void scorer_features_sub(const emt_feats& f1, const emt_feats& f2, VW::features& out) { auto iter1 = f1.begin(); auto iter2 = f2.begin(); @@ -535,15 +527,31 @@ void scorer_features(const emt_feats& f1, const emt_feats& f2, VW::features& out } } +void scorer_features_mul(const emt_feats& f1, const emt_feats& f2, VW::features& out) +{ + auto iter1 = f1.begin(); + auto iter2 = f2.begin(); + + while (iter1 != f1.end() && iter2 != f2.end()) + { + if (iter1->first < iter2->first) { iter1++; } + else if (iter2->first < iter1->first) { iter2++; } + else + { + out.push_back(iter1->second * iter2->second, iter1->first); + iter1++; + iter2++; + } + } +} + void scorer_example(emt_tree& b, const emt_example& ex1, const emt_example& ex2) { VW::example& out = *b.ex; static constexpr VW::namespace_index X_NS = 'x'; - static constexpr VW::namespace_index Z_NS = 'z'; out.feature_space[X_NS].clear(); - out.feature_space[Z_NS].clear(); if (b.scorer_type == emt_scorer_type::SELF_CONSISTENT_RANK) { @@ -552,7 +560,7 @@ void scorer_example(emt_tree& b, const emt_example& ex1, const emt_example& ex2) out.interactions->clear(); - scorer_features(ex1.full, ex2.full, out.feature_space[X_NS]); + scorer_features_sub(ex1.full, ex2.full, out.feature_space[X_NS]); out.total_sum_feat_sq = out.feature_space[X_NS].sum_feat_sq; out.num_features = out.feature_space[X_NS].size(); @@ -565,26 +573,13 @@ void scorer_example(emt_tree& b, const emt_example& ex1, const emt_example& ex2) { out.indices.clear(); out.indices.push_back(X_NS); - out.indices.push_back(Z_NS); out.interactions->clear(); - out.interactions->push_back({X_NS, Z_NS}); - - b.all->feature_tweaks_config.ignore_some_linear = true; - b.all->feature_tweaks_config.ignore_linear[X_NS] = true; - b.all->feature_tweaks_config.ignore_linear[Z_NS] = true; - scorer_features(ex1.full, out.feature_space[X_NS]); - scorer_features(ex2.full, out.feature_space[Z_NS]); + scorer_features_mul(ex1.full, ex2.full, out.feature_space[X_NS]); - // when we receive ex1 and ex2 their features are indexed on top of eachother. In order - // to make sure VW recognizes the features from the two examples as separate features - // we apply a map of multiplying by 2 and then offseting by 1 on the second example. - for (auto& j : out.feature_space[X_NS].indices) { j = j * 2; } - for (auto& j : out.feature_space[Z_NS].indices) { j = j * 2 + 1; } - - out.total_sum_feat_sq = out.feature_space[X_NS].sum_feat_sq + out.feature_space[Z_NS].sum_feat_sq; - out.num_features = out.feature_space[X_NS].size() + out.feature_space[Z_NS].size(); + out.total_sum_feat_sq = out.feature_space[X_NS].sum_feat_sq; + out.num_features = out.feature_space[X_NS].size(); auto initial = emt_initial(b.initial_type, ex1.full, ex2.full); out.ex_reduction_features.get().initial = initial;