Skip to content

Commit

Permalink
Skip packing validation (#1673)
Browse files Browse the repository at this point in the history
* Add test for skipping preproc if packing=True

Signed-off-by: Alex-Brooks <[email protected]>

* Allow skipping of validation for packing=True

Signed-off-by: Alex-Brooks <[email protected]>

* Use dummy dataset in no packing preproc test

Signed-off-by: Alex-Brooks <[email protected]>

---------

Signed-off-by: Alex-Brooks <[email protected]>
  • Loading branch information
alex-jw-brooks authored Jun 3, 2024
1 parent 6c203f9 commit 4eb0b90
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 1 deletion.
23 changes: 23 additions & 0 deletions tests/test_sft_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1156,6 +1156,29 @@ def test_sft_trainer_skip_prepare_dataset(self):
assert trainer.train_dataset.features == self.dummy_vsft_instruction_dataset.features
assert trainer.eval_dataset.features == self.dummy_vsft_instruction_dataset.features

def test_sft_trainer_skip_prepare_dataset_with_no_packing(self):
with tempfile.TemporaryDirectory() as tmp_dir:
training_args = SFTConfig(
output_dir=tmp_dir,
dataloader_drop_last=True,
eval_strategy="steps",
max_steps=4,
eval_steps=2,
save_steps=2,
per_device_train_batch_size=2,
gradient_checkpointing=True,
remove_unused_columns=False,
packing=False,
dataset_kwargs={"skip_prepare_dataset": True},
)

trainer = SFTTrainer(
model=self.model_id,
args=training_args,
train_dataset=self.dummy_dataset,
)
assert trainer.train_dataset.features == self.dummy_dataset.features

@requires_pil
def test_sft_trainer_llava(self):
with tempfile.TemporaryDirectory() as tmp_dir:
Expand Down
10 changes: 9 additions & 1 deletion trl/trainer/sft_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,7 +319,15 @@ def make_inputs_require_grad(module, input, output):
dataset_kwargs["add_special_tokens"] = False

if not args.packing:
if args.dataset_text_field is None and formatting_func is None:
# If we aren't skipping data preparation, then a dataset_text_field
# or formatting_func must be provided.
if (
args.dataset_text_field is None
and formatting_func is None
and dataset_kwargs is not None
and "skip_prepare_dataset" in dataset_kwargs
and dataset_kwargs["skip_prepare_dataset"]
):
raise ValueError(
"You passed `packing=False` to the SFTTrainer/SFTConfig, but you didn't pass a `dataset_text_field` or `formatting_func` argument."
)
Expand Down

0 comments on commit 4eb0b90

Please sign in to comment.