Skip to content

Commit

Permalink
feat: Implement tests for uncovered label types
Browse files Browse the repository at this point in the history
  • Loading branch information
lokitoth committed Feb 7, 2024
1 parent ce1cefb commit 8cb4998
Show file tree
Hide file tree
Showing 4 changed files with 342 additions and 6 deletions.
88 changes: 87 additions & 1 deletion vowpalwabbit/fb_parser/tests/flatbuffer_parser_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -384,6 +384,62 @@ TEST(FlatbufferParser, MultiExample_Multiline)
run_parse_and_verify_test(*all, prototype);
}

TEST(FlatBufferParser, LabelSmokeTest_ContinuousLabel)
{
using namespace vwtest;
using example = vwtest::example;

auto all = VW::initialize(vwtest::make_args("--no_stdin", "--quiet", "--flatbuffer"));
example_data_generator datagen;

example ex = {{datagen.create_namespace("U_a", 1, 1)},

continuous_label({{1, 0.5f, 0.25}})};

run_parse_and_verify_test(*all, ex);
}

TEST(FlatBufferParser, LabelSmokeTest_Slates)
{
using namespace vwtest;

auto all = VW::initialize(vwtest::make_args("--no_stdin", "--quiet", "--flatbuffer", "--slates"));
example_data_generator datagen;

// this is not the best way to describe it as it is technically labelled in the strictest sense
// (namely, having slates labels associated with the examples), but there is no labelling data
// there, because we do not have a global cost or probabilities for the slots.
multiex unlabeled_example = {{{{datagen.create_namespace("Context", 1, 1)},

slates::shared()},
{{datagen.create_namespace("Action", 1, 1)},

slates::action(0)},
{{datagen.create_namespace("Action", 1, 1)},

slates::action(0)},
{{datagen.create_namespace("Slot", 1, 1)},

slates::slot(0)}}};

run_parse_and_verify_test(*all, unlabeled_example);

multiex labeled_example{{{{datagen.create_namespace("Context", 1, 1)},

slates::shared(0.5)},
{{datagen.create_namespace("Action", 1, 1)},

slates::action(0)},
{{datagen.create_namespace("Action", 1, 1)},

slates::action(0)},
{{datagen.create_namespace("Slot", 1, 1)},

slates::slot(0, {{1, 0.6}, {0, 0.4}})}}};

run_parse_and_verify_test(*all, labeled_example);
}

