-
Notifications
You must be signed in to change notification settings - Fork 1.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
🔧 Optimize GRPO VRAM Usage by Computing Prompt Tokens Just Once #2669
base: main
Are you sure you want to change the base?
Conversation
Gently adding @qgallouedec to review |
… into grpo-vram-optimization
For some reason I cannot upload images but I let the training job ran one night without any issues. |
Can you profile so that we can compare? |
For sure. What are some profiling tools you guys recommend using? |
Whatever profiler you're familiar with. Can be wandb, torch profiler... |
Sure let me make 2 side-by-side comparisons and I'll share the results. |
Ok done. The results match my expectations and I still have no clue why I cannot directly upload images, but I put down all comparisons in the following PDF and it should give ppl a pretty good overview. Cheers! |
Seems very promising! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Super nice!
del prompt_last_logps | ||
|
||
# Interleave the past key values for the G times | ||
prompt_out.past_key_values.batch_repeat_interleave(G) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't have the full context, but if this shares memory per repeat (as an expand would) then perfect!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I believe so!
If you have a static cache, you can also specifically compile this part of the code as the shape will never change |
Hi @andyl98, Great thanks for your contributions!
seems |
Hmm, I'm not familiar with the TRL code, but is there a danger here that gradient will not be propagated back through the prompt sequence if the cached hidden states for the prompt are computed without gradient? Remember that even if we are not computing a label or loss on the prompt tokens, we still want gradient to flow back to these tokens via their key and value vectors |
That's a good point @Rocketknight1. I'm currently looking into this |
I think it should still be possible to share the prompt computation and re-use the hidden states! It's just important that those states are computed with gradient, and the graph isn't deleted until you compute the loss on the completion tokens |
Really appreciate all the feedback! Here're changes I've made
After a bit more research, it seems to me that in RL approaches such as PPO in GRPO, we treat the Another angle to think of this compare to SFT
For example, if the prompt is "Translate to French: 'Hello'" and the completion is "Bonjour", the model must learn how the prompt tokens ("Translate to French: 'Hello'") influence the prediction of "Bonjour". This requires gradients through the prompt's hidden states. Contrast with GRPO/PPO
My conclusion (?)Consulted with O1/R1 (not the best option I know, but their responses are identical as mine above). Maybe as an extra, I can do another round of comparison (with 1 using this branch and 1 using master) and just see how the loss/reward curve looks like for let's say 2 epochs? Performance TestingCompared with master branch (from yesterday) and @qgallouedec 's PR #2683 today and I did run with PyTorch profiler a bit to get a sense. Here're some results Experimentation Setting
Old approach (yesterday)
My previous approach (before this commit)
|
Quentin's new approach #2683
Mine after this commit
|
So I think this PR saves quite a bit of memory. The static cache/dynamic cache suggestion is a good one but currently I don't have bandwidth to look at it :) Will work on it in the future if we think it's a viable path. |
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
Thank you for this detailed report. The rather large tables are not very readable, could you perhaps give a more synthetic overview? Maybe the final sec/it that you get with a training script? |
Speed-wise (model = 1.5b, B = 1, P = 2k, C = 256, G = 4) New version:
Original version:
Memory usage with the new version is about 33%. I think the new version will be a little slower (because forward twice instead of once) but I'm sure the bottleneck usually comes from the "generation" part and thus this small difference can be dismissed. |
This reverts commit 98773d9.
What does this PR do?
"Fixes" issue here
Intuition
When running
open-r1
I realize that the VRAM requirement scales poorly with the prompt length. After further investigation with the existing implementation, it becomes obvious that this issue mainly comes from this linehttps://github.com/huggingface/trl/blob/main/trl/trainer/grpo_trainer.py#L332
Where we basically feed in all
B * G * (P + C)
tokens in to the model to calculate logprobs, whereB
denotes the batch sizeG
denotes the number of generations per prompt (a.k.a Group size)P
denotes the maximum prompt lengthC
denotes the maximum completion lengthThis is actually a bit counterintuitive, because
G
generations and a single forward pass should already give us what we need.Thus, a more ideal approach should tackle the following:
G
timesThis way, we only need to store information for
B * (P + G * C)
tokens per batch.This PR
With this change, I can fit a 7B Qwen model with
B
= 8,G
= 16,P
= 5000 andC
= 256 without issues, whereas previously, this will absolutely cause OOM.When comparing with previous approach, I do see there's some difference in logprobs (with scale of 1e-2 or so) in terms of absolute difference but that should be due to undeterministic nature of pytorch batching?
Compatibility
This method should be compatible with future improvements such as using
vLLM
orSGLang
for faster batch generation.Testing (TBD?)
Tested with reward functions (not reward LLMs) and things are running. Open to contributions.
Before submitting
Pull Request section?
to it if that's the case.
documentation guidelines.
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.