Skip to content

Commit

Permalink
Update grpo_trainer.py
Browse files Browse the repository at this point in the history
  • Loading branch information
andyl98 committed Jan 31, 2025
1 parent 41f195f commit eced62b
Showing 1 changed file with 20 additions and 8 deletions.
28 changes: 20 additions & 8 deletions trl/trainer/grpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -423,13 +423,32 @@ def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=N
**prompt_inputs, generation_config=self.generation_config
)

# Compute prompt length and extract completion ids
prompt_length = prompt_inputs["input_ids"].size(1)
completion_ids = prompt_completion_ids[:, prompt_length:]

# Mask everything after the first EOS token in completion_ids
completion_is_eos = completion_ids == self.processing_class.eos_token_id
completion_eos_idx = torch.full(
(completion_is_eos.size(0),), completion_is_eos.size(1), dtype=torch.long, device=device
)
completion_eos_idx[completion_is_eos.any(dim=1)] = completion_is_eos.int().argmax(dim=1)[
completion_is_eos.any(dim=1)
]
completion_sequence_indices = torch.arange(completion_is_eos.size(1), device=device).expand(
completion_is_eos.size(0), -1
)
completion_mask = (completion_sequence_indices <= completion_is_eos.unsqueeze(1)).long() # (B*G, C)

# Concatenate prompt_mask with completion_mask for logit computation
prompt_mask_extended = prompt_inputs["attention_mask"].repeat_interleave(
self.num_generations, dim=0
) # (B, P) -> (B*G, P)
attention_mask = torch.cat([prompt_mask_extended, completion_mask], dim=1) # (B*G, P+C)

# Get the per-token log probabilities for the completions for the model and the reference model
def get_per_token_logps(model, input_ids, num_logits_to_keep):
# We add 1 to `num_logits_to_keep` because the last logits of the sequence is later excluded
attention_mask = (input_ids != self.processing_class.pad_token_id).long()
logits = model(
input_ids=input_ids, attention_mask=attention_mask, num_logits_to_keep=num_logits_to_keep + 1
).logits # (B, L, V)
Expand All @@ -456,13 +475,6 @@ def get_per_token_logps(model, input_ids, num_logits_to_keep):
# Compute the KL divergence between the model and the reference model
per_token_kl = torch.exp(ref_per_token_logps - per_token_logps) - (ref_per_token_logps - per_token_logps) - 1

# Mask everything after the first EOS token
is_eos = completion_ids == self.processing_class.eos_token_id
eos_idx = torch.full((is_eos.size(0),), is_eos.size(1), dtype=torch.long, device=device)
eos_idx[is_eos.any(dim=1)] = is_eos.int().argmax(dim=1)[is_eos.any(dim=1)]
sequence_indices = torch.arange(is_eos.size(1), device=device).expand(is_eos.size(0), -1)
completion_mask = (sequence_indices <= eos_idx.unsqueeze(1)).int()

# Decode the generated completions
completions = self.processing_class.batch_decode(completion_ids, skip_special_tokens=True)
if is_conversational(inputs[0]):
Expand Down

0 comments on commit eced62b

Please sign in to comment.