Skip to content

Commit

Permalink
Fix warnings
Browse files Browse the repository at this point in the history
  • Loading branch information
vshampor committed Nov 19, 2024
1 parent a33f255 commit a0818c2
Showing 1 changed file with 36 additions and 28 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ static T get_tolerance() {

template <>
float get_tolerance<float>() {
return 1e-6;
return 1e-6f;
}

template <>
Expand All @@ -135,7 +135,7 @@ ov::bfloat16 get_tolerance<ov::bfloat16>() {
}

template <class TypeParam>
class CacheRotationKernelTest : public ::testing::Test {
class CacheRotationKernelInputTypeParameterizedTest : public ::testing::Test {
public:
void SetUp() override {
Rank3Matrix<TypeParam> values_before_rotation = {
Expand Down Expand Up @@ -170,16 +170,16 @@ class CacheRotationKernelTest : public ::testing::Test {
std::shared_ptr<float[]> rotation_coefficients_mem_ptr;
Rank3Matrix<TypeParam> ref_values_after_rotation = {
{
{-0.36602540, 1.41421356, 1.36602540, 0.00000000},
{0.36602540, 1.00000000, 1.36602540, 1.00000000},
{-1.41421356, -1.00000000, 0.00000000, 1.00000000},
{1.00000000, 1.40000000, -1.00000000, -0.20000000},
{-0.36602540f, 1.41421356f, 1.36602540f, 0.00000000f},
{0.36602540f, 1.00000000f, 1.36602540f, 1.00000000f},
{-1.41421356f, -1.00000000f, 0.00000000f, 1.00000000f},
{1.00000000f, 1.40000000f, -1.00000000f, -0.20000000f},
},
{
{0.73205081, -2.82842712, -2.73205081, 0.00000000},
{0.73205081, 2.00000000, 2.73205081, 2.00000000},
{2.82842712, -4.00000000, 1.41421356, 2.00000000},
{2.00000000, 2.80000000, -2.00000000, -0.40000000},
{0.73205081f, -2.82842712f, -2.73205081f, 0.00000000f},
{0.73205081f, 2.00000000f, 2.73205081f, 2.00000000f},
{2.82842712f, -4.00000000f, 1.41421356f, 2.00000000f},
{2.00000000f, 2.80000000f, -2.00000000f, -0.40000000f},
},
};

Expand Down Expand Up @@ -235,9 +235,9 @@ class CacheRotationKernelTest : public ::testing::Test {

using OV_FP_TYPES = ::testing::Types<float, ov::float16, ov::bfloat16>;

TYPED_TEST_SUITE(CacheRotationKernelTest, OV_FP_TYPES);
TYPED_TEST_SUITE_P(CacheRotationKernelInputTypeParameterizedTest);

TYPED_TEST(CacheRotationKernelTest, SWBlockRotationGivesReferenceResults) {
TYPED_TEST_P(CacheRotationKernelInputTypeParameterizedTest, SWBlockRotationGivesReferenceResults) {
auto raw_cache_mem_ptr = this->cache_mem_ptr.get();
auto raw_rotation_coefficients_mem_ptr = this->rotation_coefficients_mem_ptr.get();

Expand Down Expand Up @@ -272,7 +272,8 @@ MATCHER_P3(IsNFirstValuesNear, ref_container, abs_err, n, "") {
return is_ok;
}

class CacheRotationHWKernelTest : public ::testing::TestWithParam<std::tuple<TargetInstructionSet, size_t>> {
class CacheRotationKernelInstructionParameterizedTest
: public ::testing::TestWithParam<std::tuple<TargetInstructionSet, size_t>> {
protected:
constexpr static size_t MAX_CHUNK_SIZE_IN_ELEMENTS = 16;
template <class T>
Expand Down Expand Up @@ -430,37 +431,38 @@ class CacheRotationHWKernelTest : public ::testing::TestWithParam<std::tuple<Tar
}
};

TEST_P(CacheRotationHWKernelTest, HWChunkRotationGivesReferenceResults) {
TEST_P(CacheRotationKernelInstructionParameterizedTest, HWChunkRotationGivesReferenceResults) {
test_chunk_rotation_for_type<float>();
test_chunk_rotation_for_type<ov::float16>();
test_chunk_rotation_for_type<ov::bfloat16>();
}

auto TEST_STRUCT_TO_NAME_FN = [](const testing::TestParamInfo<CacheRotationHWKernelTest::ParamType>& info) {
size_t num_elts = std::get<1>(info.param);
switch (std::get<0>(info.param)) {
case TargetInstructionSet::AVX2:
return std::string("avx2-") + std::to_string(num_elts);
case TargetInstructionSet::AVX512:
return std::string("avx512-") + std::to_string(num_elts);
}
return std::string("unknown");
};
auto TEST_STRUCT_TO_NAME_FN =
[](const testing::TestParamInfo<CacheRotationKernelInstructionParameterizedTest::ParamType>& info) {
size_t num_elts = std::get<1>(info.param);
switch (std::get<0>(info.param)) {
case TargetInstructionSet::AVX2:
return std::string("avx2-") + std::to_string(num_elts);
case TargetInstructionSet::AVX512:
return std::string("avx512-") + std::to_string(num_elts);
}
return std::string("unknown");
};

INSTANTIATE_TEST_SUITE_P(AVX2,
CacheRotationHWKernelTest,
CacheRotationKernelInstructionParameterizedTest,
::testing::Combine(::testing::Values(TargetInstructionSet::AVX2),
::testing::Range(size_t(0),
ov::Extensions::Cpu::XARCH::vec_len_f32_avx2 + 1)),
TEST_STRUCT_TO_NAME_FN);
INSTANTIATE_TEST_SUITE_P(AVX512,
CacheRotationHWKernelTest,
CacheRotationKernelInstructionParameterizedTest,
::testing::Combine(::testing::Values(TargetInstructionSet::AVX512),
::testing::Range(size_t(0),
ov::Extensions::Cpu::XARCH::vec_len_f32_avx512 + 1)),
TEST_STRUCT_TO_NAME_FN);

TYPED_TEST(CacheRotationKernelTest, HWBlockRotationGivesReferenceResults) {
TYPED_TEST_P(CacheRotationKernelInputTypeParameterizedTest, HWBlockRotationGivesReferenceResults) {
auto raw_cache_mem_ptr = this->cache_mem_ptr.get();
auto raw_rotation_coefficients_mem_ptr = this->rotation_coefficients_mem_ptr.get();

Expand All @@ -475,10 +477,16 @@ TYPED_TEST(CacheRotationKernelTest, HWBlockRotationGivesReferenceResults) {
compare_with_tolerance(test_values_after_rotation, this->ref_values_after_rotation, get_tolerance<TypeParam>());
}

TYPED_TEST(CacheRotationKernelTest, HWBlockRotationIsSimilarToSW) {
TYPED_TEST_P(CacheRotationKernelInputTypeParameterizedTest, HWBlockRotationIsSimilarToSW) {
// short case
this->test_block_hw_vs_sw(/* num_heads = */ 4, /* embedding_size = */ 64, /* block_size = */ 2);

// long case
this->test_block_hw_vs_sw(256, 1024, 32);
}

REGISTER_TYPED_TEST_SUITE_P(CacheRotationKernelInputTypeParameterizedTest,
SWBlockRotationGivesReferenceResults,
HWBlockRotationGivesReferenceResults,
HWBlockRotationIsSimilarToSW);
INSTANTIATE_TYPED_TEST_SUITE_P(AllFPTypes, CacheRotationKernelInputTypeParameterizedTest, OV_FP_TYPES);

0 comments on commit a0818c2

Please sign in to comment.