Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

🫔 [GRPO] Pass wrapped model to unwrap_model_for_generation for DeepSpeed Stage-3 compatibility #2871

Merged
merged 2 commits into from
Feb 28, 2025

Conversation

kiddj
Copy link
Contributor

@kiddj kiddj commented Feb 15, 2025

What does this PR do?

Fixes DeepSpeed Stage-3 compatibility by passing the wrapped model (self.model_wrapped) to unwrap-model_for_generation instead of self.model.

Previously, unwrap_model_for_generation(model) was called with a model passed via function like compute_loss, which avoided any DeepSpeed Stage-3 conflicts. In the current implementation, the model is obtained directly from self, causing a crash under Stage-3.

This fix ensures consistency by always passing the wrapped model instead of self.model, preventing the Stage-3 compatibility issue.

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a GitHub issue? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines.
  • Did you write any new necessary tests?

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@qgallouedec

@dignfei
Copy link

dignfei commented Feb 15, 2025

I have already submitted a similar submission

@qgallouedec
Copy link
Member

qgallouedec commented Feb 18, 2025

Thanks for your contribution @kiddj! Can you provide a code that would fail with main branch and not with yours? I can get one. Currently, the following works for me:

# 2871.py
from datasets import load_dataset
from trl import GRPOConfig, GRPOTrainer

dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train")

training_args = GRPOConfig(
    output_dir="2871",
    per_device_train_batch_size=3,  # reduce the batch size to reduce memory usage
    num_generations=3,  # reduce the number of generations to reduce memory usage
    max_completion_length=32,  # reduce the completion length to reduce memory usage
    max_prompt_length=32,
    bf16=True,
    report_to="none",
)

def dummy_reward_func(completions, **kwargs):
    return [0.0] * len(completions)

trainer = GRPOTrainer(
    model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5",
    reward_funcs=dummy_reward_func,
    args=training_args,
    train_dataset=dataset,
)

trainer.train()
# 2 GPUs
accelerate launch --config_file examples/accelerate_configs/deepspeed_zero3.yaml --num_processes 2 sandbox/2871.py

@qgallouedec qgallouedec added the 😴 stale No update from the author, will be closed soon label Feb 24, 2025
@jamesbraza
Copy link
Contributor

@kiddj what was the crash you hit, can you share a stack trace?

I just hit the below AssertionError inside of grpo_trainer.py with accelerate launch --zero3_init_flag true --zero_stage 3, accelerate==1.4.0, deepspeed==0.16.4, current main branch of trl with #2963.

I think it might've been what your PR aims to resolve:

