-
Notifications
You must be signed in to change notification settings - Fork 1.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
🫔 [GRPO] Pass wrapped model to unwrap_model_for_generation
for DeepSpeed Stage-3 compatibility
#2871
🫔 [GRPO] Pass wrapped model to unwrap_model_for_generation
for DeepSpeed Stage-3 compatibility
#2871
Conversation
I have already submitted a similar submission |
Thanks for your contribution @kiddj! Can you provide a code that would fail with main branch and not with yours? I can get one. Currently, the following works for me: # 2871.py
from datasets import load_dataset
from trl import GRPOConfig, GRPOTrainer
dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train")
training_args = GRPOConfig(
output_dir="2871",
per_device_train_batch_size=3, # reduce the batch size to reduce memory usage
num_generations=3, # reduce the number of generations to reduce memory usage
max_completion_length=32, # reduce the completion length to reduce memory usage
max_prompt_length=32,
bf16=True,
report_to="none",
)
def dummy_reward_func(completions, **kwargs):
return [0.0] * len(completions)
trainer = GRPOTrainer(
model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5",
reward_funcs=dummy_reward_func,
args=training_args,
train_dataset=dataset,
)
trainer.train() # 2 GPUs
accelerate launch --config_file examples/accelerate_configs/deepspeed_zero3.yaml --num_processes 2 sandbox/2871.py |
@kiddj what was the crash you hit, can you share a stack trace? I just hit the below I think it might've been what your PR aims to resolve:
|
After pulling in the PR, I did not get the same error. So I think this PR works |
unwrap_model_for_generation
for DeepSpeed Stage-3 compatibilityunwrap_model_for_generation
for DeepSpeed Stage-3 compatibility
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks!
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
What does this PR do?
Fixes DeepSpeed Stage-3 compatibility by passing the wrapped model (
self.model_wrapped
) tounwrap-model_for_generation
instead ofself.model
.Previously,
unwrap_model_for_generation(model)
was called with a model passed via function likecompute_loss
, which avoided any DeepSpeed Stage-3 conflicts. In the current implementation, the model is obtained directly fromself
, causing a crash under Stage-3.This fix ensures consistency by always passing the wrapped model instead of
self.model
, preventing the Stage-3 compatibility issue.Before submitting
Pull Request section?
to it if that's the case.
documentation guidelines.
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.
@qgallouedec