Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

🫷 Include stop token in policy model's generation_config #2528

Merged
merged 11 commits into from
Jan 22, 2025
18 changes: 12 additions & 6 deletions trl/trainer/ppo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,10 +138,18 @@ def __init__(
if data_collator is None:
data_collator = DataCollatorWithPadding(self.processing_class)

self.policy_model.generation_config.eos_token_id = (
None # disable `pad_token_id` and `eos_token_id` because we just want to
)
self.policy_model.generation_config.pad_token_id = None # generate tokens without truncation / padding
# Handle stop token settings
if args.stop_token and args.stop_token_id:
raise ValueError("You cannot set both `stop_token` and `stop_token_id`. ")
qgallouedec marked this conversation as resolved.
Show resolved Hide resolved
if args.stop_token:
if args.stop_token == "eos":
args.stop_token_id = processing_class.eos_token_id
else:
raise ValueError(
f"Unknown `stop_token` {args.stop_token}. " f"Allowed values are: `eos`, None (no stop token)"
)
# Update policy model's generation_config to use provided stop token
self.policy_model.generation_config.eos_token_id = args.stop_token_id
qgallouedec marked this conversation as resolved.
Show resolved Hide resolved

# peft support
if not is_peft_available() and peft_config is not None:
Expand Down Expand Up @@ -220,8 +228,6 @@ def __init__(
for module in [self.policy_model, self.ref_model, self.value_model, self.reward_model]:
if module is not None:
disable_dropout_in_model(module)
if args.stop_token and args.stop_token == "eos":
args.stop_token_id = processing_class.eos_token_id
self.model = PolicyAndValueWrapper(self.policy_model, self.value_model)
self.model.config = self.policy_model.config # needed for pushing to hub
self.create_optimizer_and_scheduler(
Expand Down