Skip to content

Commit

Permalink
📍 Disable caching when grad checkpointing enable in GRPO (#2653)
Browse files Browse the repository at this point in the history
* disable caching when grad checkpointing

* style
  • Loading branch information
qgallouedec authored Jan 25, 2025
1 parent 317d2d4 commit 807046b
Showing 1 changed file with 4 additions and 0 deletions.
4 changes: 4 additions & 0 deletions trl/trainer/grpo_trainer.py
Original file line number Diff line number Diff line change
@@ -172,6 +172,10 @@ def __init__(
"Invalid `torch_dtype` passed to `GRPOConfig`. Expected either 'auto' or a string representing "
f"a `torch.dtype` (e.g., 'float32'), but got {torch_dtype}."
)
# Disable caching if gradient checkpointing is enabled (not supported)
model_init_kwargs["use_cache"] = (
False if args.gradient_checkpointing else model_init_kwargs.get("use_cache")
)
model = AutoModelForCausalLM.from_pretrained(model, **model_init_kwargs)
else:
model_id = model.config._name_or_path

0 comments on commit 807046b

Please sign in to comment.