diff --git a/vowpalwabbit/fb_parser/tests/flatbuffer_parser_test.cc b/vowpalwabbit/fb_parser/tests/flatbuffer_parser_test.cc index 153df495b8d..ad3a18a6ae0 100644 --- a/vowpalwabbit/fb_parser/tests/flatbuffer_parser_test.cc +++ b/vowpalwabbit/fb_parser/tests/flatbuffer_parser_test.cc @@ -285,8 +285,10 @@ void run_parse_and_verify_test(VW::workspace& w, const root_prototype_t& root_ob VW::io_buf buf; - span fb_span = builder.GetBufferSpan(); - buf.add_file(VW::io::create_buffer_view((const char*)fb_span.data(), fb_span.size())); + uint8_t* buf_ptr = builder.GetBufferPointer(); + size_t buf_size = builder.GetSize(); + + buf.add_file(VW::io::create_buffer_view((const char*)buf_ptr, buf_size)); VW::multi_ex examples; @@ -347,7 +349,143 @@ TEST(FlatbufferParser, MultiExample) }, } }; - + run_parse_and_verify_test(*all, prototype); } +namespace vwtest +{ + template + struct fb_type + { + }; + + template <> + struct fb_type + { + using type = VW::parsers::flatbuffer::Namespace; + }; + + template <> + struct fb_type + { + using type = VW::parsers::flatbuffer::Example; + }; + + template <> + struct fb_type + { + using type = VW::parsers::flatbuffer::MultiExample; + }; + + template <> + struct fb_type + { + using type = VW::parsers::flatbuffer::ExampleCollection; + }; +} + +template ::type> +void create_flatbuffer_and_validate(VW::workspace& w, const T& prototype) +{ + flatbuffers::FlatBufferBuilder builder; + + Offset buffer_offset = prototype.create_flatbuffer(builder, w); + builder.Finish(buffer_offset); + + const FB_t* fb_obj = GetRoot(builder.GetBufferPointer()); + + prototype.verify(w, fb_obj); +} + +TEST(FlatbufferParser, ValidateTestAffordances_Namespace) +{ + auto all = VW::initialize(vwtest::make_args("--no_stdin", "--quiet", "--flatbuffer")); + + prototype_namespace_t ns_prototype = { "U_a", { { "a", 1.f }, { "b", 2.f } } }; + create_flatbuffer_and_validate(*all, ns_prototype); +} + +TEST(FlatbufferParser, ValidateTestAffordances_Example_Simple) +{ + auto all = VW::initialize(vwtest::make_args("--no_stdin", "--quiet", "--flatbuffer")); + + prototype_example_t ex_prototype = { + { + { "U_a", { { "a", 1.f }, { "b", 2.f } } }, + { "U_b", { { "a", 3.f }, { "b", 4.f } } }, + }, + vwtest::simple_label(0.5, 1.0) + }; + create_flatbuffer_and_validate(*all, ex_prototype); +} + +// TEST(FlatbufferParser, ValidateTestAffordances_Example_Unlabeled) +// { +// auto all = VW::initialize(vwtest::make_args("--no_stdin", "--quiet", "--flatbuffer")); + +// prototype_example_t ex_prototype = { +// { +// { "U_a", { { "a", 1.f }, { "b", 2.f } } }, +// { "U_b", { { "a", 3.f }, { "b", 4.f } } }, +// } +// }; +// create_flatbuffer_and_validate(*all, ex_prototype); + +// } + +TEST(FlatbufferParser, ValidateTestAffordances_Example_CBShared) +{ + auto all = VW::initialize(vwtest::make_args("--no_stdin", "--quiet", "--flatbuffer", "--cb_explore_adf")); + + prototype_example_t ex_prototype = { + { + { "U_a", { { "a", 1.f }, { "b", 2.f } } }, + { "U_b", { { "a", 3.f }, { "b", 4.f } } }, + }, + vwtest::cb_label_shared(), + "tag1" + }; + create_flatbuffer_and_validate(*all, ex_prototype); +} + +TEST(FlatbufferParser, ValidateTestAffordances_Example_CB) +{ + auto all = VW::initialize(vwtest::make_args("--no_stdin", "--quiet", "--flatbuffer", "--cb_explore_adf")); + + prototype_example_t ex_prototype = { + { + { "T_a", { { "a", 5.f }, { "b", 6.f } } }, + { "T_b", { { "a", 7.f }, { "b", 8.f } } }, + }, + vwtest::cb_label({ 1, 1, 0.5f }), + "tag1" + }; + create_flatbuffer_and_validate(*all, ex_prototype); +} + +TEST(FlatbufferParser, ValidateTestAffordances_MultiExample) +{ + auto all = VW::initialize(vwtest::make_args("--no_stdin", "--quiet", "--flatbuffer")); + + prototype_multiexample_t multiex_prototype = { + { + { + { + { "U_a", { { "a", 1.f }, { "b", 2.f } } }, + { "U_b", { { "a", 3.f }, { "b", 4.f } } }, + }, + vwtest::cb_label_shared(), + "tag1" + }, + { + { + { "T_a", { { "a", 5.f }, { "b", 6.f } } }, + { "T_b", { { "a", 7.f }, { "b", 8.f } } }, + }, + vwtest::cb_label({ { 1, 1, 0.5f } }), + }, + } + }; + create_flatbuffer_and_validate(*all, multiex_prototype); +} diff --git a/vowpalwabbit/fb_parser/tests/prototype_example.h b/vowpalwabbit/fb_parser/tests/prototype_example.h index b90cf983f84..3ef14ed66e7 100644 --- a/vowpalwabbit/fb_parser/tests/prototype_example.h +++ b/vowpalwabbit/fb_parser/tests/prototype_example.h @@ -18,8 +18,6 @@ namespace vwtest struct prototype_example_t { - - std::vector namespaces; prototype_label_t label; const char* tag = nullptr; diff --git a/vowpalwabbit/fb_parser/tests/prototype_namespace.h b/vowpalwabbit/fb_parser/tests/prototype_namespace.h index df319bdc323..72907adb823 100644 --- a/vowpalwabbit/fb_parser/tests/prototype_namespace.h +++ b/vowpalwabbit/fb_parser/tests/prototype_namespace.h @@ -28,11 +28,9 @@ struct feature_t { } - feature_t(feature_t&& other) : name(std::move(other.name)), value(other.value), hash(other.hash) - { - } + feature_t(feature_t&& other) = delete; - feature_t(const feature_t& other) : name(other.name), value(other.value), hash(other.hash) + feature_t(const feature_t& other) : has_name(other.has_name), name(other.name), value(other.value), hash(other.hash) { } @@ -44,16 +42,16 @@ struct feature_t struct prototype_namespace_t { - prototype_namespace_t(const char* name, std::vector&& features) : has_name(true), name(name), features(features), hash(0), feature_group(name[0]) + prototype_namespace_t(const char* name, const std::vector& features) : has_name(true), name(name), features{features}, hash(0), feature_group(name[0]) { } - prototype_namespace_t(char feature_group, uint64_t hash, std::vector&& features) : has_name(false), name(nullptr), features(features), hash(hash), feature_group(feature_group) + prototype_namespace_t(char feature_group, uint64_t hash, const std::vector& features) : has_name(false), name(nullptr), features{features}, hash(hash), feature_group(feature_group) { } prototype_namespace_t(prototype_namespace_t&& other) = delete; - prototype_namespace_t(const prototype_namespace_t& other) : has_name(other.has_name), name(other.name), features(other.features), hash(other.hash), feature_group(other.feature_group) + prototype_namespace_t(const prototype_namespace_t& other) : has_name(other.has_name), name(other.name), features{other.features}, hash(other.hash), feature_group(other.feature_group) { } @@ -91,7 +89,7 @@ struct prototype_namespace_t const auto name_offset = has_name ? builder.CreateString(name) : Offset(); - Offset>> feature_names_offset = include_feature_names ? builder.CreateVector(feature_names) : Offset>>() /* nullptr */; + Offset>> feature_names_offset = builder.CreateVector(feature_names); Offset> feature_values_offset = builder.CreateVector(feature_values); Offset> feature_hashes_offset = builder.CreateVector(feature_hashes); @@ -175,9 +173,9 @@ struct prototype_namespace_t EXPECT_LT(extent_index, features.namespace_extents.size()); const auto& extent = features.namespace_extents[extent_index]; - for (size_t i_f = extent.begin_index; i_f < extent.end_index; i_f++) + for (size_t i_f = extent.begin_index, i = 0; i_f < extent.end_index && i < this->features.size(); i_f++, i++) { - auto& f = this->features[i_f]; + auto& f = this->features[i]; if (f.has_name) { EXPECT_EQ(features.indices[i_f], VW::hash_feature(w, f.name, hash));