From 3b4c24946b7d5580fd354b0e3800fc1047b82a41 Mon Sep 17 00:00:00 2001 From: Xiao Yu <39458711+jasonyux@users.noreply.github.com> Date: Fri, 3 May 2024 18:19:35 -0400 Subject: [PATCH] fixed adding bos and eos token unconditionally (#1591) * fixed adding bos and eos token unconditionally * fixed typo of tokenizer -> self.tokenizer. Also added update to ORPO * fixed code quality, and added BOS/EOS fix to KTO * code reformatting with pre-commit run --all-files * bug fix: check input id length before checking for EOS/BOS --- trl/trainer/cpo_trainer.py | 35 ++++++++++++++++------------ trl/trainer/dpo_trainer.py | 35 ++++++++++++++++------------ trl/trainer/kto_trainer.py | 46 ++++++++++++++++++++++++++++--------- trl/trainer/orpo_trainer.py | 35 ++++++++++++++++------------ 4 files changed, 95 insertions(+), 56 deletions(-) diff --git a/trl/trainer/cpo_trainer.py b/trl/trainer/cpo_trainer.py index 04e2cf2c9b..b5004a2532 100644 --- a/trl/trainer/cpo_trainer.py +++ b/trl/trainer/cpo_trainer.py @@ -409,21 +409,26 @@ def tokenize_row(self, feature, model: Optional[Union[PreTrainedModel, nn.Module "last token due to tokenizer merge ops." ) - # add BOS token to head of prompt - prompt_tokens["prompt_input_ids"] = [self.tokenizer.bos_token_id] + prompt_tokens["prompt_input_ids"] - chosen_tokens["prompt_input_ids"] = [self.tokenizer.bos_token_id] + chosen_tokens["prompt_input_ids"] - rejected_tokens["prompt_input_ids"] = [self.tokenizer.bos_token_id] + rejected_tokens["prompt_input_ids"] - - prompt_tokens["prompt_attention_mask"] = [1] + prompt_tokens["prompt_attention_mask"] - chosen_tokens["prompt_attention_mask"] = [1] + chosen_tokens["prompt_attention_mask"] - rejected_tokens["prompt_attention_mask"] = [1] + rejected_tokens["prompt_attention_mask"] - - # add EOS token to end of answer - chosen_tokens["input_ids"].append(self.tokenizer.eos_token_id) - chosen_tokens["attention_mask"].append(1) - - rejected_tokens["input_ids"].append(self.tokenizer.eos_token_id) - rejected_tokens["attention_mask"].append(1) + # add BOS token to head of prompt. Avoid adding if it's already there + bos_token_id = self.tokenizer.bos_token_id + if prompt_len_input_ids == 0 or bos_token_id != prompt_tokens["prompt_input_ids"][0]: + prompt_tokens["prompt_input_ids"] = [bos_token_id] + prompt_tokens["prompt_input_ids"] + prompt_tokens["prompt_attention_mask"] = [1] + prompt_tokens["prompt_attention_mask"] + if chosen_prompt_len_input_ids == 0 or bos_token_id != chosen_tokens["prompt_input_ids"][0]: + chosen_tokens["prompt_input_ids"] = [bos_token_id] + chosen_tokens["prompt_input_ids"] + chosen_tokens["prompt_attention_mask"] = [1] + chosen_tokens["prompt_attention_mask"] + if rejected_prompt_len_input_ids == 0 or bos_token_id != rejected_tokens["prompt_input_ids"][0]: + rejected_tokens["prompt_input_ids"] = [bos_token_id] + rejected_tokens["prompt_input_ids"] + rejected_tokens["prompt_attention_mask"] = [1] + rejected_tokens["prompt_attention_mask"] + + # add EOS token to end of answer. Avoid adding if it's already there + eos_token_id = self.tokenizer.eos_token_id + if len(chosen_tokens["input_ids"]) == 0 or eos_token_id != chosen_tokens["input_ids"][-1]: + chosen_tokens["input_ids"].append(eos_token_id) + chosen_tokens["attention_mask"].append(1) + if len(rejected_tokens["input_ids"]) == 0 or eos_token_id != rejected_tokens["input_ids"][-1]: + rejected_tokens["input_ids"].append(eos_token_id) + rejected_tokens["attention_mask"].append(1) longer_response_length = max(len(chosen_tokens["input_ids"]), len(rejected_tokens["input_ids"])) diff --git a/trl/trainer/dpo_trainer.py b/trl/trainer/dpo_trainer.py index 5db34a6e3f..a990abb4c5 100644 --- a/trl/trainer/dpo_trainer.py +++ b/trl/trainer/dpo_trainer.py @@ -771,21 +771,26 @@ def tokenize_row(self, feature, model: Optional[Union[PreTrainedModel, nn.Module "last token due to tokenizer merge ops." ) - # add BOS token to head of prompt - prompt_tokens["prompt_input_ids"] = [self.tokenizer.bos_token_id] + prompt_tokens["prompt_input_ids"] - chosen_tokens["prompt_input_ids"] = [self.tokenizer.bos_token_id] + chosen_tokens["prompt_input_ids"] - rejected_tokens["prompt_input_ids"] = [self.tokenizer.bos_token_id] + rejected_tokens["prompt_input_ids"] - - prompt_tokens["prompt_attention_mask"] = [1] + prompt_tokens["prompt_attention_mask"] - chosen_tokens["prompt_attention_mask"] = [1] + chosen_tokens["prompt_attention_mask"] - rejected_tokens["prompt_attention_mask"] = [1] + rejected_tokens["prompt_attention_mask"] - - # add EOS token to end of answer - chosen_tokens["input_ids"].append(self.tokenizer.eos_token_id) - chosen_tokens["attention_mask"].append(1) - - rejected_tokens["input_ids"].append(self.tokenizer.eos_token_id) - rejected_tokens["attention_mask"].append(1) + # add BOS token to head of prompt. Avoid adding if it's already there + bos_token_id = self.tokenizer.bos_token_id + if prompt_len_input_ids == 0 or bos_token_id != prompt_tokens["prompt_input_ids"][0]: + prompt_tokens["prompt_input_ids"] = [bos_token_id] + prompt_tokens["prompt_input_ids"] + prompt_tokens["prompt_attention_mask"] = [1] + prompt_tokens["prompt_attention_mask"] + if chosen_prompt_len_input_ids == 0 or bos_token_id != chosen_tokens["prompt_input_ids"][0]: + chosen_tokens["prompt_input_ids"] = [bos_token_id] + chosen_tokens["prompt_input_ids"] + chosen_tokens["prompt_attention_mask"] = [1] + chosen_tokens["prompt_attention_mask"] + if rejected_prompt_len_input_ids == 0 or bos_token_id != rejected_tokens["prompt_input_ids"][0]: + rejected_tokens["prompt_input_ids"] = [bos_token_id] + rejected_tokens["prompt_input_ids"] + rejected_tokens["prompt_attention_mask"] = [1] + rejected_tokens["prompt_attention_mask"] + + # add EOS token to end of answer. Avoid adding if it's already there + eos_token_id = self.tokenizer.eos_token_id + if len(chosen_tokens["input_ids"]) == 0 or eos_token_id != chosen_tokens["input_ids"][-1]: + chosen_tokens["input_ids"].append(eos_token_id) + chosen_tokens["attention_mask"].append(1) + if len(rejected_tokens["input_ids"]) == 0 or eos_token_id != rejected_tokens["input_ids"][-1]: + rejected_tokens["input_ids"].append(eos_token_id) + rejected_tokens["attention_mask"].append(1) longer_response_length = max(len(chosen_tokens["input_ids"]), len(rejected_tokens["input_ids"])) diff --git a/trl/trainer/kto_trainer.py b/trl/trainer/kto_trainer.py index e66d4a24d0..a22e74909f 100644 --- a/trl/trainer/kto_trainer.py +++ b/trl/trainer/kto_trainer.py @@ -186,7 +186,15 @@ def _process_tokens(example: Dict[str, Any], model: "PreTrainedModel" = None, ** "answer_attention_mask": example["answer_attention_mask"], } - max_length = kwargs["max_length"] - 2 + # calculate max length by checking if BOS/EOS is already there + max_length = kwargs["max_length"] + bos_token_id = kwargs["tokenizer"].bos_token_id + eos_token_id = kwargs["tokenizer"].eos_token_id + if bos_token_id != all_tokens["prompt_input_ids"][0]: + max_length -= 1 + if eos_token_id != all_tokens["answer_input_ids"][-1]: + max_length -= 1 + # if combined sequence is too long (> max_length - 1 for BOS token - 1 for EOS), truncate the prompt if len(all_tokens["prompt_input_ids"]) + len(all_tokens["answer_input_ids"]) > max_length: for k in ["prompt_input_ids", "prompt_attention_mask"]: @@ -202,21 +210,37 @@ def _process_tokens(example: Dict[str, Any], model: "PreTrainedModel" = None, ** for k in ["answer_input_ids", "answer_attention_mask"]: all_tokens[k] = all_tokens[k][: max_length - kwargs["max_prompt_length"]] - # for legacy reasons, use the completion_* prefix to now refer to the joint sequence - batch[f"{kwargs['prefix']}prompt_input_ids"] = [kwargs["tokenizer"].bos_token_id] + all_tokens[ - "prompt_input_ids" - ] - batch[f"{kwargs['prefix']}prompt_attention_mask"] = [1] + all_tokens["prompt_attention_mask"] + # all input_ids and attention mask as is. We then check if we need to add BOS/EOS tokens + batch[f"{kwargs['prefix']}prompt_input_ids"] = all_tokens["prompt_input_ids"] + batch[f"{kwargs['prefix']}prompt_attention_mask"] = all_tokens["prompt_attention_mask"] batch[f"{kwargs['prefix']}completion_input_ids"] = ( - [kwargs["tokenizer"].bos_token_id] - + all_tokens["prompt_input_ids"] - + all_tokens["answer_input_ids"] - + [kwargs["tokenizer"].eos_token_id] + all_tokens["prompt_input_ids"] + all_tokens["answer_input_ids"] ) batch[f"{kwargs['prefix']}completion_attention_mask"] = ( - [1] + all_tokens["prompt_attention_mask"] + all_tokens["answer_attention_mask"] + [1] + all_tokens["prompt_attention_mask"] + all_tokens["answer_attention_mask"] ) + # add BOS, which affects both prompt and the full completion + if len(all_tokens["prompt_input_ids"]) == 0 or bos_token_id != all_tokens["prompt_input_ids"][0]: + batch[f"{kwargs['prefix']}prompt_input_ids"] = [bos_token_id] + batch[ + f"{kwargs['prefix']}prompt_input_ids" + ] + batch[f"{kwargs['prefix']}prompt_attention_mask"] = [1] + batch[f"{kwargs['prefix']}prompt_attention_mask"] + batch[f"{kwargs['prefix']}completion_input_ids"] = [bos_token_id] + batch[ + f"{kwargs['prefix']}completion_input_ids" + ] + batch[f"{kwargs['prefix']}completion_attention_mask"] = [1] + batch[ + f"{kwargs['prefix']}completion_attention_mask" + ] + # add EOS, which affects only the full completion + if len(all_tokens["answer_input_ids"]) == 0 or eos_token_id != all_tokens["answer_input_ids"][-1]: + batch[f"{kwargs['prefix']}completion_input_ids"] = batch[f"{kwargs['prefix']}completion_input_ids"] + [ + eos_token_id + ] + batch[f"{kwargs['prefix']}completion_attention_mask"] = batch[ + f"{kwargs['prefix']}completion_attention_mask" + ] + [1] + batch[f"{kwargs['prefix']}completion_labels"] = batch[f"{kwargs['prefix']}completion_input_ids"][:] batch[f"{kwargs['prefix']}completion_labels"][: len(batch[f"{kwargs['prefix']}prompt_input_ids"])] = [ kwargs["label_pad_token_id"] diff --git a/trl/trainer/orpo_trainer.py b/trl/trainer/orpo_trainer.py index c1d8f0dca7..84a800ba4c 100644 --- a/trl/trainer/orpo_trainer.py +++ b/trl/trainer/orpo_trainer.py @@ -439,21 +439,26 @@ def tokenize_row(self, feature, model: Optional[Union[PreTrainedModel, nn.Module "last token due to tokenizer merge ops." ) - # add BOS token to head of prompt - prompt_tokens["prompt_input_ids"] = [self.tokenizer.bos_token_id] + prompt_tokens["prompt_input_ids"] - chosen_tokens["prompt_input_ids"] = [self.tokenizer.bos_token_id] + chosen_tokens["prompt_input_ids"] - rejected_tokens["prompt_input_ids"] = [self.tokenizer.bos_token_id] + rejected_tokens["prompt_input_ids"] - - prompt_tokens["prompt_attention_mask"] = [1] + prompt_tokens["prompt_attention_mask"] - chosen_tokens["prompt_attention_mask"] = [1] + chosen_tokens["prompt_attention_mask"] - rejected_tokens["prompt_attention_mask"] = [1] + rejected_tokens["prompt_attention_mask"] - - # add EOS token to end of answer - chosen_tokens["input_ids"].append(self.tokenizer.eos_token_id) - chosen_tokens["attention_mask"].append(1) - - rejected_tokens["input_ids"].append(self.tokenizer.eos_token_id) - rejected_tokens["attention_mask"].append(1) + # add BOS token to head of prompt. Avoid adding if it's already there + bos_token_id = self.tokenizer.bos_token_id + if prompt_len_input_ids == 0 or bos_token_id != prompt_tokens["prompt_input_ids"][0]: + prompt_tokens["prompt_input_ids"] = [bos_token_id] + prompt_tokens["prompt_input_ids"] + prompt_tokens["prompt_attention_mask"] = [1] + prompt_tokens["prompt_attention_mask"] + if chosen_prompt_len_input_ids == 0 or bos_token_id != chosen_tokens["prompt_input_ids"][0]: + chosen_tokens["prompt_input_ids"] = [bos_token_id] + chosen_tokens["prompt_input_ids"] + chosen_tokens["prompt_attention_mask"] = [1] + chosen_tokens["prompt_attention_mask"] + if rejected_prompt_len_input_ids == 0 or bos_token_id != rejected_tokens["prompt_input_ids"][0]: + rejected_tokens["prompt_input_ids"] = [bos_token_id] + rejected_tokens["prompt_input_ids"] + rejected_tokens["prompt_attention_mask"] = [1] + rejected_tokens["prompt_attention_mask"] + + # add EOS token to end of answer. Avoid adding if it's already there + eos_token_id = self.tokenizer.eos_token_id + if len(chosen_tokens["input_ids"]) == 0 or eos_token_id != chosen_tokens["input_ids"][-1]: + chosen_tokens["input_ids"].append(eos_token_id) + chosen_tokens["attention_mask"].append(1) + if len(rejected_tokens["input_ids"]) == 0 or eos_token_id != rejected_tokens["input_ids"][-1]: + rejected_tokens["input_ids"].append(eos_token_id) + rejected_tokens["attention_mask"].append(1) longer_response_length = max(len(chosen_tokens["input_ids"]), len(rejected_tokens["input_ids"]))