Skip to content

Commit

Permalink
🛣️ inference_mode to no_grad when computing old_per_token_logps (
Browse files Browse the repository at this point in the history
  • Loading branch information
qgallouedec authored Feb 28, 2025
1 parent ad6a35b commit 491921c
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion trl/trainer/grpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -776,7 +776,7 @@ def _generate_and_score_completions(

logits_to_keep = completion_ids.size(1) # we only need to compute the logits for the completion tokens

with torch.inference_mode():
with torch.no_grad():
# When using num_iterations == 1, old_per_token_logps == per_token_logps, so we can skip it's
# computation here, and use per_token_logps.detach() instead.
if self.num_iterations > 1:
Expand Down

0 comments on commit 491921c

Please sign in to comment.