Skip to content

Commit

Permalink
⚰️ Deprecate liger-kernel (#2949)
Browse files Browse the repository at this point in the history
* Deprecate liger

* remove import

* oops, shouldn't be here

* Fix other deprecations

* remove liger from gkd for now

* remove liger for teacher

---------

Co-authored-by: Kashif Rasul <[email protected]>
  • Loading branch information
qgallouedec and kashif authored Feb 28, 2025
1 parent ac7bde5 commit b882f57
Show file tree
Hide file tree
Showing 6 changed files with 27 additions and 35 deletions.
6 changes: 3 additions & 3 deletions docs/source/sft_trainer.md
Original file line number Diff line number Diff line change
Expand Up @@ -604,17 +604,17 @@ With great memory reduction, you can potentially turn off cpu_offloading or grad
| ![Speed up](https://raw.githubusercontent.com/linkedin/Liger-Kernel/main/docs/images/e2e-tps.png) | ![Memory](https://raw.githubusercontent.com/linkedin/Liger-Kernel/main/docs/images/e2e-memory.png) |


1. To use Liger-Kernel in `SFTTrainer`, first install by
1. To use Liger-Kernel in [`SFTTrainer`], first install by

```bash
pip install liger-kernel
```

2. Once installed, set `use_liger` in [`SFTConfig`]. No other changes are needed!
2. Once installed, set `use_liger_kernel` in [`SFTConfig`]. No other changes are needed!

```python
training_args = SFTConfig(
use_liger=True
use_liger_kernel=True
)
```

Expand Down
2 changes: 1 addition & 1 deletion tests/slow/test_sft_slow.py
Original file line number Diff line number Diff line change
Expand Up @@ -412,7 +412,7 @@ def test_sft_trainer_with_liger(self, model_name, packing):
max_steps=2,
packing=packing,
max_length=self.max_length,
use_liger=True,
use_liger_kernel=True,
)

trainer = SFTTrainer(
Expand Down
6 changes: 3 additions & 3 deletions trl/trainer/dpo_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -387,15 +387,15 @@ class DPOConfig(TrainingArguments):
)

# Deprecated parameters
use_num_logits_to_keep: bool = field(
default=False,
use_num_logits_to_keep: Optional[bool] = field(
default=None,
metadata={"help": "Deprecated. Use `use_logits_to_keep` instead."},
)

def __post_init__(self):
super().__post_init__()

if self.use_num_logits_to_keep:
if self.use_num_logits_to_keep is not None:
warnings.warn(
"`use_num_logits_to_keep` is deprecated and will be remove in version 0.17.0. Use "
"`use_logits_to_keep` instead.",
Expand Down
9 changes: 2 additions & 7 deletions trl/trainer/gkd_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
)
from transformers.trainer_callback import TrainerCallback
from transformers.trainer_utils import EvalPrediction
from transformers.utils import is_liger_kernel_available, is_peft_available
from transformers.utils import is_peft_available

from ..models import PreTrainedModelWrapper
from ..models.utils import unwrap_model_for_generation
Expand All @@ -54,8 +54,6 @@
if is_deepspeed_available():
import deepspeed

if is_liger_kernel_available():
from liger_kernel.transformers import AutoLigerKernelForCausalLM

if is_peft_available():
from peft import PeftConfig
Expand Down Expand Up @@ -119,10 +117,7 @@ def __init__(
)

if isinstance(teacher_model, str):
if args.use_liger:
teacher_model = AutoLigerKernelForCausalLM.from_pretrained(teacher_model, **teacher_model_init_kwargs)
else:
teacher_model = AutoModelForCausalLM.from_pretrained(teacher_model, **teacher_model_init_kwargs)
teacher_model = AutoModelForCausalLM.from_pretrained(teacher_model, **teacher_model_init_kwargs)

# Disable dropout in the model
if args.disable_dropout:
Expand Down
23 changes: 14 additions & 9 deletions trl/trainer/sft_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,6 @@ class SFTConfig(TrainingArguments):
model_init_kwargs (`dict[str, Any]` or `None`, *optional*, defaults to `None`):
Keyword arguments for [`~transformers.AutoModelForCausalLM.from_pretrained`], used when the `model`
argument of the [`SFTTrainer`] is provided as a string.
use_liger (`bool`, *optional*, defaults to `False`):
Monkey patch the model with Liger kernels to increase throughput and reduce memory usage.
> Parameters that control the data preprocessing
Expand Down Expand Up @@ -72,10 +70,6 @@ class SFTConfig(TrainingArguments):
"the `SFTTrainer` is provided as a string."
},
)
use_liger: bool = field(
default=False,
metadata={"help": "Monkey patch the model with Liger kernels to increase throughput and reduce memory usage."},
)

