generated from fastai/nbdev_template
-
Notifications
You must be signed in to change notification settings - Fork 1.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
⛰️ Reduce peak vram consumption with efficient selective log_softmax #2799
Merged
qgallouedec
merged 12 commits into
huggingface:main
from
tyler-romero:tyler/extract-logprobs-efficient
Feb 7, 2025
Merged
Changes from 4 commits
Commits
Show all changes
12 commits
Select commit
Hold shift + click to select a range
b244ce9
Reduce mem consumption across many trainers with efficient selective …
tyler-romero 3b1f91f
rename
tyler-romero f0f5300
typo fix
tyler-romero 0d3924a
precommit
tyler-romero fa2d67e
Update tests/test_core.py
tyler-romero 279c262
relocate
tyler-romero 786b202
precommit
tyler-romero 69abfd5
style
qgallouedec 788e82c
smaller values for test, and run on cpu
qgallouedec 9e0d40a
nit doc improvements
qgallouedec f37fa7d
style
qgallouedec 0714d81
fix test
qgallouedec File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -20,6 +20,7 @@ | |
|
||
import numpy as np | ||
import torch | ||
import torch.nn.functional as F | ||
from transformers import is_torch_npu_available, is_torch_xpu_available | ||
|
||
|
||
|
@@ -157,3 +158,27 @@ def randn_tensor( | |
latents = torch.randn(shape, generator=generator, device=rand_device, dtype=dtype, layout=layout).to(device) | ||
|
||
return latents | ||
|
||
|
||
def selective_log_softmax(logits, input_ids): | ||
tyler-romero marked this conversation as resolved.
Show resolved
Hide resolved
|
||
""" | ||
A memory efficient implementation of a common log_softmax -> gather operation. | ||
Equivalent to the following naive implementation: | ||
```python | ||
per_token_logps = torch.gather(logits.log_softmax(-1), dim=-1, index=labels.unsqueeze(-1)).squeeze(-1) | ||
``` | ||
""" | ||
if logits.dtype in [torch.float32, torch.float64]: | ||
selected_logits = torch.gather(logits, dim=-1, index=input_ids.unsqueeze(-1)).squeeze(-1) | ||
# loop to reduce peak mem consumption | ||
logsumexp_values = torch.stack([torch.logsumexp(lg, dim=-1) for lg in logits]) | ||
per_token_logps = selected_logits - logsumexp_values # log_softmax(x_i) = x_i - logsumexp(x) | ||
else: | ||
# logsumexp approach is unstable with bfloat16, fall back to slightly less efficent approach | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nice finding! |
||
per_token_logps = [] | ||
for row_logits, row_labels in zip(logits, input_ids): # loop to reduce peak mem consumption | ||
row_logps = F.log_softmax(row_logits, dim=-1) | ||
row_per_token_logps = row_logps.gather(dim=-1, index=row_labels.unsqueeze(-1)).squeeze(-1) | ||
per_token_logps.append(row_per_token_logps) | ||
per_token_logps = torch.stack(per_token_logps) | ||
return per_token_logps |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
changed by running precommit