Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DPO models generate multiple / corrupted responses #1025

Open
Devy99 opened this issue Nov 22, 2023 · 57 comments
Open

DPO models generate multiple / corrupted responses #1025

Devy99 opened this issue Nov 22, 2023 · 57 comments
Labels
🏋 DPO Related to DPO 🙋 help from community wanted Open invitation for community members to contribute

Comments

@Devy99
Copy link

Devy99 commented Nov 22, 2023

Hi, I am running some tests with DPOTrainer to see how it works but I have encountered some problems during the inference phase of the generated model. In details, this is the pipeline of operations I performed:

  1. I pre-trained from scratch a T5 model on natural language (English language). For this operation, I followed the instructions of the Hugging Face library. As for training the tokenizer, this was done using the sentencepiece library. The generated file (extension .model) was then used through the T5Tokenizer class, which allows using the .model file instead of a json file.

  2. I fine-tuned T5 using a very trivial dataset such as the following.

    Input Target
    I love cats a
    The cat is orange b
    The cat is on the table c
    The cat chased the mouse under the table. d

    In summary, if there is no word 'the' in the input then the output will be 'a', if there is only one occurrence of 'the' then the output will be 'b', and so on... For fine-tuning, I did not use the SFTTrainer class but the classic Seq2SeqTrainer.

  3. Then, I performed the DPO with the same inputs as the dataset present above, but in the JSON format. The code used is the same as the example on the repository. In this case, however, we used our finetuned T5 model and tokenizer (with classes T5ForConditionalGeneration, T5Tokenizer, T5Config). You can find the JSON file and the full code at the end of this message.

The problem arises in the inference phase of the model generated by the DPOTrainer. In fact, for several instances the output generated by the model is 'a a a a a a', ' b b b b b b b', 'c c c c c c c c', and so on... (the number of repetitions of the class is variable). Moreover, this behavior becomes more pronounced as the number of steps increases. Also, as the number of steps increases, words that are part of the train set are generated in the output (e.g., 'aaacat' is generated).

I cannot figure out what could be the cause of this behavior. By making inference of the simply fine-tuned model, the output generated is as expected (i.e., a class between 'a', 'b', 'c' and 'd'), so the problem is introduced during training with DPO. I also tried to use the pre-trained 't5-small' model / tokenizer instead of the ones trained from scratch, but the problem still persists.

I look forward to your feedback should more information or snippets of code used be needed.

DPO dataset [ { 'prompt': 'I love cats', 'chosen': 'a', 'rejected': 'b', }, { 'prompt': 'I love cats', 'chosen': 'a', 'rejected': 'c', }, { 'prompt': 'I love cats', 'chosen': 'a', 'rejected': 'd', }, { 'prompt': 'The cat is orange', 'chosen': 'b', 'rejected': 'a', }, { 'prompt': 'The cat is orange', 'chosen': 'b', 'rejected': 'c', }, { 'prompt': 'The cat is orange', 'chosen': 'b', 'rejected': 'd', } ... ]
DPO code ```
# 0. imports
import os
from dataclasses import dataclass, field
from typing import Dict, Optional

import torch
from datasets import Dataset, load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer, HfArgumentParser, TrainingArguments, T5Config, T5Tokenizer, T5ForConditionalGeneration

from trl import DPOTrainer


# Define and parse arguments.
@dataclass
class ScriptArguments:
    """
    The arguments for the DPO training script.
    """

    # data parameters
    beta: Optional[float] = field(default=0.1, metadata={"help": "the beta parameter for DPO loss"})

    # training parameters
    model_name_or_path: Optional[str] = field(
        default="../sft/results/final_checkpoint",
        metadata={"help": "the location of the SFT model name or path"},
    )
    config_name: Optional[str] = field(
	    default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
    )
    tokenizer_name: Optional[str] = field(
        default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}
    )
    cache_dir: Optional[str] = field(
        default=None,
        metadata={"help": "Where to store the pretrained models downloaded from huggingface.co"},
    )
    use_fast_tokenizer: bool = field(
        default=False,
        metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."},
    )
    model_revision: str = field(
        default="main",
        metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."},
    )
    use_auth_token: Optional[str] = field(
        default=None,
        metadata={
            "help": "Will use the token generated when running `transformers-cli login` (necessary to use this script "
            "with private models)."
        },
    )

    train_file: Optional[str] = field(
        default=None, metadata={"help": "The input training data file (a jsonlines or csv file)."}
    )

    eval_file: Optional[str] = field(
        default=None, metadata={"help": "The input eval data file (a jsonlines or csv file)."}
    )

    learning_rate: Optional[float] = field(default=5e-4, metadata={"help": "optimizer learning rate"})
    lr_scheduler_type: Optional[str] = field(default="cosine", metadata={"help": "the lr scheduler type"})
    warmup_steps: Optional[int] = field(default=100, metadata={"help": "the number of warmup steps"})
    weight_decay: Optional[float] = field(default=0.05, metadata={"help": "the weight decay"})
    optimizer_type: Optional[str] = field(default="paged_adamw_32bit", metadata={"help": "the optimizer type"})

    per_device_train_batch_size: Optional[int] = field(default=4, metadata={"help": "train batch size per device"})
    per_device_eval_batch_size: Optional[int] = field(default=1, metadata={"help": "eval batch size per device"})
    gradient_accumulation_steps: Optional[int] = field(
        default=4, metadata={"help": "the number of gradient accumulation steps"}
    )
    gradient_checkpointing: Optional[bool] = field(
        default=True, metadata={"help": "whether to use gradient checkpointing"}
    )

    lora_alpha: Optional[float] = field(default=16, metadata={"help": "the lora alpha parameter"})
    lora_dropout: Optional[float] = field(default=0.05, metadata={"help": "the lora dropout parameter"})
    lora_r: Optional[int] = field(default=8, metadata={"help": "the lora r parameter"})

    max_prompt_length: Optional[int] = field(default=512, metadata={"help": "the maximum prompt length"})
    max_length: Optional[int] = field(default=1024, metadata={"help": "the maximum sequence length"})
    max_steps: Optional[int] = field(default=1000, metadata={"help": "max number of training steps"})
    logging_steps: Optional[int] = field(default=10, metadata={"help": "the logging frequency"})
    save_steps: Optional[int] = field(default=100, metadata={"help": "the saving frequency"})
    eval_steps: Optional[int] = field(default=100, metadata={"help": "the evaluation frequency"})

    output_dir: Optional[str] = field(default="./results", metadata={"help": "the output directory"})
    log_freq: Optional[int] = field(default=1, metadata={"help": "the logging frequency"})

    # instrumentation
    sanity_check: Optional[bool] = field(default=False, metadata={"help": "only train on 1000 samples"})
    report_to: Optional[str] = field(
        default="wandb",
        metadata={
            "help": 'The list of integrations to report the results and logs to. Supported platforms are `"azure_ml"`,'
            '`"comet_ml"`, `"mlflow"`, `"neptune"`, `"tensorboard"`,`"clearml"` and `"wandb"`. '
            'Use `"all"` to report to all integrations installed, `"none"` for no integrations.'
        },
    )
    # debug argument for distributed training
    ignore_bias_buffers: Optional[bool] = field(
        default=False,
        metadata={
            "help": "fix for DDP issues with LM bias/mask buffers - invalid scalar type,`inplace operation. See"
            "https://github.com/huggingface/transformers/issues/22482#issuecomment-1595790992"
        },
    )


def convert(
    dataset: Dataset = None,
    sanity_check: bool = False,
    cache_dir: str = None,
    num_proc=24,
) -> Dataset:
    """Load the dataset and convert it to the necessary format.

    The dataset is converted to a dictionary with the following structure:
    {
        'prompt': List[str],
        'chosen': List[str],
        'rejected': List[str],
    }
    """
    original_columns = dataset.column_names

    if sanity_check:
        dataset = dataset.select(range(min(len(dataset), 1000)))

    def return_prompt_and_responses(samples) -> Dict[str, str]:
        return {
            "prompt": samples["prompt"],
            "chosen": samples["chosen"],
            "rejected": samples["rejected"],
        }

    return dataset.map(
        return_prompt_and_responses,
        batched=True,
        num_proc=num_proc,
        remove_columns=original_columns,
    )


if __name__ == "__main__":
    parser = HfArgumentParser(ScriptArguments)
    script_args = parser.parse_args_into_dataclasses()[0]

    # 1. load a pretrained model
    config = T5Config.from_pretrained(
        script_args.config_name if script_args.config_name else script_args.model_name_or_path,
        cache_dir=script_args.cache_dir,
        revision=script_args.model_revision,
        use_auth_token=script_args.use_auth_token,
    )

    tokenizer = T5Tokenizer.from_pretrained(
        script_args.tokenizer_name if script_args.tokenizer_name else script_args.model_name_or_path,
        cache_dir=script_args.cache_dir,
        use_fast=script_args.use_fast_tokenizer,
        revision=script_args.model_revision,
        use_auth_token=script_args.use_auth_token,
    )

    model = T5ForConditionalGeneration.from_pretrained(
        script_args.model_name_or_path,
        config=config,
        cache_dir=script_args.cache_dir,
        revision=script_args.model_revision,
        use_auth_token=script_args.use_auth_token,
    )
    model.config.use_cache = False

    model_ref = T5ForConditionalGeneration.from_pretrained(
        script_args.model_name_or_path,
        config=config,
        cache_dir=script_args.cache_dir,
        revision=script_args.model_revision,
        use_auth_token=script_args.use_auth_token,
    )


    if script_args.ignore_bias_buffers:
        # torch distributed hack
        model._ddp_params_and_buffers_to_ignore = [
            name for name, buffer in model.named_buffers() if buffer.dtype == torch.bool
        ]

    # 2. Load the dataset and split in train / eval
    train_dataset = load_dataset("json", data_files=script_args.train_file, split="train")

    train_dataset = convert(dataset=train_dataset, sanity_check=script_args.sanity_check)
    train_dataset = train_dataset.filter(
        lambda x: len(x["prompt"]) + len(x["chosen"]) <= script_args.max_length
        and len(x["prompt"]) + len(x["rejected"]) <= script_args.max_length
    )

    # 3. Load evaluation dataset
    eval_dataset = load_dataset("json", data_files=script_args.eval_file, split="train")

    eval_dataset = convert(dataset=eval_dataset, sanity_check=script_args.sanity_check)
    eval_dataset = eval_dataset.filter(
        lambda x: len(x["prompt"]) + len(x["chosen"]) <= script_args.max_length
        and len(x["prompt"]) + len(x["rejected"]) <= script_args.max_length
    )

    # 4. initialize training arguments:
    training_args = TrainingArguments(
        per_device_train_batch_size=script_args.per_device_train_batch_size,
        per_device_eval_batch_size=script_args.per_device_eval_batch_size,
        max_steps=script_args.max_steps,
        logging_steps=script_args.logging_steps,
        save_steps=script_args.save_steps,
        gradient_accumulation_steps=script_args.gradient_accumulation_steps,
        gradient_checkpointing=script_args.gradient_checkpointing,
        learning_rate=script_args.learning_rate,
        evaluation_strategy="steps",
        eval_steps=script_args.eval_steps,
        output_dir=script_args.output_dir,
        lr_scheduler_type=script_args.lr_scheduler_type,
        warmup_steps=script_args.warmup_steps,
        remove_unused_columns=False,
        run_name="dpo",
    )


    # 5. initialize the DPO trainer
    dpo_trainer = DPOTrainer(
        model,
        model_ref,
        args=training_args,
        beta=script_args.beta,
        train_dataset=train_dataset,
        eval_dataset=eval_dataset,
        tokenizer=tokenizer,
        max_prompt_length=script_args.max_prompt_length,
        max_length=script_args.max_length,
    )

    # 6. train
    dpo_trainer.train()
    dpo_trainer.save_model(script_args.output_dir)

    # 7. save
    output_dir = os.path.join(script_args.output_dir, "final_checkpoint")
    dpo_trainer.model.save_pretrained(output_dir)
</details>

@Ricardokevins
Copy link

i encounter silimar problem

@lvwerra
Copy link
Member

