Skip to content

Commit

Permalink
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge branch 'main' into grpo_vllm
Browse files Browse the repository at this point in the history
kashif authored Jan 25, 2025
2 parents 383b795 + 317d2d4 commit 4abe3ea
Showing 2 changed files with 18 additions and 7 deletions.
1 change: 1 addition & 0 deletions docs/source/grpo_trainer.md
Original file line number Diff line number Diff line change
@@ -110,6 +110,7 @@ In TRL though, as in the original paper, we only do one update per generation, s

The GRPO Trainer logs the following metrics:

- `reward/{reward_func_name}`: The reward computed by each reward function.
- `reward`: The average reward.
- `reward_std` : The average standard deviation within reward groups.
- `kl` : The average KL divergence between the model and the reference model calculated on completions.
24 changes: 17 additions & 7 deletions trl/trainer/grpo_trainer.py
Original file line number Diff line number Diff line change
@@ -14,6 +14,7 @@

import os
import textwrap
from collections import defaultdict
from typing import Any, Callable, Optional, Union

import torch
@@ -313,7 +314,7 @@ def data_collator(features): # No data collation is needed in GRPO
model.warnings_issued["estimate_tokens"] = True

# Initialize the metrics
self._metrics = {"kl": [], "reward": [], "reward_std": []}
self._metrics = defaultdict(list)

super().__init__(
model=model,
@@ -502,8 +503,8 @@ def get_per_token_logps(model, input_ids, attention_mask):

# Compute the rewards
prompts = [prompt for prompt in prompts for _ in range(self.num_generations)]
rewards_per_func = torch.zeros(len(prompts), len(self.reward_funcs), device=self.accelerator.device)

rewards = torch.zeros(len(self.reward_funcs), len(prompts), device=self.accelerator.device)
for i, (reward_func, reward_processing_class) in enumerate(
zip(self.reward_funcs, self.reward_processing_classes)
):
@@ -518,7 +519,7 @@ def get_per_token_logps(model, input_ids, attention_mask):
)
reward_inputs = super()._prepare_inputs(reward_inputs)
with torch.inference_mode():
rewards[i] = reward_func(**reward_inputs).logits[:, 0] # Shape (B*G,)
rewards_per_func[:, i] = reward_func(**reward_inputs).logits[:, 0] # Shape (B*G,)
else:
# Repeat all input columns (but "prompt" and "completion") to match the number of generations
reward_kwargs = {key: [] for key in inputs[0].keys() if key not in ["prompt", "completion"]}
@@ -527,9 +528,10 @@ def get_per_token_logps(model, input_ids, attention_mask):
# Repeat each value in the column for `num_generations` times
reward_kwargs[key].extend([example[key]] * self.num_generations)
output_reward_func = reward_func(prompts=prompts, completions=completions, **reward_kwargs)
rewards[i] = torch.tensor(output_reward_func, dtype=torch.float32, device=self.accelerator.device)
rewards_per_func[:, i] = torch.tensor(output_reward_func, dtype=torch.float32, device=self.accelerator.device)

# Sum the rewards from all reward functions
rewards = rewards.sum(dim=0)
rewards = rewards_per_func.sum(dim=1)

# Compute grouped-wise rewards
mean_grouped_rewards = rewards.view(-1, self.num_generations).mean(dim=1)
@@ -546,6 +548,14 @@ def get_per_token_logps(model, input_ids, attention_mask):
loss = ((per_token_loss * completion_mask).sum(dim=1) / completion_mask.sum(dim=1)).mean()

# Log the metrics
reward_per_func = self.accelerator.gather_for_metrics(rewards_per_func).mean(0)
for i, reward_func in enumerate(self.reward_funcs):
if isinstance(reward_func, PreTrainedModel):
reward_func_name = reward_func.config._name_or_path.split("/")[-1]
else:
reward_func_name = reward_func.__name__
self._metrics[f"rewards/{reward_func_name}"].append(reward_per_func[i].item())

self._metrics["reward"].append(self.accelerator.gather_for_metrics(rewards).mean().item())

self._metrics["reward_std"].append(self.accelerator.gather_for_metrics(std_grouped_rewards).mean().item())
@@ -556,13 +566,13 @@ def get_per_token_logps(model, input_ids, attention_mask):
return loss

def log(self, logs: dict[str, float], start_time: Optional[float] = None) -> None:
metrics = {key: sum(val) / len(val) for key, val in self._metrics.items() if val} # average the metrics
metrics = {key: sum(val) / len(val) for key, val in self._metrics.items()} # average the metrics
logs = {**logs, **metrics}
if version.parse(transformers.__version__) >= version.parse("4.47.0.dev0"):
super().log(logs, start_time)
else: # transformers<=4.46
super().log(logs)
self._metrics = {key: [] for key in self._metrics}
self._metrics.clear()

def create_model_card(
self,

0 comments on commit 4abe3ea

Please sign in to comment.