Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RLOO] fix token_level_kl #2575

Merged
merged 8 commits into from
Jan 17, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/source/rloo_trainer.md
Original file line number Diff line number Diff line change
Expand Up @@ -276,7 +276,7 @@ The [Reinforce++](https://hijkzzz.notion.site/reinforce-plus-plus) report by Jia
- Clipping rewards: limiting reward values within a specific range to mitigate the impact of extreme rewards on model updates, thus preventing gradient explosion
- Normalizing rewards: scaling rewards to have a mean of 0 and a standard deviation of 1, which helps in stabilizing the training process
- Normalizing advantages: scaling advantages to have a mean of 0 and a standard deviation of 1, which helps in stabilizing the training process
- Using token-level KL penalty (default) vs. sequence-level KL penalty
- Using token-level KL penalty that is defined as equation (1) of the report vs. sequence-level KL penalty (default)

These options are available via the appropriate arguments in the [`RLOOConfig`] class.

Expand Down
86 changes: 85 additions & 1 deletion tests/test_rloo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,15 +61,66 @@ def test_rloo_checkpoint(self):
def test_rloo_reward(self):
local_batch_size = 3
rloo_k = 4
sequence_length = 5 # Add sequence length for testing token-level rewards

# fmt: off
rlhf_reward = torch.tensor([
1, 2, 3, # first rlhf reward for three prompts
2, 3, 4, # second rlhf reward for three prompts
5, 6, 7, # third rlhf reward for three prompts
8, 9, 10, # fourth rlhf reward for three prompts
]).float()

# Create padding mask where 1 indicates valid token, 0 indicates padding
padding_mask = torch.ones(local_batch_size * rloo_k, sequence_length)
# Set padding based on sequence lengths
sequence_lengths = torch.tensor([
3, 4, 3, # lengths for first batch
4, 3, 4, # lengths for second batch
3, 4, 3, # lengths for third batch
4, 3, 4, # lengths for fourth batch
])
for i, length in enumerate(sequence_lengths):
padding_mask[i, length:] = 0

# Add kl tensor for testing token-level rewards
kl = torch.ones(local_batch_size * rloo_k, sequence_length) # Dummy KL values
# fmt: on

# Test token-level KL rewards following OpenRLHF implementation
kl_coef = 0.1
kl_reward = -kl_coef * kl

# Find last non-padded position
eos_indices = padding_mask.size(1) - 1 - padding_mask.long().fliplr().argmax(dim=1, keepdim=True)

# Create last reward tensor
last_reward = torch.zeros_like(kl)
last_reward.scatter_(dim=1, index=eos_indices, src=rlhf_reward.reshape(-1, 1))

# Test last_reward - should have rlhf_reward at the last non-padded position
for i, (length, reward) in enumerate(zip(sequence_lengths, rlhf_reward)):
# Check reward is at correct position
self.assertEqual(last_reward[i, length - 1].item(), reward.item())
# Check zeros elsewhere
self.assertTrue(torch.all(last_reward[i, : length - 1] == 0))
self.assertTrue(torch.all(last_reward[i, length:] == 0))

# Combine rewards
reward = last_reward + kl_reward
non_score_reward = kl_reward.sum(1)
token_level_rlhf_reward = reward.sum(1)

# Test reward components
# KL reward should be -0.1 for each token in sequence length
expected_kl_reward = -0.1 * sequence_length # Each position gets -0.1 KL reward
torch.testing.assert_close(non_score_reward, torch.tensor(expected_kl_reward).expand_as(non_score_reward))

# Total reward should be rlhf_reward + kl_reward
expected_total = rlhf_reward + expected_kl_reward
torch.testing.assert_close(token_level_rlhf_reward, expected_total)

# Test sequence-level rewards (existing test)
baseline = (rlhf_reward.sum(0) - rlhf_reward) / (rloo_k - 1)
advantages = torch.zeros_like(rlhf_reward)
for i in range(0, len(advantages), local_batch_size):
Expand All @@ -83,8 +134,41 @@ def test_rloo_reward(self):
self.assertLess((1 - (2 + 5 + 8) / 3 - advantages[0].item()), 1e-6)
self.assertLess((6 - (3 + 2 + 9) / 3 - advantages[7].item()), 1e-6)

# vectorized impl
# Test vectorized implementation
rlhf_reward = rlhf_reward.reshape(rloo_k, local_batch_size)
baseline = (rlhf_reward.sum(0) - rlhf_reward) / (rloo_k - 1)
vec_advantages = rlhf_reward - baseline
torch.testing.assert_close(vec_advantages.flatten(), advantages)

def test_rloo_training(self):
with tempfile.TemporaryDirectory() as tmp_dir:
training_args = RLOOConfig(
output_dir=tmp_dir,
per_device_train_batch_size=2,
per_device_eval_batch_size=2,
total_episodes=1,
num_train_epochs=1,
max_steps=2,
report_to="none",
)

# Create a simple dataset
dummy_text = [{"content": "Hello World!", "role": "user"}]
dummy_data = self.tokenizer.apply_chat_template(dummy_text)
dummy_dataset = Dataset.from_dict({"input_ids": [dummy_data, dummy_data]})

trainer = RLOOTrainer(
config=training_args,
policy=self.policy_model,
reward_model=self.reward_model,
ref_policy=self.policy_ref_model,
processing_class=self.tokenizer,
train_dataset=dummy_dataset,
eval_dataset=dummy_dataset,
)

# Test that training completes without errors
trainer.train()

# Check if objective/rlhf_reward is available
self.assertIn("objective/rlhf_reward", trainer.state.log_history[-1])
2 changes: 1 addition & 1 deletion trl/trainer/rloo_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,6 @@ class RLOOConfig(OnPolicyConfig):
metadata={"help": "Whether to normalize advantages"},
)
token_level_kl: bool = field(
default=True,
default=False,
metadata={"help": "Whether to use token-level KL penalty or sequence-level KL penalty"},
)
17 changes: 14 additions & 3 deletions trl/trainer/rloo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -397,13 +397,24 @@ def repeat_generator():
# Compute total reward with KL penalty
if args.token_level_kl:
# Token-level KL penalty: apply KL penalty per token
token_kl_penalty = -args.kl_coef * kl
non_score_reward = token_kl_penalty.sum(1)
kl_reward = -args.kl_coef * kl

# Get the index of the last non-padded token for each sequence
eos_indices = padding_mask.size(1) - 1 - padding_mask.long().fliplr().argmax(dim=1, keepdim=True)
last_reward = torch.zeros_like(kl)
# Ensure scores has correct shape and type
scores_shaped = scores.reshape(-1, 1).to(kl.dtype)
last_reward.scatter_(dim=1, index=eos_indices, src=scores_shaped)

# Combine KL reward and last reward
non_score_reward = kl_reward.sum(1) # Keep this for logging
reward = last_reward + kl_reward
rlhf_reward = reward.sum(1) # Sum across sequence length
else:
# Sequence-level KL penalty: sum KL across tokens first
sequence_kl = kl.sum(1)
non_score_reward = -args.kl_coef * sequence_kl
rlhf_reward = scores + non_score_reward
rlhf_reward = non_score_reward + scores

# vectorized RLOO advantages implementation
rlhf_reward = rlhf_reward.reshape(args.rloo_k, -1)
Expand Down
Loading