Skip to content

Commit

Permalink
Fix grpo
Browse files Browse the repository at this point in the history
  • Loading branch information
qgallouedec committed Jan 24, 2025
1 parent 8e65825 commit 6130c96
Showing 1 changed file with 6 additions and 1 deletion.
7 changes: 6 additions & 1 deletion trl/trainer/grpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,6 +262,11 @@ def data_collator(features): # No data collation is needed in GRPO
optimizers=optimizers,
)

# Gradient accumulation requires scaled loss. Normally, loss scaling in the parent class depends on whether the
# model accepts loss-related kwargs. Since we compute our own loss, this check is irrelevant. We set
# self.model_accepts_loss_kwargs to False to enable scaling.
self.model_accepts_loss_kwargs = False

if self.ref_model is not None:
self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True)

Expand Down Expand Up @@ -393,7 +398,7 @@ def get_per_token_logps(model, input_ids):
return loss

def log(self, logs: dict[str, float], start_time: Optional[float] = None) -> None:
metrics = {key: sum(val) / len(val) for key, val in self._metrics.items()} # average the metrics
metrics = {key: sum(val) / len(val) for key, val in self._metrics.items() if val} # average the metrics
logs = {**logs, **metrics}
if version.parse(transformers.__version__) >= version.parse("4.47.0.dev0"):
super().log(logs, start_time)
Expand Down

0 comments on commit 6130c96

Please sign in to comment.