diff --git a/trl/trainer/grpo_trainer.py b/trl/trainer/grpo_trainer.py index 3cddc9eb5c..1fdea42328 100644 --- a/trl/trainer/grpo_trainer.py +++ b/trl/trainer/grpo_trainer.py @@ -172,6 +172,10 @@ def __init__( "Invalid `torch_dtype` passed to `GRPOConfig`. Expected either 'auto' or a string representing " f"a `torch.dtype` (e.g., 'float32'), but got {torch_dtype}." ) + # Disable caching if gradient checkpointing is enabled (not supported) + model_init_kwargs["use_cache"] = ( + False if args.gradient_checkpointing else model_init_kwargs.get("use_cache") + ) model = AutoModelForCausalLM.from_pretrained(model, **model_init_kwargs) else: model_id = model.config._name_or_path