0: [rank3]: Traceback (most recent call last):
0: [rank3]:   File "/path/to/repo/train.py", line 268, in <module>
0: [rank3]:     main(script_args, training_args, model_args)
0: [rank3]:   File "/path/to/repo/train.py", line 254, in main
0: [rank3]:     trainer.train(**train_kw)
0: [rank3]:   File "/path/to/repo/.venv/lib/python3.12/site-packages/transformers/trainer.py", line 2245, in train
0: [rank3]:     return inner_training_loop(
0: [rank3]:            ^^^^^^^^^^^^^^^^^^^^
0: [rank3]:   File "/path/to/repo/.venv/lib/python3.12/site-packages/transformers/trainer.py", line 2556, in _inner_training_loop
0: [rank3]:     tr_loss_step = self.training_step(model, inputs, num_items_in_batch)
0: [rank3]:                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
0: [rank3]:   File "/path/to/repo/.venv/lib/python3.12/site-packages/transformers/trainer.py", line 3700, in training_step
0: [rank3]:     inputs = self._prepare_inputs(inputs)
0: [rank3]:              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
0: [rank3]:   File "/path/to/repo/.venv/lib/python3.12/site-packages/trl/extras/profiling.py", line 33, in wrapper
0: [rank3]:     result = func(self, *args, **kwargs)
0: [rank3]:              ^^^^^^^^^^^^^^^^^^^^^^^^^^^
0: [rank3]:   File "/path/to/repo/.venv/lib/python3.12/site-packages/trl/trainer/grpo_trainer.py", line 670, in _prepare_inputs
0: [rank3]:     inputs = self._generate_and_score_completions(inputs)
0: [rank3]:              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
0: [rank3]:   File "/path/to/repo/.venv/lib/python3.12/site-packages/trl/trainer/grpo_trainer.py", line 849, in _generate_and_score_completions
0: [rank3]:     with unwrap_model_for_generation(self.model, self.accelerator) as unwrapped_model:
0: [rank3]:          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
0: [rank3]:   File "/home/james/.local/share/uv/python/cpython-3.12.9-linux-x86_64-gnu/lib/python3.12/contextlib.py", line 144, in __exit__
0: [rank3]:     next(self.gen)
0: [rank3]:   File "/path/to/repo/.venv/lib/python3.12/site-packages/trl/models/utils.py", line 217, in unwrap_model_for_generation
0: [rank3]:     with deepspeed.zero.GatheredParameters(model.parameters()):
0: [rank3]:          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
0: [rank3]:   File "/path/to/repo/.venv/lib/python3.12/site-packages/deepspeed/runtime/zero/partition_parameters.py", line 2251, in __exit__
0: [rank3]:     self.params[0].partition(param_list=self.params, has_been_updated=False)
0: [rank3]:   File "/path/to/repo/.venv/lib/python3.12/site-packages/deepspeed/runtime/zero/partition_parameters.py", line 1394, in partition
0: [rank3]:     self._partition(param_list, has_been_updated=has_been_updated, free_data=True)
0: [rank3]:   File "/path/to/repo/.venv/lib/python3.12/site-packages/deepspeed/runtime/zero/partition_parameters.py", line 1543, in _partition
0: [rank3]:     self._partition_param(param, has_been_updated=has_been_updated, free_data=True)
0: [rank3]:   File "/path/to/repo/.venv/lib/python3.12/site-packages/deepspeed/utils/nvtx.py", line 18, in wrapped_fn
0: [rank3]:     ret_val = func(*args, **kwargs)
0: [rank3]:               ^^^^^^^^^^^^^^^^^^^^^
0: [rank3]:   File "/path/to/repo/.venv/lib/python3.12/site-packages/deepspeed/runtime/zero/partition_parameters.py", line 1552, in _partition_param
0: [rank3]:     assert param.ds_status is not ZeroParamStatus.INFLIGHT, f" {param} Cannot partition a param in flight"
0:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
0: AssertionError:  Parameter containing:
0: tensor([[-2.9785e-02,  2.1484e-02,  4.7913e-03,  ...,  4.6875e-02,
0:           1.6357e-02,  3.2959e-02],
0:         [-7.4219e-02, -1.3611e-02,  1.2756e-02,  ...,  5.8899e-03,
0:           2.9144e-03,  7.0496e-03],
0:         [-3.4180e-02, -1.3306e-02,  7.2937e-03,  ...,  4.2343e-04,
0:          -1.8616e-03,  3.4424e-02],
0:         ...,
0:         [ 2.1104e-26, -2.4840e-26,  7.0177e-27,  ..., -2.4032e-26,
0:          -1.1471e-25, -2.4032e-26],
0:         [ 5.0083e-26, -2.2618e-26, -1.3329e-26,  ..., -6.8662e-26,
0:           1.4439e-26,  3.2110e-26],
0:         [ 1.0350e-26, -4.8872e-26,  1.9993e-26,  ...,  1.4136e-26,
0:           2.8879e-26,  7.8255e-28]], device='cuda:0', dtype=torch.bfloat16,
0:        requires_grad=True) Cannot partition a param in flight

@jamesbraza
Copy link
Contributor

After pulling in the PR, I did not get the same error. So I think this PR works

@qgallouedec qgallouedec changed the title [GRPO] Pass wrapped model to unwrap_model_for_generation for DeepSpeed Stage-3 compatibility 🫔 [GRPO] Pass wrapped model to unwrap_model_for_generation for DeepSpeed Stage-3 compatibility Feb 28, 2025
Copy link
Member

@qgallouedec qgallouedec left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@qgallouedec qgallouedec merged commit ad6a35b into huggingface:main Feb 28, 2025
12 of 13 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
😴 stale No update from the author, will be closed soon
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants