diff --git a/vowpalwabbit/fb_parser/src/parse_example_flatbuffer.cc b/vowpalwabbit/fb_parser/src/parse_example_flatbuffer.cc index 8bc7db22bdf..966377505f9 100644 --- a/vowpalwabbit/fb_parser/src/parse_example_flatbuffer.cc +++ b/vowpalwabbit/fb_parser/src/parse_example_flatbuffer.cc @@ -373,30 +373,11 @@ int parser::get_namespace_index(const Namespace* ns, namespace_index& ni, VW::ex ni = static_cast(ns->name()->c_str()[0]); return VW::experimental::error_code::success; } - else if (flatbuffers::IsFieldPresent(ns, Namespace::VT_HASH)) + else { ni = ns->hash(); return VW::experimental::error_code::success; } - - if (_active_collection && _active_multi_ex) - { - RETURN_ERROR_LS(status, fb_parser_name_hash_missing) - << "Either name or hash field must be specified to get the namespace index in collection item with example " - "index " - << _example_index << "and multi example index " << _multi_ex_index; - } - else if (_active_multi_ex) - { - RETURN_ERROR_LS(status, fb_parser_name_hash_missing) - << "Either name or hash field must be specified to get the namespace index in multi example index " - << _multi_ex_index; - } - else - { - RETURN_ERROR_LS(status, fb_parser_name_hash_missing) - << "Either name or hash field must be specified to get the namespace index"; - } } bool get_namespace_hash(VW::workspace* all, const Namespace* ns, uint64_t& hash) diff --git a/vowpalwabbit/fb_parser/tests/example_data_generator.h b/vowpalwabbit/fb_parser/tests/example_data_generator.h index 69a7805d1c8..6b12f9636fe 100644 --- a/vowpalwabbit/fb_parser/tests/example_data_generator.h +++ b/vowpalwabbit/fb_parser/tests/example_data_generator.h @@ -32,6 +32,10 @@ class example_data_generator static VW::rand_state create_test_random_state(); + inline bool random_bool() { return rng.get_and_update_random() >= 0.5; } + + inline int random_int(int min, int max) { return static_cast(rng.get_and_update_random() * (max - min) + min); } + prototype_namespace_t create_namespace(std::string name, uint8_t numeric_features, uint8_t string_features); prototype_example_t create_simple_example(uint8_t numeric_features, uint8_t string_features); diff --git a/vowpalwabbit/fb_parser/tests/read_span_tests.cc b/vowpalwabbit/fb_parser/tests/read_span_tests.cc index 5622be90dae..cac108a85e9 100644 --- a/vowpalwabbit/fb_parser/tests/read_span_tests.cc +++ b/vowpalwabbit/fb_parser/tests/read_span_tests.cc @@ -67,7 +67,7 @@ inline void verify_multi_ex( } // namespace vwtest template ::type> -void create_flatbuffer_span_and_validate(VW::workspace& w, const T& prototype) +void create_flatbuffer_span_and_validate(VW::workspace& w, vwtest::example_data_generator& data_gen, const T& prototype) { // This is what we expect to see when we use read_span_flatbuffer, since this is intended // to be used for inference, and we would prefer not to force consumers of the API to have @@ -85,6 +85,11 @@ void create_flatbuffer_span_and_validate(VW::workspace& w, const T& prototype) flatbuffers::uoffset_t size = builder.GetSize(); VW::multi_ex parsed_examples; + if (data_gen.random_bool()) + { + parsed_examples.push_back(&ex_fac()); + } + VW::parsers::flatbuffer::read_span_flatbuffer(&w, buffer, size, ex_fac, parsed_examples); verify_multi_ex(w, prototype, parsed_examples); @@ -100,7 +105,7 @@ TEST(FlatbufferParser, ReadSpanFlatbuffer_SingleExample) vwtest::prototype_example_t prototype = { {data_gen.create_namespace("A", 3, 4), data_gen.create_namespace("B", 2, 5)}, vwtest::simple_label(1.0f)}; - create_flatbuffer_span_and_validate(*all, prototype); + create_flatbuffer_span_and_validate(*all, data_gen, prototype); } TEST(FlatbufferParser, ReadSpanFlatbuffer_MultiExample) @@ -110,7 +115,7 @@ TEST(FlatbufferParser, ReadSpanFlatbuffer_MultiExample) vwtest::example_data_generator data_gen; vwtest::prototype_multiexample_t prototype = data_gen.create_cb_adf_example(3, 1, "tag"); - create_flatbuffer_span_and_validate(*all, prototype); + create_flatbuffer_span_and_validate(*all, data_gen, prototype); } TEST(FlatbufferParser, ReadSpanFlatbuffer_ExampleCollectionSinglelines) @@ -120,7 +125,7 @@ TEST(FlatbufferParser, ReadSpanFlatbuffer_ExampleCollectionSinglelines) vwtest::example_data_generator data_gen; vwtest::prototype_example_collection_t prototype = data_gen.create_simple_log(3, 3, 4); - create_flatbuffer_span_and_validate(*all, prototype); + create_flatbuffer_span_and_validate(*all, data_gen, prototype); } TEST(FlatbufferParser, ReadSpanFlatbuffer_ExampleCollectionMultiline) @@ -130,13 +135,19 @@ TEST(FlatbufferParser, ReadSpanFlatbuffer_ExampleCollectionMultiline) vwtest::example_data_generator data_gen; vwtest::prototype_example_collection_t prototype = data_gen.create_cb_adf_log(1, 3, 4); - create_flatbuffer_span_and_validate(*all, prototype); + create_flatbuffer_span_and_validate(*all, data_gen, prototype); } template void finish_flatbuffer_and_expect_error(FlatBufferBuilder& builder, Offset root, VW::workspace& w) { VW::example_factory_t ex_fac = [&w]() -> VW::example& { return VW::get_unused_example(&w); }; + VW::example_sink_f ex_sink = [&w](VW::example& ex) { VW::finish_example(w, ex); }; + if (vwtest::example_data_generator{}::random_bool()) + { + // This is only valid because ex_fac is grabbing an example from the VW example pool + ex_sink = nullptr; + } builder.FinishSizePrefixed(root); @@ -147,7 +158,7 @@ void finish_flatbuffer_and_expect_error(FlatBufferBuilder& builder, Offset