Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

✂️ Reintroduce truncation_mode in DPOTrainer #2551

Merged
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 40 additions & 1 deletion tests/test_dpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ def test_tokenize_row_no_truncation_no_special_tokens(self):
max_prompt_length=None,
max_completion_length=None,
add_special_tokens=False,
truncation_mode="keep_end",
)

# Assert the correct output without truncation or special tokens
Expand All @@ -86,7 +87,7 @@ def test_tokenize_row_no_truncation_no_special_tokens(self):
},
)

def test_tokenize_row_with_truncation(self):
def test_tokenize_row_with_truncation_keep_end(self):
# Define the input features
features = {"prompt": "The sky is", "chosen": " blue", "rejected": " green"}

Expand All @@ -97,6 +98,7 @@ def test_tokenize_row_with_truncation(self):
max_prompt_length=2,
max_completion_length=1,
add_special_tokens=False,
truncation_mode="keep_end",
)

# Assert the correct output with truncation applied
Expand All @@ -109,6 +111,41 @@ def test_tokenize_row_with_truncation(self):
},
)

def test_tokenize_row_with_truncation_keep_start(self):
# Define the input features
features = {"prompt": "The sky is", "chosen": " blue", "rejected": " green"}

# Call the method with truncation
result = DPOTrainer.tokenize_row(
features=features,
processing_class=self.tokenizer,
max_prompt_length=2,
max_completion_length=1,
add_special_tokens=False,
truncation_mode="keep_start",
)

# Assert the correct output with truncation applied
self.assertEqual(
result,
{
"prompt_input_ids": [464, 6766], # truncated to the first 2 tokens
"chosen_input_ids": [4171], # truncated to 1 token
"rejected_input_ids": [4077], # truncated to 1 token
},
)

def test_tokenize_row_invalid_truncation_mode(self):
with self.assertRaises(ValueError):
DPOTrainer.tokenize_row(
features={"prompt": "The sky is", "chosen": " blue", "rejected": " green"},
processing_class=self.tokenizer,
max_prompt_length=2,
max_completion_length=1,
add_special_tokens=False,
truncation_mode="invalid",
)

def test_tokenize_row_with_special_tokens(self):
# Define the input features
features = {"prompt": "The sky is", "chosen": " blue", "rejected": " green"}
Expand All @@ -120,6 +157,7 @@ def test_tokenize_row_with_special_tokens(self):
max_prompt_length=None,
max_completion_length=None,
add_special_tokens=True,
truncation_mode="keep_end",
)

# Assert the correct output with special tokens added
Expand All @@ -143,6 +181,7 @@ def test_tokenize_row_with_truncation_and_special_tokens(self):
max_prompt_length=4,
max_completion_length=1,
add_special_tokens=True,
truncation_mode="keep_end",
)

# Assert the correct output with both truncation and special tokens
Expand Down
29 changes: 24 additions & 5 deletions trl/trainer/dpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -554,6 +554,7 @@ def _prepare_dataset(
"processing_class": processing_class,
"max_prompt_length": args.max_prompt_length,
"max_completion_length": args.max_completion_length,
"truncation_mode": args.truncation_mode,
anakin87 marked this conversation as resolved.
Show resolved Hide resolved
# for enc-dec, we add the special tokens ([bos_token] + prompt + [eos_token]; completion + [eos_token])
"add_special_tokens": False,
},
Expand All @@ -563,7 +564,9 @@ def _prepare_dataset(
return dataset

@staticmethod
def tokenize_row(features, processing_class, max_prompt_length, max_completion_length, add_special_tokens):
def tokenize_row(
features, processing_class, max_prompt_length, max_completion_length, add_special_tokens, truncation_mode
):
"""
Tokenize a row of the dataset.

Expand All @@ -580,6 +583,9 @@ def tokenize_row(features, processing_class, max_prompt_length, max_completion_l
Whether to add special tokens to the sequences. Typically used for encoder-decoder models. If `True`,
the prompt sequence will have a bos token prepended and an eos token appended. In any case, the
completion sequences will have an eos token appended.
truncation_mode (`str`):
Whether to truncate the prompt sequence from the end or the start. If `"keep_end"`, the prompt sequence
will be truncated from the end. If `"keep_start"`, the prompt sequence will be truncated from the start.

Returns:
`dict[str, list[int]]`:
Expand All @@ -591,7 +597,8 @@ def tokenize_row(features, processing_class, max_prompt_length, max_completion_l
>>> from transformers import GPT2Tokenizer
>>> tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
>>> features = {"prompt": "The sky is", "chosen": " blue", "rejected": " green"}
>>> DPOTrainer.tokenize_row(features, tokenizer, max_prompt_length=3, max_completion_length=3, add_special_tokens=False)
>>> DPOTrainer.tokenize_row(features, tokenizer, max_prompt_length=3, max_completion_length=3,
>>> add_special_tokens=False, truncation_mode="keep_end")
{'prompt_input_ids': [464, 6766, 318], 'chosen_input_ids': [4171, 50256], 'rejected_input_ids': [4077, 50256]}
```
"""
Expand All @@ -611,7 +618,12 @@ def tokenize_row(features, processing_class, max_prompt_length, max_completion_l

# Truncate prompt and completion sequences
if max_prompt_length is not None:
prompt_input_ids = prompt_input_ids[-max_prompt_length:]
if truncation_mode == "keep_end":
prompt_input_ids = prompt_input_ids[-max_prompt_length:]
elif truncation_mode == "keep_start":
prompt_input_ids = prompt_input_ids[:max_prompt_length]
else:
raise ValueError(f"Unknown truncation mode: {truncation_mode}")
if max_completion_length is not None:
chosen_input_ids = chosen_input_ids[:max_completion_length]
rejected_input_ids = rejected_input_ids[:max_completion_length]
Expand All @@ -623,7 +635,9 @@ def tokenize_row(features, processing_class, max_prompt_length, max_completion_l
}

@staticmethod
def process_row(features, processing_class, max_prompt_length, max_completion_length, add_special_tokens):
def process_row(
features, processing_class, max_prompt_length, max_completion_length, add_special_tokens, truncation_mode
):
"""
Same as `tokenize_row` but for vision models. Please refer to `tokenize_row` for more information.
"""
Expand All @@ -646,7 +660,12 @@ def process_row(features, processing_class, max_prompt_length, max_completion_le

# Truncate prompt and completion sequences
if max_prompt_length is not None:
prompt_input_ids = prompt_input_ids[-max_prompt_length:]
if truncation_mode == "keep_end":
prompt_input_ids = prompt_input_ids[-max_prompt_length:]
elif truncation_mode == "keep_start":
prompt_input_ids = prompt_input_ids[:max_prompt_length]
else:
raise ValueError(f"Unknown truncation mode: {truncation_mode}")
if max_completion_length is not None:
chosen_input_ids = chosen_input_ids[:max_completion_length]
rejected_input_ids = rejected_input_ids[:max_completion_length]
Expand Down