From 317d2d477ba4dcebe81fbfcc6abc7bfb91b50ad7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= <45557362+qgallouedec@users.noreply.github.com> Date: Sat, 25 Jan 2025 11:43:00 +0100 Subject: [PATCH] =?UTF-8?q?=F0=9F=94=8E=20Finegrained=20reward=20logging?= =?UTF-8?q?=20for=20GRPO=20(#2651)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- docs/source/grpo_trainer.md | 1 + trl/trainer/grpo_trainer.py | 24 +++++++++++++++++------- 2 files changed, 18 insertions(+), 7 deletions(-) diff --git a/docs/source/grpo_trainer.md b/docs/source/grpo_trainer.md index 9db61842ac..4abd250ae4 100644 --- a/docs/source/grpo_trainer.md +++ b/docs/source/grpo_trainer.md @@ -110,6 +110,7 @@ In TRL though, as in the original paper, we only do one update per generation, s The GRPO Trainer logs the following metrics: +- `reward/{reward_func_name}`: The reward computed by each reward function. - `reward`: The average reward. - `reward_std` : The average standard deviation within reward groups. - `kl` : The average KL divergence between the model and the reference model calculated on completions. diff --git a/trl/trainer/grpo_trainer.py b/trl/trainer/grpo_trainer.py index bf2ef75aa0..3cddc9eb5c 100644 --- a/trl/trainer/grpo_trainer.py +++ b/trl/trainer/grpo_trainer.py @@ -14,6 +14,7 @@ import os import textwrap +from collections import defaultdict from typing import Any, Callable, Optional, Union import torch @@ -255,7 +256,7 @@ def data_collator(features): # No data collation is needed in GRPO model.warnings_issued["estimate_tokens"] = True # Initialize the metrics - self._metrics = {"kl": [], "reward": [], "reward_std": []} + self._metrics = defaultdict(list) super().__init__( model=model, @@ -361,7 +362,7 @@ def get_per_token_logps(model, input_ids): # Compute the rewards prompts = [prompt for prompt in prompts for _ in range(self.num_generations)] - rewards = torch.zeros(len(self.reward_funcs), len(prompts), device=device) + rewards_per_func = torch.zeros(len(prompts), len(self.reward_funcs), device=device) for i, (reward_func, reward_processing_class) in enumerate( zip(self.reward_funcs, self.reward_processing_classes) ): @@ -376,7 +377,7 @@ def get_per_token_logps(model, input_ids): ) reward_inputs = super()._prepare_inputs(reward_inputs) with torch.inference_mode(): - rewards[i] = reward_func(**reward_inputs).logits[:, 0] # Shape (B*G,) + rewards_per_func[:, i] = reward_func(**reward_inputs).logits[:, 0] # Shape (B*G,) else: # Repeat all input columns (but "prompt" and "completion") to match the number of generations reward_kwargs = {key: [] for key in inputs[0].keys() if key not in ["prompt", "completion"]} @@ -385,9 +386,10 @@ def get_per_token_logps(model, input_ids): # Repeat each value in the column for `num_generations` times reward_kwargs[key].extend([example[key]] * self.num_generations) output_reward_func = reward_func(prompts=prompts, completions=completions, **reward_kwargs) - rewards[i] = torch.tensor(output_reward_func, dtype=torch.float32, device=device) + rewards_per_func[:, i] = torch.tensor(output_reward_func, dtype=torch.float32, device=device) + # Sum the rewards from all reward functions - rewards = rewards.sum(dim=0) + rewards = rewards_per_func.sum(dim=1) # Compute grouped-wise rewards mean_grouped_rewards = rewards.view(-1, self.num_generations).mean(dim=1) @@ -404,6 +406,14 @@ def get_per_token_logps(model, input_ids): loss = ((per_token_loss * completion_mask).sum(dim=1) / completion_mask.sum(dim=1)).mean() # Log the metrics + reward_per_func = self.accelerator.gather_for_metrics(rewards_per_func).mean(0) + for i, reward_func in enumerate(self.reward_funcs): + if isinstance(reward_func, PreTrainedModel): + reward_func_name = reward_func.config._name_or_path.split("/")[-1] + else: + reward_func_name = reward_func.__name__ + self._metrics[f"rewards/{reward_func_name}"].append(reward_per_func[i].item()) + self._metrics["reward"].append(self.accelerator.gather_for_metrics(rewards).mean().item()) self._metrics["reward_std"].append(self.accelerator.gather_for_metrics(std_grouped_rewards).mean().item()) @@ -414,13 +424,13 @@ def get_per_token_logps(model, input_ids): return loss def log(self, logs: dict[str, float], start_time: Optional[float] = None) -> None: - metrics = {key: sum(val) / len(val) for key, val in self._metrics.items() if val} # average the metrics + metrics = {key: sum(val) / len(val) for key, val in self._metrics.items()} # average the metrics logs = {**logs, **metrics} if version.parse(transformers.__version__) >= version.parse("4.47.0.dev0"): super().log(logs, start_time) else: # transformers<=4.46 super().log(logs) - self._metrics = {key: [] for key in self._metrics} + self._metrics.clear() def create_model_card( self,