Skip to content

Commit

Permalink
fix tokenize_row in XXPOTrainer
Browse files Browse the repository at this point in the history
  • Loading branch information
AIR-hl committed May 30, 2024
1 parent 151a452 commit 177d213
Show file tree
Hide file tree
Showing 5 changed files with 227 additions and 619 deletions.
211 changes: 6 additions & 205 deletions trl/trainer/cpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import warnings
from collections import defaultdict
from contextlib import nullcontext
from functools import wraps
from functools import wraps, partial
from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union

import numpy as np
Expand All @@ -41,6 +41,7 @@
pad_to_length,
peft_module_casting_to_bf16,
trl_sanitze_kwargs_for_tagging,
tokenize_row
)


Expand Down Expand Up @@ -274,9 +275,11 @@ def make_inputs_require_grad(module, input, output):
# see: https://github.com/huggingface/trl/pull/1255
with PartialState().local_main_process_first():
# tokenize the dataset
train_dataset = train_dataset.map(self.tokenize_row, num_proc=args.dataset_num_proc)
train_dataset = train_dataset.map(partial(tokenize_row, args=args, tokenizer=self.tokenizer),
num_proc=args.dataset_num_proc)
if eval_dataset is not None:
eval_dataset = eval_dataset.map(self.tokenize_row, num_proc=args.dataset_num_proc)
eval_dataset = eval_dataset.map(partial(tokenize_row, args=args, tokenizer=self.tokenizer),
num_proc=args.dataset_num_proc)

super().__init__(
model=model,
Expand All @@ -300,208 +303,6 @@ def make_inputs_require_grad(module, input, output):
raise AttributeError(
"Your `Trainer` does not have an `accelerator` object. Consider upgrading `transformers`."
)

def build_tokenized_answer(self, prompt, answer):
"""
Llama tokenizer does satisfy `enc(a + b) = enc(a) + enc(b)`.
It does ensure `enc(a + b) = enc(a) + enc(a + b)[len(enc(a)):]`.
Reference:
https://github.com/EleutherAI/lm-evaluation-harness/pull/531#issuecomment-1595586257
"""

full_tokenized = self.tokenizer(prompt + answer, add_special_tokens=False)
prompt_input_ids = self.tokenizer(prompt, add_special_tokens=False)["input_ids"]

answer_input_ids = full_tokenized["input_ids"][len(prompt_input_ids) :]
answer_attention_mask = full_tokenized["attention_mask"][len(prompt_input_ids) :]

# Concat tokens to form `enc(a) + enc(a + b)[len(enc(a)):]`
full_concat_input_ids = np.concatenate([prompt_input_ids, answer_input_ids])

# Prepare input tokens for token by token comparison
full_input_ids = np.array(full_tokenized["input_ids"])

if len(full_input_ids) != len(full_concat_input_ids):
raise ValueError("Prompt input ids and answer input ids should have the same length.")

# On some tokenizers, like Llama-2 tokenizer, there are occasions where tokens
# can be merged together when tokenizing prompt+answer. This could result
# on the last token from the prompt being different when tokenized on its own
# vs when done as prompt+answer.
response_token_ids_start_idx = len(prompt_input_ids)

# If tokenized prompt is different than both prompt+answer, then it means the
# last token has changed due to merging.
if prompt_input_ids != full_tokenized["input_ids"][:response_token_ids_start_idx]:
response_token_ids_start_idx -= 1

prompt_input_ids = full_tokenized["input_ids"][:response_token_ids_start_idx]
prompt_attention_mask = full_tokenized["attention_mask"][:response_token_ids_start_idx]

if len(prompt_input_ids) != len(prompt_attention_mask):
raise ValueError("Prompt input ids and attention mask should have the same length.")

answer_input_ids = full_tokenized["input_ids"][response_token_ids_start_idx:]
answer_attention_mask = full_tokenized["attention_mask"][response_token_ids_start_idx:]

return dict(
prompt_input_ids=prompt_input_ids,
prompt_attention_mask=prompt_attention_mask,
input_ids=answer_input_ids,
attention_mask=answer_attention_mask,
)

def tokenize_row(self, feature, model: Optional[Union[PreTrainedModel, nn.Module]] = None) -> Dict:
"""Tokenize a single row from a CPO specific dataset.
At this stage, we don't convert to PyTorch tensors yet; we just handle the truncation
in case the prompt + chosen or prompt + rejected responses is/are too long. First
we truncate the prompt; if we're still too long, we truncate the chosen/rejected.
We also create the labels for the chosen/rejected responses, which are of length equal to
the sum of the length of the prompt and the chosen/rejected response, with
label_pad_token_id for the prompt tokens.
"""
batch = {}
prompt = feature["prompt"]
chosen = feature["chosen"]
rejected = feature["rejected"]

if not self.is_encoder_decoder:
# Check issues below for more details
# 1. https://github.com/huggingface/trl/issues/907
# 2. https://github.com/EleutherAI/lm-evaluation-harness/pull/531#issuecomment-1595586257
# 3. https://github.com/LianjiaTech/BELLE/issues/337

if not isinstance(prompt, str):
raise ValueError(f"prompt should be an str but got {type(prompt)}")
prompt_tokens = self.tokenizer(prompt, add_special_tokens=False)
prompt_tokens = {f"prompt_{k}": v for k, v in prompt_tokens.items()}

if not isinstance(chosen, str):
raise ValueError(f"chosen should be an str but got {type(chosen)}")
chosen_tokens = self.build_tokenized_answer(prompt, chosen)

if not isinstance(rejected, str):
raise ValueError(f"rejected should be an str but got {type(rejected)}")
rejected_tokens = self.build_tokenized_answer(prompt, rejected)

# Last prompt token might get merged by tokenizer and
# it should not be included for generation if that happens
prompt_len_input_ids = len(prompt_tokens["prompt_input_ids"])

chosen_prompt_len_input_ids = len(chosen_tokens["prompt_input_ids"])
rejected_prompt_len_input_ids = len(rejected_tokens["prompt_input_ids"])
prompt_len_input_ids = min(chosen_prompt_len_input_ids, rejected_prompt_len_input_ids)

for k, v in prompt_tokens.items():
prompt_tokens[k] = v[:prompt_len_input_ids]

# Make sure prompts only have one different token at most an
# and length only differs by 1 at most
num_diff_tokens = sum(
[a != b for a, b in zip(chosen_tokens["prompt_input_ids"], rejected_tokens["prompt_input_ids"])]
)
num_diff_len = abs(chosen_prompt_len_input_ids - rejected_prompt_len_input_ids)
if num_diff_tokens > 1 or num_diff_len > 1:
raise ValueError(
"Chosen and rejected prompt_input_ids might only differ on the "
"last token due to tokenizer merge ops."
)

# add BOS token to head of prompt. Avoid adding if it's already there
bos_token_id = self.tokenizer.bos_token_id
if prompt_len_input_ids == 0 or bos_token_id != prompt_tokens["prompt_input_ids"][0]:
prompt_tokens["prompt_input_ids"] = [bos_token_id] + prompt_tokens["prompt_input_ids"]
prompt_tokens["prompt_attention_mask"] = [1] + prompt_tokens["prompt_attention_mask"]
if chosen_prompt_len_input_ids == 0 or bos_token_id != chosen_tokens["prompt_input_ids"][0]:
chosen_tokens["prompt_input_ids"] = [bos_token_id] + chosen_tokens["prompt_input_ids"]
chosen_tokens["prompt_attention_mask"] = [1] + chosen_tokens["prompt_attention_mask"]
if rejected_prompt_len_input_ids == 0 or bos_token_id != rejected_tokens["prompt_input_ids"][0]:
rejected_tokens["prompt_input_ids"] = [bos_token_id] + rejected_tokens["prompt_input_ids"]
rejected_tokens["prompt_attention_mask"] = [1] + rejected_tokens["prompt_attention_mask"]

# add EOS token to end of answer. Avoid adding if it's already there
eos_token_id = self.tokenizer.eos_token_id
if len(chosen_tokens["input_ids"]) == 0 or eos_token_id != chosen_tokens["input_ids"][-1]:
chosen_tokens["input_ids"].append(eos_token_id)
chosen_tokens["attention_mask"].append(1)
if len(rejected_tokens["input_ids"]) == 0 or eos_token_id != rejected_tokens["input_ids"][-1]:
rejected_tokens["input_ids"].append(eos_token_id)
rejected_tokens["attention_mask"].append(1)

longer_response_length = max(len(chosen_tokens["input_ids"]), len(rejected_tokens["input_ids"]))

# if combined sequence is too long, truncate the prompt
for answer_tokens in [chosen_tokens, rejected_tokens, prompt_tokens]:
if len(answer_tokens["prompt_input_ids"]) + longer_response_length > self.max_length:
if self.truncation_mode == "keep_start":
for k in ["prompt_input_ids", "prompt_attention_mask"]:
answer_tokens[k] = answer_tokens[k][: self.max_prompt_length]
elif self.truncation_mode == "keep_end":
for k in ["prompt_input_ids", "prompt_attention_mask"]:
answer_tokens[k] = answer_tokens[k][-self.max_prompt_length :]
else:
raise ValueError(f"Unknown truncation mode: {self.truncation_mode}")

# if that's still too long, truncate the response
for answer_tokens in [chosen_tokens, rejected_tokens]:
if len(answer_tokens["prompt_input_ids"]) + longer_response_length > self.max_length:
for k in ["input_ids", "attention_mask"]:
answer_tokens[k] = answer_tokens[k][: self.max_length - self.max_prompt_length]

# Create labels
chosen_sequence_tokens = {
k: chosen_tokens[f"prompt_{k}"] + chosen_tokens[k] for k in ["input_ids", "attention_mask"]
}
rejected_sequence_tokens = {
k: rejected_tokens[f"prompt_{k}"] + rejected_tokens[k] for k in ["input_ids", "attention_mask"]
}
chosen_sequence_tokens["labels"] = chosen_sequence_tokens["input_ids"][:]
chosen_sequence_tokens["labels"][: len(chosen_tokens["prompt_input_ids"])] = [
self.label_pad_token_id
] * len(chosen_tokens["prompt_input_ids"])
rejected_sequence_tokens["labels"] = rejected_sequence_tokens["input_ids"][:]
rejected_sequence_tokens["labels"][: len(rejected_tokens["prompt_input_ids"])] = [
self.label_pad_token_id
] * len(rejected_tokens["prompt_input_ids"])

for k, toks in {
"chosen_": chosen_sequence_tokens,
"rejected_": rejected_sequence_tokens,
"": prompt_tokens,
}.items():
for type_key, tokens in toks.items():
if type_key == "token_type_ids":
continue
batch[f"{k}{type_key}"] = tokens

else:
chosen_tokens = self.tokenizer(
chosen, truncation=True, max_length=self.max_target_length, add_special_tokens=True
)
rejected_tokens = self.tokenizer(
rejected, truncation=True, max_length=self.max_target_length, add_special_tokens=True
)
prompt_tokens = self.tokenizer(
prompt, truncation=True, max_length=self.max_prompt_length, add_special_tokens=True
)

batch["chosen_labels"] = chosen_tokens["input_ids"]
batch["rejected_labels"] = rejected_tokens["input_ids"]
batch["prompt_input_ids"] = prompt_tokens["input_ids"]
batch["prompt_attention_mask"] = prompt_tokens["attention_mask"]

if model is not None and hasattr(model, "prepare_decoder_input_ids_from_labels"):
batch["rejected_decoder_input_ids"] = model.prepare_decoder_input_ids_from_labels(
labels=torch.tensor(batch["rejected_labels"])
)
batch["chosen_decoder_input_ids"] = model.prepare_decoder_input_ids_from_labels(
labels=torch.tensor(batch["chosen_labels"])
)

return batch

@staticmethod
def concatenated_inputs(
batch: Dict[str, Union[List, torch.LongTensor]],
Expand Down
Loading

0 comments on commit 177d213

Please sign in to comment.