Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[compute/cker] Fix RmsNorm cker #14218

Merged
merged 2 commits into from
Oct 16, 2024

Conversation

seockho-kim
Copy link
Contributor

This commit fixes RmsNorm cker to accept 3 dimension input and single gamma.

ONE-DCO-1.0-Signed-off-by: Seockho Kim [email protected]

issue: #14089
draft: #14088

This commit fixes RmsNorm cker to accept 3 dimension input and single gamma.

ONE-DCO-1.0-Signed-off-by: Seockho Kim [email protected]
@seockho-kim seockho-kim requested a review from a team October 14, 2024 06:49
const int32_t heights = MatchingDim(input_shape, 1, output_shape, 1);
const int32_t widths = MatchingDim(input_shape, 2, output_shape, 2);
const int32_t channels = MatchingDim(input_shape, 3, output_shape, 3);
bool single_gamma = gamma_shape.DimensionsCount() == 1 && gamma_shape.Dims(0) == 1;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@seockho-kim Could you check whether we need to allow scalar number (it may have DimensionCount() = 0) ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@seockho-kim Could you check whether we need to allow scalar number (it may have DimensionCount() = 0) ?

Current model has no gamma(scale), so it's set to default gamma ( [1.0] ) after fusing.
So, no need to allow scalar number, now, I think.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

AFAIK, gamma is some value, not 1.0. Maybe your current model is not final or actual model. Let's check why gamma is 1.0.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

AFAIK, gamma is some value, not 1.0. Maybe your current model is not final or actual model. Let's check why gamma is 1.0.

Current model's RMSNorm pattern (need to fuse) has no scale.
(#13964 (comment))
So, scale(gamma) is set to 1.0 when it is fused to RMSNorm.

Copy link
Contributor

@glistening glistening Oct 16, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@seockho-kim

I talked with @jinevening who provided us model, and concluded gamma is not 1.0. It has actual values, but gamma values seem to be propagated to each successor's (= fully connected) weight during our internal quantization and optimization.

glistening
glistening previously approved these changes Oct 16, 2024
Copy link
Contributor

@glistening glistening left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@glistening glistening dismissed their stale review October 16, 2024 05:45

Let me see the code more.

Comment on lines 47 to 66
for (int32_t height = 0; height < heights; height++)
{
for (int32_t width = 0; width < widths; width++)
{
double square_sum = 0.0f;
for (int32_t channel = 0; channel < channels; channel++)
{
double input_val = input_data[Offset(input_shape, batch, height, width, channel)];
square_sum += (input_val * input_val);
}
double rms = std::sqrt((square_sum / channels) + params.epsilon);
for (int32_t channel = 0; channel < channels; channel++)
{
double gamma = (single_gamma ? gamma_data[0] : gamma_data[channel]);
output_data[Offset(output_shape, batch, height, width, channel)] =
gamma * (input_data[Offset(input_shape, batch, height, width, channel)] / rms);
}
}
}
}
Copy link
Contributor

@glistening glistening Oct 16, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@seockho-kim width must be the inner most. Please refer to compute/cker/include/cker/operation/InstanceNorm.h.

I was wrong.

@glistening glistening requested a review from a team October 16, 2024 05:52
square_sum += (input_val * input_val);
}
double rms = std::sqrt((square_sum / channels) + params.epsilon);
for (int32_t channel = 0; channel < channels; channel++)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
for (int32_t channel = 0; channel < channels; channel++)
// normalize over last-axis
for (int32_t channel = 0; channel < channels; channel++)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll update it, :)

@glistening glistening requested a review from a team October 16, 2024 07:04
Comment updated to explain that current RMSNorm normalizes over the last axis.

ONE-DCO-1.0-Signed-off-by: Seockho Kim [email protected]
Copy link
Contributor

@glistening glistening left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@glistening glistening requested a review from a team October 16, 2024 07:47
@glistening glistening merged commit c5fd64a into Samsung:master Oct 16, 2024
9 checks passed
@seockho-kim seockho-kim deleted the compute_cker_rmsnorm branch October 17, 2024 00:41
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
Status: Done
Development

Successfully merging this pull request may close these issues.

2 participants