namespace vwtest
{
template <typename T>
Expand Down Expand Up @@ -451,6 +507,8 @@ void create_flatbuffer_and_validate<prototype_label_t, void>(VW::workspace& w, c
{
case fb::Label_SimpleLabel:
case fb::Label_CBLabel:
case fb::Label_ContinuousLabel:
case fb::Label_Slates_Label:
{
prototype.verify(w, prototype.label_type, builder.GetBufferPointer());
break;
Expand All @@ -461,7 +519,7 @@ void create_flatbuffer_and_validate<prototype_label_t, void>(VW::workspace& w, c
}
default:
{
THROW("Label type not currently supported");
THROW("Label type not currently supported for create_flatbuffer_and_validate");
break;
}
}
Expand All @@ -487,6 +545,34 @@ TEST(FlatBufferParser, ValidateTestAffordances_CBLabel)
create_flatbuffer_and_validate(*all, cb_label({1.5, 2, 0.25f}));
}

TEST(FlatBufferParser, ValidateTestAffordances_ContinuousLabel)
{
auto all = VW::initialize(vwtest::make_args("--no_stdin", "--quiet", "--flatbuffer"));

std::vector<VW::cb_continuous::continuous_label_elm> probabilities = {{1, 0.5f, 0.25}};

create_flatbuffer_and_validate(*all, continuous_label(probabilities));
}

TEST(FlatBufferParser, ValidateTestAffordances_Slates)
{
auto all = VW::initialize(vwtest::make_args("--no_stdin", "--quiet", "--flatbuffer", "--slates"));

std::vector<VW::action_score> probabilities = {{1, 0.5f}, {2, 0.25f}};

VW::slates::example_type types[] = {
VW::slates::example_type::UNSET,
VW::slates::example_type::ACTION,
VW::slates::example_type::SHARED,
VW::slates::example_type::SLOT,
};

for (VW::slates::example_type type : types)
{
create_flatbuffer_and_validate(*all, slates_label_raw(type, 0.5, true, 0.3, 1, probabilities));
}
}

TEST(FlatbufferParser, ValidateTestAffordances_Namespace)
{
auto all = VW::initialize(vwtest::make_args("--no_stdin", "--quiet", "--flatbuffer"));
Expand Down
3 changes: 2 additions & 1 deletion vowpalwabbit/fb_parser/tests/prototype_example_root.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,8 @@ inline void verify_example_root(
{
EXPECT_EQ(examples.size(), 1);
EXPECT_EQ(examples[0].size(), 1);
expected.verify<expect_feature_names>(vw, examples[0]);

expected.verify<expect_feature_names>(vw, *(examples[0][0]));
}

template <bool include_feature_names = true>
Expand Down
209 changes: 205 additions & 4 deletions vowpalwabbit/fb_parser/tests/prototype_label.cc
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@

#include "prototype_label.h"

#include "vw/core/cb_continuous_label.h"
#include "vw/core/slates_label.h"

namespace vwtest
{
Offset<void> prototype_label_t::create_flatbuffer(FlatBufferBuilder& builder, VW::workspace&) const
Expand Down Expand Up @@ -35,9 +38,60 @@ Offset<void> prototype_label_t::create_flatbuffer(FlatBufferBuilder& builder, VW
{
return 0;
}
case fb::Label_ContinuousLabel:
{
std::vector<Offset<fb::Continuous_Label_Elm>> costs;
costs.reserve(label.cb_cont.costs.size());

for (const auto& cost : label.cb_cont.costs)
{
costs.push_back(fb::CreateContinuous_Label_Elm(builder, cost.action, cost.cost));
}

Offset<Vector<Offset<fb::Continuous_Label_Elm>>> costs_fb_vector = builder.CreateVector(costs);
return fb::CreateContinuousLabel(builder, costs_fb_vector).Union();
}
case fb::Label_Slates_Label:
{
fb::CCB_Slates_example_type example_type = fb::CCB_Slates_example_type::CCB_Slates_example_type_unset;
switch (label.slates.type)
{
case VW::slates::example_type::UNSET:
example_type = fb::CCB_Slates_example_type::CCB_Slates_example_type_unset;
break;
case VW::slates::example_type::ACTION:
example_type = fb::CCB_Slates_example_type::CCB_Slates_example_type_action;
break;
case VW::slates::example_type::SHARED:
example_type = fb::CCB_Slates_example_type::CCB_Slates_example_type_shared;
break;
case VW::slates::example_type::SLOT:
example_type = fb::CCB_Slates_example_type::CCB_Slates_example_type_slot;
break;
default:
THROW("Slate label example type not currently supported");
}

auto action_scores = label.slates.probabilities;

// TODO: This conversion is kind of painful: we should consider expanding the probabilities
// vector into a pair of vectors
std::vector<Offset<fb::action_score>> fb_action_scores;
fb_action_scores.reserve(action_scores.size());
for (const auto& action_score : action_scores)
{
fb_action_scores.push_back(fb::Createaction_score(builder, action_score.action, action_score.score));
}

Offset<Vector<Offset<fb::action_score>>> fb_action_scores_fb_vector = builder.CreateVector(fb_action_scores);

return fb::CreateSlates_Label(builder, example_type, label.slates.weight, label.slates.labeled,
label.slates.cost, label.slates.slot_id, fb_action_scores_fb_vector)
.Union();
}
default:
{
THROW("Label type not currently supported");
THROW("Label type not currently supported for create_flatbuffer");
return 0;
}
}
Expand All @@ -62,10 +116,20 @@ void prototype_label_t::verify(VW::workspace&, const fb::Example* e) const
{
break;
}
case fb::Label_ContinuousLabel:
{
verify_continuous_label(e);
break;
}
case fb::Label_Slates_Label:
{
verify_slates_label(e);
break;
}
// TODO: other label types
default:
{
THROW("Label type not currently supported");
THROW("Label type not currently supported for verify");
break;
}
}
Expand All @@ -89,9 +153,19 @@ void prototype_label_t::verify(VW::workspace&, const VW::example& e) const
{
break;
}
case fb::Label_ContinuousLabel:
{
verify_continuous_label(e);
break;
}
case fb::Label_Slates_Label:
{
verify_slates_label(e);
break;
}
default:
{
THROW("Label type not currently supported");
THROW("Label type not currently supported for verify");
break;
}
}
Expand All @@ -116,9 +190,19 @@ void prototype_label_t::verify(VW::workspace&, fb::Label label_type, const void*
EXPECT_EQ(label, nullptr);
break;
}
case fb::Label_ContinuousLabel:
{
verify_continuous_label(GetRoot<fb::ContinuousLabel>(label));
break;
}
case fb::Label_Slates_Label:
{
verify_slates_label(GetRoot<fb::Slates_Label>(label));
break;
}
default:
{
THROW("Label type not currently supported");
THROW("Label type not currently supported for verify");
break;
}
}
Expand Down Expand Up @@ -180,6 +264,91 @@ void prototype_label_t::verify_cb_label(const VW::example& e) const
}
}

