Skip to content

Commit

Permalink
Merge branch 'master' into master
Browse files Browse the repository at this point in the history
  • Loading branch information
ataymano authored Jan 23, 2024
2 parents 059566d + f8091b6 commit 5907c04
Show file tree
Hide file tree
Showing 2 changed files with 68 additions and 35 deletions.
62 changes: 28 additions & 34 deletions vowpalwabbit/core/src/reductions/eigen_memory_tree.cc
Original file line number Diff line number Diff line change
Expand Up @@ -489,15 +489,7 @@ void tree_bound(emt_tree& b, emt_example* ec)
}
}

void scorer_features(const emt_feats& f1, VW::features& out)
{
for (auto p : f1)
{
if (p.second != 0) { out.push_back(p.second, p.first); }
}
}

void scorer_features(const emt_feats& f1, const emt_feats& f2, VW::features& out)
void scorer_features_sub(const emt_feats& f1, const emt_feats& f2, VW::features& out)
{
auto iter1 = f1.begin();
auto iter2 = f2.begin();
Expand Down Expand Up @@ -535,15 +527,31 @@ void scorer_features(const emt_feats& f1, const emt_feats& f2, VW::features& out
}
}

void scorer_features_mul(const emt_feats& f1, const emt_feats& f2, VW::features& out)
{
auto iter1 = f1.begin();
auto iter2 = f2.begin();

while (iter1 != f1.end() && iter2 != f2.end())
{
if (iter1->first < iter2->first) { iter1++; }
else if (iter2->first < iter1->first) { iter2++; }
else
{
out.push_back(iter1->second * iter2->second, iter1->first);
iter1++;
iter2++;
}
}
}

void scorer_example(emt_tree& b, const emt_example& ex1, const emt_example& ex2)
{
VW::example& out = *b.ex;

static constexpr VW::namespace_index X_NS = 'x';
static constexpr VW::namespace_index Z_NS = 'z';

out.feature_space[X_NS].clear();
out.feature_space[Z_NS].clear();

if (b.scorer_type == emt_scorer_type::SELF_CONSISTENT_RANK)
{
Expand All @@ -552,7 +560,7 @@ void scorer_example(emt_tree& b, const emt_example& ex1, const emt_example& ex2)

out.interactions->clear();

scorer_features(ex1.full, ex2.full, out.feature_space[X_NS]);
scorer_features_sub(ex1.full, ex2.full, out.feature_space[X_NS]);

out.total_sum_feat_sq = out.feature_space[X_NS].sum_feat_sq;
out.num_features = out.feature_space[X_NS].size();
Expand All @@ -565,26 +573,13 @@ void scorer_example(emt_tree& b, const emt_example& ex1, const emt_example& ex2)
{
out.indices.clear();
out.indices.push_back(X_NS);
out.indices.push_back(Z_NS);

out.interactions->clear();
out.interactions->push_back({X_NS, Z_NS});

b.all->feature_tweaks_config.ignore_some_linear = true;
b.all->feature_tweaks_config.ignore_linear[X_NS] = true;
b.all->feature_tweaks_config.ignore_linear[Z_NS] = true;
scorer_features_mul(ex1.full, ex2.full, out.feature_space[X_NS]);

scorer_features(ex1.full, out.feature_space[X_NS]);
scorer_features(ex2.full, out.feature_space[Z_NS]);

// when we receive ex1 and ex2 their features are indexed on top of eachother. In order
// to make sure VW recognizes the features from the two examples as separate features
// we apply a map of multiplying by 2 and then offseting by 1 on the second example.
for (auto& j : out.feature_space[X_NS].indices) { j = j * 2; }
for (auto& j : out.feature_space[Z_NS].indices) { j = j * 2 + 1; }

out.total_sum_feat_sq = out.feature_space[X_NS].sum_feat_sq + out.feature_space[Z_NS].sum_feat_sq;
out.num_features = out.feature_space[X_NS].size() + out.feature_space[Z_NS].size();
out.total_sum_feat_sq = out.feature_space[X_NS].sum_feat_sq;
out.num_features = out.feature_space[X_NS].size();

auto initial = emt_initial(b.initial_type, ex1.full, ex2.full);
out.ex_reduction_features.get<VW::simple_label_reduction_features>().initial = initial;
Expand Down Expand Up @@ -741,13 +736,14 @@ void node_split(emt_tree& b, emt_node& cn)
cn.examples.clear();
}

void node_insert(emt_node& cn, std::unique_ptr<emt_example> ex)
void node_insert(emt_tree& b, emt_node& cn, std::unique_ptr<emt_example> ex)
{
for (auto& cn_ex : cn.examples)
{
if (cn_ex->full == ex->full) { return; }
}
cn.examples.push_back(std::move(ex));
tree_bound(b, cn.examples.back().get());
}

emt_example* node_pick(emt_tree& b, learner& base, emt_node& cn, const emt_example& ex)
Expand Down Expand Up @@ -779,16 +775,15 @@ void node_predict(emt_tree& b, learner& base, emt_node& cn, emt_example& ex, VW:
auto* closest_ex = node_pick(b, base, cn, ex);
ec.pred.multiclass = (closest_ex != nullptr) ? closest_ex->label : 0;
ec.loss = (ec.l.multi.label != ec.pred.multiclass) ? ec.weight : 0;
if (closest_ex != nullptr) { tree_bound(b, closest_ex); }
}

void emt_predict(emt_tree& b, learner& base, VW::example& ec)
{
b.all->feature_tweaks_config.ignore_some_linear = false;
emt_example ex(*b.all, &ec);

emt_node& cn = *tree_route(b, ex);
node_predict(b, base, cn, ex, ec);
tree_bound(b, &ex);
}

void emt_learn(emt_tree& b, learner& base, VW::example& ec)
Expand All @@ -797,10 +792,9 @@ void emt_learn(emt_tree& b, learner& base, VW::example& ec)
auto ex = VW::make_unique<emt_example>(*b.all, &ec);

emt_node& cn = *tree_route(b, *ex);
scorer_learn(b, base, cn, *ex, ec.weight);
node_predict(b, base, cn, *ex, ec); // vw learners predict and emt_learn
tree_bound(b, ex.get());
node_insert(cn, std::move(ex));
scorer_learn(b, base, cn, *ex, ec.weight);
node_insert(b, cn, std::move(ex));
node_split(b, cn);
}

Expand Down
41 changes: 40 additions & 1 deletion vowpalwabbit/core/tests/eigen_memory_tree_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ TEST(EigenMemoryTree, ExactMatchWithRouterTest)
}
}

