Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Integrate f-divergence to DPO (Follow up) #1610

Merged
merged 25 commits into from
Jun 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
0f6aafa
Step 1: update ppo_trainer and hello_world example
1485840691 Jan 11, 2024
afa1526
Step 2: Refine comments and add parameter type
1485840691 Jan 11, 2024
fc1d75f
Step 2: Add missing parameter comments
1485840691 Jan 11, 2024
99c8e6e
Step 1: Organize ptx loss into a function and add ptx_loss to train_s…
1485840691 Jan 12, 2024
635a647
Step 1 updates: add comment to ptx_loss function, fix a bug and add w…
1485840691 Jan 13, 2024
e947bea
Step 2: 1) Add ppo_ptx trainig example as ppo; 2) separate pretrain d…
1485840691 Jan 18, 2024
201bc91
Step 2: Remove loss from columns_to_log in ppo_ptx example
1485840691 Jan 21, 2024
b0773e5
Remove data set revision in load imbd dataset
1485840691 Jan 25, 2024
8fdb770
Run pre-commit and fix format issues
1485840691 Jan 27, 2024
7715dbf
Initial draft of f-divergence fn
1485840691 Feb 12, 2024
18ec6a1
Update f-divergence to avoid overflow
1485840691 Feb 18, 2024
d47fc79
Merge branch 'main' into dpo
1485840691 Feb 24, 2024
f3403f7
fix test errors and comments
1485840691 Feb 29, 2024
211c935
Add Unit tests for dpo loss with alpha and js div f
1485840691 Mar 3, 2024
9afcec9
Adjust format
1485840691 Mar 3, 2024
f77a83b
Fix test error
1485840691 Mar 6, 2024
5400f63
Merge branch 'main' of https://github.com/1485840691/trl
1485840691 Apr 27, 2024
e77b38e
Merge branch 'main' into dpo
1485840691 Apr 27, 2024
5064e79
Reverse this update
1485840691 Apr 28, 2024
2aa29b7
Add test cases
1485840691 May 1, 2024
c30eda9
Reverse un-needed updates
1485840691 May 1, 2024
b672720
Update code style
1485840691 May 4, 2024
2363c21
Merge branch 'main' into dpo
kashif Jun 7, 2024
13d18d4
Try to fix code fmt error
1485840691 Jun 10, 2024
36bdc73
remove extra end line
1485840691 Jun 19, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 85 additions & 1 deletion tests/test_dpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from pytest import mark
from transformers import AutoModelForCausalLM, AutoModelForSeq2SeqLM, AutoTokenizer

from trl import DPOConfig, DPOTrainer
from trl import DPOConfig, DPOTrainer, FDivergenceType

from .testing_utils import require_bitsandbytes, require_no_wandb, require_peft

Expand Down Expand Up @@ -719,3 +719,87 @@ def test_dpo_lora_force_use_ref(self):

# train the model
trainer.train()

def test_dpo_loss_alpha_div_f(self):
model_id = "trl-internal-testing/tiny-random-LlamaForCausalLM"
tokenizer = AutoTokenizer.from_pretrained(model_id)

# lora model
model = AutoModelForCausalLM.from_pretrained(model_id)

with tempfile.TemporaryDirectory() as tmp_dir:
training_args = DPOConfig(
output_dir=tmp_dir,
per_device_train_batch_size=2,
max_steps=3,
remove_unused_columns=False,
gradient_accumulation_steps=4,
learning_rate=9e-1,
evaluation_strategy="steps",
f_divergence_type=FDivergenceType.ALPHA_DIVERGENCE.value,
f_alpha_divergence_coef=0.5,
)

dummy_dataset = self._init_dummy_dataset()

# dpo train lora model with a lora config
trainer = DPOTrainer(
model=model,
ref_model=None,
args=training_args,
tokenizer=tokenizer,
train_dataset=dummy_dataset,
eval_dataset=dummy_dataset,
)

