generated from fastai/nbdev_template
-
Notifications
You must be signed in to change notification settings - Fork 1.5k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
π¨βπ¨βπ§βπ§ GRPO (#2565)
* init grpo [ci skip] * initial version * refine args defs * model card * initial doc * fix badges * fix spaces * try link to super in doc * temperature, fix indexing, and std=0.0 * grpo script for cli * peft support * move data preparation in `compute_loss` * weird doc trial * fix device and some logging * unwrap_model_for_generation for distributed setting * Compat with distrib training * revert grpo config doc trial (didn't work) * test * allow model to be str and processing_class to be none; fix loss computation * advantage is always 0.0: don't log * fix peft not installed * proper reward model for testing * fix script for cli * add trl grpo to cli doc * test peft * flush left * fix reward calculation * new reward model * support any reward model * fix reward processing class def * log reward std * fix reward logging * fix grad computation * skip embed layer in test * remove optimizer_cls_and_kwargs * improve GRPO default args * reduce mem usage for grpo test * reduce mem usage in test grpo * reduce memory usage for test * Fix the test * remove redondant * fix min version * Update test_grpo_trainer.py * Update test_grpo_trainer.py * Fix test, finally found the solution! * some doc * Update doc-builder workflow to use specific commit sha * more doc * advantages * drop cancel fo no grad * logged metrics [ci skip] * completion col is ignored [ci skip] * fix latex * double space? ~? * try a latex fix * with branch * Empty commit * Empty commit * double space seems to be the solution
- Loading branch information
1 parent
88514d5
commit 0f5ffad
Showing
18 changed files
with
975 additions
and
14 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,123 @@ | ||
# GRPO Trainer | ||
|
||
[![](https://img.shields.io/badge/All_models-GRPO-blue)](https://huggingface.co/models?other=grpo,trl) | ||
|
||
## Overview | ||
|
||
TRL supports the GRPO Trainer for training language models, as described in the paper [DeepSeekMath: Pushing the Limits of Mathematical Reasoning in Open Language Models](https://huggingface.co/papers/2402.03300) by [Zhihong Shao](https://huggingface.co/syhia), [Peiyi Wang](https://huggingface.co/peiyiwang89), [Qihao Zhu](https://huggingface.co/zqh11), Runxin Xu, [Junxiao Song](https://huggingface.co/haha-point), Mingchuan Zhang, Y. K. Li, Y. Wu, [Daya Guo](https://huggingface.co/guoday). | ||
|
||
The abstract from the paper is the following: | ||
|
||
> Mathematical reasoning poses a significant challenge for language models due to its complex and structured nature. In this paper, we introduce DeepSeekMath 7B, which continues pre-training DeepSeek-Coder-Base-v1.5 7B with 120B math-related tokens sourced from Common Crawl, together with natural language and code data. DeepSeekMath 7B has achieved an impressive score of 51.7% on the competition-level MATH benchmark without relying on external toolkits and voting techniques, approaching the performance level of Gemini-Ultra and GPT-4. Self-consistency over 64 samples from DeepSeekMath 7B achieves 60.9% on MATH. The mathematical reasoning capability of DeepSeekMath is attributed to two key factors: First, we harness the significant potential of publicly available web data through a meticulously engineered data selection pipeline. Second, we introduce Group Relative Policy Optimization (GRPO), a variant of Proximal Policy Optimization (PPO), that enhances mathematical reasoning abilities while concurrently optimizing the memory usage of PPO. | ||
|
||
This post-training method was contributed by [Quentin GallouΓ©dec](https://huggingface.co/qgallouedec). | ||
|
||
## Quick start | ||
|
||
This example demonstrates how to train a model using the GRPO method. We use the [Qwen 0.5B model](https://huggingface.co/Qwen/Qwen2-0.5B) as the base model and the [RM-Gemma-2B model](https://huggingface.co/weqweasdas/RM-Gemma-2B) as the reward model. We use the prompts from the [TLDR dataset](https://huggingface.co/datasets/trl-lib/tldr) (completion column is ingored!). You can view the data in the dataset here: | ||
|
||
<iframe | ||
src="https://huggingface.co/datasets/trl-lib/tldr/embed/viewer/default/train?row=0" | ||
frameborder="0" | ||
width="100%" | ||
height="560px" | ||
></iframe> | ||
Below is the script to train the model. We use PEFT to reduce the memory requirements. | ||
|
||
```python | ||
# train_grpo.py | ||
from datasets import load_dataset | ||
from peft import LoraConfig | ||
from trl import GRPOConfig, GRPOTrainer | ||
|
||
# Load the dataset | ||
dataset = load_dataset("trl-lib/tldr", split="train") | ||
|
||
training_args = GRPOConfig( | ||
output_dir="Qwen2-0.5B-GRPO", | ||
learning_rate=1e-5, | ||
logging_steps=10, | ||
gradient_accumulation_steps=16, | ||
max_completion_length=128, | ||
) | ||
trainer = GRPOTrainer( | ||
model="Qwen/Qwen2-0.5B-Instruct", | ||
reward_model="weqweasdas/RM-Gemma-2B", | ||
args=training_args, | ||
train_dataset=dataset, | ||
peft_config=LoraConfig(task_type="CAUSAL_LM"), | ||
) | ||
|
||
trainer.train() | ||
``` | ||
|
||
Execute the script using the following command: | ||
|
||
```bash | ||
accelerate launch train_grpo.py | ||
``` | ||
|
||
Distributed across 8 GPUs, the training takes approximately 1 day. | ||
|
||
![](https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/grpo_curves.png) | ||
|
||
## Looking deeper into the GRPO method | ||
|
||
GRPO is an online learning algorithm, meaning it improves iteratively by using the data generated by the trained model itself during training. The intuition behind GRPO objective is to maximize the advantage of the generated completions, while ensuring that the model remains close to the reference policy. To understand how GRPO works, it can be broken down into four main steps: **Generating completions**, **computing the advantage**, **estimating the KL divergence**, and **computing the loss**. | ||
|
||
![](https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/grpo_visual.png) | ||
|
||
### Generating completions | ||
|
||
At each training step, we sample a batch of prompts and generate a set of \\( G \\) completions for each prompt (denoted as \\( o_i \\)). | ||
|
||
### Computing the advantage | ||
|
||
For each of the \\( G \\) sequences, we compute the reward using a reward model. To align with the comparative nature of reward modelsβtypically trained on datasets of comparisons between outputs for the same questionβthe advantage is calculated to reflect these relative comparisons. It is normalized as follows: | ||
|
||
$$\hat{A}_{i,t} = \frac{r_i - \text{mean}(\mathbf{r})}{\text{std}(\mathbf{r})}$$ | ||
|
||
This approach gives the method its name: **Group Relative Policy Optimization (GRPO)**. | ||
|
||
### Estimating the KL divergence | ||
|
||
KL divergence is estimated using the approximator introduced by [Schulman et al. (2020)](http://joschu.net/blog/kl-approx.html). The approximator is defined as follows: | ||
|
||
$$\mathbb{D}_{\text{KL}}\left[\pi_\theta \|\pi_{\text{ref}}\right] = \frac{\pi_{\text{ref}}(o_{i,t} \mid q, o_{i,<t})}{\pi_\theta(o_{i,t} \mid q, o_{i,<t})} - \log \frac{\pi_{\text{ref}}(o_{i,t} \mid q, o_{i,<t})}{\pi_\theta(o_{i,t} \mid q, o_{i,<t})} - 1, | ||
$$ | ||
|
||
### Computing the loss | ||
|
||
The objective is to maximize the advantage while ensuring that the model remains close to the reference policy. Consequently, the loss is defined as follows: | ||
|
||
$$ | ||
\mathcal{L}_{\text{GRPO}}(\theta) = -\frac{1}{G} \sum_{i=1}^G \frac{1}{|o_i|} \sum_{t=1}^{|o_i|} \left[ \frac{\pi_\theta(o_{i,t} \mid q, o_{i,< t})}{\left[\pi_\theta(o_{i,t} \mid q, o_{i,< t})\right]_{\text{no grad}}} \hat{A}_{i,t} - \beta \mathbb{D}_{\text{KL}}\left[\pi_\theta \| \pi_{\text{ref}}\right] \right], | ||
$$ | ||
|
||
where the first term represents the scaled advantage and the second term penalizes deviations from the reference policy through KL divergence. | ||
|
||
In the original paper, this formulation is generalized to account for multiple updates after each generation by leveraging the **clipped surrogate objective**: | ||
|
||
$$ | ||
\mathcal{L}_{\text{GRPO}}(\theta) = - \frac{1}{G} \sum_{i=1}^G \frac{1}{|o_i|} \sum_{t=1}^{|o_i|} \left[ \min \left( \frac{\pi_\theta(o_{i,t} \mid q, o_{i,< t})}{\pi_{\theta_{\text{old}}}(o_{i,t} \mid q, o_{i,< t})} \hat{A}_{i,t}, \, \text{clip}\left( \frac{\pi_\theta(o_{i,t} \mid q, o_{i,< t})}{\pi_{\theta_{\text{old}}}(o_{i,t} \mid q, o_{i,< t})}, 1 - \epsilon, 1 + \epsilon \right) \hat{A}_{i,t} \right) - \beta \mathbb{D}_{\text{KL}}\left[\pi_\theta \| \pi_{\text{ref}}\right] \right], | ||
$$ | ||
|
||
where \\(\text{clip}(\cdot, 1 - \epsilon, 1 + \epsilon) \\) ensures that updates do not deviate excessively from the reference policy by bounding the policy ratio between \\( 1 - \epsilon \\) and \\( 1 + \epsilon \\). | ||
In TRL though, as in the original paper, we only do one update per generation, so we can simplify the loss to the first form. | ||
|
||
## Logged metrics | ||
|
||
The GRPO Trainer logs the following metrics: | ||
|
||
- `reward`: The average reward. | ||
- `reward_std` : The average standard deviation within reward groups. | ||
- `kl` : The average KL divergence between the model and the reference model calculated on completions. | ||
|
||
## GRPOTrainer | ||
|
||
[[autodoc]] GRPOTrainer | ||
|
||
## GRPOConfig | ||
|
||
[[autodoc]] GRPOConfig |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,148 @@ | ||
# Copyright 2025 The HuggingFace Team. All rights reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
import tempfile | ||
import unittest | ||
|
||
import torch | ||
from datasets import load_dataset | ||
from parameterized import parameterized | ||
from transformers import AutoModelForCausalLM, AutoModelForSequenceClassification, AutoTokenizer | ||
from transformers.testing_utils import require_peft | ||
from transformers.utils import is_peft_available | ||
|
||
from trl import GRPOConfig, GRPOTrainer | ||
|
||
|
||
if is_peft_available(): | ||
from peft import LoraConfig | ||
|
||
|
||
class GRPOTrainerTester(unittest.TestCase): | ||
def test_init_minimal(self): | ||
# Test that GRPOTrainer can be instantiated with only model, reward_model and train_dataset | ||
dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train") | ||
GRPOTrainer( | ||
model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", | ||
reward_model="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5", | ||
train_dataset=dataset, | ||
) | ||
|
||
@parameterized.expand([("standard_prompt_only",), ("conversational_prompt_only",)]) | ||
def test_training(self, config_name): | ||
dataset = load_dataset("trl-internal-testing/zen", config_name, split="train") | ||
|
||
with tempfile.TemporaryDirectory() as tmp_dir: | ||
training_args = GRPOConfig( | ||
output_dir=tmp_dir, | ||
learning_rate=0.1, # increase the learning rate to speed up the test | ||
per_device_train_batch_size=2, # reduce the batch size to reduce memory usage | ||
num_generations=3, # reduce the number of generations to reduce memory usage | ||
max_completion_length=32, # reduce the completion length to reduce memory usage | ||
report_to="none", | ||
) | ||
trainer = GRPOTrainer( | ||
model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", | ||
reward_model="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5", | ||
args=training_args, | ||
train_dataset=dataset, | ||
) | ||
|
||
previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} | ||
|
||
trainer.train() | ||
|
||
self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) | ||
|
||
# Check that the params have changed | ||
for n, param in previous_trainable_params.items(): | ||
new_param = trainer.model.get_parameter(n) | ||
self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.") | ||
|
||
@require_peft | ||
def test_training_peft(self): | ||
model = AutoModelForCausalLM.from_pretrained("trl-internal-testing/tiny-Qwen2ForCausalLM-2.5") | ||
base_param_names = [f"base_model.model.{n}" for n, _ in model.named_parameters()] | ||
dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train") | ||
|
||
with tempfile.TemporaryDirectory() as tmp_dir: | ||
training_args = GRPOConfig( | ||
output_dir=tmp_dir, | ||
learning_rate=0.1, # increase the learning rate to speed up the test | ||
per_device_train_batch_size=2, # reduce the batch size to reduce memory usage | ||
num_generations=3, # reduce the number of generations to reduce memory usage | ||
max_completion_length=32, # reduce the completion length to reduce memory usage | ||
report_to="none", | ||
) | ||
trainer = GRPOTrainer( | ||
model=model, | ||
reward_model="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5", | ||
args=training_args, | ||
train_dataset=dataset, | ||
peft_config=LoraConfig(), | ||
) | ||
|
||
previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} | ||
|
||
trainer.train() | ||
|
||
self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) | ||
|
||
# Check the peft params have changed and the base model params have not changed | ||
for n, param in previous_trainable_params.items(): | ||
new_param = trainer.model.get_parameter(n) | ||
if n in base_param_names: # We expect the base model params to be the same | ||
self.assertTrue(torch.allclose(param, new_param), f"Parameter {n} has changed.") | ||
elif "base_layer" not in n: # We expect the peft params to be different (except for the base layer) | ||
self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed.") | ||
|
||
def test_training_different_reward_model(self): | ||
# Use a reward model different from the model: different chat template, tokenization, etc. | ||
dataset = load_dataset("trl-internal-testing/zen", "conversational_prompt_only", split="train") | ||
reward_model_id = "trl-internal-testing/tiny-LlamaForSequenceClassification-3.2" | ||
reward_model = AutoModelForSequenceClassification.from_pretrained(reward_model_id) | ||
reward_tokenizer = AutoTokenizer.from_pretrained(reward_model_id) | ||
# By default, the trainer uses the eos token as the padding token. However, for Llama models, the eos token | ||
# appears in the chat template. Using it as a pad token disrupts the reward calculation, as the calculation | ||
# considers the score of the last token before the first pad token. To ensure correct reward calculations, | ||
# we use a separate pad token instead. | ||
reward_tokenizer.pad_token = "<|finetune_right_pad_id|>" | ||
|
||
with tempfile.TemporaryDirectory() as tmp_dir: | ||
training_args = GRPOConfig( | ||
output_dir=tmp_dir, | ||
learning_rate=0.1, # increase the learning rate to speed up the test | ||
per_device_train_batch_size=2, # reduce the batch size to reduce memory usage | ||
num_generations=3, # reduce the number of generations to reduce memory usage | ||
max_completion_length=32, # reduce the completion length to reduce memory usage | ||
report_to="none", | ||
) | ||
trainer = GRPOTrainer( | ||
model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", | ||
reward_model=reward_model, | ||
args=training_args, | ||
train_dataset=dataset, | ||
reward_processing_class=reward_tokenizer, | ||
) | ||
|
||
previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} | ||
|
||
trainer.train() | ||
|
||
self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) | ||
|
||
# Check the params have changed | ||
for n, param in previous_trainable_params.items(): | ||
new_param = trainer.model.get_parameter(n) | ||
self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.") |
Oops, something went wrong.