# Parameters that control the data preprocessing
dataset_text_field: str = field(
Expand Down Expand Up @@ -123,25 +117,29 @@ class SFTConfig(TrainingArguments):
)

# Deprecated parameters
dataset_batch_size: int = field(
dataset_batch_size: Optional[int] = field(
default=None,
metadata={"help": "Deprecated. You can safely remove this parameter from your configuration."},
)
num_of_sequences: int = field(
num_of_sequences: Optional[int] = field(
default=None,
metadata={
"help": "Deprecated. Use `max_length` instead, which specifies the maximum length of the tokenized "
"sequence, unlike `num_of_sequences`, which referred to string sequences."
},
)
chars_per_token: float = field(
chars_per_token: Optional[float] = field(
default=None,
metadata={"help": "Deprecated. If you want to customize the packing length, use `max_length`."},
)
max_seq_length: Optional[int] = field(
default=None,
metadata={"help": "Deprecated. Use `max_length` instead."},
)
use_liger: Optional[bool] = field(
default=None,
metadata={"help": "Deprecated. Use `use_liger_kernel` instead."},
)

def __post_init__(self):
super().__post_init__()
Expand Down Expand Up @@ -174,3 +172,10 @@ def __post_init__(self):
DeprecationWarning,
)
self.max_length = self.max_seq_length

if self.use_liger is not None:
warnings.warn(
"`use_liger` is deprecated and will be remove in version 0.18.0. Use `use_liger_kernel` instead.",
DeprecationWarning,
)
self.use_liger_kernel = self.use_liger
16 changes: 4 additions & 12 deletions trl/trainer/sft_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@
)
from transformers.trainer_callback import TrainerCallback
from transformers.trainer_utils import EvalPrediction
from transformers.utils import is_liger_kernel_available, is_peft_available
from transformers.utils import is_peft_available

from ..data_utils import is_conversational, maybe_apply_chat_template, maybe_convert_to_chatml, pack_examples
from .sft_config import SFTConfig
Expand All @@ -51,9 +51,6 @@
import peft
from peft import PeftConfig, PeftModel, get_peft_model, prepare_model_for_kbit_training

if is_liger_kernel_available():
from liger_kernel.transformers import AutoLigerKernelForCausalLM

if is_wandb_available():
import wandb

Expand Down Expand Up @@ -264,12 +261,7 @@ def _create_model_from_path(self, model_path: str, args: SFTConfig) -> PreTraine
model_init_kwargs["use_cache"] = False

# Create model
if args.use_liger:
if not is_liger_kernel_available():
raise ImportError("Please install Liger-kernel for use_liger=True")
model = AutoLigerKernelForCausalLM.from_pretrained(model_path, **model_init_kwargs)
else:
model = AutoModelForCausalLM.from_pretrained(model_path, **model_init_kwargs)
model = AutoModelForCausalLM.from_pretrained(model_path, **model_init_kwargs)
return model

def _prepare_peft_model(self, model: PreTrainedModel, peft_config: Any, args: SFTConfig) -> PreTrainedModel:
Expand Down Expand Up @@ -457,7 +449,7 @@ def truncate(example, max_length):
**map_kwargs,
)
# For Liger kernel, ensure only input_ids is present
if args.use_liger:
if args.use_liger_kernel:
dataset = dataset.select_columns("input_ids")

return dataset
Expand All @@ -471,7 +463,7 @@ def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=N
)

# Compute token accuracy if we have labels and if the model is not using Liger (no logits)
if "labels" in inputs and not self.args.use_liger:
if "labels" in inputs and not self.args.use_liger_kernel:
shift_logits = outputs.logits[..., :-1, :].contiguous()
shift_labels = inputs["labels"][..., 1:].contiguous()

Expand Down

0 comments on commit b882f57

Please sign in to comment.