void prototype_label_t::verify_continuous_label(const fb::ContinuousLabel* actual_label) const
{
EXPECT_FLOAT_EQ(actual_label->costs()->size(), label.cb_cont.costs.size());

for (size_t i = 0; i < actual_label->costs()->size(); i++)
{
auto actual_cost = actual_label->costs()->Get(i);
auto expected_cost = label.cb_cont.costs[i];

EXPECT_EQ(actual_cost->action(), expected_cost.action);
EXPECT_FLOAT_EQ(actual_cost->cost(), expected_cost.cost);
}
}

void prototype_label_t::verify_continuous_label(const VW::example& e) const
{
using label_t = VW::cb_continuous::continuous_label;

const label_t actual_label = e.l.cb_cont;

EXPECT_EQ(actual_label.costs.size(), label.cb_cont.costs.size());

for (size_t i = 0; i < actual_label.costs.size(); i++)
{
EXPECT_EQ(actual_label.costs[i].action, label.cb_cont.costs[i].action);
EXPECT_FLOAT_EQ(actual_label.costs[i].cost, label.cb_cont.costs[i].cost);
}
}

bool are_equal(fb::CCB_Slates_example_type lhs, VW::slates::example_type rhs)
{
switch (rhs)
{
case VW::slates::example_type::UNSET:
return lhs == fb::CCB_Slates_example_type_unset;
case VW::slates::example_type::ACTION:
return lhs == fb::CCB_Slates_example_type_action;
case VW::slates::example_type::SHARED:
return lhs == fb::CCB_Slates_example_type_shared;
case VW::slates::example_type::SLOT:
return lhs == fb::CCB_Slates_example_type_slot;
default:
THROW("Slates label example type not currently supported");
}
}

void prototype_label_t::verify_slates_label(const fb::Slates_Label* actual_label) const
{
EXPECT_TRUE(are_equal(actual_label->example_type(), label.slates.type));
EXPECT_FLOAT_EQ(actual_label->weight(), label.slates.weight);
EXPECT_FLOAT_EQ(actual_label->cost(), label.slates.cost);
EXPECT_EQ(actual_label->slot(), label.slates.slot_id);
EXPECT_EQ(actual_label->labeled(), label.slates.labeled);
EXPECT_EQ(actual_label->probabilities()->size(), label.slates.probabilities.size());

for (size_t i = 0; i < actual_label->probabilities()->size(); i++)
{
auto actual_prob = actual_label->probabilities()->Get(i);
auto expected_prob = label.slates.probabilities[i];

EXPECT_EQ(actual_prob->action(), expected_prob.action);
EXPECT_FLOAT_EQ(actual_prob->score(), expected_prob.score);
}
}

void prototype_label_t::verify_slates_label(const VW::example& e) const
{
using label_t = VW::slates::label;

const label_t actual_label = e.l.slates;

EXPECT_EQ(actual_label.type, label.slates.type);
EXPECT_FLOAT_EQ(actual_label.weight, label.slates.weight);
EXPECT_FLOAT_EQ(actual_label.cost, label.slates.cost);
EXPECT_EQ(actual_label.slot_id, label.slates.slot_id);
EXPECT_EQ(actual_label.labeled, label.slates.labeled);
EXPECT_EQ(actual_label.probabilities.size(), label.slates.probabilities.size());

for (size_t i = 0; i < actual_label.probabilities.size(); i++)
{
EXPECT_EQ(actual_label.probabilities[i].action, label.slates.probabilities[i].action);
EXPECT_FLOAT_EQ(actual_label.probabilities[i].score, label.slates.probabilities[i].score);
}
}

prototype_label_t no_label()
{
VW::polylabel actual_label;
Expand Down Expand Up @@ -227,4 +396,36 @@ prototype_label_t cb_label_shared()
*/
return cb_label(VW::cb_class(0., 0, -1.), 1.);
}

prototype_label_t continuous_label(std::vector<VW::cb_continuous::continuous_label_elm> costs)
{
VW::polylabel actual_label;
v_array<VW::cb_continuous::continuous_label_elm> costs_v;
costs_v.reserve(costs.size());
for (size_t i = 0; i < costs.size(); i++) { costs_v.push_back(costs[i]); }

actual_label.cb_cont = {costs_v};

return prototype_label_t{fb::Label_ContinuousLabel, actual_label, {}};
}

prototype_label_t slates_label_raw(VW::slates::example_type type, float weight, bool labeled, float cost,
uint32_t slot_id, std::vector<VW::action_score> probabilities)
{
VW::slates::label slates_label;
slates_label.type = type;
slates_label.weight = weight;
slates_label.labeled = labeled;
slates_label.cost = cost;
slates_label.slot_id = slot_id;

slates_label.probabilities.reserve(probabilities.size());
for (const auto& action_score : probabilities) { slates_label.probabilities.push_back(action_score); }

VW::polylabel actual_label;
actual_label.slates = slates_label;

return prototype_label_t{fb::Label_Slates_Label, actual_label, {}};
}

} // namespace vwtest
Loading

0 comments on commit 8cb4998

Please sign in to comment.