TEST(EigenMemoryTree, Bounding)
TEST(EigenMemoryTree, BoundingDrop)
{
auto vw = VW::initialize(vwtest::make_args("--quiet", "--emt", "--emt_tree", "5"));
auto* tree = get_emt_tree(*vw);
Expand All @@ -148,6 +148,45 @@ TEST(EigenMemoryTree, Bounding)
EXPECT_EQ(tree->root->router_weights.size(), 0);
}

TEST(EigenMemoryTree, BoundingPredict)
{
auto vw = VW::initialize(vwtest::make_args("--quiet", "--emt", "--emt_tree", "3"));
auto* tree = get_emt_tree(*vw);

auto* ex = VW::read_example(*vw, "1 | 1");
vw->predict(*ex);
vw->finish_example(*ex);

EXPECT_EQ(tree->bounder->list.size(), 0);
}

TEST(EigenMemoryTree, BoundingRecency)
{
auto vw = VW::initialize(vwtest::make_args("--quiet", "--emt", "--emt_tree", "3"));
auto* tree = get_emt_tree(*vw);

for (int i = 0; i < 3; i++)
{
auto* ex = VW::read_example(*vw, std::to_string(i) + " | " + std::to_string(i));
vw->learn(*ex);
vw->finish_example(*ex);
}

EXPECT_EQ((*tree->bounder->list.begin())->base[0].first, 2);

auto* ex1 = VW::read_example(*vw, "1 | 1");
vw->predict(*ex1);
vw->finish_example(*ex1);

EXPECT_EQ((*tree->bounder->list.begin())->base[0].first, 1);

auto* ex2 = VW::read_example(*vw, "1 | 0");
vw->predict(*ex2);
vw->finish_example(*ex2);

EXPECT_EQ((*tree->bounder->list.begin())->base[0].first, 0);
}

TEST(EigenMemoryTree, Split)
{
auto args = vwtest::make_args("--quiet", "--emt", "--emt_tree", "10", "--emt_leaf", "3");
Expand Down

0 comments on commit 5907c04

Please sign in to comment.