Skip to content

Commit

Permalink
The reward function is provided with all col from the dataset
Browse files Browse the repository at this point in the history
  • Loading branch information
qgallouedec committed Jan 24, 2025
1 parent 6f99f42 commit a0ea03a
Show file tree
Hide file tree
Showing 4 changed files with 101 additions and 13 deletions.
39 changes: 34 additions & 5 deletions docs/source/grpo_trainer.md
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,11 @@ The GRPO Trainer logs the following metrics:
The [`GRPOTrainer`] supports using custom reward functions instead of dense reward models. To ensure compatibility, your reward function must satisfy the following requirements:

1. **Input arguments**:
- The function must accept two arguments: `prompts` and `completions`.
- The function must accept the following as keyword arguments:
- `prompts` (contains the prompts),
- `completions` (contains the generated completions),
- All columns names (but `prompt`) that the dataset may have. For example, if the dataset contains a column named `ground_truth`, the function will be called with `ground_truth` as a keyword argument.
The easiest way to comply with this requirement is to use `**kwargs` in the function signature.
- Depending on the dataset format, the input will vary:
- For [standard format](dataset_formats#standard), `prompts` and `completions` will be lists of strings.
- For [conversational format](dataset_formats#conversational), `prompts` and `completions` will be lists of message dictionaries.
Expand All @@ -133,7 +137,7 @@ The [`GRPOTrainer`] supports using custom reward functions instead of dense rewa
Below is an example of a reward function for a standard format that rewards longer completions:

```python
def reward_func(prompts, completions):
def reward_func(completions, **kwargs):
"""Reward function that gives higher scores to longer completions."""
return [float(len(completion)) for completion in completions]
```
Expand All @@ -143,7 +147,7 @@ You can test it as follows:
```python
>>> prompts = ["The sky is", "The sun is"]
>>> completions = [" blue.", " in the sky."]
>>> print(reward_func(prompts, completions))
>>> print(reward_func(prompts=prompts, completions=completions))
[6.0, 12.0]
```

Expand All @@ -155,7 +159,7 @@ It is designed for conversational format, where prompts and completions consist
```python
import re

def format_reward_func(prompts, completions):
def format_reward_func(completions, **kwargs):
"""Reward function that checks if the completion has a specific format."""
pattern = r"^<think>.*?</think><answer>.*?</answer>$"
completion_contents = [completion[0]["content"] for completion in completions]
Expand All @@ -174,11 +178,36 @@ You can test this function as follows:
... [{"role": "assistant", "content": "<think>The sum of 1 and 2 is 3, which we multiply by 4 to get 12.</think><answer>(1 + 2) * 4 = 12</answer>"}],
... [{"role": "assistant", "content": "The sum of 3 and 1 is 4, which we multiply by 2 to get 8. So (3 + 1) * 2 = 8."}],
... ]
>>> format_reward_func(prompts, completions)
>>> format_reward_func(prompts=prompts, completions=completions)
[1.0, 0.0]
>>>
```

#### Example 3: Reward completions based on a reference

Below is an example of a reward function that rewards completions based on a reference ground truth. This example is designed for [standard format](dataset_formats#standard), where the dataset contains a column named `ground_truth`.

```python
import re

def reward_func(completions, ground_truth, **kwargs):
# Regular expression to capture content inside \boxed{}
matches = [re.search(r"\\boxed\{(.*?)\}", completion) for completion in completions]
content = [match.group(1) if match else "" for match in matches]
# Reward 1 if the content is the same as the ground truth, 0 otherwise
return [1.0 if c == gt else 0.0 for c, gt in zip(content, ground_truth)]
```

You can test this function as follows:

```python
>>> prompts = ["Problem: Solve the equation $2x + 3 = 7$. Solution:", "Problem: Solve the equation $3x - 5 = 10$."]
>>> completions = [r" The solution is \boxed{2}.", r" The solution is \boxed{6}."]
>>> ground_truth = ["2", "5"]
>>> reward_func(prompts=prompts, completions=completions, ground_truth=ground_truth)
[1.0, 0.0]
```

#### Passing the reward function to the trainer

To use your custom reward function, pass it to the `GRPOTrainer` as follows:
Expand Down
49 changes: 44 additions & 5 deletions tests/test_grpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ def test_training_reward_func_standard(self):
# Test if trainer can handle reward function with standard format
dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train")

def reward_func(prompts, completions):
def reward_func(completions, **kwargs):
"""Reward function that rewards longer completions."""
return [float(len(completion)) for completion in completions]

Expand Down Expand Up @@ -186,7 +186,7 @@ def test_training_reward_func_conversational(self):
# Test if trainer can handle reward function with conversational format
dataset = load_dataset("trl-internal-testing/zen", "conversational_prompt_only", split="train")

def reward_func(prompts, completions):
def reward_func(completions, **kwargs):
"""Reward function that gives higher scores to longer completion content."""
completion_contents = [completion[0]["content"] for completion in completions]
return [float(len(content)) for content in completion_contents]
Expand Down Expand Up @@ -222,11 +222,11 @@ def test_training_multiple_reward_funcs(self):
# Test that GRPOTrainer can be instantiated with multiple reward functions
dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train")

def reward_func1(prompts, completions):
def reward_func1(completions, **kwargs):
"""Reward function that rewards longer completions."""
return [float(len(completion)) for completion in completions]

def reward_func2(prompts, completions):
def reward_func2(completions, **kwargs):
"""Reward function that rewards completions with more unique letters."""
return [float(len(set(completion))) for completion in completions]

Expand Down Expand Up @@ -261,7 +261,7 @@ def test_training_multiple_mixed_reward_funcs(self):
# Test if the trainer can handle a mix of reward functions and reward models
dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train")

def reward_func(prompts, completions):
def reward_func(completions, **kwargs):
"""Reward function that rewards longer completions."""
return [float(len(completion)) for completion in completions]

Expand Down Expand Up @@ -291,3 +291,42 @@ def reward_func(prompts, completions):
for n, param in previous_trainable_params.items():
new_param = trainer.model.get_parameter(n)
self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.")

def test_training_reward_func_additional_column(self):
# Test if trainer can handle reward function that rely on additional columns in the dataset
dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train")

# Add a column to the dataset (dummy example, the column could be anything)
some_values = list(range(len(dataset)))
dataset = dataset.add_column("some_values", some_values)

def reward_func(completions, some_values, **kwargs):
"""Reward function that rewards completions with lengths closer to the values in some_values."""
return [float(abs(len(completion) - value)) for completion, value in zip(completions, some_values)]

with tempfile.TemporaryDirectory() as tmp_dir:
training_args = GRPOConfig(
output_dir=tmp_dir,
learning_rate=0.1, # increase the learning rate to speed up the test
per_device_train_batch_size=2, # reduce the batch size to reduce memory usage
num_generations=3, # reduce the number of generations to reduce memory usage
max_completion_length=32, # reduce the completion length to reduce memory usage
report_to="none",
)
trainer = GRPOTrainer(
model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5",
reward_funcs=reward_func,
args=training_args,
train_dataset=dataset,
)

previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()}

trainer.train()

self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])