lvwerra commented Nov 24, 2023

Is it possible that you didn't have EOS tokens in the fine-tuning/dpo phase? Then the model wouldn't know what token to produce after the letter and just keep generating things.

Also cc @kashif

@kashif
Copy link
Collaborator

kashif commented Nov 24, 2023

ok checking!

@Ricardokevins
Copy link

Is it possible that you didn't have EOS tokens in the fine-tuning/dpo phase? Then the model wouldn't know what token to produce after the letter and just keep generating things.

Also cc @kashif

I check my code, i have correctly set the tokenizer.eos token and tokenizer.padding side

I also tried adjusting various settings, and in the end, I speculated whether it was a training crash, so I lowered the learning rate and resolved the issue.

@Devy99
Copy link
Author

Devy99 commented Nov 25, 2023

Is it possible that you didn't have EOS tokens in the fine-tuning/dpo phase? Then the model wouldn't know what token to produce after the letter and just keep generating things.

Also cc @kashif

I, too, among the various tests I have done, have already checked that the EOS token was correctly set, but despite this I have not had any positive results.

@Ricardokevins, would you mind sharing some code snippets and the learning rate you used? Even after that test I could not solve the problem unfortunately.

@Devy99
Copy link
Author

Devy99 commented Nov 25, 2023

Small update: I have done more testing by reducing the learning rate and it seems to work much better than before. In any case, there are still instances in the output that take data from the train set input (e.g., 'cat' instead of 'a','b', or 'c'), however, they are significantly fewer in number than before.

@Ricardokevins
Copy link

Oh, i still have the problem, if i increase the Learning rate from 1e-6 to 2e-6
The repetition still occurs

@Ricardokevins
Copy link

I think this issue may relate to the padding process in DPODataCollatorWithPadding, the padding side is strange

the chosen_input_id and reject_input_id is right padding while the prompt_input_ids is left padding

and the instruction model for tuning i use is left padding

Will it cause the corrupted repetition? @Devy99 @lvwerra

@Ricardokevins
Copy link

I check other open source project and observe similar issue (repetition generation)

eric-mitchell/direct-preference-optimization#8
eric-mitchell/direct-preference-optimization#35

@kashif
Copy link
Collaborator

kashif commented Nov 26, 2023

@Ricardokevins do you mind trying with the refactored PR #885 where we have done some more fixes for the padding etc.

@Ricardokevins
Copy link

@Ricardokevins do you mind trying with the refactored PR #885 where we have done some more fixes for the padding etc.

@kashif Yeah, i adopt the PR and tune the model 500steps with 2e-6 learning rate.
So far, the repetition problem seems disappear ( i am not sure, i may run more experiment on the modified code later)

@Devy99 You can also try this code

@Devy99
Copy link
Author

Devy99 commented Nov 26, 2023

Thank you for the support! I'll try as soon as possible and I'll keep you updated.

@Ricardokevins
Copy link

Ricardokevins commented Nov 27, 2023

I observe the problem again with a new experient ( 5e-7 1700steps)

I am still looking for new solutions. Based on experiments, it appears that larger learning rates and longer update steps are more likely to trigger this issue. Therefore, I speculate whether it is related to the DPO loss mentioned above in another open-source project. eric-mitchell/direct-preference-optimization#35).

@Devy99
Copy link
Author

Devy99 commented Nov 28, 2023

Sorry for the late reply, but in my case it seems that the problem is solved. Specifically, I have taken the following measures:

  • I performed the tests with the PR mentioned above;
  • I performed tests by lowering the value of the learning rate.

No changes were made to the data collator (I used the default one, as in the attached code).
@Ricardokevins , I am waiting to hear from you before closing the issue. Thank you very much @kashif and @lvwerra for the support!

@Ricardokevins
Copy link

Ricardokevins commented Nov 29, 2023

No changes were made to the data collator (I used the default one, as in the attached code).

Congrats!

I am a little comfuse and interested in the root cause of the problem. I checked your code, and I noticed that you didn't specify the data_collator. However, in the dpo_trainer, the default data_collator is DPODataCollatorWithPadding. Did you use the code changes on DPODataCollatorWithPadding from the PR in trl/trainer/utils.py? I'm confused about the statement "No changes were made to the data collator (I used the default one, as in the attached code)."

@Devy99
Copy link
Author

Devy99 commented Nov 29, 2023

Exactly, I simply tested the above code with the changes made in the PR and the data collator was not specified ( hence, the default one is used according to the documentation ).

@Ricardokevins
Copy link

Exactly, I simply tested the above code with the changes made in the PR and the data collator was not specified ( hence, the default one is used according to the documentation ).

Thank you!

I haven't observed this problem in my subsequent experiments either. The previous occurrence was likely a random occurrence. We can go ahead and close the issue for now. I may need to seek your assistance again if I encounter any further issues.

@Devy99
Copy link
Author

Devy99 commented Dec 7, 2023

@kashif unluckly, I have to re-open the issue. Now I am testing the script on a real and more complex dataset for a text generation task and the problem seems to persist. In particular, by increasing the number of steps the performance of the model deteriorates considerably, leading to checkpoints where a large number of repeated characters are generated. I also tried changing the learning rate several times, without getting any promising results. Of course for all experiments I tried the pull request code, without success.

@Devy99 Devy99 reopened this Dec 7, 2023
@kashif
Copy link
Collaborator

kashif commented Dec 7, 2023

no worries @Devy99 i can test with my branch to check... the main issue being that your model is an enc-dec style model while the code is mostly been tested on decoder only style models...

@Devy99
Copy link
Author

Devy99 commented Dec 7, 2023

Thank you very much! Then, I look forward to hearing from you.

@Devy99
Copy link
Author

Devy99 commented Dec 17, 2023

Hi @kashif , sorry to bother you. Is there any update about this problem?
Also, do you know feasible alternatives to test the DPO for this kind of task? I was thinking of doing the same study with GPT2, since it is decoder-only and small enough for a small-scale test. However I don't know if it was tested on GPT2 before.

P.S. @Ricardokevins did you solve your problem / find any solution?

@raghavgarg97
Copy link

raghavgarg97 commented Dec 27, 2023

I am also facing similar issue..
Though the repetition is more at token level.I am using a custom prepared dataset.
I trained mistral-7B model..full DPO training(no PEFT) with a lr of 5.6e-5(which is same as SFT).
Lr may be an issue so i will try to reduce it and retrigger training.
Example of my models response
Thank you for contacting us. We're here to assist you with your specific issue. We're here to assist you with your specific issue. We're here to assist you with your specific issue. We're here to assist you with your specific issue. We're here to assist you with your specific issue. We're here to assist you with your specific issue. We're here to assist you with your specific issue. We're here to assist you with your

@raghavgarg97
Copy link

Update: reducing the lr did help removing the redundancy
i dropped it to 2e-6

@vwxyzjn vwxyzjn added the 🏋 DPO Related to DPO label Jan 5, 2024
@Devy99
Copy link
Author

Devy99 commented Jan 10, 2024

Hi @kashif @lvwerra @younesbelkada , sorry to ask again. Can you tell me if the DPO implementation will be further tested on encoder-decoder models like T5? I have tried to replace my model with a decoder-only, but an encoder-decoder like T5 would be ideal for the experiments I am doing. Can you tell me more about it?

@Ricardokevins
Copy link

unluckly

unluckly, when i explore more basemodel for DPO, i encounter the problem again...

I still think there might be some bugs in the code or the algorithm itself. Lower Learning rate might just slow the deteriorates, and the problem might still exsit.

But i have no idea about why...

1 similar comment
@Ricardokevins
Copy link

unluckly

unluckly, when i explore more basemodel for DPO, i encounter the problem again...

I still think there might be some bugs in the code or the algorithm itself. Lower Learning rate might just slow the deteriorates, and the problem might still exsit.

But i have no idea about why...

@yata0
Copy link

yata0 commented Mar 6, 2024

@Devy99 @Ricardokevins Hi, I met the problem,too.
Could you tell me the setting of "average_log_pi"?

#789

@Ricardokevins
Copy link

@Devy99 @Ricardokevins Hi, I met the problem,too. Could you tell me the setting of "average_log_pi"?

