diff --git a/trl/trainer/grpo_trainer.py b/trl/trainer/grpo_trainer.py index d787178e8f..86753a3a5c 100644 --- a/trl/trainer/grpo_trainer.py +++ b/trl/trainer/grpo_trainer.py @@ -262,6 +262,11 @@ def data_collator(features): # No data collation is needed in GRPO optimizers=optimizers, ) + # Gradient accumulation requires scaled loss. Normally, loss scaling in the parent class depends on whether the + # model accepts loss-related kwargs. Since we compute our own loss, this check is irrelevant. We set + # self.model_accepts_loss_kwargs to False to enable scaling. + self.model_accepts_loss_kwargs = False + if self.ref_model is not None: self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True) @@ -393,7 +398,7 @@ def get_per_token_logps(model, input_ids): return loss def log(self, logs: dict[str, float], start_time: Optional[float] = None) -> None: - metrics = {key: sum(val) / len(val) for key, val in self._metrics.items()} # average the metrics + metrics = {key: sum(val) / len(val) for key, val in self._metrics.items() if val} # average the metrics logs = {**logs, **metrics} if version.parse(transformers.__version__) >= version.parse("4.47.0.dev0"): super().log(logs, start_time)