From 5a305f9273abdc61ecd8fd3f127fe68f4447223b Mon Sep 17 00:00:00 2001 From: bartoszzuk Date: Mon, 13 May 2024 18:52:01 +0200 Subject: [PATCH 1/2] Fixed wrong logs prefixes in KTOTrainer --- trl/trainer/kto_trainer.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/trl/trainer/kto_trainer.py b/trl/trainer/kto_trainer.py index a22e74909f..5b6503a1f2 100644 --- a/trl/trainer/kto_trainer.py +++ b/trl/trainer/kto_trainer.py @@ -1525,26 +1525,28 @@ def log(self, logs: Dict[str, float]) -> None: """ # logs either has 'loss' or 'eval_loss' train_eval = "train" if "loss" in logs else "eval" + # train metrics should have no prefix, eval should have 'eval_' + prefix = "eval_" if train_eval == "eval" else "" # accumulate average metrics from sums and lengths for split in ["chosen", "rejected"]: if f"count/{split}" in self._stored_metrics[train_eval]: count_sum = torch.Tensor(self._stored_metrics[train_eval][f"count/{split}"]).sum().item() - logs[f"{train_eval}/rewards/{split}"] = ( + logs[f"{prefix}rewards/{split}"] = ( torch.Tensor(self._stored_metrics[train_eval][f"rewards/{split}_sum"]).sum().item() / count_sum ) - logs[f"{train_eval}/logps/{split}"] = ( + logs[f"{prefix}logps/{split}"] = ( torch.Tensor(self._stored_metrics[train_eval][f"logps/{split}_sum"]).sum().item() / count_sum ) for key in [f"count/{split}", f"rewards/{split}_sum", f"logps/{split}_sum"]: del self._stored_metrics[train_eval][key] # calculate reward margin - if f"{train_eval}/rewards/chosen" in logs and f"{train_eval}/rewards/rejected" in logs: - logs[f"{train_eval}/rewards/margins"] = ( - logs[f"{train_eval}/rewards/chosen"] - logs[f"{train_eval}/rewards/rejected"] + if f"{prefix}rewards/chosen" in logs and f"{prefix}rewards/rejected" in logs: + logs[f"{prefix}rewards/margins"] = ( + logs[f"{prefix}rewards/chosen"] - logs[f"{prefix}rewards/rejected"] ) # Add averaged stored metrics to logs for key, metrics in self._stored_metrics[train_eval].items(): - logs[f"{train_eval}/{key}"] = torch.Tensor(metrics).mean().item() + logs[f"{prefix}{key}"] = torch.Tensor(metrics).mean().item() del self._stored_metrics[train_eval] return super().log(logs) From d75783f5b7f7cd032bedbbcdddcd05591e280cdb Mon Sep 17 00:00:00 2001 From: bartoszzuk Date: Mon, 13 May 2024 19:11:45 +0200 Subject: [PATCH 2/2] Pre-commit formating --- trl/trainer/kto_trainer.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/trl/trainer/kto_trainer.py b/trl/trainer/kto_trainer.py index 5b6503a1f2..c57650939d 100644 --- a/trl/trainer/kto_trainer.py +++ b/trl/trainer/kto_trainer.py @@ -1541,9 +1541,7 @@ def log(self, logs: Dict[str, float]) -> None: del self._stored_metrics[train_eval][key] # calculate reward margin if f"{prefix}rewards/chosen" in logs and f"{prefix}rewards/rejected" in logs: - logs[f"{prefix}rewards/margins"] = ( - logs[f"{prefix}rewards/chosen"] - logs[f"{prefix}rewards/rejected"] - ) + logs[f"{prefix}rewards/margins"] = logs[f"{prefix}rewards/chosen"] - logs[f"{prefix}rewards/rejected"] # Add averaged stored metrics to logs for key, metrics in self._stored_metrics[train_eval].items(): logs[f"{prefix}{key}"] = torch.Tensor(metrics).mean().item()