Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

✂️ Truncate by default #2587

Merged
merged 4 commits into from
Jan 17, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 23 additions & 1 deletion CONTRIBUTING.md
Original file line number Diff line number Diff line change
Expand Up @@ -257,6 +257,28 @@ That's how `make test` is implemented (without the `pip install` line)!
You can specify a smaller set of tests to test only the feature
you're working on.

### Default values guidelines

1. **Use defaults when appropriate**:

Provide default values unless the parameter's value varies significantly by use case. For example, datasets or models should not have defaults, but parameters like `learning_rate` should.

2. **Prioritize proven defaults**:

Default values should align with those recommended in the original paper or method. Alternatives require strong evidence of superior performance in most cases.

3. **Ensure safety and predictability**:

Defaults must be safe, expected and reliable. Avoid settings that could lead to surprising outcomes, such as excessive memory usage or poor performance in edge cases.

4. **Balance consistency and flexibility**:

Aim for consistent defaults across similar functions or methods. However, consistency should not be preferred to point 2 or 3.

5. **Opt-in for new features**:

Do not enable new features or improvements (e.g., novel loss functions) by default. Users should explicitly opt-in to use these.

### Writing documentation

High-quality documentation is crucial for maintaining a project that is easy to use, understand, and extend. When adding new features, ensure they are thoroughly documented to maintain consistency and clarity throughout the project.
Expand Down Expand Up @@ -356,7 +378,7 @@ def replicate_str(string: str, n: int, sep: str = " ") -> str:
...
```

### Deprecation and Backward Compatibility
### Deprecation and backward compatibility

Our approach to deprecation and backward compatibility is flexible and based on the feature’s usage and impact. Each deprecation is carefully evaluated, aiming to balance innovation with user needs.

Expand Down
14 changes: 11 additions & 3 deletions docs/source/reducing_memory_usage.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,13 @@ Section under construction. Feel free to contribute!

## Truncation

Sequence lengths in the dataset can vary widely, and by default, TRL does not modify the data. When data is batched, sequences are padded to match the longest one in the batch, which can cause high memory usage, even if most sequences are relatively short.
Sequence lengths in the dataset can vary widely. When data is batched, sequences are padded to match the longest one in the batch, which can cause high memory usage, even if most sequences are relatively short.

<div class="flex justify-center">
<img src="https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/why_you_should_truncate.png" alt="Truncation prompt completion" width="600"/>
</div>

To reduce memory usage, it’s important to truncate sequences to a reasonable length. Even discarding just a few tokens from the dataset can result in significant memory savings by minimizing unnecessary padding. Truncation is a good practice and should always be applied to ensure efficient use of resources. While the truncation limit doesn’t need to be overly restrictive, setting a sensible value is essential for optimal performance.
To reduce memory usage, it’s important to truncate sequences to a reasonable length. While TRL trainers truncate sequences by default, you may want to adjust the default truncation length to better align with your specific use case.

<hfoptions id="dpo">
<hfoption id="DPO">
Expand All @@ -30,7 +30,15 @@ To set the truncation parameters, use the following code snippet:
```python
from trl import DPOConfig

training_args = DPOConfig(..., max_prompt_length=..., max_completion_length=..., max_length=...)
training_args = DPOConfig(..., max_prompt_length=..., max_length=...)
```

You can also use the `max_completion_length` parameter to truncate the completion, though this is less common since the goal is typically to preserve the completion's full length whenever possible.

```python
from trl import DPOConfig

