Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SFTTrainer: Fix backward Compatibility issue with TrainingArguments #1707

Merged
merged 2 commits into from
Jun 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 26 additions & 0 deletions tests/test_sft_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
AutoProcessor,
AutoTokenizer,
LlavaForConditionalGeneration,
TrainingArguments,
)

from trl import SFTConfig, SFTTrainer
Expand Down Expand Up @@ -213,6 +214,31 @@ def test_constant_length_dataset(self):
decoded_text = self.tokenizer.decode(example["input_ids"])
assert ("Question" in decoded_text) and ("Answer" in decoded_text)

def test_sft_trainer_backward_compatibility(self):
with tempfile.TemporaryDirectory() as tmp_dir:
training_args = TrainingArguments(
output_dir=tmp_dir,
eval_strategy="steps",
max_steps=4,
eval_steps=2,
save_steps=2,
per_device_train_batch_size=2,
)

trainer = SFTTrainer(
model=self.model_id,
args=training_args,
train_dataset=self.train_dataset,
eval_dataset=self.eval_dataset,
)

trainer.train()

assert trainer.state.log_history[(-1)]["train_loss"] is not None
assert trainer.state.log_history[0]["eval_loss"] is not None

assert "model.safetensors" in os.listdir(tmp_dir + "/checkpoint-2")

def test_sft_trainer(self):
with tempfile.TemporaryDirectory() as tmp_dir:
training_args = SFTConfig(
Expand Down
2 changes: 2 additions & 0 deletions trl/trainer/sft_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,8 @@ def __init__(
output_dir = "tmp_trainer"
warnings.warn(f"No `SFTConfig` passed, using `output_dir={output_dir}`.")
args = SFTConfig(output_dir=output_dir)
elif args is not None and args.__class__.__name__ == "TrainingArguments":
args = SFTConfig(**args.to_dict())
Comment on lines +148 to +149
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice fix!

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This still fails for me with error: TypeError: SFTConfig.__init__() got an unexpected keyword argument 'cache_dir' because TrainingArguments has a field for cache_dir but SFTConfig does not

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm that's strange, there is no field cache_dir in TrainingArguments: https://github.com/huggingface/transformers/blob/main/src/transformers/training_args.py - can you show us how do you get that error?

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Apologies I see this is mixture of TrainingArguments where my project built a class off of transformers.TrainingArguments that we pass into trl.SFTTrainer and so the additional parameters that we are adding to TrainingArguments is causing the discrepancy issue.


if model_init_kwargs is not None:
warnings.warn(
Expand Down
Loading