Skip to content

Commit

Permalink
fix outputs
Browse files Browse the repository at this point in the history
  • Loading branch information
kashif committed Jan 15, 2025
1 parent f50e74d commit e3eebd3
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 11 deletions.
10 changes: 6 additions & 4 deletions tests/test_dpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1271,12 +1271,14 @@ def test_dpo_trainer_with_liger(self):

# Verify model can still do forward pass after training
dummy_batch = next(iter(trainer.get_train_dataloader()))
model_inputs = {
"input_ids": dummy_batch["prompt_input_ids"],
"attention_mask": dummy_batch["prompt_attention_mask"],
}
with torch.no_grad():
output = trainer.model(
**{k: v for k, v in dummy_batch.items() if k in trainer.model.forward.__code__.co_varnames}
)
output = trainer.model(**model_inputs)
self.assertIsNotNone(output)
self.assertTrue(torch.isfinite(output.loss))
self.assertIsNone(output.loss)


@require_vision
Expand Down
18 changes: 11 additions & 7 deletions trl/trainer/dpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1233,19 +1233,23 @@ def concatenated_forward(self, model: nn.Module, batch: dict[str, Union[list, to
ref_weight=ref_weight if not self.reference_free else None,
ref_bias=ref_bias if not self.reference_free else None,
)
loss, (chosen_logps, rejected_logps, chosen_logits_mean, rejected_logits_mean, nll_loss, *aux_outputs) = (
loss_output
)
(
loss,
(chosen_logps, rejected_logps, chosen_logits_mean, rejected_logits_mean, nll_loss, _, _, *aux_outputs),
) = loss_output

output = {
"loss": loss,
"chosen_logps": chosen_logps,
"rejected_logps": rejected_logps,
"mean_chosen_logits": chosen_logits_mean,
"mean_rejected_logits": rejected_logits_mean,
"nll_loss": nll_loss,
"chosen_rewards": aux_outputs[0],
"rejected_rewards": aux_outputs[1],
}
if self.aux_loss_enabled and aux_outputs:
output["aux_loss"] = aux_outputs[0] # Assuming aux_loss is the first aux output
if self.aux_loss_enabled:
output["aux_loss"] = outputs.aux_loss

return output
else:
Expand Down Expand Up @@ -1374,8 +1378,8 @@ def concatenated_forward(self, model: nn.Module, batch: dict[str, Union[list, to

if self.args.rpo_alpha is not None:
# Only use the chosen logits for the RPO loss
chosen_logits = logits[:num_examples]
chosen_labels = labels[:num_examples]
chosen_logits = logits[:num_examples, :-1] if not self.is_encoder_decoder else logits[:num_examples]
chosen_labels = labels[:num_examples, 1:] if self.is_encoder_decoder else labels[:num_examples]

# Compute the log probabilities of the labels
output["nll_loss"] = F.cross_entropy(
Expand Down

0 comments on commit e3eebd3

Please sign in to comment.