Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

📉 Optimize GRPO memory usage by redefining per_device_batch_size as generations per device #2776

Merged
merged 18 commits into from
Feb 6, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions docs/source/grpo_trainer.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,6 @@ This example demonstrates how to train a model using the GRPO method. We train a
></iframe>

Below is the script to train the model.
Note that the input tensor for the forward pass has a size of `num_generations * per_device_train_batch_size` because GRPO generates `num_generations` completions for each prompt in the batch. Adjusting these values appropriately can help prevent OOM errors.
Consequently, the effective train batch size is `num_generations * per_device_train_batch_size * gradient_accumulation_steps`.

```python
# train_grpo.py
Expand Down
2 changes: 1 addition & 1 deletion tests/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def test_grpo(self):
from trl.cli import main

with tempfile.TemporaryDirectory() as tmp_dir: # Create a temporary directory
command = f"trl grpo --output_dir {tmp_dir} --model_name_or_path trl-internal-testing/tiny-Qwen2ForCausalLM-2.5 --reward_model_name_or_path trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5 --dataset_name trl-internal-testing/zen --dataset_config standard_prompt_only --num_generations 3 --max_completion_length 32 --report_to none"
command = f"trl grpo --output_dir {tmp_dir} --model_name_or_path trl-internal-testing/tiny-Qwen2ForCausalLM-2.5 --reward_model_name_or_path trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5 --dataset_name trl-internal-testing/zen --dataset_config standard_prompt_only --num_generations 4 --max_completion_length 32 --report_to none"
with patch("sys.argv", command.split(" ")):
main()

Expand Down
26 changes: 13 additions & 13 deletions tests/test_grpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def test_training(self, config_name):
training_args = GRPOConfig(
output_dir=tmp_dir,
learning_rate=0.1, # increase the learning rate to speed up the test
per_device_train_batch_size=2, # reduce the batch size to reduce memory usage
per_device_train_batch_size=3, # reduce the batch size to reduce memory usage
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

New requirement: The global batch size must be evenly divisible by the number of generations per prompt.

num_generations=3, # reduce the number of generations to reduce memory usage
max_completion_length=32, # reduce the completion length to reduce memory usage
report_to="none",
Expand Down Expand Up @@ -78,8 +78,8 @@ def test_training_with_eval(self):
with tempfile.TemporaryDirectory() as tmp_dir:
training_args = GRPOConfig(
output_dir=tmp_dir,
per_device_train_batch_size=2, # reduce the batch size to reduce memory usage
per_device_eval_batch_size=2, # reduce the batch size to reduce memory usage
per_device_train_batch_size=3, # reduce the batch size to reduce memory usage
per_device_eval_batch_size=3, # reduce the batch size to reduce memory usage
num_generations=3, # reduce the number of generations to reduce memory usage
max_completion_length=32, # reduce the completion length to reduce memory usage
eval_strategy="steps",
Expand All @@ -106,7 +106,7 @@ def test_training_peft(self):
training_args = GRPOConfig(
output_dir=tmp_dir,
learning_rate=0.1, # increase the learning rate to speed up the test
per_device_train_batch_size=2, # reduce the batch size to reduce memory usage
per_device_train_batch_size=3, # reduce the batch size to reduce memory usage
num_generations=3, # reduce the number of generations to reduce memory usage
max_completion_length=32, # reduce the completion length to reduce memory usage
report_to="none",
Expand Down Expand Up @@ -149,7 +149,7 @@ def test_training_different_reward_model(self):
training_args = GRPOConfig(
output_dir=tmp_dir,
learning_rate=0.1, # increase the learning rate to speed up the test
per_device_train_batch_size=2, # reduce the batch size to reduce memory usage
per_device_train_batch_size=3, # reduce the batch size to reduce memory usage
num_generations=3, # reduce the number of generations to reduce memory usage
max_completion_length=32, # reduce the completion length to reduce memory usage
report_to="none",
Expand Down Expand Up @@ -185,7 +185,7 @@ def reward_func(completions, **kwargs):
training_args = GRPOConfig(
output_dir=tmp_dir,
learning_rate=0.1, # increase the learning rate to speed up the test
per_device_train_batch_size=2, # reduce the batch size to reduce memory usage
per_device_train_batch_size=3, # reduce the batch size to reduce memory usage
num_generations=3, # reduce the number of generations to reduce memory usage
max_completion_length=32, # reduce the completion length to reduce memory usage
report_to="none",
Expand Down Expand Up @@ -221,7 +221,7 @@ def reward_func(completions, **kwargs):
training_args = GRPOConfig(
output_dir=tmp_dir,
learning_rate=0.1, # increase the learning rate to speed up the test
per_device_train_batch_size=2, # reduce the batch size to reduce memory usage
per_device_train_batch_size=3, # reduce the batch size to reduce memory usage
num_generations=3, # reduce the number of generations to reduce memory usage
max_completion_length=32, # reduce the completion length to reduce memory usage
report_to="none",
Expand Down Expand Up @@ -260,7 +260,7 @@ def reward_func2(completions, **kwargs):
training_args = GRPOConfig(
output_dir=tmp_dir,
learning_rate=0.1, # increase the learning rate to speed up the test
per_device_train_batch_size=2, # reduce the batch size to reduce memory usage
per_device_train_batch_size=3, # reduce the batch size to reduce memory usage
num_generations=3, # reduce the number of generations to reduce memory usage
max_completion_length=32, # reduce the completion length to reduce memory usage
report_to="none",
Expand Down Expand Up @@ -295,7 +295,7 @@ def reward_func(completions, **kwargs):
training_args = GRPOConfig(
output_dir=tmp_dir,
learning_rate=0.1, # increase the learning rate to speed up the test
per_device_train_batch_size=2, # reduce the batch size to reduce memory usage
per_device_train_batch_size=3, # reduce the batch size to reduce memory usage
num_generations=3, # reduce the number of generations to reduce memory usage
max_completion_length=32, # reduce the completion length to reduce memory usage
report_to="none",
Expand Down Expand Up @@ -334,7 +334,7 @@ def reward_func(completions, some_values, **kwargs):
training_args = GRPOConfig(
output_dir=tmp_dir,
learning_rate=0.1, # increase the learning rate to speed up the test
per_device_train_batch_size=2, # reduce the batch size to reduce memory usage
per_device_train_batch_size=3, # reduce the batch size to reduce memory usage
num_generations=3, # reduce the number of generations to reduce memory usage
max_completion_length=32, # reduce the completion length to reduce memory usage
report_to="none",
Expand Down Expand Up @@ -367,7 +367,7 @@ def test_training_vllm(self):
training_args = GRPOConfig(
output_dir=tmp_dir,
learning_rate=0.1, # increase the learning rate to speed up the test
per_device_train_batch_size=2, # reduce the batch size to reduce memory usage
per_device_train_batch_size=3, # reduce the batch size to reduce memory usage
num_generations=3, # reduce the number of generations to reduce memory usage
max_completion_length=32, # reduce the completion length to reduce memory usage
report_to="none",
Expand Down Expand Up @@ -400,7 +400,7 @@ def test_training_torch_compile(self):
training_args = GRPOConfig(
output_dir=tmp_dir,
learning_rate=0.1, # increase the learning rate to speed up the test
per_device_train_batch_size=2, # reduce the batch size to reduce memory usage
per_device_train_batch_size=3, # reduce the batch size to reduce memory usage
num_generations=3, # reduce the number of generations to reduce memory usage
max_completion_length=32, # reduce the completion length to reduce memory usage
torch_compile=True,
Expand Down Expand Up @@ -431,7 +431,7 @@ def test_training_with_sync_ref_model(self):
training_args = GRPOConfig(
output_dir=tmp_dir,
learning_rate=0.1, # increase the learning rate to speed up the test
per_device_train_batch_size=2, # reduce the batch size to reduce memory usage
per_device_train_batch_size=3, # reduce the batch size to reduce memory usage
num_generations=3, # reduce the number of generations to reduce memory usage
max_completion_length=32, # reduce the completion length to reduce memory usage
sync_ref_model=True,
Expand Down
4 changes: 4 additions & 0 deletions trl/models/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,8 @@ def setup_chat_format(

def remove_hooks(model: "DeepSpeedEngine") -> None:
"""Removes the optimizer hooks from a DeepSpeed ZeRO-3 model."""
if not hasattr(model, "optimizer"): # before the first training step, the model has no optimizer
return
if model.optimizer is not None and hasattr(model.optimizer, "parameter_offload"):
optimizer_offload = model.optimizer.parameter_offload
elif model.optimizer is not None:
Expand Down Expand Up @@ -164,6 +166,8 @@ def iter_params(module, recurse=False):

def add_hooks(model: "DeepSpeedEngine") -> None:
"""Adds the optimizer hooks from a DeepSpeed ZeRO-3 model."""
if not hasattr(model, "optimizer"): # before the first training step, the model has no optimizer
return
if model.optimizer is not None and hasattr(model.optimizer, "parameter_offload"):
optimizer_offload = model.optimizer.parameter_offload
elif model.optimizer is not None:
Expand Down
30 changes: 6 additions & 24 deletions trl/trainer/grpo_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,8 @@ class GRPOConfig(TrainingArguments):
max_prompt_length (`int` or `None`, *optional*, defaults to `512`):
Maximum length of the prompt. If the prompt is longer than this value, it will be truncated left.
num_generations (`int` or `None`, *optional*, defaults to `8`):
Number of generations per prompt to sample.
Number of generations per prompt to sample. The global batch size (num_processes * per_device_batch_size)
must be divisible by this value.
temperature (`float`, *optional*, defaults to `0.9`):
Temperature for sampling. The higher the temperature, the more random the completions.
max_completion_length (`int` or `None`, *optional*, defaults to `256`):
Expand Down Expand Up @@ -83,11 +84,6 @@ class GRPOConfig(TrainingArguments):
learning_rate (`float`, *optional*, defaults to `1e-6`):
Initial learning rate for [`AdamW`] optimizer. The default value replaces that of
[`~transformers.TrainingArguments`].
per_device_train_batch_size (`int`, *optional*, defaults to `1`):
Number of prompts sampled per device for training. The actual batch passed into the model will be this
value multiplied by `num_generations`.
gradient_accumulation_steps (`int`, *optional*, defaults to `8`):
Number of updates steps to accumulate the gradients for, before performing a backward/update pass.
beta (`float`, *optional*, defaults to `0.04`):
KL coefficient.
sync_ref_model (`bool`, *optional*, defaults to `False`):
Expand Down Expand Up @@ -132,7 +128,10 @@ class GRPOConfig(TrainingArguments):
)
num_generations: Optional[int] = field(
default=8,
metadata={"help": "Number of generations to sample."},
metadata={
"help": "Number of generations to sample. The global batch size (num_processes * per_device_batch_size) "
"must be divisible by this value."
},
)
temperature: Optional[float] = field(
default=0.9,
Expand Down Expand Up @@ -202,23 +201,6 @@ class GRPOConfig(TrainingArguments):
"`transformers.TrainingArguments`."
},
)
# GRPO generates multiple completions per prompt, increasing memory usage.
# To accommodate this, the per-device train batch size is decreased (overriden from the parent class),
# and the number gradient accumulation steps is increased to maintain the effective batch size.
per_device_train_batch_size: int = field(
default=1,
metadata={
"help": "Number of prompts sampled per device for training. The actual batch passed into the model will "
"be this value multiplied by `num_generations`."
},
)
gradient_accumulation_steps: int = field(
default=8,
metadata={
"help": "Number of updates steps to accumulate the gradients for, before performing a backward/update "
"pass."
},
)
beta: float = field(
default=0.04,
metadata={"help": "KL coefficient."},
Expand Down
Loading