-
Notifications
You must be signed in to change notification settings - Fork 158
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[compute/cker] Fix RmsNorm cker #14218
Conversation
This commit fixes RmsNorm cker to accept 3 dimension input and single gamma. ONE-DCO-1.0-Signed-off-by: Seockho Kim [email protected]
const int32_t heights = MatchingDim(input_shape, 1, output_shape, 1); | ||
const int32_t widths = MatchingDim(input_shape, 2, output_shape, 2); | ||
const int32_t channels = MatchingDim(input_shape, 3, output_shape, 3); | ||
bool single_gamma = gamma_shape.DimensionsCount() == 1 && gamma_shape.Dims(0) == 1; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@seockho-kim Could you check whether we need to allow scalar number (it may have DimensionCount() = 0
) ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@seockho-kim Could you check whether we need to allow scalar number (it may have
DimensionCount() = 0
) ?
Current model has no gamma(scale), so it's set to default gamma ( [1.0] ) after fusing.
So, no need to allow scalar number, now, I think.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
AFAIK, gamma
is some value, not 1.0
. Maybe your current model is not final or actual model. Let's check why gamma
is 1.0
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
AFAIK,
gamma
is some value, not1.0
. Maybe your current model is not final or actual model. Let's check whygamma
is1.0
.
Current model's RMSNorm pattern (need to fuse) has no scale.
(#13964 (comment))
So, scale(gamma) is set to 1.0 when it is fused to RMSNorm.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I talked with @jinevening who provided us model, and concluded gamma
is not 1.0
. It has actual values, but gamma
values seem to be propagated to each successor's (= fully connected) weight during our internal quantization and optimization.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
for (int32_t height = 0; height < heights; height++) | ||
{ | ||
for (int32_t width = 0; width < widths; width++) | ||
{ | ||
double square_sum = 0.0f; | ||
for (int32_t channel = 0; channel < channels; channel++) | ||
{ | ||
double input_val = input_data[Offset(input_shape, batch, height, width, channel)]; | ||
square_sum += (input_val * input_val); | ||
} | ||
double rms = std::sqrt((square_sum / channels) + params.epsilon); | ||
for (int32_t channel = 0; channel < channels; channel++) | ||
{ | ||
double gamma = (single_gamma ? gamma_data[0] : gamma_data[channel]); | ||
output_data[Offset(output_shape, batch, height, width, channel)] = | ||
gamma * (input_data[Offset(input_shape, batch, height, width, channel)] / rms); | ||
} | ||
} | ||
} | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@seockho-kim width
must be the inner most. Please refer to compute/cker/include/cker/operation/InstanceNorm.h
.
I was wrong.
square_sum += (input_val * input_val); | ||
} | ||
double rms = std::sqrt((square_sum / channels) + params.epsilon); | ||
for (int32_t channel = 0; channel < channels; channel++) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
for (int32_t channel = 0; channel < channels; channel++) | |
// normalize over last-axis | |
for (int32_t channel = 0; channel < channels; channel++) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'll update it, :)
Comment updated to explain that current RMSNorm normalizes over the last axis. ONE-DCO-1.0-Signed-off-by: Seockho Kim [email protected]
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
This commit fixes RmsNorm cker to accept 3 dimension input and single gamma.
ONE-DCO-1.0-Signed-off-by: Seockho Kim [email protected]
issue: #14089
draft: #14088