diff --git a/trl/trainer/dpo_trainer.py b/trl/trainer/dpo_trainer.py index 7a6e0fe6af..2403539c6b 100644 --- a/trl/trainer/dpo_trainer.py +++ b/trl/trainer/dpo_trainer.py @@ -1122,7 +1122,7 @@ def get_batch_logps( def concatenated_forward( self, model: nn.Module, batch: Dict[str, Union[List, torch.LongTensor]] - ) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]: + ) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]: """Run the given model on the given batch of inputs, concatenating the chosen and rejected inputs together. We do this to avoid doing two forward passes, because it's faster for FSDP. @@ -1158,7 +1158,23 @@ def concatenated_forward( is_encoder_decoder=self.is_encoder_decoder, label_pad_token_id=self.label_pad_token_id, ) - chosen_logps_avg = all_logps[:len_chosen] / size_completion[:len_chosen] + + def cross_entropy_loss(logits, labels): + if not self.is_encoder_decoder: + # Shift so that tokens < n predict n + logits = logits[..., :-1, :].contiguous() + labels = labels[..., 1:].contiguous() + # Flatten the tokens + loss_fct = nn.CrossEntropyLoss() + logits = logits.view(-1, logits.shape[-1]) + labels = labels.view(-1) + # Enable model parallelism + labels = labels.to(logits.device) + loss = loss_fct(logits, labels) + return loss + + labels = concatenated_batch["concatenated_labels"].clone() + nll_loss = cross_entropy_loss(all_logits[:len_chosen], labels[:len_chosen]) if self.loss_type == "ipo": all_logps = all_logps / size_completion @@ -1169,7 +1185,7 @@ def concatenated_forward( chosen_logits = all_logits[:len_chosen] rejected_logits = all_logits[len_chosen:] - return (chosen_logps, rejected_logps, chosen_logits, rejected_logits, chosen_logps_avg) + return (chosen_logps, rejected_logps, chosen_logits, rejected_logits, nll_loss) def get_batch_loss_metrics( self, @@ -1185,7 +1201,7 @@ def get_batch_loss_metrics( policy_rejected_logps, policy_chosen_logits, policy_rejected_logits, - policy_chosen_logps_avg, + policy_nll_loss, ) = self.concatenated_forward(model, batch) # if reference_chosen_logps and reference_rejected_logps in batch use them, otherwise use the reference model @@ -1225,7 +1241,7 @@ def get_batch_loss_metrics( reward_accuracies = (chosen_rewards > rejected_rewards).float() if self.args.rpo_alpha is not None: - losses = losses * self.args.rpo_alpha - policy_chosen_logps_avg + losses = losses * self.args.rpo_alpha + policy_nll_loss prefix = "eval_" if train_eval == "eval" else "" metrics[f"{prefix}rewards/chosen"] = chosen_rewards.mean().cpu() @@ -1236,6 +1252,8 @@ def get_batch_loss_metrics( metrics[f"{prefix}logps/chosen"] = policy_chosen_logps.detach().mean().cpu() metrics[f"{prefix}logits/rejected"] = policy_rejected_logits.detach().mean().cpu() metrics[f"{prefix}logits/chosen"] = policy_chosen_logits.detach().mean().cpu() + if self.args.rpo_alpha is not None: + metrics[f"{prefix}nll_loss"] = policy_nll_loss.detach().mean().cpu() return losses.mean(), metrics