Skip to content

Commit

Permalink
fix: test validation logic was indexing incorrectly
Browse files Browse the repository at this point in the history
In the prototype_namespace_t's validator, we need to carefully iterate the example's features because we need to be in the correct extent. There was an indexing issue where we were using the extent-index in a pure namespace-index manner, which caused us to read beyond the end of the vector, leading to spurious data being loaded.
  • Loading branch information
lokitoth committed Feb 6, 2024
1 parent 98ce35d commit f9cb076
Show file tree
Hide file tree
Showing 3 changed files with 149 additions and 15 deletions.
144 changes: 141 additions & 3 deletions vowpalwabbit/fb_parser/tests/flatbuffer_parser_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -285,8 +285,10 @@ void run_parse_and_verify_test(VW::workspace& w, const root_prototype_t& root_ob

VW::io_buf buf;

span<uint8_t> fb_span = builder.GetBufferSpan();
buf.add_file(VW::io::create_buffer_view((const char*)fb_span.data(), fb_span.size()));
uint8_t* buf_ptr = builder.GetBufferPointer();
size_t buf_size = builder.GetSize();

buf.add_file(VW::io::create_buffer_view((const char*)buf_ptr, buf_size));

VW::multi_ex examples;

Expand Down Expand Up @@ -347,7 +349,143 @@ TEST(FlatbufferParser, MultiExample)
},
}
};

run_parse_and_verify_test(*all, prototype);
}

namespace vwtest
{
template <typename T>
struct fb_type
{
};

template <>
struct fb_type<prototype_namespace_t>
{
using type = VW::parsers::flatbuffer::Namespace;
};

template <>
struct fb_type<prototype_example_t>
{
using type = VW::parsers::flatbuffer::Example;
};

template <>
struct fb_type<prototype_multiexample_t>
{
using type = VW::parsers::flatbuffer::MultiExample;
};

template <>
struct fb_type<prototype_example_collection_t>
{
using type = VW::parsers::flatbuffer::ExampleCollection;
};
}

template <typename T, typename FB_t = typename vwtest::fb_type<T>::type>
void create_flatbuffer_and_validate(VW::workspace& w, const T& prototype)
{
flatbuffers::FlatBufferBuilder builder;

Offset<FB_t> buffer_offset = prototype.create_flatbuffer(builder, w);
builder.Finish(buffer_offset);

const FB_t* fb_obj = GetRoot<FB_t>(builder.GetBufferPointer());

prototype.verify(w, fb_obj);
}

TEST(FlatbufferParser, ValidateTestAffordances_Namespace)
{
auto all = VW::initialize(vwtest::make_args("--no_stdin", "--quiet", "--flatbuffer"));

prototype_namespace_t ns_prototype = { "U_a", { { "a", 1.f }, { "b", 2.f } } };
create_flatbuffer_and_validate(*all, ns_prototype);
}

TEST(FlatbufferParser, ValidateTestAffordances_Example_Simple)
{
auto all = VW::initialize(vwtest::make_args("--no_stdin", "--quiet", "--flatbuffer"));

prototype_example_t ex_prototype = {
{
{ "U_a", { { "a", 1.f }, { "b", 2.f } } },
{ "U_b", { { "a", 3.f }, { "b", 4.f } } },
},
vwtest::simple_label(0.5, 1.0)
};
create_flatbuffer_and_validate(*all, ex_prototype);
}

// TEST(FlatbufferParser, ValidateTestAffordances_Example_Unlabeled)
// {
// auto all = VW::initialize(vwtest::make_args("--no_stdin", "--quiet", "--flatbuffer"));

// prototype_example_t ex_prototype = {
// {
// { "U_a", { { "a", 1.f }, { "b", 2.f } } },
// { "U_b", { { "a", 3.f }, { "b", 4.f } } },
// }
// };
// create_flatbuffer_and_validate(*all, ex_prototype);

// }

TEST(FlatbufferParser, ValidateTestAffordances_Example_CBShared)
{
auto all = VW::initialize(vwtest::make_args("--no_stdin", "--quiet", "--flatbuffer", "--cb_explore_adf"));

prototype_example_t ex_prototype = {
{
{ "U_a", { { "a", 1.f }, { "b", 2.f } } },
{ "U_b", { { "a", 3.f }, { "b", 4.f } } },
},
vwtest::cb_label_shared(),
"tag1"
};
create_flatbuffer_and_validate(*all, ex_prototype);
}

