diff --git a/vowpalwabbit/fb_parser/tests/flatbuffer_parser_test.cc b/vowpalwabbit/fb_parser/tests/flatbuffer_parser_test.cc index 01e606cf40f..d7902975798 100644 --- a/vowpalwabbit/fb_parser/tests/flatbuffer_parser_test.cc +++ b/vowpalwabbit/fb_parser/tests/flatbuffer_parser_test.cc @@ -384,6 +384,62 @@ TEST(FlatbufferParser, MultiExample_Multiline) run_parse_and_verify_test(*all, prototype); } +TEST(FlatBufferParser, LabelSmokeTest_ContinuousLabel) +{ + using namespace vwtest; + using example = vwtest::example; + + auto all = VW::initialize(vwtest::make_args("--no_stdin", "--quiet", "--flatbuffer")); + example_data_generator datagen; + + example ex = {{datagen.create_namespace("U_a", 1, 1)}, + + continuous_label({{1, 0.5f, 0.25}})}; + + run_parse_and_verify_test(*all, ex); +} + +TEST(FlatBufferParser, LabelSmokeTest_Slates) +{ + using namespace vwtest; + + auto all = VW::initialize(vwtest::make_args("--no_stdin", "--quiet", "--flatbuffer", "--slates")); + example_data_generator datagen; + + // this is not the best way to describe it as it is technically labelled in the strictest sense + // (namely, having slates labels associated with the examples), but there is no labelling data + // there, because we do not have a global cost or probabilities for the slots. + multiex unlabeled_example = {{{{datagen.create_namespace("Context", 1, 1)}, + + slates::shared()}, + {{datagen.create_namespace("Action", 1, 1)}, + + slates::action(0)}, + {{datagen.create_namespace("Action", 1, 1)}, + + slates::action(0)}, + {{datagen.create_namespace("Slot", 1, 1)}, + + slates::slot(0)}}}; + + run_parse_and_verify_test(*all, unlabeled_example); + + multiex labeled_example{{{{datagen.create_namespace("Context", 1, 1)}, + + slates::shared(0.5)}, + {{datagen.create_namespace("Action", 1, 1)}, + + slates::action(0)}, + {{datagen.create_namespace("Action", 1, 1)}, + + slates::action(0)}, + {{datagen.create_namespace("Slot", 1, 1)}, + + slates::slot(0, {{1, 0.6}, {0, 0.4}})}}}; + + run_parse_and_verify_test(*all, labeled_example); +} + namespace vwtest { template @@ -451,6 +507,8 @@ void create_flatbuffer_and_validate(VW::workspace& w, c { case fb::Label_SimpleLabel: case fb::Label_CBLabel: + case fb::Label_ContinuousLabel: + case fb::Label_Slates_Label: { prototype.verify(w, prototype.label_type, builder.GetBufferPointer()); break; @@ -461,7 +519,7 @@ void create_flatbuffer_and_validate(VW::workspace& w, c } default: { - THROW("Label type not currently supported"); + THROW("Label type not currently supported for create_flatbuffer_and_validate"); break; } } @@ -487,6 +545,34 @@ TEST(FlatBufferParser, ValidateTestAffordances_CBLabel) create_flatbuffer_and_validate(*all, cb_label({1.5, 2, 0.25f})); } +TEST(FlatBufferParser, ValidateTestAffordances_ContinuousLabel) +{ + auto all = VW::initialize(vwtest::make_args("--no_stdin", "--quiet", "--flatbuffer")); + + std::vector probabilities = {{1, 0.5f, 0.25}}; + + create_flatbuffer_and_validate(*all, continuous_label(probabilities)); +} + +TEST(FlatBufferParser, ValidateTestAffordances_Slates) +{ + auto all = VW::initialize(vwtest::make_args("--no_stdin", "--quiet", "--flatbuffer", "--slates")); + + std::vector probabilities = {{1, 0.5f}, {2, 0.25f}}; + + VW::slates::example_type types[] = { + VW::slates::example_type::UNSET, + VW::slates::example_type::ACTION, + VW::slates::example_type::SHARED, + VW::slates::example_type::SLOT, + }; + + for (VW::slates::example_type type : types) + { + create_flatbuffer_and_validate(*all, slates_label_raw(type, 0.5, true, 0.3, 1, probabilities)); + } +} + TEST(FlatbufferParser, ValidateTestAffordances_Namespace) { auto all = VW::initialize(vwtest::make_args("--no_stdin", "--quiet", "--flatbuffer")); diff --git a/vowpalwabbit/fb_parser/tests/prototype_example_root.h b/vowpalwabbit/fb_parser/tests/prototype_example_root.h index ea58a68e0d7..facf693a304 100644 --- a/vowpalwabbit/fb_parser/tests/prototype_example_root.h +++ b/vowpalwabbit/fb_parser/tests/prototype_example_root.h @@ -38,7 +38,8 @@ inline void verify_example_root( { EXPECT_EQ(examples.size(), 1); EXPECT_EQ(examples[0].size(), 1); - expected.verify(vw, examples[0]); + + expected.verify(vw, *(examples[0][0])); } template diff --git a/vowpalwabbit/fb_parser/tests/prototype_label.cc b/vowpalwabbit/fb_parser/tests/prototype_label.cc index cb71489e346..f4af045050e 100644 --- a/vowpalwabbit/fb_parser/tests/prototype_label.cc +++ b/vowpalwabbit/fb_parser/tests/prototype_label.cc @@ -4,6 +4,9 @@ #include "prototype_label.h" +#include "vw/core/cb_continuous_label.h" +#include "vw/core/slates_label.h" + namespace vwtest { Offset prototype_label_t::create_flatbuffer(FlatBufferBuilder& builder, VW::workspace&) const @@ -35,9 +38,60 @@ Offset prototype_label_t::create_flatbuffer(FlatBufferBuilder& builder, VW { return 0; } + case fb::Label_ContinuousLabel: + { + std::vector> costs; + costs.reserve(label.cb_cont.costs.size()); + + for (const auto& cost : label.cb_cont.costs) + { + costs.push_back(fb::CreateContinuous_Label_Elm(builder, cost.action, cost.cost)); + } + + Offset>> costs_fb_vector = builder.CreateVector(costs); + return fb::CreateContinuousLabel(builder, costs_fb_vector).Union(); + } + case fb::Label_Slates_Label: + { + fb::CCB_Slates_example_type example_type = fb::CCB_Slates_example_type::CCB_Slates_example_type_unset; + switch (label.slates.type) + { + case VW::slates::example_type::UNSET: + example_type = fb::CCB_Slates_example_type::CCB_Slates_example_type_unset; + break; + case VW::slates::example_type::ACTION: + example_type = fb::CCB_Slates_example_type::CCB_Slates_example_type_action; + break; + case VW::slates::example_type::SHARED: + example_type = fb::CCB_Slates_example_type::CCB_Slates_example_type_shared; + break; + case VW::slates::example_type::SLOT: + example_type = fb::CCB_Slates_example_type::CCB_Slates_example_type_slot; + break; + default: + THROW("Slate label example type not currently supported"); + } + + auto action_scores = label.slates.probabilities; + + // TODO: This conversion is kind of painful: we should consider expanding the probabilities + // vector into a pair of vectors + std::vector> fb_action_scores; + fb_action_scores.reserve(action_scores.size()); + for (const auto& action_score : action_scores) + { + fb_action_scores.push_back(fb::Createaction_score(builder, action_score.action, action_score.score)); + } + + Offset>> fb_action_scores_fb_vector = builder.CreateVector(fb_action_scores); + + return fb::CreateSlates_Label(builder, example_type, label.slates.weight, label.slates.labeled, + label.slates.cost, label.slates.slot_id, fb_action_scores_fb_vector) + .Union(); + } default: { - THROW("Label type not currently supported"); + THROW("Label type not currently supported for create_flatbuffer"); return 0; } } @@ -62,10 +116,20 @@ void prototype_label_t::verify(VW::workspace&, const fb::Example* e) const { break; } + case fb::Label_ContinuousLabel: + { + verify_continuous_label(e); + break; + } + case fb::Label_Slates_Label: + { + verify_slates_label(e); + break; + } // TODO: other label types default: { - THROW("Label type not currently supported"); + THROW("Label type not currently supported for verify"); break; } } @@ -89,9 +153,19 @@ void prototype_label_t::verify(VW::workspace&, const VW::example& e) const { break; } + case fb::Label_ContinuousLabel: + { + verify_continuous_label(e); + break; + } + case fb::Label_Slates_Label: + { + verify_slates_label(e); + break; + } default: { - THROW("Label type not currently supported"); + THROW("Label type not currently supported for verify"); break; } } @@ -116,9 +190,19 @@ void prototype_label_t::verify(VW::workspace&, fb::Label label_type, const void* EXPECT_EQ(label, nullptr); break; } + case fb::Label_ContinuousLabel: + { + verify_continuous_label(GetRoot(label)); + break; + } + case fb::Label_Slates_Label: + { + verify_slates_label(GetRoot(label)); + break; + } default: { - THROW("Label type not currently supported"); + THROW("Label type not currently supported for verify"); break; } } @@ -180,6 +264,91 @@ void prototype_label_t::verify_cb_label(const VW::example& e) const } } +void prototype_label_t::verify_continuous_label(const fb::ContinuousLabel* actual_label) const +{ + EXPECT_FLOAT_EQ(actual_label->costs()->size(), label.cb_cont.costs.size()); + + for (size_t i = 0; i < actual_label->costs()->size(); i++) + { + auto actual_cost = actual_label->costs()->Get(i); + auto expected_cost = label.cb_cont.costs[i]; + + EXPECT_EQ(actual_cost->action(), expected_cost.action); + EXPECT_FLOAT_EQ(actual_cost->cost(), expected_cost.cost); + } +} + +void prototype_label_t::verify_continuous_label(const VW::example& e) const +{ + using label_t = VW::cb_continuous::continuous_label; + + const label_t actual_label = e.l.cb_cont; + + EXPECT_EQ(actual_label.costs.size(), label.cb_cont.costs.size()); + + for (size_t i = 0; i < actual_label.costs.size(); i++) + { + EXPECT_EQ(actual_label.costs[i].action, label.cb_cont.costs[i].action); + EXPECT_FLOAT_EQ(actual_label.costs[i].cost, label.cb_cont.costs[i].cost); + } +} + +bool are_equal(fb::CCB_Slates_example_type lhs, VW::slates::example_type rhs) +{ + switch (rhs) + { + case VW::slates::example_type::UNSET: + return lhs == fb::CCB_Slates_example_type_unset; + case VW::slates::example_type::ACTION: + return lhs == fb::CCB_Slates_example_type_action; + case VW::slates::example_type::SHARED: + return lhs == fb::CCB_Slates_example_type_shared; + case VW::slates::example_type::SLOT: + return lhs == fb::CCB_Slates_example_type_slot; + default: + THROW("Slates label example type not currently supported"); + } +} + +void prototype_label_t::verify_slates_label(const fb::Slates_Label* actual_label) const +{ + EXPECT_TRUE(are_equal(actual_label->example_type(), label.slates.type)); + EXPECT_FLOAT_EQ(actual_label->weight(), label.slates.weight); + EXPECT_FLOAT_EQ(actual_label->cost(), label.slates.cost); + EXPECT_EQ(actual_label->slot(), label.slates.slot_id); + EXPECT_EQ(actual_label->labeled(), label.slates.labeled); + EXPECT_EQ(actual_label->probabilities()->size(), label.slates.probabilities.size()); + + for (size_t i = 0; i < actual_label->probabilities()->size(); i++) + { + auto actual_prob = actual_label->probabilities()->Get(i); + auto expected_prob = label.slates.probabilities[i]; + + EXPECT_EQ(actual_prob->action(), expected_prob.action); + EXPECT_FLOAT_EQ(actual_prob->score(), expected_prob.score); + } +} + +void prototype_label_t::verify_slates_label(const VW::example& e) const +{ + using label_t = VW::slates::label; + + const label_t actual_label = e.l.slates; + + EXPECT_EQ(actual_label.type, label.slates.type); + EXPECT_FLOAT_EQ(actual_label.weight, label.slates.weight); + EXPECT_FLOAT_EQ(actual_label.cost, label.slates.cost); + EXPECT_EQ(actual_label.slot_id, label.slates.slot_id); + EXPECT_EQ(actual_label.labeled, label.slates.labeled); + EXPECT_EQ(actual_label.probabilities.size(), label.slates.probabilities.size()); + + for (size_t i = 0; i < actual_label.probabilities.size(); i++) + { + EXPECT_EQ(actual_label.probabilities[i].action, label.slates.probabilities[i].action); + EXPECT_FLOAT_EQ(actual_label.probabilities[i].score, label.slates.probabilities[i].score); + } +} + prototype_label_t no_label() { VW::polylabel actual_label; @@ -227,4 +396,36 @@ prototype_label_t cb_label_shared() */ return cb_label(VW::cb_class(0., 0, -1.), 1.); } + +prototype_label_t continuous_label(std::vector costs) +{ + VW::polylabel actual_label; + v_array costs_v; + costs_v.reserve(costs.size()); + for (size_t i = 0; i < costs.size(); i++) { costs_v.push_back(costs[i]); } + + actual_label.cb_cont = {costs_v}; + + return prototype_label_t{fb::Label_ContinuousLabel, actual_label, {}}; +} + +prototype_label_t slates_label_raw(VW::slates::example_type type, float weight, bool labeled, float cost, + uint32_t slot_id, std::vector probabilities) +{ + VW::slates::label slates_label; + slates_label.type = type; + slates_label.weight = weight; + slates_label.labeled = labeled; + slates_label.cost = cost; + slates_label.slot_id = slot_id; + + slates_label.probabilities.reserve(probabilities.size()); + for (const auto& action_score : probabilities) { slates_label.probabilities.push_back(action_score); } + + VW::polylabel actual_label; + actual_label.slates = slates_label; + + return prototype_label_t{fb::Label_Slates_Label, actual_label, {}}; +} + } // namespace vwtest diff --git a/vowpalwabbit/fb_parser/tests/prototype_label.h b/vowpalwabbit/fb_parser/tests/prototype_label.h index 770ca06844c..98d7a905212 100644 --- a/vowpalwabbit/fb_parser/tests/prototype_label.h +++ b/vowpalwabbit/fb_parser/tests/prototype_label.h @@ -18,6 +18,7 @@ using namespace flatbuffers; namespace vwtest { + struct prototype_label_t { fb::Label label_type; @@ -53,6 +54,28 @@ struct prototype_label_t void verify_cb_label(const fb::CBLabel* label) const; void verify_cb_label(const VW::example& ex) const; + + inline void verify_continuous_label(const fb::Example* ex) const + { + EXPECT_EQ(ex->label_type(), fb::Label_ContinuousLabel); + + const fb::ContinuousLabel* actual_label = ex->label_as_ContinuousLabel(); + verify_continuous_label(actual_label); + } + + void verify_continuous_label(const fb::ContinuousLabel* label) const; + void verify_continuous_label(const VW::example& ex) const; + + inline void verify_slates_label(const fb::Example* ex) const + { + EXPECT_EQ(ex->label_type(), fb::Label_Slates_Label); + + const fb::Slates_Label* actual_label = ex->label_as_Slates_Label(); + verify_slates_label(actual_label); + } + + void verify_slates_label(const fb::Slates_Label* label) const; + void verify_slates_label(const VW::example& ex) const; }; prototype_label_t no_label(); @@ -62,4 +85,29 @@ prototype_label_t simple_label(float label, float weight, float initial = 0.f); prototype_label_t cb_label(std::vector costs, float weight = 1.0f); prototype_label_t cb_label(VW::cb_class single_class, float weight = 1.0f); prototype_label_t cb_label_shared(); + +prototype_label_t continuous_label(std::vector costs); + +prototype_label_t slates_label_raw(VW::slates::example_type type, float weight, bool labeled, float cost, + uint32_t slot_id, std::vector probabilities); + +namespace slates +{ +inline prototype_label_t shared() +{ + return vwtest::slates_label_raw(VW::slates::example_type::SHARED, 0.0f, false, 0.0f, 0, {}); +} +inline prototype_label_t shared(float global_reward) +{ + return vwtest::slates_label_raw(VW::slates::example_type::SHARED, 0.0f, true, global_reward, 0, {}); +} +inline prototype_label_t action(uint32_t for_slot) +{ + return vwtest::slates_label_raw(VW::slates::example_type::ACTION, 0.0f, false, 0.0f, for_slot, {}); +} +inline prototype_label_t slot(uint32_t slot_id, std::vector probabilities = {}) +{ + return vwtest::slates_label_raw(VW::slates::example_type::SLOT, 0.0f, false, 0.0f, slot_id, probabilities); +} +}; // namespace slates } // namespace vwtest \ No newline at end of file