Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

🫷 Include stop token in policy model's generation_config #2528

Merged
merged 11 commits into from
Jan 22, 2025
26 changes: 16 additions & 10 deletions trl/trainer/ppo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,10 +138,18 @@ def __init__(
if data_collator is None:
data_collator = DataCollatorWithPadding(self.processing_class)

self.policy_model.generation_config.eos_token_id = (
None # disable `pad_token_id` and `eos_token_id` because we just want to
)
self.policy_model.generation_config.pad_token_id = None # generate tokens without truncation / padding
# Handle stop token settings: update policy model's generation_config to use provided stop token
if args.stop_token and args.stop_token_id:
raise ValueError("You cannot set both `stop_token` and `stop_token_id`.")
elif args.stop_token:
if args.stop_token == "eos":
self.policy_model.generation_config.eos_token_id = self.stop_token_id = processing_class.eos_token_id
else:
raise ValueError(
f"Unknown `stop_token` {args.stop_token}. Allowed values are: `'eos'` and `None` (no stop token)."
)
else:
self.policy_model.generation_config.eos_token_id = self.stop_token_id = args.stop_token_id # None or int

# peft support
if not is_peft_available() and peft_config is not None:
Expand Down Expand Up @@ -220,8 +228,6 @@ def __init__(
for module in [self.policy_model, self.ref_model, self.value_model, self.reward_model]:
if module is not None:
disable_dropout_in_model(module)
if args.stop_token and args.stop_token == "eos":
args.stop_token_id = processing_class.eos_token_id
self.model = PolicyAndValueWrapper(self.policy_model, self.value_model)
self.model.config = self.policy_model.config # needed for pushing to hub
self.create_optimizer_and_scheduler(
Expand Down Expand Up @@ -449,9 +455,9 @@ def repeat_generator():

# Response Processing 1. truncate response after the first occurrence of `stop_token_id`
postprocessed_response = response
if args.stop_token_id is not None: # handle the edge case when stop_token_id exists but is 0
if self.stop_token_id is not None: # handle the edge case when stop_token_id exists but is 0
postprocessed_response = truncate_response(
args.stop_token_id, processing_class.pad_token_id, response
self.stop_token_id, processing_class.pad_token_id, response
)

# Response Processing 2. run reward model on the truncated responses
Expand Down Expand Up @@ -706,9 +712,9 @@ def generate_completions(self, sampling: bool = False):
)
response = query_response[:, context_length:]
postprocessed_response = response
if args.stop_token_id is not None: # handle the edge case when stop_token_id exists but is 0
if self.stop_token_id is not None: # handle the edge case when stop_token_id exists but is 0
postprocessed_response = truncate_response(
args.stop_token_id, processing_class.pad_token_id, response
self.stop_token_id, processing_class.pad_token_id, response
)
table["query"].extend(
gather_object(processing_class.batch_decode(query, skip_special_tokens=True))
Expand Down
20 changes: 16 additions & 4 deletions trl/trainer/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -993,9 +993,15 @@ class OnPolicyConfig(TrainingArguments):
response_length (`int`, *optional*, defaults to `53`):
Length of the response.
stop_token (`str` or `None`, *optional*, defaults to `None`):
Stop token.
Specifies the stop token to use for text generation. This parameter is mutually exclusive with
`stop_token_id`.

- `None`: No stop token is applied, unless `stop_token_id` is specified.
- `'eos'`: Uses the tokenizer's `eos_token`.

stop_token_id (`int` or `None`, *optional*, defaults to `None`):
Truncation token id.
Specifies the ID of the stop token to use for text generation. If `None`, no stop token ID is applied,
unless `stop_token` is specified. This parameter is mutually exclusive with `stop_token`.
temperature (`float`, *optional*, defaults to `0.7`):
Sampling temperature.
missing_eos_penalty (`float` or `None`, *optional*, defaults to `None`):
Expand Down Expand Up @@ -1054,11 +1060,17 @@ class OnPolicyConfig(TrainingArguments):
)
stop_token: Optional[Literal["eos"]] = field(
default=None,
metadata={"help": "Stop token."},
metadata={
"help": "Specifies the stop token to use for text generation. This parameter is mutually exclusive with "
"`stop_token_id`."
},
)
stop_token_id: Optional[int] = field(
default=None,
metadata={"help": "Truncation token id."},
metadata={
"help": "Specifies the ID of the stop token to use for text generation. If `None`, no stop token ID is "
"applied, unless `stop_token` is specified. This parameter is mutually exclusive with `stop_token`."
},
)
temperature: float = field(
default=0.7,
Expand Down
Loading