Skip to content

Commit

Permalink
[DPO] Refactor eval logging of dpo trainer (#954)
Browse files Browse the repository at this point in the history
* first attempts at refactor of dpo trainer

* removed extra stuff in prediction step

* import fixes

* label names

* all working

---------

Co-authored-by: Leandro von Werra <[email protected]>
  • Loading branch information
mnoukhov and lvwerra authored Nov 30, 2023
1 parent c203e47 commit 6d9ea38
Showing 1 changed file with 56 additions and 90 deletions.
146 changes: 56 additions & 90 deletions trl/trainer/dpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import inspect
import random
import warnings
from collections import defaultdict
from collections import OrderedDict
from copy import deepcopy
from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union

Expand All @@ -34,7 +34,7 @@
TrainingArguments,
)
from transformers.trainer_callback import TrainerCallback
from transformers.trainer_utils import EvalLoopOutput
from transformers.trainer_utils import EvalLoopOutput, EvalPrediction

from ..import_utils import is_peft_available, is_wandb_available
from ..models import PreTrainedModelWrapper, create_reference_model
Expand All @@ -48,6 +48,7 @@
if is_wandb_available():
import wandb


if is_deepspeed_available():
import deepspeed

Expand Down Expand Up @@ -298,6 +299,8 @@ def make_inputs_require_grad(module, input, output):
)

self.use_dpo_data_collator = True

args.label_names = ["chosen_labels", "rejected_labels"]
else:
self.use_dpo_data_collator = False

Expand All @@ -320,7 +323,8 @@ def make_inputs_require_grad(module, input, output):
self.label_smoothing = label_smoothing
self.loss_type = loss_type

self._stored_metrics = defaultdict(lambda: defaultdict(list))
if compute_metrics is None:
compute_metrics = compute_dpo_metrics

super().__init__(
model=model,
Expand Down Expand Up @@ -545,21 +549,24 @@ def concatenated_forward(

return (chosen_logps, rejected_logps, chosen_logits, rejected_logits)

def get_batch_metrics(
def compute_loss(
self,
model,
batch: Dict[str, Union[List, torch.LongTensor]],
train_eval: Literal["train", "eval"] = "train",
):
"""Compute the DPO loss and other metrics for the given batch of inputs for train or test."""
metrics = {}

model: Union[PreTrainedModel, nn.Module],
inputs: Dict[str, Union[torch.Tensor, Any]],
return_outputs=False,
) -> Union[torch.Tensor, Tuple[torch.Tensor, Dict[str, torch.Tensor]]]:
if not self.use_dpo_data_collator:
warnings.warn(
"compute_loss is only implemented for DPODataCollatorWithPadding, and you passed a datacollator that is different than "
"DPODataCollatorWithPadding - you might see unexpected behavior. Alternatively, you can implement your own prediction_step method if you are using a custom data collator"
)
(
policy_chosen_logps,
policy_rejected_logps,
policy_chosen_logits,
policy_rejected_logits,
) = self.concatenated_forward(model, batch)
) = self.concatenated_forward(model, inputs)

with torch.no_grad():
if self.ref_model is None:
with self.accelerator.unwrap_model(self.model).disable_adapter():
Expand All @@ -568,55 +575,35 @@ def get_batch_metrics(
reference_rejected_logps,
_,
_,
) = self.concatenated_forward(self.model, batch)
) = self.concatenated_forward(self.model, inputs)
else:
(
reference_chosen_logps,
reference_rejected_logps,
_,
_,
) = self.concatenated_forward(self.ref_model, batch)
) = self.concatenated_forward(self.ref_model, inputs)

losses, chosen_rewards, rejected_rewards = self.dpo_loss(
policy_chosen_logps,
policy_rejected_logps,
reference_chosen_logps,
reference_rejected_logps,
)
reward_accuracies = (chosen_rewards > rejected_rewards).float()

prefix = "eval_" if train_eval == "eval" else ""
metrics[f"{prefix}rewards/chosen"] = chosen_rewards.cpu().mean()
metrics[f"{prefix}rewards/rejected"] = rejected_rewards.cpu().mean()
metrics[f"{prefix}rewards/accuracies"] = reward_accuracies.cpu().mean()
metrics[f"{prefix}rewards/margins"] = (chosen_rewards - rejected_rewards).cpu().mean()
metrics[f"{prefix}logps/rejected"] = policy_rejected_logps.detach().cpu().mean()
metrics[f"{prefix}logps/chosen"] = policy_chosen_logps.detach().cpu().mean()
metrics[f"{prefix}logits/rejected"] = policy_rejected_logits.detach().cpu().mean()
metrics[f"{prefix}logits/chosen"] = policy_chosen_logits.detach().cpu().mean()

return losses.mean(), metrics

def compute_loss(
self,
model: Union[PreTrainedModel, nn.Module],
inputs: Dict[str, Union[torch.Tensor, Any]],
return_outputs=False,
) -> Union[torch.Tensor, Tuple[torch.Tensor, Dict[str, torch.Tensor]]]:
if not self.use_dpo_data_collator:
warnings.warn(
"compute_loss is only implemented for DPODataCollatorWithPadding, and you passed a datacollator that is different than "
"DPODataCollatorWithPadding - you might see unexpected behavior. Alternatively, you can implement your own prediction_step method if you are using a custom data collator"
)
loss, metrics = self.get_batch_metrics(model, inputs, train_eval="train")

# force log the metrics
if self.accelerator.is_main_process:
self.store_metrics(metrics, train_eval="train")
loss = losses.mean()
outputs = OrderedDict(
{
"loss": loss,
"chosen_rewards": chosen_rewards,
"rejected_rewards": rejected_rewards,
"policy_chosen_logps": policy_chosen_logps.detach(),
"policy_rejected_logps": policy_chosen_logps.detach(),
"reference_chosen_logps": reference_chosen_logps.detach(),
"reference_rejected_logps": reference_chosen_logps.detach(),
}
)

if return_outputs:
return (loss, metrics)
return loss
return (loss, outputs) if return_outputs else loss

def get_batch_samples(self, model, batch: Dict[str, torch.LongTensor]) -> Tuple[str, str]:
"""Generate samples from the model and reference model for the given batch of inputs."""
Expand Down Expand Up @@ -667,36 +654,8 @@ def prediction_step(
"prediction_step is only implemented for DPODataCollatorWithPadding, and you passed a datacollator that is different than "
"DPODataCollatorWithPadding - you might see unexpected behavior. Alternatively, you can implement your own prediction_step method if you are using a custom data collator"
)
if ignore_keys is None:
if hasattr(model, "config"):
ignore_keys = getattr(model.config, "keys_to_ignore_at_inference", [])
else:
ignore_keys = []

with torch.no_grad():
loss, metrics = self.get_batch_metrics(model, inputs, train_eval="eval")

# force log the metrics
if self.accelerator.is_main_process:
self.store_metrics(metrics, train_eval="eval")

if prediction_loss_only:
return (loss.detach(), None, None)

# logits for the chosen and rejected samples from model
logits_dict = {
"eval_logits/chosen": metrics["eval_logits/chosen"],
"eval_logits/rejected": metrics["eval_logits/rejected"],
}
logits = tuple(v.unsqueeze(dim=0) for k, v in logits_dict.items() if k not in ignore_keys)
logits = torch.stack(logits).mean(axis=1).to(self.accelerator.device)
labels = torch.zeros(logits.shape[0], device=self.accelerator.device)

return (loss.detach(), logits, labels)

def store_metrics(self, metrics: Dict[str, float], train_eval: Literal["train", "eval"] = "train") -> None:
for key, value in metrics.items():
self._stored_metrics[train_eval][key].append(value)
return super().prediction_step(model, inputs, prediction_loss_only, ignore_keys)

def evaluation_loop(
self,
Expand Down Expand Up @@ -748,18 +707,25 @@ def evaluation_loop(

return initial_output

def log(self, logs: Dict[str, float]) -> None:
"""
Log `logs` on the various objects watching training, including stored metrics.

Args:
logs (`Dict[str, float]`):
The values to log.
"""
# logs either has 'loss' or 'eval_loss'
train_eval = "train" if "loss" in logs else "eval"
# Add averaged stored metrics to logs
for key, metrics in self._stored_metrics[train_eval].items():
logs[key] = torch.tensor(metrics).mean().item()
del self._stored_metrics[train_eval]
return super().log(logs)
def compute_dpo_metrics(eval_preds: EvalPrediction):
(
chosen_rewards,
rejected_rewards,
policy_chosen_logps,
policy_rejected_logps,
reference_chosen_logps,
reference_rejected_logps,
) = eval_preds.predictions

reward_accuracies = (chosen_rewards > rejected_rewards).mean()

metrics = {}
metrics["rewards/chosen"] = chosen_rewards.mean()
metrics["rewards/rejected"] = rejected_rewards.mean()
metrics["rewards/accuracies"] = reward_accuracies.mean()
metrics["rewards/margins"] = (chosen_rewards - rejected_rewards).mean()
metrics["logps/rejected"] = policy_rejected_logps.mean()
metrics["logps/chosen"] = policy_chosen_logps.mean()

return metrics

0 comments on commit 6d9ea38

Please sign in to comment.