Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add JSD Loss for Distillation #425

Merged
merged 11 commits into from
Jan 30, 2025
24 changes: 24 additions & 0 deletions benchmark/data/all_benchmark_data.csv
Original file line number Diff line number Diff line change
Expand Up @@ -745,3 +745,27 @@ kto_loss,huggingface,full,memory,MB,B,Batch Size (B),4,5544.25390625,5544.253906
kto_loss,huggingface,full,memory,MB,B,Batch Size (B),8,9057.287109375,9057.287109375,9057.287109375,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA A100-SXM4-80GB,2024-12-23 23:34:59,0.5.2
kto_loss,huggingface,full,memory,MB,B,Batch Size (B),16,16087.353515625,16087.353515625,16087.353515625,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA A100-SXM4-80GB,2024-12-23 23:34:59,0.5.2
kto_loss,huggingface,full,memory,MB,B,Batch Size (B),32,30147.486328125,30147.486328125,30147.486328125,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA A100-SXM4-80GB,2024-12-23 23:34:59,0.5.2
distill_jsd_loss,liger,forward,speed,ms,BT,B x T,1024,7.735536098480225,7.729177474975586,7.798131465911865,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": false, ""weight_hard_loss"": 0.5, ""weight_soft_loss"": 0.5, ""ignore_index"": -100}",NVIDIA H100 80GB HBM3,2024-12-03 07:58:46,0.4.2
distill_jsd_loss,liger,forward,speed,ms,BT,B x T,2048,15.20411205291748,15.165056228637695,15.226079940795898,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": false, ""weight_hard_loss"": 0.5, ""weight_soft_loss"": 0.5, ""ignore_index"": -100}",NVIDIA H100 80GB HBM3,2024-12-03 07:58:46,0.4.2
distill_jsd_loss,liger,forward,speed,ms,BT,B x T,4096,30.159456253051758,30.126911163330078,30.165311813354492,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": false, ""weight_hard_loss"": 0.5, ""weight_soft_loss"": 0.5, ""ignore_index"": -100}",NVIDIA H100 80GB HBM3,2024-12-03 07:58:46,0.4.2
distill_jsd_loss,liger,forward,speed,ms,BT,B x T,8192,60.24163055419922,60.24163055419922,60.24163055419922,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": false, ""weight_hard_loss"": 0.5, ""weight_soft_loss"": 0.5, ""ignore_index"": -100}",NVIDIA H100 80GB HBM3,2024-12-03 07:58:46,0.4.2
distill_jsd_loss,torch,forward,speed,ms,BT,B x T,1024,10.906111717224121,10.903244972229004,10.91296672821045,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": false, ""weight_hard_loss"": 0.5, ""weight_soft_loss"": 0.5, ""ignore_index"": -100}",NVIDIA H100 80GB HBM3,2024-12-03 07:59:18,0.4.2
distill_jsd_loss,torch,forward,speed,ms,BT,B x T,2048,21.480207443237305,21.465139389038086,21.489286422729492,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": false, ""weight_hard_loss"": 0.5, ""weight_soft_loss"": 0.5, ""ignore_index"": -100}",NVIDIA H100 80GB HBM3,2024-12-03 07:59:18,0.4.2
distill_jsd_loss,torch,forward,speed,ms,BT,B x T,4096,42.96339416503906,42.96237564086914,42.96440887451172,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": false, ""weight_hard_loss"": 0.5, ""weight_soft_loss"": 0.5, ""ignore_index"": -100}",NVIDIA H100 80GB HBM3,2024-12-03 07:59:18,0.4.2
distill_jsd_loss,torch,forward,speed,ms,BT,B x T,8192,85.3946533203125,85.3946533203125,85.3946533203125,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": false, ""weight_hard_loss"": 0.5, ""weight_soft_loss"": 0.5, ""ignore_index"": -100}",NVIDIA H100 80GB HBM3,2024-12-03 07:59:18,0.4.2
distill_jsd_loss,liger,full,speed,ms,BT,B x T,1024,8.312895774841309,8.310400009155273,8.326751708984375,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": false, ""weight_hard_loss"": 0.5, ""weight_soft_loss"": 0.5, ""ignore_index"": -100}",NVIDIA H100 80GB HBM3,2024-12-03 07:59:51,0.4.2
distill_jsd_loss,liger,full,speed,ms,BT,B x T,2048,15.770208358764648,15.767775535583496,15.774784088134766,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": false, ""weight_hard_loss"": 0.5, ""weight_soft_loss"": 0.5, ""ignore_index"": -100}",NVIDIA H100 80GB HBM3,2024-12-03 07:59:51,0.4.2
distill_jsd_loss,liger,full,speed,ms,BT,B x T,4096,30.922752380371094,30.920312881469727,30.927898406982422,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": false, ""weight_hard_loss"": 0.5, ""weight_soft_loss"": 0.5, ""ignore_index"": -100}",NVIDIA H100 80GB HBM3,2024-12-03 07:59:51,0.4.2
distill_jsd_loss,liger,full,speed,ms,BT,B x T,8192,60.70627212524414,60.70627212524414,60.70627212524414,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": false, ""weight_hard_loss"": 0.5, ""weight_soft_loss"": 0.5, ""ignore_index"": -100}",NVIDIA H100 80GB HBM3,2024-12-03 07:59:51,0.4.2
distill_jsd_loss,torch,full,speed,ms,BT,B x T,1024,28.72480010986328,28.718809127807617,28.728179931640625,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": false, ""weight_hard_loss"": 0.5, ""weight_soft_loss"": 0.5, ""ignore_index"": -100}",NVIDIA H100 80GB HBM3,2024-12-03 08:00:25,0.4.2
distill_jsd_loss,torch,full,speed,ms,BT,B x T,2048,54.281761169433594,54.281761169433594,54.281761169433594,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": false, ""weight_hard_loss"": 0.5, ""weight_soft_loss"": 0.5, ""ignore_index"": -100}",NVIDIA H100 80GB HBM3,2024-12-03 08:00:25,0.4.2
distill_jsd_loss,torch,full,speed,ms,BT,B x T,4096,107.08905792236328,107.08905792236328,107.08905792236328,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": false, ""weight_hard_loss"": 0.5, ""weight_soft_loss"": 0.5, ""ignore_index"": -100}",NVIDIA H100 80GB HBM3,2024-12-03 08:00:25,0.4.2
distill_jsd_loss,torch,full,speed,ms,BT,B x T,8192,213.1598663330078,213.1598663330078,213.1598663330078,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": false, ""weight_hard_loss"": 0.5, ""weight_soft_loss"": 0.5, ""ignore_index"": -100}",NVIDIA H100 80GB HBM3,2024-12-03 08:00:25,0.4.2
distill_jsd_loss,liger,full,memory,MB,BT,B x T,1024,10913.541015625,10913.541015625,10913.541015625,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": false, ""weight_hard_loss"": 0.5, ""weight_soft_loss"": 0.5, ""ignore_index"": -100}",NVIDIA H100 80GB HBM3,2024-12-03 08:00:58,0.4.2
distill_jsd_loss,liger,full,memory,MB,BT,B x T,2048,10941.548828125,10941.548828125,10941.548828125,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": false, ""weight_hard_loss"": 0.5, ""weight_soft_loss"": 0.5, ""ignore_index"": -100}",NVIDIA H100 80GB HBM3,2024-12-03 08:00:58,0.4.2
distill_jsd_loss,liger,full,memory,MB,BT,B x T,4096,10997.564453125,10997.564453125,10997.564453125,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": false, ""weight_hard_loss"": 0.5, ""weight_soft_loss"": 0.5, ""ignore_index"": -100}",NVIDIA H100 80GB HBM3,2024-12-03 08:00:58,0.4.2
distill_jsd_loss,liger,full,memory,MB,BT,B x T,8192,11109.595703125,11109.595703125,11109.595703125,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": false, ""weight_hard_loss"": 0.5, ""weight_soft_loss"": 0.5, ""ignore_index"": -100}",NVIDIA H100 80GB HBM3,2024-12-03 08:00:58,0.4.2
distill_jsd_loss,torch,full,memory,MB,BT,B x T,1024,16174.0390625,16174.0390625,16174.0390625,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": false, ""weight_hard_loss"": 0.5, ""weight_soft_loss"": 0.5, ""ignore_index"": -100}",NVIDIA H100 80GB HBM3,2024-12-03 08:01:32,0.4.2
distill_jsd_loss,torch,full,memory,MB,BT,B x T,2048,23713.05078125,23713.05078125,23713.05078125,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": false, ""weight_hard_loss"": 0.5, ""weight_soft_loss"": 0.5, ""ignore_index"": -100}",NVIDIA H100 80GB HBM3,2024-12-03 08:01:32,0.4.2
distill_jsd_loss,torch,full,memory,MB,BT,B x T,4096,38791.07421875,38791.07421875,38791.07421875,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": false, ""weight_hard_loss"": 0.5, ""weight_soft_loss"": 0.5, ""ignore_index"": -100}",NVIDIA H100 80GB HBM3,2024-12-03 08:01:32,0.4.2
distill_jsd_loss,torch,full,memory,MB,BT,B x T,8192,68947.1015625,68947.1015625,68947.1015625,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": false, ""weight_hard_loss"": 0.5, ""weight_soft_loss"": 0.5, ""ignore_index"": -100}",NVIDIA H100 80GB HBM3,2024-12-03 08:01:32,0.4.2
261 changes: 261 additions & 0 deletions benchmark/scripts/benchmark_distill_jsd_loss.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,261 @@
import os
import sys

