Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add rmsnorm kernel #633

Merged
merged 2 commits into from
Sep 7, 2024
Merged

Add rmsnorm kernel #633

merged 2 commits into from
Sep 7, 2024

Conversation

rahulbatra85
Copy link

Adds forward kernel for RMSNorm

triton.Config({}, num_warps=16, num_stages=1),
]

def get_hip_autotune_config():
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nothing wrong here. Just wondering if these configs are comprehensive enough to cover the typical use cases you have encountered in actual models

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good question, I just came up with these on what I thought made sense. Other than that, no systematic way to come up with this. I was thinking may be I can add an argument to take in a custom config file. That way one doesn't need to touch the code in this file if some other configs need to be benchmarked.

Actually, I noticed a small bug for CUDA. There are repeated entries.

@brunomazzottiamd

This comment was marked as resolved.

@brunomazzottiamd

This comment was marked as resolved.

row = tl.load(row_start_ptr + col_offsets, mask=mask, other=0.0)
g = tl.load(g_ptr + col_offsets, mask=mask, other=0.0)
row_norm = row * row #square each value
row_norm = tl.sum(row_norm, axis=-1) #sum across columns(axis=-1)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For RMSNorm, is it always safe to accumulate FP16 values with a FP16 accumulator? Do we need a FP32 accumulator here? I'm supposing that tl.sum of FP16's is also a FP16.

For instance, if we square 7 three-digit numbers and then add them together we're already overflowing FP16.

I'm not that familiar with normalization layers, maybe I'm just too conservative and I'm thinking too much in the general case.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I need to think about this case a bit more.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let me try to give you some food for thought...

Use a wider data type as accumulator

Our GEMM kernel does this:

# INT32 accumulator for INT8 data and FP32 accumulator for everything else.
acc_dtype = tl.float32 if a_ptr.type.element_ty != tl.int8 else tl.int32
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=acc_dtype)

You can do something similar. I think it's safe to accumulate using FP32.

Use a scale factor

You can scale down the numbers before squaring them and scale up the mean of squares afterwards. Here is a NumPy prototype of the idea:

import numpy as np


def mean_sqr(x):
    return np.sum(x * x) / len(x)


# Scale factor can be a constant or computed on the fly.
def mean_sqr_with_scale_factor(x, scale_factor):
    x = (1 / scale_factor) * x
    mean_sqr = np.sum(x * x) / len(x)
    return (scale_factor * scale_factor) * mean_sqr


def compute_scale_factor(x):
    max_x = np.max(np.abs(x))
    return np.exp(np.floor(np.log(max_x))) if max_x > 0 else 1


np.random.seed(42)
x = np.random.uniform(size=4096, low=-500.0, high=500.0).astype(np.float32)
print(mean_sqr(x))
print(mean_sqr_with_scale_factor(x, compute_scale_factor(x)))

Larger scale factors reduce the risk of overflow but increase the chance of losing precision. Smaller scale factors retain precision but may not prevent overflow effectively if the numbers are too large.

Inspect PyTorch implementation

Try to find out what PyTorch is doing. If PyTorch doesn't care about sum of squares overflow then we can follow it and doesn't care as well.

@brunomazzottiamd
Copy link
Collaborator

Seeing this PR reminded that I added a kernel without a benchmark and without a test! Shame on me! I think adding benchmark and correctness test must be mandatory from now on. Reviewers should reject PRs lacking these features.

@brunomazzottiamd brunomazzottiamd mentioned this pull request Sep 5, 2024
@rahulbatra85

This comment was marked as resolved.

@rahulbatra85

This comment was marked as resolved.

Copy link
Collaborator

@micmelesse micmelesse left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we add tests for this kernel?

@micmelesse
Copy link
Collaborator

@rahulbatra85 There are failures in the RMS prop code

@rahulbatra85
Copy link
Author

@rahulbatra85 There are failures in the RMS prop code

yeah, looking.

@rahulbatra85
Copy link
Author

rahulbatra85 commented Sep 6, 2024

@rahulbatra85 There are failures in the RMS prop code

yeah, looking.

@micmelesse ah ok, so rms norm layer was added in Pytorch starting with version 2.4. Which docker image does the CI use?

@micmelesse
Copy link
Collaborator

@rahulbatra85 It's rocm/pytorch:rocm6.0.2_ubuntu22.04_py3.10_pytorch_2.1.2 you can probably update it to the newest one

@rahulbatra85
Copy link
Author

@rahulbatra85 It's rocm/pytorch:rocm6.0.2_ubuntu22.04_py3.10_pytorch_2.1.2 you can probably update it to the newest one

@micmelesse It's passing with new docker image.

@rahulbatra85 rahulbatra85 merged commit c4bd738 into main_perf Sep 7, 2024
4 checks passed
rahulbatra85 added a commit that referenced this pull request Sep 12, 2024
@rahulbatra85 rahulbatra85 deleted the main_perf-rmsnorm branch September 24, 2024 18:36
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants