diff --git a/docs/source/grpo_trainer.md b/docs/source/grpo_trainer.md index 973dadf03a..af54406dfe 100644 --- a/docs/source/grpo_trainer.md +++ b/docs/source/grpo_trainer.md @@ -149,28 +149,33 @@ You can test it as follows: #### Example 2: Conversational format -For conversational format, completions consist of structured messages. Here’s an example that rewards longer completion content: +For conversational format, completions consist of structured messages. Here’s an example of reward function that checks if the completion has a specific format. This example is inspired by the reward function used in the paper [DeepSeek-R1: Incentivizing Reasoning Capability in LLMs via Reinforcement Learning](https://huggingface.co/papers/2501.12948). ```python -def reward_func(prompts, completions): - """Reward function that gives higher scores to longer completion content.""" +import re + +def format_reward_func(prompts, completions): + """Reward function that checks if the completion has a specific format.""" + pattern = r"^.*?.*?$" completion_contents = [completion[0]["content"] for completion in completions] - return [float(len(content)) for content in completion_contents] + matches = [re.match(pattern, content) for content in completion_contents] + return [1.0 if match else 0.0 for match in matches] ``` You can test this function as follows: ```python >>> prompts = [ -... [{"role": "user", "content": "What color is the sky?"}], -... [{"role": "user", "content": "Where is the sun?"}], +... [{"role": "assistant", "content": "What is the result of (1 + 2) * 4?"}], +... [{"role": "assistant", "content": "What is the result of (3 + 1) * 2?"}], ... ] >>> completions = [ -... [{"role": "assistant", "content": "It is blue."}], -... [{"role": "assistant", "content": "In the sky."}], +... [{"role": "assistant", "content": "The sum of 1 and 2 is 3, which we multiply by 4 to get 12.(1 + 2) * 4 = 12"}], +... [{"role": "assistant", "content": "The sum of 3 and 1 is 4, which we multiply by 2 to get 8. So (3 + 1) * 2 = 8."}], ... ] ->>> print(reward_func(prompts, completions)) -[11.0, 11.0] +>>> format_reward_func(prompts, completions) +[1.0, 0.0] +>>> ``` #### Passing the reward function to the trainer