diff --git a/docs/source/grpo_trainer.md b/docs/source/grpo_trainer.md index 9db61842ac..4abd250ae4 100644 --- a/docs/source/grpo_trainer.md +++ b/docs/source/grpo_trainer.md @@ -110,6 +110,7 @@ In TRL though, as in the original paper, we only do one update per generation, s The GRPO Trainer logs the following metrics: +- `reward/{reward_func_name}`: The reward computed by each reward function. - `reward`: The average reward. - `reward_std` : The average standard deviation within reward groups. - `kl` : The average KL divergence between the model and the reference model calculated on completions. diff --git a/trl/trainer/grpo_trainer.py b/trl/trainer/grpo_trainer.py index a484f1cb2d..6f18ac9959 100644 --- a/trl/trainer/grpo_trainer.py +++ b/trl/trainer/grpo_trainer.py @@ -14,6 +14,7 @@ import os import textwrap +from collections import defaultdict from typing import Any, Callable, Optional, Union import torch @@ -313,7 +314,7 @@ def data_collator(features): # No data collation is needed in GRPO model.warnings_issued["estimate_tokens"] = True # Initialize the metrics - self._metrics = {"kl": [], "reward": [], "reward_std": []} + self._metrics = defaultdict(list) super().__init__( model=model, @@ -502,8 +503,8 @@ def get_per_token_logps(model, input_ids, attention_mask): # Compute the rewards prompts = [prompt for prompt in prompts for _ in range(self.num_generations)] + rewards_per_func = torch.zeros(len(prompts), len(self.reward_funcs), device=self.accelerator.device) - rewards = torch.zeros(len(self.reward_funcs), len(prompts), device=self.accelerator.device) for i, (reward_func, reward_processing_class) in enumerate( zip(self.reward_funcs, self.reward_processing_classes) ): @@ -518,7 +519,7 @@ def get_per_token_logps(model, input_ids, attention_mask): ) reward_inputs = super()._prepare_inputs(reward_inputs) with torch.inference_mode(): - rewards[i] = reward_func(**reward_inputs).logits[:, 0] # Shape (B*G,) + rewards_per_func[:, i] = reward_func(**reward_inputs).logits[:, 0] # Shape (B*G,) else: # Repeat all input columns (but "prompt" and "completion") to match the number of generations reward_kwargs = {key: [] for key in inputs[0].keys() if key not in ["prompt", "completion"]} @@ -527,9 +528,10 @@ def get_per_token_logps(model, input_ids, attention_mask): # Repeat each value in the column for `num_generations` times reward_kwargs[key].extend([example[key]] * self.num_generations) output_reward_func = reward_func(prompts=prompts, completions=completions, **reward_kwargs) - rewards[i] = torch.tensor(output_reward_func, dtype=torch.float32, device=self.accelerator.device) + rewards_per_func[:, i] = torch.tensor(output_reward_func, dtype=torch.float32, device=self.accelerator.device) + # Sum the rewards from all reward functions - rewards = rewards.sum(dim=0) + rewards = rewards_per_func.sum(dim=1) # Compute grouped-wise rewards mean_grouped_rewards = rewards.view(-1, self.num_generations).mean(dim=1) @@ -546,6 +548,14 @@ def get_per_token_logps(model, input_ids, attention_mask): loss = ((per_token_loss * completion_mask).sum(dim=1) / completion_mask.sum(dim=1)).mean() # Log the metrics + reward_per_func = self.accelerator.gather_for_metrics(rewards_per_func).mean(0) + for i, reward_func in enumerate(self.reward_funcs): + if isinstance(reward_func, PreTrainedModel): + reward_func_name = reward_func.config._name_or_path.split("/")[-1] + else: + reward_func_name = reward_func.__name__ + self._metrics[f"rewards/{reward_func_name}"].append(reward_per_func[i].item()) + self._metrics["reward"].append(self.accelerator.gather_for_metrics(rewards).mean().item()) self._metrics["reward_std"].append(self.accelerator.gather_for_metrics(std_grouped_rewards).mean().item()) @@ -556,13 +566,13 @@ def get_per_token_logps(model, input_ids, attention_mask): return loss def log(self, logs: dict[str, float], start_time: Optional[float] = None) -> None: - metrics = {key: sum(val) / len(val) for key, val in self._metrics.items() if val} # average the metrics + metrics = {key: sum(val) / len(val) for key, val in self._metrics.items()} # average the metrics logs = {**logs, **metrics} if version.parse(transformers.__version__) >= version.parse("4.47.0.dev0"): super().log(logs, start_time) else: # transformers<=4.46 super().log(logs) - self._metrics = {key: [] for key in self._metrics} + self._metrics.clear() def create_model_card( self,