From b244ce9b4ab26defe2192fd6a628146323ba3603 Mon Sep 17 00:00:00 2001 From: Tyler Romero Date: Fri, 7 Feb 2025 10:57:47 -0800 Subject: [PATCH 01/12] Reduce mem consumption across many trainers with efficient selective log-softmax approach --- tests/test_core.py | 23 ++++++++++++++++++++++- trl/core.py | 25 +++++++++++++++++++++++++ trl/trainer/bco_trainer.py | 11 +++++++---- trl/trainer/cpo_trainer.py | 3 ++- trl/trainer/dpo_trainer.py | 11 +++++++---- trl/trainer/grpo_trainer.py | 8 ++------ trl/trainer/kto_trainer.py | 11 +++++++---- trl/trainer/nash_md_trainer.py | 4 ++-- trl/trainer/orpo_trainer.py | 3 ++- trl/trainer/ppo_trainer.py | 31 +++++++++++++++---------------- trl/trainer/rloo_trainer.py | 19 ++++++++----------- trl/trainer/xpo_trainer.py | 4 ++-- 12 files changed, 101 insertions(+), 52 deletions(-) diff --git a/tests/test_core.py b/tests/test_core.py index 4be0810912..1bd9235266 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -16,7 +16,7 @@ import torch -from trl.core import masked_mean, masked_var, masked_whiten +from trl.core import extract_per_token_logprobs, masked_mean, masked_var, masked_whiten class CoreTester(unittest.TestCase): @@ -44,3 +44,24 @@ def whiten(values: torch.Tensor) -> torch.Tensor: whiten_masked = masked_whiten(self.test_input, self.test_mask)[1:3] diffs = (whiten_unmasked - whiten_masked).sum() self.assertLess(abs(diffs.item()), 0.00001) + + def test_extract_per_token_logprobs(self): + """Test extract_per_token_logprobs with different dtypes""" + dtypes = [torch.float64, torch.float32, torch.float16, torch.bfloat16] + vocab_size = 32768 + batch_size = 4 + seq_len = 256 + + for dtype in dtypes: + with self.subTest(dtype=dtype): + input_ids = torch.randint(low=0, high=vocab_size, size=(batch_size, seq_len), device="cuda") + logits = torch.randn(batch_size, seq_len, vocab_size, device="cuda", dtype=dtype) + + expected_output = extract_per_token_logprobs(logits=logits, input_ids=input_ids) + actual_output = extract_per_token_logprobs(logits=logits, input_ids=input_ids) + + if dtype in [torch.float16, torch.bfloat16]: + # float16 falls back to an exact method + self.assertTrue(torch.equal(actual_output, expected_output)) + else: + torch.testing.assert_close(actual_output, expected_output, rtol=1e-5, atol=1e-5) diff --git a/trl/core.py b/trl/core.py index b12d3a79d5..44dfe111c9 100644 --- a/trl/core.py +++ b/trl/core.py @@ -20,6 +20,7 @@ import numpy as np import torch +import torch.nn.functional as F from transformers import is_torch_npu_available, is_torch_xpu_available @@ -157,3 +158,27 @@ def randn_tensor( latents = torch.randn(shape, generator=generator, device=rand_device, dtype=dtype, layout=layout).to(device) return latents + + +def extract_per_token_logprobs(logits, input_ids): + """ + A memory efficient implementation of a common log_softmax -> gather operation. + Equivalent to the following naive implementation: + ```python + per_token_logps = torch.gather(logits.log_softmax(-1), dim=-1, index=labels.unsqueeze(-1)).squeeze(-1) + ``` + """ + if logits.dtype in [torch.float32, torch.float64]: + selected_logits = torch.gather(logits, dim=-1, index=input_ids.unsqueeze(-1)).squeeze(-1) + # loop to reduce peak mem consumption + logsumexp_values = torch.stack([torch.logsumexp(lg, dim=-1) for lg in logits]) + per_token_logps = selected_logits - logsumexp_values # log_softmax(x_i) = x_i - logsumexp(x) + else: + # logsumexp approach is unstable with bfloat16, fall back to slightly less efficent approach + per_token_logps = [] + for row_logits, row_labels in zip(logits, input_ids): # loop to reduce peak mem consumption + row_logps = F.log_softmax(row_logits, dim=-1) + row_per_token_logps = row_logps.gather(dim=-1, index=row_labels.unsqueeze(-1)).squeeze(-1) + per_token_logps.append(row_per_token_logps) + per_token_logps = torch.stack(per_token_logps) + return per_token_logps diff --git a/trl/trainer/bco_trainer.py b/trl/trainer/bco_trainer.py index 18df4a7210..a9db3afb3b 100644 --- a/trl/trainer/bco_trainer.py +++ b/trl/trainer/bco_trainer.py @@ -53,6 +53,7 @@ from transformers.trainer_utils import EvalLoopOutput, has_length from transformers.utils import is_peft_available +from ..core import extract_per_token_logprobs from ..data_utils import maybe_apply_chat_template from ..models import PreTrainedModelWrapper, create_reference_model from .bco_config import BCOConfig @@ -897,9 +898,11 @@ def _load_optimizer_and_scheduler(self, checkpoint): @contextmanager def null_ref_context(self): """Context manager for handling null reference model (that is, peft adapter manipulation).""" - with self.accelerator.unwrap_model( - self.model - ).disable_adapter() if self.is_peft_model and not self.ref_adapter_name else nullcontext(): + with ( + self.accelerator.unwrap_model(self.model).disable_adapter() + if self.is_peft_model and not self.ref_adapter_name + else nullcontext() + ): if self.ref_adapter_name: self.model.set_adapter(self.ref_adapter_name) yield @@ -1062,7 +1065,7 @@ def get_batch_logps( # dummy token; we'll ignore the losses on these tokens later labels[labels == label_pad_token_id] = 0 - per_token_logps = torch.gather(logits.log_softmax(-1), dim=2, index=labels.unsqueeze(2)).squeeze(2) + per_token_logps = extract_per_token_logprobs(logits, labels) if average_log_prob: return (per_token_logps * loss_mask).sum(-1) / loss_mask.sum(-1) diff --git a/trl/trainer/cpo_trainer.py b/trl/trainer/cpo_trainer.py index 644a2d5353..20eeeb3421 100644 --- a/trl/trainer/cpo_trainer.py +++ b/trl/trainer/cpo_trainer.py @@ -48,6 +48,7 @@ from transformers.trainer_utils import EvalLoopOutput from transformers.utils import is_peft_available, is_torch_fx_proxy +from ..core import extract_per_token_logprobs from ..data_utils import maybe_apply_chat_template, maybe_extract_prompt from .cpo_config import CPOConfig from .utils import ( @@ -711,7 +712,7 @@ def get_batch_logps( # dummy token; we'll ignore the losses on these tokens later labels[labels == label_pad_token_id] = 0 - per_token_logps = torch.gather(logits.log_softmax(-1), dim=2, index=labels.unsqueeze(2)).squeeze(2) + per_token_logps = extract_per_token_logprobs(logits, labels) if average_log_prob: return (per_token_logps * loss_mask).sum(-1) / loss_mask.sum(-1) diff --git a/trl/trainer/dpo_trainer.py b/trl/trainer/dpo_trainer.py index cd3c3b4dca..b50e3cbb43 100644 --- a/trl/trainer/dpo_trainer.py +++ b/trl/trainer/dpo_trainer.py @@ -53,6 +53,7 @@ from transformers.utils import is_peft_available, is_torch_xpu_available from transformers.utils.deprecation import deprecate_kwarg +from ..core import extract_per_token_logprobs from ..data_utils import maybe_apply_chat_template, maybe_extract_prompt from ..models import PreTrainedModelWrapper, create_reference_model from .callbacks import SyncRefModelCallback @@ -822,9 +823,11 @@ def get_eval_dataloader(self, eval_dataset: Optional[Dataset] = None) -> DataLoa @contextmanager def null_ref_context(self): """Context manager for handling null reference model (that is, peft adapter manipulation).""" - with self.accelerator.unwrap_model( - self.model - ).disable_adapter() if self.is_peft_model and not self.ref_adapter_name else nullcontext(): + with ( + self.accelerator.unwrap_model(self.model).disable_adapter() + if self.is_peft_model and not self.ref_adapter_name + else nullcontext() + ): if self.ref_adapter_name: self.model.set_adapter(self.ref_adapter_name) yield @@ -1211,7 +1214,7 @@ def concatenated_forward(self, model: nn.Module, batch: dict[str, Union[list, to # Compute the log probabilities of the labels labels[~loss_mask] = 0 # dummy token; we'll ignore the losses on these tokens later - per_token_logps = torch.gather(logits.log_softmax(-1), dim=2, index=labels.unsqueeze(2)).squeeze(2) + per_token_logps = extract_per_token_logprobs(logits, labels) per_token_logps[~loss_mask] = 0 per_token_logps = torch.roll(per_token_logps, shifts=1, dims=1) diff --git a/trl/trainer/grpo_trainer.py b/trl/trainer/grpo_trainer.py index 4cbba2fcaa..54540ec252 100644 --- a/trl/trainer/grpo_trainer.py +++ b/trl/trainer/grpo_trainer.py @@ -42,6 +42,7 @@ from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled from transformers.utils import is_peft_available +from ..core import extract_per_token_logprobs from ..data_utils import apply_chat_template, is_conversational, maybe_apply_chat_template from ..import_utils import is_vllm_available from ..models import create_reference_model, prepare_deepspeed, unwrap_model_for_generation @@ -442,12 +443,7 @@ def _get_per_token_logps(self, model, input_ids, attention_mask, logits_to_keep) # For transformers<=4.48, logits_to_keep argument isn't supported, so here we drop logits ourselves. # See https://github.com/huggingface/trl/issues/2770 logits = logits[:, -logits_to_keep:] - - # Compute the log probabilities for the input tokens. - token_logits = logits.gather(dim=-1, index=input_ids.unsqueeze(-1)).squeeze(-1) - logsumexp_values = torch.stack([torch.logsumexp(lg, dim=-1) for lg in logits]) # loop to reduce memory peak - token_log_probs = token_logits - logsumexp_values # log_softmax = logits - log(sum(exp(logits))) - return token_log_probs + return extract_per_token_logprobs(logits, input_ids) # compute logprobs for the input tokens def _prepare_inputs(self, inputs: dict[str, Union[torch.Tensor, Any]]) -> dict[str, Union[torch.Tensor, Any]]: device = self.accelerator.device diff --git a/trl/trainer/kto_trainer.py b/trl/trainer/kto_trainer.py index c45a88d554..be092dfccc 100644 --- a/trl/trainer/kto_trainer.py +++ b/trl/trainer/kto_trainer.py @@ -52,6 +52,7 @@ from transformers.trainer_utils import EvalLoopOutput, has_length from transformers.utils import is_peft_available +from ..core import extract_per_token_logprobs from ..data_utils import maybe_apply_chat_template, maybe_extract_prompt, maybe_unpair_preference_dataset from ..models import PreTrainedModelWrapper, create_reference_model from .kto_config import KTOConfig @@ -812,9 +813,11 @@ def _prepare_deepspeed(self, model: PreTrainedModelWrapper): @contextmanager def null_ref_context(self): """Context manager for handling null reference model (that is, peft adapter manipulation).""" - with self.accelerator.unwrap_model( - self.model - ).disable_adapter() if self.is_peft_model and not self.ref_adapter_name else nullcontext(): + with ( + self.accelerator.unwrap_model(self.model).disable_adapter() + if self.is_peft_model and not self.ref_adapter_name + else nullcontext() + ): if self.ref_adapter_name: self.model.set_adapter(self.ref_adapter_name) yield @@ -1032,7 +1035,7 @@ def get_batch_logps( # dummy token; we'll ignore the losses on these tokens later labels[labels == label_pad_token_id] = 0 - per_token_logps = torch.gather(logits.log_softmax(-1), dim=2, index=labels.unsqueeze(2)).squeeze(2) + per_token_logps = extract_per_token_logprobs(logits, labels) if average_log_prob: return (per_token_logps * loss_mask).sum(-1) / loss_mask.sum(-1) diff --git a/trl/trainer/nash_md_trainer.py b/trl/trainer/nash_md_trainer.py index cbe218066e..5441d6cc40 100644 --- a/trl/trainer/nash_md_trainer.py +++ b/trl/trainer/nash_md_trainer.py @@ -34,6 +34,7 @@ from transformers.training_args import OptimizerNames from transformers.utils import is_apex_available +from ..core import extract_per_token_logprobs from ..data_utils import is_conversational, maybe_apply_chat_template from ..models.modeling_base import GeometricMixtureWrapper from ..models.utils import unwrap_model_for_generation @@ -277,8 +278,7 @@ def _compute_logprobs(self, model, model_data, context_length): def compute_logprobs_for_data(m, data): output = m(data["input_ids"], attention_mask=data["attention_mask"]) logits = output.logits[:, context_length - 1 : -1] - logprobs = F.log_softmax(logits, dim=-1) - token_logprobs = torch.gather(logprobs, 2, data["input_ids"][:, context_length:].unsqueeze(-1)).squeeze(-1) + token_logprobs = extract_per_token_logprobs(logits, data["input_ids"][:, context_length]) return token_logprobs # Compute logprobs for model completions under the model diff --git a/trl/trainer/orpo_trainer.py b/trl/trainer/orpo_trainer.py index 832927d05c..9a69999a39 100644 --- a/trl/trainer/orpo_trainer.py +++ b/trl/trainer/orpo_trainer.py @@ -51,6 +51,7 @@ from transformers.trainer_utils import EvalLoopOutput from transformers.utils import is_peft_available, is_torch_fx_proxy +from ..core import extract_per_token_logprobs from ..data_utils import maybe_apply_chat_template, maybe_extract_prompt from ..models import PreTrainedModelWrapper from .orpo_config import ORPOConfig @@ -718,7 +719,7 @@ def get_batch_logps( # dummy token; we'll ignore the losses on these tokens later labels = torch.where(labels == label_pad_token_id, 0, labels) - per_token_logps = torch.gather(logits.log_softmax(-1), dim=2, index=labels.unsqueeze(2)).squeeze(2) + per_token_logps = extract_per_token_logprobs(logits, labels) if average_log_prob: return (per_token_logps * loss_mask).sum(-1) / loss_mask.sum(-1) diff --git a/trl/trainer/ppo_trainer.py b/trl/trainer/ppo_trainer.py index 27cbdd016c..cb922419da 100644 --- a/trl/trainer/ppo_trainer.py +++ b/trl/trainer/ppo_trainer.py @@ -47,7 +47,7 @@ from transformers.trainer_callback import CallbackHandler, ExportableState, PrinterCallback from transformers.utils import is_peft_available -from ..core import masked_mean, masked_whiten +from ..core import extract_per_token_logprobs, masked_mean, masked_whiten from ..models import create_reference_model from ..models.utils import unwrap_model_for_generation from .ppo_config import PPOConfig @@ -197,9 +197,9 @@ def __init__( args.local_batch_size, args.num_mini_batches, "`local_batch_size` must be a multiple of `num_mini_batches`" ) if args.whiten_rewards: - assert ( - args.local_mini_batch_size >= 8 - ), f"Per-rank minibatch size {args.local_mini_batch_size} is insufficient for whitening" + assert args.local_mini_batch_size >= 8, ( + f"Per-rank minibatch size {args.local_mini_batch_size} is insufficient for whitening" + ) # `per_rank_rollout_batch_size` is our `args.local_batch_size` # `per_rank_minibatch_size` is our `args.local_mini_batch_size` args.num_total_batches = math.ceil( @@ -310,9 +310,11 @@ def get_eval_dataloader(self) -> DataLoader: @contextmanager def null_ref_context(self): """Context manager for handling null reference model (that is, peft adapter manipulation).""" - with self.accelerator.unwrap_model( - self.model.policy - ).disable_adapter() if self.is_peft_model and not self.ref_adapter_name else nullcontext(): + with ( + self.accelerator.unwrap_model(self.model.policy).disable_adapter() + if self.is_peft_model and not self.ref_adapter_name + else nullcontext() + ): if self.ref_adapter_name: self.model.policy.set_adapter(self.ref_adapter_name) yield @@ -427,9 +429,8 @@ def repeat_generator(): query_response = query_responses[i : i + args.local_rollout_forward_batch_size] response = query_response[:, context_length:] logits = logitss[i : i + args.local_rollout_forward_batch_size] - all_logprob = F.log_softmax(logits, dim=-1) - logprob = torch.gather(all_logprob, 2, response.unsqueeze(-1)).squeeze(-1) - del logits, all_logprob + logprob = extract_per_token_logprobs(logits, response) + del logits torch.cuda.empty_cache() if ref_policy is None: @@ -439,9 +440,8 @@ def repeat_generator(): ref_output = forward(ref_policy, query_response, processing_class.pad_token_id) ref_logits = ref_output.logits[:, context_length - 1 : -1] ref_logits /= args.temperature + 1e-7 - ref_all_logprob = F.log_softmax(ref_logits, dim=-1) - ref_logprob = torch.gather(ref_all_logprob, 2, response.unsqueeze(-1)).squeeze(-1) - del ref_output, ref_logits, ref_all_logprob + ref_logprob = extract_per_token_logprobs(ref_logits, response) + del ref_output, ref_logits torch.cuda.empty_cache() # Response Processing 1. truncate response after the first occurrence of `stop_token_id` @@ -547,8 +547,7 @@ def repeat_generator(): output, vpred_temp = forward(model, mb_query_responses, processing_class.pad_token_id) logits = output.logits[:, context_length - 1 : -1] logits /= args.temperature + 1e-7 - new_all_logprobs = F.log_softmax(logits, dim=-1) - new_logprobs = torch.gather(new_all_logprobs, 2, mb_responses.unsqueeze(-1)).squeeze(-1) + new_logprobs = extract_per_token_logprobs(logits, mb_responses) new_logprobs = torch.masked_fill( new_logprobs, padding_mask[micro_batch_inds], INVALID_LOGPROB ) @@ -599,7 +598,7 @@ def repeat_generator(): # del everything and empty cache # fmt: off del ( - output, vpred_temp, logits, new_all_logprobs, new_logprobs, vpred, vpredclipped, + output, vpred_temp, logits, new_logprobs, vpred, vpredclipped, vf_losses1, vf_losses2, vf_loss, vf_clipfrac, logprobs_diff, ratio, pg_losses, pg_losses2, pg_loss_max, pg_loss, loss, pg_clipfrac, prob_dist, entropy, approxkl, mb_return, mb_advantage, mb_values, mb_responses, mb_query_responses, mb_logprobs, diff --git a/trl/trainer/rloo_trainer.py b/trl/trainer/rloo_trainer.py index 6626baa15c..1eae9bdbe5 100644 --- a/trl/trainer/rloo_trainer.py +++ b/trl/trainer/rloo_trainer.py @@ -45,6 +45,7 @@ from transformers.trainer import DEFAULT_CALLBACKS, DEFAULT_PROGRESS_CALLBACK from transformers.trainer_callback import CallbackHandler, ExportableState, PrinterCallback +from ..core import extract_per_token_logprobs from ..models.utils import unwrap_model_for_generation from ..trainer.utils import ( OnlineTrainerState, @@ -330,17 +331,15 @@ def repeat_generator(): query_response = query_responses[i : i + args.local_rollout_forward_batch_size] response = query_response[:, context_length:] logits = logitss[i : i + args.local_rollout_forward_batch_size] - all_logprob = F.log_softmax(logits, dim=-1) - logprob = torch.gather(all_logprob, 2, response.unsqueeze(-1)).squeeze(-1) - del logits, all_logprob + logprob = extract_per_token_logprobs(logits, response) + del logits torch.cuda.empty_cache() ref_output = forward(ref_policy, query_response, processing_class.pad_token_id) ref_logits = ref_output.logits[:, context_length - 1 : -1] ref_logits /= args.temperature + 1e-7 - ref_all_logprob = F.log_softmax(ref_logits, dim=-1) - ref_logprob = torch.gather(ref_all_logprob, 2, response.unsqueeze(-1)).squeeze(-1) - del ref_output, ref_logits, ref_all_logprob + ref_logprob = extract_per_token_logprobs(ref_logits, response) + del ref_output, ref_logits torch.cuda.empty_cache() # Response Processing 1. truncate response after the first occurrence of `stop_token_id` @@ -467,8 +466,7 @@ def repeat_generator(): logits /= args.temperature + 1e-7 # Compute new logprobs - new_all_logprobs = F.log_softmax(logits, dim=-1) - new_logprobs = torch.gather(new_all_logprobs, 2, mb_responses.unsqueeze(-1)).squeeze(-1) + new_logprobs = extract_per_token_logprobs(logits, mb_responses) new_logprobs = torch.masked_fill( new_logprobs, padding_mask[micro_batch_inds], INVALID_LOGPROB ) @@ -512,9 +510,8 @@ def repeat_generator(): # del everything and empty cache # fmt: off del ( - output, logits, new_all_logprobs, new_logprobs, - logprobs_diff, ratio, pg_losses, pg_losses2, - pg_loss, loss, pg_clipfrac, prob_dist, entropy, approxkl, + output, logits, new_logprobs, logprobs_diff, ratio, pg_losses, + pg_losses2, pg_loss, loss, pg_clipfrac, prob_dist, entropy, approxkl, mb_advantage, mb_responses, mb_query_responses, mb_logprobs, ) # fmt: on diff --git a/trl/trainer/xpo_trainer.py b/trl/trainer/xpo_trainer.py index 2d535344e7..a92ed9cdf9 100644 --- a/trl/trainer/xpo_trainer.py +++ b/trl/trainer/xpo_trainer.py @@ -34,6 +34,7 @@ from transformers.trainer_utils import EvalPrediction from transformers.training_args import OptimizerNames +from ..core import extract_per_token_logprobs from ..data_utils import is_conversational, maybe_apply_chat_template from ..models.utils import unwrap_model_for_generation from .judges import BasePairwiseJudge @@ -274,8 +275,7 @@ def _compute_logprobs(self, model, model_data, ref_data, context_length): def compute_logprobs_for_data(m, data): output = m(data["input_ids"], attention_mask=data["attention_mask"]) logits = output.logits[:, context_length - 1 : -1] - logprobs = F.log_softmax(logits, dim=-1) - token_logprobs = torch.gather(logprobs, 2, data["input_ids"][:, context_length:].unsqueeze(-1)).squeeze(-1) + token_logprobs = extract_per_token_logprobs(logits, data["input_ids"][:, context_length:]) return token_logprobs # Compute logprobs for model completions From 3b1f91f8f5bb50fd77df24af8269b0089d68fc6a Mon Sep 17 00:00:00 2001 From: Tyler Romero Date: Fri, 7 Feb 2025 10:59:07 -0800 Subject: [PATCH 02/12] rename --- tests/test_core.py | 8 +++++--- trl/core.py | 2 +- trl/trainer/bco_trainer.py | 4 ++-- trl/trainer/cpo_trainer.py | 4 ++-- trl/trainer/dpo_trainer.py | 4 ++-- trl/trainer/grpo_trainer.py | 4 ++-- trl/trainer/kto_trainer.py | 4 ++-- trl/trainer/nash_md_trainer.py | 4 ++-- trl/trainer/orpo_trainer.py | 4 ++-- trl/trainer/ppo_trainer.py | 8 ++++---- trl/trainer/rloo_trainer.py | 8 ++++---- trl/trainer/xpo_trainer.py | 4 ++-- 12 files changed, 30 insertions(+), 28 deletions(-) diff --git a/tests/test_core.py b/tests/test_core.py index 1bd9235266..2fe252ac61 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -16,7 +16,7 @@ import torch -from trl.core import extract_per_token_logprobs, masked_mean, masked_var, masked_whiten +from trl.core import masked_mean, masked_var, masked_whiten, selective_log_softmax class CoreTester(unittest.TestCase): @@ -57,8 +57,10 @@ def test_extract_per_token_logprobs(self): input_ids = torch.randint(low=0, high=vocab_size, size=(batch_size, seq_len), device="cuda") logits = torch.randn(batch_size, seq_len, vocab_size, device="cuda", dtype=dtype) - expected_output = extract_per_token_logprobs(logits=logits, input_ids=input_ids) - actual_output = extract_per_token_logprobs(logits=logits, input_ids=input_ids) + expected_output = torch.gather(logits.log_softmax(-1), dim=-1, index=input_ids.unsqueeze(-1)).squeeze( + -1 + ) + actual_output = selective_log_softmax(logits=logits, input_ids=input_ids) if dtype in [torch.float16, torch.bfloat16]: # float16 falls back to an exact method diff --git a/trl/core.py b/trl/core.py index 44dfe111c9..c41e625165 100644 --- a/trl/core.py +++ b/trl/core.py @@ -160,7 +160,7 @@ def randn_tensor( return latents -def extract_per_token_logprobs(logits, input_ids): +def selective_log_softmax(logits, input_ids): """ A memory efficient implementation of a common log_softmax -> gather operation. Equivalent to the following naive implementation: diff --git a/trl/trainer/bco_trainer.py b/trl/trainer/bco_trainer.py index a9db3afb3b..8873f62477 100644 --- a/trl/trainer/bco_trainer.py +++ b/trl/trainer/bco_trainer.py @@ -53,7 +53,7 @@ from transformers.trainer_utils import EvalLoopOutput, has_length from transformers.utils import is_peft_available -from ..core import extract_per_token_logprobs +from ..core import selective_log_softmax from ..data_utils import maybe_apply_chat_template from ..models import PreTrainedModelWrapper, create_reference_model from .bco_config import BCOConfig @@ -1065,7 +1065,7 @@ def get_batch_logps( # dummy token; we'll ignore the losses on these tokens later labels[labels == label_pad_token_id] = 0 - per_token_logps = extract_per_token_logprobs(logits, labels) + per_token_logps = selective_log_softmax(logits, labels) if average_log_prob: return (per_token_logps * loss_mask).sum(-1) / loss_mask.sum(-1) diff --git a/trl/trainer/cpo_trainer.py b/trl/trainer/cpo_trainer.py index 20eeeb3421..ad1c4a1b74 100644 --- a/trl/trainer/cpo_trainer.py +++ b/trl/trainer/cpo_trainer.py @@ -48,7 +48,7 @@ from transformers.trainer_utils import EvalLoopOutput from transformers.utils import is_peft_available, is_torch_fx_proxy -from ..core import extract_per_token_logprobs +from ..core import selective_log_softmax from ..data_utils import maybe_apply_chat_template, maybe_extract_prompt from .cpo_config import CPOConfig from .utils import ( @@ -712,7 +712,7 @@ def get_batch_logps( # dummy token; we'll ignore the losses on these tokens later labels[labels == label_pad_token_id] = 0 - per_token_logps = extract_per_token_logprobs(logits, labels) + per_token_logps = selective_log_softmax(logits, labels) if average_log_prob: return (per_token_logps * loss_mask).sum(-1) / loss_mask.sum(-1) diff --git a/trl/trainer/dpo_trainer.py b/trl/trainer/dpo_trainer.py index b50e3cbb43..9b3578ef29 100644 --- a/trl/trainer/dpo_trainer.py +++ b/trl/trainer/dpo_trainer.py @@ -53,7 +53,7 @@ from transformers.utils import is_peft_available, is_torch_xpu_available from transformers.utils.deprecation import deprecate_kwarg -from ..core import extract_per_token_logprobs +from ..core import selective_log_softmax from ..data_utils import maybe_apply_chat_template, maybe_extract_prompt from ..models import PreTrainedModelWrapper, create_reference_model from .callbacks import SyncRefModelCallback @@ -1214,7 +1214,7 @@ def concatenated_forward(self, model: nn.Module, batch: dict[str, Union[list, to # Compute the log probabilities of the labels labels[~loss_mask] = 0 # dummy token; we'll ignore the losses on these tokens later - per_token_logps = extract_per_token_logprobs(logits, labels) + per_token_logps = selective_log_softmax(logits, labels) per_token_logps[~loss_mask] = 0 per_token_logps = torch.roll(per_token_logps, shifts=1, dims=1) diff --git a/trl/trainer/grpo_trainer.py b/trl/trainer/grpo_trainer.py index 54540ec252..faacfabb37 100644 --- a/trl/trainer/grpo_trainer.py +++ b/trl/trainer/grpo_trainer.py @@ -42,7 +42,7 @@ from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled from transformers.utils import is_peft_available -from ..core import extract_per_token_logprobs +from ..core import selective_log_softmax from ..data_utils import apply_chat_template, is_conversational, maybe_apply_chat_template from ..import_utils import is_vllm_available from ..models import create_reference_model, prepare_deepspeed, unwrap_model_for_generation @@ -443,7 +443,7 @@ def _get_per_token_logps(self, model, input_ids, attention_mask, logits_to_keep) # For transformers<=4.48, logits_to_keep argument isn't supported, so here we drop logits ourselves. # See https://github.com/huggingface/trl/issues/2770 logits = logits[:, -logits_to_keep:] - return extract_per_token_logprobs(logits, input_ids) # compute logprobs for the input tokens + return selective_log_softmax(logits, input_ids) # compute logprobs for the input tokens def _prepare_inputs(self, inputs: dict[str, Union[torch.Tensor, Any]]) -> dict[str, Union[torch.Tensor, Any]]: device = self.accelerator.device diff --git a/trl/trainer/kto_trainer.py b/trl/trainer/kto_trainer.py index be092dfccc..361370d344 100644 --- a/trl/trainer/kto_trainer.py +++ b/trl/trainer/kto_trainer.py @@ -52,7 +52,7 @@ from transformers.trainer_utils import EvalLoopOutput, has_length from transformers.utils import is_peft_available -from ..core import extract_per_token_logprobs +from ..core import selective_log_softmax from ..data_utils import maybe_apply_chat_template, maybe_extract_prompt, maybe_unpair_preference_dataset from ..models import PreTrainedModelWrapper, create_reference_model from .kto_config import KTOConfig @@ -1035,7 +1035,7 @@ def get_batch_logps( # dummy token; we'll ignore the losses on these tokens later labels[labels == label_pad_token_id] = 0 - per_token_logps = extract_per_token_logprobs(logits, labels) + per_token_logps = selective_log_softmax(logits, labels) if average_log_prob: return (per_token_logps * loss_mask).sum(-1) / loss_mask.sum(-1) diff --git a/trl/trainer/nash_md_trainer.py b/trl/trainer/nash_md_trainer.py index 5441d6cc40..5219061682 100644 --- a/trl/trainer/nash_md_trainer.py +++ b/trl/trainer/nash_md_trainer.py @@ -34,7 +34,7 @@ from transformers.training_args import OptimizerNames from transformers.utils import is_apex_available -from ..core import extract_per_token_logprobs +from ..core import selective_log_softmax from ..data_utils import is_conversational, maybe_apply_chat_template from ..models.modeling_base import GeometricMixtureWrapper from ..models.utils import unwrap_model_for_generation @@ -278,7 +278,7 @@ def _compute_logprobs(self, model, model_data, context_length): def compute_logprobs_for_data(m, data): output = m(data["input_ids"], attention_mask=data["attention_mask"]) logits = output.logits[:, context_length - 1 : -1] - token_logprobs = extract_per_token_logprobs(logits, data["input_ids"][:, context_length]) + token_logprobs = selective_log_softmax(logits, data["input_ids"][:, context_length]) return token_logprobs # Compute logprobs for model completions under the model diff --git a/trl/trainer/orpo_trainer.py b/trl/trainer/orpo_trainer.py index 9a69999a39..7eed87f665 100644 --- a/trl/trainer/orpo_trainer.py +++ b/trl/trainer/orpo_trainer.py @@ -51,7 +51,7 @@ from transformers.trainer_utils import EvalLoopOutput from transformers.utils import is_peft_available, is_torch_fx_proxy -from ..core import extract_per_token_logprobs +from ..core import selective_log_softmax from ..data_utils import maybe_apply_chat_template, maybe_extract_prompt from ..models import PreTrainedModelWrapper from .orpo_config import ORPOConfig @@ -719,7 +719,7 @@ def get_batch_logps( # dummy token; we'll ignore the losses on these tokens later labels = torch.where(labels == label_pad_token_id, 0, labels) - per_token_logps = extract_per_token_logprobs(logits, labels) + per_token_logps = selective_log_softmax(logits, labels) if average_log_prob: return (per_token_logps * loss_mask).sum(-1) / loss_mask.sum(-1) diff --git a/trl/trainer/ppo_trainer.py b/trl/trainer/ppo_trainer.py index cb922419da..d544b38320 100644 --- a/trl/trainer/ppo_trainer.py +++ b/trl/trainer/ppo_trainer.py @@ -47,7 +47,7 @@ from transformers.trainer_callback import CallbackHandler, ExportableState, PrinterCallback from transformers.utils import is_peft_available -from ..core import extract_per_token_logprobs, masked_mean, masked_whiten +from ..core import masked_mean, masked_whiten, selective_log_softmax from ..models import create_reference_model from ..models.utils import unwrap_model_for_generation from .ppo_config import PPOConfig @@ -429,7 +429,7 @@ def repeat_generator(): query_response = query_responses[i : i + args.local_rollout_forward_batch_size] response = query_response[:, context_length:] logits = logitss[i : i + args.local_rollout_forward_batch_size] - logprob = extract_per_token_logprobs(logits, response) + logprob = selective_log_softmax(logits, response) del logits torch.cuda.empty_cache() @@ -440,7 +440,7 @@ def repeat_generator(): ref_output = forward(ref_policy, query_response, processing_class.pad_token_id) ref_logits = ref_output.logits[:, context_length - 1 : -1] ref_logits /= args.temperature + 1e-7 - ref_logprob = extract_per_token_logprobs(ref_logits, response) + ref_logprob = selective_log_softmax(ref_logits, response) del ref_output, ref_logits torch.cuda.empty_cache() @@ -547,7 +547,7 @@ def repeat_generator(): output, vpred_temp = forward(model, mb_query_responses, processing_class.pad_token_id) logits = output.logits[:, context_length - 1 : -1] logits /= args.temperature + 1e-7 - new_logprobs = extract_per_token_logprobs(logits, mb_responses) + new_logprobs = selective_log_softmax(logits, mb_responses) new_logprobs = torch.masked_fill( new_logprobs, padding_mask[micro_batch_inds], INVALID_LOGPROB ) diff --git a/trl/trainer/rloo_trainer.py b/trl/trainer/rloo_trainer.py index 1eae9bdbe5..32d2ac037b 100644 --- a/trl/trainer/rloo_trainer.py +++ b/trl/trainer/rloo_trainer.py @@ -45,7 +45,7 @@ from transformers.trainer import DEFAULT_CALLBACKS, DEFAULT_PROGRESS_CALLBACK from transformers.trainer_callback import CallbackHandler, ExportableState, PrinterCallback -from ..core import extract_per_token_logprobs +from ..core import selective_log_softmax from ..models.utils import unwrap_model_for_generation from ..trainer.utils import ( OnlineTrainerState, @@ -331,14 +331,14 @@ def repeat_generator(): query_response = query_responses[i : i + args.local_rollout_forward_batch_size] response = query_response[:, context_length:] logits = logitss[i : i + args.local_rollout_forward_batch_size] - logprob = extract_per_token_logprobs(logits, response) + logprob = selective_log_softmax(logits, response) del logits torch.cuda.empty_cache() ref_output = forward(ref_policy, query_response, processing_class.pad_token_id) ref_logits = ref_output.logits[:, context_length - 1 : -1] ref_logits /= args.temperature + 1e-7 - ref_logprob = extract_per_token_logprobs(ref_logits, response) + ref_logprob = selective_log_softmax(ref_logits, response) del ref_output, ref_logits torch.cuda.empty_cache() @@ -466,7 +466,7 @@ def repeat_generator(): logits /= args.temperature + 1e-7 # Compute new logprobs - new_logprobs = extract_per_token_logprobs(logits, mb_responses) + new_logprobs = selective_log_softmax(logits, mb_responses) new_logprobs = torch.masked_fill( new_logprobs, padding_mask[micro_batch_inds], INVALID_LOGPROB ) diff --git a/trl/trainer/xpo_trainer.py b/trl/trainer/xpo_trainer.py index a92ed9cdf9..5ab2b2c15d 100644 --- a/trl/trainer/xpo_trainer.py +++ b/trl/trainer/xpo_trainer.py @@ -34,7 +34,7 @@ from transformers.trainer_utils import EvalPrediction from transformers.training_args import OptimizerNames -from ..core import extract_per_token_logprobs +from ..core import selective_log_softmax from ..data_utils import is_conversational, maybe_apply_chat_template from ..models.utils import unwrap_model_for_generation from .judges import BasePairwiseJudge @@ -275,7 +275,7 @@ def _compute_logprobs(self, model, model_data, ref_data, context_length): def compute_logprobs_for_data(m, data): output = m(data["input_ids"], attention_mask=data["attention_mask"]) logits = output.logits[:, context_length - 1 : -1] - token_logprobs = extract_per_token_logprobs(logits, data["input_ids"][:, context_length:]) + token_logprobs = selective_log_softmax(logits, data["input_ids"][:, context_length:]) return token_logprobs # Compute logprobs for model completions From f0f5300de38598ec980cf7c46eaa0db32665590d Mon Sep 17 00:00:00 2001 From: Tyler Romero Date: Fri, 7 Feb 2025 11:25:22 -0800 Subject: [PATCH 03/12] typo fix --- trl/trainer/nash_md_trainer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/trl/trainer/nash_md_trainer.py b/trl/trainer/nash_md_trainer.py index 5219061682..bad028401d 100644 --- a/trl/trainer/nash_md_trainer.py +++ b/trl/trainer/nash_md_trainer.py @@ -278,7 +278,7 @@ def _compute_logprobs(self, model, model_data, context_length): def compute_logprobs_for_data(m, data): output = m(data["input_ids"], attention_mask=data["attention_mask"]) logits = output.logits[:, context_length - 1 : -1] - token_logprobs = selective_log_softmax(logits, data["input_ids"][:, context_length]) + token_logprobs = selective_log_softmax(logits, data["input_ids"][:, context_length:]) return token_logprobs # Compute logprobs for model completions under the model From 0d3924a583523341146f00c53c159fbeb1af7ebc Mon Sep 17 00:00:00 2001 From: Tyler Romero Date: Fri, 7 Feb 2025 11:27:21 -0800 Subject: [PATCH 04/12] precommit --- examples/scripts/sft_video_llm.py | 2 +- trl/trainer/ppo_trainer.py | 7 +++---- trl/trainer/rloo_trainer.py | 1 - 3 files changed, 4 insertions(+), 6 deletions(-) diff --git a/examples/scripts/sft_video_llm.py b/examples/scripts/sft_video_llm.py index aeb79433e6..3dd0995799 100644 --- a/examples/scripts/sft_video_llm.py +++ b/examples/scripts/sft_video_llm.py @@ -50,12 +50,12 @@ import requests import torch -import wandb from datasets import load_dataset from peft import LoraConfig from qwen_vl_utils import process_vision_info from transformers import AutoModelForVision2Seq, AutoProcessor, BitsAndBytesConfig, Qwen2VLProcessor +import wandb from trl import ModelConfig, ScriptArguments, SFTConfig, SFTTrainer, TrlParser, get_kbit_device_map diff --git a/trl/trainer/ppo_trainer.py b/trl/trainer/ppo_trainer.py index d544b38320..c0b118df7f 100644 --- a/trl/trainer/ppo_trainer.py +++ b/trl/trainer/ppo_trainer.py @@ -25,7 +25,6 @@ import pandas as pd import torch import torch.nn as nn -import torch.nn.functional as F from accelerate import Accelerator from accelerate.utils import broadcast, gather_object from datasets import Dataset @@ -197,9 +196,9 @@ def __init__( args.local_batch_size, args.num_mini_batches, "`local_batch_size` must be a multiple of `num_mini_batches`" ) if args.whiten_rewards: - assert args.local_mini_batch_size >= 8, ( - f"Per-rank minibatch size {args.local_mini_batch_size} is insufficient for whitening" - ) + assert ( + args.local_mini_batch_size >= 8 + ), f"Per-rank minibatch size {args.local_mini_batch_size} is insufficient for whitening" # `per_rank_rollout_batch_size` is our `args.local_batch_size` # `per_rank_minibatch_size` is our `args.local_mini_batch_size` args.num_total_batches = math.ceil( diff --git a/trl/trainer/rloo_trainer.py b/trl/trainer/rloo_trainer.py index 32d2ac037b..f44680e30f 100644 --- a/trl/trainer/rloo_trainer.py +++ b/trl/trainer/rloo_trainer.py @@ -24,7 +24,6 @@ import pandas as pd import torch import torch.nn as nn -import torch.nn.functional as F from accelerate import Accelerator from accelerate.utils import broadcast, gather_object from datasets import Dataset From fa2d67ecc13f7774db65c1f1f048abc6be213ad9 Mon Sep 17 00:00:00 2001 From: Tyler Romero Date: Fri, 7 Feb 2025 12:46:28 -0800 Subject: [PATCH 05/12] Update tests/test_core.py --- tests/test_core.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tests/test_core.py b/tests/test_core.py index 2fe252ac61..a7e0d586b6 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -45,8 +45,7 @@ def whiten(values: torch.Tensor) -> torch.Tensor: diffs = (whiten_unmasked - whiten_masked).sum() self.assertLess(abs(diffs.item()), 0.00001) - def test_extract_per_token_logprobs(self): - """Test extract_per_token_logprobs with different dtypes""" + def test_selective_log_softmax(self): dtypes = [torch.float64, torch.float32, torch.float16, torch.bfloat16] vocab_size = 32768 batch_size = 4 From 279c262df294e9a5f8a428966d13d80438826dd2 Mon Sep 17 00:00:00 2001 From: Tyler Romero Date: Fri, 7 Feb 2025 14:33:43 -0800 Subject: [PATCH 06/12] relocate --- tests/test_core.py | 24 +----------------------- tests/test_utils.py | 31 +++++++++++++++++++++++++++++++ trl/core.py | 24 ------------------------ trl/trainer/bco_trainer.py | 2 +- trl/trainer/cpo_trainer.py | 2 +- trl/trainer/dpo_trainer.py | 2 +- trl/trainer/grpo_trainer.py | 3 +-- trl/trainer/kto_trainer.py | 2 +- trl/trainer/nash_md_trainer.py | 2 +- trl/trainer/orpo_trainer.py | 2 +- trl/trainer/ppo_trainer.py | 3 ++- trl/trainer/rloo_trainer.py | 2 +- trl/trainer/utils.py | 25 +++++++++++++++++++++++++ trl/trainer/xpo_trainer.py | 2 +- 14 files changed, 68 insertions(+), 58 deletions(-) diff --git a/tests/test_core.py b/tests/test_core.py index a7e0d586b6..4be0810912 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -16,7 +16,7 @@ import torch -from trl.core import masked_mean, masked_var, masked_whiten, selective_log_softmax +from trl.core import masked_mean, masked_var, masked_whiten class CoreTester(unittest.TestCase): @@ -44,25 +44,3 @@ def whiten(values: torch.Tensor) -> torch.Tensor: whiten_masked = masked_whiten(self.test_input, self.test_mask)[1:3] diffs = (whiten_unmasked - whiten_masked).sum() self.assertLess(abs(diffs.item()), 0.00001) - - def test_selective_log_softmax(self): - dtypes = [torch.float64, torch.float32, torch.float16, torch.bfloat16] - vocab_size = 32768 - batch_size = 4 - seq_len = 256 - - for dtype in dtypes: - with self.subTest(dtype=dtype): - input_ids = torch.randint(low=0, high=vocab_size, size=(batch_size, seq_len), device="cuda") - logits = torch.randn(batch_size, seq_len, vocab_size, device="cuda", dtype=dtype) - - expected_output = torch.gather(logits.log_softmax(-1), dim=-1, index=input_ids.unsqueeze(-1)).squeeze( - -1 - ) - actual_output = selective_log_softmax(logits=logits, input_ids=input_ids) - - if dtype in [torch.float16, torch.bfloat16]: - # float16 falls back to an exact method - self.assertTrue(torch.equal(actual_output, expected_output)) - else: - torch.testing.assert_close(actual_output, expected_output, rtol=1e-5, atol=1e-5) diff --git a/tests/test_utils.py b/tests/test_utils.py index 9fd7ed9e0f..de3f54119d 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -20,6 +20,7 @@ from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig from transformers.testing_utils import require_peft from transformers.utils import is_peft_available +from parameterized import parameterized from trl import ModelConfig from trl.trainer import compute_accuracy @@ -32,6 +33,7 @@ generate_model_card, get_peft_config, pad, + selective_log_softmax ) @@ -506,3 +508,32 @@ def test_batch_accuracy(self): ) accuracy = compute_token_accuracy(logits, labels) self.assertAlmostEqual(accuracy, 0.8) + + + +class TestSelectiveLogSoftmax(unittest.TestCase): + @parameterized.expand([ + (torch.float64,), + (torch.float32,), + (torch.float16,), + (torch.bfloat16,) + ]) + def test_selective_log_softmax(self, dtype): + """Test selective_log_softmax with logits of different dtypes""" + vocab_size = 32768 + batch_size = 4 + seq_len = 256 + + input_ids = torch.randint(low=0, high=vocab_size, size=(batch_size, seq_len), device="cuda") + logits = torch.randn(batch_size, seq_len, vocab_size, device="cuda", dtype=dtype) + + expected_output = torch.gather(logits.log_softmax(-1), dim=-1, index=input_ids.unsqueeze(-1)).squeeze( + -1 + ) + actual_output = selective_log_softmax(logits=logits, input_ids=input_ids) + + if dtype in [torch.float16, torch.bfloat16]: + # half-precision dtypes fall back to an exact method + self.assertTrue(torch.equal(actual_output, expected_output)) + else: + torch.testing.assert_close(actual_output, expected_output, rtol=1e-5, atol=1e-5) diff --git a/trl/core.py b/trl/core.py index c41e625165..f4e8b074f6 100644 --- a/trl/core.py +++ b/trl/core.py @@ -158,27 +158,3 @@ def randn_tensor( latents = torch.randn(shape, generator=generator, device=rand_device, dtype=dtype, layout=layout).to(device) return latents - - -def selective_log_softmax(logits, input_ids): - """ - A memory efficient implementation of a common log_softmax -> gather operation. - Equivalent to the following naive implementation: - ```python - per_token_logps = torch.gather(logits.log_softmax(-1), dim=-1, index=labels.unsqueeze(-1)).squeeze(-1) - ``` - """ - if logits.dtype in [torch.float32, torch.float64]: - selected_logits = torch.gather(logits, dim=-1, index=input_ids.unsqueeze(-1)).squeeze(-1) - # loop to reduce peak mem consumption - logsumexp_values = torch.stack([torch.logsumexp(lg, dim=-1) for lg in logits]) - per_token_logps = selected_logits - logsumexp_values # log_softmax(x_i) = x_i - logsumexp(x) - else: - # logsumexp approach is unstable with bfloat16, fall back to slightly less efficent approach - per_token_logps = [] - for row_logits, row_labels in zip(logits, input_ids): # loop to reduce peak mem consumption - row_logps = F.log_softmax(row_logits, dim=-1) - row_per_token_logps = row_logps.gather(dim=-1, index=row_labels.unsqueeze(-1)).squeeze(-1) - per_token_logps.append(row_per_token_logps) - per_token_logps = torch.stack(per_token_logps) - return per_token_logps diff --git a/trl/trainer/bco_trainer.py b/trl/trainer/bco_trainer.py index 8873f62477..e163a2046e 100644 --- a/trl/trainer/bco_trainer.py +++ b/trl/trainer/bco_trainer.py @@ -53,7 +53,6 @@ from transformers.trainer_utils import EvalLoopOutput, has_length from transformers.utils import is_peft_available -from ..core import selective_log_softmax from ..data_utils import maybe_apply_chat_template from ..models import PreTrainedModelWrapper, create_reference_model from .bco_config import BCOConfig @@ -66,6 +65,7 @@ log_table_to_comet_experiment, pad_to_length, peft_module_casting_to_bf16, + selective_log_softmax ) diff --git a/trl/trainer/cpo_trainer.py b/trl/trainer/cpo_trainer.py index ad1c4a1b74..b3f0600828 100644 --- a/trl/trainer/cpo_trainer.py +++ b/trl/trainer/cpo_trainer.py @@ -48,7 +48,6 @@ from transformers.trainer_utils import EvalLoopOutput from transformers.utils import is_peft_available, is_torch_fx_proxy -from ..core import selective_log_softmax from ..data_utils import maybe_apply_chat_template, maybe_extract_prompt from .cpo_config import CPOConfig from .utils import ( @@ -61,6 +60,7 @@ log_table_to_comet_experiment, pad_to_length, peft_module_casting_to_bf16, + selective_log_softmax ) diff --git a/trl/trainer/dpo_trainer.py b/trl/trainer/dpo_trainer.py index 9b3578ef29..42be2a7610 100644 --- a/trl/trainer/dpo_trainer.py +++ b/trl/trainer/dpo_trainer.py @@ -53,7 +53,6 @@ from transformers.utils import is_peft_available, is_torch_xpu_available from transformers.utils.deprecation import deprecate_kwarg -from ..core import selective_log_softmax from ..data_utils import maybe_apply_chat_template, maybe_extract_prompt from ..models import PreTrainedModelWrapper, create_reference_model from .callbacks import SyncRefModelCallback @@ -70,6 +69,7 @@ pad, pad_to_length, peft_module_casting_to_bf16, + selective_log_softmax ) diff --git a/trl/trainer/grpo_trainer.py b/trl/trainer/grpo_trainer.py index faacfabb37..bab86f7bcc 100644 --- a/trl/trainer/grpo_trainer.py +++ b/trl/trainer/grpo_trainer.py @@ -42,13 +42,12 @@ from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled from transformers.utils import is_peft_available -from ..core import selective_log_softmax from ..data_utils import apply_chat_template, is_conversational, maybe_apply_chat_template from ..import_utils import is_vllm_available from ..models import create_reference_model, prepare_deepspeed, unwrap_model_for_generation from .callbacks import SyncRefModelCallback from .grpo_config import GRPOConfig -from .utils import generate_model_card, get_comet_experiment_url, pad +from .utils import generate_model_card, get_comet_experiment_url, pad, selective_log_softmax if is_peft_available(): diff --git a/trl/trainer/kto_trainer.py b/trl/trainer/kto_trainer.py index 361370d344..c2bb1ad413 100644 --- a/trl/trainer/kto_trainer.py +++ b/trl/trainer/kto_trainer.py @@ -52,7 +52,6 @@ from transformers.trainer_utils import EvalLoopOutput, has_length from transformers.utils import is_peft_available -from ..core import selective_log_softmax from ..data_utils import maybe_apply_chat_template, maybe_extract_prompt, maybe_unpair_preference_dataset from ..models import PreTrainedModelWrapper, create_reference_model from .kto_config import KTOConfig @@ -64,6 +63,7 @@ log_table_to_comet_experiment, pad_to_length, peft_module_casting_to_bf16, + selective_log_softmax ) diff --git a/trl/trainer/nash_md_trainer.py b/trl/trainer/nash_md_trainer.py index bad028401d..c320f23384 100644 --- a/trl/trainer/nash_md_trainer.py +++ b/trl/trainer/nash_md_trainer.py @@ -34,7 +34,6 @@ from transformers.training_args import OptimizerNames from transformers.utils import is_apex_available -from ..core import selective_log_softmax from ..data_utils import is_conversational, maybe_apply_chat_template from ..models.modeling_base import GeometricMixtureWrapper from ..models.utils import unwrap_model_for_generation @@ -48,6 +47,7 @@ get_comet_experiment_url, get_reward, truncate_right, + selective_log_softmax ) diff --git a/trl/trainer/orpo_trainer.py b/trl/trainer/orpo_trainer.py index 7eed87f665..0bdf51b282 100644 --- a/trl/trainer/orpo_trainer.py +++ b/trl/trainer/orpo_trainer.py @@ -51,7 +51,6 @@ from transformers.trainer_utils import EvalLoopOutput from transformers.utils import is_peft_available, is_torch_fx_proxy -from ..core import selective_log_softmax from ..data_utils import maybe_apply_chat_template, maybe_extract_prompt from ..models import PreTrainedModelWrapper from .orpo_config import ORPOConfig @@ -65,6 +64,7 @@ log_table_to_comet_experiment, pad_to_length, peft_module_casting_to_bf16, + selective_log_softmax ) diff --git a/trl/trainer/ppo_trainer.py b/trl/trainer/ppo_trainer.py index c0b118df7f..87abf84340 100644 --- a/trl/trainer/ppo_trainer.py +++ b/trl/trainer/ppo_trainer.py @@ -46,7 +46,7 @@ from transformers.trainer_callback import CallbackHandler, ExportableState, PrinterCallback from transformers.utils import is_peft_available -from ..core import masked_mean, masked_whiten, selective_log_softmax +from ..core import masked_mean, masked_whiten from ..models import create_reference_model from ..models.utils import unwrap_model_for_generation from .ppo_config import PPOConfig @@ -65,6 +65,7 @@ prepare_deepspeed, print_rich_table, truncate_response, + selective_log_softmax ) diff --git a/trl/trainer/rloo_trainer.py b/trl/trainer/rloo_trainer.py index f44680e30f..82171ce31f 100644 --- a/trl/trainer/rloo_trainer.py +++ b/trl/trainer/rloo_trainer.py @@ -44,7 +44,6 @@ from transformers.trainer import DEFAULT_CALLBACKS, DEFAULT_PROGRESS_CALLBACK from transformers.trainer_callback import CallbackHandler, ExportableState, PrinterCallback -from ..core import selective_log_softmax from ..models.utils import unwrap_model_for_generation from ..trainer.utils import ( OnlineTrainerState, @@ -57,6 +56,7 @@ prepare_deepspeed, print_rich_table, truncate_response, + selective_log_softmax ) from .rloo_config import RLOOConfig from .utils import generate_model_card, get_comet_experiment_url, log_table_to_comet_experiment diff --git a/trl/trainer/utils.py b/trl/trainer/utils.py index 029e5639ab..947930a8a1 100644 --- a/trl/trainer/utils.py +++ b/trl/trainer/utils.py @@ -1668,3 +1668,28 @@ def compute_token_accuracy(logits: torch.Tensor, labels: torch.Tensor, ignore_in accuracy = correct_tokens.item() / total_tokens.item() if total_tokens > 0 else 0.0 return accuracy + + +def selective_log_softmax(logits, input_ids): + """ + A memory efficient implementation of a common log_softmax -> gather operation. + + Equivalent to the following naive implementation: + ```python + per_token_logps = torch.gather(logits.log_softmax(-1), dim=-1, index=labels.unsqueeze(-1)).squeeze(-1) + ``` + """ + if logits.dtype in [torch.float32, torch.float64]: + selected_logits = torch.gather(logits, dim=-1, index=input_ids.unsqueeze(-1)).squeeze(-1) + # loop to reduce peak mem consumption + logsumexp_values = torch.stack([torch.logsumexp(lg, dim=-1) for lg in logits]) + per_token_logps = selected_logits - logsumexp_values # log_softmax(x_i) = x_i - logsumexp(x) + else: + # logsumexp approach is unstable with bfloat16, fall back to slightly less efficent approach + per_token_logps = [] + for row_logits, row_labels in zip(logits, input_ids): # loop to reduce peak mem consumption + row_logps = F.log_softmax(row_logits, dim=-1) + row_per_token_logps = row_logps.gather(dim=-1, index=row_labels.unsqueeze(-1)).squeeze(-1) + per_token_logps.append(row_per_token_logps) + per_token_logps = torch.stack(per_token_logps) + return per_token_logps \ No newline at end of file diff --git a/trl/trainer/xpo_trainer.py b/trl/trainer/xpo_trainer.py index 5ab2b2c15d..e708bfefa0 100644 --- a/trl/trainer/xpo_trainer.py +++ b/trl/trainer/xpo_trainer.py @@ -34,7 +34,6 @@ from transformers.trainer_utils import EvalPrediction from transformers.training_args import OptimizerNames -from ..core import selective_log_softmax from ..data_utils import is_conversational, maybe_apply_chat_template from ..models.utils import unwrap_model_for_generation from .judges import BasePairwiseJudge @@ -46,6 +45,7 @@ get_comet_experiment_url, get_reward, truncate_right, + selective_log_softmax ) from .xpo_config import XPOConfig From 786b2026b0051c4025a61d837086fbd0a05b8f61 Mon Sep 17 00:00:00 2001 From: Tyler Romero Date: Fri, 7 Feb 2025 14:37:12 -0800 Subject: [PATCH 07/12] precommit --- tests/test_utils.py | 16 ++++------------ trl/core.py | 1 - trl/trainer/bco_trainer.py | 2 +- trl/trainer/cpo_trainer.py | 2 +- trl/trainer/dpo_trainer.py | 2 +- trl/trainer/kto_trainer.py | 2 +- trl/trainer/nash_md_trainer.py | 2 +- trl/trainer/orpo_trainer.py | 2 +- trl/trainer/ppo_trainer.py | 2 +- trl/trainer/rloo_trainer.py | 2 +- trl/trainer/utils.py | 2 +- trl/trainer/xpo_trainer.py | 2 +- 12 files changed, 14 insertions(+), 23 deletions(-) diff --git a/tests/test_utils.py b/tests/test_utils.py index de3f54119d..f560e5afe0 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -17,10 +17,10 @@ import numpy as np import torch from datasets import load_dataset +from parameterized import parameterized from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig from transformers.testing_utils import require_peft from transformers.utils import is_peft_available -from parameterized import parameterized from trl import ModelConfig from trl.trainer import compute_accuracy @@ -33,7 +33,7 @@ generate_model_card, get_peft_config, pad, - selective_log_softmax + selective_log_softmax, ) @@ -510,14 +510,8 @@ def test_batch_accuracy(self): self.assertAlmostEqual(accuracy, 0.8) - class TestSelectiveLogSoftmax(unittest.TestCase): - @parameterized.expand([ - (torch.float64,), - (torch.float32,), - (torch.float16,), - (torch.bfloat16,) - ]) + @parameterized.expand([(torch.float64,), (torch.float32,), (torch.float16,), (torch.bfloat16,)]) def test_selective_log_softmax(self, dtype): """Test selective_log_softmax with logits of different dtypes""" vocab_size = 32768 @@ -527,9 +521,7 @@ def test_selective_log_softmax(self, dtype): input_ids = torch.randint(low=0, high=vocab_size, size=(batch_size, seq_len), device="cuda") logits = torch.randn(batch_size, seq_len, vocab_size, device="cuda", dtype=dtype) - expected_output = torch.gather(logits.log_softmax(-1), dim=-1, index=input_ids.unsqueeze(-1)).squeeze( - -1 - ) + expected_output = torch.gather(logits.log_softmax(-1), dim=-1, index=input_ids.unsqueeze(-1)).squeeze(-1) actual_output = selective_log_softmax(logits=logits, input_ids=input_ids) if dtype in [torch.float16, torch.bfloat16]: diff --git a/trl/core.py b/trl/core.py index f4e8b074f6..b12d3a79d5 100644 --- a/trl/core.py +++ b/trl/core.py @@ -20,7 +20,6 @@ import numpy as np import torch -import torch.nn.functional as F from transformers import is_torch_npu_available, is_torch_xpu_available diff --git a/trl/trainer/bco_trainer.py b/trl/trainer/bco_trainer.py index e163a2046e..db4fa156e4 100644 --- a/trl/trainer/bco_trainer.py +++ b/trl/trainer/bco_trainer.py @@ -65,7 +65,7 @@ log_table_to_comet_experiment, pad_to_length, peft_module_casting_to_bf16, - selective_log_softmax + selective_log_softmax, ) diff --git a/trl/trainer/cpo_trainer.py b/trl/trainer/cpo_trainer.py index b3f0600828..174cb4f255 100644 --- a/trl/trainer/cpo_trainer.py +++ b/trl/trainer/cpo_trainer.py @@ -60,7 +60,7 @@ log_table_to_comet_experiment, pad_to_length, peft_module_casting_to_bf16, - selective_log_softmax + selective_log_softmax, ) diff --git a/trl/trainer/dpo_trainer.py b/trl/trainer/dpo_trainer.py index 42be2a7610..a16edb6f37 100644 --- a/trl/trainer/dpo_trainer.py +++ b/trl/trainer/dpo_trainer.py @@ -69,7 +69,7 @@ pad, pad_to_length, peft_module_casting_to_bf16, - selective_log_softmax + selective_log_softmax, ) diff --git a/trl/trainer/kto_trainer.py b/trl/trainer/kto_trainer.py index c2bb1ad413..0c92ad70b8 100644 --- a/trl/trainer/kto_trainer.py +++ b/trl/trainer/kto_trainer.py @@ -63,7 +63,7 @@ log_table_to_comet_experiment, pad_to_length, peft_module_casting_to_bf16, - selective_log_softmax + selective_log_softmax, ) diff --git a/trl/trainer/nash_md_trainer.py b/trl/trainer/nash_md_trainer.py index c320f23384..5d2a8e830d 100644 --- a/trl/trainer/nash_md_trainer.py +++ b/trl/trainer/nash_md_trainer.py @@ -46,8 +46,8 @@ generate_model_card, get_comet_experiment_url, get_reward, + selective_log_softmax, truncate_right, - selective_log_softmax ) diff --git a/trl/trainer/orpo_trainer.py b/trl/trainer/orpo_trainer.py index 0bdf51b282..72436d321d 100644 --- a/trl/trainer/orpo_trainer.py +++ b/trl/trainer/orpo_trainer.py @@ -64,7 +64,7 @@ log_table_to_comet_experiment, pad_to_length, peft_module_casting_to_bf16, - selective_log_softmax + selective_log_softmax, ) diff --git a/trl/trainer/ppo_trainer.py b/trl/trainer/ppo_trainer.py index 87abf84340..cf7f6768a6 100644 --- a/trl/trainer/ppo_trainer.py +++ b/trl/trainer/ppo_trainer.py @@ -64,8 +64,8 @@ peft_module_casting_to_bf16, prepare_deepspeed, print_rich_table, + selective_log_softmax, truncate_response, - selective_log_softmax ) diff --git a/trl/trainer/rloo_trainer.py b/trl/trainer/rloo_trainer.py index 82171ce31f..344253c2b8 100644 --- a/trl/trainer/rloo_trainer.py +++ b/trl/trainer/rloo_trainer.py @@ -55,8 +55,8 @@ get_reward, prepare_deepspeed, print_rich_table, + selective_log_softmax, truncate_response, - selective_log_softmax ) from .rloo_config import RLOOConfig from .utils import generate_model_card, get_comet_experiment_url, log_table_to_comet_experiment diff --git a/trl/trainer/utils.py b/trl/trainer/utils.py index 947930a8a1..45dd53842c 100644 --- a/trl/trainer/utils.py +++ b/trl/trainer/utils.py @@ -1692,4 +1692,4 @@ def selective_log_softmax(logits, input_ids): row_per_token_logps = row_logps.gather(dim=-1, index=row_labels.unsqueeze(-1)).squeeze(-1) per_token_logps.append(row_per_token_logps) per_token_logps = torch.stack(per_token_logps) - return per_token_logps \ No newline at end of file + return per_token_logps diff --git a/trl/trainer/xpo_trainer.py b/trl/trainer/xpo_trainer.py index e708bfefa0..6c7579ae8a 100644 --- a/trl/trainer/xpo_trainer.py +++ b/trl/trainer/xpo_trainer.py @@ -44,8 +44,8 @@ generate_model_card, get_comet_experiment_url, get_reward, + selective_log_softmax, truncate_right, - selective_log_softmax ) from .xpo_config import XPOConfig From 69abfd5bd0e1de81acabf7e648ec755b89daea7d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Fri, 7 Feb 2025 22:49:44 +0000 Subject: [PATCH 08/12] style --- examples/scripts/sft_video_llm.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/scripts/sft_video_llm.py b/examples/scripts/sft_video_llm.py index 3dd0995799..aeb79433e6 100644 --- a/examples/scripts/sft_video_llm.py +++ b/examples/scripts/sft_video_llm.py @@ -50,12 +50,12 @@ import requests import torch +import wandb from datasets import load_dataset from peft import LoraConfig from qwen_vl_utils import process_vision_info from transformers import AutoModelForVision2Seq, AutoProcessor, BitsAndBytesConfig, Qwen2VLProcessor -import wandb from trl import ModelConfig, ScriptArguments, SFTConfig, SFTTrainer, TrlParser, get_kbit_device_map From 788e82c526259e39224f68d83862d16556798aea Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Fri, 7 Feb 2025 22:50:09 +0000 Subject: [PATCH 09/12] smaller values for test, and run on cpu --- tests/test_utils.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/test_utils.py b/tests/test_utils.py index f560e5afe0..9b39615c1b 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -514,12 +514,12 @@ class TestSelectiveLogSoftmax(unittest.TestCase): @parameterized.expand([(torch.float64,), (torch.float32,), (torch.float16,), (torch.bfloat16,)]) def test_selective_log_softmax(self, dtype): """Test selective_log_softmax with logits of different dtypes""" - vocab_size = 32768 + vocab_size = 1024 batch_size = 4 - seq_len = 256 + seq_len = 32 - input_ids = torch.randint(low=0, high=vocab_size, size=(batch_size, seq_len), device="cuda") - logits = torch.randn(batch_size, seq_len, vocab_size, device="cuda", dtype=dtype) + input_ids = torch.randint(low=0, high=vocab_size, size=(batch_size, seq_len)) + logits = torch.randn(batch_size, seq_len, vocab_size, dtype=dtype) expected_output = torch.gather(logits.log_softmax(-1), dim=-1, index=input_ids.unsqueeze(-1)).squeeze(-1) actual_output = selective_log_softmax(logits=logits, input_ids=input_ids) From 9e0d40aa2427ab5aaed6ea4be3e961b1ce95ccb5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Fri, 7 Feb 2025 22:59:21 +0000 Subject: [PATCH 10/12] nit doc improvements --- trl/trainer/utils.py | 23 +++++++++++++++++------ 1 file changed, 17 insertions(+), 6 deletions(-) diff --git a/trl/trainer/utils.py b/trl/trainer/utils.py index 45dd53842c..5153600670 100644 --- a/trl/trainer/utils.py +++ b/trl/trainer/utils.py @@ -26,6 +26,7 @@ import numpy as np import pandas as pd import torch +import torch.nn.functional as F import torch.utils.data from accelerate import Accelerator, PartialState from accelerate.state import AcceleratorState @@ -1670,24 +1671,34 @@ def compute_token_accuracy(logits: torch.Tensor, labels: torch.Tensor, ignore_in return accuracy -def selective_log_softmax(logits, input_ids): +def selective_log_softmax(logits, index): """ - A memory efficient implementation of a common log_softmax -> gather operation. + A memory-efficient implementation of the common `log_softmax -> gather` operation. - Equivalent to the following naive implementation: + This function is equivalent to the following naive implementation: ```python - per_token_logps = torch.gather(logits.log_softmax(-1), dim=-1, index=labels.unsqueeze(-1)).squeeze(-1) + logps = torch.gather(logits.log_softmax(-1), dim=-1, index=index.unsqueeze(-1)).squeeze(-1) ``` + + Args: + logits (`torch.Tensor`): + Logits tensor of shape `(..., num_classes)`. + index (`torch.Tensor`): + Index tensor of shape `(...)`, specifying the positions to gather from the log-softmax output. + + Returns: + `torch.Tensor`: + Gathered log probabilities with the same shape as `index`. """ if logits.dtype in [torch.float32, torch.float64]: - selected_logits = torch.gather(logits, dim=-1, index=input_ids.unsqueeze(-1)).squeeze(-1) + selected_logits = torch.gather(logits, dim=-1, index=index.unsqueeze(-1)).squeeze(-1) # loop to reduce peak mem consumption logsumexp_values = torch.stack([torch.logsumexp(lg, dim=-1) for lg in logits]) per_token_logps = selected_logits - logsumexp_values # log_softmax(x_i) = x_i - logsumexp(x) else: # logsumexp approach is unstable with bfloat16, fall back to slightly less efficent approach per_token_logps = [] - for row_logits, row_labels in zip(logits, input_ids): # loop to reduce peak mem consumption + for row_logits, row_labels in zip(logits, index): # loop to reduce peak mem consumption row_logps = F.log_softmax(row_logits, dim=-1) row_per_token_logps = row_logps.gather(dim=-1, index=row_labels.unsqueeze(-1)).squeeze(-1) per_token_logps.append(row_per_token_logps) From f37fa7d048596400ff5ef42dfb07de0ab58fb703 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Fri, 7 Feb 2025 23:01:01 +0000 Subject: [PATCH 11/12] style --- trl/trainer/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/trl/trainer/utils.py b/trl/trainer/utils.py index 5153600670..ea603b9637 100644 --- a/trl/trainer/utils.py +++ b/trl/trainer/utils.py @@ -1687,7 +1687,7 @@ def selective_log_softmax(logits, index): Index tensor of shape `(...)`, specifying the positions to gather from the log-softmax output. Returns: - `torch.Tensor`: + `torch.Tensor`: Gathered log probabilities with the same shape as `index`. """ if logits.dtype in [torch.float32, torch.float64]: From 0714d8132951a2107f68317c77294fdb695c4675 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Fri, 7 Feb 2025 23:25:08 +0000 Subject: [PATCH 12/12] fix test --- tests/test_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_utils.py b/tests/test_utils.py index 9b39615c1b..64ab68bf74 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -522,7 +522,7 @@ def test_selective_log_softmax(self, dtype): logits = torch.randn(batch_size, seq_len, vocab_size, dtype=dtype) expected_output = torch.gather(logits.log_softmax(-1), dim=-1, index=input_ids.unsqueeze(-1)).squeeze(-1) - actual_output = selective_log_softmax(logits=logits, input_ids=input_ids) + actual_output = selective_log_softmax(logits, input_ids) if dtype in [torch.float16, torch.bfloat16]: # half-precision dtypes fall back to an exact method