Skip to content

Commit

Permalink
fix grad computation
Browse files Browse the repository at this point in the history
  • Loading branch information
qgallouedec committed Jan 16, 2025
1 parent be1c21e commit 071c19a
Showing 1 changed file with 4 additions and 7 deletions.
11 changes: 4 additions & 7 deletions trl/trainer/grpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,9 +226,7 @@ def get_per_token_logps(model, input_ids):
ref_per_token_logps = ref_per_token_logps[:, prompt_length:] # get rid of the prompt

# Compute the KL divergence between the model and the reference model
per_token_kl = (
torch.exp(ref_per_token_logps) / torch.exp(per_token_logps) - (ref_per_token_logps - per_token_logps) - 1
)
per_token_kl = torch.exp(ref_per_token_logps - per_token_logps) - (ref_per_token_logps - per_token_logps) - 1

# Mask everything after the first EOS token
is_eos = completion_ids == self.processing_class.eos_token_id
Expand Down Expand Up @@ -268,10 +266,9 @@ def get_per_token_logps(model, input_ids):
std_grouped_rewards = std_grouped_rewards.repeat_interleave(self.num_generations, dim=0)
advantages = (rewards - mean_grouped_rewards) / (std_grouped_rewards + 1e-4)

# Compute the loss
per_token_loss = -(
advantages.unsqueeze(1) * per_token_logps / per_token_logps.detach() - self.beta * per_token_kl
)
# x - x.detach() allows for preserving gradients from x
advatages = torch.exp(per_token_logps - per_token_logps.detach()) * advantages.unsqueeze(1)
per_token_loss = -(advatages - self.beta * per_token_kl)
loss = ((per_token_loss * completion_mask).sum(dim=1) / completion_mask.sum(dim=1)).mean()

# Log the metrics
Expand Down

0 comments on commit 071c19a

Please sign in to comment.