From 071c19a06bf106f550b9c5a52f6cb80999564614 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Thu, 16 Jan 2025 16:12:08 +0000 Subject: [PATCH] fix grad computation --- trl/trainer/grpo_trainer.py | 11 ++++------- 1 file changed, 4 insertions(+), 7 deletions(-) diff --git a/trl/trainer/grpo_trainer.py b/trl/trainer/grpo_trainer.py index 6e40740ad6..00e46baf13 100644 --- a/trl/trainer/grpo_trainer.py +++ b/trl/trainer/grpo_trainer.py @@ -226,9 +226,7 @@ def get_per_token_logps(model, input_ids): ref_per_token_logps = ref_per_token_logps[:, prompt_length:] # get rid of the prompt # Compute the KL divergence between the model and the reference model - per_token_kl = ( - torch.exp(ref_per_token_logps) / torch.exp(per_token_logps) - (ref_per_token_logps - per_token_logps) - 1 - ) + per_token_kl = torch.exp(ref_per_token_logps - per_token_logps) - (ref_per_token_logps - per_token_logps) - 1 # Mask everything after the first EOS token is_eos = completion_ids == self.processing_class.eos_token_id @@ -268,10 +266,9 @@ def get_per_token_logps(model, input_ids): std_grouped_rewards = std_grouped_rewards.repeat_interleave(self.num_generations, dim=0) advantages = (rewards - mean_grouped_rewards) / (std_grouped_rewards + 1e-4) - # Compute the loss - per_token_loss = -( - advantages.unsqueeze(1) * per_token_logps / per_token_logps.detach() - self.beta * per_token_kl - ) + # x - x.detach() allows for preserving gradients from x + advatages = torch.exp(per_token_logps - per_token_logps.detach()) * advantages.unsqueeze(1) + per_token_loss = -(advatages - self.beta * per_token_kl) loss = ((per_token_loss * completion_mask).sum(dim=1) / completion_mask.sum(dim=1)).mean() # Log the metrics