-
Notifications
You must be signed in to change notification settings - Fork 1.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
📉 Optimize GRPO memory usage by redefining per_device_batch_size
as generations per device
#2776
Conversation
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
per_device_train_batch_size=2, # reduce the batch size to reduce memory usage | ||
per_device_train_batch_size=3, # reduce the batch size to reduce memory usage |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
New requirement: The global batch size must be evenly divisible by the number of generations per prompt.
@@ -324,7 +419,6 @@ def data_collator(features): # No data collation is needed in GRPO | |||
enable_prefix_caching=True, | |||
) | |||
self.sampling_params = SamplingParams( | |||
n=self.num_generations, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
now, the prompt is repeated, so the generation will be called G
times. It should be equivalent in term of compute thanks to enable_prefix_caching=True,
trl/trainer/grpo_trainer.py
Outdated
all_prompts_text = gather_object(prompts_text) | ||
if self.accelerator.is_main_process: | ||
outputs = self.llm.generate(all_prompts_text, sampling_params=self.sampling_params, use_tqdm=False) | ||
completion_ids = [out.token_ids for completions in outputs for out in completions.outputs] | ||
else: | ||
completion_ids = [None] * len(all_prompts_text) * self.num_generations | ||
|
||
# Broadcast the completions from the main process to all processes, ensuring each process receives its | ||
# corresponding slice. | ||
completion_ids = broadcast_object_list(completion_ids, from_process=0) | ||
process_slice = slice( | ||
self.accelerator.process_index * len(prompts) * self.num_generations, | ||
(self.accelerator.process_index + 1) * len(prompts) * self.num_generations, | ||
) | ||
completion_ids = completion_ids[process_slice] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
None need to gather/broadcast, everything here happens only in one process
trl/trainer/grpo_trainer.py
Outdated
# First, have main process load weights if needed | ||
if self.state.global_step != self._last_loaded_step: | ||
with unwrap_model_for_generation( | ||
self.model, self.accelerator, gather_deepspeed3_params=self.args.ds3_gather_for_generation | ||
) as unwrapped_model: | ||
if is_compiled_module(unwrapped_model): | ||
state_dict = unwrapped_model._orig_mod.state_dict() | ||
else: | ||
state_dict = unwrapped_model.state_dict() | ||
if self.accelerator.is_main_process: | ||
llm_model = self.llm.llm_engine.model_executor.driver_worker.model_runner.model | ||
llm_model.load_weights(state_dict.items()) | ||
self._last_loaded_step = self.state.global_step | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this is done just above
@@ -484,8 +575,6 @@ def _prepare_inputs(self, inputs: dict[str, Union[torch.Tensor, Any]]) -> dict[s | |||
completions = [[{"role": "assistant", "content": completion}] for completion in completions] | |||
|
|||
# Compute the rewards | |||
prompts = [prompt for prompt in prompts for _ in range(self.num_generations)] # repeat prompts |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
prompts are already repeated at this point
@@ -453,7 +545,6 @@ def _prepare_inputs(self, inputs: dict[str, Union[torch.Tensor, Any]]) -> dict[s | |||
prompt_length = prompt_ids.size(1) | |||
prompt_ids = prompt_completion_ids[:, :prompt_length] | |||
completion_ids = prompt_completion_ids[:, prompt_length:] | |||
prompt_mask = prompt_mask.repeat_interleave(self.num_generations, dim=0) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
prompts are already repeated at this point
@@ -525,24 +611,24 @@ def _prepare_inputs(self, inputs: dict[str, Union[torch.Tensor, Any]]) -> dict[s | |||
advantages = (rewards - mean_grouped_rewards) / (std_grouped_rewards + 1e-4) | |||
|
|||
# Log the metrics | |||
reward_per_func = self.accelerator.gather_for_metrics(rewards_per_func).mean(0) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
no need to gather, everything is processed in the main process
trl/trainer/grpo_trainer.py
Outdated
# Gather inputs and process them in the main process. This is important because the rewards are normalized | ||
# per group. | ||
inputs = gather_object(inputs) | ||
prepared = self._prepare_main(inputs) if self.accelerator.is_main_process else None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In the non-vllm setting, is only one GPU doing generation and the others sit idle?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, I hadn't thought of that.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It shouldn't be super hard to solve that. Let me try
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done, it should be much easier to review now 🤗
does this also address #2688 |
No, The OOM from #2688 is due to vLLM, the solution is here: #2688 (comment) |
Per device batch size = 4 |
Thank you <3 |
did you check the outputs for each device in the case of in my case, all devices generate the same outptus. when
expectation:
|
Is it also possible to create larger num_generations for same prompt by repeating (like accumulate?) |
indeed, i tried using using trl==0.15.0 to reproduce open-r1, the training speed become extremly slow |
@zaddy6 But 2 is too small. How many devices do you use? |
8 Devices, 1 assigned to VLLM @qgallouedec |
I recommend using num generations to 7 then |
What the equivalent of my old config for this new approach?
|
- per_device_train_batch_size=1,
+ per_device_train_batch_size=2, Because now, the |
@qgallouedec using
|
Ok but what's the point to use grpo with 2 generations per prompt? |
Am I missing something, or this is still not possible? |
I dont know the memory to train Qwen2.5-7b. But I think 3 x H20(96G) is enough. I try to train it, but OOM. I think Maybe I misunderstand the per_device_train_batch_size and num_generations. I test the problem in different per_device_train_batch_size and num_generations, I find that if I use one H20,
|
Hi, @qgallouedec thanks for this PR, which seems to solve my previous concerns, which is:
as long as |
when set |
Background
In the current implementation of GRPO, each device samples
N
prompts. For theseN
prompts,G
generations are produced, and the loss is computed over them. However, this approach is highly memory-intensive, especially when usingG ≥ 8
, often leading to OOM errors, even with a batch size of just 1 prompt per device.Proposed change
This PR introduces a more flexible approach:
per_device_batch_size
as the number of prompts per device, it now represents the number of generations per device.G
) and the batch size per device.num_processes * per_device_batch_size
) must be divisible byG
.Note that these settings should be equivalent:
(new setting) and ``per_device_batch_size==1
(old setting)Benefits
G
or completion length without hitting OOMBefore:
Now: