-
Notifications
You must be signed in to change notification settings - Fork 1.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
⛰️ Reduce peak vram consumption with efficient selective log_softmax #2799
⛰️ Reduce peak vram consumption with efficient selective log_softmax #2799
Conversation
…log-softmax approach
See benchmarks here: #2773 (comment) (thanks @qgallouedec ) Notably, the most efficient approach in these benchmarks is not stable with bfloat16, and so we fall back to the approach that loops over log_softmax for bfloat16 and float16. |
examples/scripts/sft_video_llm.py
Outdated
@@ -50,12 +50,12 @@ | |||
|
|||
import requests | |||
import torch | |||
import wandb |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
changed by running precommit
That's a super cool improvement! Thanks! |
trl/core.py
Outdated
logsumexp_values = torch.stack([torch.logsumexp(lg, dim=-1) for lg in logits]) | ||
per_token_logps = selected_logits - logsumexp_values # log_softmax(x_i) = x_i - logsumexp(x) | ||
else: | ||
# logsumexp approach is unstable with bfloat16, fall back to slightly less efficent approach |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nice finding!
Ready for re-review! |
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
Thanks again! |
What does this PR do?
Many TRL Trainers use the same log_softmax -> gather operation to compute a selected set of logprobs. This approach is inefficient b/c it allocates a
bs*seqlen*vocab_size
tensor to hold the logprobs. For modest bs/seqlen/vocab_size this tensor can require >2GB vram. There are a variety of more memory efficient (and faster) approaches.This PR creates a utility function to hold a more efficient implementation of this operation and uses that utility function broadly across TRL.
Before submitting
Pull Request section?
to it if that's the case.
documentation guidelines.
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.