Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

📉 Optimize GRPO memory usage by redefining per_device_batch_size as generations per device #2776

Merged
merged 18 commits into from
Feb 6, 2025

Conversation

qgallouedec
Copy link
Member

@qgallouedec qgallouedec commented Feb 5, 2025

Background

In the current implementation of GRPO, each device samples N prompts. For these N prompts, G generations are produced, and the loss is computed over them. However, this approach is highly memory-intensive, especially when using G ≥ 8, often leading to OOM errors, even with a batch size of just 1 prompt per device.

Proposed change

This PR introduces a more flexible approach:

  • Instead of defining per_device_batch_size as the number of prompts per device, it now represents the number of generations per device.
  • This allows for much greater flexibility in choosing the number of generations (G) and the batch size per device.
  • The only constraint is that the global batch size (num_processes * per_device_batch_size) must be divisible by G.

Note that these settings should be equivalent:

num_generations = ...  # eg, 8
num_prompts_per_device = ...  # eg, 1
# main
GRPOConfig(num_generations=num_generations, per_device_batch_size=num_prompts_per_device, ...)
# this PR
GRPOConfig(num_generations=num_generations, per_device_batch_size=num_generations*num_prompts_per_device, ...)

(new setting) and ``per_device_batch_size==1 (old setting)

Benefits

  • Reduces memory, making it feasible to use higher G or completion length without hitting OOM
  • Provides more control over batch size and generation count per device
  • Ensures scalability while maintaining the integrity of GRPO's training logic

Before:

image

Now:

image

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

per_device_train_batch_size=2, # reduce the batch size to reduce memory usage
per_device_train_batch_size=3, # reduce the batch size to reduce memory usage
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

New requirement: The global batch size must be evenly divisible by the number of generations per prompt.

