-
Notifications
You must be signed in to change notification settings - Fork 309
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add Context Parallelism support to cudnn Flash Attention #1133
base: main
Are you sure you want to change the base?
Add Context Parallelism support to cudnn Flash Attention #1133
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Awesome, this looks great! Just one change to the sharding rules (with current state of this PR it looks like the old sequence parallelism is broken)
I'd appreciate it if you could also run bash code_style.sh
to run our linter (I recommend saving the branch before running this just in case...)
logical_axis_rules: [ | ||
['activation_batch', ['data', 'fsdp', 'fsdp_transpose', 'expert']], | ||
['activation_batch_no_exp', ['data', 'fsdp', 'fsdp_transpose']], | ||
['activation_embed_and_logits_batch', ['data', 'stage', 'fsdp', 'fsdp_transpose', 'expert']], | ||
['activation_heads', ['tensor','sequence']], | ||
['activation_kv_heads', ['tensor','sequence']], | ||
['activation_length', 'context'], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should shard length by both context and sequence, e.g. only one activation_length
with ['activation_length', ['context', 'sequence']]
We may deprecate sequence in favor of context at some point in the future but for now this should cover both
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Faysal is OOO. Ill make this change for him and re-request a review.
@@ -425,6 +432,8 @@ def cudnn_flash_attention( | |||
scale_factor=1.0 / math.sqrt(head_dim), | |||
transpose_batch_sequence=False, | |||
window_size=sliding_window_size, | |||
context_parallel_causal_load_balanced=True, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There is some additional detail to handle in the framework here for correctness. When we apply causal load balancing the attention expects that tokens are reordered. This is not done in the PR so the output results will not be correct.
Details of this reordering are the same as described in the paragraph: Context parallelism for long sequences in the Llama 3 paper. Essentially the sequences is split into 2*CP
groups and reordered to gather chunks i
and (2 × CP − 1 − i)
We have a helper function in TE that performs this re-ordering or it could be easily added separately here. https://github.com/NVIDIA/TransformerEngine/blob/c9ea6be92948e1ec553037f1a04900617b9f7f6b/transformer_engine/jax/cpp_extensions/attention.py#L1005
We are likely going to employ a different reordering approaches if the user wants to employ sequence packing due the limited masking flexibility that exists in cudnn today. This PR does not intend to add support for sequence packed CP attention yet. We have several approaches but each has some nuance of how it works and may be worth discussing how we can pipe through the appropriate configurations to ensure consistency with cuDNN.
Description
This PR adds Context Parallelism support to GPU Flash Attention. It is necessary to support large sequence lengths in MaxText. Right now, the support is offered through Transformer-Engine and uses an All-Gather type implementation. Note that, it requires mask type to be
causal
and does not work with sliding window attention yet. Also it requirestransformer-engine==1.13
or above.Tests
Unit test is included with the PR with the base model for
4 x a100
gpus,Checklist