Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merge recent changes from ROCm xformers #1196

Open
wants to merge 884 commits into
base: main
Choose a base branch
from

Conversation

qianfengz
Copy link
Contributor

@qianfengz qianfengz commented Jan 17, 2025

This PR provides

  1. Add support of hdim-512 for fmha-fwd
  2. Avoid PageBlockDiagonal attn_bias types being used with forward-training (since PageKVCache is only implemented by the splitkv-kernel)
  3. Fix in xformers/benchmarks/benchmark_attn_decoding.py to make it works correctly for ck.FwOp
  4. Performance improvement for decoder fmha-fwd with mqa/gqa

The following scripts are used to test/verify the changes

#> pytest tests/test_mem_eff_attention.py::test_forward
#> pytest tests/test_mem_eff_attention.py::test_backward
#> pytest tests/test_mem_eff_attention.py::test_dropout_ck
#> pytest tests/test_mem_eff_attention.py::test_dropout_backward_ck
#> pytest tests/test_mem_eff_attention.py::test_logsumexp
#> pytest tests/test_mem_eff_attention.py::test_paged_attention_ck

The following script is used to benchmark/verify the performance of decoder with mqa/gqa using ck.FwOp

#> python xformers/benchmarks/benchmark_attn_decoding.py

qianfengz and others added 30 commits July 9, 2024 18:22
Avoid unused-const-variable warning
qianfengz and others added 23 commits January 13, 2025 04:52
[CK] Memory-efficient attention (Head Dimension = 512)
Remove using splitkv kernel from fmha fwd training path
Disable PagedAttn bias types and hdim-512 for test_logsumexp
Enable hdim=512 by default
Further update to build hdim-512 by default
Merge upstream into ROCM develop
@facebook-github-bot facebook-github-bot added CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. module: rocm labels Jan 17, 2025
@xw285cornell
Copy link
Contributor

Let's hold off a bit, I'm still working on merging the prior PR. Can we make sure all mem efficient tests are passing?

@qianfengz
Copy link
Contributor Author

qianfengz commented Jan 23, 2025

Let's hold off a bit, I'm still working on merging the prior PR. Can we make sure all mem efficient tests are passing?

I just pushed commit f858c, let the forward training path still able to use splitkv-kernel, since without doing this, the benchmark scripts which uses memory_efficient_attention_partial will not benefit from our recent optimization for small-q sizes by the splitkv kernel. But with this enabled the unit test

#> pytest tests/test_mem_eff_attention.py::test_forward 

will have 14 bfloat16 cases failed even when export ENABLE_HIP_FMHA_RTN_CONVERT16=1 is used to enable RTN method for fp32 to bfloat16 conversion (which is much more accurate than the default RTZ conversion method). So currently, we are not able to judge whether the failed cases are due to bug or lower accuracy with regard to the lse output from fmha-forward kernel brings expanded inaccuracy in the final outputs (dQuery, dKey, dValue).

@danthe3rd
Copy link
Contributor

danthe3rd commented Jan 30, 2025

Hi, is this ready to merge? We would like to do a new release soon (PT 2.6 is just out)
Linters are still failing

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. module: rocm
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants