Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Context Parallelism support to cudnn Flash Attention #1133

Open
wants to merge 2 commits into
base: main
Choose a base branch
from

Conversation

kocchop
Copy link
Contributor

@kocchop kocchop commented Jan 1, 2025

Description

This PR adds Context Parallelism support to GPU Flash Attention. It is necessary to support large sequence lengths in MaxText. Right now, the support is offered through Transformer-Engine and uses an All-Gather type implementation. Note that, it requires mask type to be causal and does not work with sliding window attention yet. Also it requires transformer-engine==1.13 or above.

Tests

Unit test is included with the PR with the base model for 4 x a100 gpus,

Checklist

  • I have performed a self-review of my code.
  • I have necessary comments in my code, particularly in hard-to-understand areas.
  • I have run end-to-end tests tests and provided workload links above if applicable.
  • I have made or will make corresponding changes to the doc if needed.

Copy link
Collaborator

@gobbleturk gobbleturk left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Awesome, this looks great! Just one change to the sharding rules (with current state of this PR it looks like the old sequence parallelism is broken)

I'd appreciate it if you could also run bash code_style.sh to run our linter (I recommend saving the branch before running this just in case...)

logical_axis_rules: [
['activation_batch', ['data', 'fsdp', 'fsdp_transpose', 'expert']],
['activation_batch_no_exp', ['data', 'fsdp', 'fsdp_transpose']],
['activation_embed_and_logits_batch', ['data', 'stage', 'fsdp', 'fsdp_transpose', 'expert']],
['activation_heads', ['tensor','sequence']],
['activation_kv_heads', ['tensor','sequence']],
['activation_length', 'context'],
Copy link
Collaborator

@gobbleturk gobbleturk Jan 1, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should shard length by both context and sequence, e.g. only one activation_length with ['activation_length', ['context', 'sequence']]

We may deprecate sequence in favor of context at some point in the future but for now this should cover both

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Faysal is OOO. Ill make this change for him and re-request a review.

@@ -425,6 +432,8 @@ def cudnn_flash_attention(
scale_factor=1.0 / math.sqrt(head_dim),
transpose_batch_sequence=False,
window_size=sliding_window_size,
context_parallel_causal_load_balanced=True,

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is some additional detail to handle in the framework here for correctness. When we apply causal load balancing the attention expects that tokens are reordered. This is not done in the PR so the output results will not be correct.

Details of this reordering are the same as described in the paragraph: Context parallelism for long sequences in the Llama 3 paper. Essentially the sequences is split into 2*CP groups and reordered to gather chunks i and (2 × CP − 1 − i)

We have a helper function in TE that performs this re-ordering or it could be easily added separately here. https://github.com/NVIDIA/TransformerEngine/blob/c9ea6be92948e1ec553037f1a04900617b9f7f6b/transformer_engine/jax/cpp_extensions/attention.py#L1005

We are likely going to employ a different reordering approaches if the user wants to employ sequence packing due the limited masking flexibility that exists in cudnn today. This PR does not intend to add support for sequence packed CP attention yet. We have several approaches but each has some nuance of how it works and may be worth discussing how we can pipe through the appropriate configurations to ensure consistency with cuDNN.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants