-
Notifications
You must be signed in to change notification settings - Fork 227
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add batch splitting in attention layer to hide NIC latency(#14) #1640
Conversation
kalyank007
commented
Dec 19, 2024
huggingface#14) - Introduced the `--attn_batch_split` parameter to enable batch splitting in the attention and mlp layer. - This approach aims to overlap communication and computation, effectively hiding NIC latency during distributed attention operations. - Perform the add in the beginning of the next layer for better pipelining - Updated Readme - [SW-212702] Fix the attn_batch_split argument specific to llama config (huggingface#74) Co-authored-by: Kalyan <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this PR an improvement of the current example run? If so, can you please post some number to show how much it can improve?
And can you please run CI to confirm this doesn't break anything?
python -m pytest tests/test_text_generation_example.py -s -v -k llama
python -m pytest tests/test_examples.py -s -v -k llama
@@ -1274,7 +1362,13 @@ def forward( | |||
valid_sequence_lengths=valid_sequence_lengths, | |||
cache_idx=cache_idx, | |||
num_virtual_tokens=num_virtual_tokens, | |||
attn_batch_split=attn_batch_split, | |||
prev_layer_residual=prev_layer_residual, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
From the previous review, it seems this need to be changed to,
prev_layer_residual=prev_layer_residual if use_prev_layer_residual else None,
And where prev_layer_residual value is set after line 1298?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
prev_layer_residual is set in line 1370 prev_layer_residual = layer_outputs[index], for the first layer it will be None.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you please post CI result also?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@yeonsily Can we mark this conversation resolved ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you please re-run test_examples.py with RUN_SLOW=true GAUDI2_CI=1 ? as llama ones are all skipped and didn't run.
LGTM. |
@kalyank007 Can you make sure the Llama training examples still pass please? |
Some tests failing because of access issues to data-set. Log : |
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM!