Skip to content

Commit

Permalink
fixed adding bos and eos token unconditionally (#1591)
Browse files Browse the repository at this point in the history
* fixed adding bos and eos token unconditionally

* fixed typo of tokenizer -> self.tokenizer. Also added update to ORPO

* fixed code quality, and added BOS/EOS fix to KTO

* code reformatting with pre-commit run --all-files

* bug fix: check input id length before checking for EOS/BOS
  • Loading branch information
jasonyux authored May 3, 2024
1 parent 0347f58 commit 3b4c249
Show file tree
Hide file tree
Showing 4 changed files with 95 additions and 56 deletions.
35 changes: 20 additions & 15 deletions trl/trainer/cpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -409,21 +409,26 @@ def tokenize_row(self, feature, model: Optional[Union[PreTrainedModel, nn.Module
"last token due to tokenizer merge ops."
)

# add BOS token to head of prompt
prompt_tokens["prompt_input_ids"] = [self.tokenizer.bos_token_id] + prompt_tokens["prompt_input_ids"]
chosen_tokens["prompt_input_ids"] = [self.tokenizer.bos_token_id] + chosen_tokens["prompt_input_ids"]
rejected_tokens["prompt_input_ids"] = [self.tokenizer.bos_token_id] + rejected_tokens["prompt_input_ids"]

prompt_tokens["prompt_attention_mask"] = [1] + prompt_tokens["prompt_attention_mask"]
chosen_tokens["prompt_attention_mask"] = [1] + chosen_tokens["prompt_attention_mask"]
rejected_tokens["prompt_attention_mask"] = [1] + rejected_tokens["prompt_attention_mask"]

# add EOS token to end of answer
chosen_tokens["input_ids"].append(self.tokenizer.eos_token_id)
chosen_tokens["attention_mask"].append(1)

rejected_tokens["input_ids"].append(self.tokenizer.eos_token_id)
rejected_tokens["attention_mask"].append(1)
# add BOS token to head of prompt. Avoid adding if it's already there
bos_token_id = self.tokenizer.bos_token_id
if prompt_len_input_ids == 0 or bos_token_id != prompt_tokens["prompt_input_ids"][0]:
prompt_tokens["prompt_input_ids"] = [bos_token_id] + prompt_tokens["prompt_input_ids"]
prompt_tokens["prompt_attention_mask"] = [1] + prompt_tokens["prompt_attention_mask"]
if chosen_prompt_len_input_ids == 0 or bos_token_id != chosen_tokens["prompt_input_ids"][0]:
chosen_tokens["prompt_input_ids"] = [bos_token_id] + chosen_tokens["prompt_input_ids"]
chosen_tokens["prompt_attention_mask"] = [1] + chosen_tokens["prompt_attention_mask"]
if rejected_prompt_len_input_ids == 0 or bos_token_id != rejected_tokens["prompt_input_ids"][0]:
rejected_tokens["prompt_input_ids"] = [bos_token_id] + rejected_tokens["prompt_input_ids"]
rejected_tokens["prompt_attention_mask"] = [1] + rejected_tokens["prompt_attention_mask"]

# add EOS token to end of answer. Avoid adding if it's already there
eos_token_id = self.tokenizer.eos_token_id
if len(chosen_tokens["input_ids"]) == 0 or eos_token_id != chosen_tokens["input_ids"][-1]:
chosen_tokens["input_ids"].append(eos_token_id)
chosen_tokens["attention_mask"].append(1)
if len(rejected_tokens["input_ids"]) == 0 or eos_token_id != rejected_tokens["input_ids"][-1]:
rejected_tokens["input_ids"].append(eos_token_id)
rejected_tokens["attention_mask"].append(1)

longer_response_length = max(len(chosen_tokens["input_ids"]), len(rejected_tokens["input_ids"]))

Expand Down
35 changes: 20 additions & 15 deletions trl/trainer/dpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -771,21 +771,26 @@ def tokenize_row(self, feature, model: Optional[Union[PreTrainedModel, nn.Module
"last token due to tokenizer merge ops."
)

# add BOS token to head of prompt
prompt_tokens["prompt_input_ids"] = [self.tokenizer.bos_token_id] + prompt_tokens["prompt_input_ids"]
chosen_tokens["prompt_input_ids"] = [self.tokenizer.bos_token_id] + chosen_tokens["prompt_input_ids"]
rejected_tokens["prompt_input_ids"] = [self.tokenizer.bos_token_id] + rejected_tokens["prompt_input_ids"]

prompt_tokens["prompt_attention_mask"] = [1] + prompt_tokens["prompt_attention_mask"]
chosen_tokens["prompt_attention_mask"] = [1] + chosen_tokens["prompt_attention_mask"]
rejected_tokens["prompt_attention_mask"] = [1] + rejected_tokens["prompt_attention_mask"]

# add EOS token to end of answer
chosen_tokens["input_ids"].append(self.tokenizer.eos_token_id)
chosen_tokens["attention_mask"].append(1)

rejected_tokens["input_ids"].append(self.tokenizer.eos_token_id)
rejected_tokens["attention_mask"].append(1)
# add BOS token to head of prompt. Avoid adding if it's already there
bos_token_id = self.tokenizer.bos_token_id
if prompt_len_input_ids == 0 or bos_token_id != prompt_tokens["prompt_input_ids"][0]:
prompt_tokens["prompt_input_ids"] = [bos_token_id] + prompt_tokens["prompt_input_ids"]
prompt_tokens["prompt_attention_mask"] = [1] + prompt_tokens["prompt_attention_mask"]
if chosen_prompt_len_input_ids == 0 or bos_token_id != chosen_tokens["prompt_input_ids"][0]:
chosen_tokens["prompt_input_ids"] = [bos_token_id] + chosen_tokens["prompt_input_ids"]
chosen_tokens["prompt_attention_mask"] = [1] + chosen_tokens["prompt_attention_mask"]
if rejected_prompt_len_input_ids == 0 or bos_token_id != rejected_tokens["prompt_input_ids"][0]:
rejected_tokens["prompt_input_ids"] = [bos_token_id] + rejected_tokens["prompt_input_ids"]
rejected_tokens["prompt_attention_mask"] = [1] + rejected_tokens["prompt_attention_mask"]

# add EOS token to end of answer. Avoid adding if it's already there
eos_token_id = self.tokenizer.eos_token_id
if len(chosen_tokens["input_ids"]) == 0 or eos_token_id != chosen_tokens["input_ids"][-1]:
chosen_tokens["input_ids"].append(eos_token_id)
chosen_tokens["attention_mask"].append(1)
if len(rejected_tokens["input_ids"]) == 0 or eos_token_id != rejected_tokens["input_ids"][-1]:
rejected_tokens["input_ids"].append(eos_token_id)
rejected_tokens["attention_mask"].append(1)

longer_response_length = max(len(chosen_tokens["input_ids"]), len(rejected_tokens["input_ids"]))

Expand Down
46 changes: 35 additions & 11 deletions trl/trainer/kto_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,15 @@ def _process_tokens(example: Dict[str, Any], model: "PreTrainedModel" = None, **
"answer_attention_mask": example["answer_attention_mask"],
}

max_length = kwargs["max_length"] - 2
# calculate max length by checking if BOS/EOS is already there
max_length = kwargs["max_length"]
bos_token_id = kwargs["tokenizer"].bos_token_id
eos_token_id = kwargs["tokenizer"].eos_token_id
if bos_token_id != all_tokens["prompt_input_ids"][0]:
max_length -= 1
if eos_token_id != all_tokens["answer_input_ids"][-1]:
max_length -= 1

# if combined sequence is too long (> max_length - 1 for BOS token - 1 for EOS), truncate the prompt
if len(all_tokens["prompt_input_ids"]) + len(all_tokens["answer_input_ids"]) > max_length:
for k in ["prompt_input_ids", "prompt_attention_mask"]:
Expand All @@ -202,21 +210,37 @@ def _process_tokens(example: Dict[str, Any], model: "PreTrainedModel" = None, **
for k in ["answer_input_ids", "answer_attention_mask"]:
all_tokens[k] = all_tokens[k][: max_length - kwargs["max_prompt_length"]]

# for legacy reasons, use the completion_* prefix to now refer to the joint sequence
batch[f"{kwargs['prefix']}prompt_input_ids"] = [kwargs["tokenizer"].bos_token_id] + all_tokens[
"prompt_input_ids"
]
batch[f"{kwargs['prefix']}prompt_attention_mask"] = [1] + all_tokens["prompt_attention_mask"]
# all input_ids and attention mask as is. We then check if we need to add BOS/EOS tokens
batch[f"{kwargs['prefix']}prompt_input_ids"] = all_tokens["prompt_input_ids"]
batch[f"{kwargs['prefix']}prompt_attention_mask"] = all_tokens["prompt_attention_mask"]
batch[f"{kwargs['prefix']}completion_input_ids"] = (
[kwargs["tokenizer"].bos_token_id]
+ all_tokens["prompt_input_ids"]
+ all_tokens["answer_input_ids"]
+ [kwargs["tokenizer"].eos_token_id]
all_tokens["prompt_input_ids"] + all_tokens["answer_input_ids"]
)
batch[f"{kwargs['prefix']}completion_attention_mask"] = (
[1] + all_tokens["prompt_attention_mask"] + all_tokens["answer_attention_mask"] + [1]
all_tokens["prompt_attention_mask"] + all_tokens["answer_attention_mask"]
)

# add BOS, which affects both prompt and the full completion
if len(all_tokens["prompt_input_ids"]) == 0 or bos_token_id != all_tokens["prompt_input_ids"][0]:
batch[f"{kwargs['prefix']}prompt_input_ids"] = [bos_token_id] + batch[
f"{kwargs['prefix']}prompt_input_ids"
]
batch[f"{kwargs['prefix']}prompt_attention_mask"] = [1] + batch[f"{kwargs['prefix']}prompt_attention_mask"]
batch[f"{kwargs['prefix']}completion_input_ids"] = [bos_token_id] + batch[
f"{kwargs['prefix']}completion_input_ids"
]
batch[f"{kwargs['prefix']}completion_attention_mask"] = [1] + batch[
f"{kwargs['prefix']}completion_attention_mask"
]
# add EOS, which affects only the full completion
if len(all_tokens["answer_input_ids"]) == 0 or eos_token_id != all_tokens["answer_input_ids"][-1]:
batch[f"{kwargs['prefix']}completion_input_ids"] = batch[f"{kwargs['prefix']}completion_input_ids"] + [
eos_token_id
]
batch[f"{kwargs['prefix']}completion_attention_mask"] = batch[
f"{kwargs['prefix']}completion_attention_mask"
] + [1]

batch[f"{kwargs['prefix']}completion_labels"] = batch[f"{kwargs['prefix']}completion_input_ids"][:]
batch[f"{kwargs['prefix']}completion_labels"][: len(batch[f"{kwargs['prefix']}prompt_input_ids"])] = [
kwargs["label_pad_token_id"]
Expand Down
35 changes: 20 additions & 15 deletions trl/trainer/orpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -439,21 +439,26 @@ def tokenize_row(self, feature, model: Optional[Union[PreTrainedModel, nn.Module
"last token due to tokenizer merge ops."
)

# add BOS token to head of prompt
prompt_tokens["prompt_input_ids"] = [self.tokenizer.bos_token_id] + prompt_tokens["prompt_input_ids"]
chosen_tokens["prompt_input_ids"] = [self.tokenizer.bos_token_id] + chosen_tokens["prompt_input_ids"]
rejected_tokens["prompt_input_ids"] = [self.tokenizer.bos_token_id] + rejected_tokens["prompt_input_ids"]

prompt_tokens["prompt_attention_mask"] = [1] + prompt_tokens["prompt_attention_mask"]
chosen_tokens["prompt_attention_mask"] = [1] + chosen_tokens["prompt_attention_mask"]
rejected_tokens["prompt_attention_mask"] = [1] + rejected_tokens["prompt_attention_mask"]

# add EOS token to end of answer
chosen_tokens["input_ids"].append(self.tokenizer.eos_token_id)
chosen_tokens["attention_mask"].append(1)

rejected_tokens["input_ids"].append(self.tokenizer.eos_token_id)
rejected_tokens["attention_mask"].append(1)
# add BOS token to head of prompt. Avoid adding if it's already there
bos_token_id = self.tokenizer.bos_token_id
if prompt_len_input_ids == 0 or bos_token_id != prompt_tokens["prompt_input_ids"][0]:
prompt_tokens["prompt_input_ids"] = [bos_token_id] + prompt_tokens["prompt_input_ids"]
prompt_tokens["prompt_attention_mask"] = [1] + prompt_tokens["prompt_attention_mask"]
if chosen_prompt_len_input_ids == 0 or bos_token_id != chosen_tokens["prompt_input_ids"][0]:
chosen_tokens["prompt_input_ids"] = [bos_token_id] + chosen_tokens["prompt_input_ids"]
chosen_tokens["prompt_attention_mask"] = [1] + chosen_tokens["prompt_attention_mask"]
if rejected_prompt_len_input_ids == 0 or bos_token_id != rejected_tokens["prompt_input_ids"][0]:
rejected_tokens["prompt_input_ids"] = [bos_token_id] + rejected_tokens["prompt_input_ids"]
rejected_tokens["prompt_attention_mask"] = [1] + rejected_tokens["prompt_attention_mask"]

# add EOS token to end of answer. Avoid adding if it's already there
eos_token_id = self.tokenizer.eos_token_id
if len(chosen_tokens["input_ids"]) == 0 or eos_token_id != chosen_tokens["input_ids"][-1]:
chosen_tokens["input_ids"].append(eos_token_id)
chosen_tokens["attention_mask"].append(1)
if len(rejected_tokens["input_ids"]) == 0 or eos_token_id != rejected_tokens["input_ids"][-1]:
rejected_tokens["input_ids"].append(eos_token_id)
rejected_tokens["attention_mask"].append(1)

longer_response_length = max(len(chosen_tokens["input_ids"]), len(rejected_tokens["input_ids"]))

Expand Down

0 comments on commit 3b4c249

Please sign in to comment.