diff --git a/trl/trainer/dpo_trainer.py b/trl/trainer/dpo_trainer.py index b5755e6a4b..f2d274c842 100644 --- a/trl/trainer/dpo_trainer.py +++ b/trl/trainer/dpo_trainer.py @@ -15,7 +15,7 @@ import inspect import random import warnings -from collections import defaultdict +from collections import OrderedDict from copy import deepcopy from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union @@ -34,7 +34,7 @@ TrainingArguments, ) from transformers.trainer_callback import TrainerCallback -from transformers.trainer_utils import EvalLoopOutput +from transformers.trainer_utils import EvalLoopOutput, EvalPrediction from ..import_utils import is_peft_available, is_wandb_available from ..models import PreTrainedModelWrapper, create_reference_model @@ -48,6 +48,7 @@ if is_wandb_available(): import wandb + if is_deepspeed_available(): import deepspeed @@ -298,6 +299,8 @@ def make_inputs_require_grad(module, input, output): ) self.use_dpo_data_collator = True + + args.label_names = ["chosen_labels", "rejected_labels"] else: self.use_dpo_data_collator = False @@ -320,7 +323,8 @@ def make_inputs_require_grad(module, input, output): self.label_smoothing = label_smoothing self.loss_type = loss_type - self._stored_metrics = defaultdict(lambda: defaultdict(list)) + if compute_metrics is None: + compute_metrics = compute_dpo_metrics super().__init__( model=model, @@ -545,21 +549,24 @@ def concatenated_forward( return (chosen_logps, rejected_logps, chosen_logits, rejected_logits) - def get_batch_metrics( + def compute_loss( self, - model, - batch: Dict[str, Union[List, torch.LongTensor]], - train_eval: Literal["train", "eval"] = "train", - ): - """Compute the DPO loss and other metrics for the given batch of inputs for train or test.""" - metrics = {} - + model: Union[PreTrainedModel, nn.Module], + inputs: Dict[str, Union[torch.Tensor, Any]], + return_outputs=False, + ) -> Union[torch.Tensor, Tuple[torch.Tensor, Dict[str, torch.Tensor]]]: + if not self.use_dpo_data_collator: + warnings.warn( + "compute_loss is only implemented for DPODataCollatorWithPadding, and you passed a datacollator that is different than " + "DPODataCollatorWithPadding - you might see unexpected behavior. Alternatively, you can implement your own prediction_step method if you are using a custom data collator" + ) ( policy_chosen_logps, policy_rejected_logps, policy_chosen_logits, policy_rejected_logits, - ) = self.concatenated_forward(model, batch) + ) = self.concatenated_forward(model, inputs) + with torch.no_grad(): if self.ref_model is None: with self.accelerator.unwrap_model(self.model).disable_adapter(): @@ -568,14 +575,14 @@ def get_batch_metrics( reference_rejected_logps, _, _, - ) = self.concatenated_forward(self.model, batch) + ) = self.concatenated_forward(self.model, inputs) else: ( reference_chosen_logps, reference_rejected_logps, _, _, - ) = self.concatenated_forward(self.ref_model, batch) + ) = self.concatenated_forward(self.ref_model, inputs) losses, chosen_rewards, rejected_rewards = self.dpo_loss( policy_chosen_logps, @@ -583,40 +590,20 @@ def get_batch_metrics( reference_chosen_logps, reference_rejected_logps, ) - reward_accuracies = (chosen_rewards > rejected_rewards).float() - - prefix = "eval_" if train_eval == "eval" else "" - metrics[f"{prefix}rewards/chosen"] = chosen_rewards.cpu().mean() - metrics[f"{prefix}rewards/rejected"] = rejected_rewards.cpu().mean() - metrics[f"{prefix}rewards/accuracies"] = reward_accuracies.cpu().mean() - metrics[f"{prefix}rewards/margins"] = (chosen_rewards - rejected_rewards).cpu().mean() - metrics[f"{prefix}logps/rejected"] = policy_rejected_logps.detach().cpu().mean() - metrics[f"{prefix}logps/chosen"] = policy_chosen_logps.detach().cpu().mean() - metrics[f"{prefix}logits/rejected"] = policy_rejected_logits.detach().cpu().mean() - metrics[f"{prefix}logits/chosen"] = policy_chosen_logits.detach().cpu().mean() - - return losses.mean(), metrics - - def compute_loss( - self, - model: Union[PreTrainedModel, nn.Module], - inputs: Dict[str, Union[torch.Tensor, Any]], - return_outputs=False, - ) -> Union[torch.Tensor, Tuple[torch.Tensor, Dict[str, torch.Tensor]]]: - if not self.use_dpo_data_collator: - warnings.warn( - "compute_loss is only implemented for DPODataCollatorWithPadding, and you passed a datacollator that is different than " - "DPODataCollatorWithPadding - you might see unexpected behavior. Alternatively, you can implement your own prediction_step method if you are using a custom data collator" - ) - loss, metrics = self.get_batch_metrics(model, inputs, train_eval="train") - - # force log the metrics - if self.accelerator.is_main_process: - self.store_metrics(metrics, train_eval="train") + loss = losses.mean() + outputs = OrderedDict( + { + "loss": loss, + "chosen_rewards": chosen_rewards, + "rejected_rewards": rejected_rewards, + "policy_chosen_logps": policy_chosen_logps.detach(), + "policy_rejected_logps": policy_chosen_logps.detach(), + "reference_chosen_logps": reference_chosen_logps.detach(), + "reference_rejected_logps": reference_chosen_logps.detach(), + } + ) - if return_outputs: - return (loss, metrics) - return loss + return (loss, outputs) if return_outputs else loss def get_batch_samples(self, model, batch: Dict[str, torch.LongTensor]) -> Tuple[str, str]: """Generate samples from the model and reference model for the given batch of inputs.""" @@ -667,36 +654,8 @@ def prediction_step( "prediction_step is only implemented for DPODataCollatorWithPadding, and you passed a datacollator that is different than " "DPODataCollatorWithPadding - you might see unexpected behavior. Alternatively, you can implement your own prediction_step method if you are using a custom data collator" ) - if ignore_keys is None: - if hasattr(model, "config"): - ignore_keys = getattr(model.config, "keys_to_ignore_at_inference", []) - else: - ignore_keys = [] - with torch.no_grad(): - loss, metrics = self.get_batch_metrics(model, inputs, train_eval="eval") - - # force log the metrics - if self.accelerator.is_main_process: - self.store_metrics(metrics, train_eval="eval") - - if prediction_loss_only: - return (loss.detach(), None, None) - - # logits for the chosen and rejected samples from model - logits_dict = { - "eval_logits/chosen": metrics["eval_logits/chosen"], - "eval_logits/rejected": metrics["eval_logits/rejected"], - } - logits = tuple(v.unsqueeze(dim=0) for k, v in logits_dict.items() if k not in ignore_keys) - logits = torch.stack(logits).mean(axis=1).to(self.accelerator.device) - labels = torch.zeros(logits.shape[0], device=self.accelerator.device) - - return (loss.detach(), logits, labels) - - def store_metrics(self, metrics: Dict[str, float], train_eval: Literal["train", "eval"] = "train") -> None: - for key, value in metrics.items(): - self._stored_metrics[train_eval][key].append(value) + return super().prediction_step(model, inputs, prediction_loss_only, ignore_keys) def evaluation_loop( self, @@ -748,18 +707,25 @@ def evaluation_loop( return initial_output - def log(self, logs: Dict[str, float]) -> None: - """ - Log `logs` on the various objects watching training, including stored metrics. - Args: - logs (`Dict[str, float]`): - The values to log. - """ - # logs either has 'loss' or 'eval_loss' - train_eval = "train" if "loss" in logs else "eval" - # Add averaged stored metrics to logs - for key, metrics in self._stored_metrics[train_eval].items(): - logs[key] = torch.tensor(metrics).mean().item() - del self._stored_metrics[train_eval] - return super().log(logs) +def compute_dpo_metrics(eval_preds: EvalPrediction): + ( + chosen_rewards, + rejected_rewards, + policy_chosen_logps, + policy_rejected_logps, + reference_chosen_logps, + reference_rejected_logps, + ) = eval_preds.predictions + + reward_accuracies = (chosen_rewards > rejected_rewards).mean() + + metrics = {} + metrics["rewards/chosen"] = chosen_rewards.mean() + metrics["rewards/rejected"] = rejected_rewards.mean() + metrics["rewards/accuracies"] = reward_accuracies.mean() + metrics["rewards/margins"] = (chosen_rewards - rejected_rewards).mean() + metrics["logps/rejected"] = policy_rejected_logps.mean() + metrics["logps/chosen"] = policy_chosen_logps.mean() + + return metrics