Skip to content

Commit

Permalink
🫔 [GRPO] Pass wrapped model to unwrap_model_for_generation for DeepSp…
Browse files Browse the repository at this point in the history
…eed Stage-3 compatibility (#2871)

Co-authored-by: Quentin GallouĂ©dec <[email protected]>
  • Loading branch information
kiddj and qgallouedec authored Feb 28, 2025
1 parent 7bc9858 commit ad6a35b
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions trl/trainer/grpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -655,7 +655,7 @@ def _get_per_token_logps(self, model, input_ids, attention_mask, logits_to_keep)
@profiling_decorator
def _move_model_to_vllm(self):
with unwrap_model_for_generation(
self.model, self.accelerator, gather_deepspeed3_params=self.args.ds3_gather_for_generation
self.model_wrapped, self.accelerator, gather_deepspeed3_params=self.args.ds3_gather_for_generation
) as unwrapped_model:
if is_compiled_module(unwrapped_model):
unwrapped_model = unwrapped_model._orig_mod
Expand Down Expand Up @@ -754,7 +754,7 @@ def _generate_and_score_completions(
prompt_completion_ids = torch.cat([prompt_ids, completion_ids], dim=1)
else:
# Regular generation path
with unwrap_model_for_generation(self.model, self.accelerator) as unwrapped_model:
with unwrap_model_for_generation(self.model_wrapped, self.accelerator) as unwrapped_model:
prompt_completion_ids = unwrapped_model.generate(
prompt_ids, attention_mask=prompt_mask, generation_config=self.generation_config
)
Expand Down

0 comments on commit ad6a35b

Please sign in to comment.