#789

I didnt change this setting (might be the default value)

@kashif
Copy link
Collaborator

kashif commented Mar 6, 2024

@Devy99 i think with the refactorings that happened the enc-dec setup needs a closer look, I will take a look with a tiny T5 style model and report back

@Ricardokevins
Copy link

Ricardokevins commented Mar 6, 2024

In my setting (my recent research on Multilingual Reasoning by Preference Optimization), I use 1e-6 for DPO training on LlaMA2(7/13B)-based LLM. And it works well on these Decoder-Only LLM. If 1e-6 not working, you can lower it to 2e-7.

Meanwhile, you can put those bad sampling output with repetition to "Reject" to penalize this behavior

from collections import Counter
def check_repeated_sentences(paragraph):
    paragraph = paragraph.replace('\n\n','\n')
    sentences = paragraph.split('\n')
    sentence_counts = Counter(sentences)
    
    for sentence, count in sentence_counts.items():
        if count >= 2:
            return True
    return False

@Devy99
Copy link
Author

Devy99 commented Mar 6, 2024

@kashif nice, thank you! I'll wait for your updates then 🙏

Copy link

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

@Devy99
Copy link
Author

Devy99 commented Mar 30, 2024

Up ( no stale )

@thusinh1969
Copy link

Up ...!

@kashif
Copy link
Collaborator

kashif commented Apr 4, 2024

@Devy99 i think the trainers havent been working with encoder-decoder style models so I fear there might still be a bug with masking and tokenization etc. Let me find some time to have a look at T5, I can use your example above to check?

@Devy99
Copy link
Author

Devy99 commented Apr 4, 2024

@kashif ok, thanks for your patience and your work! Actually, I don't have the dataset used for the example, since it passed some time from when I opened this issue. But it can be easily generated from scratch since I applied trivial heuristics to determine the output value.

@LDelPinoNT
Copy link

Hi! Any update on this issue? I also have a similar problem with the KTO Trainer and TinyMistral.

@Devy99
Copy link
Author

Devy99 commented Apr 24, 2024

Not yet, but @kashif is actually testing it for T5

@Klein-Lan
Copy link

Hi @kashif , sorry to bother you. Is there any update about this problem? Also, do you know feasible alternatives to test the DPO for this kind of task? I was thinking of doing the same study with GPT2, since it is decoder-only and small enough for a small-scale test. However I don't know if it was tested on GPT2 before.

P.S. @Ricardokevins did you solve your problem / find any solution?

Hello, I am planning to train DPOon a 100M GPT-2 model. Since this is just for testing purposes, I haven't formatted my training data specifically. I randomly selected a DPO dataset from Hugging Face for training.

After 2 days of DPO training, I found that my GPT-2 model couldn't even generate coherent sentences, or rather, the model's performance degraded catastrophically.

I am puzzled. Do I need to format my training data in a specific way like yours to see results? Is it normal for the model's performance to degrade with random data? Of course, I am more than willing to reproduce the issue you initially encountered, as long as it doesn't result in the model degrading to an untrained state.

@sajastu
Copy link

sajastu commented May 9, 2024

Up... facing the same problem with encoder-decoder style models (BART) on PPO. Having read the relevant thread(s), I suspect this might be universal issue among DPO and PPO trainers (?)

@santyzenith
Copy link

The same problem here, after tuning an Only-Decoder model using SFT, the behaviour of the model is as expected. But applying DPO training, the model does not even generate consistent responses. As a guide, I have used the parameters from the huggingface/alignment-handbook.

@starmpcc
Copy link

This blog post may related to this problem. It seems as a fundamental issue of DPO/KTO loss.

https://kyunghyuncho.me/a-proper-preference-optimization-loss-and-its-gradient/

Copy link

github-actions bot commented Jun 7, 2024

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

@Devy99
Copy link
Author

Devy99 commented Jun 7, 2024

Up. Any news, @kashif ?

@WJMacro
Copy link

WJMacro commented Jun 12, 2024

I'm facing the same problem. Currently, I'm using in-distribution data sampled from model output and a low learning rate (5e-7). The repetition still exists.

