Skip to content

Commit

Permalink
🚧 Add Optional ZeRO-3 Weight Gathering for GRPO in Sequence Generation (
Browse files Browse the repository at this point in the history
#2667)

* Add (grpo) unwrap_model_generation zero3 gathering

* proper placement

* Disabling this option is not compatible with vLLM generation.

---------

Co-authored-by: Quentin Gallouédec <[email protected]>
Co-authored-by: Quentin Gallouédec <[email protected]>
  • Loading branch information
3 people authored Feb 4, 2025
1 parent b2ae999 commit af4ad47
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 1 deletion.
14 changes: 14 additions & 0 deletions trl/trainer/grpo_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,11 @@ class GRPOConfig(TrainingArguments):
Temperature for sampling. The higher the temperature, the more random the completions.
max_completion_length (`int` or `None`, *optional*, defaults to `256`):
Maximum length of the generated completion.
ds3_gather_for_generation (`bool`, *optional*, defaults to `True`):
This setting applies to DeepSpeed ZeRO-3. If enabled, the policy model weights are gathered for generation,
improving generation speed. However, disabling this option allows training models that exceed the VRAM
capacity of a single GPU, albeit at the cost of slower generation. Disabling this option is not compatible
with vLLM generation.
> Parameters that control generation acceleration powered by vLLM
Expand Down Expand Up @@ -133,6 +138,15 @@ class GRPOConfig(TrainingArguments):
default=256,
metadata={"help": "Maximum length of the generated completion."},
)
ds3_gather_for_generation: bool = field(
default=True,
metadata={
"help": "This setting applies to DeepSpeed ZeRO-3. If enabled, the policy model weights are gathered for "
"generation, improving generation speed. However, disabling this option allows training models that "
"exceed the VRAM capacity of a single GPU, albeit at the cost of slower generation. Disabling this option "
"is not compatible with vLLM generation."
},
)

# Parameters that control generation acceleration powered by vLLM
use_vllm: Optional[bool] = field(
Expand Down
4 changes: 3 additions & 1 deletion trl/trainer/grpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -407,7 +407,9 @@ def _prepare_inputs(self, inputs: dict[str, Union[torch.Tensor, Any]]) -> dict[s
if self.args.use_vllm:
# First, have main process load weights if needed
if self.state.global_step != self._last_loaded_step:
with unwrap_model_for_generation(self.model, self.accelerator) as unwrapped_model:
with unwrap_model_for_generation(
self.model, self.accelerator, gather_deepspeed3_params=self.args.ds3_gather_for_generation
) as unwrapped_model:
if is_compiled_module(unwrapped_model):
state_dict = unwrapped_model._orig_mod.state_dict()
else:
Expand Down

0 comments on commit af4ad47

Please sign in to comment.