Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RPO] fix nll loss #1705

Merged
merged 1 commit into from
Jun 7, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 23 additions & 5 deletions trl/trainer/dpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1122,7 +1122,7 @@ def get_batch_logps(

def concatenated_forward(
self, model: nn.Module, batch: Dict[str, Union[List, torch.LongTensor]]
) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
"""Run the given model on the given batch of inputs, concatenating the chosen and rejected inputs together.

We do this to avoid doing two forward passes, because it's faster for FSDP.
Expand Down Expand Up @@ -1158,7 +1158,23 @@ def concatenated_forward(
is_encoder_decoder=self.is_encoder_decoder,
label_pad_token_id=self.label_pad_token_id,
)
chosen_logps_avg = all_logps[:len_chosen] / size_completion[:len_chosen]

def cross_entropy_loss(logits, labels):
if not self.is_encoder_decoder:
# Shift so that tokens < n predict n
logits = logits[..., :-1, :].contiguous()
labels = labels[..., 1:].contiguous()
# Flatten the tokens
loss_fct = nn.CrossEntropyLoss()
logits = logits.view(-1, logits.shape[-1])
labels = labels.view(-1)
# Enable model parallelism
labels = labels.to(logits.device)
loss = loss_fct(logits, labels)
return loss

labels = concatenated_batch["concatenated_labels"].clone()
nll_loss = cross_entropy_loss(all_logits[:len_chosen], labels[:len_chosen])

if self.loss_type == "ipo":
all_logps = all_logps / size_completion
Expand All @@ -1169,7 +1185,7 @@ def concatenated_forward(
chosen_logits = all_logits[:len_chosen]
rejected_logits = all_logits[len_chosen:]

return (chosen_logps, rejected_logps, chosen_logits, rejected_logits, chosen_logps_avg)
return (chosen_logps, rejected_logps, chosen_logits, rejected_logits, nll_loss)

def get_batch_loss_metrics(
self,
Expand All @@ -1185,7 +1201,7 @@ def get_batch_loss_metrics(
policy_rejected_logps,
policy_chosen_logits,
policy_rejected_logits,
policy_chosen_logps_avg,
policy_nll_loss,
) = self.concatenated_forward(model, batch)

# if reference_chosen_logps and reference_rejected_logps in batch use them, otherwise use the reference model
Expand Down Expand Up @@ -1225,7 +1241,7 @@ def get_batch_loss_metrics(
reward_accuracies = (chosen_rewards > rejected_rewards).float()

if self.args.rpo_alpha is not None:
losses = losses * self.args.rpo_alpha - policy_chosen_logps_avg
losses = losses * self.args.rpo_alpha + policy_nll_loss
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps I'm missing something, but isn't the NLL loss normalized by the total sequence length?

Screenshot 2024-06-06 at 14 54 39

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes the CrossEntropyLoss is reducing by mean by default

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the clarification!


prefix = "eval_" if train_eval == "eval" else ""
metrics[f"{prefix}rewards/chosen"] = chosen_rewards.mean().cpu()
Expand All @@ -1236,6 +1252,8 @@ def get_batch_loss_metrics(
metrics[f"{prefix}logps/chosen"] = policy_chosen_logps.detach().mean().cpu()
metrics[f"{prefix}logits/rejected"] = policy_rejected_logits.detach().mean().cpu()
metrics[f"{prefix}logits/chosen"] = policy_chosen_logits.detach().mean().cpu()
if self.args.rpo_alpha is not None:
metrics[f"{prefix}nll_loss"] = policy_nll_loss.detach().mean().cpu()

return losses.mean(), metrics

Expand Down
Loading