diff --git a/tests/test_utils.py b/tests/test_utils.py index 9fd7ed9e0f..64ab68bf74 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -17,6 +17,7 @@ import numpy as np import torch from datasets import load_dataset +from parameterized import parameterized from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig from transformers.testing_utils import require_peft from transformers.utils import is_peft_available @@ -32,6 +33,7 @@ generate_model_card, get_peft_config, pad, + selective_log_softmax, ) @@ -506,3 +508,24 @@ def test_batch_accuracy(self): ) accuracy = compute_token_accuracy(logits, labels) self.assertAlmostEqual(accuracy, 0.8) + + +class TestSelectiveLogSoftmax(unittest.TestCase): + @parameterized.expand([(torch.float64,), (torch.float32,), (torch.float16,), (torch.bfloat16,)]) + def test_selective_log_softmax(self, dtype): + """Test selective_log_softmax with logits of different dtypes""" + vocab_size = 1024 + batch_size = 4 + seq_len = 32 + + input_ids = torch.randint(low=0, high=vocab_size, size=(batch_size, seq_len)) + logits = torch.randn(batch_size, seq_len, vocab_size, dtype=dtype) + + expected_output = torch.gather(logits.log_softmax(-1), dim=-1, index=input_ids.unsqueeze(-1)).squeeze(-1) + actual_output = selective_log_softmax(logits, input_ids) + + if dtype in [torch.float16, torch.bfloat16]: + # half-precision dtypes fall back to an exact method + self.assertTrue(torch.equal(actual_output, expected_output)) + else: + torch.testing.assert_close(actual_output, expected_output, rtol=1e-5, atol=1e-5) diff --git a/trl/trainer/bco_trainer.py b/trl/trainer/bco_trainer.py index 18df4a7210..db4fa156e4 100644 --- a/trl/trainer/bco_trainer.py +++ b/trl/trainer/bco_trainer.py @@ -65,6 +65,7 @@ log_table_to_comet_experiment, pad_to_length, peft_module_casting_to_bf16, + selective_log_softmax, ) @@ -897,9 +898,11 @@ def _load_optimizer_and_scheduler(self, checkpoint): @contextmanager def null_ref_context(self): """Context manager for handling null reference model (that is, peft adapter manipulation).""" - with self.accelerator.unwrap_model( - self.model - ).disable_adapter() if self.is_peft_model and not self.ref_adapter_name else nullcontext(): + with ( + self.accelerator.unwrap_model(self.model).disable_adapter() + if self.is_peft_model and not self.ref_adapter_name + else nullcontext() + ): if self.ref_adapter_name: self.model.set_adapter(self.ref_adapter_name) yield @@ -1062,7 +1065,7 @@ def get_batch_logps( # dummy token; we'll ignore the losses on these tokens later labels[labels == label_pad_token_id] = 0 - per_token_logps = torch.gather(logits.log_softmax(-1), dim=2, index=labels.unsqueeze(2)).squeeze(2) + per_token_logps = selective_log_softmax(logits, labels) if average_log_prob: return (per_token_logps * loss_mask).sum(-1) / loss_mask.sum(-1) diff --git a/trl/trainer/cpo_trainer.py b/trl/trainer/cpo_trainer.py index 644a2d5353..174cb4f255 100644 --- a/trl/trainer/cpo_trainer.py +++ b/trl/trainer/cpo_trainer.py @@ -60,6 +60,7 @@ log_table_to_comet_experiment, pad_to_length, peft_module_casting_to_bf16, + selective_log_softmax, ) @@ -711,7 +712,7 @@ def get_batch_logps( # dummy token; we'll ignore the losses on these tokens later labels[labels == label_pad_token_id] = 0 - per_token_logps = torch.gather(logits.log_softmax(-1), dim=2, index=labels.unsqueeze(2)).squeeze(2) + per_token_logps = selective_log_softmax(logits, labels) if average_log_prob: return (per_token_logps * loss_mask).sum(-1) / loss_mask.sum(-1) diff --git a/trl/trainer/dpo_trainer.py b/trl/trainer/dpo_trainer.py index cd3c3b4dca..a16edb6f37 100644 --- a/trl/trainer/dpo_trainer.py +++ b/trl/trainer/dpo_trainer.py @@ -69,6 +69,7 @@ pad, pad_to_length, peft_module_casting_to_bf16, + selective_log_softmax, ) @@ -822,9 +823,11 @@ def get_eval_dataloader(self, eval_dataset: Optional[Dataset] = None) -> DataLoa @contextmanager def null_ref_context(self): """Context manager for handling null reference model (that is, peft adapter manipulation).""" - with self.accelerator.unwrap_model( - self.model - ).disable_adapter() if self.is_peft_model and not self.ref_adapter_name else nullcontext(): + with ( + self.accelerator.unwrap_model(self.model).disable_adapter() + if self.is_peft_model and not self.ref_adapter_name + else nullcontext() + ): if self.ref_adapter_name: self.model.set_adapter(self.ref_adapter_name) yield @@ -1211,7 +1214,7 @@ def concatenated_forward(self, model: nn.Module, batch: dict[str, Union[list, to # Compute the log probabilities of the labels labels[~loss_mask] = 0 # dummy token; we'll ignore the losses on these tokens later - per_token_logps = torch.gather(logits.log_softmax(-1), dim=2, index=labels.unsqueeze(2)).squeeze(2) + per_token_logps = selective_log_softmax(logits, labels) per_token_logps[~loss_mask] = 0 per_token_logps = torch.roll(per_token_logps, shifts=1, dims=1) diff --git a/trl/trainer/grpo_trainer.py b/trl/trainer/grpo_trainer.py index 4cbba2fcaa..bab86f7bcc 100644 --- a/trl/trainer/grpo_trainer.py +++ b/trl/trainer/grpo_trainer.py @@ -47,7 +47,7 @@ from ..models import create_reference_model, prepare_deepspeed, unwrap_model_for_generation from .callbacks import SyncRefModelCallback from .grpo_config import GRPOConfig -from .utils import generate_model_card, get_comet_experiment_url, pad +from .utils import generate_model_card, get_comet_experiment_url, pad, selective_log_softmax if is_peft_available(): @@ -442,12 +442,7 @@ def _get_per_token_logps(self, model, input_ids, attention_mask, logits_to_keep) # For transformers<=4.48, logits_to_keep argument isn't supported, so here we drop logits ourselves. # See https://github.com/huggingface/trl/issues/2770 logits = logits[:, -logits_to_keep:] - - # Compute the log probabilities for the input tokens. - token_logits = logits.gather(dim=-1, index=input_ids.unsqueeze(-1)).squeeze(-1) - logsumexp_values = torch.stack([torch.logsumexp(lg, dim=-1) for lg in logits]) # loop to reduce memory peak - token_log_probs = token_logits - logsumexp_values # log_softmax = logits - log(sum(exp(logits))) - return token_log_probs + return selective_log_softmax(logits, input_ids) # compute logprobs for the input tokens def _prepare_inputs(self, inputs: dict[str, Union[torch.Tensor, Any]]) -> dict[str, Union[torch.Tensor, Any]]: device = self.accelerator.device diff --git a/trl/trainer/kto_trainer.py b/trl/trainer/kto_trainer.py index c45a88d554..0c92ad70b8 100644 --- a/trl/trainer/kto_trainer.py +++ b/trl/trainer/kto_trainer.py @@ -63,6 +63,7 @@ log_table_to_comet_experiment, pad_to_length, peft_module_casting_to_bf16, + selective_log_softmax, ) @@ -812,9 +813,11 @@ def _prepare_deepspeed(self, model: PreTrainedModelWrapper): @contextmanager def null_ref_context(self): """Context manager for handling null reference model (that is, peft adapter manipulation).""" - with self.accelerator.unwrap_model( - self.model - ).disable_adapter() if self.is_peft_model and not self.ref_adapter_name else nullcontext(): + with ( + self.accelerator.unwrap_model(self.model).disable_adapter() + if self.is_peft_model and not self.ref_adapter_name + else nullcontext() + ): if self.ref_adapter_name: self.model.set_adapter(self.ref_adapter_name) yield @@ -1032,7 +1035,7 @@ def get_batch_logps( # dummy token; we'll ignore the losses on these tokens later labels[labels == label_pad_token_id] = 0 - per_token_logps = torch.gather(logits.log_softmax(-1), dim=2, index=labels.unsqueeze(2)).squeeze(2) + per_token_logps = selective_log_softmax(logits, labels) if average_log_prob: return (per_token_logps * loss_mask).sum(-1) / loss_mask.sum(-1) diff --git a/trl/trainer/nash_md_trainer.py b/trl/trainer/nash_md_trainer.py index cbe218066e..5d2a8e830d 100644 --- a/trl/trainer/nash_md_trainer.py +++ b/trl/trainer/nash_md_trainer.py @@ -46,6 +46,7 @@ generate_model_card, get_comet_experiment_url, get_reward, + selective_log_softmax, truncate_right, ) @@ -277,8 +278,7 @@ def _compute_logprobs(self, model, model_data, context_length): def compute_logprobs_for_data(m, data): output = m(data["input_ids"], attention_mask=data["attention_mask"]) logits = output.logits[:, context_length - 1 : -1] - logprobs = F.log_softmax(logits, dim=-1) - token_logprobs = torch.gather(logprobs, 2, data["input_ids"][:, context_length:].unsqueeze(-1)).squeeze(-1) + token_logprobs = selective_log_softmax(logits, data["input_ids"][:, context_length:]) return token_logprobs # Compute logprobs for model completions under the model diff --git a/trl/trainer/orpo_trainer.py b/trl/trainer/orpo_trainer.py index 832927d05c..72436d321d 100644 --- a/trl/trainer/orpo_trainer.py +++ b/trl/trainer/orpo_trainer.py @@ -64,6 +64,7 @@ log_table_to_comet_experiment, pad_to_length, peft_module_casting_to_bf16, + selective_log_softmax, ) @@ -718,7 +719,7 @@ def get_batch_logps( # dummy token; we'll ignore the losses on these tokens later labels = torch.where(labels == label_pad_token_id, 0, labels) - per_token_logps = torch.gather(logits.log_softmax(-1), dim=2, index=labels.unsqueeze(2)).squeeze(2) + per_token_logps = selective_log_softmax(logits, labels) if average_log_prob: return (per_token_logps * loss_mask).sum(-1) / loss_mask.sum(-1) diff --git a/trl/trainer/ppo_trainer.py b/trl/trainer/ppo_trainer.py index 27cbdd016c..cf7f6768a6 100644 --- a/trl/trainer/ppo_trainer.py +++ b/trl/trainer/ppo_trainer.py @@ -25,7 +25,6 @@ import pandas as pd import torch import torch.nn as nn -import torch.nn.functional as F from accelerate import Accelerator from accelerate.utils import broadcast, gather_object from datasets import Dataset @@ -65,6 +64,7 @@ peft_module_casting_to_bf16, prepare_deepspeed, print_rich_table, + selective_log_softmax, truncate_response, ) @@ -310,9 +310,11 @@ def get_eval_dataloader(self) -> DataLoader: @contextmanager def null_ref_context(self): """Context manager for handling null reference model (that is, peft adapter manipulation).""" - with self.accelerator.unwrap_model( - self.model.policy - ).disable_adapter() if self.is_peft_model and not self.ref_adapter_name else nullcontext(): + with ( + self.accelerator.unwrap_model(self.model.policy).disable_adapter() + if self.is_peft_model and not self.ref_adapter_name + else nullcontext() + ): if self.ref_adapter_name: self.model.policy.set_adapter(self.ref_adapter_name) yield @@ -427,9 +429,8 @@ def repeat_generator(): query_response = query_responses[i : i + args.local_rollout_forward_batch_size] response = query_response[:, context_length:] logits = logitss[i : i + args.local_rollout_forward_batch_size] - all_logprob = F.log_softmax(logits, dim=-1) - logprob = torch.gather(all_logprob, 2, response.unsqueeze(-1)).squeeze(-1) - del logits, all_logprob + logprob = selective_log_softmax(logits, response) + del logits torch.cuda.empty_cache() if ref_policy is None: @@ -439,9 +440,8 @@ def repeat_generator(): ref_output = forward(ref_policy, query_response, processing_class.pad_token_id) ref_logits = ref_output.logits[:, context_length - 1 : -1] ref_logits /= args.temperature + 1e-7 - ref_all_logprob = F.log_softmax(ref_logits, dim=-1) - ref_logprob = torch.gather(ref_all_logprob, 2, response.unsqueeze(-1)).squeeze(-1) - del ref_output, ref_logits, ref_all_logprob + ref_logprob = selective_log_softmax(ref_logits, response) + del ref_output, ref_logits torch.cuda.empty_cache() # Response Processing 1. truncate response after the first occurrence of `stop_token_id` @@ -547,8 +547,7 @@ def repeat_generator(): output, vpred_temp = forward(model, mb_query_responses, processing_class.pad_token_id) logits = output.logits[:, context_length - 1 : -1] logits /= args.temperature + 1e-7 - new_all_logprobs = F.log_softmax(logits, dim=-1) - new_logprobs = torch.gather(new_all_logprobs, 2, mb_responses.unsqueeze(-1)).squeeze(-1) + new_logprobs = selective_log_softmax(logits, mb_responses) new_logprobs = torch.masked_fill( new_logprobs, padding_mask[micro_batch_inds], INVALID_LOGPROB ) @@ -599,7 +598,7 @@ def repeat_generator(): # del everything and empty cache # fmt: off del ( - output, vpred_temp, logits, new_all_logprobs, new_logprobs, vpred, vpredclipped, + output, vpred_temp, logits, new_logprobs, vpred, vpredclipped, vf_losses1, vf_losses2, vf_loss, vf_clipfrac, logprobs_diff, ratio, pg_losses, pg_losses2, pg_loss_max, pg_loss, loss, pg_clipfrac, prob_dist, entropy, approxkl, mb_return, mb_advantage, mb_values, mb_responses, mb_query_responses, mb_logprobs, diff --git a/trl/trainer/rloo_trainer.py b/trl/trainer/rloo_trainer.py index 6626baa15c..344253c2b8 100644 --- a/trl/trainer/rloo_trainer.py +++ b/trl/trainer/rloo_trainer.py @@ -24,7 +24,6 @@ import pandas as pd import torch import torch.nn as nn -import torch.nn.functional as F from accelerate import Accelerator from accelerate.utils import broadcast, gather_object from datasets import Dataset @@ -56,6 +55,7 @@ get_reward, prepare_deepspeed, print_rich_table, + selective_log_softmax, truncate_response, ) from .rloo_config import RLOOConfig @@ -330,17 +330,15 @@ def repeat_generator(): query_response = query_responses[i : i + args.local_rollout_forward_batch_size] response = query_response[:, context_length:] logits = logitss[i : i + args.local_rollout_forward_batch_size] - all_logprob = F.log_softmax(logits, dim=-1) - logprob = torch.gather(all_logprob, 2, response.unsqueeze(-1)).squeeze(-1) - del logits, all_logprob + logprob = selective_log_softmax(logits, response) + del logits torch.cuda.empty_cache() ref_output = forward(ref_policy, query_response, processing_class.pad_token_id) ref_logits = ref_output.logits[:, context_length - 1 : -1] ref_logits /= args.temperature + 1e-7 - ref_all_logprob = F.log_softmax(ref_logits, dim=-1) - ref_logprob = torch.gather(ref_all_logprob, 2, response.unsqueeze(-1)).squeeze(-1) - del ref_output, ref_logits, ref_all_logprob + ref_logprob = selective_log_softmax(ref_logits, response) + del ref_output, ref_logits torch.cuda.empty_cache() # Response Processing 1. truncate response after the first occurrence of `stop_token_id` @@ -467,8 +465,7 @@ def repeat_generator(): logits /= args.temperature + 1e-7 # Compute new logprobs - new_all_logprobs = F.log_softmax(logits, dim=-1) - new_logprobs = torch.gather(new_all_logprobs, 2, mb_responses.unsqueeze(-1)).squeeze(-1) + new_logprobs = selective_log_softmax(logits, mb_responses) new_logprobs = torch.masked_fill( new_logprobs, padding_mask[micro_batch_inds], INVALID_LOGPROB ) @@ -512,9 +509,8 @@ def repeat_generator(): # del everything and empty cache # fmt: off del ( - output, logits, new_all_logprobs, new_logprobs, - logprobs_diff, ratio, pg_losses, pg_losses2, - pg_loss, loss, pg_clipfrac, prob_dist, entropy, approxkl, + output, logits, new_logprobs, logprobs_diff, ratio, pg_losses, + pg_losses2, pg_loss, loss, pg_clipfrac, prob_dist, entropy, approxkl, mb_advantage, mb_responses, mb_query_responses, mb_logprobs, ) # fmt: on diff --git a/trl/trainer/utils.py b/trl/trainer/utils.py index 029e5639ab..ea603b9637 100644 --- a/trl/trainer/utils.py +++ b/trl/trainer/utils.py @@ -26,6 +26,7 @@ import numpy as np import pandas as pd import torch +import torch.nn.functional as F import torch.utils.data from accelerate import Accelerator, PartialState from accelerate.state import AcceleratorState @@ -1668,3 +1669,38 @@ def compute_token_accuracy(logits: torch.Tensor, labels: torch.Tensor, ignore_in accuracy = correct_tokens.item() / total_tokens.item() if total_tokens > 0 else 0.0 return accuracy + + +def selective_log_softmax(logits, index): + """ + A memory-efficient implementation of the common `log_softmax -> gather` operation. + + This function is equivalent to the following naive implementation: + ```python + logps = torch.gather(logits.log_softmax(-1), dim=-1, index=index.unsqueeze(-1)).squeeze(-1) + ``` + + Args: + logits (`torch.Tensor`): + Logits tensor of shape `(..., num_classes)`. + index (`torch.Tensor`): + Index tensor of shape `(...)`, specifying the positions to gather from the log-softmax output. + + Returns: + `torch.Tensor`: + Gathered log probabilities with the same shape as `index`. + """ + if logits.dtype in [torch.float32, torch.float64]: + selected_logits = torch.gather(logits, dim=-1, index=index.unsqueeze(-1)).squeeze(-1) + # loop to reduce peak mem consumption + logsumexp_values = torch.stack([torch.logsumexp(lg, dim=-1) for lg in logits]) + per_token_logps = selected_logits - logsumexp_values # log_softmax(x_i) = x_i - logsumexp(x) + else: + # logsumexp approach is unstable with bfloat16, fall back to slightly less efficent approach + per_token_logps = [] + for row_logits, row_labels in zip(logits, index): # loop to reduce peak mem consumption + row_logps = F.log_softmax(row_logits, dim=-1) + row_per_token_logps = row_logps.gather(dim=-1, index=row_labels.unsqueeze(-1)).squeeze(-1) + per_token_logps.append(row_per_token_logps) + per_token_logps = torch.stack(per_token_logps) + return per_token_logps diff --git a/trl/trainer/xpo_trainer.py b/trl/trainer/xpo_trainer.py index 2d535344e7..6c7579ae8a 100644 --- a/trl/trainer/xpo_trainer.py +++ b/trl/trainer/xpo_trainer.py @@ -44,6 +44,7 @@ generate_model_card, get_comet_experiment_url, get_reward, + selective_log_softmax, truncate_right, ) from .xpo_config import XPOConfig @@ -274,8 +275,7 @@ def _compute_logprobs(self, model, model_data, ref_data, context_length): def compute_logprobs_for_data(m, data): output = m(data["input_ids"], attention_mask=data["attention_mask"]) logits = output.logits[:, context_length - 1 : -1] - logprobs = F.log_softmax(logits, dim=-1) - token_logprobs = torch.gather(logprobs, 2, data["input_ids"][:, context_length:].unsqueeze(-1)).squeeze(-1) + token_logprobs = selective_log_softmax(logits, data["input_ids"][:, context_length:]) return token_logprobs # Compute logprobs for model completions