Skip to content

Commit

Permalink
📏 Log completion length in GRPO (#2659)
Browse files Browse the repository at this point in the history
  • Loading branch information
qgallouedec authored Jan 25, 2025
1 parent 807046b commit 4720656
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 0 deletions.
1 change: 1 addition & 0 deletions docs/source/grpo_trainer.md
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,7 @@ In TRL though, as in the original paper, we only do one update per generation, s

The GRPO Trainer logs the following metrics:

- `completion_length`: The average completion length.
- `reward/{reward_func_name}`: The reward computed by each reward function.
- `reward`: The average reward.
- `reward_std` : The average standard deviation within reward groups.
Expand Down
3 changes: 3 additions & 0 deletions trl/trainer/grpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -410,6 +410,9 @@ def get_per_token_logps(model, input_ids):
loss = ((per_token_loss * completion_mask).sum(dim=1) / completion_mask.sum(dim=1)).mean()

# Log the metrics
completion_length = self.accelerator.gather_for_metrics(completion_mask.sum(1)).float().mean().item()
self._metrics["completion_length"].append(completion_length)

reward_per_func = self.accelerator.gather_for_metrics(rewards_per_func).mean(0)
for i, reward_func in enumerate(self.reward_funcs):
if isinstance(reward_func, PreTrainedModel):
Expand Down

0 comments on commit 4720656

Please sign in to comment.