Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

✂️ Reintroduce truncation_mode in DPOTrainer #2551

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 11 additions & 9 deletions trl/trainer/dpo_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,14 +71,15 @@ class DPOConfig(TrainingArguments):
Padding value to use. If `None`, the padding value of the tokenizer is used.
label_pad_token_id (`int`, *optional*, defaults to `-100`):
Padding value to use for labels.
truncation_mode (`str`, *optional*, defaults to `"keep_end"`):
Truncation mode to usewhen the prompt is too long, either `keep_end` or `keep_start`.
max_prompt_length (`int` or `None`, *optional*, defaults to `512`):
Maximum length of the prompt.
max_completion_length (`int` or `None`, *optional*, defaults to `None`):
Maximum length of the completion.
max_length (`int` or `None`, *optional*, defaults to `1024`):
Maximum length of the full sequence (prompt + completion).
truncation_mode (`str`, *optional*, defaults to `"keep_end"`):
Truncation mode to use when the sequence exceeds `max_length`. Possible values are `"keep_end"` and
`"keep_start"`.
padding_free (`bool`, *optional*, defaults to `False`):
Whether forward passes are performed without padding by flattening all sequences in the batch
into a single continuous sequence. This approach requires associating a `position_ids` vector to track
Expand Down Expand Up @@ -219,13 +220,6 @@ class DPOConfig(TrainingArguments):
default=-100,
metadata={"help": "Padding value to use for labels."},
)
truncation_mode: str = field(
default="keep_end",
metadata={
"help": "Truncation mode to use when the prompt is too long.",
"choices": ["keep_end", "keep_start"],
},
)
max_prompt_length: Optional[int] = field(
default=512,
metadata={"help": "Maximum length of the prompt."},
Expand All @@ -238,6 +232,14 @@ class DPOConfig(TrainingArguments):
default=1024,
metadata={"help": "Maximum length of the full sequence (prompt + completion)."},
)
truncation_mode: str = field(
default="keep_end",
metadata={
"help": "Truncation mode to use when the sequence exceeds `max_length`. Possible values are `'keep_end'` "
"and `'keep_start'`.",
"choices": ["keep_end", "keep_start"],
},
)
padding_free: bool = field(
default=False,
metadata={
Expand Down
26 changes: 19 additions & 7 deletions trl/trainer/dpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -388,12 +388,12 @@ def make_inputs_require_grad(module, input, output):
if self.ref_model is not None:
disable_dropout_in_model(self.ref_model)

self.max_length = args.max_length
self.generate_during_eval = args.generate_during_eval
self.label_pad_token_id = args.label_pad_token_id
self.max_prompt_length = args.max_prompt_length
self.truncation_mode = args.truncation_mode
self.max_completion_length = args.max_completion_length
self.max_length = args.max_length
self.truncation_mode = args.truncation_mode
self.precompute_ref_log_probs = args.precompute_ref_log_probs
self.use_num_logits_to_keep = args.use_num_logits_to_keep

Expand Down Expand Up @@ -595,7 +595,9 @@ def tokenize_row(features, processing_class, max_prompt_length, max_completion_l
>>> from transformers import GPT2Tokenizer
>>> tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
>>> features = {"prompt": "The sky is", "chosen": " blue", "rejected": " green"}
>>> DPOTrainer.tokenize_row(features, tokenizer, max_prompt_length=3, max_completion_length=3, add_special_tokens=False)
>>> DPOTrainer.tokenize_row(
... features, tokenizer, max_prompt_length=3, max_completion_length=3, add_special_tokens=False
... )
{'prompt_input_ids': [464, 6766, 318], 'chosen_input_ids': [4171, 50256], 'rejected_input_ids': [4077, 50256]}
```
"""
Expand Down Expand Up @@ -1145,10 +1147,20 @@ def concatenated_forward(self, model: nn.Module, batch: dict[str, Union[list, to
attention_mask, input_ids, loss_mask = flush_left(attention_mask, input_ids, loss_mask)

# Truncate right
if self.args.max_length is not None:
input_ids = input_ids[:, : self.args.max_length]
attention_mask = attention_mask[:, : self.args.max_length]
loss_mask = loss_mask[:, : self.args.max_length]
if self.max_length is not None:
if self.truncation_mode == "keep_end":
input_ids = input_ids[:, -self.max_length :]
attention_mask = attention_mask[:, -self.max_length :]
loss_mask = loss_mask[:, -self.max_length :]
elif self.truncation_mode == "keep_start":
input_ids = input_ids[:, : self.max_length]
attention_mask = attention_mask[:, : self.max_length]
loss_mask = loss_mask[:, : self.max_length]
else:
raise ValueError(
f"Unknown truncation mode: '{self.truncation_mode}'. Should be one of ['keep_end', "
"'keep_start']."
)

if self.use_num_logits_to_keep:
# Compute num_logits_to_keep based on loss_mask pattern:
Expand Down
Loading