Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Using PEFT causes model to not predict EOS #1578

Closed
Km3888 opened this issue Apr 23, 2024 · 2 comments
Closed

Using PEFT causes model to not predict EOS #1578

Km3888 opened this issue Apr 23, 2024 · 2 comments

Comments

@Km3888
Copy link

Km3888 commented Apr 23, 2024

System Info

peft version: 0.9.0
accelerate version: 0.27.2
transformers version: 4.37.0
trl version: 0.7.12.dev0
base model: openai-community/gpt2
hardware: 2xA100

Issue

I'm doing a LORA peft of GPT2 through trl and have noticed that my trained model assigns very low probability to the EOS token which causes it to alway generate the maximum number of tokens.

After trying a few different fixes I ran the code without the PEFT option and just used the base model. The problem resolved immediately.

To make the comparison clear I created a toy case with a dataset that contains the same datapoint ("Hello <|endoftext|>") repeatedly. I then overfit on this dataset with a small batch size for a few dozen iterations. To see the effect on the probability of generating the eos_token I inserted the following code fragment in my compute_metrics method:

logits, labels = eval_preds
eos_indices = np.where(labels==tokenizer.eos_token_id)
model_distribution = torch.softmax(torch.tensor(logits),dim=-1).numpy()
eos_probs = model_distribution[eos_indices[0],eos_indices[1],-1]
eos_probs = [format(x*100,'.3f') for x in eos_probs.tolist()]
print('eos probs:',eos_probs)

The basic full finetuning results in the EOS token probability converging to 1 almost immediately as the model memorizes the location of the EOS tokens. However if I just use TRL's code for a LORA PEFT the printed values remain close to zero and don't increase at all.

I've seen some references online suggesting that this could be caused by LORA not updating the model's embedding matrix. So I added the following change to the peft_config: peft_config.modules_to_save = ["wte"]. This doesn't have any effect on the results. I'm also doubtful this is the cause as when I run the supervised finetuning I don't see any change in the embedding matrix but get the desired results anyway.

Any help would be appreciated as I would like to avoid a full finetuning but right now have no way of getting a functional model with a PEFT.

Reproduction

Use the following model_config (note the PEFT parameters) and training arguments:

ModelConfig(model_name_or_path='openai-community/gpt2', model_revision='main', torch_dtype=None, trust_remote_code=False, attn_implementation=None, use_peft=True, lora_r=64, lora_alpha=16, lora_dropout=0.05, lora_target_modules=None, lora_modules_to_save=None, load_in_8bit=False, load_in_4bit=False, bnb_4bit_quant_type='nf4', use_bnb_nested_quant=False)

python examples/scripts/sft_overfit.py
--model_name_or_path=$GPT2
--report_to="wandb"
--learning_rate=1.41e-5
--per_device_train_batch_size=4
--per_device_eval_batch_size=4
--gradient_accumulation_steps=1
--output_dir='/scratch/km3888/gcode_peft/${SLURM_ARRAY_JOB_ID}_${SLURM_ARRAY_TASK_ID}'
--logging_steps=1
--num_train_epochs=1
--max_seq_length=1024
--max_steps=20000
--push_to_hub
--evaluation_strategy="steps"
--eval_steps=10
--eval_accumulation_steps=1
--gradient_checkpointing
--use_peft
--lora_r=64
--lora_alpha=16

Create dataset:

import copy
dummy_data = [{"text":"Hello <|endoftext|>"} for _ in range(1000)]
with open("dummy_data.json","w") as f:
json.dump(dummy_data,f)
full_dataset = load_dataset('json', data_files="dummy_data.json",split='train')
full_dataset = full_dataset.map(lambda x: {'text':add_eos(x['text'])})
split_data = full_dataset.train_test_split(test_size=0.05)
train_dataset = split_data['train'].shuffle()
eval_dataset = copy.deepocpy(train_dataset)
tokenizer = AutoTokenizer.from_pretrained(model_config.model_name_or_path, token=access_token,use_fast=True, add_eos=True)

Set up custom evaluation function:

def compute_metrics(eval_preds):
metric = evaluate.load("accuracy",training_args.output_dir.split('/')[-1])
logits, labels = eval_preds
eos_indices = np.where(labels==tokenizer.eos_token_id)
model_distribution = torch.softmax(torch.tensor(logits),dim=-1).numpy()
eos_probs = model_distribution[eos_indices[0],eos_indices[1],-1]
eos_probs = [format(x*100,'.3f') for x in eos_probs.tolist()]
print('eos probs:',eos_probs)
predictions = np.argmax(logits,axis=-1)
predictions = np.reshape(predictions.astype(np.int32),-1)
labels = np.reshape(labels.astype(np.int32),-1)
return metric.compute(predictions=predictions, references=labels)

Instantiate and run SFTTrainer

trainer = SFTTrainer(
model=model_config.model_name_or_path,
model_init_kwargs=model_kwargs,
args=training_args,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
dataset_text_field="text",
max_seq_length=args.max_seq_length,
tokenizer=tokenizer,
packing=False,
peft_config=get_peft_config(model_config),
compute_metrics=compute_metrics,
dataset_num_proc=20)

trainer.train()

The eos_probs printed in compute_metrics will be near-zero

Expected behavior

I would expect the above code to result in eos_probs values being nearly 1 after a few training iterations.

@derekelewis
Copy link

Running into this same issue myself. No PEFT and EOS is predicted fine. w/ PEFT and EOS is not predicted at all, which causes the text generation pipeline to continue until max_tokens.

Copy link

github-actions bot commented Jun 3, 2024

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants