diff --git a/trl/scripts/grpo.py b/trl/scripts/grpo.py index 552f6c3a4c..4b336b28e9 100644 --- a/trl/scripts/grpo.py +++ b/trl/scripts/grpo.py @@ -60,7 +60,7 @@ def main(script_args, training_args, model_args): # Initialize the GRPO trainer trainer = GRPOTrainer( model=model, - reward_model=reward_model, + reward_funcs=reward_model, args=training_args, train_dataset=dataset[script_args.dataset_train_split], eval_dataset=dataset[script_args.dataset_test_split] if training_args.eval_strategy != "no" else None,