Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

⛰️ Reduce peak vram consumption with efficient selective log_softmax #2799

Merged
2 changes: 1 addition & 1 deletion examples/scripts/sft_video_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,12 +50,12 @@

import requests
import torch
import wandb
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

changed by running precommit

from datasets import load_dataset
from peft import LoraConfig
from qwen_vl_utils import process_vision_info
from transformers import AutoModelForVision2Seq, AutoProcessor, BitsAndBytesConfig, Qwen2VLProcessor

import wandb
from trl import ModelConfig, ScriptArguments, SFTConfig, SFTTrainer, TrlParser, get_kbit_device_map


Expand Down
25 changes: 24 additions & 1 deletion tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

import torch

from trl.core import masked_mean, masked_var, masked_whiten
from trl.core import masked_mean, masked_var, masked_whiten, selective_log_softmax


class CoreTester(unittest.TestCase):
Expand Down Expand Up @@ -44,3 +44,26 @@ def whiten(values: torch.Tensor) -> torch.Tensor:
whiten_masked = masked_whiten(self.test_input, self.test_mask)[1:3]
diffs = (whiten_unmasked - whiten_masked).sum()
self.assertLess(abs(diffs.item()), 0.00001)

def test_extract_per_token_logprobs(self):
"""Test extract_per_token_logprobs with different dtypes"""
tyler-romero marked this conversation as resolved.
Show resolved Hide resolved
dtypes = [torch.float64, torch.float32, torch.float16, torch.bfloat16]
tyler-romero marked this conversation as resolved.
Show resolved Hide resolved
vocab_size = 32768
batch_size = 4
seq_len = 256

for dtype in dtypes:
with self.subTest(dtype=dtype):
input_ids = torch.randint(low=0, high=vocab_size, size=(batch_size, seq_len), device="cuda")
logits = torch.randn(batch_size, seq_len, vocab_size, device="cuda", dtype=dtype)

expected_output = torch.gather(logits.log_softmax(-1), dim=-1, index=input_ids.unsqueeze(-1)).squeeze(
-1
)
actual_output = selective_log_softmax(logits=logits, input_ids=input_ids)

if dtype in [torch.float16, torch.bfloat16]:
# float16 falls back to an exact method
self.assertTrue(torch.equal(actual_output, expected_output))
else:
torch.testing.assert_close(actual_output, expected_output, rtol=1e-5, atol=1e-5)
25 changes: 25 additions & 0 deletions trl/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

import numpy as np
import torch
import torch.nn.functional as F
from transformers import is_torch_npu_available, is_torch_xpu_available


