diff --git a/vowpalwabbit/core/src/reductions/eigen_memory_tree.cc b/vowpalwabbit/core/src/reductions/eigen_memory_tree.cc index 37c856644fa..11da7bbb24b 100644 --- a/vowpalwabbit/core/src/reductions/eigen_memory_tree.cc +++ b/vowpalwabbit/core/src/reductions/eigen_memory_tree.cc @@ -489,15 +489,7 @@ void tree_bound(emt_tree& b, emt_example* ec) } } -void scorer_features(const emt_feats& f1, VW::features& out) -{ - for (auto p : f1) - { - if (p.second != 0) { out.push_back(p.second, p.first); } - } -} - -void scorer_features(const emt_feats& f1, const emt_feats& f2, VW::features& out) +void scorer_features_sub(const emt_feats& f1, const emt_feats& f2, VW::features& out) { auto iter1 = f1.begin(); auto iter2 = f2.begin(); @@ -535,15 +527,31 @@ void scorer_features(const emt_feats& f1, const emt_feats& f2, VW::features& out } } +void scorer_features_mul(const emt_feats& f1, const emt_feats& f2, VW::features& out) +{ + auto iter1 = f1.begin(); + auto iter2 = f2.begin(); + + while (iter1 != f1.end() && iter2 != f2.end()) + { + if (iter1->first < iter2->first) { iter1++; } + else if (iter2->first < iter1->first) { iter2++; } + else + { + out.push_back(iter1->second * iter2->second, iter1->first); + iter1++; + iter2++; + } + } +} + void scorer_example(emt_tree& b, const emt_example& ex1, const emt_example& ex2) { VW::example& out = *b.ex; static constexpr VW::namespace_index X_NS = 'x'; - static constexpr VW::namespace_index Z_NS = 'z'; out.feature_space[X_NS].clear(); - out.feature_space[Z_NS].clear(); if (b.scorer_type == emt_scorer_type::SELF_CONSISTENT_RANK) { @@ -552,7 +560,7 @@ void scorer_example(emt_tree& b, const emt_example& ex1, const emt_example& ex2) out.interactions->clear(); - scorer_features(ex1.full, ex2.full, out.feature_space[X_NS]); + scorer_features_sub(ex1.full, ex2.full, out.feature_space[X_NS]); out.total_sum_feat_sq = out.feature_space[X_NS].sum_feat_sq; out.num_features = out.feature_space[X_NS].size(); @@ -565,26 +573,13 @@ void scorer_example(emt_tree& b, const emt_example& ex1, const emt_example& ex2) { out.indices.clear(); out.indices.push_back(X_NS); - out.indices.push_back(Z_NS); out.interactions->clear(); - out.interactions->push_back({X_NS, Z_NS}); - b.all->feature_tweaks_config.ignore_some_linear = true; - b.all->feature_tweaks_config.ignore_linear[X_NS] = true; - b.all->feature_tweaks_config.ignore_linear[Z_NS] = true; + scorer_features_mul(ex1.full, ex2.full, out.feature_space[X_NS]); - scorer_features(ex1.full, out.feature_space[X_NS]); - scorer_features(ex2.full, out.feature_space[Z_NS]); - - // when we receive ex1 and ex2 their features are indexed on top of eachother. In order - // to make sure VW recognizes the features from the two examples as separate features - // we apply a map of multiplying by 2 and then offseting by 1 on the second example. - for (auto& j : out.feature_space[X_NS].indices) { j = j * 2; } - for (auto& j : out.feature_space[Z_NS].indices) { j = j * 2 + 1; } - - out.total_sum_feat_sq = out.feature_space[X_NS].sum_feat_sq + out.feature_space[Z_NS].sum_feat_sq; - out.num_features = out.feature_space[X_NS].size() + out.feature_space[Z_NS].size(); + out.total_sum_feat_sq = out.feature_space[X_NS].sum_feat_sq; + out.num_features = out.feature_space[X_NS].size(); auto initial = emt_initial(b.initial_type, ex1.full, ex2.full); out.ex_reduction_features.get().initial = initial; @@ -741,13 +736,14 @@ void node_split(emt_tree& b, emt_node& cn) cn.examples.clear(); } -void node_insert(emt_node& cn, std::unique_ptr ex) +void node_insert(emt_tree& b, emt_node& cn, std::unique_ptr ex) { for (auto& cn_ex : cn.examples) { if (cn_ex->full == ex->full) { return; } } cn.examples.push_back(std::move(ex)); + tree_bound(b, cn.examples.back().get()); } emt_example* node_pick(emt_tree& b, learner& base, emt_node& cn, const emt_example& ex) @@ -779,16 +775,15 @@ void node_predict(emt_tree& b, learner& base, emt_node& cn, emt_example& ex, VW: auto* closest_ex = node_pick(b, base, cn, ex); ec.pred.multiclass = (closest_ex != nullptr) ? closest_ex->label : 0; ec.loss = (ec.l.multi.label != ec.pred.multiclass) ? ec.weight : 0; + if (closest_ex != nullptr) { tree_bound(b, closest_ex); } } void emt_predict(emt_tree& b, learner& base, VW::example& ec) { b.all->feature_tweaks_config.ignore_some_linear = false; emt_example ex(*b.all, &ec); - emt_node& cn = *tree_route(b, ex); node_predict(b, base, cn, ex, ec); - tree_bound(b, &ex); } void emt_learn(emt_tree& b, learner& base, VW::example& ec) @@ -797,10 +792,9 @@ void emt_learn(emt_tree& b, learner& base, VW::example& ec) auto ex = VW::make_unique(*b.all, &ec); emt_node& cn = *tree_route(b, *ex); - scorer_learn(b, base, cn, *ex, ec.weight); node_predict(b, base, cn, *ex, ec); // vw learners predict and emt_learn - tree_bound(b, ex.get()); - node_insert(cn, std::move(ex)); + scorer_learn(b, base, cn, *ex, ec.weight); + node_insert(b, cn, std::move(ex)); node_split(b, cn); } diff --git a/vowpalwabbit/core/tests/eigen_memory_tree_test.cc b/vowpalwabbit/core/tests/eigen_memory_tree_test.cc index 82624c846df..da9aeafc0cc 100644 --- a/vowpalwabbit/core/tests/eigen_memory_tree_test.cc +++ b/vowpalwabbit/core/tests/eigen_memory_tree_test.cc @@ -131,7 +131,7 @@ TEST(EigenMemoryTree, ExactMatchWithRouterTest) } } -TEST(EigenMemoryTree, Bounding) +TEST(EigenMemoryTree, BoundingDrop) { auto vw = VW::initialize(vwtest::make_args("--quiet", "--emt", "--emt_tree", "5")); auto* tree = get_emt_tree(*vw); @@ -148,6 +148,45 @@ TEST(EigenMemoryTree, Bounding) EXPECT_EQ(tree->root->router_weights.size(), 0); } +TEST(EigenMemoryTree, BoundingPredict) +{ + auto vw = VW::initialize(vwtest::make_args("--quiet", "--emt", "--emt_tree", "3")); + auto* tree = get_emt_tree(*vw); + + auto* ex = VW::read_example(*vw, "1 | 1"); + vw->predict(*ex); + vw->finish_example(*ex); + + EXPECT_EQ(tree->bounder->list.size(), 0); +} + +TEST(EigenMemoryTree, BoundingRecency) +{ + auto vw = VW::initialize(vwtest::make_args("--quiet", "--emt", "--emt_tree", "3")); + auto* tree = get_emt_tree(*vw); + + for (int i = 0; i < 3; i++) + { + auto* ex = VW::read_example(*vw, std::to_string(i) + " | " + std::to_string(i)); + vw->learn(*ex); + vw->finish_example(*ex); + } + + EXPECT_EQ((*tree->bounder->list.begin())->base[0].first, 2); + + auto* ex1 = VW::read_example(*vw, "1 | 1"); + vw->predict(*ex1); + vw->finish_example(*ex1); + + EXPECT_EQ((*tree->bounder->list.begin())->base[0].first, 1); + + auto* ex2 = VW::read_example(*vw, "1 | 0"); + vw->predict(*ex2); + vw->finish_example(*ex2); + + EXPECT_EQ((*tree->bounder->list.begin())->base[0].first, 0); +} + TEST(EigenMemoryTree, Split) { auto args = vwtest::make_args("--quiet", "--emt", "--emt_tree", "10", "--emt_leaf", "3");