Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

⛰️ Reduce peak vram consumption with efficient selective log_softmax #2799

Merged
23 changes: 23 additions & 0 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import numpy as np
import torch
from datasets import load_dataset
from parameterized import parameterized
from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig
from transformers.testing_utils import require_peft
from transformers.utils import is_peft_available
Expand All @@ -32,6 +33,7 @@
generate_model_card,
get_peft_config,
pad,
selective_log_softmax,
)


Expand Down Expand Up @@ -506,3 +508,24 @@ def test_batch_accuracy(self):
)
accuracy = compute_token_accuracy(logits, labels)
self.assertAlmostEqual(accuracy, 0.8)


class TestSelectiveLogSoftmax(unittest.TestCase):
@parameterized.expand([(torch.float64,), (torch.float32,), (torch.float16,), (torch.bfloat16,)])
def test_selective_log_softmax(self, dtype):
"""Test selective_log_softmax with logits of different dtypes"""
vocab_size = 1024
batch_size = 4
seq_len = 32

input_ids = torch.randint(low=0, high=vocab_size, size=(batch_size, seq_len))
logits = torch.randn(batch_size, seq_len, vocab_size, dtype=dtype)

expected_output = torch.gather(logits.log_softmax(-1), dim=-1, index=input_ids.unsqueeze(-1)).squeeze(-1)
actual_output = selective_log_softmax(logits, input_ids)

if dtype in [torch.float16, torch.bfloat16]:
# half-precision dtypes fall back to an exact method
self.assertTrue(torch.equal(actual_output, expected_output))
else:
torch.testing.assert_close(actual_output, expected_output, rtol=1e-5, atol=1e-5)
11 changes: 7 additions & 4 deletions trl/trainer/bco_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@
log_table_to_comet_experiment,
pad_to_length,
peft_module_casting_to_bf16,
selective_log_softmax,
)


Expand Down Expand Up @@ -897,9 +898,11 @@ def _load_optimizer_and_scheduler(self, checkpoint):
@contextmanager
def null_ref_context(self):
"""Context manager for handling null reference model (that is, peft adapter manipulation)."""
with self.accelerator.unwrap_model(
self.model
).disable_adapter() if self.is_peft_model and not self.ref_adapter_name else nullcontext():
with (
self.accelerator.unwrap_model(self.model).disable_adapter()
if self.is_peft_model and not self.ref_adapter_name
else nullcontext()
):
if self.ref_adapter_name:
self.model.set_adapter(self.ref_adapter_name)
yield
Expand Down Expand Up @@ -1062,7 +1065,7 @@ def get_batch_logps(
# dummy token; we'll ignore the losses on these tokens later
labels[labels == label_pad_token_id] = 0

per_token_logps = torch.gather(logits.log_softmax(-1), dim=2, index=labels.unsqueeze(2)).squeeze(2)
per_token_logps = selective_log_softmax(logits, labels)

if average_log_prob:
return (per_token_logps * loss_mask).sum(-1) / loss_mask.sum(-1)
Expand Down
3 changes: 2 additions & 1 deletion trl/trainer/cpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@
log_table_to_comet_experiment,
pad_to_length,
peft_module_casting_to_bf16,
selective_log_softmax,
)


Expand Down Expand Up @@ -711,7 +712,7 @@ def get_batch_logps(
# dummy token; we'll ignore the losses on these tokens later
labels[labels == label_pad_token_id] = 0

per_token_logps = torch.gather(logits.log_softmax(-1), dim=2, index=labels.unsqueeze(2)).squeeze(2)
per_token_logps = selective_log_softmax(logits, labels)

if average_log_prob:
return (per_token_logps * loss_mask).sum(-1) / loss_mask.sum(-1)
Expand Down
11 changes: 7 additions & 4 deletions trl/trainer/dpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@
pad,
pad_to_length,
peft_module_casting_to_bf16,
selective_log_softmax,
)


Expand Down Expand Up @@ -822,9 +823,11 @@ def get_eval_dataloader(self, eval_dataset: Optional[Dataset] = None) -> DataLoa
@contextmanager
def null_ref_context(self):
"""Context manager for handling null reference model (that is, peft adapter manipulation)."""
with self.accelerator.unwrap_model(
self.model
).disable_adapter() if self.is_peft_model and not self.ref_adapter_name else nullcontext():
with (
self.accelerator.unwrap_model(self.model).disable_adapter()
if self.is_peft_model and not self.ref_adapter_name
else nullcontext()
):
if self.ref_adapter_name:
self.model.set_adapter(self.ref_adapter_name)
yield
Expand Down Expand Up @@ -1211,7 +1214,7 @@ def concatenated_forward(self, model: nn.Module, batch: dict[str, Union[list, to

# Compute the log probabilities of the labels
labels[~loss_mask] = 0 # dummy token; we'll ignore the losses on these tokens later
per_token_logps = torch.gather(logits.log_softmax(-1), dim=2, index=labels.unsqueeze(2)).squeeze(2)
per_token_logps = selective_log_softmax(logits, labels)
per_token_logps[~loss_mask] = 0
per_token_logps = torch.roll(per_token_logps, shifts=1, dims=1)

Expand Down
9 changes: 2 additions & 7 deletions trl/trainer/grpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@
from ..models import create_reference_model, prepare_deepspeed, unwrap_model_for_generation
from .callbacks import SyncRefModelCallback
from .grpo_config import GRPOConfig
from .utils import generate_model_card, get_comet_experiment_url, pad
from .utils import generate_model_card, get_comet_experiment_url, pad, selective_log_softmax


if is_peft_available():
Expand Down Expand Up @@ -442,12 +442,7 @@ def _get_per_token_logps(self, model, input_ids, attention_mask, logits_to_keep)
# For transformers<=4.48, logits_to_keep argument isn't supported, so here we drop logits ourselves.
# See https://github.com/huggingface/trl/issues/2770
logits = logits[:, -logits_to_keep:]

# Compute the log probabilities for the input tokens.
token_logits = logits.gather(dim=-1, index=input_ids.unsqueeze(-1)).squeeze(-1)
logsumexp_values = torch.stack([torch.logsumexp(lg, dim=-1) for lg in logits]) # loop to reduce memory peak
token_log_probs = token_logits - logsumexp_values # log_softmax = logits - log(sum(exp(logits)))
return token_log_probs
return selective_log_softmax(logits, input_ids) # compute logprobs for the input tokens

def _prepare_inputs(self, inputs: dict[str, Union[torch.Tensor, Any]]) -> dict[str, Union[torch.Tensor, Any]]:
device = self.accelerator.device
Expand Down
11 changes: 7 additions & 4 deletions trl/trainer/kto_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@
log_table_to_comet_experiment,
pad_to_length,
peft_module_casting_to_bf16,
selective_log_softmax,
)


Expand Down Expand Up @@ -812,9 +813,11 @@ def _prepare_deepspeed(self, model: PreTrainedModelWrapper):
@contextmanager
def null_ref_context(self):
"""Context manager for handling null reference model (that is, peft adapter manipulation)."""
with self.accelerator.unwrap_model(
self.model
).disable_adapter() if self.is_peft_model and not self.ref_adapter_name else nullcontext():
with (
self.accelerator.unwrap_model(self.model).disable_adapter()
if self.is_peft_model and not self.ref_adapter_name
else nullcontext()
):
if self.ref_adapter_name:
self.model.set_adapter(self.ref_adapter_name)
yield
Expand Down Expand Up @@ -1032,7 +1035,7 @@ def get_batch_logps(
# dummy token; we'll ignore the losses on these tokens later
labels[labels == label_pad_token_id] = 0

per_token_logps = torch.gather(logits.log_softmax(-1), dim=2, index=labels.unsqueeze(2)).squeeze(2)
per_token_logps = selective_log_softmax(logits, labels)

if average_log_prob:
return (per_token_logps * loss_mask).sum(-1) / loss_mask.sum(-1)
Expand Down
4 changes: 2 additions & 2 deletions trl/trainer/nash_md_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
generate_model_card,
get_comet_experiment_url,
get_reward,
selective_log_softmax,
truncate_right,
)

Expand Down Expand Up @@ -277,8 +278,7 @@ def _compute_logprobs(self, model, model_data, context_length):
def compute_logprobs_for_data(m, data):
output = m(data["input_ids"], attention_mask=data["attention_mask"])
logits = output.logits[:, context_length - 1 : -1]
logprobs = F.log_softmax(logits, dim=-1)
token_logprobs = torch.gather(logprobs, 2, data["input_ids"][:, context_length:].unsqueeze(-1)).squeeze(-1)
token_logprobs = selective_log_softmax(logits, data["input_ids"][:, context_length:])
return token_logprobs

# Compute logprobs for model completions under the model
Expand Down
3 changes: 2 additions & 1 deletion trl/trainer/orpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@
log_table_to_comet_experiment,
pad_to_length,
peft_module_casting_to_bf16,
selective_log_softmax,
)


Expand Down Expand Up @@ -718,7 +719,7 @@ def get_batch_logps(
# dummy token; we'll ignore the losses on these tokens later
labels = torch.where(labels == label_pad_token_id, 0, labels)

per_token_logps = torch.gather(logits.log_softmax(-1), dim=2, index=labels.unsqueeze(2)).squeeze(2)
per_token_logps = selective_log_softmax(logits, labels)

if average_log_prob:
return (per_token_logps * loss_mask).sum(-1) / loss_mask.sum(-1)
Expand Down
25 changes: 12 additions & 13 deletions trl/trainer/ppo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
from accelerate import Accelerator
from accelerate.utils import broadcast, gather_object
from datasets import Dataset
Expand Down Expand Up @@ -65,6 +64,7 @@
peft_module_casting_to_bf16,
prepare_deepspeed,
print_rich_table,
selective_log_softmax,
truncate_response,
)

Expand Down Expand Up @@ -310,9 +310,11 @@ def get_eval_dataloader(self) -> DataLoader:
@contextmanager
def null_ref_context(self):
"""Context manager for handling null reference model (that is, peft adapter manipulation)."""
with self.accelerator.unwrap_model(
self.model.policy
).disable_adapter() if self.is_peft_model and not self.ref_adapter_name else nullcontext():
with (
self.accelerator.unwrap_model(self.model.policy).disable_adapter()
if self.is_peft_model and not self.ref_adapter_name
else nullcontext()
):
if self.ref_adapter_name:
self.model.policy.set_adapter(self.ref_adapter_name)
yield
Expand Down Expand Up @@ -427,9 +429,8 @@ def repeat_generator():
query_response = query_responses[i : i + args.local_rollout_forward_batch_size]
response = query_response[:, context_length:]
logits = logitss[i : i + args.local_rollout_forward_batch_size]
all_logprob = F.log_softmax(logits, dim=-1)
logprob = torch.gather(all_logprob, 2, response.unsqueeze(-1)).squeeze(-1)
del logits, all_logprob
logprob = selective_log_softmax(logits, response)
del logits
torch.cuda.empty_cache()

if ref_policy is None:
Expand All @@ -439,9 +440,8 @@ def repeat_generator():
ref_output = forward(ref_policy, query_response, processing_class.pad_token_id)
ref_logits = ref_output.logits[:, context_length - 1 : -1]
ref_logits /= args.temperature + 1e-7
ref_all_logprob = F.log_softmax(ref_logits, dim=-1)
ref_logprob = torch.gather(ref_all_logprob, 2, response.unsqueeze(-1)).squeeze(-1)
del ref_output, ref_logits, ref_all_logprob
ref_logprob = selective_log_softmax(ref_logits, response)
del ref_output, ref_logits
torch.cuda.empty_cache()

# Response Processing 1. truncate response after the first occurrence of `stop_token_id`
Expand Down Expand Up @@ -547,8 +547,7 @@ def repeat_generator():
output, vpred_temp = forward(model, mb_query_responses, processing_class.pad_token_id)
logits = output.logits[:, context_length - 1 : -1]
logits /= args.temperature + 1e-7
new_all_logprobs = F.log_softmax(logits, dim=-1)
new_logprobs = torch.gather(new_all_logprobs, 2, mb_responses.unsqueeze(-1)).squeeze(-1)
new_logprobs = selective_log_softmax(logits, mb_responses)
new_logprobs = torch.masked_fill(
new_logprobs, padding_mask[micro_batch_inds], INVALID_LOGPROB
)
Expand Down Expand Up @@ -599,7 +598,7 @@ def repeat_generator():
# del everything and empty cache
# fmt: off
del (
output, vpred_temp, logits, new_all_logprobs, new_logprobs, vpred, vpredclipped,
output, vpred_temp, logits, new_logprobs, vpred, vpredclipped,
vf_losses1, vf_losses2, vf_loss, vf_clipfrac, logprobs_diff, ratio, pg_losses, pg_losses2, pg_loss_max,
pg_loss, loss, pg_clipfrac, prob_dist, entropy, approxkl, mb_return,
mb_advantage, mb_values, mb_responses, mb_query_responses, mb_logprobs,
Expand Down
20 changes: 8 additions & 12 deletions trl/trainer/rloo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
from accelerate import Accelerator
from accelerate.utils import broadcast, gather_object
from datasets import Dataset
Expand Down Expand Up @@ -56,6 +55,7 @@
get_reward,
prepare_deepspeed,
print_rich_table,
selective_log_softmax,
truncate_response,
)
from .rloo_config import RLOOConfig
Expand Down Expand Up @@ -330,17 +330,15 @@ def repeat_generator():
query_response = query_responses[i : i + args.local_rollout_forward_batch_size]
response = query_response[:, context_length:]
logits = logitss[i : i + args.local_rollout_forward_batch_size]
all_logprob = F.log_softmax(logits, dim=-1)
logprob = torch.gather(all_logprob, 2, response.unsqueeze(-1)).squeeze(-1)
del logits, all_logprob
logprob = selective_log_softmax(logits, response)
del logits
torch.cuda.empty_cache()

ref_output = forward(ref_policy, query_response, processing_class.pad_token_id)
ref_logits = ref_output.logits[:, context_length - 1 : -1]
ref_logits /= args.temperature + 1e-7
ref_all_logprob = F.log_softmax(ref_logits, dim=-1)
ref_logprob = torch.gather(ref_all_logprob, 2, response.unsqueeze(-1)).squeeze(-1)
del ref_output, ref_logits, ref_all_logprob
ref_logprob = selective_log_softmax(ref_logits, response)
del ref_output, ref_logits
torch.cuda.empty_cache()

# Response Processing 1. truncate response after the first occurrence of `stop_token_id`
Expand Down Expand Up @@ -467,8 +465,7 @@ def repeat_generator():
logits /= args.temperature + 1e-7

# Compute new logprobs
new_all_logprobs = F.log_softmax(logits, dim=-1)
new_logprobs = torch.gather(new_all_logprobs, 2, mb_responses.unsqueeze(-1)).squeeze(-1)
new_logprobs = selective_log_softmax(logits, mb_responses)
new_logprobs = torch.masked_fill(
new_logprobs, padding_mask[micro_batch_inds], INVALID_LOGPROB
)
Expand Down Expand Up @@ -512,9 +509,8 @@ def repeat_generator():
# del everything and empty cache
# fmt: off
del (
output, logits, new_all_logprobs, new_logprobs,
logprobs_diff, ratio, pg_losses, pg_losses2,
pg_loss, loss, pg_clipfrac, prob_dist, entropy, approxkl,
output, logits, new_logprobs, logprobs_diff, ratio, pg_losses,
pg_losses2, pg_loss, loss, pg_clipfrac, prob_dist, entropy, approxkl,
mb_advantage, mb_responses, mb_query_responses, mb_logprobs,
)
# fmt: on
Expand Down
36 changes: 36 additions & 0 deletions trl/trainer/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import numpy as np
import pandas as pd
import torch
import torch.nn.functional as F
import torch.utils.data
from accelerate import Accelerator, PartialState
from accelerate.state import AcceleratorState
Expand Down Expand Up @@ -1668,3 +1669,38 @@ def compute_token_accuracy(logits: torch.Tensor, labels: torch.Tensor, ignore_in
accuracy = correct_tokens.item() / total_tokens.item() if total_tokens > 0 else 0.0

return accuracy


def selective_log_softmax(logits, index):
"""
A memory-efficient implementation of the common `log_softmax -> gather` operation.
This function is equivalent to the following naive implementation:
```python
logps = torch.gather(logits.log_softmax(-1), dim=-1, index=index.unsqueeze(-1)).squeeze(-1)
```
Args:
logits (`torch.Tensor`):
Logits tensor of shape `(..., num_classes)`.
index (`torch.Tensor`):
Index tensor of shape `(...)`, specifying the positions to gather from the log-softmax output.
Returns:
`torch.Tensor`:
Gathered log probabilities with the same shape as `index`.
"""
if logits.dtype in [torch.float32, torch.float64]:
selected_logits = torch.gather(logits, dim=-1, index=index.unsqueeze(-1)).squeeze(-1)
# loop to reduce peak mem consumption
logsumexp_values = torch.stack([torch.logsumexp(lg, dim=-1) for lg in logits])
per_token_logps = selected_logits - logsumexp_values # log_softmax(x_i) = x_i - logsumexp(x)
else:
# logsumexp approach is unstable with bfloat16, fall back to slightly less efficent approach
per_token_logps = []
for row_logits, row_labels in zip(logits, index): # loop to reduce peak mem consumption
row_logps = F.log_softmax(row_logits, dim=-1)
row_per_token_logps = row_logps.gather(dim=-1, index=row_labels.unsqueeze(-1)).squeeze(-1)
per_token_logps.append(row_per_token_logps)
per_token_logps = torch.stack(per_token_logps)
return per_token_logps
Loading