Expand Down Expand Up @@ -157,3 +158,27 @@ def randn_tensor(
latents = torch.randn(shape, generator=generator, device=rand_device, dtype=dtype, layout=layout).to(device)

return latents


def selective_log_softmax(logits, input_ids):
tyler-romero marked this conversation as resolved.
Show resolved Hide resolved
"""
A memory efficient implementation of a common log_softmax -> gather operation.
Equivalent to the following naive implementation:
```python
per_token_logps = torch.gather(logits.log_softmax(-1), dim=-1, index=labels.unsqueeze(-1)).squeeze(-1)
```
"""
if logits.dtype in [torch.float32, torch.float64]:
selected_logits = torch.gather(logits, dim=-1, index=input_ids.unsqueeze(-1)).squeeze(-1)
# loop to reduce peak mem consumption
logsumexp_values = torch.stack([torch.logsumexp(lg, dim=-1) for lg in logits])
per_token_logps = selected_logits - logsumexp_values # log_softmax(x_i) = x_i - logsumexp(x)
else:
# logsumexp approach is unstable with bfloat16, fall back to slightly less efficent approach
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice finding!

per_token_logps = []
for row_logits, row_labels in zip(logits, input_ids): # loop to reduce peak mem consumption
row_logps = F.log_softmax(row_logits, dim=-1)
row_per_token_logps = row_logps.gather(dim=-1, index=row_labels.unsqueeze(-1)).squeeze(-1)
per_token_logps.append(row_per_token_logps)
per_token_logps = torch.stack(per_token_logps)
return per_token_logps
11 changes: 7 additions & 4 deletions trl/trainer/bco_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@
from transformers.trainer_utils import EvalLoopOutput, has_length
from transformers.utils import is_peft_available

from ..core import selective_log_softmax
from ..data_utils import maybe_apply_chat_template
from ..models import PreTrainedModelWrapper, create_reference_model
from .bco_config import BCOConfig
Expand Down Expand Up @@ -897,9 +898,11 @@ def _load_optimizer_and_scheduler(self, checkpoint):
@contextmanager
def null_ref_context(self):
"""Context manager for handling null reference model (that is, peft adapter manipulation)."""
with self.accelerator.unwrap_model(
self.model
).disable_adapter() if self.is_peft_model and not self.ref_adapter_name else nullcontext():
with (
self.accelerator.unwrap_model(self.model).disable_adapter()
if self.is_peft_model and not self.ref_adapter_name
else nullcontext()
):
if self.ref_adapter_name:
self.model.set_adapter(self.ref_adapter_name)
yield
Expand Down Expand Up @@ -1062,7 +1065,7 @@ def get_batch_logps(
# dummy token; we'll ignore the losses on these tokens later
labels[labels == label_pad_token_id] = 0

per_token_logps = torch.gather(logits.log_softmax(-1), dim=2, index=labels.unsqueeze(2)).squeeze(2)
per_token_logps = selective_log_softmax(logits, labels)

if average_log_prob:
return (per_token_logps * loss_mask).sum(-1) / loss_mask.sum(-1)
Expand Down
3 changes: 2 additions & 1 deletion trl/trainer/cpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
from transformers.trainer_utils import EvalLoopOutput
from transformers.utils import is_peft_available, is_torch_fx_proxy

from ..core import selective_log_softmax
from ..data_utils import maybe_apply_chat_template, maybe_extract_prompt
from .cpo_config import CPOConfig
from .utils import (
Expand Down Expand Up @@ -711,7 +712,7 @@ def get_batch_logps(
# dummy token; we'll ignore the losses on these tokens later
labels[labels == label_pad_token_id] = 0

per_token_logps = torch.gather(logits.log_softmax(-1), dim=2, index=labels.unsqueeze(2)).squeeze(2)
per_token_logps = selective_log_softmax(logits, labels)

if average_log_prob:
return (per_token_logps * loss_mask).sum(-1) / loss_mask.sum(-1)
Expand Down
11 changes: 7 additions & 4 deletions trl/trainer/dpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@
from transformers.utils import is_peft_available, is_torch_xpu_available
from transformers.utils.deprecation import deprecate_kwarg

from ..core import selective_log_softmax
from ..data_utils import maybe_apply_chat_template, maybe_extract_prompt
from ..models import PreTrainedModelWrapper, create_reference_model
from .callbacks import SyncRefModelCallback
Expand Down Expand Up @@ -822,9 +823,11 @@ def get_eval_dataloader(self, eval_dataset: Optional[Dataset] = None) -> DataLoa
@contextmanager
def null_ref_context(self):
"""Context manager for handling null reference model (that is, peft adapter manipulation)."""
with self.accelerator.unwrap_model(
self.model
).disable_adapter() if self.is_peft_model and not self.ref_adapter_name else nullcontext():
with (
self.accelerator.unwrap_model(self.model).disable_adapter()
if self.is_peft_model and not self.ref_adapter_name
else nullcontext()
):
if self.ref_adapter_name:
self.model.set_adapter(self.ref_adapter_name)
yield
Expand Down Expand Up @@ -1211,7 +1214,7 @@ def concatenated_forward(self, model: nn.Module, batch: dict[str, Union[list, to

# Compute the log probabilities of the labels
labels[~loss_mask] = 0 # dummy token; we'll ignore the losses on these tokens later
per_token_logps = torch.gather(logits.log_softmax(-1), dim=2, index=labels.unsqueeze(2)).squeeze(2)
per_token_logps = selective_log_softmax(logits, labels)
per_token_logps[~loss_mask] = 0
per_token_logps = torch.roll(per_token_logps, shifts=1, dims=1)

Expand Down
8 changes: 2 additions & 6 deletions trl/trainer/grpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled
from transformers.utils import is_peft_available

from ..core import selective_log_softmax
from ..data_utils import apply_chat_template, is_conversational, maybe_apply_chat_template
from ..import_utils import is_vllm_available
from ..models import create_reference_model, prepare_deepspeed, unwrap_model_for_generation
Expand Down Expand Up @@ -442,12 +443,7 @@ def _get_per_token_logps(self, model, input_ids, attention_mask, logits_to_keep)
# For transformers<=4.48, logits_to_keep argument isn't supported, so here we drop logits ourselves.
# See https://github.com/huggingface/trl/issues/2770
logits = logits[:, -logits_to_keep:]

# Compute the log probabilities for the input tokens.
token_logits = logits.gather(dim=-1, index=input_ids.unsqueeze(-1)).squeeze(-1)
logsumexp_values = torch.stack([torch.logsumexp(lg, dim=-1) for lg in logits]) # loop to reduce memory peak
token_log_probs = token_logits - logsumexp_values # log_softmax = logits - log(sum(exp(logits)))
return token_log_probs
return selective_log_softmax(logits, input_ids) # compute logprobs for the input tokens

def _prepare_inputs(self, inputs: dict[str, Union[torch.Tensor, Any]]) -> dict[str, Union[torch.Tensor, Any]]:
device = self.accelerator.device
Expand Down
11 changes: 7 additions & 4 deletions trl/trainer/kto_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@
from transformers.trainer_utils import EvalLoopOutput, has_length
from transformers.utils import is_peft_available

from ..core import selective_log_softmax
from ..data_utils import maybe_apply_chat_template, maybe_extract_prompt, maybe_unpair_preference_dataset
from ..models import PreTrainedModelWrapper, create_reference_model
from .kto_config import KTOConfig
Expand Down Expand Up @@ -812,9 +813,11 @@ def _prepare_deepspeed(self, model: PreTrainedModelWrapper):
@contextmanager
def null_ref_context(self):
"""Context manager for handling null reference model (that is, peft adapter manipulation)."""
with self.accelerator.unwrap_model(
self.model
).disable_adapter() if self.is_peft_model and not self.ref_adapter_name else nullcontext():
with (
self.accelerator.unwrap_model(self.model).disable_adapter()
if self.is_peft_model and not self.ref_adapter_name
else nullcontext()
):
if self.ref_adapter_name:
self.model.set_adapter(self.ref_adapter_name)
yield
Expand Down Expand Up @@ -1032,7 +1035,7 @@ def get_batch_logps(
# dummy token; we'll ignore the losses on these tokens later
labels[labels == label_pad_token_id] = 0

per_token_logps = torch.gather(logits.log_softmax(-1), dim=2, index=labels.unsqueeze(2)).squeeze(2)
per_token_logps = selective_log_softmax(logits, labels)

if average_log_prob:
return (per_token_logps * loss_mask).sum(-1) / loss_mask.sum(-1)
Expand Down
4 changes: 2 additions & 2 deletions trl/trainer/nash_md_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
from transformers.training_args import OptimizerNames
from transformers.utils import is_apex_available

from ..core import selective_log_softmax
from ..data_utils import is_conversational, maybe_apply_chat_template
from ..models.modeling_base import GeometricMixtureWrapper
from ..models.utils import unwrap_model_for_generation
Expand Down Expand Up @@ -277,8 +278,7 @@ def _compute_logprobs(self, model, model_data, context_length):
def compute_logprobs_for_data(m, data):
output = m(data["input_ids"], attention_mask=data["attention_mask"])
logits = output.logits[:, context_length - 1 : -1]
logprobs = F.log_softmax(logits, dim=-1)
token_logprobs = torch.gather(logprobs, 2, data["input_ids"][:, context_length:].unsqueeze(-1)).squeeze(-1)
token_logprobs = selective_log_softmax(logits, data["input_ids"][:, context_length:])
return token_logprobs

# Compute logprobs for model completions under the model
Expand Down
3 changes: 2 additions & 1 deletion trl/trainer/orpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@
from transformers.trainer_utils import EvalLoopOutput
from transformers.utils import is_peft_available, is_torch_fx_proxy

from ..core import selective_log_softmax
from ..data_utils import maybe_apply_chat_template, maybe_extract_prompt
from ..models import PreTrainedModelWrapper
from .orpo_config import ORPOConfig
Expand Down Expand Up @@ -718,7 +719,7 @@ def get_batch_logps(
# dummy token; we'll ignore the losses on these tokens later
labels = torch.where(labels == label_pad_token_id, 0, labels)

per_token_logps = torch.gather(logits.log_softmax(-1), dim=2, index=labels.unsqueeze(2)).squeeze(2)
per_token_logps = selective_log_softmax(logits, labels)

if average_log_prob:
return (per_token_logps * loss_mask).sum(-1) / loss_mask.sum(-1)
Expand Down
26 changes: 12 additions & 14 deletions trl/trainer/ppo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
from accelerate import Accelerator
from accelerate.utils import broadcast, gather_object
from datasets import Dataset
Expand All @@ -47,7 +46,7 @@
from transformers.trainer_callback import CallbackHandler, ExportableState, PrinterCallback
from transformers.utils import is_peft_available

from ..core import masked_mean, masked_whiten
from ..core import masked_mean, masked_whiten, selective_log_softmax
from ..models import create_reference_model
from ..models.utils import unwrap_model_for_generation
from .ppo_config import PPOConfig
Expand Down Expand Up @@ -310,9 +309,11 @@ def get_eval_dataloader(self) -> DataLoader:
@contextmanager
def null_ref_context(self):
"""Context manager for handling null reference model (that is, peft adapter manipulation)."""
with self.accelerator.unwrap_model(
self.model.policy
).disable_adapter() if self.is_peft_model and not self.ref_adapter_name else nullcontext():
with (
self.accelerator.unwrap_model(self.model.policy).disable_adapter()
if self.is_peft_model and not self.ref_adapter_name
else nullcontext()
):
if self.ref_adapter_name:
self.model.policy.set_adapter(self.ref_adapter_name)
yield
Expand Down Expand Up @@ -427,9 +428,8 @@ def repeat_generator():
query_response = query_responses[i : i + args.local_rollout_forward_batch_size]
response = query_response[:, context_length:]
logits = logitss[i : i + args.local_rollout_forward_batch_size]
all_logprob = F.log_softmax(logits, dim=-1)
logprob = torch.gather(all_logprob, 2, response.unsqueeze(-1)).squeeze(-1)
del logits, all_logprob
logprob = selective_log_softmax(logits, response)
del logits
torch.cuda.empty_cache()

if ref_policy is None:
Expand All @@ -439,9 +439,8 @@ def repeat_generator():
ref_output = forward(ref_policy, query_response, processing_class.pad_token_id)
ref_logits = ref_output.logits[:, context_length - 1 : -1]
ref_logits /= args.temperature + 1e-7
ref_all_logprob = F.log_softmax(ref_logits, dim=-1)
ref_logprob = torch.gather(ref_all_logprob, 2, response.unsqueeze(-1)).squeeze(-1)
del ref_output, ref_logits, ref_all_logprob
ref_logprob = selective_log_softmax(ref_logits, response)
del ref_output, ref_logits
torch.cuda.empty_cache()

# Response Processing 1. truncate response after the first occurrence of `stop_token_id`
Expand Down Expand Up @@ -547,8 +546,7 @@ def repeat_generator():
output, vpred_temp = forward(model, mb_query_responses, processing_class.pad_token_id)
logits = output.logits[:, context_length - 1 : -1]
logits /= args.temperature + 1e-7
new_all_logprobs = F.log_softmax(logits, dim=-1)
new_logprobs = torch.gather(new_all_logprobs, 2, mb_responses.unsqueeze(-1)).squeeze(-1)
new_logprobs = selective_log_softmax(logits, mb_responses)
new_logprobs = torch.masked_fill(
new_logprobs, padding_mask[micro_batch_inds], INVALID_LOGPROB
)
Expand Down Expand Up @@ -599,7 +597,7 @@ def repeat_generator():
# del everything and empty cache
# fmt: off
del (
output, vpred_temp, logits, new_all_logprobs, new_logprobs, vpred, vpredclipped,
output, vpred_temp, logits, new_logprobs, vpred, vpredclipped,
vf_losses1, vf_losses2, vf_loss, vf_clipfrac, logprobs_diff, ratio, pg_losses, pg_losses2, pg_loss_max,
pg_loss, loss, pg_clipfrac, prob_dist, entropy, approxkl, mb_return,
mb_advantage, mb_values, mb_responses, mb_query_responses, mb_logprobs,
Expand Down
20 changes: 8 additions & 12 deletions trl/trainer/rloo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
from accelerate import Accelerator
from accelerate.utils import broadcast, gather_object
from datasets import Dataset
Expand All @@ -45,6 +44,7 @@
from transformers.trainer import DEFAULT_CALLBACKS, DEFAULT_PROGRESS_CALLBACK
from transformers.trainer_callback import CallbackHandler, ExportableState, PrinterCallback

from ..core import selective_log_softmax
from ..models.utils import unwrap_model_for_generation
from ..trainer.utils import (
OnlineTrainerState,
Expand Down Expand Up @@ -330,17 +330,15 @@ def repeat_generator():
query_response = query_responses[i : i + args.local_rollout_forward_batch_size]
response = query_response[:, context_length:]
logits = logitss[i : i + args.local_rollout_forward_batch_size]
all_logprob = F.log_softmax(logits, dim=-1)
logprob = torch.gather(all_logprob, 2, response.unsqueeze(-1)).squeeze(-1)
del logits, all_logprob
logprob = selective_log_softmax(logits, response)
del logits
torch.cuda.empty_cache()

ref_output = forward(ref_policy, query_response, processing_class.pad_token_id)
ref_logits = ref_output.logits[:, context_length - 1 : -1]
ref_logits /= args.temperature + 1e-7
ref_all_logprob = F.log_softmax(ref_logits, dim=-1)
ref_logprob = torch.gather(ref_all_logprob, 2, response.unsqueeze(-1)).squeeze(-1)
del ref_output, ref_logits, ref_all_logprob
ref_logprob = selective_log_softmax(ref_logits, response)
del ref_output, ref_logits
torch.cuda.empty_cache()

# Response Processing 1. truncate response after the first occurrence of `stop_token_id`
Expand Down Expand Up @@ -467,8 +465,7 @@ def repeat_generator():
logits /= args.temperature + 1e-7

# Compute new logprobs
new_all_logprobs = F.log_softmax(logits, dim=-1)
new_logprobs = torch.gather(new_all_logprobs, 2, mb_responses.unsqueeze(-1)).squeeze(-1)
new_logprobs = selective_log_softmax(logits, mb_responses)
new_logprobs = torch.masked_fill(
new_logprobs, padding_mask[micro_batch_inds], INVALID_LOGPROB
)
Expand Down Expand Up @@ -512,9 +509,8 @@ def repeat_generator():
# del everything and empty cache
# fmt: off
del (
output, logits, new_all_logprobs, new_logprobs,
logprobs_diff, ratio, pg_losses, pg_losses2,
pg_loss, loss, pg_clipfrac, prob_dist, entropy, approxkl,
output, logits, new_logprobs, logprobs_diff, ratio, pg_losses,
pg_losses2, pg_loss, loss, pg_clipfrac, prob_dist, entropy, approxkl,
mb_advantage, mb_responses, mb_query_responses, mb_logprobs,
)
# fmt: on
Expand Down
Loading