@@ -324,7 +419,6 @@ def data_collator(features): # No data collation is needed in GRPO
enable_prefix_caching=True,
)
self.sampling_params = SamplingParams(
n=self.num_generations,
Copy link
Member Author

@qgallouedec qgallouedec Feb 5, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

now, the prompt is repeated, so the generation will be called G times. It should be equivalent in term of compute thanks to enable_prefix_caching=True,

Comment on lines 423 to 442
all_prompts_text = gather_object(prompts_text)
if self.accelerator.is_main_process:
outputs = self.llm.generate(all_prompts_text, sampling_params=self.sampling_params, use_tqdm=False)
completion_ids = [out.token_ids for completions in outputs for out in completions.outputs]
else:
completion_ids = [None] * len(all_prompts_text) * self.num_generations

# Broadcast the completions from the main process to all processes, ensuring each process receives its
# corresponding slice.
completion_ids = broadcast_object_list(completion_ids, from_process=0)
process_slice = slice(
self.accelerator.process_index * len(prompts) * self.num_generations,
(self.accelerator.process_index + 1) * len(prompts) * self.num_generations,
)
completion_ids = completion_ids[process_slice]
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

None need to gather/broadcast, everything here happens only in one process

Comment on lines 408 to 426
# First, have main process load weights if needed
if self.state.global_step != self._last_loaded_step:
with unwrap_model_for_generation(
self.model, self.accelerator, gather_deepspeed3_params=self.args.ds3_gather_for_generation
) as unwrapped_model:
if is_compiled_module(unwrapped_model):
state_dict = unwrapped_model._orig_mod.state_dict()
else:
state_dict = unwrapped_model.state_dict()
if self.accelerator.is_main_process:
llm_model = self.llm.llm_engine.model_executor.driver_worker.model_runner.model
llm_model.load_weights(state_dict.items())
self._last_loaded_step = self.state.global_step

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is done just above

@@ -484,8 +575,6 @@ def _prepare_inputs(self, inputs: dict[str, Union[torch.Tensor, Any]]) -> dict[s
completions = [[{"role": "assistant", "content": completion}] for completion in completions]

# Compute the rewards
prompts = [prompt for prompt in prompts for _ in range(self.num_generations)] # repeat prompts
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

prompts are already repeated at this point

@@ -453,7 +545,6 @@ def _prepare_inputs(self, inputs: dict[str, Union[torch.Tensor, Any]]) -> dict[s
prompt_length = prompt_ids.size(1)
prompt_ids = prompt_completion_ids[:, :prompt_length]
completion_ids = prompt_completion_ids[:, prompt_length:]
prompt_mask = prompt_mask.repeat_interleave(self.num_generations, dim=0)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

prompts are already repeated at this point

@@ -525,24 +611,24 @@ def _prepare_inputs(self, inputs: dict[str, Union[torch.Tensor, Any]]) -> dict[s
advantages = (rewards - mean_grouped_rewards) / (std_grouped_rewards + 1e-4)

# Log the metrics
reward_per_func = self.accelerator.gather_for_metrics(rewards_per_func).mean(0)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no need to gather, everything is processed in the main process

# Gather inputs and process them in the main process. This is important because the rewards are normalized
# per group.
inputs = gather_object(inputs)
prepared = self._prepare_main(inputs) if self.accelerator.is_main_process else None
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the non-vllm setting, is only one GPU doing generation and the others sit idle?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I hadn't thought of that.

Copy link
Member Author

@qgallouedec qgallouedec Feb 5, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It shouldn't be super hard to solve that. Let me try

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done, it should be much easier to review now 🤗

@zaddy6
Copy link

zaddy6 commented Feb 6, 2025

does this also address #2688

@qgallouedec
Copy link
Member Author

No, The OOM from #2688 is due to vLLM, the solution is here: #2688 (comment)

@qgallouedec
Copy link
Member Author

Regression test:
green: this PR
red: main

Screenshot 2025-02-06 at 14 21 53
from datasets import load_dataset
from trl import GRPOConfig, GRPOTrainer

dataset = load_dataset("trl-lib/tldr", split="train")


# Dummy reward function: the closer the completion is to 20 characters, the higher the reward
def reward_len(completions, **kwargs):
    return [-abs(20 - len(completion)) for completion in completions]


training_args = GRPOConfig(
    output_dir="Qwen2.5-0.5B-GRPO-2776",
    logging_steps=5,
    gradient_accumulation_steps=4,
    per_device_train_batch_size=4, # or 1 with main
    num_generations=4,
    max_completion_length=64,
    max_prompt_length=64,
    max_steps=100,
)
trainer = GRPOTrainer(
    model="Qwen/Qwen2.5-0.5B-Instruct",
    reward_funcs=reward_len,
    args=training_args,
    train_dataset=dataset,
)

trainer.train()

@qgallouedec
Copy link
Member Author

Per device batch size = 4
Num completions = 4

@NicolasMejiaPetit
Copy link

Per device batch size = 4 Num completions = 4

Thank you <3

@kiddj
Copy link
Contributor

kiddj commented Feb 10, 2025

did you check the outputs for each device in the case of num_generations > per_device_batch_size?

in my case, all devices generate the same outptus.

when num_generations = 8, per_device_batch_size = 4,
the current PR produces

GPU0: [1, 2, 3, 4]
GPU1: [1, 2, 3, 4]

expectation:

GPU0: [1, 2, 3, 4]
GPU1: [5, 6, 7, 8]

@ertugrul-dmr
Copy link

Is it also possible to create larger num_generations for same prompt by repeating (like accumulate?)

@zaddy6
Copy link

zaddy6 commented Feb 14, 2025

I think this PR made alignment slower

image

Green is this PR

Current Config


training_args = GRPOConfig(
    output_dir=output_dir,
    run_name=run_name,
    learning_rate=5e-7,
    lr_scheduler_type='cosine',
    logging_steps=1,
    per_device_train_batch_size=2,
    gradient_accumulation_steps=1,
    gradient_checkpointing=True,
    gradient_checkpointing_kwargs={"use_reentrant": False},
    bf16=True,
    num_generations=2,
    max_prompt_length=512,
    max_completion_length=512,
    use_vllm = True,
    num_train_epochs=4,
    save_steps=100,
    report_to="wandb",
    save_total_limit=2,
    log_on_each_node=False,
    beta=0.001,
    warmup_ratio=0.07,
    vllm_gpu_memory_utilization=0.7,
    optim="adamw_8bit",
    # temperature=0.5
)

Previous config

    per_device_train_batch_size=1,
    gradient_accumulation_steps=1,
    gradient_checkpointing=True,
    gradient_checkpointing_kwargs={"use_reentrant": False},
    bf16=True,
    num_generations=2,
    
    ```
    
   Training on H100x8 

Qwen2.5-3B 


@2proveit
Copy link

I think this PR made alignment slower

image Green is this PR

Current Config


training_args = GRPOConfig(
    output_dir=output_dir,
    run_name=run_name,
    learning_rate=5e-7,
    lr_scheduler_type='cosine',
    logging_steps=1,
    per_device_train_batch_size=2,
    gradient_accumulation_steps=1,
    gradient_checkpointing=True,
    gradient_checkpointing_kwargs={"use_reentrant": False},
    bf16=True,
    num_generations=2,
    max_prompt_length=512,
    max_completion_length=512,
    use_vllm = True,
    num_train_epochs=4,
    save_steps=100,
    report_to="wandb",
    save_total_limit=2,
    log_on_each_node=False,
    beta=0.001,
    warmup_ratio=0.07,
    vllm_gpu_memory_utilization=0.7,
    optim="adamw_8bit",
    # temperature=0.5
)

Previous config

    per_device_train_batch_size=1,
    gradient_accumulation_steps=1,
    gradient_checkpointing=True,
    gradient_checkpointing_kwargs={"use_reentrant": False},
    bf16=True,
    num_generations=2,
    
    ```
    
   Training on H100x8 

Qwen2.5-3B 

indeed, i tried using using trl==0.15.0 to reproduce open-r1, the training speed become extremly slow

@qgallouedec
Copy link
Member Author

qgallouedec commented Feb 17, 2025

@zaddy6 it's because you changed num_generations edit: no

But 2 is too small. How many devices do you use?

@zaddy6
Copy link

zaddy6 commented Feb 17, 2025

8 Devices, 1 assigned to VLLM @qgallouedec

@qgallouedec
Copy link
Member Author

I recommend using num generations to 7 then

@zaddy6
Copy link

zaddy6 commented Feb 17, 2025

What the equivalent of my old config for this new approach?

    per_device_train_batch_size=1,
    gradient_accumulation_steps=1,
    gradient_checkpointing=True,
    gradient_checkpointing_kwargs={"use_reentrant": False},
    bf16=True,
    num_generations=2,
    ```
    my reward function make an external api call, so i sort of need to control how many generation in order to not exceed the rate limit 
    

@qgallouedec
Copy link
Member Author

- per_device_train_batch_size=1,
+ per_device_train_batch_size=2,

Because now, the per_device_train_batch_size denotes the number of generations per device, and not the number of prompts per device anymore.

@zaddy6
Copy link

zaddy6 commented Feb 17, 2025

@qgallouedec using

num_generations=7```

Slowed down training almost 4x 

from 3hours to 12 hours

@qgallouedec
Copy link
Member Author

Ok but what's the point to use grpo with 2 generations per prompt?

@zaddy6
Copy link

zaddy6 commented Feb 17, 2025

image My bad, `num_generation=7` made it faster and better, green is the new run

@ertugrul-dmr
Copy link

Is it also possible to create larger num_generations for same prompt by repeating (like accumulate?)

Am I missing something, or this is still not possible?

@Tuziking
Copy link

I dont know the memory to train Qwen2.5-7b. But I think 3 x H20(96G) is enough. I try to train it, but OOM. I think Maybe I misunderstand the per_device_train_batch_size and num_generations. I test the problem in different per_device_train_batch_size and num_generations, I find that if I use one H20, per_device_train_batch_size==4, num_generations==4to train, it can continue some steps before OOM. But if I use 3 x H20, per_device_train_batch_size==2, num_generations==6to train, OOM occur early. I don't know why more H20 to train but OOM is occur early. The error log is as follow:

[rank0]:[W223 17:38:57.742650271 reducer.cpp:1400] Warning: find_unused_parameters=True was specified in DDP constructor, but did not find any unused parameters in the forward pass. This flag results in an extra traversal of the autograd graph every iteration,  which can adversely affect performance. If your model indeed never has any unused parameters in the forward pass, consider turning this flag off. Note that this warning may be a false positive if your model has flow control causing later iterations to have unused parameters. (function operator())
[rank1]:[W223 17:38:57.742968833 reducer.cpp:1400] Warning: find_unused_parameters=True was specified in DDP constructor, but did not find any unused parameters in the forward pass. This flag results in an extra traversal of the autograd graph every iteration,  which can adversely affect performance. If your model indeed never has any unused parameters in the forward pass, consider turning this flag off. Note that this warning may be a false positive if your model has flow control causing later iterations to have unused parameters. (function operator())
[rank2]:[W223 17:38:57.745997855 reducer.cpp:1400] Warning: find_unused_parameters=True was specified in DDP constructor, but did not find any unused parameters in the forward pass. This flag results in an extra traversal of the autograd graph every iteration,  which can adversely affect performance. If your model indeed never has any unused parameters in the forward pass, consider turning this flag off. Note that this warning may be a false positive if your model has flow control causing later iterations to have unused parameters. (function operator())
Traceback (most recent call last):
  File "/online1/sc100010/sc100010/qb_project/MARL/trl_GRPO_train.py", line 176, in <module>
    trainer.train()
  File "/home/export/base/sc100010/sc100010/.conda/envs/torch/lib/python3.10/site-packages/transformers/trainer.py", line 2241, in train
    return inner_training_loop(
  File "/home/export/base/sc100010/sc100010/.conda/envs/torch/lib/python3.10/site-packages/transformers/trainer.py", line 2599, in _inner_training_loop
    self.optimizer.step()
  File "/home/export/base/sc100010/sc100010/.conda/envs/torch/lib/python3.10/site-packages/accelerate/optimizer.py", line 178, in step
    self.optimizer.step(closure)
  File "/home/export/base/sc100010/sc100010/.conda/envs/torch/lib/python3.10/site-packages/torch/optim/lr_scheduler.py", line 137, in wrapper
    return func.__get__(opt, opt.__class__)(*args, **kwargs)
  File "/home/export/base/sc100010/sc100010/.conda/envs/torch/lib/python3.10/site-packages/torch/optim/optimizer.py", line 487, in wrapper
    out = func(*args, **kwargs)
  File "/home/export/base/sc100010/sc100010/.conda/envs/torch/lib/python3.10/site-packages/torch/optim/optimizer.py", line 91, in _use_grad
    ret = func(self, *args, **kwargs)
  File "/home/export/base/sc100010/sc100010/.conda/envs/torch/lib/python3.10/site-packages/torch/optim/adamw.py", line 220, in step
    adamw(
  File "/home/export/base/sc100010/sc100010/.conda/envs/torch/lib/python3.10/site-packages/torch/optim/optimizer.py", line 154, in maybe_fallback
    return func(*args, **kwargs)
  File "/home/export/base/sc100010/sc100010/.conda/envs/torch/lib/python3.10/site-packages/torch/optim/adamw.py", line 782, in adamw
    func(
  File "/home/export/base/sc100010/sc100010/.conda/envs/torch/lib/python3.10/site-packages/torch/optim/adamw.py", line 606, in _multi_tensor_adamw
    exp_avg_sq_sqrt = torch._foreach_sqrt(device_exp_avg_sqs)
torch.OutOfMemoryError: CUDA out of memory. Tried to allocate 130.00 MiB. GPU 0 has a total capacity of 94.99 GiB of which 87.19 MiB is free. Including non-PyTorch memory, this process has 94.90 GiB memory in use. Of the allocated memory 90.98 GiB is allocated by PyTorch, and 2.01 GiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)
[rank0]: Traceback (most recent call last):
[rank0]:   File "/online1/sc100010/sc100010/qb_project/MARL/trl_GRPO_train.py", line 176, in <module>
[rank0]:     trainer.train()
[rank0]:   File "/home/export/base/sc100010/sc100010/.conda/envs/torch/lib/python3.10/site-packages/transformers/trainer.py", line 2241, in train
[rank0]:     return inner_training_loop(
[rank0]:   File "/home/export/base/sc100010/sc100010/.conda/envs/torch/lib/python3.10/site-packages/transformers/trainer.py", line 2599, in _inner_training_loop
[rank0]:     self.optimizer.step()
[rank0]:   File "/home/export/base/sc100010/sc100010/.conda/envs/torch/lib/python3.10/site-packages/accelerate/optimizer.py", line 178, in step
[rank0]:     self.optimizer.step(closure)
[rank0]:   File "/home/export/base/sc100010/sc100010/.conda/envs/torch/lib/python3.10/site-packages/torch/optim/lr_scheduler.py", line 137, in wrapper
[rank0]:     return func.__get__(opt, opt.__class__)(*args, **kwargs)
[rank0]:   File "/home/export/base/sc100010/sc100010/.conda/envs/torch/lib/python3.10/site-packages/torch/optim/optimizer.py", line 487, in wrapper
[rank0]:     out = func(*args, **kwargs)
[rank0]:   File "/home/export/base/sc100010/sc100010/.conda/envs/torch/lib/python3.10/site-packages/torch/optim/optimizer.py", line 91, in _use_grad
[rank0]:     ret = func(self, *args, **kwargs)
[rank0]:   File "/home/export/base/sc100010/sc100010/.conda/envs/torch/lib/python3.10/site-packages/torch/optim/adamw.py", line 220, in step
[rank0]:     adamw(
[rank0]:   File "/home/export/base/sc100010/sc100010/.conda/envs/torch/lib/python3.10/site-packages/torch/optim/optimizer.py", line 154, in maybe_fallback
[rank0]:     return func(*args, **kwargs)
[rank0]:   File "/home/export/base/sc100010/sc100010/.conda/envs/torch/lib/python3.10/site-packages/torch/optim/adamw.py", line 782, in adamw
[rank0]:     func(
[rank0]:   File "/home/export/base/sc100010/sc100010/.conda/envs/torch/lib/python3.10/site-packages/torch/optim/adamw.py", line 606, in _multi_tensor_adamw
[rank0]:     exp_avg_sq_sqrt = torch._foreach_sqrt(device_exp_avg_sqs)
[rank0]: torch.OutOfMemoryError: CUDA out of memory. Tried to allocate 130.00 MiB. GPU 0 has a total capacity of 94.99 GiB of which 87.19 MiB is free. Including non-PyTorch memory, this process has 94.90 GiB memory in use. Of the allocated memory 90.98 GiB is allocated by PyTorch, and 2.01 GiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)
[rank1]: Traceback (most recent call last):
[rank1]:   File "/online1/sc100010/sc100010/qb_project/MARL/trl_GRPO_train.py", line 176, in <module>
[rank1]:     trainer.train()
[rank1]:   File "/home/export/base/sc100010/sc100010/.conda/envs/torch/lib/python3.10/site-packages/transformers/trainer.py", line 2241, in train
[rank1]:     return inner_training_loop(
[rank1]:   File "/home/export/base/sc100010/sc100010/.conda/envs/torch/lib/python3.10/site-packages/transformers/trainer.py", line 2599, in _inner_training_loop
[rank1]:     self.optimizer.step()
[rank1]:   File "/home/export/base/sc100010/sc100010/.conda/envs/torch/lib/python3.10/site-packages/accelerate/optimizer.py", line 178, in step
[rank1]:     self.optimizer.step(closure)
[rank1]:   File "/home/export/base/sc100010/sc100010/.conda/envs/torch/lib/python3.10/site-packages/torch/optim/lr_scheduler.py", line 137, in wrapper
[rank1]:     return func.__get__(opt, opt.__class__)(*args, **kwargs)
[rank1]:   File "/home/export/base/sc100010/sc100010/.conda/envs/torch/lib/python3.10/site-packages/torch/optim/optimizer.py", line 487, in wrapper
[rank1]:     out = func(*args, **kwargs)
[rank1]:   File "/home/export/base/sc100010/sc100010/.conda/envs/torch/lib/python3.10/site-packages/torch/optim/optimizer.py", line 91, in _use_grad
[rank1]:     ret = func(self, *args, **kwargs)
[rank1]:   File "/home/export/base/sc100010/sc100010/.conda/envs/torch/lib/python3.10/site-packages/torch/optim/adamw.py", line 220, in step
[rank1]:     adamw(
[rank1]:   File "/home/export/base/sc100010/sc100010/.conda/envs/torch/lib/python3.10/site-packages/torch/optim/optimizer.py", line 154, in maybe_fallback
[rank1]:     return func(*args, **kwargs)
[rank1]:   File "/home/export/base/sc100010/sc100010/.conda/envs/torch/lib/python3.10/site-packages/torch/optim/adamw.py", line 782, in adamw
[rank1]:     func(
[rank1]:   File "/home/export/base/sc100010/sc100010/.conda/envs/torch/lib/python3.10/site-packages/torch/optim/adamw.py", line 606, in _multi_tensor_adamw
[rank1]:     exp_avg_sq_sqrt = torch._foreach_sqrt(device_exp_avg_sqs)
[rank1]: torch.OutOfMemoryError: CUDA out of memory. Tried to allocate 130.00 MiB. GPU 1 has a total capacity of 94.99 GiB of which 127.19 MiB is free. Including non-PyTorch memory, this process has 94.86 GiB memory in use. Of the allocated memory 92.27 GiB is allocated by PyTorch, and 682.32 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)
[rank2]: Traceback (most recent call last):
[rank2]:   File "/online1/sc100010/sc100010/qb_project/MARL/trl_GRPO_train.py", line 176, in <module>
[rank2]:     trainer.train()
[rank2]:   File "/home/export/base/sc100010/sc100010/.conda/envs/torch/lib/python3.10/site-packages/transformers/trainer.py", line 2241, in train
[rank2]:     return inner_training_loop(
[rank2]:   File "/home/export/base/sc100010/sc100010/.conda/envs/torch/lib/python3.10/site-packages/transformers/trainer.py", line 2599, in _inner_training_loop
[rank2]:     self.optimizer.step()
[rank2]:   File "/home/export/base/sc100010/sc100010/.conda/envs/torch/lib/python3.10/site-packages/accelerate/optimizer.py", line 178, in step
[rank2]:     self.optimizer.step(closure)
[rank2]:   File "/home/export/base/sc100010/sc100010/.conda/envs/torch/lib/python3.10/site-packages/torch/optim/lr_scheduler.py", line 137, in wrapper
[rank2]:     return func.__get__(opt, opt.__class__)(*args, **kwargs)
[rank2]:   File "/home/export/base/sc100010/sc100010/.conda/envs/torch/lib/python3.10/site-packages/torch/optim/optimizer.py", line 487, in wrapper
[rank2]:     out = func(*args, **kwargs)
[rank2]:   File "/home/export/base/sc100010/sc100010/.conda/envs/torch/lib/python3.10/site-packages/torch/optim/optimizer.py", line 91, in _use_grad
[rank2]:     ret = func(self, *args, **kwargs)
[rank2]:   File "/home/export/base/sc100010/sc100010/.conda/envs/torch/lib/python3.10/site-packages/torch/optim/adamw.py", line 220, in step
[rank2]:     adamw(
[rank2]:   File "/home/export/base/sc100010/sc100010/.conda/envs/torch/lib/python3.10/site-packages/torch/optim/optimizer.py", line 154, in maybe_fallback
[rank2]:     return func(*args, **kwargs)
[rank2]:   File "/home/export/base/sc100010/sc100010/.conda/envs/torch/lib/python3.10/site-packages/torch/optim/adamw.py", line 782, in adamw
[rank2]:     func(
[rank2]:   File "/home/export/base/sc100010/sc100010/.conda/envs/torch/lib/python3.10/site-packages/torch/optim/adamw.py", line 606, in _multi_tensor_adamw
[rank2]:     exp_avg_sq_sqrt = torch._foreach_sqrt(device_exp_avg_sqs)
[rank2]: torch.OutOfMemoryError: CUDA out of memory. Tried to allocate 130.00 MiB. GPU 2 has a total capacity of 94.99 GiB of which 1.19 MiB is free. Including non-PyTorch memory, this process has 94.98 GiB memory in use. Of the allocated memory 92.43 GiB is allocated by PyTorch, and 703.27 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)
wandb: 
wandb: 🚀 View run outputs/Qwen2.5-7B-GRPO at: https://wandb.ai/bobo1398861921-nus/huggingface/runs/eorh7fyx
wandb: Find logs at: ../../../../../../../../online1/sc100010/sc100010/qb_project/MARL/wandb/run-20250223_173847-eorh7fyx/logs
W0223 17:39:00.122000 40366 /online1/sc100010/sc100010/.conda/envs/torch/lib/python3.10/site-packages/torch/distributed/elastic/multiprocessing/api.py:897] Sending process 40439 closing signal SIGTERM
W0223 17:39:00.125000 40366 /online1/sc100010/sc100010/.conda/envs/torch/lib/python3.10/site-packages/torch/distributed/elastic/multiprocessing/api.py:897] Sending process 40440 closing signal SIGTERM
E0223 17:39:00.741000 40366 /online1/sc100010/sc100010/.conda/envs/torch/lib/python3.10/site-packages/torch/distributed/elastic/multiprocessing/api.py:869] failed (exitcode: 1) local_rank: 2 (pid: 40441) of binary: /home/export/base/sc100010/sc100010/.conda/envs/torch/bin/python
Traceback (most recent call last):
  File "/home/export/base/sc100010/sc100010/.conda/envs/torch/bin/accelerate", line 8, in <module>
    sys.exit(main())
  File "/home/export/base/sc100010/sc100010/.conda/envs/torch/lib/python3.10/site-packages/accelerate/commands/accelerate_cli.py", line 48, in main
    args.func(args)
  File "/home/export/base/sc100010/sc100010/.conda/envs/torch/lib/python3.10/site-packages/accelerate/commands/launch.py", line 1163, in launch_command
    multi_gpu_launcher(args)
  File "/home/export/base/sc100010/sc100010/.conda/envs/torch/lib/python3.10/site-packages/accelerate/commands/launch.py", line 792, in multi_gpu_launcher
    distrib_run.run(args)
  File "/home/export/base/sc100010/sc100010/.conda/envs/torch/lib/python3.10/site-packages/torch/distributed/run.py", line 910, in run
    elastic_launch(
  File "/home/export/base/sc100010/sc100010/.conda/envs/torch/lib/python3.10/site-packages/torch/distributed/launcher/api.py", line 138, in __call__
    return launch_agent(self._config, self._entrypoint, list(args))
  File "/home/export/base/sc100010/sc100010/.conda/envs/torch/lib/python3.10/site-packages/torch/distributed/launcher/api.py", line 269, in launch_agent
    raise ChildFailedError(
torch.distributed.elastic.multiprocessing.errors.ChildFailedError: 
============================================================
trl_GRPO_train.py FAILED
------------------------------------------------------------
Failures:
  <NO_OTHER_FAILURES>
------------------------------------------------------------
Root Cause (first observed failure):
[0]:
  time      : 2025-02-23_17:39:00
  host      : gpu018
  rank      : 2 (local_rank: 2)
  exitcode  : 1 (pid: 40441)
  error_file: <N/A>
  traceback : To enable traceback see: https://pytorch.org/docs/stable/elastic/errors.html
============================================================

@congchan
Copy link
Contributor

congchan commented Feb 26, 2025

Hi, @qgallouedec thanks for this PR, which seems to solve my previous concerns, which is:
Could I scale my num_generations for each prompt with below setting by this PR:

per_device_train_batch_size=x (some value that would not cause each worker OOM)
num_generations=y (some value scalable, such as 64, 128 ... 1024. etc..)

as long as per_device_train_batch_size x num_generations % number_worker(except vllm worker) == 0

@mengban
Copy link
Contributor

mengban commented Mar 3, 2025

when set user_vllm: true , Is this logic still valid?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.