# Check the params have changed
for n, param in previous_trainable_params.items():
new_param = trainer.model.get_parameter(n)
self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.")
12 changes: 12 additions & 0 deletions trl/trainer/grpo_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,9 @@ class GRPOConfig(TrainingArguments):
> Parameters that control the data preprocessing
remove_unused_columns (`bool`, *optional*, defaults to `False`):
Whether to only keep the column `"prompt"` in the dataset. If you use a custom reward function that
requires any column other than `"prompts"` and `"completions"`, you should keep this to `False`.
max_prompt_length (`int` or `None`, *optional*, defaults to `512`):
Maximum length of the prompt. If the prompt is longer than this value, it will be truncated left.
num_generations (`int` or `None`, *optional*, defaults to `8`):
Expand Down Expand Up @@ -67,6 +70,15 @@ class GRPOConfig(TrainingArguments):
)

# Parameters that control the data preprocessing
# The default value remove_unused_columns is overwritten from the parent class, because in GRPO we usually rely on
# additional columns to compute the reward
remove_unused_columns: Optional[bool] = field(
default=False,
metadata={
"help": "Whether to only keep the column 'prompt' in the dataset. If you use a custom reward function "
"that requires any column other than 'prompts' and 'completions', you should keep this to `False`."
},
)
max_prompt_length: Optional[int] = field(
default=512,
metadata={
Expand Down
14 changes: 11 additions & 3 deletions trl/trainer/grpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,8 +94,9 @@ class GRPOTrainer(Trainer):
using [`~transformers.AutoModelForSequenceClassification.from_pretrained`] with `num_labels=1` and the
keyword arguments in `args.model_init_kwargs`.
- A [`~transformers.PreTrainedModel`] object: Only sequence classification models are supported.
- A custom reward function: This should take a list of prompts and completions and return a list of
rewards. For more details, see [Using a custom reward function](#using-a-custom-reward-function).
- A custom reward function: The function is provided with the prompts and the generated completions,
plus any additional columns in the dataset. It should return a list of rewards. For more details, see
[Using a custom reward function](#using-a-custom-reward-function).
- A list of reward functions, where each item can independently be any of the above types. Mixing different
types within the list (e.g., a string model ID and a custom reward function) is allowed.
args ([`GRPOConfig`], *optional*, defaults to `None`):
Expand Down Expand Up @@ -369,7 +370,14 @@ def get_per_token_logps(model, input_ids):
with torch.inference_mode():
rewards[i] = reward_func(**reward_inputs).logits[:, 0] # Shape (B*G,)
else:
rewards[i] = torch.tensor(reward_func(prompts, completions))
# Repeat all input columns (but "prompt" and "completion") to match the number of generations
reward_kwargs = {key: [] for key in inputs[0].keys() if key not in ["prompt", "completion"]}
for key in reward_kwargs:
for example in inputs:
# Repeat each value in the column for `num_generations` times
reward_kwargs[key].extend([example[key]] * self.num_generations)
output_reward_func = reward_func(prompts=prompts, completions=completions, **reward_kwargs)
rewards[i] = torch.tensor(output_reward_func, dtype=torch.float32, device=device)
# Sum the rewards from all reward functions
rewards = rewards.sum(dim=0)

Expand Down

0 comments on commit a0ea03a

Please sign in to comment.