Skip to content

Commit

Permalink
fix reward calculation
Browse files Browse the repository at this point in the history
  • Loading branch information
qgallouedec committed Jan 14, 2025
1 parent c70402c commit 106d271
Showing 1 changed file with 10 additions and 18 deletions.
28 changes: 10 additions & 18 deletions trl/trainer/grpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
from ..data_utils import maybe_apply_chat_template
from ..models import create_reference_model, unwrap_model_for_generation
from .grpo_config import GRPOConfig
from .utils import generate_model_card, get_comet_experiment_url
from .utils import flush_left, generate_model_card, get_comet_experiment_url


if is_peft_available():
Expand Down Expand Up @@ -113,8 +113,11 @@ def __init__(
# Reward model
if isinstance(reward_model, str):
reward_model = AutoModelForSequenceClassification.from_pretrained(
reward_model, num_labels=1, pad_token_id=0, **model_init_kwargs
reward_model, num_labels=1, **model_init_kwargs
)
# The reward model computes the reward for the latest non-padded token in the input sequence.
# So it's important to set the pad token ID to the padding token ID of the processing class.
self.reward_model.config.pad_token_id = processing_class.pad_token_id
self.reward_model = reward_model

# Data loading and preprocessing
Expand All @@ -131,7 +134,7 @@ def data_collator(features): # No data collation is needed in GRPO
do_sample=True,
temperature=args.temperature,
num_return_sequences=self.num_generations,
pad_token_id=processing_class.eos_token_id,
pad_token_id=processing_class.pad_token_id,
)
self.beta = args.beta

Expand Down Expand Up @@ -229,21 +232,10 @@ def get_per_token_logps(model, input_ids):
prompt_mask = inputs["attention_mask"].repeat_interleave(self.num_generations, dim=0)
prompt_completion_mask = torch.cat([prompt_mask, completion_mask], dim=1)

def get_reward(model, input_ids, attention_mask):
# Forward pass through the reward model
base_model = getattr(model, model.base_model_prefix) # usually base_model_prefix = "model"
output = base_model(input_ids=input_ids, attention_mask=attention_mask, output_hidden_states=True)
per_token_reward = model.score(output.hidden_states[-1]).squeeze(2)

# Get the last True index in the mask
flipped = torch.flip(prompt_completion_mask, dims=[1])
final_idx = prompt_completion_mask.shape[1] - torch.argmax(flipped.int(), dim=1) - 1

# Get the reward logits for the last token in the sequence
rewards = per_token_reward[torch.arange(per_token_reward.size(0)), final_idx]
return rewards

rewards = get_reward(self.reward_model, prompt_completion_ids, prompt_completion_mask)
attention_mask, input_ids = flush_left(
prompt_completion_mask, prompt_completion_ids
) # needed for the reward model
rewards = self.reward_model(input_ids, attention_mask)

# Compute grouped-wise rewards
mean_grouped_rewards = rewards.view(-1, self.num_generations).mean(dim=1)
Expand Down

0 comments on commit 106d271

Please sign in to comment.