Skip to content

Commit

Permalink
fix reward logging
Browse files Browse the repository at this point in the history
  • Loading branch information
qgallouedec committed Jan 16, 2025
1 parent c597c62 commit be1c21e
Showing 1 changed file with 7 additions and 8 deletions.
15 changes: 7 additions & 8 deletions trl/trainer/grpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ def data_collator(features): # No data collation is needed in GRPO
model.warnings_issued["estimate_tokens"] = True

# Initialize the metrics
self._metrics = {"kl": [], "reward": []}
self._metrics = {"kl": [], "reward": [], "reward_std": []}

super().__init__(
model=model,
Expand Down Expand Up @@ -261,24 +261,23 @@ def get_per_token_logps(model, input_ids):

# Compute grouped-wise rewards
mean_grouped_rewards = rewards.view(-1, self.num_generations).mean(dim=1)
std_grouped_rewards_raw = rewards.view(-1, self.num_generations).std(dim=1)
std_grouped_rewards = torch.where( # avoid division by zero
std_grouped_rewards < 1e-8, torch.tensor(1.0, device=device), std_grouped_rewards_raw
)
std_grouped_rewards = rewards.view(-1, self.num_generations).std(dim=1)

# Normalize the rewards to compute the advantages
mean_grouped_rewards = mean_grouped_rewards.repeat_interleave(self.num_generations, dim=0)
std_grouped_rewards = std_grouped_rewards.repeat_interleave(self.num_generations, dim=0)
advantages = (rewards - mean_grouped_rewards) / std_grouped_rewards
advantages = (rewards - mean_grouped_rewards) / (std_grouped_rewards + 1e-4)

# Compute the loss
per_token_loss = -(advantages.unsqueeze(1) - self.beta * per_token_kl)
per_token_loss = -(
advantages.unsqueeze(1) * per_token_logps / per_token_logps.detach() - self.beta * per_token_kl
)
loss = ((per_token_loss * completion_mask).sum(dim=1) / completion_mask.sum(dim=1)).mean()

# Log the metrics
self._metrics["reward"].append(self.accelerator.gather_for_metrics(rewards).mean().item())

self._metrics["reward_std"].append(self.accelerator.gather_for_metrics(std_grouped_rewards_raw).mean().item())
self._metrics["reward_std"].append(self.accelerator.gather_for_metrics(std_grouped_rewards).mean().item())

mean_kl = ((per_token_kl * completion_mask).sum(dim=1) / completion_mask.sum(dim=1)).mean()
self._metrics["kl"].append(self.accelerator.gather_for_metrics(mean_kl).mean().item())
Expand Down

0 comments on commit be1c21e

Please sign in to comment.