# Fake chosen and rejected log probs
policy_chosen_logps = torch.FloatTensor([410.0, 0.1])
policy_rejected_logps = torch.FloatTensor([810.5, 0.2])
reference_chosen_logps = torch.FloatTensor([-610.0, -0.1])
reference_rejected_logps = torch.FloatTensor([110.6, 0.5])
losses, _, _ = trainer.dpo_loss(
policy_chosen_logps, policy_rejected_logps, reference_chosen_logps, reference_rejected_logps
)
assert torch.isfinite(losses).cpu().numpy().all()

def test_dpo_loss_js_div_f(self):
model_id = "trl-internal-testing/tiny-random-LlamaForCausalLM"
tokenizer = AutoTokenizer.from_pretrained(model_id)

# lora model
model = AutoModelForCausalLM.from_pretrained(model_id)

with tempfile.TemporaryDirectory() as tmp_dir:
training_args = DPOConfig(
output_dir=tmp_dir,
per_device_train_batch_size=2,
max_steps=3,
remove_unused_columns=False,
gradient_accumulation_steps=4,
learning_rate=9e-1,
evaluation_strategy="steps",
f_divergence_type=FDivergenceType.JS_DIVERGENCE.value,
f_alpha_divergence_coef=0.5,
)

dummy_dataset = self._init_dummy_dataset()

# dpo train lora model with a lora config
trainer = DPOTrainer(
model=model,
ref_model=None,
args=training_args,
tokenizer=tokenizer,
train_dataset=dummy_dataset,
eval_dataset=dummy_dataset,
)

# Fake chosen and rejected log probs
policy_chosen_logps = torch.FloatTensor([410.0, 0.1])
policy_rejected_logps = torch.FloatTensor([95.5, 0.2])
reference_chosen_logps = torch.FloatTensor([-610.0, -0.1])
reference_rejected_logps = torch.FloatTensor([5.5, 0.5])
losses, _, _ = trainer.dpo_loss(
policy_chosen_logps, policy_rejected_logps, reference_chosen_logps, reference_rejected_logps
)
assert torch.isfinite(losses).cpu().numpy().all()
4 changes: 4 additions & 0 deletions trl/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,8 @@
"RewardTrainer",
"SFTConfig",
"SFTTrainer",
"FDivergenceConstants",
"FDivergenceType",
],
"commands": [],
"commands.cli_utils": ["init_zero_verbose", "SFTScriptArguments", "DPOScriptArguments", "TrlParser"],
Expand Down Expand Up @@ -117,6 +119,8 @@
RewardTrainer,
SFTConfig,
SFTTrainer,
FDivergenceConstants,
FDivergenceType,
)
from .trainer.utils import get_kbit_device_map, get_peft_config, get_quantization_config, RichProgressCallback
from .commands.cli_utils import init_zero_verbose, SFTScriptArguments, DPOScriptArguments, TrlParser
Expand Down
4 changes: 2 additions & 2 deletions trl/trainer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
"peft_module_casting_to_bf16",
"RichProgressCallback",
],
"dpo_config": ["DPOConfig"],
"dpo_config": ["DPOConfig", "FDivergenceConstants", "FDivergenceType"],
"dpo_trainer": ["DPOTrainer"],
"cpo_config": ["CPOConfig"],
"cpo_trainer": ["CPOTrainer"],
Expand Down Expand Up @@ -76,7 +76,7 @@
from .base import BaseTrainer
from .ddpo_config import DDPOConfig

from .dpo_config import DPOConfig
from .dpo_config import DPOConfig, FDivergenceConstants, FDivergenceType
from .dpo_trainer import DPOTrainer
from .iterative_sft_trainer import IterativeSFTTrainer
from .cpo_config import CPOConfig
Expand Down
18 changes: 18 additions & 0 deletions trl/trainer/dpo_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,23 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
from enum import Enum
from typing import Dict, Literal, Optional

from transformers import TrainingArguments


