Skip to content

Commit

Permalink
[RLOO] fix token_level_kl (#2575)
Browse files Browse the repository at this point in the history
* fix token_level_kl

* fix non_score_reward and rlhf_reward

* add rloo test

* update test

* fix docs

* fix doc
  • Loading branch information
kashif authored Jan 17, 2025
1 parent 4c7eb6f commit 1b1140a
Show file tree
Hide file tree
Showing 4 changed files with 101 additions and 6 deletions.
2 changes: 1 addition & 1 deletion docs/source/rloo_trainer.md
Original file line number Diff line number Diff line change
Expand Up @@ -276,7 +276,7 @@ The [Reinforce++](https://hijkzzz.notion.site/reinforce-plus-plus) report by Jia
- Clipping rewards: limiting reward values within a specific range to mitigate the impact of extreme rewards on model updates, thus preventing gradient explosion
- Normalizing rewards: scaling rewards to have a mean of 0 and a standard deviation of 1, which helps in stabilizing the training process
- Normalizing advantages: scaling advantages to have a mean of 0 and a standard deviation of 1, which helps in stabilizing the training process
- Using token-level KL penalty (default) vs. sequence-level KL penalty
- Using token-level KL penalty that is defined as equation (1) of the report vs. sequence-level KL penalty (default)

These options are available via the appropriate arguments in the [`RLOOConfig`] class.

Expand Down
86 changes: 85 additions & 1 deletion tests/test_rloo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,15 +61,66 @@ def test_rloo_checkpoint(self):
def test_rloo_reward(self):
local_batch_size = 3
rloo_k = 4
sequence_length = 5 # Add sequence length for testing token-level rewards

# fmt: off
rlhf_reward = torch.tensor([
1, 2, 3, # first rlhf reward for three prompts
2, 3, 4, # second rlhf reward for three prompts
5, 6, 7, # third rlhf reward for three prompts
8, 9, 10, # fourth rlhf reward for three prompts
]).float()

# Create padding mask where 1 indicates valid token, 0 indicates padding
padding_mask = torch.ones(local_batch_size * rloo_k, sequence_length)
# Set padding based on sequence lengths
sequence_lengths = torch.tensor([
3, 4, 3, # lengths for first batch
4, 3, 4, # lengths for second batch
3, 4, 3, # lengths for third batch
4, 3, 4, # lengths for fourth batch
])
for i, length in enumerate(sequence_lengths):
padding_mask[i, length:] = 0

# Add kl tensor for testing token-level rewards
kl = torch.ones(local_batch_size * rloo_k, sequence_length) # Dummy KL values
# fmt: on

# Test token-level KL rewards following OpenRLHF implementation
kl_coef = 0.1
kl_reward = -kl_coef * kl

# Find last non-padded position
eos_indices = padding_mask.size(1) - 1 - padding_mask.long().fliplr().argmax(dim=1, keepdim=True)

# Create last reward tensor
last_reward = torch.zeros_like(kl)
last_reward.scatter_(dim=1, index=eos_indices, src=rlhf_reward.reshape(-1, 1))

# Test last_reward - should have rlhf_reward at the last non-padded position
for i, (length, reward) in enumerate(zip(sequence_lengths, rlhf_reward)):
# Check reward is at correct position
self.assertEqual(last_reward[i, length - 1].item(), reward.item())
# Check zeros elsewhere
self.assertTrue(torch.all(last_reward[i, : length - 1] == 0))
self.assertTrue(torch.all(last_reward[i, length:] == 0))

# Combine rewards
reward = last_reward + kl_reward
non_score_reward = kl_reward.sum(1)
token_level_rlhf_reward = reward.sum(1)

# Test reward components
# KL reward should be -0.1 for each token in sequence length
expected_kl_reward = -0.1 * sequence_length # Each position gets -0.1 KL reward
torch.testing.assert_close(non_score_reward, torch.tensor(expected_kl_reward).expand_as(non_score_reward))

# Total reward should be rlhf_reward + kl_reward
expected_total = rlhf_reward + expected_kl_reward
torch.testing.assert_close(token_level_rlhf_reward, expected_total)

# Test sequence-level rewards (existing test)
baseline = (rlhf_reward.sum(0) - rlhf_reward) / (rloo_k - 1)
advantages = torch.zeros_like(rlhf_reward)
for i in range(0, len(advantages), local_batch_size):
Expand All @@ -83,8 +134,41 @@ def test_rloo_reward(self):
self.assertLess((1 - (2 + 5 + 8) / 3 - advantages[0].item()), 1e-6)
self.assertLess((6 - (3 + 2 + 9) / 3 - advantages[7].item()), 1e-6)

# vectorized impl
# Test vectorized implementation
rlhf_reward = rlhf_reward.reshape(rloo_k, local_batch_size)
baseline = (rlhf_reward.sum(0) - rlhf_reward) / (rloo_k - 1)
vec_advantages = rlhf_reward - baseline
torch.testing.assert_close(vec_advantages.flatten(), advantages)

def test_rloo_training(self):
with tempfile.TemporaryDirectory() as tmp_dir:
training_args = RLOOConfig(
output_dir=tmp_dir,
per_device_train_batch_size=2,
per_device_eval_batch_size=2,
total_episodes=1,
num_train_epochs=1,
max_steps=2,
report_to="none",
)

# Create a simple dataset
dummy_text = [{"content": "Hello World!", "role": "user"}]
dummy_data = self.tokenizer.apply_chat_template(dummy_text)
dummy_dataset = Dataset.from_dict({"input_ids": [dummy_data, dummy_data]})

trainer = RLOOTrainer(
config=training_args,
policy=self.policy_model,
reward_model=self.reward_model,
ref_policy=self.policy_ref_model,
processing_class=self.tokenizer,
train_dataset=dummy_dataset,
eval_dataset=dummy_dataset,
)

# Test that training completes without errors
trainer.train()

# Check if objective/rlhf_reward is available
self.assertIn("objective/rlhf_reward", trainer.state.log_history[-1])
2 changes: 1 addition & 1 deletion trl/trainer/rloo_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,6 @@ class RLOOConfig(OnPolicyConfig):
metadata={"help": "Whether to normalize advantages"},
)
token_level_kl: bool = field(
default=True,
default=False,
metadata={"help": "Whether to use token-level KL penalty or sequence-level KL penalty"},
)
17 changes: 14 additions & 3 deletions trl/trainer/rloo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -397,13 +397,24 @@ def repeat_generator():
# Compute total reward with KL penalty
if args.token_level_kl:
# Token-level KL penalty: apply KL penalty per token
token_kl_penalty = -args.kl_coef * kl
non_score_reward = token_kl_penalty.sum(1)
kl_reward = -args.kl_coef * kl

# Get the index of the last non-padded token for each sequence
eos_indices = padding_mask.size(1) - 1 - padding_mask.long().fliplr().argmax(dim=1, keepdim=True)
last_reward = torch.zeros_like(kl)
# Ensure scores has correct shape and type
scores_shaped = scores.reshape(-1, 1).to(kl.dtype)
last_reward.scatter_(dim=1, index=eos_indices, src=scores_shaped)

# Combine KL reward and last reward
non_score_reward = kl_reward.sum(1) # Keep this for logging
reward = last_reward + kl_reward
rlhf_reward = reward.sum(1) # Sum across sequence length
else:
# Sequence-level KL penalty: sum KL across tokens first
sequence_kl = kl.sum(1)
non_score_reward = -args.kl_coef * sequence_kl
rlhf_reward = scores + non_score_reward
rlhf_reward = non_score_reward + scores

# vectorized RLOO advantages implementation
rlhf_reward = rlhf_reward.reshape(args.rloo_k, -1)
Expand Down

0 comments on commit 1b1140a

Please sign in to comment.