diff --git a/compute/cker/include/cker/operation/RmsNorm.h b/compute/cker/include/cker/operation/RmsNorm.h index 97eb2a2982d..dee3f618428 100644 --- a/compute/cker/include/cker/operation/RmsNorm.h +++ b/compute/cker/include/cker/operation/RmsNorm.h @@ -33,37 +33,70 @@ inline void RmsNorm(const RmsNormParams ¶ms, const Shape &input_shape, const const Shape &gamma_shape, const float *gamma_data, const Shape &output_shape, float *output_data) { - const int32_t batches = MatchingDim(input_shape, 0, output_shape, 0); - const int32_t heights = MatchingDim(input_shape, 1, output_shape, 1); - const int32_t widths = MatchingDim(input_shape, 2, output_shape, 2); - const int32_t channels = MatchingDim(input_shape, 3, output_shape, 3); + bool single_gamma = gamma_shape.DimensionsCount() == 1 && gamma_shape.Dims(0) == 1; - if (gamma_shape.DimensionsCount() != 1 || - gamma_shape.Dims(0) != input_shape.Dims(input_shape.DimensionsCount() - 1)) - throw std::runtime_error("cker::RmsNorm: Unmatched gamma shape"); + if (input_shape.DimensionsCount() == 4) + { + const int32_t batches = MatchingDim(input_shape, 0, output_shape, 0); + const int32_t heights = MatchingDim(input_shape, 1, output_shape, 1); + const int32_t widths = MatchingDim(input_shape, 2, output_shape, 2); + const int32_t channels = MatchingDim(input_shape, 3, output_shape, 3); - for (int32_t batch = 0; batch < batches; batch++) + for (int32_t batch = 0; batch < batches; batch++) + { + for (int32_t height = 0; height < heights; height++) + { + for (int32_t width = 0; width < widths; width++) + { + // normalize over last-axis + double square_sum = 0.0f; + for (int32_t channel = 0; channel < channels; channel++) + { + double input_val = input_data[Offset(input_shape, batch, height, width, channel)]; + square_sum += (input_val * input_val); + } + double rms = std::sqrt((square_sum / channels) + params.epsilon); + for (int32_t channel = 0; channel < channels; channel++) + { + double gamma = (single_gamma ? gamma_data[0] : gamma_data[channel]); + output_data[Offset(output_shape, batch, height, width, channel)] = + gamma * (input_data[Offset(input_shape, batch, height, width, channel)] / rms); + } + } + } + } + } + else if (input_shape.DimensionsCount() == 3) { + const int32_t heights = MatchingDim(input_shape, 1, output_shape, 0); + const int32_t widths = MatchingDim(input_shape, 2, output_shape, 1); + const int32_t channels = MatchingDim(input_shape, 3, output_shape, 2); + for (int32_t height = 0; height < heights; height++) { for (int32_t width = 0; width < widths; width++) { + // normalize over last-axis double square_sum = 0.0f; for (int32_t channel = 0; channel < channels; channel++) { - double input_val = input_data[Offset(input_shape, batch, height, width, channel)]; + double input_val = input_data[(height * widths + width) * channels + channel]; square_sum += (input_val * input_val); } double rms = std::sqrt((square_sum / channels) + params.epsilon); for (int32_t channel = 0; channel < channels; channel++) { - double gamma = (gamma_data ? gamma_data[channel] : 1.0); - output_data[Offset(output_shape, batch, height, width, channel)] = - gamma * (input_data[Offset(input_shape, batch, height, width, channel)] / rms); + double gamma = (single_gamma ? gamma_data[0] : gamma_data[channel]); + output_data[(height * widths + width) * channels + channel] = + gamma * (input_data[(height * widths + width) * channels + channel] / rms); } } } } + else + { + throw std::runtime_error("cker::RmsNorm: Unsupported input shape"); + } } } // namespace cker diff --git a/compute/cker/src/RmsNorm.test.cc b/compute/cker/src/RmsNorm.test.cc index 3179f666b65..926524d5860 100644 --- a/compute/cker/src/RmsNorm.test.cc +++ b/compute/cker/src/RmsNorm.test.cc @@ -67,22 +67,21 @@ TEST(CKer_Operation, RmsNorm) } } -TEST(CKer_Operation, neg_RmsNormWrongGammaDims) +TEST(CKer_Operation, neg_RmsNormWrongInputDims) { { - std::vector input = {0, 1, 2, 3, 4, 5, 6, 7}; - nnfw::cker::Shape input_shape{1, 2, 2, 2}; + std::vector input = {0, 1, 2, 3}; + nnfw::cker::Shape input_shape{2, 2}; - std::vector expected_output = {0, 1.412802, 0.784404, 1.176606, - 0.883431, 1.104288, 0.920347, 1.073738}; + std::vector expected_output = {0, 1, 1, 1}; std::vector output(expected_output.size()); - nnfw::cker::Shape output_shape{1, 2, 2, 2}; + nnfw::cker::Shape output_shape{2, 2}; std::vector gamma = {1}; nnfw::cker::Shape gamma_shape{1}; nnfw::cker::RmsNormParams param; - param.epsilon = 0.001f; + param.epsilon = 0.00001f; EXPECT_ANY_THROW(nnfw::cker::RmsNorm(param, input_shape, input.data(), gamma_shape, gamma.data(), output_shape, output.data()));