Skip to content

Commit

Permalink
Fixed doc string and docs for the SFTConfig update (#1706)
Browse files Browse the repository at this point in the history
  • Loading branch information
GuilhermeFreire authored Jun 6, 2024
1 parent 275d33b commit 0bdc638
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 10 deletions.
14 changes: 7 additions & 7 deletions docs/source/sft_trainer.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ sft_config = SFTConfig(
trainer = SFTTrainer(
"facebook/opt-350m",
train_dataset=dataset,
args=training_args,
args=sft_config,
)
trainer.train()
```
Expand Down Expand Up @@ -263,23 +263,23 @@ To properly format your input make sure to process all the examples by looping o

### Packing dataset ([`ConstantLengthDataset`])

[`SFTTrainer`] supports _example packing_, where multiple short examples are packed in the same input sequence to increase training efficiency. This is done with the [`ConstantLengthDataset`] utility class that returns constant length chunks of tokens from a stream of examples. To enable the usage of this dataset class, simply pass `packing=True` to the [`SFTTrainer`] constructor.
[`SFTTrainer`] supports _example packing_, where multiple short examples are packed in the same input sequence to increase training efficiency. This is done with the [`ConstantLengthDataset`] utility class that returns constant length chunks of tokens from a stream of examples. To enable the usage of this dataset class, simply pass `packing=True` to the [`SFTConfig`] constructor.

```python
...
sft_config = SFTConfig(packing=True, dataset_text_field="text",)

trainer = SFTTrainer(
"facebook/opt-350m",
train_dataset=dataset,
dataset_text_field="text",
packing=True
args=sft_config
)

trainer.train()
```

Note that if you use a packed dataset and if you pass `max_steps` in the training arguments you will probably train your models for more than few epochs, depending on the way you have configured the packed dataset and the training protocol. Double check that you know and understand what you are doing.
If you don't want to pack your `eval_dataset`, you can pass `eval_packing=False` to the `SFTTrainer` init method.
If you don't want to pack your `eval_dataset`, you can pass `eval_packing=False` to the `SFTConfig` init method.

#### Customize your prompts using packed dataset

Expand All @@ -300,11 +300,11 @@ trainer = SFTTrainer(

trainer.train()
```
You can also customize the [`ConstantLengthDataset`] much more by directly passing the arguments to the [`SFTTrainer`] constructor. Please refer to that class' signature for more information.
You can also customize the [`ConstantLengthDataset`] much more by directly passing the arguments to the [`SFTConfig`] constructor. Please refer to that class' signature for more information.

### Control over the pretrained model

You can directly pass the kwargs of the `from_pretrained()` method to the [`SFTTrainer`]. For example, if you want to load a model in a different precision, analogous to
You can directly pass the kwargs of the `from_pretrained()` method to the [`SFTConfig`]. For example, if you want to load a model in a different precision, analogous to

```python
model = AutoModelForCausalLM.from_pretrained("facebook/opt-350m", torch_dtype=torch.bfloat16)
Expand Down
6 changes: 3 additions & 3 deletions trl/trainer/sft_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,9 +66,9 @@ class SFTTrainer(Trainer):
The model to train, can be a `PreTrainedModel`, a `torch.nn.Module` or a string with the model name to
load from cache or download. The model can be also converted to a `PeftModel` if a `PeftConfig` object is
passed to the `peft_config` argument.
args (Optional[`transformers.TrainingArguments`]):
The arguments to tweak for training. Please refer to the official documentation of `transformers.TrainingArguments`
for more information.
args (Optional[`SFTConfig`]):
The arguments to tweak for training. Will default to a basic instance of [`SFTConfig`] with the `output_dir`
set to a directory named *tmp_trainer* in the current directory if not provided.
data_collator (Optional[`transformers.DataCollator`]):
The data collator to use for training.
train_dataset (Optional[`datasets.Dataset`]):
Expand Down

0 comments on commit 0bdc638

Please sign in to comment.