diff --git a/tests/test_dpo_trainer.py b/tests/test_dpo_trainer.py index 3b5bfcae2b..ec2b7d015a 100644 --- a/tests/test_dpo_trainer.py +++ b/tests/test_dpo_trainer.py @@ -20,7 +20,7 @@ from pytest import mark from transformers import AutoModelForCausalLM, AutoModelForSeq2SeqLM, AutoTokenizer -from trl import DPOConfig, DPOTrainer +from trl import DPOConfig, DPOTrainer, FDivergenceType from .testing_utils import require_bitsandbytes, require_no_wandb, require_peft @@ -719,3 +719,87 @@ def test_dpo_lora_force_use_ref(self): # train the model trainer.train() + + def test_dpo_loss_alpha_div_f(self): + model_id = "trl-internal-testing/tiny-random-LlamaForCausalLM" + tokenizer = AutoTokenizer.from_pretrained(model_id) + + # lora model + model = AutoModelForCausalLM.from_pretrained(model_id) + + with tempfile.TemporaryDirectory() as tmp_dir: + training_args = DPOConfig( + output_dir=tmp_dir, + per_device_train_batch_size=2, + max_steps=3, + remove_unused_columns=False, + gradient_accumulation_steps=4, + learning_rate=9e-1, + evaluation_strategy="steps", + f_divergence_type=FDivergenceType.ALPHA_DIVERGENCE.value, + f_alpha_divergence_coef=0.5, + ) + + dummy_dataset = self._init_dummy_dataset() + + # dpo train lora model with a lora config + trainer = DPOTrainer( + model=model, + ref_model=None, + args=training_args, + tokenizer=tokenizer, + train_dataset=dummy_dataset, + eval_dataset=dummy_dataset, + ) + + # Fake chosen and rejected log probs + policy_chosen_logps = torch.FloatTensor([410.0, 0.1]) + policy_rejected_logps = torch.FloatTensor([810.5, 0.2]) + reference_chosen_logps = torch.FloatTensor([-610.0, -0.1]) + reference_rejected_logps = torch.FloatTensor([110.6, 0.5]) + losses, _, _ = trainer.dpo_loss( + policy_chosen_logps, policy_rejected_logps, reference_chosen_logps, reference_rejected_logps + ) + assert torch.isfinite(losses).cpu().numpy().all() + + def test_dpo_loss_js_div_f(self): + model_id = "trl-internal-testing/tiny-random-LlamaForCausalLM" + tokenizer = AutoTokenizer.from_pretrained(model_id) + + # lora model + model = AutoModelForCausalLM.from_pretrained(model_id) + + with tempfile.TemporaryDirectory() as tmp_dir: + training_args = DPOConfig( + output_dir=tmp_dir, + per_device_train_batch_size=2, + max_steps=3, + remove_unused_columns=False, + gradient_accumulation_steps=4, + learning_rate=9e-1, + evaluation_strategy="steps", + f_divergence_type=FDivergenceType.JS_DIVERGENCE.value, + f_alpha_divergence_coef=0.5, + ) + + dummy_dataset = self._init_dummy_dataset() + + # dpo train lora model with a lora config + trainer = DPOTrainer( + model=model, + ref_model=None, + args=training_args, + tokenizer=tokenizer, + train_dataset=dummy_dataset, + eval_dataset=dummy_dataset, + ) + + # Fake chosen and rejected log probs + policy_chosen_logps = torch.FloatTensor([410.0, 0.1]) + policy_rejected_logps = torch.FloatTensor([95.5, 0.2]) + reference_chosen_logps = torch.FloatTensor([-610.0, -0.1]) + reference_rejected_logps = torch.FloatTensor([5.5, 0.5]) + losses, _, _ = trainer.dpo_loss( + policy_chosen_logps, policy_rejected_logps, reference_chosen_logps, reference_rejected_logps + ) + assert torch.isfinite(losses).cpu().numpy().all() diff --git a/trl/__init__.py b/trl/__init__.py index 0879b8311e..44ed2f21b3 100644 --- a/trl/__init__.py +++ b/trl/__init__.py @@ -51,6 +51,8 @@ "RewardTrainer", "SFTConfig", "SFTTrainer", + "FDivergenceConstants", + "FDivergenceType", ], "commands": [], "commands.cli_utils": ["init_zero_verbose", "SFTScriptArguments", "DPOScriptArguments", "TrlParser"], @@ -117,6 +119,8 @@ RewardTrainer, SFTConfig, SFTTrainer, + FDivergenceConstants, + FDivergenceType, ) from .trainer.utils import get_kbit_device_map, get_peft_config, get_quantization_config, RichProgressCallback from .commands.cli_utils import init_zero_verbose, SFTScriptArguments, DPOScriptArguments, TrlParser diff --git a/trl/trainer/__init__.py b/trl/trainer/__init__.py index c291ff31ec..f417f7878f 100644 --- a/trl/trainer/__init__.py +++ b/trl/trainer/__init__.py @@ -29,7 +29,7 @@ "peft_module_casting_to_bf16", "RichProgressCallback", ], - "dpo_config": ["DPOConfig"], + "dpo_config": ["DPOConfig", "FDivergenceConstants", "FDivergenceType"], "dpo_trainer": ["DPOTrainer"], "cpo_config": ["CPOConfig"], "cpo_trainer": ["CPOTrainer"], @@ -76,7 +76,7 @@ from .base import BaseTrainer from .ddpo_config import DDPOConfig - from .dpo_config import DPOConfig + from .dpo_config import DPOConfig, FDivergenceConstants, FDivergenceType from .dpo_trainer import DPOTrainer from .iterative_sft_trainer import IterativeSFTTrainer from .cpo_config import CPOConfig diff --git a/trl/trainer/dpo_config.py b/trl/trainer/dpo_config.py index 30a434cfd9..cc6293fc20 100644 --- a/trl/trainer/dpo_config.py +++ b/trl/trainer/dpo_config.py @@ -12,11 +12,23 @@ # See the License for the specific language governing permissions and # limitations under the License. from dataclasses import dataclass +from enum import Enum from typing import Dict, Literal, Optional from transformers import TrainingArguments +class FDivergenceType(Enum): + REVERSE_KL = "reverse_kl" + JS_DIVERGENCE = "js_divergence" + ALPHA_DIVERGENCE = "alpha_divergence" + + +class FDivergenceConstants: + ALPHA_DIVERGENCE_COEF_KEY = "alpha_divergence_coef" + ALPHA_DIVERGENCE_COEF_DEFAULT = 1.0 + + @dataclass class DPOConfig(TrainingArguments): r""" @@ -65,6 +77,10 @@ class DPOConfig(TrainingArguments): If True, we ignore the _provided_ reference model and implicitly use a reference model that assigns equal probability to all responses. force_use_ref_model (`bool`, defaults to `False`): In case one passes a PEFT model for the active model and you want to use a different model for the ref_model, set this flag to `True`. + f_divergence_type (`FDivergenceType`, *optional*, defaults to `FDivergenceType.REVERSE_KL`): + The type of f-divergence regularization function to compute divergence between policy and reference model. This argument is optional, defaults to `FDivergenceType.REVERSE_KL`. + f_alpha_divergence_coef (`float`, *optional*, defaults to `1.0`): + The alpha coef in alpha-divergence(u^-alpha) regularization function for DPO loss. sync_ref_model ('bool', defaults to `False`): The flag for syncing reference model during training from the [TR-DPO](https://arxiv.org/pdf/2404.09656) paper. ref_model_mixup_alpha ('float', defaults to 1.0): @@ -97,6 +113,8 @@ class DPOConfig(TrainingArguments): ref_adapter_name: Optional[str] = None reference_free: bool = False force_use_ref_model: bool = False + f_divergence_type: Optional[FDivergenceType] = FDivergenceType.REVERSE_KL + f_alpha_divergence_coef: Optional[float] = 1.0 sync_ref_model: bool = False ref_model_mixup_alpha: float = 0.9 ref_model_sync_steps: int = 64 diff --git a/trl/trainer/dpo_trainer.py b/trl/trainer/dpo_trainer.py index 2403539c6b..b95ca7eb45 100644 --- a/trl/trainer/dpo_trainer.py +++ b/trl/trainer/dpo_trainer.py @@ -42,11 +42,12 @@ from ..import_utils import is_peft_available, is_wandb_available from ..models import PreTrainedModelWrapper, create_reference_model -from .dpo_config import DPOConfig +from .dpo_config import DPOConfig, FDivergenceConstants, FDivergenceType from .utils import ( DPODataCollatorWithPadding, RunningMoments, SyncRefModelCallback, + cap_exp, disable_dropout_in_model, pad_to_length, peft_module_casting_to_bf16, @@ -479,6 +480,9 @@ def make_inputs_require_grad(module, input, output): self._stored_metrics = defaultdict(lambda: defaultdict(list)) + self.f_divergence_type = args.f_divergence_type + self.f_divergence_params = {FDivergenceConstants.ALPHA_DIVERGENCE_COEF_KEY: args.f_alpha_divergence_coef} + if dataset_num_proc is not None: warnings.warn( "You passed `dataset_num_proc` to the DPOTrainer, the value you passed will override the one in the `DPOConfig`." @@ -996,15 +1000,43 @@ def dpo_loss( The losses tensor contains the DPO loss for each example in the batch. The chosen_rewards and rejected_rewards tensors contain the rewards for the chosen and rejected responses, respectively. """ - pi_logratios = policy_chosen_logps - policy_rejected_logps - if self.reference_free: - ref_logratios = torch.tensor([0], dtype=pi_logratios.dtype, device=pi_logratios.device) + chosen_logratios = policy_chosen_logps.to(self.accelerator.device) - ( + not self.reference_free + ) * reference_chosen_logps.to(self.accelerator.device) + rejected_logratios = policy_rejected_logps.to(self.accelerator.device) - ( + not self.reference_free + ) * reference_rejected_logps.to(self.accelerator.device) + + if self.f_divergence_type == FDivergenceType.ALPHA_DIVERGENCE.value: + # The alpha-divergence formula: (1 - u^-alpha) / alpha + # The divergence difference between the chosen and rejected sample is: + # (1 - u[w]^-alpha) / alpha - (1 - u[l]^-alpha) / alpha + # = (u[l]^-alpha - u[w]^-alpha) / alpha + # where u[w] and u[l] are the policy/reference probability ratios + # for the chosen and rejected samples, respectively. + alpha_coef = FDivergenceConstants.ALPHA_DIVERGENCE_COEF_DEFAULT + if self.f_divergence_params and FDivergenceConstants.ALPHA_DIVERGENCE_COEF_KEY in self.f_divergence_params: + alpha_coef = float(self.f_divergence_params[FDivergenceConstants.ALPHA_DIVERGENCE_COEF_KEY]) + logits = (cap_exp(rejected_logratios * -alpha_coef) - cap_exp(chosen_logratios * -alpha_coef)) / alpha_coef else: - ref_logratios = reference_chosen_logps - reference_rejected_logps - - pi_logratios = pi_logratios.to(self.accelerator.device) - ref_logratios = ref_logratios.to(self.accelerator.device) - logits = pi_logratios - ref_logratios + pi_logratios = policy_chosen_logps - policy_rejected_logps + if self.reference_free: + ref_logratios = torch.tensor([0], dtype=pi_logratios.dtype, device=pi_logratios.device) + else: + ref_logratios = reference_chosen_logps - reference_rejected_logps + + pi_logratios = pi_logratios.to(self.accelerator.device) + ref_logratios = ref_logratios.to(self.accelerator.device) + logits = pi_logratios - ref_logratios + + if self.f_divergence_type == FDivergenceType.JS_DIVERGENCE.value: + # The js-divergence formula: log(2 * u / (1 + u)) + # The divergence difference between the chosen and rejected sample is: + # log(2 * u[w] / (1 + u[w])) - log(2 * u[l] / (1 + u[l])) + # = log(u[w]) - log(u[l]) - (log(1 + u[w]) - log(1 + u[l])) + # where u[w] and u[l] are the policy/reference probability ratios + # for the chosen and rejected samples, respectively. + logits -= F.softplus(chosen_logratios) - F.softplus(rejected_logratios) # The beta is a temperature parameter for the DPO loss, typically something in the range of 0.1 to 0.5. # We ignore the reference model as beta -> 0. The label_smoothing parameter encodes our uncertainty about the labels and diff --git a/trl/trainer/utils.py b/trl/trainer/utils.py index c0197336e7..8709452945 100644 --- a/trl/trainer/utils.py +++ b/trl/trainer/utils.py @@ -859,6 +859,32 @@ def on_train_end(self, args, state, control, **kwargs): self.current_step = None +def get_exp_cap(value, decimal=4): + """ + Get the exponent cap of a value. This is used to cap the exponent of a value to avoid overflow. + The formula is : log(value.dtype.max) + E.g. + For float32 data type, the maximum exponent value is 88.7228 to 4 decimal points. + ``` + Args: + value (`torch.Tensor`): + The input tensor to obtain the data type + decimal (`int`): + The number of decimal points of the output exponent cap. + eg: direct calling exp(log(torch.float32.max)) will result in inf + so we cap the exponent to 88.7228 to avoid overflow. + """ + vdtype_max = torch.zeros([1]).to(value.dtype) + torch.finfo(value.dtype).max + vdtype_log_max = torch.log(vdtype_max).to(value.device) + return torch.floor(vdtype_log_max * 10**decimal) / 10**decimal if decimal > 0 else vdtype_log_max + + +def cap_exp(value, cap=-1): + # Cap the exponent value below the upper-bound to avoid overflow, before calling torch.exp + cap = get_exp_cap(value) if cap < 0 else cap + return torch.exp(torch.clamp(value, max=cap)) + + def print_rich_table(df: pd.DataFrame) -> Table: console = Console() table = Table(show_lines=True)