TEST(FlatbufferParser, ValidateTestAffordances_Example_CB)
{
auto all = VW::initialize(vwtest::make_args("--no_stdin", "--quiet", "--flatbuffer", "--cb_explore_adf"));

prototype_example_t ex_prototype = {
{
{ "T_a", { { "a", 5.f }, { "b", 6.f } } },
{ "T_b", { { "a", 7.f }, { "b", 8.f } } },
},
vwtest::cb_label({ 1, 1, 0.5f }),
"tag1"
};
create_flatbuffer_and_validate(*all, ex_prototype);
}

TEST(FlatbufferParser, ValidateTestAffordances_MultiExample)
{
auto all = VW::initialize(vwtest::make_args("--no_stdin", "--quiet", "--flatbuffer"));

prototype_multiexample_t multiex_prototype = {
{
{
{
{ "U_a", { { "a", 1.f }, { "b", 2.f } } },
{ "U_b", { { "a", 3.f }, { "b", 4.f } } },
},
vwtest::cb_label_shared(),
"tag1"
},
{
{
{ "T_a", { { "a", 5.f }, { "b", 6.f } } },
{ "T_b", { { "a", 7.f }, { "b", 8.f } } },
},
vwtest::cb_label({ { 1, 1, 0.5f } }),
},
}
};
create_flatbuffer_and_validate(*all, multiex_prototype);
}
2 changes: 0 additions & 2 deletions vowpalwabbit/fb_parser/tests/prototype_example.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,6 @@ namespace vwtest

struct prototype_example_t
{


std::vector<prototype_namespace_t> namespaces;
prototype_label_t label;
const char* tag = nullptr;
Expand Down
18 changes: 8 additions & 10 deletions vowpalwabbit/fb_parser/tests/prototype_namespace.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,9 @@ struct feature_t
{
}

feature_t(feature_t&& other) : name(std::move(other.name)), value(other.value), hash(other.hash)
{
}
feature_t(feature_t&& other) = delete;

feature_t(const feature_t& other) : name(other.name), value(other.value), hash(other.hash)
feature_t(const feature_t& other) : has_name(other.has_name), name(other.name), value(other.value), hash(other.hash)
{
}

Expand All @@ -44,16 +42,16 @@ struct feature_t

struct prototype_namespace_t
{
prototype_namespace_t(const char* name, std::vector<feature_t>&& features) : has_name(true), name(name), features(features), hash(0), feature_group(name[0])
prototype_namespace_t(const char* name, const std::vector<feature_t>& features) : has_name(true), name(name), features{features}, hash(0), feature_group(name[0])
{
}

prototype_namespace_t(char feature_group, uint64_t hash, std::vector<feature_t>&& features) : has_name(false), name(nullptr), features(features), hash(hash), feature_group(feature_group)
prototype_namespace_t(char feature_group, uint64_t hash, const std::vector<feature_t>& features) : has_name(false), name(nullptr), features{features}, hash(hash), feature_group(feature_group)
{
}

prototype_namespace_t(prototype_namespace_t&& other) = delete;
prototype_namespace_t(const prototype_namespace_t& other) : has_name(other.has_name), name(other.name), features(other.features), hash(other.hash), feature_group(other.feature_group)
prototype_namespace_t(const prototype_namespace_t& other) : has_name(other.has_name), name(other.name), features{other.features}, hash(other.hash), feature_group(other.feature_group)
{
}

Expand Down Expand Up @@ -91,7 +89,7 @@ struct prototype_namespace_t

const auto name_offset = has_name ? builder.CreateString(name) : Offset<String>();

Offset<Vector<Offset<String>>> feature_names_offset = include_feature_names ? builder.CreateVector(feature_names) : Offset<Vector<Offset<String>>>() /* nullptr */;
Offset<Vector<Offset<String>>> feature_names_offset = builder.CreateVector(feature_names);
Offset<Vector<float>> feature_values_offset = builder.CreateVector(feature_values);
Offset<Vector<uint64_t>> feature_hashes_offset = builder.CreateVector(feature_hashes);

Expand Down Expand Up @@ -175,9 +173,9 @@ struct prototype_namespace_t
EXPECT_LT(extent_index, features.namespace_extents.size());
const auto& extent = features.namespace_extents[extent_index];

for (size_t i_f = extent.begin_index; i_f < extent.end_index; i_f++)
for (size_t i_f = extent.begin_index, i = 0; i_f < extent.end_index && i < this->features.size(); i_f++, i++)
{
auto& f = this->features[i_f];
auto& f = this->features[i];
if (f.has_name)
{
EXPECT_EQ(features.indices[i_f], VW::hash_feature(w, f.name, hash));
Expand Down

0 comments on commit f9cb076

Please sign in to comment.