Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merge recent changes from ROCm xformers #1196

Open
wants to merge 885 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
885 commits
Select commit Hold shift + click to select a range
bd49f48
Remove _check_large_shapes checking in fmha/ck.py (#1067)
qianfengz Jul 14, 2024
0d1d1be
make xformers install editable to fix cpp extensions detection
tenpercent Jul 18, 2024
9390d6a
Update to using the improved fmha-bwd (compiling passed)
qianfengz Jul 23, 2024
22fce7e
Update to get 80% of the test_backward and test_dropout_backward_ck c…
qianfengz Jul 23, 2024
463a475
Replace the using of ConvertGradQ by using torch tensor type converting
qianfengz Jul 25, 2024
3427a6f
Change the tile settings for MaxK=32
qianfengz Jul 25, 2024
fbc7c50
Fix padding setting bug in grouped_backward
qianfengz Jul 26, 2024
6e08666
Change -DCK_FMHA_FWD_FAST_EXP2=1 to -DCK_TILE_FMHA_FWD_FAST_EXP2=1
qianfengz Jul 26, 2024
94ab599
Point the composable_kernel_tiled submodule to ck_tile/fa_bwd_opt branch
qianfengz Jul 26, 2024
830697c
Disable flshattF and flshattB on ROCM
qianfengz Jul 27, 2024
afd7e02
Add -mllvm and -enable-post-misched=0 compiling options for ROCM on s…
qianfengz Jul 27, 2024
e67de41
Disable flshattF and flshattB on ROCM
qianfengz Jul 27, 2024
d72c2b3
Update to support separate grad_q_f32_strides do to the API change in…
qianfengz Jul 28, 2024
5ddff31
Use old method for setting BlockDropout due to the revert in fmha_fwd…
qianfengz Jul 28, 2024
cf2b622
Tiny fix in grouped_backward
qianfengz Jul 28, 2024
112aaed
Use packed tensor allocation for grad_q_f32
qianfengz Jul 28, 2024
dd83c62
Update to the ConvertGradQ kernel calling
qianfengz Jul 28, 2024
3e9b99d
Tiny update
qianfengz Jul 28, 2024
019448e
Fix the parameter location in grouped_backward
qianfengz Jul 29, 2024
c55966a
Adjust headdim128 tile shapes for better performance
qianfengz Aug 5, 2024
e22829a
Update backward kernel calling due to adding of nhead_stride_dk/nhead…
qianfengz Aug 5, 2024
cae1b77
Synchronize with CK to use separate pipeline for kPadHeadDim true of …
qianfengz Aug 5, 2024
e564f5e
Use convertDQ kernel
qianfengz Aug 6, 2024
b043765
Update to use unpadded lse layout
qianfengz Aug 7, 2024
c9e7595
Add explicit headdim256 instances for fmha backward
qianfengz Aug 7, 2024
4a7b7dc
Add leaked headdim256 instance references
qianfengz Aug 7, 2024
1ad9cbe
Change to generate.py and the re-generate the instance files using it
qianfengz Aug 7, 2024
7db2aa4
Change to generate.py to generate instances refences and uses the gen…
qianfengz Aug 7, 2024
73dbf32
Relax the RTOL of ckFwOp from 4e-4 to 3e-3 due to one big result case
qianfengz Aug 8, 2024
0e6d0c3
Change to use .h rather than .hpp as suffix for generated header files
qianfengz Aug 12, 2024
914ccc5
Fix in .gitignore
qianfengz Aug 12, 2024
8503f87
Update to bwd setting to use only IGLP pipeline
qianfengz Aug 12, 2024
bfe164d
Synchronize to latest ck_tile fix and align the headdim64 tile shape …
qianfengz Aug 12, 2024
f75c3b2
Reformat the generated instances cpp files
qianfengz Aug 12, 2024
520e6ed
Merge pull request #18 from ROCm/fa_bwd_opt_test
qianfengz Aug 12, 2024
bc3db99
Fix to the backward Trait
qianfengz Aug 13, 2024
fa6d8b3
Set occupancy to -1 to avoid the compiling warning
qianfengz Aug 13, 2024
c5c7cce
Revert "Set occupancy to -1 to avoid the compiling warning"
qianfengz Aug 13, 2024
d230433
Add environment variable and compiler definition to control the gener…
qianfengz Aug 14, 2024
82a07ae
Add --ignore-hd256 argument to generate_instance.py and some update i…
qianfengz Aug 14, 2024
38593d6
Add environment variable ENABLE_HIP_FMHA_RTN_BF16_CONVERT to enable u…
qianfengz Aug 15, 2024
15dc911
Remove commented lines in test_mem_eff_attention.py
qianfengz Aug 15, 2024
367274c
Synchronize to latest ck_tile commit
qianfengz Aug 15, 2024
f7b28c5
apply black
tenpercent Aug 16, 2024
fd82f20
apply flake8
tenpercent Aug 16, 2024
7d21800
fix mypy
tenpercent Aug 16, 2024
d6b6456
revert disable flash operator on rocm
tenpercent Aug 16, 2024
87188ea
Synchronize to ck_tile latest commit again
qianfengz Aug 16, 2024
5be80a3
Re-position the composable_kernel submodule to the develop branch
qianfengz Aug 17, 2024
cee0980
Merge pull request #20 from tenpercent/develop
qianfengz Aug 17, 2024
2a5c141
Avoid the Async pipeline when khasBias is true
qianfengz Aug 17, 2024
2874842
clang-format for two files
qianfengz Aug 17, 2024
7a91589
Change allocation of grouped mode lse from [H, M] to [1, H, M] to mat…
qianfengz Aug 17, 2024
66efb2c
Change in generate_instances.py so that this scripts can be called fr…
qianfengz Aug 20, 2024
c19b1f5
Add manual for generate_instances.py (.md)
qianfengz Aug 20, 2024
b450d01
Modification in GENERATE_INSTANCES.md
qianfengz Aug 20, 2024
07dc8e7
Fix in GENERATE_INSTANCES.md
qianfengz Aug 20, 2024
72bf603
Update GENERATE_INSTANCES.md
qianfengz Aug 20, 2024
e397974
clean-up commented codes
qianfengz Aug 20, 2024
7a04357
Revert "Change allocation of grouped mode lse from [H, M] to [1, H, M…
qianfengz Aug 20, 2024
2923301
Merge branch 'main' into develop-asorb-upstream
qianfengz Aug 22, 2024
84b50ac
Merge pull request #22 from ROCm/develop-asorb-upstream
qianfengz Aug 22, 2024
e0e6863
Merge remote-tracking branch 'fair/main' into merge-xformers-0826
tenpercent Aug 26, 2024
e1387a4
Merge pull request #23 from tenpercent/merge-xformers-0826
tenpercent Aug 27, 2024
77a2c24
Synchronize to latest ck develop for using the latest RTN bf16 convert
qianfengz Sep 3, 2024
4e51efa
Add c++ extension compiling options for better performance on ROCM 6.2
qianfengz Sep 3, 2024
bce6363
Merge branch 'main' into develop
tenpercent Sep 10, 2024
2b08141
reformat setup.py
tenpercent Sep 10, 2024
43bb919
Enable complete BlockDiagonalGappyKeysMask and BlockDiagonalPaddedKey…
qianfengz Sep 20, 2024
8382c7d
Sync to latest ck_tile commits and adapt the random_uniform_kernel to…
qianfengz Sep 20, 2024
21ae9d9
Move ck decoder codes to xformers/csrc/attention/hip_decoder folder
qianfengz Sep 22, 2024
fb3628d
Sync to latest ck_tile commits for fixing NaN when seqlen_k == 0
qianfengz Sep 22, 2024
ffa9906
Separate the kernel/pipeline dispatch into two files for infer/forward
qianfengz Sep 22, 2024
221860e
Remove unused member variable in GroupedForwardParams
qianfengz Sep 22, 2024
6a07c16
delete autogenerated files
tenpercent Sep 24, 2024
74355e9
delete autogenerated files (2)
tenpercent Sep 25, 2024
0dbdc5f
Initial add support of fmha-forward splitk (copmiling passed)
qianfengz Sep 23, 2024
6b8ddde
Add generated files under hip_decoder into gitignore list
qianfengz Sep 26, 2024
cf9be1c
apply black and fix lint
tenpercent Sep 26, 2024
08219dc
rewrite hipified split-k decoder invocation to ck-tile style
tenpercent Sep 25, 2024
f37fb3d
Merge pull request #25 from tenpercent/refactor-hip-decoder
tenpercent Sep 27, 2024
9c8d2f1
Force kPadSeqLenQ == true for grouped mode splitkv-combine kernel traits
qianfengz Sep 28, 2024
761e8a5
Add compile-time checking to save compile-time
qianfengz Sep 29, 2024
eb50024
add dockerfile
tenpercent Jul 3, 2024
669ee34
migrate base docker image to manylinux
tenpercent Oct 2, 2024
0a97ed6
build a wheel
tenpercent Oct 2, 2024
f8129c3
rename dockerfile
tenpercent Oct 2, 2024
a0221e5
lint
tenpercent Oct 2, 2024
9975759
Merge pull request #26 from ROCm/dockerfile
tenpercent Oct 2, 2024
4d2a37d
Try adding docker image build workflow
tenpercent Oct 3, 2024
eddb1ec
Merge branch 'develop' into dockerfile
tenpercent Oct 3, 2024
ea3b796
add newline
tenpercent Oct 3, 2024
983fc19
add newline
tenpercent Oct 3, 2024
d6ea535
Merge pull request #27 from ROCm/dockerfile
tenpercent Oct 3, 2024
eb986e1
Update README.md
tenpercent Oct 3, 2024
ac0b05c
Merge pull request #28 from ROCm/tenpercent-patch-1
tenpercent Oct 3, 2024
e391974
Remove directly including of <ck_tile/host.hpp>
qianfengz Oct 3, 2024
9d03beb
Synchronize with latest ck develop commit
qianfengz Oct 3, 2024
0a4d420
Remove the printing in attention_forward_generic_ck_tiled.cpp
qianfengz Oct 3, 2024
a1c788e
Tune the TilePartitioner for splitkv-combine kernel
qianfengz Oct 3, 2024
b1e5ee4
Use 64 as maximum possible number of splitkv
qianfengz Oct 3, 2024
a53ed75
Add environment variable to disable building fmha-fwd-splitkv
qianfengz Oct 4, 2024
772e8f6
Use 32 as maximum number of splits
qianfengz Oct 6, 2024
04bb150
Fix compilation errors due to CK interface change
poyenc Oct 6, 2024
7949da4
Determine kHasUnevenSplits at runtime
poyenc Oct 6, 2024
3a8d7cf
Determine kPadSeqLenK at runtime
poyenc Oct 7, 2024
28ac1ca
Let kPadSeqLenK be reversed value of kHasUnevenSplits
qianfengz Oct 7, 2024
00482a0
Merge branch 'develop' into add-splitkv
qianfengz Oct 7, 2024
b62e722
Merge pull request #29 from ROCm/add-splitkv
qianfengz Oct 7, 2024
e8143c3
Synchronize to latest ck develop commit for updates with regard to fm…
qianfengz Oct 8, 2024
7986c2c
fix build: stream type
tenpercent Oct 8, 2024
93524db
Merge pull request #30 from tenpercent/develop
qianfengz Oct 9, 2024
ee10600
Add support for fmha-bwd headdim-96
qianfengz Oct 11, 2024
c9fa526
Use kK2=96
qianfengz Oct 12, 2024
abc9361
Synchronize the change in ck-tile to rename kQKHeaddimForGemmN to kQK…
qianfengz Oct 14, 2024
5bb0542
Synchronize the change in ck-tile to replace kVHeaddimForGemmN by kVH…
qianfengz Oct 14, 2024
c5b594d
Simplify FmhaBwdPipelineEnumSelector templates
qianfengz Oct 14, 2024
dd6cf04
Merge branch 'develop' into bwd_hd96_perf
qianfengz Oct 14, 2024
723c420
Synchronize to latest ck_tile commit
qianfengz Oct 14, 2024
a15e559
Replace TileFmhaBwdTraits by TileFmhaTraits
qianfengz Oct 15, 2024
2773383
Relocate to ck_tile develop branch and synchronize to latest commits
qianfengz Oct 16, 2024
d4437ad
Merge pull request #31 from ROCm/bwd_hd96_perf
qianfengz Oct 16, 2024
f94fdfd
Remove using splitkv from fmha-fwd training path
qianfengz Oct 16, 2024
4b4327e
Revert "Remove using splitkv from fmha-fwd training path"
qianfengz Oct 17, 2024
bc107ad
Add kMaxSplits=8 support
qianfengz Oct 17, 2024
91e01f9
Add tile settings for splitkv kernel
qianfengz Oct 18, 2024
139334c
Use WarpTile 16x16x16 for fmha-fwd splitkv
qianfengz Oct 20, 2024
c553f1a
Add MaxSeqlenQ as parameter for creating tile shape settings
qianfengz Oct 20, 2024
eb4586e
Update in FmhaFwdSplitKVShape
qianfengz Oct 20, 2024
6b0fae2
Synchronize to the latest commit of ck_tile for split-kv support
qianfengz Oct 21, 2024
46bc17d
Merge pull request #32 from ROCm/splitkv_improve
qianfengz Oct 21, 2024
76b9738
Change the selection of Default2DEpilogue for Fwd SplitKV kernel to a…
qianfengz Oct 25, 2024
7243b49
Try to have kPadSeqLenK be false in splitkv dispatch
qianfengz Oct 25, 2024
6ffea6a
Revert "Try to have kPadSeqLenK be false in splitkv dispatch"
qianfengz Oct 26, 2024
5f1ec0c
Synchronize for latest splitkv support in ck-tile
qianfengz Oct 26, 2024
3437842
Use kSubQKHeaddim to replace kK0BlockLength
qianfengz Oct 27, 2024
6c8a8b4
Add headdim96 support for fmha-fwd
qianfengz Oct 28, 2024
cb58e69
Synchronize to latest commit in ck-tile
qianfengz Oct 28, 2024
7d8ced0
Reposition the composable_kernel_tiled submodule to latest ck develop…
qianfengz Oct 30, 2024
06b548c
Merge pull request #34 from ROCm/fwd_hd96_debug
qianfengz Oct 30, 2024
7f91bb1
Synchronize to latest ck_tile commit for some bug fixing in page-attn
qianfengz Nov 11, 2024
44b6def
Fix grad_k/grad_v strides
qianfengz Nov 13, 2024
b000bb3
Merge pull request #36 from ROCm/stride_fix
qianfengz Nov 13, 2024
bdfffaa
Synchronize to latest ck_tile commit for adding Paged-KVCache dependa…
qianfengz Nov 21, 2024
266e3c6
Let splitkv combine kernel not called when num_splits is 1
qianfengz Nov 22, 2024
273a892
Add supported for Paged-KVCache (PagedBlockDiagonalPaddedKeysMask pas…
qianfengz Nov 25, 2024
22df8c9
Add is_gappy indicator to let kernel have special treatment for seqst…
qianfengz Nov 25, 2024
e768502
Fix in _custom_mask_type of ck.py
qianfengz Nov 26, 2024
00c70d0
Add test_paged_attention_ck in tests/test_mem_eff_attention.py
qianfengz Nov 26, 2024
468c83f
position to the latest ck develop branch
qianfengz Nov 26, 2024
95460bc
Change to check causalmask type and window_size parameter together to…
qianfengz Nov 26, 2024
56dba6b
Merge pull request #37 from ROCm/add_paged_kvcache
qianfengz Nov 26, 2024
9ccc42f
bump python op maxk
tenpercent Nov 27, 2024
760cdcc
run codegen
tenpercent Nov 27, 2024
4de46f4
run codegen (1)
tenpercent Nov 27, 2024
89e8e91
add missing FmhaFwdBlockTile instance; handle 512 case when computing…
tenpercent Dec 2, 2024
f13d987
Initial adding support for splitkv smallq pipeline
qianfengz Dec 3, 2024
672617b
fix compile error in qr_ks_vs pipeline
tenpercent Dec 3, 2024
d7099cb
fix occupancy related compilation errors
tenpercent Dec 3, 2024
a198345
try adding qsksvs pipeline and stash the result
tenpercent Dec 4, 2024
580ec51
Synchronize to latest ck_tile commit to utilize the padding optimzation
qianfengz Dec 6, 2024
8a45436
Merge pull request #40 from ROCm/optimize_padding
qianfengz Dec 6, 2024
a19d6a3
Resync to latest ck-tile commit for padding optimization
qianfengz Dec 6, 2024
e27b84c
Fix in batched_forward splitkv dispatch
qianfengz Dec 6, 2024
5041a12
Merge branch 'develop' into add_splitkv_smallq
qianfengz Dec 6, 2024
aee3570
Fix in batched_forward splitkv smallq dispatch
qianfengz Dec 6, 2024
be06c43
Update the splits selector and instances settings for splitkv-smallq …
qianfengz Dec 9, 2024
aff7bfd
Enable gemm-0 to use 16x16x16 warp-gemm
qianfengz Dec 10, 2024
1922015
enable offload compression
tenpercent Dec 4, 2024
2cc18ef
run black
tenpercent Dec 11, 2024
da455ec
fix merge conflict (1)
tenpercent Dec 11, 2024
21330ed
reset submodule
tenpercent Dec 11, 2024
e8946b2
cleanup
tenpercent Dec 11, 2024
7e92d1f
Merge remote-tracking branch 'origin/develop' into ci-fixes
tenpercent Dec 11, 2024
8b580f4
run black
tenpercent Dec 11, 2024
3f9a40b
Merge pull request #41 from ROCm/ci-fixes
tenpercent Dec 12, 2024
afdfa46
Synchronize to use the latest optimization for splitkv combine kernel
qianfengz Dec 13, 2024
1258328
Update in ck FwOp apply() to welll utilize the group query support in…
qianfengz Dec 15, 2024
08edbf9
Update to let fmha infer kernel can select either 16x16 or 32x32 inst…
qianfengz Dec 16, 2024
57e157e
Remove the conditional compiling of using splitkv kernel
qianfengz Dec 16, 2024
c1647c7
Merge remote-tracking branch 'origin/develop' into hdim-512
tenpercent Dec 16, 2024
84d7253
Sync to the latest commit of the ck_tile branch
qianfengz Dec 17, 2024
97523dd
Sync to the latest commit of the ck_tile branch for updated pipeline …
qianfengz Dec 17, 2024
aa781c8
Update in the method for determining num_kv_splits
qianfengz Dec 17, 2024
e53d164
Update to the tile setting for splitkv-smallq headdim128
qianfengz Dec 17, 2024
1ae3de9
call qsksvs pipeline on either async or sync codepath in dispatch
tenpercent Dec 18, 2024
f10bc80
more pipeline changes
tenpercent Dec 18, 2024
83cabd4
update submodule
tenpercent Dec 18, 2024
53d4e0e
update headdim switch
tenpercent Dec 18, 2024
d0431e1
Update to the splitkv and splitkv-smallq selector
qianfengz Dec 18, 2024
5644f9f
fix kernel not being called
tenpercent Dec 18, 2024
bb703b5
test head dimension 512 for ckF
tenpercent Dec 18, 2024
2ea82a9
re-run generate_instances.py to please clang-format
tenpercent Dec 18, 2024
82ba746
run clang-format
tenpercent Dec 18, 2024
70d767d
run black
tenpercent Dec 18, 2024
6605ddb
Add ck in tests/test_mem_eff_attention.py::test_backward_gqa
qianfengz Dec 18, 2024
c1ab8e5
Re-position to latest develop branch and rename the SplitkvSmallq pip…
qianfengz Dec 20, 2024
4a5298c
Merge branch 'develop' into add_splitkv_smallq_nwarps
qianfengz Dec 20, 2024
73204e1
Merge pull request #44 from ROCm/add_splitkv_smallq_nwarps
qianfengz Dec 20, 2024
73d06c1
Replace the reshape() by flatten/unflatten in ck.py
qianfengz Dec 20, 2024
2980a55
Update ck.py to support expanded 5-D input for ck.FwOp
qianfengz Dec 20, 2024
84414b1
Fix in ck.py
qianfengz Dec 21, 2024
bf33926
Remove using partitioner for fmha kernels
qianfengz Dec 26, 2024
256d6a4
Add support for mqa_decoder optimization which merge Hq/Hkv with seql…
qianfengz Jan 7, 2025
23d7b1c
Synchronize to latest ck_tile commit which has changed GridSize() of …
qianfengz Jan 7, 2025
d66e7bf
Merge pull request #46 from ROCm/mqa_decoder_improve
qianfengz Jan 7, 2025
e07d13c
bump submodule
tenpercent Jan 7, 2025
bf78988
bump submodule to today's merge commit in ck
tenpercent Jan 7, 2025
8c28fdb
Merge remote-tracking branch 'origin/develop' into hdim-512
tenpercent Jan 7, 2025
40cbefb
refactor dispatch
tenpercent Jan 8, 2025
40f92e7
bump ck submodule to the current develop branch
tenpercent Jan 8, 2025
e4a7f3b
fix flake8 lint
tenpercent Jan 8, 2025
e5a43d4
Merge branch 'develop' into hdim-512
tenpercent Jan 8, 2025
cbe8e20
clang-format
tenpercent Jan 8, 2025
c2d9939
Removing the compressing of expanded 5D to 4D for xops.fmha.ck.FwOp
qianfengz Jan 9, 2025
58fa14a
Merge branch 'develop' of https://github.com/ROCm/xformers into develop
qianfengz Jan 9, 2025
cb60bad
wheels
johnnynunez Jan 9, 2025
b301741
Synchronize to latest ck_tile commit
qianfengz Jan 10, 2025
2f75f5a
Skip PagedBlockDiagonal attn_bias types for hdim-512
qianfengz Jan 10, 2025
acb58a5
Remove using DISABLE_HD256_HIP_FMHA env-variable and FMHA_SUPPORT_MAX…
qianfengz Jan 10, 2025
1887a33
Add using ENABLE_HD512_HIP_FMHA env-variable and FMHA_LIMIT_MAX_HEADD…
qianfengz Jan 10, 2025
a5c68d2
Update to the selector to explicitly use non-splitkv kernel for hdim-512
qianfengz Jan 10, 2025
73d7b78
Merge pull request #1 from ROCm/hdim-512-testing
tenpercent Jan 10, 2025
6da69d3
Update wheels.yml
johnnynunez Jan 10, 2025
eeb581f
Synchronize to latest ck commit
qianfengz Jan 13, 2025
701685c
Use 64x128 Gemm0 Tile and WarpGemm-16x16x16 for hdim-512
qianfengz Jan 13, 2025
fd11dbd
Merge pull request #48 from ROCm/hdim-512
qianfengz Jan 13, 2025
84883b5
Remove using splitkv kernel from fmha fwd training path
qianfengz Jan 13, 2025
2f66b19
Merge pull request #49 from ROCm/hack_test_backward
qianfengz Jan 13, 2025
be6f8c2
Add -Wc++11-narrowing to hip_fmha compiling options to avoid any erro…
qianfengz Jan 14, 2025
1f12982
Merge branch 'develop' into develop
johnnynunez Jan 15, 2025
e14bf36
Update wheels.yml
johnnynunez Jan 15, 2025
6213bf6
Disable PagedAttn bias types and hdim-512 for test_logsumexp
qianfengz Jan 15, 2025
028196d
Merge pull request #50 from ROCm/fix_test_logsumexp
qianfengz Jan 15, 2025
d6e7e4f
Merge branch 'develop' into develop
johnnynunez Jan 15, 2025
58c037b
Update wheels.yml
johnnynunez Jan 15, 2025
1dcb9d8
hotfix typo
tenpercent Jan 15, 2025
21ede52
Merge pull request #51 from tenpercent/develop
tenpercent Jan 15, 2025
433f4f9
Merge branch 'develop' into develop
tenpercent Jan 15, 2025
4685c44
Merge pull request #47 from johnnynunez/develop
tenpercent Jan 16, 2025
6c78398
enable hdim=512 by default
tenpercent Jan 16, 2025
865e802
Merge branch 'develop' into develop
tenpercent Jan 16, 2025
beadd0b
Merge pull request #52 from tenpercent/develop
qianfengz Jan 17, 2025
0c85bee
Further update to build hdim-512 by default
qianfengz Jan 17, 2025
fdc410a
Merge pull request #53 from ROCm/further_fix
qianfengz Jan 17, 2025
9928374
Merge remote-tracking branch 'upstream/main' into merge_upstream
qianfengz Jan 17, 2025
9045af7
Merge pull request #54 from ROCm/merge_upstream
qianfengz Jan 17, 2025
8e84e22
Remove Dockerfile.rocm
qianfengz Jan 17, 2025
4cfab36
Revert "Remove using splitkv kernel from fmha fwd training path"
qianfengz Jan 23, 2025
d141385
Fix in ck.py to handle attn_bias types with 5-D bias tensor
qianfengz Feb 11, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
The table of contents is too big for display.
Diff view
Diff view
  •  
  •  
  •  
2 changes: 1 addition & 1 deletion .github/workflows/rocm_build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ jobs:
python: ['3.11']
torch_version: ['2.5.1']
toolkit_type: ['rocm']
toolkit_short_version: ['6.1', '6.2']
toolkit_short_version: ['6.1', '6.2', '6.3']

uses: ./.github/workflows/wheels_build.yml
if: github.repository == 'rocm/xformers'
Expand Down
5 changes: 3 additions & 2 deletions .github/workflows/wheels.yml
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ jobs:
# NOTE: Don't forget to update `upload_pt`'s matrix
# when changing the CUDA/ROCM versions below!
CU_VERSIONS = ['118', '121', '124']
ROCM_VERSIONS = ["6.1"] # <- 6.0 broken in `manylinux_2_28`
ROCM_VERSIONS = ['6.1', '6.2', '6.3'] # <- 6.0 broken in `manylinux_2_28`
PY_CU = list(itertools.product(PY_VERSIONS, CU_VERSIONS))
PY_ROCM = list(itertools.product(PY_VERSIONS, ROCM_VERSIONS))
print("Full matrix PY_CU", PY_CU)
Expand Down Expand Up @@ -111,11 +111,12 @@ jobs:
- cu121
- cu124
- rocm6.1
- rocm6.2
- rocm6.3
uses: ./.github/workflows/wheels_upload_s3.yml
with:
aws_role: "arn:aws:iam::749337293305:role/pytorch_bot_uploader_role"
s3_path: s3://pytorch/whl/${{ matrix.suffix }}/
aws_s3_cp_extra_args: --acl public-read
filter: "*torch2.5.1+${{ matrix.suffix }}*"
execute: ${{ github.repository == 'facebookresearch/xformers' && github.ref_type == 'tag' }}

11 changes: 1 addition & 10 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -522,14 +522,6 @@ def get_extensions():
elif torch.version.hip and (
torch.cuda.is_available() or os.getenv("HIP_ARCHITECTURES", "") != ""
):
disable_hd256_hip_fmha = os.getenv("DISABLE_HD256_HIP_FMHA", "0")
if disable_hd256_hip_fmha == "1":
source_hip_maxk_256 = []
for ff in source_hip:
if ff.endswith("maxk_256.cpp"):
source_hip_maxk_256 += [ff]
source_hip = list(set(source_hip) - set(source_hip_maxk_256))

rename_cpp_cu(source_hip)
hip_version = get_hip_version(ROCM_HOME)

Expand All @@ -549,8 +541,6 @@ def get_extensions():
]

generator_flag = []
if disable_hd256_hip_fmha == "1":
generator_flag += ["-DFMHA_SUPPORT_MAX_HEADDIM_128=1"]

cc_flag = ["-DBUILD_PYTHON_PACKAGE"]
use_rtn_bf16_convert = os.getenv("ENABLE_HIP_FMHA_RTN_BF16_CONVERT", "0")
Expand All @@ -575,6 +565,7 @@ def get_extensions():
"-DCK_TILE_FMHA_FWD_FAST_EXP2=1",
"-fgpu-flush-denormals-to-zero",
"-Werror",
"-Wc++11-narrowing",
"-Woverloaded-virtual",
"-mllvm",
"-enable-post-misched=0",
Expand Down
22 changes: 22 additions & 0 deletions tests/test_mem_eff_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -457,6 +457,16 @@ def test_forward(opFW_device_dtype_biasT_B_Mq_Mkv_H_K_Kv, packed, fmt, **kwargs)
if fmt == "BMK" and not fmha.common._is_bias_type_supported_in_BMK(bias_type):
pytest.skip("BMK incompatible with this bias")

if op is fmha.ck.FwOp:
if (k > 256 or kv > 256) and issubclass(
bias_type,
(
fmha.attn_bias.PagedBlockDiagonalPaddedKeysMask,
fmha.attn_bias.PagedBlockDiagonalGappyKeysMask,
),
):
pytest.skip("ck.FwOp hdim-512 is not supported when Paged-KVCache is used!")

query, key, value, attn_bias = create_tensors(
*opFW_device_dtype_biasT_B_Mq_Mkv_H_K_Kv,
fmt="BMHK" if packed else fmt,
Expand Down Expand Up @@ -545,6 +555,18 @@ def test_logsumexp(opFW_device_dtype_biasT_B_Mq_Mkv_H_K_Kv):
kv,
) = opFW_device_dtype_biasT_B_Mq_Mkv_H_K_Kv

if op is fmha.ck.FwOp:
if issubclass(
bias_type,
(
fmha.attn_bias.PagedBlockDiagonalPaddedKeysMask,
fmha.attn_bias.PagedBlockDiagonalGappyKeysMask,
),
):
pytest.skip(
"With ck.FwOp Paged-KVCache has some problem with forward training!"
)

query, key, value, attn_bias = create_tensors(
*opFW_device_dtype_biasT_B_Mq_Mkv_H_K_Kv,
fmt="BMHK",
Expand Down
2 changes: 1 addition & 1 deletion third_party/composable_kernel_tiled
67 changes: 67 additions & 0 deletions xformers/benchmarks/benchmark_attn_decoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,73 @@ class AttentionDecodingCUTLASS(AttentionDecodingBase):
class AttentionDecodingCK(AttentionDecodingBase):
OP = xops.fmha.ck.FwOp

def __init__(
self,
B: int,
Mq: int,
Mkv: int,
Hq: int,
Hkv: int,
K: int,
bw: bool,
attn_bias_type,
) -> None:
dtype = torch.float16
torch.manual_seed(10)
self.sub_label = (
f"B={B} Mq={Mq} Mkv={Mkv} Hq={Hq} Hkv={Hkv} K={K} TotalBytes="
f"{((B * Mkv * Hkv * K * 2) + (B * Mq * Hq * K) + (B * Mq * Hq * K)) * 2}"
)
self.label = "attn_decoding"
self.shapes = (B, Mq, Mkv, Hq, Hkv, K)

assert Hkv <= Hq
assert Hq % Hkv == 0
self.q = torch.randn(
[B, Mq, Hkv, Hq // Hkv, K], device="cuda", dtype=dtype, requires_grad=bw
)
self.k = torch.randn(
[B, Mkv, Hkv, 1, K], device="cuda", dtype=dtype, requires_grad=bw
).expand(-1, -1, -1, Hq // Hkv, -1)
self.v = torch.randn(
[B, Mkv, Hkv, 1, K], device="cuda", dtype=dtype, requires_grad=bw
).expand(-1, -1, -1, Hq // Hkv, -1)

if Hq == Hkv:
self.q = self.q[:, :, :, 0]
self.k = self.k[:, :, :, 0]
self.v = self.v[:, :, :, 0]

self.attn_bias = create_attn_bias(
attn_bias_type,
batch_size=B,
num_heads=Hq,
num_heads_groups=Hq // Hkv,
q_len=Mq,
kv_len=Mkv,
dtype=dtype,
device=device,
requires_grad=False,
fmt="BMHK",
op=self.OP,
)

if isinstance(
self.attn_bias,
xops.fmha.attn_bias.BlockDiagonalCausalWithOffsetPaddedKeysMask,
):
self.q = self.q.view(1, -1, *self.q.shape[2:])
self.k = self.k.view(1, -1, *self.k.shape[2:])
self.v = self.v.view(1, -1, *self.v.shape[2:])

if hasattr(self.OP, "not_supported_reasons"):
inp = xops.fmha.Inputs(
query=self.q, key=self.k, value=self.v, attn_bias=self.attn_bias
)
not_supported_reasons = self.OP.not_supported_reasons(inp)
if not_supported_reasons:
raise NotSupportedInputError(not_supported_reasons)


class AttentionDecodingCKDecoder(AttentionDecodingBase):
OP = xops.fmha.ck_decoder.FwOp
Expand Down
35 changes: 21 additions & 14 deletions xformers/csrc/attention/hip_fmha/ck_tiled_fmha_batched_forward.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,24 +23,31 @@ template <
void run_batched_forward_mask_bias_dropout_dispatch(
BatchedForwardParams& param,
hipStream_t stream) {
// currently split-kv implementation does not support dropout
// currently split-kv implementation does not support:
// (*) dropout
// (*) head dimension > 256
if constexpr (!kHasDropout) {
if (param.use_split_kv) {
if (use_splitkv_smallq(param.M, std::max(param.K, param.Kv))) {
batched_forward_splitkv_smallq_mask_bias_dropout_dispatch<
ScalarType,
kHasMask,
kHasBias,
MaxK>::Run(param, stream);
} else {
FMHA_FWD_SEQLEN_Q_SWITCH(param.M, MaxSeqlenQ, [&] {
batched_forward_splitkv_mask_bias_dropout_dispatch<
if (param.use_split_kv && MaxK <= 256) {
if constexpr (MaxK <= 256) {
if (use_splitkv_smallq(param.M, std::max(param.K, param.Kv))) {
batched_forward_splitkv_smallq_mask_bias_dropout_dispatch<
ScalarType,
kHasMask,
kHasBias,
MaxK,
MaxSeqlenQ>::Run(param, stream);
});
MaxK>::Run(param, stream);
} else {
FMHA_FWD_SEQLEN_Q_SWITCH(param.M, MaxSeqlenQ, [&] {
batched_forward_splitkv_mask_bias_dropout_dispatch<
ScalarType,
kHasMask,
kHasBias,
MaxK,
MaxSeqlenQ>::Run(param, stream);
});
}
} else {
// Unreachable. Do not instantiate split-kv pipelines with head
// dimension > 256
}
} else {
if (get_fmha_fwd_mtile(param.B, param.Hq, param.M) == 128)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ struct batched_forward_mask_bias_dropout_dispatch {

using FmhaFwdShape_ = typename FmhaFwdShape<MaxK, MTile>::Type;
constexpr ck_tile::index_t occupancy =
(MaxK == 64) ? 3 : ((MaxK == 256) ? 1 : 2);
(MaxK == 64) ? 3 : ((MaxK >= 256) ? 1 : 2);

constexpr auto kBiasEnum = kHasBias
? ck_tile::BlockAttentionBiasEnum::ELEMENTWISE_BIAS
Expand Down Expand Up @@ -89,8 +89,10 @@ struct batched_forward_mask_bias_dropout_dispatch {
using FmhaPipelineProblem =
FmhaPipelineProblemTemp<FmhaFwdTraits_, FmhaMask>;

using FmhaFwdPipeline_ =
ck_tile::BlockFmhaPipelineQRKSVS<FmhaPipelineProblem>;
using FmhaFwdPipeline_ = std::conditional_t<
MaxK <= 256,
ck_tile::BlockFmhaPipelineQRKSVS<FmhaPipelineProblem>,
ck_tile::BlockFmhaPipelineQSKSVS<FmhaPipelineProblem>>;

using FmhaFwdEpilogue_ =
ck_tile::Default2DEpilogue<ck_tile::Default2DEpilogueProblem<
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@ struct batched_forward_splitkv_mask_bias_dropout_dispatch {
false, // kDoFp8StaticQuant place-holder
false, // kIsPagedKV
kHasUnevenSplits,
false, // kMergeNumHeadGroupsSeqLenQ
occupancy>;

if (param.num_kv_splits > 1) {
Expand Down Expand Up @@ -305,7 +306,7 @@ struct batched_forward_splitkv_mask_bias_dropout_dispatch {
}();

dim3 kGridSize = FmhaFwdSplitKVKernel::GridSize(
param.B, param.Hq, param.M, param.Kv, param.num_kv_splits);
param.B, param.Hq, param.Hkv, param.M, param.Kv, param.num_kv_splits);
constexpr dim3 kBlockSize = FmhaFwdSplitKVKernel::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = FmhaFwdSplitKVKernel::kBlockPerCu;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@ struct batched_forward_splitkv_smallq_mask_bias_dropout_dispatch {
false, // kDoFp8StaticQuant place-holder
false, // kIsPagedKV
kHasUnevenSplits,
false, // kMergeNumHeadGroupsSeqLenQ
occupancy>;

if (param.num_kv_splits > 1) {
Expand Down Expand Up @@ -304,7 +305,7 @@ struct batched_forward_splitkv_smallq_mask_bias_dropout_dispatch {
}();

dim3 kGridSize = FmhaFwdSplitKVKernel::GridSize(
param.B, param.Hq, param.M, param.Kv, param.num_kv_splits);
param.B, param.Hq, param.Hkv, param.M, param.Kv, param.num_kv_splits);
constexpr dim3 kBlockSize = FmhaFwdSplitKVKernel::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = FmhaFwdSplitKVKernel::kBlockPerCu;

Expand Down
35 changes: 21 additions & 14 deletions xformers/csrc/attention/hip_fmha/ck_tiled_fmha_batched_infer.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,24 +23,31 @@ template <
void run_batched_infer_mask_bias_dropout_dispatch(
BatchedForwardParams& param,
hipStream_t stream) {
// currently split-kv implementation does not support dropout
// currently split-kv implementation does not support:
// (*) dropout
// (*) head dimension > 256
if constexpr (!kHasDropout) {
if (param.use_split_kv) {
if (use_splitkv_smallq(param.M, std::max(param.K, param.Kv))) {
batched_infer_splitkv_smallq_mask_bias_dropout_dispatch<
ScalarType,
kHasMask,
kHasBias,
MaxK>::Run(param, stream);
} else {
FMHA_FWD_SEQLEN_Q_SWITCH(param.M, MaxSeqlenQ, [&] {
batched_infer_splitkv_mask_bias_dropout_dispatch<
if (param.use_split_kv && MaxK <= 256) {
if constexpr (MaxK <= 256) {
if (use_splitkv_smallq(param.M, std::max(param.K, param.Kv))) {
batched_infer_splitkv_smallq_mask_bias_dropout_dispatch<
ScalarType,
kHasMask,
kHasBias,
MaxK,
MaxSeqlenQ>::Run(param, stream);
});
MaxK>::Run(param, stream);
} else {
FMHA_FWD_SEQLEN_Q_SWITCH(param.M, MaxSeqlenQ, [&] {
batched_infer_splitkv_mask_bias_dropout_dispatch<
ScalarType,
kHasMask,
kHasBias,
MaxK,
MaxSeqlenQ>::Run(param, stream);
});
}
} else {
// Unreachable. Do not instantiate split-kv pipelines with head
// dimension > 256
}
} else {
if (get_fmha_fwd_mtile(param.B, param.Hq, param.M) == 128)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ struct batched_infer_mask_bias_dropout_dispatch {

using FmhaShape = typename FmhaFwdShape<MaxK, MTile>::Type;
constexpr ck_tile::index_t occupancy =
(MaxK == 64) ? 3 : ((MaxK == 256) ? 1 : 2);
(MaxK == 64) ? 3 : ((MaxK >= 256) ? 1 : 2);

constexpr auto kBiasEnum = kHasBias
? ck_tile::BlockAttentionBiasEnum::ELEMENTWISE_BIAS
Expand Down Expand Up @@ -92,8 +92,10 @@ struct batched_infer_mask_bias_dropout_dispatch {
using FmhaPipelineProblem =
FmhaPipelineProblemTemp<FmhaTraits, FmhaMask>;

using FmhaPipeline =
ck_tile::BlockFmhaPipelineQRKSVS<FmhaPipelineProblem>;
using FmhaPipeline = std::conditional_t<
MaxK <= 256,
ck_tile::BlockFmhaPipelineQRKSVS<FmhaPipelineProblem>,
ck_tile::BlockFmhaPipelineQSKSVS<FmhaPipelineProblem>>;

using FmhaEpilogue =
ck_tile::Default2DEpilogue<ck_tile::Default2DEpilogueProblem<
Expand Down Expand Up @@ -124,8 +126,10 @@ struct batched_infer_mask_bias_dropout_dispatch {
using FmhaPipelineProblem =
FmhaPipelineProblemTemp<FmhaTraits, FmhaMask>;

using FmhaPipeline =
ck_tile::BlockFmhaPipelineQRKSVSAsync<FmhaPipelineProblem>;
using FmhaPipeline = std::conditional_t<
MaxK <= 256,
ck_tile::BlockFmhaPipelineQRKSVSAsync<FmhaPipelineProblem>,
ck_tile::BlockFmhaPipelineQSKSVS<FmhaPipelineProblem>>;

using FmhaEpilogue =
ck_tile::Default2DEpilogue<ck_tile::Default2DEpilogueProblem<
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@ struct batched_infer_splitkv_mask_bias_dropout_dispatch {
false, // kDoFp8StaticQuant place-holder
false, // kIsPagedKV
kHasUnevenSplits,
false, // kMergeNumHeadGroupsSeqLenQ
occupancy>;

using ODataType =
Expand Down Expand Up @@ -136,6 +137,7 @@ struct batched_infer_splitkv_mask_bias_dropout_dispatch {
false, // kDoFp8StaticQuant place-holder
false, // kIsPagedKV
kHasUnevenSplits,
false, // kMergeNumHeadGroupsSeqLenQ
occupancy>;

using ODataType =
Expand Down Expand Up @@ -318,7 +320,7 @@ struct batched_infer_splitkv_mask_bias_dropout_dispatch {
}();

dim3 kGridSize = FmhaFwdSplitKVKernel::GridSize(
param.B, param.Hq, param.M, param.Kv, param.num_kv_splits);
param.B, param.Hq, param.Hkv, param.M, param.Kv, param.num_kv_splits);
constexpr dim3 kBlockSize = FmhaFwdSplitKVKernel::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = FmhaFwdSplitKVKernel::kBlockPerCu;

Expand Down
Loading