diff --git a/tests/test_sft_trainer.py b/tests/test_sft_trainer.py index dd896823e2..30729a6a41 100644 --- a/tests/test_sft_trainer.py +++ b/tests/test_sft_trainer.py @@ -1156,6 +1156,29 @@ def test_sft_trainer_skip_prepare_dataset(self): assert trainer.train_dataset.features == self.dummy_vsft_instruction_dataset.features assert trainer.eval_dataset.features == self.dummy_vsft_instruction_dataset.features + def test_sft_trainer_skip_prepare_dataset_with_no_packing(self): + with tempfile.TemporaryDirectory() as tmp_dir: + training_args = SFTConfig( + output_dir=tmp_dir, + dataloader_drop_last=True, + eval_strategy="steps", + max_steps=4, + eval_steps=2, + save_steps=2, + per_device_train_batch_size=2, + gradient_checkpointing=True, + remove_unused_columns=False, + packing=False, + dataset_kwargs={"skip_prepare_dataset": True}, + ) + + trainer = SFTTrainer( + model=self.model_id, + args=training_args, + train_dataset=self.dummy_dataset, + ) + assert trainer.train_dataset.features == self.dummy_dataset.features + @requires_pil def test_sft_trainer_llava(self): with tempfile.TemporaryDirectory() as tmp_dir: diff --git a/trl/trainer/sft_trainer.py b/trl/trainer/sft_trainer.py index c3bb7c7c98..a11c99128e 100644 --- a/trl/trainer/sft_trainer.py +++ b/trl/trainer/sft_trainer.py @@ -319,7 +319,15 @@ def make_inputs_require_grad(module, input, output): dataset_kwargs["add_special_tokens"] = False if not args.packing: - if args.dataset_text_field is None and formatting_func is None: + # If we aren't skipping data preparation, then a dataset_text_field + # or formatting_func must be provided. + if ( + args.dataset_text_field is None + and formatting_func is None + and dataset_kwargs is not None + and "skip_prepare_dataset" in dataset_kwargs + and dataset_kwargs["skip_prepare_dataset"] + ): raise ValueError( "You passed `packing=False` to the SFTTrainer/SFTConfig, but you didn't pass a `dataset_text_field` or `formatting_func` argument." )