training_args = DPOConfig(..., max_completion_length=...)
```

</hfoption>
Expand Down
8 changes: 4 additions & 4 deletions trl/trainer/bco_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,10 @@ class BCOConfig(TrainingArguments):
command line.

Parameters:
max_length (`int` or `None`, *optional*, defaults to `None`):
max_length (`int` or `None`, *optional*, defaults to `1024`):
Maximum length of the sequences (prompt + completion) in the batch. This argument is required if you want
to use the default data collator.
max_prompt_length (`int` or `None`, *optional*, defaults to `None`):
max_prompt_length (`int` or `None`, *optional*, defaults to `512`):
Maximum length of the prompt. This argument is required if you want to use the default data collator.
max_completion_length (`int` or `None`, *optional*, defaults to `None`):
Maximum length of the completion. This argument is required if you want to use the default data collator
Expand Down Expand Up @@ -74,14 +74,14 @@ class BCOConfig(TrainingArguments):
"""

max_length: Optional[int] = field(
default=None,
default=1024,
metadata={
"help": "Maximum length of the sequences (prompt + completion) in the batch. "
"This argument is required if you want to use the default data collator."
},
)
max_prompt_length: Optional[int] = field(
default=None,
default=512,
metadata={
"help": "Maximum length of the prompt. "
"This argument is required if you want to use the default data collator."
Expand Down
8 changes: 4 additions & 4 deletions trl/trainer/cpo_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,10 @@ class CPOConfig(TrainingArguments):
learning_rate (`float`, *optional*, defaults to `1e-6`):
Initial learning rate for [`AdamW`] optimizer. The default value replaces that of
[`~transformers.TrainingArguments`].
max_length (`int` or `None`, *optional*, defaults to `None`):
max_length (`int` or `None`, *optional*, defaults to `1024`):
Maximum length of the sequences (prompt + completion) in the batch. This argument is required if you want
to use the default data collator.
max_prompt_length (`int` or `None`, *optional*, defaults to `None`):
max_prompt_length (`int` or `None`, *optional*, defaults to `512`):
Maximum length of the prompt. This argument is required if you want to use the default data collator.
max_completion_length (`int` or `None`, *optional*, defaults to `None`):
Maximum length of the completion. This argument is required if you want to use the default data collator
Expand Down Expand Up @@ -86,11 +86,11 @@ class CPOConfig(TrainingArguments):
},
)
max_length: Optional[int] = field(
default=None,
default=1024,
metadata={"help": "Maximum length of the sequences (prompt + completion) in the batch."},
)
max_prompt_length: Optional[int] = field(
default=None,
default=512,
metadata={
"help": "Maximum length of the prompt. This argument is required if you want to use the default data "
"collator and your model is an encoder-decoder."
Expand Down
8 changes: 4 additions & 4 deletions trl/trainer/dpo_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,11 +73,11 @@ class DPOConfig(TrainingArguments):
Padding value to use for labels.
truncation_mode (`str`, *optional*, defaults to `"keep_end"`):
Truncation mode to usewhen the prompt is too long, either `keep_end` or `keep_start`.
max_prompt_length (`int` or `None`, *optional*, defaults to `None`):
max_prompt_length (`int` or `None`, *optional*, defaults to `512`):
Maximum length of the prompt.
max_completion_length (`int` or `None`, *optional*, defaults to `None`):
Maximum length of the completion.
max_length (`int` or `None`, *optional*, defaults to `None`):
max_length (`int` or `None`, *optional*, defaults to `1024`):
Maximum length of the full sequence (prompt + completion).
padding_free (`bool`, *optional*, defaults to `False`):
Whether forward passes are performed without padding by flattening all sequences in the batch
Expand Down Expand Up @@ -224,15 +224,15 @@ class DPOConfig(TrainingArguments):
},
)
max_prompt_length: Optional[int] = field(
default=None,
default=512,
metadata={"help": "Maximum length of the prompt."},
)
max_completion_length: Optional[int] = field(
default=None,
metadata={"help": "Maximum length of the completion."},
)
max_length: Optional[int] = field(
default=None,
default=1024,
metadata={"help": "Maximum length of the full sequence (prompt + completion)."},
)
padding_free: bool = field(
Expand Down
8 changes: 4 additions & 4 deletions trl/trainer/kto_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,10 @@ class KTOConfig(TrainingArguments):
learning_rate (`float`, *optional*, defaults to `5e-7`):
Initial learning rate for [`AdamW`] optimizer. The default value replaces that of
[`~transformers.TrainingArguments`].
max_length (`int` or `None`, *optional*, defaults to `None`):
max_length (`int` or `None`, *optional*, defaults to `1024`):
Maximum length of the sequences (prompt + completion) in the batch. This argument is required if you want
to use the default data collator.
max_prompt_length (`int` or `None`, *optional*, defaults to `None`):
max_prompt_length (`int` or `None`, *optional*, defaults to `512`):
Maximum length of the prompt. This argument is required if you want to use the default data collator.
max_completion_length (`int` or `None`, *optional*, defaults to `None`):
Maximum length of the completion. This argument is required if you want to use the default data collator
Expand Down Expand Up @@ -88,11 +88,11 @@ class KTOConfig(TrainingArguments):
},
)
max_length: Optional[int] = field(
default=None,
default=1024,
metadata={"help": "Maximum length of the sequences (prompt + completion) in the batch."},
)
max_prompt_length: Optional[int] = field(
default=None,
default=512,
metadata={
"help": "Maximum length of the prompt. This argument is required if you want to use the default data "
"collator and your model is an encoder-decoder."
Expand Down
2 changes: 1 addition & 1 deletion trl/trainer/online_dpo_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ class OnlineDPOConfig(TrainingArguments):
metadata={"help": "Maximum number of tokens to generate per completion."},
)
max_length: int = field(
default=256,
default=512,
metadata={
"help": "Maximum total length of the sequence (prompt + completion) used to compute log probabilities. If "
"the sequence exceeds this limit, the leftmost tokens will be truncated to preserve as much of the "
Expand Down
8 changes: 4 additions & 4 deletions trl/trainer/orpo_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,10 @@ class ORPOConfig(TrainingArguments):
learning_rate (`float`, *optional*, defaults to `1e-6`):
Initial learning rate for [`AdamW`] optimizer. The default value replaces that of
[`~transformers.TrainingArguments`].
max_length (`int` or `None`, *optional*, defaults to `None`):
max_length (`int` or `None`, *optional*, defaults to `1024`):
Maximum length of the sequences (prompt + completion) in the batch. This argument is required if you want
to use the default data collator.
max_prompt_length (`int` or `None`, *optional*, defaults to `None`):
max_prompt_length (`int` or `None`, *optional*, defaults to `512`):
Maximum length of the prompt. This argument is required if you want to use the default data collator.
max_completion_length (`int` or `None`, *optional*, defaults to `None`):
Maximum length of the completion. This argument is required if you want to use the default data collator
Expand Down Expand Up @@ -71,11 +71,11 @@ class ORPOConfig(TrainingArguments):
},
)
max_length: Optional[int] = field(
default=None,
default=1024,
metadata={"help": "Maximum length of the sequences (prompt + completion) in the batch."},
)
max_prompt_length: Optional[int] = field(
default=None,
default=512,
metadata={
"help": "Maximum length of the prompt. This argument is required if you want to use the default data "
"collator and your model is an encoder-decoder."
Expand Down
10 changes: 8 additions & 2 deletions trl/trainer/prm_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,10 @@ class PRMConfig(TrainingArguments):
learning_rate (`float`, *optional*, defaults to `1e-5`):
Initial learning rate for [`AdamW`] optimizer. The default value replaces that of
[`~transformers.TrainingArguments`].
max_length (`int` or `None`, *optional*, defaults to `None`):
max_length (`int` or `None`, *optional*, defaults to `1024`):
Maximum length of the sequences (prompt + completion) used for truncation.
max_prompt_length (`int` or `None`, *optional*, defaults to `512`):
Maximum length of the prompt used for truncation.
max_completion_length (`int` or `None`, *optional*, defaults to `None`):
Maximum length of the completion used for truncation. The completion is the concatenation of the steps.
disable_dropout (`bool`, *optional*, defaults to `True`):
Expand All @@ -53,9 +55,13 @@ class PRMConfig(TrainingArguments):
},
)
max_length: Optional[int] = field(
default=None,
default=1024,
metadata={"help": "Maximum length of the sequences (prompt + completion) used for truncation."},
)
max_prompt_length: Optional[int] = field(
default=512,
metadata={"help": "Maximum length of the prompt used for truncation."},
)
max_completion_length: Optional[int] = field(
default=None,
metadata={
Expand Down
21 changes: 17 additions & 4 deletions trl/trainer/prm_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,7 @@ def __init__(
"tokenizer": processing_class,
"step_separator": args.step_separator,
"max_length": args.max_length,
"max_prompt_length": args.max_prompt_length,
"max_completion_length": args.max_completion_length,
"train_on_last_step_only": args.train_on_last_step_only,
}
Expand Down Expand Up @@ -204,7 +205,14 @@ def __init__(

@staticmethod
def tokenize_row(
features, tokenizer, step_separator, max_length, max_completion_length, train_on_last_step_only, is_eval
features,
tokenizer,
step_separator,
max_length,
max_prompt_length,
max_completion_length,
train_on_last_step_only,
is_eval,
):
r"""
Tokenize a row of the dataset.
Expand All @@ -218,6 +226,8 @@ def tokenize_row(
Separator between steps in the completion.
max_length (`int` or `None`):
Maximum length of the sequences (prompt + completion). If `None`, the sequences are not truncated.
max_prompt_length (`int` or `None`):
Maximum length of the prompt. If `None`, the prompt is not truncated.
max_completion_length (`int` or `None`):
Maximum length of the completion sequences. If `None`, the completion sequences are not truncated.
train_on_last_step_only (`bool`):
Expand Down Expand Up @@ -264,13 +274,16 @@ def tokenize_row(
completion_ids = list(chain(*completions_ids))
labels = list(chain(*labels))

if tokenizer.bos_token_id is not None:
prompt_ids = [tokenizer.bos_token_id] + prompt_ids

# Truncate prompt and completion sequences
if max_prompt_length is not None:
prompt_ids = prompt_ids[-max_prompt_length:]
if max_completion_length is not None:
completion_ids = completion_ids[:max_completion_length]
labels = labels[:max_completion_length]

if tokenizer.bos_token_id is not None:
prompt_ids = [tokenizer.bos_token_id] + prompt_ids

input_ids = prompt_ids + completion_ids
labels = [-100] * len(prompt_ids) + labels

Expand Down
4 changes: 2 additions & 2 deletions trl/trainer/reward_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ class RewardConfig(TrainingArguments):
command line.

Parameters:
max_length (`int` or `None`, *optional*, defaults to `None`):
max_length (`int` or `None`, *optional*, defaults to `1024`):
Maximum length of the sequences (prompt + completion) in the batch. This argument is required if you want
to use the default data collator.
disable_dropout (`bool`, *optional*, defaults to `True`):
Expand All @@ -44,7 +44,7 @@ class RewardConfig(TrainingArguments):
"""

max_length: Optional[int] = field(
default=None,
default=1024,
metadata={
"help": "Maximum length of the sequences (prompt + completion) in the batch. This argument is required if "
"you want to use the default data collator."
Expand Down
9 changes: 2 additions & 7 deletions trl/trainer/reward_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,6 @@ def __init__(
None,
),
preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,
max_length: Optional[int] = None,
peft_config: Optional[dict] = None,
):
"""
Expand Down Expand Up @@ -140,10 +139,6 @@ def __init__(
peft_config (`dict`, defaults to `None`):
The PEFT configuration to use for training. If you pass a PEFT configuration, the model will be wrapped in a PEFT model.
"""
if max_length is not None and args.max_length is not None:
raise ValueError(
"You cannot specify both `max_length` and `args.max_length`. Please use the `RewardConfig` to set `max_length` once."
)
if not is_peft_available() and peft_config is not None:
raise ValueError(
"PEFT is not installed and you passed a `peft_config` in the trainer's kwargs, please install it to use the PEFT models"
Expand Down Expand Up @@ -182,8 +177,8 @@ def __init__(
raise ValueError(
"A processing_class must be specified when using the default RewardDataCollatorWithPadding"
)
if max_length is None:
max_length = 512 if args.max_length is None else args.max_length

max_length = args.max_length

data_collator = RewardDataCollatorWithPadding(processing_class)

Expand Down