class FDivergenceType(Enum):
REVERSE_KL = "reverse_kl"
JS_DIVERGENCE = "js_divergence"
ALPHA_DIVERGENCE = "alpha_divergence"


class FDivergenceConstants:
ALPHA_DIVERGENCE_COEF_KEY = "alpha_divergence_coef"
ALPHA_DIVERGENCE_COEF_DEFAULT = 1.0


@dataclass
class DPOConfig(TrainingArguments):
r"""
Expand Down Expand Up @@ -65,6 +77,10 @@ class DPOConfig(TrainingArguments):
If True, we ignore the _provided_ reference model and implicitly use a reference model that assigns equal probability to all responses.
force_use_ref_model (`bool`, defaults to `False`):
In case one passes a PEFT model for the active model and you want to use a different model for the ref_model, set this flag to `True`.
f_divergence_type (`FDivergenceType`, *optional*, defaults to `FDivergenceType.REVERSE_KL`):
The type of f-divergence regularization function to compute divergence between policy and reference model. This argument is optional, defaults to `FDivergenceType.REVERSE_KL`.
f_alpha_divergence_coef (`float`, *optional*, defaults to `1.0`):
The alpha coef in alpha-divergence(u^-alpha) regularization function for DPO loss.
sync_ref_model ('bool', defaults to `False`):
The flag for syncing reference model during training from the [TR-DPO](https://arxiv.org/pdf/2404.09656) paper.
ref_model_mixup_alpha ('float', defaults to 1.0):
Expand Down Expand Up @@ -97,6 +113,8 @@ class DPOConfig(TrainingArguments):
ref_adapter_name: Optional[str] = None
reference_free: bool = False
force_use_ref_model: bool = False
f_divergence_type: Optional[FDivergenceType] = FDivergenceType.REVERSE_KL
f_alpha_divergence_coef: Optional[float] = 1.0
sync_ref_model: bool = False
ref_model_mixup_alpha: float = 0.9
ref_model_sync_steps: int = 64
Expand Down
50 changes: 41 additions & 9 deletions trl/trainer/dpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,11 +42,12 @@

from ..import_utils import is_peft_available, is_wandb_available
from ..models import PreTrainedModelWrapper, create_reference_model
from .dpo_config import DPOConfig
from .dpo_config import DPOConfig, FDivergenceConstants, FDivergenceType
from .utils import (
DPODataCollatorWithPadding,
RunningMoments,
SyncRefModelCallback,
cap_exp,
disable_dropout_in_model,
pad_to_length,
peft_module_casting_to_bf16,
Expand Down Expand Up @@ -479,6 +480,9 @@ def make_inputs_require_grad(module, input, output):

self._stored_metrics = defaultdict(lambda: defaultdict(list))

self.f_divergence_type = args.f_divergence_type
self.f_divergence_params = {FDivergenceConstants.ALPHA_DIVERGENCE_COEF_KEY: args.f_alpha_divergence_coef}

if dataset_num_proc is not None:
warnings.warn(
"You passed `dataset_num_proc` to the DPOTrainer, the value you passed will override the one in the `DPOConfig`."
Expand Down Expand Up @@ -996,15 +1000,43 @@ def dpo_loss(
The losses tensor contains the DPO loss for each example in the batch.
The chosen_rewards and rejected_rewards tensors contain the rewards for the chosen and rejected responses, respectively.
"""
pi_logratios = policy_chosen_logps - policy_rejected_logps
if self.reference_free:
ref_logratios = torch.tensor([0], dtype=pi_logratios.dtype, device=pi_logratios.device)
chosen_logratios = policy_chosen_logps.to(self.accelerator.device) - (
not self.reference_free
) * reference_chosen_logps.to(self.accelerator.device)
rejected_logratios = policy_rejected_logps.to(self.accelerator.device) - (
not self.reference_free
) * reference_rejected_logps.to(self.accelerator.device)

if self.f_divergence_type == FDivergenceType.ALPHA_DIVERGENCE.value:
# The alpha-divergence formula: (1 - u^-alpha) / alpha
# The divergence difference between the chosen and rejected sample is:
# (1 - u[w]^-alpha) / alpha - (1 - u[l]^-alpha) / alpha
# = (u[l]^-alpha - u[w]^-alpha) / alpha
# where u[w] and u[l] are the policy/reference probability ratios
# for the chosen and rejected samples, respectively.
alpha_coef = FDivergenceConstants.ALPHA_DIVERGENCE_COEF_DEFAULT
if self.f_divergence_params and FDivergenceConstants.ALPHA_DIVERGENCE_COEF_KEY in self.f_divergence_params:
alpha_coef = float(self.f_divergence_params[FDivergenceConstants.ALPHA_DIVERGENCE_COEF_KEY])
logits = (cap_exp(rejected_logratios * -alpha_coef) - cap_exp(chosen_logratios * -alpha_coef)) / alpha_coef
else:
ref_logratios = reference_chosen_logps - reference_rejected_logps

pi_logratios = pi_logratios.to(self.accelerator.device)
ref_logratios = ref_logratios.to(self.accelerator.device)
logits = pi_logratios - ref_logratios
pi_logratios = policy_chosen_logps - policy_rejected_logps
if self.reference_free:
ref_logratios = torch.tensor([0], dtype=pi_logratios.dtype, device=pi_logratios.device)
else:
ref_logratios = reference_chosen_logps - reference_rejected_logps

pi_logratios = pi_logratios.to(self.accelerator.device)
ref_logratios = ref_logratios.to(self.accelerator.device)
logits = pi_logratios - ref_logratios

if self.f_divergence_type == FDivergenceType.JS_DIVERGENCE.value:
# The js-divergence formula: log(2 * u / (1 + u))
# The divergence difference between the chosen and rejected sample is:
# log(2 * u[w] / (1 + u[w])) - log(2 * u[l] / (1 + u[l]))
# = log(u[w]) - log(u[l]) - (log(1 + u[w]) - log(1 + u[l]))
# where u[w] and u[l] are the policy/reference probability ratios
# for the chosen and rejected samples, respectively.
logits -= F.softplus(chosen_logratios) - F.softplus(rejected_logratios)

# The beta is a temperature parameter for the DPO loss, typically something in the range of 0.1 to 0.5.
# We ignore the reference model as beta -> 0. The label_smoothing parameter encodes our uncertainty about the labels and
Expand Down
26 changes: 26 additions & 0 deletions trl/trainer/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -859,6 +859,32 @@ def on_train_end(self, args, state, control, **kwargs):
self.current_step = None


def get_exp_cap(value, decimal=4):
"""
Get the exponent cap of a value. This is used to cap the exponent of a value to avoid overflow.
The formula is : log(value.dtype.max)
E.g.
For float32 data type, the maximum exponent value is 88.7228 to 4 decimal points.
```
Args:
value (`torch.Tensor`):
The input tensor to obtain the data type
decimal (`int`):
The number of decimal points of the output exponent cap.
eg: direct calling exp(log(torch.float32.max)) will result in inf
so we cap the exponent to 88.7228 to avoid overflow.
"""
vdtype_max = torch.zeros([1]).to(value.dtype) + torch.finfo(value.dtype).max
vdtype_log_max = torch.log(vdtype_max).to(value.device)
return torch.floor(vdtype_log_max * 10**decimal) / 10**decimal if decimal > 0 else vdtype_log_max


def cap_exp(value, cap=-1):
# Cap the exponent value below the upper-bound to avoid overflow, before calling torch.exp
cap = get_exp_cap(value) if cap < 0 else cap
return torch.exp(torch.clamp(value, max=cap))


def print_rich_table(df: pd.DataFrame) -> Table:
console = Console()
table = Table(show_lines=True)
Expand Down
Loading