Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Learning to generate EOS tokens #1623

Closed
vwxyzjn opened this issue May 6, 2024 · 7 comments
Closed

Learning to generate EOS tokens #1623

vwxyzjn opened this issue May 6, 2024 · 7 comments

Comments

@vwxyzjn
Copy link
Contributor

vwxyzjn commented May 6, 2024

@edbeeching and I noticed sometimes the trained SFT models do not learn to stop generations. In other words, the model never learn to generate EOS tokens.

Upon some digging, I noticed this is mainly an issue with the dataset preprocessing. In particular, if we simply pass a dataset like https://huggingface.co/datasets/timdettmers/openassistant-guanaco to the SFTTrainer, the trainer may not postpend the completion with an EOS token.

If we run for item1, item2 in zip(inputs["input_ids"][1], inputs["attention_mask"][1]): print(item1, item2) at https://github.com/huggingface/transformers/blob/91d155ea92da372b319a79dd4eef69533ee15170/src/transformers/trainer.py#L3207, with our SFT example we get

python examples/scripts/sft.py \
    --model_name_or_path="facebook/opt-350m" \
    --report_to="wandb" \
    --learning_rate=1.41e-5 \
    --per_device_train_batch_size=2 \
    --gradient_accumulation_steps=16 \
    --output_dir="sft_openassistant-guanaco" \
    --logging_steps=1 \
    --num_train_epochs=3 \
    --max_steps=-1 \
    --push_to_hub \
    --gradient_checkpointing \
    --dataset_text_field text
image

Notice how the pad token / eos token corresponds to attention mask = 0.

potential solution

This can be resolved if we add an eos token to the dataset itself. For example,

"{% for message in messages %}{{' ' + message['content']}}{% endfor %}{{ eos_token }}"
always adds an EOS token to the tokenized dataset, and as a result we get

python examples/scripts/minimal/sft.py \
    --learning_rate 3e-6 \
    --per_device_train_batch_size 1 \
    --per_device_eval_batch_size 1 \
    --gradient_accumulation_steps 32 \
    --learning_rate 5e-05 \
    --logging_steps 10 \
    --evaluation_strategy epoch \
    --max_seq_length 1024 \
    --num_train_epochs 5 \
    --output_dir models/minimal/sft
image

Notice how the first eos token corresponds to attention mask = 1.

@edbeeching
Copy link
Collaborator

As I mentioned on out internal slack, we should probably add a line such as:

    if sft_config.packing is False:
        tokenizer.add_eos_token = True

this needs to be removed before saving the model as otherwise generation is broken:

    if sft_config.packing is False:
        # setting this as true breaks generation during evaluation
        tokenizer.add_eos_token = False

I tested these additions in h4 and it resolved many of the issues we saw with models trained with packing=False.

@yananchen1989
Copy link

is this an issue when packing=True ? I also do find that the generations from the SFT model are quite wordy.

@derekelewis
Copy link

derekelewis commented May 9, 2024

@yananchen1989 I believe the answer is yes for packing=True & packing=False. I'm experiencing lack of predicting EOS on SFTTrainer fine-tuned models w/ using chat templates. Still doing testing, but it doesn't seem to be an issue when not using chat templates and using formatting_func instead.

PEFT also seems to be a contributing factor. No PEFT and EOS is predicted correctly. W/ PEFT and EOS is not correctly predicted.

@vwxyzjn
Copy link
Contributor Author

vwxyzjn commented May 15, 2024

Actually I am not even sure if setting the tokenizer.pad_token = tokenizer.eos_token would work. Even if the dataset has an EOS token, what happens is that attention_mask is set to 1, but the label is still set to -100, so the loss on the EOS token is still masked out.

for input_id, attention_mask, label in zip(inputs["input_ids"][0], inputs["attention_mask"][0], inputs["labels"][0]): print(f"{input_id=}, {attention_mask=}, {label=}")
input_id=tensor(15, device='cuda:0'), attention_mask=tensor(1, device='cuda:0'), label=tensor(15, device='cuda:0')
input_id=tensor(0, device='cuda:0'), attention_mask=tensor(1, device='cuda:0'), label=tensor(-100, device='cuda:0')
input_id=tensor(0, device='cuda:0'), attention_mask=tensor(0, device='cuda:0'), label=tensor(-100, device='cuda:0')
input_id=tensor(0, device='cuda:0'), attention_mask=tensor(0, device='cuda:0'), label=tensor(-100, device='cuda:0')
input_id=tensor(0, device='cuda:0'), attention_mask=tensor(0, device='cuda:0'), label=tensor(-100, device='cuda:0')

@yananchen1989
Copy link

yes, i agree that no matter packing is set or not, EOS token has not been properly predicted which causes lengthy output.

@vwxyzjn
Copy link
Contributor Author

vwxyzjn commented May 21, 2024

@yananchen1989 FYI when packing is set this should not be a problem. See #1646 (comment).

Copy link

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants