Skip to content

Commit

Permalink
[compute/cker] Fix RMSNorm shape assert error (#14247)
Browse files Browse the repository at this point in the history
* [compute/cker] Fix RMSNorm shape assert error

This commit fixes shape assert error when running model including RMSNorm operation.

ONE-DCO-1.0-Signed-off-by: Seockho Kim [email protected]

* [compute/cker] Add RMSNorm unittests

Unit test added for RMSNorm to test rank 3 input.

ONE-DCO-1.0-Signed-off-by: Seockho Kim [email protected]
  • Loading branch information
seockho-kim authored Oct 24, 2024
1 parent 9c49d99 commit 1ba6970
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 4 deletions.
6 changes: 3 additions & 3 deletions compute/cker/include/cker/operation/RmsNorm.h
Original file line number Diff line number Diff line change
Expand Up @@ -68,9 +68,9 @@ inline void RmsNorm(const RmsNormParams &params, const Shape &input_shape, const
}
else if (input_shape.DimensionsCount() == 3)
{
const int32_t heights = MatchingDim(input_shape, 1, output_shape, 0);
const int32_t widths = MatchingDim(input_shape, 2, output_shape, 1);
const int32_t channels = MatchingDim(input_shape, 3, output_shape, 2);
const int32_t heights = MatchingDim(input_shape, 0, output_shape, 0);
const int32_t widths = MatchingDim(input_shape, 1, output_shape, 1);
const int32_t channels = MatchingDim(input_shape, 2, output_shape, 2);

for (int32_t height = 0; height < heights; height++)
{
Expand Down
25 changes: 24 additions & 1 deletion compute/cker/src/RmsNorm.test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ TEST(CKer_Operation, RmsNorm)
EXPECT_NEAR(output[i], expected_output[i], 1e-5f);
}

// Default gamma
// rank 4
{
std::vector<float> input = {0, 1, 2, 3, 4, 5, 6, 7};
nnfw::cker::Shape input_shape{1, 2, 2, 2};
Expand All @@ -65,6 +65,29 @@ TEST(CKer_Operation, RmsNorm)
for (size_t i = 0; i < expected_output.size(); ++i)
EXPECT_NEAR(output[i], expected_output[i], 1e-5f);
}

// rank 3
{
std::vector<float> input = {0, 1, 2, 3, 4, 5, 6, 7};
nnfw::cker::Shape input_shape{2, 2, 2};

std::vector<float> expected_output = {0, 1.412802, 0.784404, 1.176606,
0.883431, 1.104288, 0.920347, 1.073738};
std::vector<float> output(expected_output.size());
nnfw::cker::Shape output_shape{2, 2, 2};

std::vector<float> gamma = {1, 1};
nnfw::cker::Shape gamma_shape{2};

nnfw::cker::RmsNormParams param;
param.epsilon = 0.001f;

nnfw::cker::RmsNorm(param, input_shape, input.data(), gamma_shape, gamma.data(), output_shape,
output.data());

for (size_t i = 0; i < expected_output.size(); ++i)
EXPECT_NEAR(output[i], expected_output[i], 1e-5f);
}
}

TEST(CKer_Operation, neg_RmsNormWrongInputDims)
Expand Down

0 comments on commit 1ba6970

Please sign in to comment.