import torch
import triton

from utils import QUANTILES
from utils import SingleBenchmarkRunInput
from utils import SingleBenchmarkRunOutput
from utils import _test_memory
from utils import parse_benchmark_script_args
from utils import run_benchmarks

from liger_kernel.chunked_loss.jsd_loss import LigerFusedLinearJSDFunction
from liger_kernel.utils import infer_device

device = infer_device()

sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "../..")))


class TorchJSDLoss(torch.nn.Module):
def __init__(
self,
H: int,
V: int,
dtype: torch.dtype,
weight_hard_loss: float = 0.5,
weight_soft_loss: float = 0.5,
ignore_index: int = -100,
temperature: float = 1.0,
bias: bool = False,
):
from test.chunked_loss.test_jsd_loss import HFJSDLoss

super().__init__()
self.student_lin = torch.nn.Linear(in_features=H // 2, out_features=V, bias=bias, dtype=dtype)
self.teacher_lin = torch.nn.Linear(in_features=H, out_features=V, bias=bias, dtype=dtype)
self.jsd_loss = HFJSDLoss(
ignore_index=ignore_index,
weight_hard_loss=weight_hard_loss,
weight_soft_loss=weight_soft_loss,
temperature=temperature,
).get_batch_loss_metrics

def forward(self, student, teacher, target):
return self.jsd_loss(
student,
self.student_lin.weight,
teacher,
self.teacher_lin.weight,
target,
)


class LigerJSDLoss(torch.nn.Module):
def __init__(
self,
H: int,
V: int,
dtype: torch.dtype,
weight_hard_loss: float = 0.5,
weight_soft_loss: float = 0.5,
ignore_index: int = -100,
temperature: float = 1.0,
bias: bool = False,
):
super().__init__()
self.student_lin = torch.nn.Linear(in_features=H // 2, out_features=V, bias=bias, dtype=dtype)
self.teacher_lin = torch.nn.Linear(in_features=H, out_features=V, bias=bias, dtype=dtype)
self.weight_hard_loss = weight_hard_loss
self.weight_soft_loss = weight_soft_loss
self.ignore_index = ignore_index
self.temperature = temperature
self.jsd_loss = LigerFusedLinearJSDFunction.apply

def forward(self, student, teacher, target):
return self.jsd_loss(
student,
self.student_lin.weight,
teacher,
self.teacher_lin.weight,
target,
self.weight_hard_loss,
self.weight_soft_loss,
)


def bench_memory_jsd_loss(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput:
BT = input.x
H = input.extra_benchmark_config["H"]
V = input.extra_benchmark_config["V"]
dtype = input.extra_benchmark_config["dtype"]
bias = input.extra_benchmark_config["bias"]
weight_hard_loss = input.extra_benchmark_config["weight_hard_loss"]
weight_soft_loss = input.extra_benchmark_config["weight_soft_loss"]
ignore_index = input.extra_benchmark_config["ignore_index"]
provider = input.kernel_provider

torch_jsd_loss = TorchJSDLoss(
H=H,
V=V,
dtype=dtype,
ignore_index=ignore_index,
bias=bias,
weight_hard_loss=weight_hard_loss,
weight_soft_loss=weight_soft_loss,
).to(device)
liger_jsd_loss = LigerJSDLoss(
H=H,
V=V,
dtype=dtype,
ignore_index=ignore_index,
bias=bias,
weight_hard_loss=weight_hard_loss,
weight_soft_loss=weight_soft_loss,
).to(device)

_tensor = torch.rand(BT, H // 2, device=device, dtype=dtype)
student_input1 = _tensor.detach().clone().requires_grad_(True)
student_input2 = _tensor.detach().clone().requires_grad_(True)

teacher_input = torch.rand(BT, H, device=device, dtype=dtype)

target = torch.randint(0, V, (BT,), device=device, dtype=torch.long)

def fwd():
if provider == "liger":
return liger_jsd_loss(student_input1, teacher_input, target)
elif provider == "torch":
return torch_jsd_loss(student_input2, teacher_input, target)

def full():
y = fwd()
y.backward()

mem_50, mem_20, mem_80 = _test_memory(full, _iter=10, quantiles=QUANTILES)
return SingleBenchmarkRunOutput(
y_20=mem_20,
y_50=mem_50,
y_80=mem_80,
)


def bench_speed_jsd_loss(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput:
BT = input.x
H = input.extra_benchmark_config["H"]
V = input.extra_benchmark_config["V"]
dtype = input.extra_benchmark_config["dtype"]
bias = input.extra_benchmark_config["bias"]
weight_hard_loss = input.extra_benchmark_config["weight_hard_loss"]
weight_soft_loss = input.extra_benchmark_config["weight_soft_loss"]
ignore_index = input.extra_benchmark_config["ignore_index"]
provider = input.kernel_provider
mode = input.kernel_operation_mode

torch_jsd_loss = TorchJSDLoss(
H=H,
V=V,
dtype=dtype,
ignore_index=ignore_index,
bias=bias,
weight_hard_loss=weight_hard_loss,
weight_soft_loss=weight_soft_loss,
).to(device)
liger_jsd_loss = LigerJSDLoss(
H=H,
V=V,
dtype=dtype,
ignore_index=ignore_index,
bias=bias,
weight_hard_loss=weight_hard_loss,
weight_soft_loss=weight_soft_loss,
).to(device)

_tensor = torch.rand(BT, H // 2, device=device, dtype=dtype)
student_input1 = _tensor.detach().clone().requires_grad_(True)
student_input2 = _tensor.detach().clone().requires_grad_(True)

teacher_input = torch.rand(BT, H, device=device, dtype=dtype)

target = torch.randint(0, V, (BT,), device=device, dtype=torch.long)

def fwd():
if provider == "liger":
return liger_jsd_loss(student_input1, teacher_input, target)
elif provider == "torch":
return torch_jsd_loss(student_input2, teacher_input, target)

if mode == "forward":
ms_50, ms_20, ms_80 = triton.testing.do_bench(
fwd,
rep=100,
quantiles=QUANTILES,
)
elif mode == "backward":
y = fwd()
ms_50, ms_20, ms_80 = triton.testing.do_bench(
lambda: y.backward(retain_graph=True),
grad_to_none=[student_input1, student_input2],
rep=100,
quantiles=QUANTILES,
)
elif mode == "full":

def full():
y = fwd()
y.backward()

ms_50, ms_20, ms_80 = triton.testing.do_bench(
full,
rep=100,
quantiles=QUANTILES,
)

return SingleBenchmarkRunOutput(
y_20=ms_20,
y_50=ms_50,
y_80=ms_80,
)


if __name__ == "__main__":
args = parse_benchmark_script_args()

common_configs = {
"kernel_name": "distill_jsd_loss",
"x_name": "BT",
"x_label": "B x T",
"x_values": [2**i for i in range(10, 14)],
"kernel_providers": ["liger", "torch"],
"extra_benchmark_configs": [
{
"H": 4096,
"V": 128256,
"mode": "forward",
"dtype": torch.bfloat16,
"bias": False,
"weight_hard_loss": 0.5,
"weight_soft_loss": 0.5,
"ignore_index": -100,
}
],
"overwrite": args.overwrite,
}

run_benchmarks(
bench_test_fn=bench_speed_jsd_loss,
kernel_operation_modes=["forward", "full"],
metric_name="speed",
metric_unit="ms",
**common_configs,
)

run_benchmarks(
bench_test_fn=bench_memory_jsd_loss,
kernel_operation_modes=["full"],
metric_name="memory",
metric_unit="MB",
**common_configs,
)
1 change: 1 addition & 0 deletions src/liger_kernel/chunked_loss/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from liger_kernel.chunked_loss.cpo_loss import LigerFusedLinearCPOLoss # noqa: F401
from liger_kernel.chunked_loss.dpo_loss import LigerFusedLinearDPOLoss # noqa: F401
from liger_kernel.chunked_loss.jsd_loss import LigerFusedLinearJSDLoss # noqa: F401
from liger_kernel.chunked_loss.kto_loss import LigerFusedLinearKTOLoss # noqa: F401
from liger_kernel.chunked_loss.orpo_loss import LigerFusedLinearORPOLoss # noqa: F401
from liger_kernel.chunked_loss.simpo_loss import LigerFusedLinearSimPOLoss # noqa: F401
2 changes: 2 additions & 0 deletions src/liger_kernel/chunked_loss/functional.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
from liger_kernel.chunked_loss.cpo_loss import LigerFusedLinearCPOFunction
from liger_kernel.chunked_loss.dpo_loss import LigerFusedLinearDPOFunction
from liger_kernel.chunked_loss.jsd_loss import LigerFusedLinearJSDFunction
from liger_kernel.chunked_loss.kto_loss import LigerFusedLinearKTOFunction
from liger_kernel.chunked_loss.orpo_loss import LigerFusedLinearORPOFunction
from liger_kernel.chunked_loss.simpo_loss import LigerFusedLinearSimPOFunction

liger_fused_linear_orpo = LigerFusedLinearORPOFunction.apply
liger_fused_linear_dpo = LigerFusedLinearDPOFunction.apply
liger_fused_linear_jsd = LigerFusedLinearJSDFunction.apply
liger_fused_linear_cpo = LigerFusedLinearCPOFunction.apply
liger_fused_linear_simpo = LigerFusedLinearSimPOFunction.apply
liger_fused_linear_kto = LigerFusedLinearKTOFunction.apply
Loading
Loading