Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

🔧 Optimize GRPO VRAM Usage by Computing Prompt Tokens Just Once #2669

Open
wants to merge 23 commits into
base: main
Choose a base branch
from

Conversation

andyl98
Copy link
Contributor

@andyl98 andyl98 commented Jan 27, 2025

What does this PR do?

"Fixes" issue here

Intuition

When running open-r1 I realize that the VRAM requirement scales poorly with the prompt length. After further investigation with the existing implementation, it becomes obvious that this issue mainly comes from this line

logits = model(input_ids).logits  # (B, L, V)

https://github.com/huggingface/trl/blob/main/trl/trainer/grpo_trainer.py#L332

Where we basically feed in all B * G * (P + C) tokens in to the model to calculate logprobs, where

  • B denotes the batch size
  • G denotes the number of generations per prompt (a.k.a Group size)
  • P denotes the maximum prompt length
  • C denotes the maximum completion length

This is actually a bit counterintuitive, because

  1. We only need the logprobs of the completion tokens for reward/loss calculation.
  2. The prompt are fixed for all G generations and a single forward pass should already give us what we need.

Thus, a more ideal approach should tackle the following:

  1. Only "forward" the prompt once per group
  2. Re-use the cached hidden states for G times

This way, we only need to store information for B * (P + G * C) tokens per batch.

This PR

With this change, I can fit a 7B Qwen model with B = 8, G = 16, P = 5000 and C = 256 without issues, whereas previously, this will absolutely cause OOM.

When comparing with previous approach, I do see there's some difference in logprobs (with scale of 1e-2 or so) in terms of absolute difference but that should be due to undeterministic nature of pytorch batching?

Compatibility

This method should be compatible with future improvements such as using vLLM or SGLang for faster batch generation.

Testing (TBD?)

Tested with reward functions (not reward LLMs) and things are running. Open to contributions.

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a GitHub issue? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines.
  • Did you write any new necessary tests?

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@andyl98
Copy link
Contributor Author

andyl98 commented Jan 27, 2025

Gently adding @qgallouedec to review

@andyl98 andyl98 changed the title 🔧 GRPO VRAM Optimization 🔧 Optimize GRPO VRAM Usage by Compute Prompt Tokens Once Jan 27, 2025
@andyl98 andyl98 changed the title 🔧 Optimize GRPO VRAM Usage by Compute Prompt Tokens Once 🔧 Optimize GRPO VRAM Usage by Computing Prompt Tokens Just Once Jan 27, 2025
@andyl98 andyl98 marked this pull request as draft January 28, 2025 00:44
@andyl98 andyl98 marked this pull request as ready for review January 28, 2025 01:48
@andyl98
Copy link
Contributor Author

andyl98 commented Jan 28, 2025

For some reason I cannot upload images but I let the training job ran one night without any issues.

@qgallouedec
Copy link
Member

Can you profile so that we can compare?

@andyl98
Copy link
Contributor Author

andyl98 commented Jan 28, 2025

Can you profile so that we can compare?

For sure. What are some profiling tools you guys recommend using?
Would torch.cuda.memory_allocated() and torch.cuda.memory_reserved() be enough?

@qgallouedec
Copy link
Member

Whatever profiler you're familiar with. Can be wandb, torch profiler...

@andyl98
Copy link
Contributor Author

andyl98 commented Jan 28, 2025

Whatever profiler you're familiar with. Can be wandb, torch profiler...

Sure let me make 2 side-by-side comparisons and I'll share the results.

@andyl98
Copy link
Contributor Author

andyl98 commented Jan 28, 2025

Ok done. The results match my expectations and I still have no clue why I cannot directly upload images, but I put down all comparisons in the following PDF and it should give ppl a pretty good overview.

Cheers!

GRPO VRAM Diff Investigation.pdf

@Superskyyy Superskyyy mentioned this pull request Jan 28, 2025
5 tasks
@qgallouedec
Copy link
Member

Seems very promising!
What about the training speed?

Copy link

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Super nice!

trl/trainer/utils.py Outdated Show resolved Hide resolved
del prompt_last_logps

# Interleave the past key values for the G times
prompt_out.past_key_values.batch_repeat_interleave(G)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't have the full context, but if this shares memory per repeat (as an expand would) then perfect!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe so!

trl/trainer/utils.py Outdated Show resolved Hide resolved
@ArthurZucker
Copy link

If you have a static cache, you can also specifically compile this part of the code as the shape will never change

@fkxie
Copy link

fkxie commented Jan 29, 2025

Hi @andyl98, Great thanks for your contributions!
When I test your code, it aborts:

utils.py, line 1698, in compute_logps_with_prompt_cache
    prompt_out.past_key_values.batch_repeat_interleave(G)
AttributeError: 'tuple' object has no attribute 'batch_repeat_interleave'

seems past_key_values is a tuple, did I missing something?

@Rocketknight1
Copy link
Member

Rocketknight1 commented Jan 29, 2025

Hmm, I'm not familiar with the TRL code, but is there a danger here that gradient will not be propagated back through the prompt sequence if the cached hidden states for the prompt are computed without gradient? Remember that even if we are not computing a label or loss on the prompt tokens, we still want gradient to flow back to these tokens via their key and value vectors

@qgallouedec
Copy link
Member

That's a good point @Rocketknight1. I'm currently looking into this

@Rocketknight1
Copy link
Member

I think it should still be possible to share the prompt computation and re-use the hidden states! It's just important that those states are computed with gradient, and the graph isn't deleted until you compute the loss on the completion tokens

@andyl98
Copy link
Contributor Author

andyl98 commented Jan 29, 2025

Really appreciate all the feedback! Here're changes I've made

  • Removed the torch.no_grad() context manager when forwarding on prompt in case it blocks gradient flow (need to verify if needed but it doesn't add significant memory overhead) Ok nevermind removing this will break the training script, since use_cache=True only stores minimal inference-time required tensors.
  • Use the num_logits_to_keep as @ArthurZucker suggested. This saves quite a bit of memory

After a bit more research, it seems to me that in RL approaches such as PPO in GRPO, we treat the prompt as a given state/environment and we don't need to calculate policy gradient on them. This means we should be safe to use torch.no_grad() with use_cache to save lots of compute and only process the last logprobs from the prompt tokens.

Another angle to think of this compare to SFT where we need gradient to flow back to prompt is that

SFT

  • Input: The full sequence [prompt + completion] is processed in one forward pass.
  • Loss: Cross-entropy loss is computed only on the completion tokens.
  • Gradients: Backpropagation updates all parameters that contributed to the completion predictions, including those processing the prompt.

For example, if the prompt is "Translate to French: 'Hello'" and the completion is "Bonjour", the model must learn how the prompt tokens ("Translate to French: 'Hello'") influence the prediction of "Bonjour". This requires gradients through the prompt's hidden states.

Contrast with GRPO/PPO

  • The prompt is treated as a fixed "state" (like a game board in RL).
  • Only gradients from policy decisions (completions) are relevant for reward optimization.
  • Detaching the prompt’s hidden states saves memory without harming training, since the policy focuses on adjusting completion-generation behavior, not prompt interpretation.

My conclusion (?)

Consulted with O1/R1 (not the best option I know, but their responses are identical as mine above).

Maybe as an extra, I can do another round of comparison (with 1 using this branch and 1 using master) and just see how the loss/reward curve looks like for let's say 2 epochs?

Performance Testing

Compared with master branch (from yesterday) and @qgallouedec 's PR #2683 today and I did run with PyTorch profiler a bit to get a sense. Here're some results

Experimentation Setting

  • Hardware: 4090 GPU
  • Model: Qwen 1.5B
  • B = 2
  • G = 4
  • P = 660
  • C = 32

Old approach (yesterday)

-------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                             Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg       CPU Mem  Self CPU Mem      CUDA Mem  Self CUDA Mem    # of Calls  
-------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                         aten::mm        18.18%      69.542ms        18.18%      69.542ms     716.928us     136.159ms        32.18%     136.159ms       1.404ms           0 b           0 b       4.05 Gb       4.05 Gb            97  
                                        aten::mul         4.98%      19.049ms         4.98%      19.049ms      86.586us      36.698ms         8.67%      36.698ms     166.809us           0 b           0 b       2.85 Gb       2.85 Gb           220  
                               aten::_log_softmax        15.92%      60.895ms        15.92%      60.895ms       7.612ms      71.630ms        16.93%      71.630ms       8.954ms           0 b           0 b       1.43 Gb       1.43 Gb             8  
                              aten::empty_strided         6.50%      24.862ms         6.50%      24.862ms     246.158us       1.026ms         0.24%       1.026ms      10.158us           0 b           0 b       1.26 Gb       1.26 Gb           101  
                                       aten::silu         3.83%      14.654ms         3.83%      14.654ms     610.583us      29.972ms         7.08%      29.972ms       1.249ms           0 b           0 b       1.10 Gb       1.10 Gb            24  
                                        aten::pow         5.31%      20.322ms         5.31%      20.322ms     414.735us      12.339ms         2.92%      12.789ms     261.000us           0 b           0 b     880.56 Mb     880.56 Mb            49  
                                      aten::empty         0.06%     223.000us         0.06%     223.000us       1.312us       1.862ms         0.44%       1.862ms      10.953us         288 b         288 b     784.39 Mb     784.39 Mb           170  
                                        aten::add         0.70%       2.674ms         0.70%       2.674ms      18.441us      21.883ms         5.17%      21.883ms     150.917us           0 b           0 b     659.22 Mb     659.22 Mb           145  
                                      aten::addmm         1.15%       4.405ms         1.46%       5.586ms      77.583us       5.216ms         1.23%       7.184ms      99.778us           0 b           0 b     266.88 Mb     266.88 Mb            72  
                                        aten::cat         0.24%     918.000us         0.24%     918.000us      18.360us      53.215ms        12.58%      53.215ms       1.064ms           0 b           0 b     243.68 Mb     243.68 Mb            50  
                                        aten::neg         0.22%     853.000us         0.22%     853.000us      17.771us     941.000us         0.22%     941.000us      19.604us           0 b           0 b     118.50 Mb     118.50 Mb            48  
                                   aten::_to_copy         1.74%       6.671ms         8.69%      33.243ms     329.139us       1.984ms         0.47%       5.408ms      53.545us           0 b           0 b       1.29 Gb      26.64 Mb           101  
               aten::_efficient_attention_forward         0.69%       2.625ms         0.71%       2.716ms     113.167us       5.236ms         1.24%       6.683ms     278.458us         384 b          96 b     213.94 Mb      26.47 Mb            24  
                                 aten::empty_like         0.17%     666.000us         0.20%     783.000us      10.875us     850.000us         0.20%       1.256ms      17.444us           0 b           0 b     622.84 Mb      25.92 Mb            72  
                                    aten::resize_         0.00%      16.000us         0.00%      16.000us       8.000us      18.000us         0.00%      18.000us       9.000us           0 b           0 b       8.65 Mb       8.65 Mb             2  
                                       aten::mean         0.31%       1.171ms         0.31%       1.171ms      23.898us     732.000us         0.17%     732.000us      14.939us           0 b           0 b     980.00 Kb     980.00 Kb            49  
                                      aten::rsqrt         0.38%       1.452ms         0.38%       1.452ms      29.633us     508.000us         0.12%     508.000us      10.367us           0 b           0 b     980.00 Kb     980.00 Kb            49  
                                        aten::cos         0.00%      11.000us         0.00%      11.000us      11.000us      10.000us         0.00%      10.000us      10.000us           0 b           0 b     158.00 Kb     158.00 Kb             1  
                                        aten::sin         0.00%       9.000us         0.00%       9.000us       9.000us      13.000us         0.00%      13.000us      13.000us           0 b           0 b     158.00 Kb     158.00 Kb             1  
                                        aten::bmm         0.02%      70.000us         0.02%      70.000us      70.000us      69.000us         0.02%      69.000us      69.000us           0 b           0 b      79.00 Kb      79.00 Kb             1  
                                     aten::gather         9.33%      35.675ms         9.33%      35.679ms       4.460ms     994.000us         0.23%       1.021ms     127.625us           0 b           0 b      12.00 Kb      12.00 Kb             8  
                                  aten::embedding         0.02%      59.000us         0.05%     184.000us     184.000us      41.000us         0.01%     449.000us     449.000us           0 b           0 b       8.64 Mb           0 b             1  
                                    aten::reshape         1.52%       5.832ms         2.19%       8.364ms      34.279us       2.623ms         0.62%       8.985ms      36.824us           0 b           0 b     414.75 Mb           0 b           244  
                                       aten::view         0.16%     608.000us         0.16%     608.000us       1.778us       3.602ms         0.85%       3.602ms      10.532us           0 b           0 b           0 b           0 b           342  
                               aten::index_select         0.02%      69.000us         0.02%      87.000us      87.000us     343.000us         0.08%     365.000us     365.000us           0 b           0 b       8.64 Mb           0 b             1  
                                     aten::arange         0.01%      49.000us         0.02%      90.000us      45.000us      52.000us         0.01%     103.000us      51.500us           0 b           0 b      10.00 Kb           0 b             2  
                                  aten::unsqueeze         0.48%       1.847ms         0.49%       1.858ms      17.204us       1.519ms         0.36%       2.095ms      19.398us           0 b           0 b           0 b           0 b           108  
                                 aten::as_strided         0.02%      93.000us         0.02%      93.000us       0.100us       7.878ms         1.86%       7.878ms       8.453us           0 b           0 b           0 b           0 b           932  
                                      aten::slice         1.28%       4.905ms         1.29%       4.918ms      16.448us       3.443ms         0.81%       6.368ms      21.298us           0 b           0 b           0 b           0 b           299  
                                         aten::to         0.63%       2.396ms         9.32%      35.639ms     232.935us       1.105ms         0.26%       6.513ms      42.569us           0 b           0 b       1.29 Gb           0 b           153  
                                     aten::expand         0.44%       1.680ms         0.44%       1.680ms      13.659us       2.265ms         0.54%       2.961ms      24.073us           0 b           0 b           0 b           0 b           123  
                                      aten::copy_         0.57%       2.192ms         0.57%       2.192ms      12.671us       4.328ms         1.02%       4.328ms      25.017us           0 b           0 b           0 b           0 b           173  
                                     aten::matmul         6.24%      23.865ms        24.97%      95.504ms     974.531us       1.751ms         0.41%     141.257ms       1.441ms           0 b           0 b       4.05 Gb           0 b            98  
                               aten::_unsafe_view         0.07%     261.000us         0.07%     261.000us       1.788us     459.000us         0.11%     459.000us       3.144us           0 b           0 b           0 b           0 b           146  
                                  aten::transpose         1.51%       5.783ms         1.53%       5.836ms      16.122us       3.799ms         0.90%       7.312ms      20.199us           0 b           0 b           0 b           0 b           362  
                                aten::result_type         0.00%       0.000us         0.00%       0.000us       0.000us     155.000us         0.04%     155.000us       3.163us           0 b           0 b           0 b           0 b            49  
                                     aten::linear        15.10%      57.753ms        43.81%     167.564ms     991.503us       3.572ms         0.84%     158.242ms     936.343us           0 b           0 b       4.31 Gb           0 b           169  
                                          aten::t         0.76%       2.914ms         1.53%       5.837ms      34.538us       1.692ms         0.40%       4.697ms      27.793us           0 b           0 b           0 b           0 b           169  
                                      aten::clone         0.42%       1.617ms         0.75%       2.882ms      40.028us       1.600ms         0.38%       4.786ms      66.472us           0 b           0 b     622.84 Mb           0 b            72  
                                 aten::contiguous         0.07%     270.000us         0.29%       1.101ms      45.875us     279.000us         0.07%       2.080ms      86.667us           0 b           0 b     208.09 Mb           0 b            24  
               aten::scaled_dot_product_attention         0.14%     517.000us         1.74%       6.661ms     277.542us     176.000us         0.04%      10.428ms     434.500us         384 b           0 b     213.94 Mb           0 b            24  
    aten::_scaled_dot_product_efficient_attention         0.56%       2.147ms         1.61%       6.144ms     256.000us     638.000us         0.15%      10.252ms     427.167us         384 b           0 b     213.94 Mb           0 b            24  
                                     aten::unbind         0.07%     268.000us         0.15%     564.000us     282.000us      48.000us         0.01%     392.000us     196.000us           0 b           0 b           0 b           0 b             2  
                                     aten::select         0.08%     287.000us         0.08%     296.000us      18.500us     216.000us         0.05%     344.000us      21.500us           0 b           0 b           0 b           0 b            16  
                                aten::log_softmax         0.03%     122.000us        15.95%      61.017ms       7.627ms      56.000us         0.01%      71.686ms       8.961ms           0 b           0 b       1.43 Gb           0 b             8  
                                    aten::squeeze         0.04%     142.000us         0.04%     145.000us      18.125us     102.000us         0.02%     115.000us      14.375us           0 b           0 b           0 b           0 b             8  
                                      aten::stack         0.01%      27.000us         0.02%      85.000us      85.000us       4.000us         0.00%     167.000us     167.000us           0 b           0 b      10.00 Kb           0 b             1  
                                         [memory]         0.00%       0.000us         0.00%       0.000us       0.000us       0.000us         0.00%       0.000us       0.000us           0 b           0 b      -5.30 Gb      -5.30 Gb           724  

My previous approach (before this commit)

-------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                             Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg       CPU Mem  Self CPU Mem      CUDA Mem  Self CUDA Mem    # of Calls  
-------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                         aten::mm         4.15%       4.252ms         4.15%       4.252ms      21.918us      13.953ms        11.15%      13.953ms      71.923us           0 b           0 b       1.22 Gb       1.22 Gb           194  
                                        aten::mul         5.58%       5.713ms         5.58%       5.713ms      12.955us       9.110ms         7.28%       9.110ms      20.658us           0 b           0 b     937.33 Mb     937.33 Mb           441  
                                      aten::empty         0.55%     568.000us         0.55%     568.000us       1.349us       2.172ms         1.74%       2.172ms       5.159us         600 b         600 b     825.25 Mb     825.25 Mb           421  
                              aten::empty_strided         1.65%       1.689ms         1.65%       1.689ms       8.320us       5.890ms         4.71%       5.890ms      29.015us           0 b           0 b     378.82 Mb     378.82 Mb           203  
                                       aten::silu         0.39%     395.000us         0.39%     395.000us       8.229us     492.000us         0.39%     492.000us      10.250us           0 b           0 b     335.11 Mb     335.11 Mb            48  
                                        aten::pow         3.85%       3.945ms         3.85%       3.945ms      40.255us       6.909ms         5.52%       7.567ms      77.214us           0 b           0 b     265.29 Mb     265.29 Mb            98  
                                        aten::add         2.45%       2.509ms         2.45%       2.509ms       8.622us       2.351ms         1.88%       2.351ms       8.079us           0 b           0 b     237.44 Mb     237.44 Mb           291  
                                        aten::cat         1.74%       1.780ms         1.74%       1.780ms      12.109us       2.645ms         2.11%       2.645ms      17.993us           0 b           0 b     148.23 Mb     148.23 Mb           147  
                                      aten::addmm         4.68%       4.795ms         5.84%       5.984ms      41.556us       5.329ms         4.26%       7.494ms      52.042us           0 b           0 b      92.81 Mb      92.81 Mb           144  
                                        aten::neg         0.94%     963.000us         0.94%     963.000us      10.031us       1.508ms         1.21%       1.508ms      15.708us           0 b           0 b      35.77 Mb      35.77 Mb            96  
                                   aten::_to_copy         4.61%       4.724ms         7.75%       7.938ms      39.297us       2.577ms         2.06%      10.484ms      51.901us           0 b           0 b     398.67 Mb      21.22 Mb           202  
                                 aten::empty_like         2.03%       2.081ms         2.41%       2.473ms      12.490us       1.600ms         1.28%       2.687ms      13.571us           0 b           0 b     764.89 Mb      20.88 Mb           198  
                                    aten::resize_         0.04%      42.000us         0.04%      42.000us       7.000us      65.000us         0.05%      65.000us      10.833us           0 b           0 b       3.17 Mb       3.17 Mb             6  
                                         aten::eq         0.08%      78.000us         0.08%      78.000us      26.000us      64.000us         0.05%      64.000us      21.333us           0 b           0 b       1.38 Mb       1.38 Mb             3  
                            aten::constant_pad_nd         0.83%     853.000us         1.52%       1.562ms      65.083us     574.000us         0.46%       1.700ms      70.833us           0 b           0 b       7.50 Mb       1.25 Mb            24  
                                       aten::triu         0.03%      26.000us         0.03%      26.000us      13.000us      29.000us         0.02%      29.000us      14.500us           0 b           0 b     743.50 Kb     743.50 Kb             2  
               aten::_efficient_attention_forward         2.38%       2.437ms         2.51%       2.567ms      53.479us       2.999ms         2.40%       3.939ms      82.062us         768 b         168 b      76.08 Mb     448.00 Kb            48  
                                         aten::gt         0.03%      30.000us         0.03%      30.000us      15.000us      30.000us         0.02%      30.000us      15.000us           0 b           0 b     372.00 Kb     372.00 Kb             2  
                                       aten::mean         0.94%     965.000us         0.94%     965.000us       9.847us     711.000us         0.57%     711.000us       7.255us           0 b           0 b     294.00 Kb     294.00 Kb            98  
                                      aten::rsqrt         0.79%     813.000us         0.79%     813.000us       8.296us       1.089ms         0.87%       1.089ms      11.112us           0 b           0 b     294.00 Kb     294.00 Kb            98  
                                        aten::cos         0.02%      21.000us         0.02%      21.000us      10.500us      27.000us         0.02%      27.000us      13.500us           0 b           0 b     158.00 Kb     158.00 Kb             2  
                                        aten::sin         0.01%      13.000us         0.01%      13.000us       6.500us      18.000us         0.01%      18.000us       9.000us           0 b           0 b     158.00 Kb     158.00 Kb             2  
                                        aten::bmm         0.07%      71.000us         0.07%      71.000us      35.500us      45.000us         0.04%      45.000us      22.500us           0 b           0 b      79.00 Kb      79.00 Kb             2  
                                        aten::all         0.05%      48.000us         0.05%      49.000us      24.500us      49.000us         0.04%      57.000us      28.500us           0 b           0 b       2.00 Kb       2.00 Kb             2  
                                aten::bitwise_not         0.01%      11.000us         0.01%      11.000us      11.000us      11.000us         0.01%      11.000us      11.000us           0 b           0 b       1.50 Kb       1.50 Kb             1  
                                     aten::gather         0.05%      48.000us         0.05%      48.000us      24.000us      17.000us         0.01%      70.000us      35.000us           0 b           0 b       1.00 Kb       1.00 Kb             2  
                                    aten::reshape         5.52%       5.658ms        10.11%      10.353ms      21.129us       4.143ms         3.31%      10.605ms      21.643us           0 b           0 b     545.97 Mb           0 b           490  
                                       aten::view         0.58%     593.000us         0.58%     593.000us       0.810us       3.276ms         2.62%       3.276ms       4.475us           0 b           0 b           0 b           0 b           732  
                               aten::index_select         0.11%     108.000us         0.14%     144.000us      72.000us     109.000us         0.09%     160.000us      80.000us           0 b           0 b       3.16 Mb           0 b             2  
                                     aten::arange         0.15%     151.000us         0.25%     259.000us      32.375us     148.000us         0.12%     300.000us      37.500us           0 b           0 b      31.00 Kb           0 b             8  
                                  aten::unsqueeze         2.69%       2.754ms         2.75%       2.814ms      10.949us       2.391ms         1.91%       3.708ms      14.428us           0 b           0 b           0 b           0 b           257  
                                 aten::as_strided         0.13%     134.000us         0.13%     134.000us       0.062us       9.062ms         7.24%       9.062ms       4.191us           0 b           0 b           0 b           0 b          2162  
                                 aten::is_nonzero         0.03%      28.000us         0.07%      73.000us      73.000us      13.000us         0.01%      78.000us      78.000us           0 b           0 b           0 b           0 b             1  
                                       aten::item         0.01%      14.000us         0.04%      45.000us      45.000us      29.000us         0.02%      65.000us      65.000us           0 b           0 b           0 b           0 b             1  
                        aten::_local_scalar_dense         0.03%      31.000us         0.03%      31.000us      31.000us      36.000us         0.03%      36.000us      36.000us           0 b           0 b           0 b           0 b             1  
                                       aten::full         0.12%     118.000us         0.16%     166.000us      83.000us      28.000us         0.02%      82.000us      41.000us           0 b           0 b     743.50 Kb           0 b             2  
                                      aten::fill_         0.14%     148.000us         0.14%     148.000us       5.692us     289.000us         0.23%     289.000us      11.115us           0 b           0 b           0 b           0 b            26  
                                       aten::mul_         0.03%      27.000us         0.03%      27.000us      13.500us      30.000us         0.02%      30.000us      15.000us           0 b           0 b           0 b           0 b             2  
                                      aten::slice         8.55%       8.756ms         8.59%       8.797ms      10.612us       8.230ms         6.58%      11.979ms      14.450us           0 b           0 b           0 b           0 b           829  
                                     aten::expand         3.40%       3.480ms         3.41%       3.493ms      10.125us       3.231ms         2.58%       4.878ms      14.139us           0 b           0 b           0 b           0 b           345  
                                      aten::clone         4.50%       4.612ms         8.60%       8.811ms      44.500us       2.904ms         2.32%       8.417ms      42.510us           0 b           0 b     764.89 Mb           0 b           198  
                                      aten::copy_         3.41%       3.495ms         3.41%       3.495ms       8.224us       5.220ms         4.17%       5.220ms      12.282us           0 b           0 b           0 b           0 b           425  
                                aten::masked_fill         0.03%      32.000us         0.13%     133.000us     133.000us       5.000us         0.00%      55.000us      55.000us           0 b           0 b       1.37 Mb           0 b             1  
                               aten::masked_fill_         0.01%      15.000us         0.01%      15.000us      15.000us      18.000us         0.01%      18.000us      18.000us           0 b           0 b           0 b           0 b             1  
                                         aten::to         1.88%       1.923ms         9.63%       9.861ms      32.225us       1.586ms         1.27%      12.070ms      39.444us           0 b           0 b     398.67 Mb           0 b           306  
                                     aten::matmul         5.00%       5.119ms        11.23%      11.499ms      58.668us       3.189ms         2.55%      20.354ms     103.847us           0 b           0 b       1.22 Gb           0 b           196  
                               aten::_unsafe_view         0.22%     222.000us         0.22%     222.000us       0.758us     712.000us         0.57%     712.000us       2.430us           0 b           0 b           0 b           0 b           293  
                                  aten::transpose         5.79%       5.934ms         5.81%       5.953ms       8.222us       7.495ms         5.99%       9.773ms      13.499us           0 b           0 b           0 b           0 b           724  
                                aten::result_type         0.00%       0.000us         0.00%       0.000us       0.000us     326.000us         0.26%     326.000us       3.327us           0 b           0 b           0 b           0 b            98  
                                     aten::linear         8.85%       9.070ms        33.30%      34.116ms     100.935us       4.964ms         3.97%      42.947ms     127.062us           0 b           0 b       1.31 Gb           0 b           338  
                                          aten::t         3.29%       3.371ms         5.83%       5.975ms      17.678us       2.938ms         2.35%       8.435ms      24.956us           0 b           0 b           0 b           0 b           338  
                                 aten::contiguous         0.30%     304.000us         2.06%       2.113ms      42.260us     272.000us         0.22%       2.606ms      52.120us           0 b           0 b     148.20 Mb           0 b            50  
               aten::scaled_dot_product_attention         1.64%       1.681ms        10.54%      10.796ms     224.917us       1.034ms         0.83%      11.357ms     236.604us         432 b        -336 b      83.58 Mb           0 b            48  
    aten::_scaled_dot_product_efficient_attention         2.17%       2.220ms         6.24%       6.392ms     133.167us       1.074ms         0.86%       7.481ms     155.854us         768 b           0 b      76.08 Mb           0 b            48  
                                aten::log_softmax         0.01%      13.000us         0.21%     210.000us     105.000us       6.000us         0.00%     490.000us     245.000us           0 b           0 b      72.45 Mb           0 b             2  
                               aten::_log_softmax         0.03%      33.000us         0.19%     197.000us      98.500us     251.000us         0.20%     484.000us     242.000us           0 b           0 b      72.45 Mb           0 b             2  
                          aten::repeat_interleave         2.17%       2.228ms         6.57%       6.731ms     137.367us       1.041ms         0.83%       4.949ms     101.000us           0 b           0 b      67.97 Mb           0 b            49  
                                    aten::flatten         0.65%     669.000us         0.66%     672.000us      13.714us     402.000us         0.32%     531.000us      10.837us           0 b           0 b           0 b           0 b            49  
                                    aten::squeeze         0.01%      14.000us         0.01%      14.000us       7.000us       7.000us         0.01%      17.000us       8.500us           0 b           0 b           0 b           0 b             2  
                                        aten::pad         0.35%     356.000us         1.87%       1.918ms      79.917us     215.000us         0.17%       1.915ms      79.792us           0 b           0 b       7.50 Mb           0 b            24  
                                     aten::narrow         0.14%     146.000us         0.35%     361.000us      15.042us     163.000us         0.13%     436.000us      18.167us           0 b           0 b           0 b           0 b            24  
                                  aten::embedding         0.11%     112.000us         0.31%     320.000us     160.000us      42.000us         0.03%     431.000us     215.500us           0 b           0 b       3.16 Mb      -2.00 Kb             2  
                                         [memory]         0.00%       0.000us         0.00%       0.000us       0.000us       0.000us         0.00%       0.000us       0.000us         -48 b         -48 b      -3.61 Gb      -3.61 Gb          1946  

@andyl98
Copy link
Contributor Author

andyl98 commented Jan 29, 2025

Quentin's new approach #2683

-------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                             Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg       CPU Mem  Self CPU Mem      CUDA Mem  Self CUDA Mem    # of Calls  
-------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                        aten::mul        19.31%     141.976ms        19.31%     141.976ms     645.345us     137.645ms        18.11%     137.645ms     625.659us           0 b           0 b       2.85 Gb       2.85 Gb           220  
                                         aten::mm        16.74%     123.085ms        16.74%     123.085ms       1.269ms     162.452ms        21.37%     162.452ms       1.675ms           0 b           0 b       2.68 Gb       2.68 Gb            97  
                              aten::empty_strided        18.24%     134.109ms        18.24%     134.109ms       1.328ms     125.889ms        16.56%     125.889ms       1.246ms           0 b           0 b       1.25 Gb       1.25 Gb           101  
                                       aten::silu         8.87%      65.203ms         8.87%      65.203ms       2.717ms      79.626ms        10.48%      79.626ms       3.318ms           0 b           0 b       1.10 Gb       1.10 Gb            24  
                                        aten::pow        17.47%     128.482ms        17.47%     128.483ms       2.622ms     112.970ms        14.86%     113.246ms       2.311ms           0 b           0 b     877.69 Mb     877.69 Mb            49  
                                      aten::empty         8.39%      61.655ms         8.39%      61.655ms     360.556us      69.549ms         9.15%      69.549ms     406.719us         288 b         288 b     848.43 Mb     848.43 Mb           171  
                                        aten::add         0.38%       2.792ms         0.38%       2.792ms      19.255us       1.920ms         0.25%       1.920ms      13.241us           0 b           0 b     655.46 Mb     655.46 Mb           145  
                                      aten::addmm         0.42%       3.113ms         0.58%       4.281ms      59.458us       3.818ms         0.50%       5.348ms      74.278us           0 b           0 b     281.00 Mb     281.00 Mb            72  
                                        aten::cat         0.09%     691.000us         0.09%     691.000us      13.820us     988.000us         0.13%     988.000us      19.760us           0 b           0 b     239.90 Mb     239.90 Mb            50  
                                        aten::neg         1.50%      11.055ms         1.50%      11.055ms     230.312us      11.046ms         1.45%      11.046ms     230.125us           0 b           0 b     126.09 Mb     126.09 Mb            48  
                               aten::_log_softmax         2.25%      16.546ms         2.25%      16.546ms       2.068ms      16.281ms         2.14%      16.281ms       2.035ms           0 b           0 b      74.19 Mb      74.19 Mb             8  
                                   aten::_to_copy         0.58%       4.258ms        18.99%     139.610ms       1.382ms     794.000us         0.10%     128.044ms       1.268ms           0 b           0 b       1.27 Gb      25.92 Mb           101  
                                 aten::empty_like         0.10%     761.000us         0.14%       1.016ms      13.918us     593.000us         0.08%       1.203ms      16.479us           0 b           0 b     652.04 Mb      17.28 Mb            73  
                                    aten::resize_         0.00%      19.000us         0.00%      19.000us       9.500us     194.000us         0.03%     194.000us      97.000us           0 b           0 b       8.65 Mb       8.65 Mb             2  
                                       aten::mean         0.18%       1.338ms         0.18%       1.338ms      27.306us     752.000us         0.10%     752.000us      15.347us           0 b           0 b     980.00 Kb     980.00 Kb            49  
                                      aten::rsqrt         0.13%     992.000us         0.13%     992.000us      20.245us     462.000us         0.06%     462.000us       9.429us           0 b           0 b     980.00 Kb     980.00 Kb            49  
               aten::_efficient_attention_forward         0.48%       3.558ms         8.83%      64.950ms       2.706ms       2.381ms         0.31%      71.316ms       2.971ms         384 b          96 b     213.94 Mb     280.00 Kb            24  
                                        aten::cos         0.00%      10.000us         0.00%      10.000us      10.000us      11.000us         0.00%      11.000us      11.000us           0 b           0 b     158.00 Kb     158.00 Kb             1  
                                        aten::sin         0.00%       8.000us         0.00%       8.000us       8.000us      10.000us         0.00%      10.000us      10.000us           0 b           0 b     158.00 Kb     158.00 Kb             1  
                                        aten::bmm         0.00%      36.000us         0.00%      36.000us      36.000us      36.000us         0.00%      36.000us      36.000us           0 b           0 b      79.00 Kb      79.00 Kb             1  
                                     aten::gather         0.07%     485.000us         0.07%     499.000us      62.375us     253.000us         0.03%     297.000us      37.125us           0 b           0 b       4.00 Kb       4.00 Kb             8  
                                  aten::embedding         0.00%      36.000us         0.02%     117.000us     117.000us       8.000us         0.00%     233.000us     233.000us           0 b           0 b       8.64 Mb           0 b             1  
                                    aten::reshape         0.40%       2.947ms         0.81%       5.939ms      24.340us       1.915ms         0.25%       8.532ms      34.967us           0 b           0 b     444.67 Mb           0 b           244  
                                       aten::view         0.10%     714.000us         0.10%     714.000us       2.094us       3.772ms         0.50%       3.772ms      11.062us           0 b           0 b           0 b           0 b           341  
                               aten::index_select         0.01%      45.000us         0.01%      68.000us      68.000us      25.000us         0.00%     220.000us     220.000us           0 b           0 b       8.64 Mb           0 b             1  
                                     aten::arange         0.00%      24.000us         0.01%      43.000us      21.500us      37.000us         0.00%      73.000us      36.500us           0 b           0 b      10.00 Kb           0 b             2  
                                  aten::unsqueeze         0.17%       1.257ms         0.17%       1.278ms      11.833us       1.440ms         0.19%       2.012ms      18.630us           0 b           0 b           0 b           0 b           108  
                                 aten::as_strided         0.03%     248.000us         0.03%     248.000us       0.266us       4.108ms         0.54%       4.108ms       4.408us           0 b           0 b           0 b           0 b           932  
                                      aten::slice         0.52%       3.855ms         0.53%       3.910ms      13.077us       3.530ms         0.46%       5.105ms      17.074us           0 b           0 b           0 b           0 b           299  
                                         aten::to         0.17%       1.216ms        19.15%     140.826ms     920.431us     947.000us         0.12%     128.991ms     843.078us           0 b           0 b       1.27 Gb           0 b           153  
                                     aten::expand         0.24%       1.796ms         0.25%       1.820ms      14.797us       1.985ms         0.26%       2.516ms      20.455us           0 b           0 b           0 b           0 b           123  
                                      aten::copy_         0.25%       1.852ms         0.25%       1.852ms      10.644us       2.710ms         0.36%       2.710ms      15.575us           0 b           0 b           0 b           0 b           174  
                                     aten::matmul         0.47%       3.488ms        17.44%     128.210ms       1.308ms       1.436ms         0.19%     167.914ms       1.713ms           0 b           0 b       2.68 Gb           0 b            98  
                               aten::_unsafe_view         0.06%     461.000us         0.06%     461.000us       3.136us     569.000us         0.07%     569.000us       3.871us           0 b           0 b           0 b           0 b           147  
                                  aten::transpose         0.61%       4.478ms         0.63%       4.598ms      12.702us       2.775ms         0.37%       4.030ms      11.133us           0 b           0 b           0 b           0 b           362  
                                aten::result_type         0.00%       0.000us         0.00%       0.000us       0.000us      95.000us         0.01%      95.000us       1.939us           0 b           0 b           0 b           0 b            49  
                                     aten::linear         0.74%       5.422ms        19.43%     142.880ms     845.444us       2.910ms         0.38%     179.849ms       1.064ms           0 b           0 b       2.95 Gb           0 b           169  
                                          aten::t         0.29%       2.136ms         0.54%       3.976ms      23.527us       1.260ms         0.17%       2.869ms      16.976us           0 b           0 b           0 b           0 b           169  
                                      aten::clone         0.25%       1.807ms         0.47%       3.432ms      47.014us       1.586ms         0.21%       4.138ms      56.685us           0 b           0 b     652.04 Mb           0 b            73  
                                 aten::contiguous         0.04%     258.000us         0.18%       1.357ms      56.542us     223.000us         0.03%       1.501ms      62.542us           0 b           0 b     207.38 Mb           0 b            24  
               aten::scaled_dot_product_attention         0.05%     360.000us         9.30%      68.377ms       2.849ms     142.000us         0.02%      73.391ms       3.058ms         384 b           0 b     213.94 Mb           0 b            24  
    aten::_scaled_dot_product_efficient_attention         0.24%       1.786ms         9.25%      68.017ms       2.834ms     582.000us         0.08%      73.249ms       3.052ms         384 b           0 b     213.94 Mb           0 b            24  
                                     aten::unbind         0.04%     324.000us         0.08%     599.000us     299.500us      96.000us         0.01%     369.000us     184.500us           0 b           0 b           0 b           0 b             2  
                                     aten::select         0.04%     269.000us         0.04%     275.000us      17.188us     155.000us         0.02%     273.000us      17.062us           0 b           0 b           0 b           0 b            16  
                                aten::log_softmax         0.02%     120.000us         2.27%      16.666ms       2.083ms      46.000us         0.01%      16.327ms       2.041ms           0 b           0 b      74.19 Mb           0 b             8  
                                    aten::squeeze         0.02%     182.000us         0.03%     190.000us      23.750us      31.000us         0.00%      44.000us       5.500us           0 b           0 b           0 b           0 b             8  
                                      aten::stack         0.01%      38.000us         0.01%      63.000us      63.000us      21.000us         0.00%      44.000us      44.000us           0 b           0 b         512 b           0 b             1  
                                         [memory]         0.00%       0.000us         0.00%       0.000us       0.000us       0.000us         0.00%       0.000us       0.000us           0 b           0 b      -3.96 Gb      -3.96 Gb           725  

Mine after this commit

-------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                             Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg       CPU Mem  Self CPU Mem      CUDA Mem  Self CUDA Mem    # of Calls  
-------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                         aten::mm        16.43%     131.428ms        16.43%     131.428ms     677.464us     177.986ms        20.36%     177.986ms     917.454us           0 b           0 b     883.89 Mb     883.89 Mb           194  
                                        aten::mul        15.69%     125.517ms        15.69%     125.517ms     284.619us     146.195ms        16.73%     146.195ms     331.508us           0 b           0 b     876.07 Mb     876.07 Mb           441  
                                      aten::empty         9.24%      73.895ms         9.24%      73.895ms     176.360us     117.505ms        13.44%     117.505ms     280.442us         728 b         728 b     743.00 Mb     743.00 Mb           419  
                              aten::empty_strided        12.45%      99.564ms        12.45%      99.564ms     490.463us      99.619ms        11.40%      99.619ms     490.734us           0 b           0 b     380.25 Mb     380.25 Mb           203  
                                       aten::silu         7.83%      62.628ms         7.83%      62.628ms       1.305ms      74.492ms         8.52%      74.492ms       1.552ms           0 b           0 b     344.19 Mb     344.19 Mb            48  
                                        aten::pow        12.34%      98.706ms        12.34%      98.706ms       1.007ms      80.462ms         9.21%      81.197ms     828.541us           0 b           0 b     246.73 Mb     246.73 Mb            98  
                                        aten::add         0.82%       6.562ms         0.82%       6.562ms      22.550us       4.767ms         0.55%       4.767ms      16.381us           0 b           0 b     206.86 Mb     206.86 Mb           291  
                                        aten::cat         0.43%       3.469ms         0.43%       3.469ms      23.599us       3.973ms         0.45%       3.973ms      27.027us           0 b           0 b     150.40 Mb     150.40 Mb           147  
                                      aten::addmm         1.21%       9.707ms         1.54%      12.341ms      85.701us       8.279ms         0.95%      11.663ms      80.993us           0 b           0 b      85.61 Mb      85.61 Mb           144  
                               aten::_log_softmax         2.74%      21.911ms         2.74%      21.911ms      10.956ms      17.248ms         1.97%      17.248ms       8.624ms           0 b           0 b      74.77 Mb      74.77 Mb             2  
                                        aten::neg         0.26%       2.057ms         0.26%       2.057ms      21.427us       2.853ms         0.33%       2.853ms      29.719us           0 b           0 b      35.77 Mb      35.77 Mb            96  
                                    aten::resize_         0.01%      40.000us         0.01%      40.000us       6.667us      75.000us         0.01%      75.000us      12.500us           0 b           0 b       2.54 Mb       2.54 Mb             6  
                                         aten::eq         0.01%      80.000us         0.01%      80.000us      26.667us      85.000us         0.01%      85.000us      28.333us           0 b           0 b       1.38 Mb       1.38 Mb             3  
                                       aten::triu         0.00%      35.000us         0.00%      35.000us      17.500us      45.000us         0.01%      45.000us      22.500us           0 b           0 b     743.50 Kb     743.50 Kb             2  
                                         aten::gt         0.01%      45.000us         0.01%      45.000us      22.500us      54.000us         0.01%      54.000us      27.000us           0 b           0 b     372.00 Kb     372.00 Kb             2  
                                       aten::mean         0.30%       2.424ms         0.30%       2.424ms      24.735us       1.304ms         0.15%       1.304ms      13.306us           0 b           0 b     294.00 Kb     294.00 Kb            98  
                                      aten::rsqrt         0.26%       2.043ms         0.26%       2.043ms      20.847us       1.626ms         0.19%       1.626ms      16.592us           0 b           0 b     294.00 Kb     294.00 Kb            98  
                                        aten::cos         0.00%      35.000us         0.00%      35.000us      17.500us      42.000us         0.00%      42.000us      21.000us           0 b           0 b     158.00 Kb     158.00 Kb             2  
                                        aten::sin         0.00%      31.000us         0.00%      31.000us      15.500us      39.000us         0.00%      39.000us      19.500us           0 b           0 b     158.00 Kb     158.00 Kb             2  
                                        aten::bmm         0.01%      88.000us         0.01%      88.000us      44.000us     101.000us         0.01%     101.000us      50.500us           0 b           0 b      79.00 Kb      79.00 Kb             2  
                                        aten::all         0.01%      56.000us         0.01%      57.000us      28.500us      60.000us         0.01%      67.000us      33.500us           0 b           0 b       2.00 Kb       2.00 Kb             2  
                                aten::bitwise_not         0.00%      21.000us         0.00%      21.000us      21.000us      26.000us         0.00%      26.000us      26.000us           0 b           0 b       1.50 Kb       1.50 Kb             1  
                                     aten::gather         0.01%     119.000us         0.02%     123.000us      61.500us      62.000us         0.01%      78.000us      39.000us           0 b           0 b       1.00 Kb       1.00 Kb             2  
                                    aten::reshape         1.33%      10.662ms        11.48%      91.829ms     187.406us       6.181ms         0.71%     131.618ms     268.608us           0 b           0 b     534.24 Mb           0 b           490  
                                       aten::view         0.40%       3.172ms         0.40%       3.172ms       4.339us       8.341ms         0.95%       8.341ms      11.410us           0 b           0 b           0 b           0 b           731  
                               aten::index_select         0.01%     114.000us         0.02%     145.000us      72.500us     102.000us         0.01%     161.000us      80.500us           0 b           0 b       2.52 Mb           0 b             2  
                                     aten::arange         0.03%     203.000us         0.04%     339.000us      42.375us     181.000us         0.02%     372.000us      46.500us           0 b           0 b      31.00 Kb           0 b             8  
                                  aten::unsqueeze         0.56%       4.461ms         0.59%       4.708ms      18.319us       4.424ms         0.51%       6.021ms      23.428us           0 b           0 b           0 b           0 b           257  
                                 aten::as_strided         0.27%       2.121ms         0.27%       2.121ms       0.984us      22.779ms         2.61%      22.779ms      10.565us           0 b           0 b           0 b           0 b          2156  
                                 aten::is_nonzero         0.00%      25.000us         0.01%      70.000us      70.000us      14.000us         0.00%      75.000us      75.000us           0 b           0 b           0 b           0 b             1  
                                       aten::item         0.00%      14.000us         0.01%      45.000us      45.000us      26.000us         0.00%      61.000us      61.000us           0 b           0 b           0 b           0 b             1  
                        aten::_local_scalar_dense         0.00%      31.000us         0.00%      31.000us      31.000us      35.000us         0.00%      35.000us      35.000us           0 b           0 b           0 b           0 b             1  
                                       aten::full         0.01%      59.000us         0.01%      98.000us      49.000us      48.000us         0.01%     109.000us      54.500us           0 b           0 b     743.50 Kb           0 b             2  
                                      aten::fill_         0.04%     343.000us         0.04%     343.000us      13.192us     284.000us         0.03%     284.000us      10.923us           0 b           0 b           0 b           0 b            26  
                                       aten::mul_         0.00%      34.000us         0.00%      34.000us      17.000us      45.000us         0.01%      45.000us      22.500us           0 b           0 b           0 b           0 b             2  
                                      aten::slice         3.24%      25.908ms         3.34%      26.732ms      32.481us      14.685ms         1.68%      29.851ms      36.271us           0 b           0 b           0 b           0 b           823  
                                     aten::expand         2.15%      17.162ms         2.19%      17.512ms      50.759us      16.360ms         1.87%      18.636ms      54.017us           0 b           0 b           0 b           0 b           345  
                                      aten::clone         0.83%       6.630ms        10.75%      85.953ms     438.536us       4.821ms         0.55%     126.932ms     647.612us           0 b           0 b     670.20 Mb           0 b           196  
                                 aten::empty_like         0.40%       3.194ms         9.55%      76.405ms     389.821us       2.775ms         0.32%     118.547ms     604.832us           0 b           0 b     670.20 Mb           0 b           196  
                                      aten::copy_         0.77%       6.124ms         0.77%       6.124ms      14.478us       6.428ms         0.74%       6.428ms      15.196us           0 b           0 b           0 b           0 b           423  
                                aten::masked_fill         0.01%      42.000us         0.02%     135.000us     135.000us      37.000us         0.00%     139.000us     139.000us           0 b           0 b       1.45 Mb           0 b             1  
                               aten::masked_fill_         0.00%      14.000us         0.00%      14.000us      14.000us      18.000us         0.00%      18.000us      18.000us           0 b           0 b           0 b           0 b             1  
                                         aten::to         0.45%       3.566ms        14.23%     113.823ms     371.971us       2.572ms         0.29%     109.228ms     356.954us           0 b           0 b     378.80 Mb           0 b           306  
                                   aten::_to_copy         0.98%       7.856ms        13.78%     110.257ms     545.827us       4.555ms         0.52%     106.656ms     528.000us           0 b           0 b     378.80 Mb           0 b           202  
                                     aten::matmul         1.34%      10.702ms        18.42%     147.355ms     751.811us       4.109ms         0.47%     188.447ms     961.464us           0 b           0 b     883.97 Mb           0 b           196  
                               aten::_unsafe_view         0.15%       1.231ms         0.15%       1.231ms       4.201us       1.300ms         0.15%       1.300ms       4.437us           0 b           0 b           0 b           0 b           293  
                                  aten::transpose         1.59%      12.747ms         1.68%      13.440ms      18.564us       9.698ms         1.11%      13.389ms      18.493us           0 b           0 b           0 b           0 b           724  
                                aten::result_type         0.00%       0.000us         0.00%       0.000us       0.000us     401.000us         0.05%     401.000us       4.092us           0 b           0 b           0 b           0 b            98  
                                     aten::linear         2.16%      17.273ms        24.20%     193.612ms     572.817us       9.822ms         1.12%     224.641ms     664.618us           0 b           0 b     969.51 Mb           0 b           338  
                                          aten::t         0.88%       7.049ms         1.60%      12.819ms      37.926us       4.884ms         0.56%      10.953ms      32.405us           0 b           0 b           0 b           0 b           338  
                                 aten::contiguous         0.10%     783.000us         0.52%       4.134ms      86.125us     595.000us         0.07%       3.765ms      78.438us           0 b           0 b      64.89 Mb           0 b            48  
               aten::scaled_dot_product_attention         0.38%       3.049ms         2.61%      20.884ms     435.083us       1.664ms         0.19%      16.392ms     341.500us         768 b           0 b      73.53 Mb           0 b            48  
    aten::_scaled_dot_product_efficient_attention         0.57%       4.534ms         1.55%      12.367ms     257.646us       1.935ms         0.22%      10.446ms     217.625us         768 b           0 b      66.03 Mb           0 b            48  
               aten::_efficient_attention_forward         0.47%       3.763ms         0.54%       4.310ms      89.792us       3.969ms         0.45%       5.436ms     113.250us         768 b          40 b      66.03 Mb           0 b            48  
                             aten::_reshape_alias         0.00%       5.000us         0.00%       5.000us       5.000us      11.000us         0.00%      11.000us      11.000us           0 b           0 b           0 b           0 b             1  
                                aten::log_softmax         0.00%      34.000us         2.74%      21.945ms      10.973ms       8.000us         0.00%      17.256ms       8.628ms           0 b           0 b      74.77 Mb           0 b             2  
                          aten::repeat_interleave         0.38%       3.007ms         1.16%       9.317ms     190.143us       1.992ms         0.23%       9.505ms     193.980us           0 b           0 b      68.16 Mb           0 b            49  
                                    aten::flatten         0.09%     701.000us         0.11%     895.000us      18.265us     566.000us         0.06%     959.000us      19.571us           0 b           0 b           0 b           0 b            49  
                                    aten::squeeze         0.00%      38.000us         0.01%      40.000us      20.000us      21.000us         0.00%      47.000us      23.500us           0 b           0 b           0 b           0 b             2  
                                        aten::pad         0.07%     526.000us         0.48%       3.861ms     160.875us     271.000us         0.03%       2.646ms     110.250us           0 b           0 b       7.50 Mb           0 b            24  
                            aten::constant_pad_nd         0.22%       1.745ms         0.42%       3.335ms     138.958us     764.000us         0.09%       2.375ms      98.958us           0 b           0 b       7.50 Mb           0 b            24  
                                     aten::narrow         0.05%     414.000us         0.10%     828.000us      34.500us     412.000us         0.05%     829.000us      34.542us           0 b           0 b           0 b           0 b            24  
                                  aten::embedding         0.01%     119.000us         0.05%     378.000us     189.000us      63.000us         0.01%     340.000us     170.000us           0 b           0 b       2.52 Mb      -2.00 Kb             2  
                                         [memory]         0.00%       0.000us         0.00%       0.000us       0.000us       0.000us         0.00%       0.000us       0.000us           0 b           0 b      -1.37 Gb      -1.37 Gb          1547  

@andyl98
Copy link
Contributor Author

andyl98 commented Jan 29, 2025

So I think this PR saves quite a bit of memory. The static cache/dynamic cache suggestion is a good one but currently I don't have bandwidth to look at it :) Will work on it in the future if we think it's a viable path.

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@qgallouedec
Copy link
Member

Thank you for this detailed report. The rather large tables are not very readable, could you perhaps give a more synthetic overview? Maybe the final sec/it that you get with a training script?

@andyl98
Copy link
Contributor Author

andyl98 commented Jan 29, 2025

Speed-wise (model = 1.5b, B = 1, P = 2k, C = 256, G = 4)

New version:

{'loss': 0.0, 'grad_norm': 1.8341439962387085, 'learning_rate': 9.997435897435897e-07, 'completion_length': 146.625, 'rewards/concise_reward': -575.4375, 'rewards/format_reward': 0.0, 'reward': -575.4375, 'reward_std': 236.35992431640625, 'kl': 0.0, 'epoch': 0.03}
{'loss': 0.0, 'grad_norm': 2.196131944656372, 'learning_rate': 9.994871794871794e-07, 'completion_length': 201.4375, 'rewards/concise_reward': -775.375, 'rewards/format_reward': 0.0, 'reward': -775.375, 'reward_std': 332.21124267578125, 'kl': 1.6450881958007812e-05, 'epoch': 0.05}
{'loss': 0.0, 'grad_norm': 1.8303061723709106, 'learning_rate': 9.992307692307693e-07, 'completion_length': 123.5625, 'rewards/concise_reward': -498.0625, 'rewards/format_reward': 0.0, 'reward': -498.0625, 'reward_std': 255.6620635986328, 'kl': 3.647804260253906e-05, 'epoch': 0.08}
{'loss': -0.0, 'grad_norm': 2.763676166534424, 'learning_rate': 9.98974358974359e-07, 'completion_length': 109.9375, 'rewards/concise_reward': -444.1875, 'rewards/format_reward': 0.0, 'reward': -444.1875, 'reward_std': 314.7947082519531, 'kl': -2.9325485229492188e-05, 'epoch': 0.1}
{'loss': -0.0, 'grad_norm': 2.452415704727173, 'learning_rate': 9.987179487179487e-07, 'completion_length': 146.4375, 'rewards/concise_reward': -603.4375, 'rewards/format_reward': 0.0, 'reward': -603.4375, 'reward_std': 300.88116455078125, 'kl': -5.555152893066406e-05, 'epoch': 0.13}
  0%|▏                                                                                                                            | 5/3900 [01:18<16:06:09, 14.88s/it

Original version:

{'loss': 0.0, 'grad_norm': 2.3698227405548096, 'learning_rate': 9.997435897435897e-07, 'completion_length': 146.625, 'rewards/concise_reward': -575.4375, 'rewards/format_reward': 0.0, 'reward': -575.4375, 'reward_std': 236.35992431640625, 'kl': 0.0, 'epoch': 0.03}
{'loss': -0.0, 'grad_norm': 3.092586040496826, 'learning_rate': 9.994871794871794e-07, 'completion_length': 199.6875, 'rewards/concise_reward': -748.1875, 'rewards/format_reward': 0.0, 'reward': -748.1875, 'reward_std': 301.6230163574219, 'kl': -3.0279159545898438e-05, 'epoch': 0.05}
{'loss': -0.0, 'grad_norm': 2.243756055831909, 'learning_rate': 9.992307692307693e-07, 'completion_length': 116.25, 'rewards/concise_reward': -460.375, 'rewards/format_reward': 0.0, 'reward': -460.375, 'reward_std': 244.493408203125, 'kl': -2.3692846298217773e-06, 'epoch': 0.08}
{'loss': 0.0, 'grad_norm': 2.633629560470581, 'learning_rate': 9.98974358974359e-07, 'completion_length': 121.5, 'rewards/concise_reward': -498.0, 'rewards/format_reward': 0.0, 'reward': -498.0, 'reward_std': 289.4837646484375, 'kl': 5.9604644775390625e-05, 'epoch': 0.1}
  0%|▏                                                                                                                            | 4/3900 [01:01<15:48:59, 14.61s/it]

Memory usage with the new version is about 33%.

I think the new version will be a little slower (because forward twice instead of once) but I'm sure the bottleneck usually comes from the "generation" part and thus this small difference can be dismissed.

@andyl98 andyl98 requested a review from qgallouedec January 29, 2025 22:27
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants