Skip to content

Commit

Permalink
feat: add unit tests for desired_align
Browse files Browse the repository at this point in the history
  • Loading branch information
lokitoth committed Feb 1, 2024
1 parent 926ef5d commit 1b16e13
Show file tree
Hide file tree
Showing 4 changed files with 110 additions and 4 deletions.
1 change: 1 addition & 0 deletions vowpalwabbit/core/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -481,6 +481,7 @@ set(vw_core_test_sources
tests/flat_example_test.cc
tests/guard_test.cc
tests/interactions_test.cc
tests/io_alignment_test.cc
tests/loss_functions_test.cc
tests/math_test.cc
tests/merge_header_opts_test.cc
Expand Down
4 changes: 3 additions & 1 deletion vowpalwabbit/core/include/vw/core/io_buf.h
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,8 @@ struct desired_align
struct flatbuffer_t
{
flatbuffer_t() = delete;

static constexpr align_t align = 8;
};

// print to ostream
Expand All @@ -83,7 +85,7 @@ struct desired_align
{
// if T is a flatbuffer type, we need to align to 8 bytes,
// otherwise alignof(T)
return std::is_base_of<flatbuffer_t, T>::value ? 8 : alignof(T);
return std::is_base_of<flatbuffer_t, T>::value ? flatbuffer_t::align : alignof(T);
}
};

Expand Down
6 changes: 3 additions & 3 deletions vowpalwabbit/core/src/io_buf.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,12 @@ size_t VW::io_buf::buf_read(char*& pointer, size_t n, desired_align align)
_buffer.shift_to_front(_head, align);
}
if (_current < _input_files.size() && fill(_input_files[_current].get()) > 0)
{ // read more bytes from _current file if present
return buf_read(pointer, n); // more bytes are read.
{ // read more bytes from _current file if present
return buf_read(pointer, n, align); // more bytes are read.
}
else if (++_current < _input_files.size())
{
return buf_read(pointer, n); // No more bytes, so go to next file and try again.
return buf_read(pointer, n, align); // No more bytes, so go to next file and try again.
}
else
{
Expand Down
103 changes: 103 additions & 0 deletions vowpalwabbit/core/tests/io_alignment_test.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
// Copyright (c) by respective owners including Yahoo!, Microsoft, and
// individual contributors. All rights reserved. Released under a BSD (revised)
// license as described in the file LICENSE.

#include "vw/core/io_buf.h"

#include <gmock/gmock.h>
#include <gtest/gtest.h>

using namespace testing;

using VW::desired_align::align_t;
using VW::desired_align::flatbuffer_t;
using VW::desired_align;

struct positioned_ptr
{
void* allocation_unit;
int8_t* p; static_assert(sizeof(int8_t) == 1, "int8_t is not 1 byte");
size_t allocation;

positioned_ptr(size_t allocation) : allocation(allocation), allocation_unit(malloc(allocation)), p(reinterpret_cast<int8_t*>(allocation_unit)) {}
~positioned_ptr() { free(allocation_unit); }

void realign(align_t alignment, align_t offset)
{
size_t base_address = reinterpret_cast<size_t>(allocation_unit);
size_t base_offset = base_address % alignment;

size_t padding = alignment - base_offset + offset;
assert(padding < allocation);

p += padding;
}
};

template <typename T>
positioned_ptr prepare_pointer(align_t offset)
{
VW::align_t base_alignment = alignof(T);
size_t playground = 2 * sizeof(T);

positioned_ptr ptr(playground);
ptr.realign(base_alignment, offset);

return ptr;
}

template <>
positioned_ptr prepare_pointer<flatbuffer_t>(align_t offset)
{
VW::align_t base_alignment = flatbuffer_t::align;
size_t playground = 16;

positioned_ptr ptr(playground);
ptr.realign(base_alignment, offset);

return ptr;
}

template <typename T>
void test_desired_alignment_checker(align_t offset)
{
typename desired_align da = desired_align::align_for<T>(offset);

for (size_t i_offset = 0; i_offset < da.alignment, i_offset++)
{
positioned_ptr ptr = prepare_pointer<T>(i_offset);

if (i_offset == offset)
{
EXPECT_TRUE(da.is_aligned(ptr.p));
}
else
{
EXPECT_FALSE(da.is_aligned(ptr.p));
}
}
}

template <typename T>
void test_all_alignments()
{
for (size_t i_offset = 0; i_offset < alignof(T); i_offset++)
{
test_desired_alignment_checker<T>(i_offset);
}
}

TEST(DesiredAlign, TestsAlignmentCorrectly)
{
test_all_alignments<int8_t>();
test_all_alignments<int16_t>();
test_all_alignments<int32_t>();
test_all_alignments<int64_t>();
test_all_alignments<uint8_t>();
test_all_alignments<uint16_t>();
test_all_alignments<uint32_t>();
test_all_alignments<float>();
test_all_alignments<double>();
test_all_alignments<char>();
test_all_alignments<flatbuffer_t>();
}

0 comments on commit 1b16e13

Please sign in to comment.