diff --git a/vowpalwabbit/core/include/vw/core/error_data.h b/vowpalwabbit/core/include/vw/core/error_data.h index edf4376715e..847f52f190b 100644 --- a/vowpalwabbit/core/include/vw/core/error_data.h +++ b/vowpalwabbit/core/include/vw/core/error_data.h @@ -26,8 +26,10 @@ ERROR_CODE_DEFINITION( 13, fb_parser_size_mismatch_ft_names_ft_values, "Size of feature names and feature values do not match. ") ERROR_CODE_DEFINITION(14, unknown_label_type, "Label type in Flatbuffer not understood. ") + // TODO: This is temporary until we switch to the new error handling mechanism. ERROR_CODE_DEFINITION(10000, vw_exception, "vw_exception: ") +ERROR_CODE_DEFINITION(20000, internal_error, "BUGBUG: ") #ifdef ERROR_CODE_DEFINITION_NOOP #undef ERROR_CODE_DEFINITION diff --git a/vowpalwabbit/core/include/vw/core/io_buf.h b/vowpalwabbit/core/include/vw/core/io_buf.h index 2cb06c8b0fd..84834b4733f 100644 --- a/vowpalwabbit/core/include/vw/core/io_buf.h +++ b/vowpalwabbit/core/include/vw/core/io_buf.h @@ -44,6 +44,56 @@ namespace VW { +struct desired_align +{ + using align_t = size_t; + + align_t align; + + // DO NOT USE THIS UNLESS YOU *REALLY* KNOW WHAT YOU ARE DOING + // Off-alignment reads are UB. Only use this if you know you need an offset + // from a true aligned address. + align_t offset; + + template + static constexpr desired_align align_for(align_t offset = 0) + { + return desired_align{compute_align(), offset}; + } + + desired_align(align_t align = 1, align_t offset = 0) : align(align), offset(offset) {} + + struct flatbuffer_t { flatbuffer_t() = delete; }; + + // print to ostream + friend std::ostream& operator<<(std::ostream& os, const desired_align& da) + { + os << "align: " << da.align << ", offset: " << da.offset; + return os; + } + + inline bool is_aligned(const void* ptr) const + { + return (reinterpret_cast(ptr) % align) == offset; + } + +private: + template + static constexpr align_t compute_align() + { + // if T is a flatbuffer type, we need to align to 8 bytes, + // otherwise alignof(T) + return std::is_base_of::value ? 8 : alignof(T); + } +}; + + + +namespace known_alignments +{ + const desired_align TEXT = desired_align::align_for(); +} + class io_buf { public: @@ -204,7 +254,7 @@ class io_buf } void buf_write(char*& pointer, size_t n); - size_t buf_read(char*& pointer, size_t n); + size_t buf_read(char*& pointer, size_t n, desired_align align = known_alignments::TEXT); size_t bin_read_fixed(char* data, size_t len) { @@ -274,15 +324,29 @@ class io_buf memset(end, 0, sizeof(char) * (end_array - end)); } - void shift_to_front(char* head_ptr) + void shift_to_front(char*& head_ptr, desired_align align = known_alignments::TEXT) { + size_t required_padding = 0; + if (align.align != 1) + { + // we are moving head => begin, but if begin is misaligned, we need to pad it + size_t begin_address = reinterpret_cast(begin); + if (begin_address % align.align != align.offset) + { + required_padding = ((align.align << 1) - (begin_address % align.align) + align.offset) % align.align; + + required_padding /= sizeof(char); // sizeof(char) = 1, but this is more explicit + } + } + assert(end >= head_ptr); const size_t space_left = end - head_ptr; // Only call memmove if we are within the bounds of the loaded buffer. // Also, this ensures we don't memmove when head_ptr == end_array which // would be undefined behavior. - if (head_ptr >= begin && head_ptr < end) { std::memmove(begin, head_ptr, space_left); } - end = begin + space_left; + if (head_ptr >= (begin + required_padding) && head_ptr < end) { std::memmove(begin + required_padding, head_ptr, space_left); } + end = begin + required_padding + space_left; + head_ptr = begin + required_padding; } size_t capacity() const { return end_array - begin; } diff --git a/vowpalwabbit/core/src/io_buf.cc b/vowpalwabbit/core/src/io_buf.cc index 3a15693bab3..c4314c52044 100644 --- a/vowpalwabbit/core/src/io_buf.cc +++ b/vowpalwabbit/core/src/io_buf.cc @@ -3,7 +3,7 @@ // license as described in the file LICENSE. #include "vw/core/io_buf.h" -size_t VW::io_buf::buf_read(char*& pointer, size_t n) +size_t VW::io_buf::buf_read(char*& pointer, size_t n, desired_align align) { // return a pointer to the next n bytes. n must be smaller than the maximum size. if (_head + n <= _buffer.end) @@ -17,8 +17,7 @@ size_t VW::io_buf::buf_read(char*& pointer, size_t n) if (_head != _buffer.begin) // There exists room to shift. { // Out of buffer so swap to beginning. - _buffer.shift_to_front(_head); - _head = _buffer.begin; + _buffer.shift_to_front(_head, align); } if (_current < _input_files.size() && fill(_input_files[_current].get()) > 0) { // read more bytes from _current file if present diff --git a/vowpalwabbit/fb_parser/include/vw/fb_parser/parse_example_flatbuffer.h b/vowpalwabbit/fb_parser/include/vw/fb_parser/parse_example_flatbuffer.h index 8a419eb76e9..e8078f8d5d5 100644 --- a/vowpalwabbit/fb_parser/include/vw/fb_parser/parse_example_flatbuffer.h +++ b/vowpalwabbit/fb_parser/include/vw/fb_parser/parse_example_flatbuffer.h @@ -28,6 +28,7 @@ class parser VW::experimental::api_status* status = nullptr); private: + size_t _num_example_roots = 0; const VW::parsers::flatbuffer::ExampleRoot* _data; uint8_t* _flatbuffer_pointer; flatbuffers::uoffset_t _object_size = 0; @@ -39,7 +40,7 @@ class parser uint32_t _labeled_action = 0; uint64_t _c_hash = 0; - bool parse(io_buf& buf, uint8_t* buffer_pointer = nullptr); + int parse(io_buf& buf, uint8_t* buffer_pointer = nullptr, VW::experimental::api_status* status = nullptr); int process_collection_item( VW::workspace* all, VW::multi_ex& examples, VW::experimental::api_status* status = nullptr); int parse_example(VW::workspace* all, example* ae, const Example* eg, VW::experimental::api_status* status = nullptr); diff --git a/vowpalwabbit/fb_parser/src/parse_example_flatbuffer.cc b/vowpalwabbit/fb_parser/src/parse_example_flatbuffer.cc index 6a3aa3e2c27..4eafa302fa4 100644 --- a/vowpalwabbit/fb_parser/src/parse_example_flatbuffer.cc +++ b/vowpalwabbit/fb_parser/src/parse_example_flatbuffer.cc @@ -15,6 +15,7 @@ #include #include #include +#include namespace VW { @@ -28,8 +29,13 @@ int flatbuffer_to_examples(VW::workspace* all, io_buf& buf, VW::multi_ex& exampl { // TODO: At what point do we report the error? VW::experimental::api_status status; - return static_cast(all->parser_runtime.flat_converter->parse_examples(all, buf, examples, nullptr, - &status) == VW::experimental::error_code::success); + if (all->parser_runtime.flat_converter->parse_examples(all, buf, examples, nullptr, + &status) != VW::experimental::error_code::success) + { + std::cerr << "Error parsing examples: " << status.get_error_msg() << std::endl; + } + + return static_cast(status.get_error_code()); } else return static_cast(all->parser_runtime.flat_converter->parse_examples(all, buf, examples, nullptr, nullptr) == @@ -38,32 +44,59 @@ int flatbuffer_to_examples(VW::workspace* all, io_buf& buf, VW::multi_ex& exampl const VW::parsers::flatbuffer::ExampleRoot* parser::data() { return _data; } -bool parser::parse(io_buf& buf, uint8_t* buffer_pointer) +int parser::parse(io_buf& buf, uint8_t* buffer_pointer, VW::experimental::api_status* status) { +#define RETURN_IF_ALIGN_ERROR(target_align, actual_ptr, example_root_count) \ + if (!target_align.is_aligned(actual_ptr)) \ + { \ + size_t address = reinterpret_cast(actual_ptr); \ + RETURN_ERROR_LS(status, internal_error) \ + << "fb_parser error: flatbuffer data not aligned to " << target_align << '\n' \ + << " example_root[" << example_root_count << "] => @" << address << " % " \ + << target_align.align << " - " << target_align.offset << " = " \ + << address % target_align.align - target_align.offset << '\n' \ + << " ^^ -4 is the size of the flatbuffer prefix, which we read explicitly."; \ + } + + constexpr std::size_t EXPECTED_ALIGNMENT = 8; // this is where FB expects the size-prefixed FB to be aligned + constexpr std::size_t EXPECTED_OFFSET = sizeof(uint32_t); // + + desired_align align_prefixed = {EXPECTED_ALIGNMENT, 0}; + desired_align align_data = {EXPECTED_ALIGNMENT, EXPECTED_OFFSET}; + if (buffer_pointer) { + RETURN_IF_ALIGN_ERROR(align_prefixed, buffer_pointer, _num_example_roots); + _flatbuffer_pointer = buffer_pointer; _data = VW::parsers::flatbuffer::GetSizePrefixedExampleRoot(_flatbuffer_pointer); - return true; + _num_example_roots++; + return VW::experimental::error_code::success; } char* line = nullptr; - auto len = buf.buf_read(line, sizeof(uint32_t)); + auto len = buf.buf_read(line, sizeof(uint32_t), align_prefixed); - if (len < sizeof(uint32_t)) { return false; } + if (len < sizeof(uint32_t)) { RETURN_ERROR(status, nothing_to_parse); } _object_size = flatbuffers::ReadScalar(line); // read one object, object size defined by the read prefix - buf.buf_read(line, _object_size); + buf.buf_read(line, _object_size, align_data); + + RETURN_IF_ALIGN_ERROR(align_data, line, _num_example_roots); _flatbuffer_pointer = reinterpret_cast(line); _data = VW::parsers::flatbuffer::GetExampleRoot(_flatbuffer_pointer); - return true; + + _num_example_roots++; + return VW::experimental::error_code::success; + +#undef RETURN_IF_ALIGN_ERROR } -int parser::process_collection_item(VW::workspace* all, VW::multi_ex& examples, VW::experimental::api_status* status) +int parser::process_collection_item(VW::workspace* all, VW::multi_ex& examples, VW::experimental::api_status*) { // new example/multi example object to process from collection if (_data->example_obj_as_ExampleCollection()->is_multiline()) @@ -101,7 +134,7 @@ int parser::parse_examples(VW::workspace* all, io_buf& buf, VW::multi_ex& exampl else { // new object to be read from file - if (!parse(buf, buffer_pointer)) { RETURN_ERROR(status, nothing_to_parse); } + RETURN_IF_FAIL(parse(buf, buffer_pointer)); switch (_data->example_obj_type()) { @@ -132,7 +165,7 @@ int parser::parse_examples(VW::workspace* all, io_buf& buf, VW::multi_ex& exampl return VW::experimental::error_code::success; } -int parser::parse_example(VW::workspace* all, example* ae, const Example* eg, VW::experimental::api_status* status) +int parser::parse_example(VW::workspace* all, example* ae, const Example* eg, VW::experimental::api_status*) { all->parser_runtime.example_parser->lbl_parser.default_label(ae->l); ae->is_newline = eg->is_newline(); @@ -150,7 +183,7 @@ int parser::parse_example(VW::workspace* all, example* ae, const Example* eg, VW } int parser::parse_multi_example( - VW::workspace* all, example* ae, const MultiExample* eg, VW::experimental::api_status* status) + VW::workspace* all, example* ae, const MultiExample* eg, VW::experimental::api_status*) { all->parser_runtime.example_parser->lbl_parser.default_label(ae->l); if (_multi_ex_index >= eg->examples()->size())