From 6f99f42f724123409422f2fad42bf56fa91f366f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= <45557362+qgallouedec@users.noreply.github.com> Date: Fri, 24 Jan 2025 16:23:16 +0100 Subject: [PATCH] =?UTF-8?q?=F0=9F=A5=9E=20Fix=20KTO=20gradient=20accumulat?= =?UTF-8?q?ion=20loss=20scaling=20(#2648)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- trl/trainer/kto_trainer.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/trl/trainer/kto_trainer.py b/trl/trainer/kto_trainer.py index 897ce25520..c45a88d554 100644 --- a/trl/trainer/kto_trainer.py +++ b/trl/trainer/kto_trainer.py @@ -746,6 +746,11 @@ def make_inputs_require_grad(module, input, output): preprocess_logits_for_metrics=preprocess_logits_for_metrics, ) + # Gradient accumulation requires scaled loss. Normally, loss scaling in the parent class depends on whether the + # model accepts loss-related kwargs. Since we compute our own loss, this check is irrelevant. We set + # self.model_accepts_loss_kwargs to False to enable scaling. + self.model_accepts_loss_kwargs = False + # Add tags for models that have been loaded with the correct transformers version if hasattr(self.model, "add_model_tags"): self.model.add_model_tags(self._tag_names)