Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

👐 DeepSpeed integration for GRPO #2652

Merged
merged 1 commit into from
Jan 25, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions trl/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
_import_structure = {
"modeling_base": ["GeometricMixtureWrapper", "PreTrainedModelWrapper", "create_reference_model"],
"modeling_value_head": ["AutoModelForCausalLMWithValueHead", "AutoModelForSeq2SeqLMWithValueHead"],
"utils": ["SUPPORTED_ARCHITECTURES", "setup_chat_format", "unwrap_model_for_generation"],
"utils": ["SUPPORTED_ARCHITECTURES", "prepare_deepspeed", "setup_chat_format", "unwrap_model_for_generation"],
}

try:
Expand All @@ -39,7 +39,7 @@
if TYPE_CHECKING:
from .modeling_base import GeometricMixtureWrapper, PreTrainedModelWrapper, create_reference_model
from .modeling_value_head import AutoModelForCausalLMWithValueHead, AutoModelForSeq2SeqLMWithValueHead
from .utils import SUPPORTED_ARCHITECTURES, setup_chat_format, unwrap_model_for_generation
from .utils import SUPPORTED_ARCHITECTURES, prepare_deepspeed, setup_chat_format, unwrap_model_for_generation

try:
if not is_diffusers_available():
Expand Down
35 changes: 35 additions & 0 deletions trl/models/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

import itertools
from contextlib import contextmanager
from copy import deepcopy
from dataclasses import dataclass
from typing import TYPE_CHECKING, Literal, Optional, Union

Expand Down Expand Up @@ -193,3 +194,37 @@ def unwrap_model_for_generation(
add_hooks(model)
else:
yield unwrapped_model


def prepare_deepspeed(model, accelerator):
# Adapted from accelerate: https://github.com/huggingface/accelerate/blob/739b135f8367becb67ffaada12fe76e3aa60fefd/src/accelerate/accelerator.py#L1473
deepspeed_plugin = accelerator.state.deepspeed_plugin
config_kwargs = deepcopy(deepspeed_plugin.deepspeed_config)
stage = config_kwargs["zero_optimization"]["stage"]

if model is not None:
hidden_size = (
max(model.config.hidden_sizes)
if getattr(model.config, "hidden_sizes", None)
else getattr(model.config, "hidden_size", None)
)
if hidden_size is not None and stage == 3:
# Note that `stage3_prefetch_bucket_size` can produce DeepSpeed messages like: `Invalidate trace cache
# @ step 0: expected module 1, but got module 0`
# This is expected and is not an error, see: https://github.com/microsoft/DeepSpeed/discussions/4081
config_kwargs.update(
{
"zero_optimization.reduce_bucket_size": hidden_size * hidden_size,
"zero_optimization.stage3_param_persistence_threshold": 10 * hidden_size,
"zero_optimization.stage3_prefetch_bucket_size": 0.9 * hidden_size * hidden_size,
}
)

# If ZeRO-3 is used, we shard both the active and reference model.
# Otherwise, we assume the reference model fits in memory and is initialized on each device with ZeRO
# disabled (stage 0)
if stage != 3:
config_kwargs["zero_optimization"]["stage"] = 0
model, *_ = deepspeed.initialize(model=model, config=config_kwargs)
model.eval()
return model
14 changes: 11 additions & 3 deletions trl/trainer/grpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,11 @@
TrainerCallback,
is_wandb_available,
)
from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled
from transformers.utils import is_peft_available

from ..data_utils import apply_chat_template, is_conversational, maybe_apply_chat_template
from ..models import create_reference_model, unwrap_model_for_generation
from ..models import create_reference_model, prepare_deepspeed, unwrap_model_for_generation
from .grpo_config import GRPOConfig
from .utils import generate_model_card, get_comet_experiment_url

Expand Down Expand Up @@ -158,6 +159,7 @@ def __init__(
# Trained model
model_init_kwargs = args.model_init_kwargs or {}
if isinstance(model, str):
model_id = model
torch_dtype = model_init_kwargs.get("torch_dtype")
if isinstance(torch_dtype, torch.dtype) or torch_dtype == "auto" or torch_dtype is None:
pass # torch_dtype is already a torch.dtype or "auto" or None
Expand All @@ -171,6 +173,7 @@ def __init__(
)
model = AutoModelForCausalLM.from_pretrained(model, **model_init_kwargs)
else:
model_id = model.config._name_or_path
if args.model_init_kwargs is not None:
raise ValueError(
"You passed `model_init_kwargs` to the `GRPOConfig`, but your model is already instantiated. "
Expand All @@ -181,7 +184,9 @@ def __init__(
model = get_peft_model(model, peft_config)

# Reference model
if peft_config is None:
if is_deepspeed_zero3_enabled():
self.ref_model = AutoModelForCausalLM.from_pretrained(model_id, **model_init_kwargs)
elif peft_config is None:
# If PEFT configuration is not provided, create a reference model based on the initial model.
self.ref_model = create_reference_model(model)
else:
Expand Down Expand Up @@ -269,7 +274,10 @@ def data_collator(features): # No data collation is needed in GRPO
self.model_accepts_loss_kwargs = False

if self.ref_model is not None:
self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True)
if self.is_deepspeed_enabled:
self.ref_model = prepare_deepspeed(self.ref_model, self.accelerator)
else:
self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True)

for i, reward_func in enumerate(self.reward_funcs):
if isinstance(reward_func, PreTrainedModel):
Expand Down
Loading