Copy link

github-actions bot commented Jul 6, 2024

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

@Devy99
Copy link
Author

Devy99 commented Jul 6, 2024

Up

@shengminp
Copy link

shengminp commented Jul 24, 2024

I am currently working on a project that requires me to use the DPO trainer for the T5 model.

I have noticed that the DPO trainer provided by Huggingface has some issues. Unfortunately, the development team does not seem to be paying much attention to this so far. (I understand most people are now focusing more on decoder-only language models.)

Because from the code snippet in "https://github.com/huggingface/trl/blob/main/tests/test_dpo_trainer.py", which states: "
if name == "t5":
self.skipTest("For some reason t5 does not compute gradients properly on tiny models")
"

Anyway, I encountered a specific issue with the original DPO trainer when applying it to the T5 model, mainly concerning the padding value added in the decoder.

And, after addressing these problems, the DPO trainer is now functioning correctly for my use case. Additionally, I have found that using a fine-tuned T5 model and a very small learning rate is crucial for stable training.

@Devy99
Copy link
Author

Devy99 commented Jul 24, 2024

@shengminp thank you for sharing your experience!
Can you please provide more details on how did you fix the problem for T5? This way we can check if your fix applies also in our use cases.

Thanks in advance 🙏

@shengminp
Copy link

@Devy99 Sure, I've made several modifications based on my specific needs, so I can only provide you with the problem code. I'm not sure if these details will help you solve your current issues, but here they are.

The version of library are as follow:

  • transformers: 4.37.2
  • trl: 0.9.6

The problem code is primarily in DPODataCollatorWithPadding. In line 380 of (https://github.com/huggingface/trl/blob/main/trl/trainer/utils.py), you will find the following condition:
elif k.startswith(("chosen", "rejected", "completion")) or ("decoder" in k):
padding_value = self.label_pad_token_id
This condition causes the input_ids of the decoder to be padded with -100, which results in an error. To solve this issue, I modified the condition to:
elif k.startswith(("chosen", "rejected", "completion")):
padding_value = self.label_pad_token_id
This simple change resolved the problem I mentioned.

Please note that my current changes are solely intended to ensure the T5 model runs without issues. This means the code has not been tested with decoder-only models or other encoder-decoder models, so it is possible that my changes could cause errors in other situations.

@shengminp
Copy link

shengminp commented Jul 25, 2024

Let me share more details about my experience. I followed the xxx-tuning phase utilized in most of today's LM model training, which includes the following steps:

  1. Pretrained Model Initialization: I started with a pretrained T5 model, specifically using "google-t5/t5-base." For this, I used T5ForConditionalGeneration and T5Tokenizer to load the corresponding model and tokenizer.
  2. Supervised Fine-Tuning: In this stage, I performed supervised fine-tuning with a learning rate of 0.0005. I ensured that the model was fine-tuned to generate fluent and meaningful sentences. During this phase, I applied Seq2SeqTrainer.
  3. Preference Data Generation and DPO Tuning: After fine-tuning, I generated preference data pairs based on the fine-tuned T5 model. I then used this generated data to apply DPOTrainer to the fine-tuned model.

I am not sure if these experience can help you. And at present, I am still debugging and modifying the code, so the code are quite complicated. Please forgive me for not being able to provide too many details of the code at this time.

@Devy99
Copy link
Author

Devy99 commented Jul 25, 2024

@shengminp thanks! I'll give a try 😄

Copy link

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

@qgallouedec qgallouedec added the 🙋 help from community wanted Open invitation for community members to contribute label Aug 22, 2024
@kashif
Copy link
Collaborator

kashif commented Aug 28, 2024

@Devy99 can you try now on the latest head with an encoder-decoder model?

@Devy99
Copy link
Author

Devy99 commented Aug 30, 2024

@kashif thanks for your work! Lately, I am quite busy, so I'll update you as soon as possible. Anyway, I also invite the others who experienced my same issue to check whether the problem is fixed or not.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
🏋 DPO Related to DPO 🙋 help from community wanted Open invitation for community members to contribute
Projects
None yet
Development

No branches or pull requests