Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Tiny upstream pr #1094

Merged
merged 704 commits into from
Sep 10, 2024
Merged

Tiny upstream pr #1094

merged 704 commits into from
Sep 10, 2024

Conversation

qianfengz
Copy link
Contributor

This PR provide:

  1. Synchronize to latest composable kernel commit which added inline-asm implementation of fp32 to bf16 RTN conversion. Using inline-asm RTN conversion is able to improve the performance when BF16+RTN is used
  2. Add compiler options for compiling c++ extension on ROCM/HIP, which is able to improve the performance of HIP FMHA BWD on ROCM 6.2
  • The following are benchmark results compared with triton when using RTN with those compiling options added on ROCM 6.2
Run reference fwd:
Reference fwd time: 28.90159034729004
Run reference bwd:
Reference bwd time: 48.68329620361328
Run triton fwd:
Triton fwd time: 2.0252671241760254
Run triton bwd:
Triton bwd time: 6.977703094482422
Run CK fwd:
xformers fwd time: 1.8350895643234253
Run CK fwd:
xformers bwd time: 7.089707374572754
(triton_dq - ref_dq).abs().mean()=tensor(0.0002, device='cuda:0', dtype=torch.bfloat16)
(triton_dk - ref_dk).abs().mean()=tensor(0.0001, device='cuda:0', dtype=torch.bfloat16)
(triton_dv - ref_dv).abs().mean()=tensor(0.0004, device='cuda:0', dtype=torch.bfloat16)
(xformer_dq - ref_dq).abs().mean()=tensor(0.0002, device='cuda:0', dtype=torch.bfloat16, grad_fn=<MeanBackward0>)
(xformer_dk - ref_dk).abs().mean()=tensor(0.0002, device='cuda:0', dtype=torch.bfloat16, grad_fn=<MeanBackward0>)
(xformer_dv - ref_dv).abs().mean()=tensor(6.7234e-05, device='cuda:0', dtype=torch.bfloat16, grad_fn=<MeanBackward0>)
  • The following are benchmark results compared with triton when using RTN without those compiling options added on ROCM 6.2
Run reference fwd:
Reference fwd time: 28.867050170898438
Run reference bwd:
Reference bwd time: 48.91793441772461
Run triton fwd:
Triton fwd time: 2.056668996810913
Run triton bwd:
Triton bwd time: 6.982858180999756
Run CK fwd:
xformers fwd time: 1.8234171867370605
Run CK fwd:
xformers bwd time: 7.428786754608154
(triton_dq - ref_dq).abs().mean()=tensor(0.0002, device='cuda:0', dtype=torch.bfloat16)
(triton_dk - ref_dk).abs().mean()=tensor(0.0001, device='cuda:0', dtype=torch.bfloat16)
(triton_dv - ref_dv).abs().mean()=tensor(0.0004, device='cuda:0', dtype=torch.bfloat16)
(xformer_dq - ref_dq).abs().mean()=tensor(0.0002, device='cuda:0', dtype=torch.bfloat16, grad_fn=<MeanBackward0>)
(xformer_dk - ref_dk).abs().mean()=tensor(0.0002, device='cuda:0', dtype=torch.bfloat16, grad_fn=<MeanBackward0>)
(xformer_dv - ref_dv).abs().mean()=tensor(8.7738e-05, device='cuda:0', dtype=torch.bfloat16, grad_fn=<MeanBackward0>)

qianfengz and others added 30 commits February 5, 2024 17:58
ensure ck_decoder does not dispatch in test_attn_bias_padded
Apply the existing linters (1/n)
@facebook-github-bot facebook-github-bot added CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. module: rocm labels Sep 5, 2024
@danthe3rd
Copy link
Contributor

Thanks! can you fix the formatting of setup.py tho? (see linter CI)

@qianfengz
Copy link
Contributor Author

Any further layout changing is needed ?

@danthe3rd
Copy link
Contributor

Sorry, forgot about that PR :)
Let me merge it

@danthe3rd danthe3rd merged commit 0004c67 into facebookresearch:main Sep 10, 2024
22 of 27 checks passed
@qianfengz qianfengz deleted the tiny_upstream_pr branch September 20, 2024 08:10
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. module: rocm
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants