Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

⛰️ Reduce peak vram consumption with efficient selective log_softmax #2799

Merged

Conversation

tyler-romero
Copy link
Contributor

@tyler-romero tyler-romero commented Feb 7, 2025

What does this PR do?

Many TRL Trainers use the same log_softmax -> gather operation to compute a selected set of logprobs. This approach is inefficient b/c it allocates a bs*seqlen*vocab_size tensor to hold the logprobs. For modest bs/seqlen/vocab_size this tensor can require >2GB vram. There are a variety of more memory efficient (and faster) approaches.

This PR creates a utility function to hold a more efficient implementation of this operation and uses that utility function broadly across TRL.

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a GitHub issue? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines.
  • Did you write any new necessary tests?

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@tyler-romero
Copy link
Contributor Author

tyler-romero commented Feb 7, 2025

See benchmarks here: #2773 (comment) (thanks @qgallouedec )

Notably, the most efficient approach in these benchmarks is not stable with bfloat16, and so we fall back to the approach that loops over log_softmax for bfloat16 and float16.

@tyler-romero tyler-romero marked this pull request as ready for review February 7, 2025 19:29
@tyler-romero
Copy link
Contributor Author

@qgallouedec

@@ -50,12 +50,12 @@

import requests
import torch
import wandb
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

changed by running precommit

tests/test_core.py Outdated Show resolved Hide resolved
tests/test_core.py Outdated Show resolved Hide resolved
@qgallouedec
Copy link
Member

That's a super cool improvement! Thanks!
Just some minor rewarks to adresse and we're good to merge

trl/core.py Outdated Show resolved Hide resolved
trl/core.py Outdated
logsumexp_values = torch.stack([torch.logsumexp(lg, dim=-1) for lg in logits])
per_token_logps = selected_logits - logsumexp_values # log_softmax(x_i) = x_i - logsumexp(x)
else:
# logsumexp approach is unstable with bfloat16, fall back to slightly less efficent approach
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice finding!

@tyler-romero
Copy link
Contributor Author

Ready for re-review!

@qgallouedec qgallouedec changed the title Reduce peak vram consumption with efficient selective log_softmax ⛰️ Reduce peak vram consumption with efficient selective log_softmax Feb 7, 2025
@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@qgallouedec
Copy link
Member

Thanks again!

@qgallouedec qgallouedec merged commit 09eefa7 into huggingface:main Feb 7, 2025
13 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants