From 15ec8f485f8e6e7ab1def5cbaed41fa770943ab0 Mon Sep 17 00:00:00 2001 From: Dawid Motyka Date: Sat, 28 Dec 2024 14:31:12 +0100 Subject: [PATCH 01/10] Include stop token in policy model's generation_config --- trl/trainer/ppo_trainer.py | 19 +++++++++++++------ 1 file changed, 13 insertions(+), 6 deletions(-) diff --git a/trl/trainer/ppo_trainer.py b/trl/trainer/ppo_trainer.py index 51897eeb44..804f2b427b 100644 --- a/trl/trainer/ppo_trainer.py +++ b/trl/trainer/ppo_trainer.py @@ -138,10 +138,19 @@ def __init__( if data_collator is None: data_collator = DataCollatorWithPadding(self.processing_class) - self.policy_model.generation_config.eos_token_id = ( - None # disable `pad_token_id` and `eos_token_id` because we just want to - ) - self.policy_model.generation_config.pad_token_id = None # generate tokens without truncation / padding + # Handle stop token settings + if args.stop_token and args.stop_token_id: + raise ValueError("You cannot set both `stop_token` and `stop_token_id`. ") + if args.stop_token: + if args.stop_token == "eos": + args.stop_token_id = processing_class.eos_token_id + else: + raise ValueError( + f"Unknown `stop_token` {args.stop_token}. " + f"Allowed values are: `eos`, None (no stop token)" + ) + # Update policy model's generation_config to use provided stop token + self.policy_model.generation_config.eos_token_id = args.stop_token_id # peft support if not is_peft_available() and peft_config is not None: @@ -220,8 +229,6 @@ def __init__( for module in [self.policy_model, self.ref_model, self.value_model, self.reward_model]: if module is not None: disable_dropout_in_model(module) - if args.stop_token and args.stop_token == "eos": - args.stop_token_id = processing_class.eos_token_id self.model = PolicyAndValueWrapper(self.policy_model, self.value_model) self.model.config = self.policy_model.config # needed for pushing to hub self.create_optimizer_and_scheduler( From 1c4e9edfce5ecce9609dd023f7cb8cc4ffd4ca23 Mon Sep 17 00:00:00 2001 From: Dawid Motyka Date: Sat, 28 Dec 2024 15:21:31 +0100 Subject: [PATCH 02/10] Fix formatting --- trl/trainer/ppo_trainer.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/trl/trainer/ppo_trainer.py b/trl/trainer/ppo_trainer.py index 804f2b427b..723c3d8132 100644 --- a/trl/trainer/ppo_trainer.py +++ b/trl/trainer/ppo_trainer.py @@ -146,8 +146,7 @@ def __init__( args.stop_token_id = processing_class.eos_token_id else: raise ValueError( - f"Unknown `stop_token` {args.stop_token}. " - f"Allowed values are: `eos`, None (no stop token)" + f"Unknown `stop_token` {args.stop_token}. " f"Allowed values are: `eos`, None (no stop token)" ) # Update policy model's generation_config to use provided stop token self.policy_model.generation_config.eos_token_id = args.stop_token_id From 6f4f841f79d45bef953b42b86d52e6e008801c6c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= <45557362+qgallouedec@users.noreply.github.com> Date: Wed, 22 Jan 2025 09:04:55 +0100 Subject: [PATCH 03/10] Update trl/trainer/ppo_trainer.py --- trl/trainer/ppo_trainer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/trl/trainer/ppo_trainer.py b/trl/trainer/ppo_trainer.py index 723c3d8132..dec59c4489 100644 --- a/trl/trainer/ppo_trainer.py +++ b/trl/trainer/ppo_trainer.py @@ -140,7 +140,7 @@ def __init__( # Handle stop token settings if args.stop_token and args.stop_token_id: - raise ValueError("You cannot set both `stop_token` and `stop_token_id`. ") + raise ValueError("You cannot set both `stop_token` and `stop_token_id`.") if args.stop_token: if args.stop_token == "eos": args.stop_token_id = processing_class.eos_token_id From 5cd901b5f47925c34fd944e0e8099a5dd778afd2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= <45557362+qgallouedec@users.noreply.github.com> Date: Wed, 22 Jan 2025 09:05:03 +0100 Subject: [PATCH 04/10] Update trl/trainer/ppo_trainer.py --- trl/trainer/ppo_trainer.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/trl/trainer/ppo_trainer.py b/trl/trainer/ppo_trainer.py index dec59c4489..19eefdfd6b 100644 --- a/trl/trainer/ppo_trainer.py +++ b/trl/trainer/ppo_trainer.py @@ -141,13 +141,15 @@ def __init__( # Handle stop token settings if args.stop_token and args.stop_token_id: raise ValueError("You cannot set both `stop_token` and `stop_token_id`.") - if args.stop_token: + elif args.stop_token: if args.stop_token == "eos": - args.stop_token_id = processing_class.eos_token_id + stop_token_id = processing_class.eos_token_id else: raise ValueError( - f"Unknown `stop_token` {args.stop_token}. " f"Allowed values are: `eos`, None (no stop token)" + f"Unknown `stop_token` {args.stop_token}. " f"Allowed values are: `eos` and `None` (no stop token)" ) + elif args.stop_token_id: + stop_token_id = processing_class.eos_token_id # Update policy model's generation_config to use provided stop token self.policy_model.generation_config.eos_token_id = args.stop_token_id From d2c62aec7646853948ec4e839db5db40feb49f51 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Wed, 22 Jan 2025 08:16:06 +0000 Subject: [PATCH 05/10] don't modify args --- trl/trainer/ppo_trainer.py | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/trl/trainer/ppo_trainer.py b/trl/trainer/ppo_trainer.py index 4565ab3c58..4ebcda3190 100644 --- a/trl/trainer/ppo_trainer.py +++ b/trl/trainer/ppo_trainer.py @@ -138,20 +138,18 @@ def __init__( if data_collator is None: data_collator = DataCollatorWithPadding(self.processing_class) - # Handle stop token settings + # Handle stop token settings: update policy model's generation_config to use provided stop token if args.stop_token and args.stop_token_id: raise ValueError("You cannot set both `stop_token` and `stop_token_id`.") elif args.stop_token: if args.stop_token == "eos": - stop_token_id = processing_class.eos_token_id + self.policy_model.generation_config.eos_token_id = processing_class.eos_token_id else: raise ValueError( - f"Unknown `stop_token` {args.stop_token}. " f"Allowed values are: `eos` and `None` (no stop token)" + f"Unknown `stop_token` {args.stop_token}. Allowed values are: `'eos'` and `None` (no stop token)." ) - elif args.stop_token_id: - stop_token_id = processing_class.eos_token_id - # Update policy model's generation_config to use provided stop token - self.policy_model.generation_config.eos_token_id = args.stop_token_id + else: + self.policy_model.generation_config.eos_token_id = args.stop_token_id # either None or an integer # peft support if not is_peft_available() and peft_config is not None: From 91948b150593945a0a697a1216adf6d5cdfca241 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Wed, 22 Jan 2025 08:20:11 +0000 Subject: [PATCH 06/10] clarify doc --- trl/trainer/utils.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/trl/trainer/utils.py b/trl/trainer/utils.py index 1228dc7ece..3642ec6061 100644 --- a/trl/trainer/utils.py +++ b/trl/trainer/utils.py @@ -992,8 +992,12 @@ class OnPolicyConfig(TrainingArguments): Number of debugging samples generations (i.e., `generate_completions` calls) throughout training. response_length (`int`, *optional*, defaults to `53`): Length of the response. - stop_token (`str` or `None`, *optional*, defaults to `None`): - Stop token. + stop_token (`str` or `None`, *optional*, defaults to `None`): + Specifies the token at which truncation should stop: + + - `None`: No truncation is applied. + - `"eos"`: Uses the tokenizer's `eos_token` as the stop token. + stop_token_id (`int` or `None`, *optional*, defaults to `None`): Truncation token id. temperature (`float`, *optional*, defaults to `0.7`): From 7dd8c4c5888f7e07c044297781e5b5fe883e6182 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Wed, 22 Jan 2025 08:28:15 +0000 Subject: [PATCH 07/10] more nice doc --- trl/trainer/utils.py | 24 ++++++++++++++++-------- 1 file changed, 16 insertions(+), 8 deletions(-) diff --git a/trl/trainer/utils.py b/trl/trainer/utils.py index 3642ec6061..9bdca917ee 100644 --- a/trl/trainer/utils.py +++ b/trl/trainer/utils.py @@ -992,14 +992,16 @@ class OnPolicyConfig(TrainingArguments): Number of debugging samples generations (i.e., `generate_completions` calls) throughout training. response_length (`int`, *optional*, defaults to `53`): Length of the response. - stop_token (`str` or `None`, *optional*, defaults to `None`): - Specifies the token at which truncation should stop: - - - `None`: No truncation is applied. - - `"eos"`: Uses the tokenizer's `eos_token` as the stop token. + stop_token (`str` or `None`, *optional*, defaults to `None`): + Specifies the stop token to use for text generation. This parameter is mutually exclusive with + `stop_token_id`. + + - `None`: No stop token is applied, unless `stop_token_id` is specified. + - `'eos'`: Uses the tokenizer's `eos_token`. stop_token_id (`int` or `None`, *optional*, defaults to `None`): - Truncation token id. + Specifies the ID of the stop token to use for text generation. If `None`, stop token ID is applied, unless + `stop_token` is specified. This parameter is mutually exclusive with `stop_token`. temperature (`float`, *optional*, defaults to `0.7`): Sampling temperature. missing_eos_penalty (`float` or `None`, *optional*, defaults to `None`): @@ -1058,11 +1060,17 @@ class OnPolicyConfig(TrainingArguments): ) stop_token: Optional[Literal["eos"]] = field( default=None, - metadata={"help": "Stop token."}, + metadata={ + "help": "Specifies the stop token to use for text generation. This parameter is mutually exclusive with " + "`stop_token_id`." + }, ) stop_token_id: Optional[int] = field( default=None, - metadata={"help": "Truncation token id."}, + metadata={ + "help": "Specifies the ID of the stop token to use for text generation. If `None`, stop token ID is " + "applied, unless `stop_token` is specified. This parameter is mutually exclusive with `stop_token`." + }, ) temperature: float = field( default=0.7, From 7b8e66f69e7d11f87270372a1e79be43d29dd035 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Wed, 22 Jan 2025 08:29:24 +0000 Subject: [PATCH 08/10] missing no [ci skip] --- trl/trainer/utils.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/trl/trainer/utils.py b/trl/trainer/utils.py index 9bdca917ee..719d952f1f 100644 --- a/trl/trainer/utils.py +++ b/trl/trainer/utils.py @@ -1000,8 +1000,8 @@ class OnPolicyConfig(TrainingArguments): - `'eos'`: Uses the tokenizer's `eos_token`. stop_token_id (`int` or `None`, *optional*, defaults to `None`): - Specifies the ID of the stop token to use for text generation. If `None`, stop token ID is applied, unless - `stop_token` is specified. This parameter is mutually exclusive with `stop_token`. + Specifies the ID of the stop token to use for text generation. If `None`, no stop token ID is applied, + unless `stop_token` is specified. This parameter is mutually exclusive with `stop_token`. temperature (`float`, *optional*, defaults to `0.7`): Sampling temperature. missing_eos_penalty (`float` or `None`, *optional*, defaults to `None`): @@ -1068,7 +1068,7 @@ class OnPolicyConfig(TrainingArguments): stop_token_id: Optional[int] = field( default=None, metadata={ - "help": "Specifies the ID of the stop token to use for text generation. If `None`, stop token ID is " + "help": "Specifies the ID of the stop token to use for text generation. If `None`, no stop token ID is " "applied, unless `stop_token` is specified. This parameter is mutually exclusive with `stop_token`." }, ) From 57cc809c15c87a56024e765bf397b9b2760f9958 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Wed, 22 Jan 2025 08:33:57 +0000 Subject: [PATCH 09/10] really don't modify args --- trl/trainer/ppo_trainer.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/trl/trainer/ppo_trainer.py b/trl/trainer/ppo_trainer.py index 4ebcda3190..8facddaf15 100644 --- a/trl/trainer/ppo_trainer.py +++ b/trl/trainer/ppo_trainer.py @@ -143,13 +143,13 @@ def __init__( raise ValueError("You cannot set both `stop_token` and `stop_token_id`.") elif args.stop_token: if args.stop_token == "eos": - self.policy_model.generation_config.eos_token_id = processing_class.eos_token_id + self.policy_model.generation_config.eos_token_id = self.eos_token_id = processing_class.eos_token_id else: raise ValueError( f"Unknown `stop_token` {args.stop_token}. Allowed values are: `'eos'` and `None` (no stop token)." ) else: - self.policy_model.generation_config.eos_token_id = args.stop_token_id # either None or an integer + self.policy_model.generation_config.eos_token_id = self.eos_token_id = args.stop_token_id # None or int # peft support if not is_peft_available() and peft_config is not None: @@ -455,9 +455,9 @@ def repeat_generator(): # Response Processing 1. truncate response after the first occurrence of `stop_token_id` postprocessed_response = response - if args.stop_token_id is not None: # handle the edge case when stop_token_id exists but is 0 + if self.stop_token_id is not None: # handle the edge case when stop_token_id exists but is 0 postprocessed_response = truncate_response( - args.stop_token_id, processing_class.pad_token_id, response + self.stop_token_id, processing_class.pad_token_id, response ) # Response Processing 2. run reward model on the truncated responses @@ -712,9 +712,9 @@ def generate_completions(self, sampling: bool = False): ) response = query_response[:, context_length:] postprocessed_response = response - if args.stop_token_id is not None: # handle the edge case when stop_token_id exists but is 0 + if self.stop_token_id is not None: # handle the edge case when stop_token_id exists but is 0 postprocessed_response = truncate_response( - args.stop_token_id, processing_class.pad_token_id, response + self.stop_token_id, processing_class.pad_token_id, response ) table["query"].extend( gather_object(processing_class.batch_decode(query, skip_special_tokens=True)) From 6d78514a70dbaba87524a4a3f526c3dc9d785ad0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Wed, 22 Jan 2025 09:49:05 +0000 Subject: [PATCH 10/10] oups --- trl/trainer/ppo_trainer.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/trl/trainer/ppo_trainer.py b/trl/trainer/ppo_trainer.py index 8facddaf15..83926cfd6a 100644 --- a/trl/trainer/ppo_trainer.py +++ b/trl/trainer/ppo_trainer.py @@ -143,13 +143,13 @@ def __init__( raise ValueError("You cannot set both `stop_token` and `stop_token_id`.") elif args.stop_token: if args.stop_token == "eos": - self.policy_model.generation_config.eos_token_id = self.eos_token_id = processing_class.eos_token_id + self.policy_model.generation_config.eos_token_id = self.stop_token_id = processing_class.eos_token_id else: raise ValueError( f"Unknown `stop_token` {args.stop_token}. Allowed values are: `'eos'` and `None` (no stop token)." ) else: - self.policy_model.generation_config.eos_token_id = self.eos_token_id = args.stop_token_id # None or int + self.policy_model.generation_config.eos_token_id = self.stop_token_id = args.stop_token_id # None or int # peft support if not is_peft_available() and peft_config is not None: