Skip to content

Commit

Permalink
[compute/cker] Fix RmsNorm cker (#14218)
Browse files Browse the repository at this point in the history
This commit fixes RmsNorm cker to accept 3 dimension input and single gamma.

ONE-DCO-1.0-Signed-off-by: Seockho Kim [email protected]
  • Loading branch information
seockho-kim authored Oct 16, 2024
1 parent ec45ed2 commit c5fd64a
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 19 deletions.
57 changes: 45 additions & 12 deletions compute/cker/include/cker/operation/RmsNorm.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,37 +33,70 @@ inline void RmsNorm(const RmsNormParams &params, const Shape &input_shape, const
const Shape &gamma_shape, const float *gamma_data, const Shape &output_shape,
float *output_data)
{
const int32_t batches = MatchingDim(input_shape, 0, output_shape, 0);
const int32_t heights = MatchingDim(input_shape, 1, output_shape, 1);
const int32_t widths = MatchingDim(input_shape, 2, output_shape, 2);
const int32_t channels = MatchingDim(input_shape, 3, output_shape, 3);
bool single_gamma = gamma_shape.DimensionsCount() == 1 && gamma_shape.Dims(0) == 1;

if (gamma_shape.DimensionsCount() != 1 ||
gamma_shape.Dims(0) != input_shape.Dims(input_shape.DimensionsCount() - 1))
throw std::runtime_error("cker::RmsNorm: Unmatched gamma shape");
if (input_shape.DimensionsCount() == 4)
{
const int32_t batches = MatchingDim(input_shape, 0, output_shape, 0);
const int32_t heights = MatchingDim(input_shape, 1, output_shape, 1);
const int32_t widths = MatchingDim(input_shape, 2, output_shape, 2);
const int32_t channels = MatchingDim(input_shape, 3, output_shape, 3);

for (int32_t batch = 0; batch < batches; batch++)
for (int32_t batch = 0; batch < batches; batch++)
{
for (int32_t height = 0; height < heights; height++)
{
for (int32_t width = 0; width < widths; width++)
{
// normalize over last-axis
double square_sum = 0.0f;
for (int32_t channel = 0; channel < channels; channel++)
{
double input_val = input_data[Offset(input_shape, batch, height, width, channel)];
square_sum += (input_val * input_val);
}
double rms = std::sqrt((square_sum / channels) + params.epsilon);
for (int32_t channel = 0; channel < channels; channel++)
{
double gamma = (single_gamma ? gamma_data[0] : gamma_data[channel]);
output_data[Offset(output_shape, batch, height, width, channel)] =
gamma * (input_data[Offset(input_shape, batch, height, width, channel)] / rms);
}
}
}
}
}
else if (input_shape.DimensionsCount() == 3)
{
const int32_t heights = MatchingDim(input_shape, 1, output_shape, 0);
const int32_t widths = MatchingDim(input_shape, 2, output_shape, 1);
const int32_t channels = MatchingDim(input_shape, 3, output_shape, 2);

for (int32_t height = 0; height < heights; height++)
{
for (int32_t width = 0; width < widths; width++)
{
// normalize over last-axis
double square_sum = 0.0f;
for (int32_t channel = 0; channel < channels; channel++)
{
double input_val = input_data[Offset(input_shape, batch, height, width, channel)];
double input_val = input_data[(height * widths + width) * channels + channel];
square_sum += (input_val * input_val);
}
double rms = std::sqrt((square_sum / channels) + params.epsilon);
for (int32_t channel = 0; channel < channels; channel++)
{
double gamma = (gamma_data ? gamma_data[channel] : 1.0);
output_data[Offset(output_shape, batch, height, width, channel)] =
gamma * (input_data[Offset(input_shape, batch, height, width, channel)] / rms);
double gamma = (single_gamma ? gamma_data[0] : gamma_data[channel]);
output_data[(height * widths + width) * channels + channel] =
gamma * (input_data[(height * widths + width) * channels + channel] / rms);
}
}
}
}
else
{
throw std::runtime_error("cker::RmsNorm: Unsupported input shape");
}
}

} // namespace cker
Expand Down
13 changes: 6 additions & 7 deletions compute/cker/src/RmsNorm.test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -67,22 +67,21 @@ TEST(CKer_Operation, RmsNorm)
}
}

TEST(CKer_Operation, neg_RmsNormWrongGammaDims)
TEST(CKer_Operation, neg_RmsNormWrongInputDims)
{
{
std::vector<float> input = {0, 1, 2, 3, 4, 5, 6, 7};
nnfw::cker::Shape input_shape{1, 2, 2, 2};
std::vector<float> input = {0, 1, 2, 3};
nnfw::cker::Shape input_shape{2, 2};

std::vector<float> expected_output = {0, 1.412802, 0.784404, 1.176606,
0.883431, 1.104288, 0.920347, 1.073738};
std::vector<float> expected_output = {0, 1, 1, 1};
std::vector<float> output(expected_output.size());
nnfw::cker::Shape output_shape{1, 2, 2, 2};
nnfw::cker::Shape output_shape{2, 2};

std::vector<float> gamma = {1};
nnfw::cker::Shape gamma_shape{1};

nnfw::cker::RmsNormParams param;
param.epsilon = 0.001f;
param.epsilon = 0.00001f;

EXPECT_ANY_THROW(nnfw::cker::RmsNorm(param, input_shape, input.data(), gamma_shape,
gamma.data(), output_shape, output.data()));
Expand Down

0 comments on commit c5fd64a

Please sign in to comment.