diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml
index cd73ef928..0454da4ce 100644
--- a/.pre-commit-config.yaml
+++ b/.pre-commit-config.yaml
@@ -69,4 +69,5 @@ repos:
| ^docs
| ^configs
| ^.*/configs*
+ | ^projects
)
diff --git a/README.md b/README.md
index 4dbb364d5..ad92732cf 100644
--- a/README.md
+++ b/README.md
@@ -61,6 +61,8 @@ English | [简体中文](README_zh-CN.md)
+**:star: MMRazor for Large Models** is Available Now! Please refer to [MMRazorLarge](projects/mmrazor_large/README.md)
+
## Introduction
MMRazor is a model compression toolkit for model slimming and AutoML, which includes 4 mainstream technologies:
diff --git a/mmrazor/implementations/pruning/__init__.py b/mmrazor/implementations/pruning/__init__.py
index e28ae7dc2..d536adf1f 100644
--- a/mmrazor/implementations/pruning/__init__.py
+++ b/mmrazor/implementations/pruning/__init__.py
@@ -1,4 +1,4 @@
# Copyright (c) OpenMMLab. All rights reserved.
-from . import group_fisher
+from . import group_fisher, sparse_gpt
-__all__ = ['group_fisher']
+__all__ = ['group_fisher', 'sparse_gpt']
diff --git a/mmrazor/implementations/pruning/sparse_gpt/__init__.py b/mmrazor/implementations/pruning/sparse_gpt/__init__.py
new file mode 100644
index 000000000..8caae7fd0
--- /dev/null
+++ b/mmrazor/implementations/pruning/sparse_gpt/__init__.py
@@ -0,0 +1,9 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from .compressor import SparseGptCompressor
+from .ops import SparseGptLinear, SparseGptMixIn
+from .utils import replace_with_dynamic_ops
+
+__all__ = [
+ 'SparseGptLinear', 'SparseGptMixIn', 'replace_with_dynamic_ops',
+ 'SparseGptCompressor'
+]
diff --git a/mmrazor/implementations/pruning/sparse_gpt/compressor.py b/mmrazor/implementations/pruning/sparse_gpt/compressor.py
new file mode 100644
index 000000000..f5ef42ec6
--- /dev/null
+++ b/mmrazor/implementations/pruning/sparse_gpt/compressor.py
@@ -0,0 +1,106 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import torch
+import torch.nn as nn
+
+from mmrazor.utils import print_log
+from .ops import SparseGptConv2d, SparseGptLinear, SparseGptMixIn
+from .utils import replace_with_dynamic_ops
+
+
+def to_static_model(model: nn.Module):
+ """Replace dynamicops with torch modules."""
+ from mmrazor.structures.subnet.fix_subnet import (export_fix_subnet,
+ load_fix_subnet)
+ fix_subnet = export_fix_subnet(model)[0]
+ load_fix_subnet(model, fix_subnet)
+ return model
+
+
+class SparseGptCompressor():
+ """The compressor with SparseGPT."""
+
+ def __init__(self) -> None:
+ self.model: nn.Module = None
+
+ def prepare(self,
+ model: nn.Module,
+ prune_conv=True,
+ prune_linear=True) -> None:
+ """Prepare for compressing model."""
+ self.model = model
+ prune_modules: dict = {}
+ if prune_conv:
+ prune_modules[nn.Conv2d] = SparseGptConv2d
+ if prune_linear:
+ prune_modules[nn.Linear] = SparseGptLinear
+ replace_with_dynamic_ops(model, prune_modules)
+
+ @classmethod
+ def to_static_model(cls, model):
+ """Convert replaced op with the original torch model."""
+ return to_static_model(model)
+
+ # hessian
+
+ def register_hessian_hooks(self):
+ """Register updating hessian hooks for specified ops."""
+ for module in self.sparse_ops:
+ module.register_hessian_hook()
+
+ def remove_hessian_hooks(self):
+ """Remove updating hessian hooks for specified ops."""
+ for module in self.sparse_ops:
+ module.remove_hessian_hook()
+
+ def init_hessian(self, device=None):
+ """Init hessian."""
+ for op in self.sparse_ops:
+ op.init_hessian(device=device)
+
+ # prune
+ def prune(self,
+ sparsity,
+ prunen=0,
+ prunem=0,
+ blocksize=128,
+ percdamp=.01,
+ device=torch.device('cuda')):
+ """Apply the compression algorithm to the model."""
+ for name, module in self.named_sparse_ops:
+ try:
+ original_device = next(module.parameters()).device
+ module: SparseGptMixIn = module.to(device)
+ error = module.prune(
+ sparsity=sparsity,
+ prunen=prunen,
+ prunem=prunem,
+ blocksize=blocksize,
+ percdamp=percdamp,
+ )
+ print_log(f'prune {name} success \t error = {error}')
+ module.to(original_device)
+ torch.cuda.empty_cache()
+ except Exception as e:
+ print_log(f'prune {name} failed as {e}')
+
+ def prune_24(self, device=torch.device('cuda:0')):
+ """Apply the compression algorithm to the model with the specified
+ setting."""
+ self.prune(0.5, prunen=2, prunem=4, device=device)
+
+ # ops
+
+ @property
+ def sparse_ops(self):
+ """The ops to be applied the algorithm."""
+ assert self.model is not None
+ for module in self.model.modules():
+ if isinstance(module, SparseGptMixIn):
+ yield module
+
+ @property
+ def named_sparse_ops(self):
+ """The named ops to be applied the algorithm."""
+ for name, module in self.model.named_modules():
+ if isinstance(module, SparseGptMixIn):
+ yield name, module
diff --git a/mmrazor/implementations/pruning/sparse_gpt/ops.py b/mmrazor/implementations/pruning/sparse_gpt/ops.py
new file mode 100644
index 000000000..0f11b176f
--- /dev/null
+++ b/mmrazor/implementations/pruning/sparse_gpt/ops.py
@@ -0,0 +1,278 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import sys
+
+if sys.version_info < (3, 8):
+ from typing_extensions import Protocol
+else:
+ from typing import Protocol
+
+import torch
+import torch.distributed as dist
+import torch.nn as nn
+import torch.nn.functional as F
+
+from mmrazor.models.architectures.dynamic_ops import (DynamicConv2d,
+ DynamicLinear)
+from .utils import ModuleProtocol, torch_setting
+
+
+class SparseGptMixIn(ModuleProtocol):
+ """The core algorithm implementation for SparseGpt."""
+
+ def _sparse_gpt_mix_in_init(self):
+ """Init mixin."""
+ self.sparse_gpt_handles = []
+ self.rows = self.weight_matrix.shape[0]
+ self.columns = self.weight_matrix.shape[1]
+
+ self._hessian: torch.Tensor = None
+ self.hessian_batch = 0
+
+ # weight and input adaptive
+
+ @property
+ def weight_matrix(self):
+ """Return weight with shape (out in)"""
+ return self.weight.flatten(1) # out in
+
+ @weight_matrix.setter
+ def weight_matrix(self, value: torch.Tensor):
+ """Set weight."""
+ with torch.no_grad():
+ value = value.reshape(self.weight.shape).to(self.weight.device).to(
+ self.weight.dtype)
+ self.weight.data.copy_(value)
+
+ def format_input(self, input: torch.Tensor):
+ """Return input with shape (B N C)"""
+ if len(input.shape) == 2: # N C
+ input = input.unsqueeze(0) # 1 N C
+ return input
+
+ # compute hessian
+
+ @property
+ def hessian(self):
+ """hessian always return float."""
+ if dist.is_initialized():
+ if dist.get_rank() == 0:
+ assert self._hessian is not None, 'hessian is not initialized.'
+ hessian = self._hessian.to(self.weight_matrix.device)
+ else:
+ hessian = torch.zeros(
+ self.columns,
+ self.columns,
+ device=self.weight_matrix.device)
+ dist.broadcast(hessian, 0)
+ return hessian
+ else:
+ return self._hessian
+
+ @hessian.setter
+ def hessian(self, value: torch.Tensor):
+ """Set hessian."""
+ with torch.no_grad():
+ if dist.is_initialized():
+ if dist.get_rank() == 0:
+ assert self._hessian is not None, 'hessian is not initialized.' # noqa
+ self._hessian.data.copy_(
+ value.data.to(self._hessian.device))
+ else:
+ self._hessian = None
+ else:
+ self._hessian.data.copy_(value.data.to(self._hessian.device))
+
+ @torch.no_grad()
+ def update_hessian(self, input: torch.Tensor):
+ """Update hessian."""
+ input = self.format_input(input).float()
+ H_save = self.hessian
+ H_save = H_save.to(input.device)
+
+ assert len(input.shape) == 3
+ B = input.shape[0] # B N C
+ input = input.transpose(0, -1).flatten(1) # C D
+
+ H = input @ input.T * 2 # C C
+
+ if dist.is_initialized():
+ dist.all_reduce(H)
+ B *= dist.get_world_size()
+ H_save = (H_save * self.hessian_batch + H) / (self.hessian_batch + B)
+ self.hessian = H_save
+ self.hessian_batch = self.hessian_batch + B
+
+ def register_hessian_hook(self):
+ """Register updating hessian hook."""
+
+ @torch.no_grad()
+ def forward_pre_hook(module: Protocol, input: tuple):
+ assert len(input) == 1
+ self.update_hessian(input[0])
+
+ handle = self.register_forward_pre_hook(forward_pre_hook)
+ self.sparse_gpt_handles.append(handle)
+
+ def remove_hessian_hook(self):
+ """Remove updating hessian hook."""
+ for h in self.sparse_gpt_handles:
+ h.remove()
+
+ def init_hessian(self, device=None):
+ """Init hessian."""
+ if dist.is_initialized():
+ if dist.get_rank() == 0:
+ self._hessian = torch.zeros([self.columns, self.columns],
+ device=device,
+ dtype=torch.float)
+ else:
+ self._hessian = None
+ else:
+ self._hessian = torch.zeros([self.columns, self.columns],
+ device=device,
+ dtype=torch.float)
+
+ # prune
+
+ @torch.no_grad()
+ def prune(self, sparsity, prunen=0, prunem=0, blocksize=128, percdamp=.01):
+ """The implementation for SparseGPT."""
+ with torch_setting(dtype=torch.float):
+ # Converted from https://github.com/ist-daslab/sparsegpt
+
+ assert self.hessian is not None
+ W: torch.Tensor = self.weight_matrix.float() # out in
+
+ H = self.hessian.float().to(W.device)
+
+ dead = torch.diag(H) == 0
+ H[dead, dead] = 1
+ W[:, dead] = 0
+
+ Losses = torch.zeros(self.rows, device=W.device)
+
+ damp = percdamp * torch.mean(torch.diag(H))
+ diag = torch.arange(self.columns, device=W.device)
+ H[diag, diag] += damp
+ H = torch.linalg.cholesky(H)
+ H = torch.cholesky_inverse(H)
+ H = torch.linalg.cholesky(H, upper=True)
+ Hinv = H
+
+ mask = None
+
+ for i1 in range(0, self.columns, blocksize):
+ i2 = min(i1 + blocksize, self.columns)
+ count = i2 - i1
+
+ W1 = W[:, i1:i2].clone()
+ Q1 = torch.zeros_like(W1)
+ Err1 = torch.zeros_like(W1)
+ Losses1 = torch.zeros_like(W1)
+ Hinv1 = Hinv[i1:i2, i1:i2]
+
+ if prunen == 0:
+ if mask is not None:
+ mask1 = mask[:, i1:i2]
+ else:
+ tmp = W1**2 / (torch.diag(Hinv1).reshape((1, -1)))**2
+ thresh = torch.sort(tmp.flatten())[0][int(tmp.numel() *
+ sparsity)]
+ mask1 = tmp <= thresh
+ else:
+ mask1 = torch.zeros_like(W1) == 1
+
+ for i in range(count):
+ w = W1[:, i]
+ d = Hinv1[i, i]
+
+ if prunen != 0 and i % prunem == 0:
+ tmp = W1[:, i:(i + prunem)]**2 / (torch.diag(Hinv1)[i:(
+ i + prunem)].reshape((1, -1)))**2
+ mask1.scatter_(
+ 1, i +
+ torch.topk(tmp, prunen, dim=1, largest=False)[1],
+ True)
+
+ q = w.clone()
+ q[mask1[:, i]] = 0
+
+ Q1[:, i] = q
+ Losses1[:, i] = (w - q)**2 / d**2
+
+ err1 = (w - q) / d
+ W1[:,
+ i:] -= err1.unsqueeze(1).matmul(Hinv1[i,
+ i:].unsqueeze(0))
+ Err1[:, i] = err1
+
+ W[:, i1:i2] = Q1
+ Losses += torch.sum(Losses1, 1) / 2
+
+ W[:, i2:] -= Err1.matmul(Hinv[i1:i2, i2:])
+
+ if W.device.type == 'cuda':
+ torch.cuda.synchronize()
+ from .sparse24_utils import is_weight_sparse_24
+ if prunen == 2 and prunem == 4:
+ assert is_weight_sparse_24(
+ W, -1), f'Weight dose not satisfy 24 with shape {W.shape}'
+ error = torch.sum(Losses)
+
+ if torch.isnan(error).any():
+ raise Exception('get nan error')
+ else:
+ self.weight_matrix = W.data
+
+ return error.item()
+
+
+# SparseGpt Ops for Linear and Conv2d
+
+
+class SparseGptLinear(DynamicLinear, SparseGptMixIn):
+ """Custom Linear for SparseGpt."""
+
+ def __init__(self, *args, **kwargs) -> None:
+ super().__init__(*args, **kwargs)
+ self._sparse_gpt_mix_in_init()
+
+ @classmethod
+ def convert_from(cls, module: nn.Linear) -> 'DynamicConv2d':
+ """Convert to cls from torch's module."""
+ if module.out_features < module.in_features:
+ return module
+ new_module = super().convert_from(module)
+ new_module.load_state_dict(module.state_dict(), strict=False)
+
+ dtype = next(module.parameters()).dtype
+ new_module = new_module.to(dtype)
+
+ return new_module
+
+
+class SparseGptConv2d(DynamicConv2d, SparseGptMixIn):
+ """Custom Conv2d for SparseGpt."""
+
+ def __init__(self, *args, **kwargs) -> None:
+ super().__init__(*args, **kwargs)
+ self._sparse_gpt_mix_in_init()
+
+ @classmethod
+ def convert_from(cls, module: nn.Conv2d) -> 'DynamicConv2d':
+ """Convert to cls from torch's module."""
+ new_module = super().convert_from(module)
+ new_module.load_state_dict(module.state_dict(), strict=False)
+
+ dtype = next(module.parameters()).dtype
+ new_module = new_module.to(dtype)
+
+ return new_module
+
+ def format_input(self, input: torch.Tensor):
+ """Format input shape."""
+ # input B C H W
+ input = F.unfold(
+ input, self.kernel_size, padding=self.padding,
+ stride=self.stride) # B C D
+ return input.transpose(-1, -2)
diff --git a/mmrazor/implementations/pruning/sparse_gpt/sparse24_utils.py b/mmrazor/implementations/pruning/sparse_gpt/sparse24_utils.py
new file mode 100644
index 000000000..1d646dee1
--- /dev/null
+++ b/mmrazor/implementations/pruning/sparse_gpt/sparse24_utils.py
@@ -0,0 +1,10 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import torch
+
+
+@torch.no_grad()
+def is_weight_sparse_24(weight: torch.Tensor, dim=-1):
+ """"Check if the weight is sparse 24."""
+ weight = weight.transpose(-1, dim).reshape(-1, 4) # N 4
+ is_zero = (weight == 0).sum(-1) # N
+ return (is_zero >= 2).all()
diff --git a/mmrazor/implementations/pruning/sparse_gpt/utils.py b/mmrazor/implementations/pruning/sparse_gpt/utils.py
new file mode 100644
index 000000000..df82784c1
--- /dev/null
+++ b/mmrazor/implementations/pruning/sparse_gpt/utils.py
@@ -0,0 +1,140 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import sys
+from typing import Dict, Type
+
+if sys.version_info < (3, 8):
+ from typing_extensions import Protocol
+else:
+ from typing import Protocol
+
+import torch
+import torch.nn as nn
+
+from mmrazor.models.architectures.dynamic_ops import DynamicMixin
+from mmrazor.utils import print_log
+
+
+class ModuleProtocol(Protocol):
+ """Custom module protocol for algorithm mixin."""
+ weight: torch.Tensor
+
+ def forward(self, x):
+ """The abstract method."""
+ pass
+
+ def register_forward_hook(self, hook):
+ """The abstract method."""
+ pass
+
+ def register_backward_hook(self, hook):
+ """The abstract method."""
+ pass
+
+ def register_forward_pre_hook(self, hook):
+ """The abstract method."""
+ pass
+
+ def register_buffer(self, name, tensor):
+ """The abstract method."""
+ pass
+
+
+def replace_with_dynamic_ops(model: nn.Module,
+ dynamicop_map: Dict[Type[nn.Module],
+ Type[DynamicMixin]]):
+ """Replace torch modules with dynamic-ops."""
+
+ def replace_op(model: nn.Module, name: str, module: nn.Module):
+ names = name.split('.')
+ for sub_name in names[:-1]:
+ model = getattr(model, sub_name)
+
+ setattr(model, names[-1], module)
+
+ for name, module in model.named_modules():
+ if type(module) in dynamicop_map:
+ new_module = dynamicop_map[type(module)].convert_from(module)
+ replace_op(model, name, new_module)
+
+
+def register_efficient_forward_hook(module: nn.Module,
+ device=torch.device('cuda:0')):
+ """Register efficient forward hook."""
+
+ def forward_pre_hook(module: nn.Module, input):
+ module.to(device)
+
+ def forward_hook(module: nn.Module, input, output):
+ module.to('cpu')
+ torch.cuda.empty_cache()
+
+ h1 = module.register_forward_pre_hook(forward_pre_hook)
+ h2 = module.register_forward_hook(forward_hook)
+ return [h1, h2]
+
+
+def enable_efficient_forward(model: nn.Module,
+ device=torch.device('cuda:0'),
+ wrap_modules=[]):
+ """Enable efficient forward."""
+ handles = []
+ blocks = []
+ for name, module in model.named_children():
+ if type(module) in wrap_modules or len(module._parameters) != 0 or len(
+ module._buffers) != 0:
+ handles_ = register_efficient_forward_hook(module, device)
+ blocks_ = [name]
+ else:
+ handles_, blocks_ = enable_efficient_forward(
+ module, device, wrap_modules)
+ handles += handles_
+ blocks += blocks_
+ return handles, blocks
+
+
+class memory_efficient_forward:
+ """The class for Memory efficient forward."""
+
+ def __init__(self,
+ model: nn.Module,
+ enabled=True,
+ device=torch.device('cuda:0'),
+ wrap_modules=[]) -> None:
+ self.model = model
+ self.device = device
+ self.wrap_modules = wrap_modules
+ self.enabled = enabled
+ self.handlers: list = []
+
+ if not enabled:
+ model.to(device)
+
+ def __enter__(self, ):
+ """Enter."""
+ if self.enabled:
+ handles, blocks = enable_efficient_forward(self.model, self.device,
+ self.wrap_modules)
+ print_log(f'enable memory efficient forward for {blocks}')
+ self.handlers = handles
+
+ def __exit__(self, exc_type, exc_value, exc_traceback):
+ """Exit."""
+ for h in self.handlers:
+ h.remove()
+
+
+class torch_setting():
+ """Set the default torch dtype setting."""
+
+ def __init__(self, dtype=None) -> None:
+ self.original_dtype = torch.get_default_dtype()
+ self.dtype = dtype
+
+ def __enter__(self):
+ """Enter."""
+ if self.dtype is not None:
+ torch.set_default_dtype(self.dtype)
+
+ def __exit__(self, exc_type, exc_value, exc_traceback):
+ """Exit."""
+ torch.set_default_dtype(self.original_dtype)
diff --git a/mmrazor/implementations/quantization/gptq/__init__.py b/mmrazor/implementations/quantization/gptq/__init__.py
new file mode 100644
index 000000000..4981c8014
--- /dev/null
+++ b/mmrazor/implementations/quantization/gptq/__init__.py
@@ -0,0 +1,14 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from .compressor import GPTQCompressor
+from .gptq import GPTQMixIn
+from .ops import GPTQConv2d, GPTQLinear, TritonGPTQLinear
+from .quantizer import Quantizer
+
+__all__ = [
+ 'GPTQCompressor',
+ 'GPTQMixIn',
+ 'GPTQConv2d',
+ 'GPTQLinear',
+ 'TritonGPTQLinear',
+ 'Quantizer',
+]
diff --git a/mmrazor/implementations/quantization/gptq/compressor.py b/mmrazor/implementations/quantization/gptq/compressor.py
new file mode 100644
index 000000000..4a5aadd80
--- /dev/null
+++ b/mmrazor/implementations/quantization/gptq/compressor.py
@@ -0,0 +1,146 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from typing import Any, Dict, Type
+
+import torch
+import torch.nn as nn
+
+from mmrazor.utils import print_log
+from .ops import GPTQConv2d, GPTQLinear, GPTQMixIn, TritonGPTQLinear
+from .quantizer import Quantizer
+
+
+def replace_with_dynamic_ops(model: nn.Module,
+ dynamicop_map: Dict[Type[nn.Module], Type[Any]],
+ skipped_layers=[],
+ a_qconfig=None,
+ **kwargs):
+ """Replace torch modules with dynamic-ops."""
+
+ def replace_op(model: nn.Module, name: str, module: nn.Module):
+ names = name.split('.')
+ for sub_name in names[:-1]:
+ model = getattr(model, sub_name)
+
+ setattr(model, names[-1], module)
+
+ for name, module in model.named_modules():
+ if type(module) in dynamicop_map and name not in skipped_layers:
+ if isinstance(module, nn.Linear):
+ if a_qconfig:
+ a_fakequant = Quantizer()
+ a_fakequant.configure(**a_qconfig)
+ kwargs.update({'a_fakequant': a_fakequant})
+ new_module = dynamicop_map[type(module)].convert_from(
+ module, **kwargs)
+ else:
+ new_module = dynamicop_map[type(module)].convert_from(module)
+ replace_op(model, name, new_module)
+
+
+def to_static_model(model: nn.Module):
+ """Replace dynamicops with torch modules."""
+ from mmrazor.structures.subnet.fix_subnet import (export_fix_subnet,
+ load_fix_subnet)
+ fix_subnet = export_fix_subnet(model)[0]
+ load_fix_subnet(model, fix_subnet)
+ return model
+
+
+class GPTQCompressor():
+ """The compressor with GPTQ."""
+
+ def __init__(self) -> None:
+ self.model: nn.Module = None
+
+ def prepare(self,
+ model: nn.Module,
+ quant_conv=True,
+ quant_linear=True,
+ use_triton_ops=True,
+ skipped_layers=[],
+ a_qconfig=None,
+ **kwargs) -> None:
+ """Prepare for compressing model."""
+ self.model = model
+ quant_modules: dict = {}
+ if quant_conv:
+ quant_modules[nn.Conv2d] = GPTQConv2d
+ if quant_linear:
+ gptq_linear = TritonGPTQLinear if use_triton_ops else GPTQLinear
+ quant_modules[nn.Linear] = gptq_linear
+ replace_with_dynamic_ops(model, quant_modules, skipped_layers,
+ a_qconfig, **kwargs)
+
+ @classmethod
+ def to_static_model(cls, model):
+ """Convert replaced op with the original torch model."""
+ return to_static_model(model)
+
+ # hessian
+
+ def register_hessian_hooks(self):
+ """Register updating hessian hooks for specified ops."""
+ for module in self.quant_ops:
+ module.register_hessian_hook()
+
+ def remove_hessian_hooks(self):
+ """Remove updating hessian hooks for specified ops."""
+ for module in self.quant_ops:
+ module.remove_hessian_hook()
+
+ def init_hessian(self, device=None):
+ """Init hessian."""
+ for op in self.quant_ops:
+ op.init_hessian(device=device)
+
+ # quant
+ def quant(self,
+ blocksize=128,
+ percdamp=0.01,
+ groupsize=-1,
+ actorder=False,
+ device=torch.device('cuda:0'),
+ **qconfig):
+ """Apply the compression algorithm to the model."""
+ for name, module in self.named_quant_ops:
+ try:
+ original_device = next(module.parameters()).device
+ module: GPTQMixIn = module.to(device)
+ quantizer = Quantizer()
+ quantizer.configure(**qconfig)
+ # print_log(f'quant {name}...')
+ error = module.quant(
+ quantizer=quantizer,
+ blocksize=blocksize,
+ percdamp=percdamp,
+ groupsize=groupsize,
+ actorder=actorder)
+ print_log(f'quant {name} success \t error = {error}')
+ module.to(original_device)
+ module.free()
+ except Exception as e:
+ print_log(f'quant {name} failed as {e}')
+
+ def quant_with_default_qconfig(self, groupsize=128, device='cpu'):
+ """Apply the compression algorithm to the model with the specified
+ setting."""
+ qconfig = dict(bits=4, perchannel=True, sym=False)
+ self.quant(
+ groupsize=groupsize, actorder=True, device=device, **qconfig)
+
+ # ops
+
+ @property
+ def quant_ops(self):
+ """The ops to be applied the algorithm."""
+ assert self.model is not None
+ for module in self.model.modules():
+ if isinstance(module, GPTQMixIn):
+ yield module
+
+ @property
+ def named_quant_ops(self):
+ """The named ops to be applied the algorithm."""
+ for name, module in self.model.named_modules():
+ if isinstance(module, GPTQMixIn):
+ yield name, module
diff --git a/mmrazor/implementations/quantization/gptq/custom_autotune.py b/mmrazor/implementations/quantization/gptq/custom_autotune.py
new file mode 100644
index 000000000..1bc0d7d5f
--- /dev/null
+++ b/mmrazor/implementations/quantization/gptq/custom_autotune.py
@@ -0,0 +1,254 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+# https://github.com/fpgaminer/GPTQ-triton
+"""Mostly the same as the autotuner in Triton, but with a few changes like
+using 40 runs instead of 100."""
+
+import builtins
+import math
+import time
+from typing import Dict
+
+try:
+ import triton
+except ImportError:
+ from mmrazor.utils import get_package_placeholder
+ triton = get_package_placeholder('triton >= 2.0.0')
+
+
+class Autotuner(triton.KernelInterface):
+ """Autotuner."""
+
+ def __init__(self,
+ fn,
+ arg_names,
+ configs,
+ key,
+ reset_to_zero,
+ prune_configs_by: Dict = None,
+ nearest_power_of_two: bool = False):
+ '''prune_configs_by: a dict of functions that are used to prune
+ configs, fields:
+ 'perf_model': performance model used to predicate running time
+ with different configs, returns running time
+ 'top_k': number of configs to bench
+ 'prune_num_stages_by'(optional): a function used to prune
+ num_stages. It take configs:List[Config] as its input, and
+ returns pruned configs.
+ 'nearest_power_of_two'(optional): whether to round key arguments
+ to the nearest power of two when caching tuning results.'''
+ if not configs:
+ self.configs = [triton.Config({}, num_warps=4, num_stages=2)]
+ else:
+ self.configs = configs
+ self.key_idx = [arg_names.index(k) for k in key]
+ self.nearest_power_of_two = nearest_power_of_two
+ self.cache: Dict = {}
+ # hook to reset all required tensor to zeros before relaunching
+ # a kernel
+ self.hook = lambda args: 0
+ if reset_to_zero is not None:
+ self.reset_idx = [arg_names.index(k) for k in reset_to_zero]
+
+ def _hook(args):
+ for i in self.reset_idx:
+ args[i].zero_()
+
+ self.hook = _hook
+ self.arg_names = arg_names
+ # prune configs
+ if prune_configs_by:
+ perf_model, top_k = prune_configs_by[
+ 'perf_model'], prune_configs_by['top_k']
+ if 'early_config_prune' in prune_configs_by:
+ early_config_prune = prune_configs_by['early_config_prune']
+ else:
+ perf_model, top_k, early_config_prune = None, None, None
+ self.perf_model, self.configs_top_k = perf_model, top_k
+ self.early_config_prune = early_config_prune
+ self.fn = fn
+
+ def _bench(self, *args, config, **meta):
+ """Check for conflicts, i.e. meta-parameters both provided as kwargs
+ and by the autotuner."""
+ conflicts = meta.keys() & config.kwargs.keys()
+ if conflicts:
+ raise ValueError(
+ f"Conflicting meta-parameters: {', '.join(conflicts)}."
+ " Make sure that you don't re-define auto-tuned symbols.")
+ # augment meta-parameters with tunable ones
+ current = dict(meta, **config.kwargs)
+
+ def kernel_call():
+ if config.pre_hook:
+ config.pre_hook(self.nargs)
+ self.hook(args)
+ self.fn.run(
+ *args,
+ num_warps=config.num_warps,
+ num_stages=config.num_stages,
+ **current)
+
+ try:
+ # In testings using only 40 reps seems to be close enough and it
+ # appears to be what PyTorch uses
+ # PyTorch also sets fast_flush to True, but I didn't see any
+ # speedup so I'll leave the default
+ return triton.testing.do_bench(
+ kernel_call, percentiles=(0.5, 0.2, 0.8), rep=40)
+ except triton.compiler.OutOfResources:
+ return (float('inf'), float('inf'), float('inf'))
+
+ def run(self, *args, **kwargs):
+ """Run."""
+ self.nargs = dict(zip(self.arg_names, args))
+ if len(self.configs) > 1:
+ key = tuple(args[i] for i in self.key_idx)
+
+ # This reduces the amount of autotuning by rounding the keys to
+ # the nearest power of two
+ # In my testing this gives decent results, and greatly reduces
+ # the amount of tuning required
+ if self.nearest_power_of_two:
+ key = tuple([2**int(math.log2(x) + 0.5) for x in key])
+
+ if key not in self.cache:
+ # prune configs
+ pruned_configs = self.prune_configs(kwargs)
+ bench_start = time.time()
+ timings = {
+ config: self._bench(*args, config=config, **kwargs)
+ for config in pruned_configs
+ }
+ bench_end = time.time()
+ self.bench_time = bench_end - bench_start
+ self.cache[key] = builtins.min(timings, key=timings.get)
+ self.hook(args)
+ self.configs_timings = timings
+ config = self.cache[key]
+ else:
+ config = self.configs[0]
+ self.best_config = config
+ if config.pre_hook is not None:
+ config.pre_hook(self.nargs)
+ return self.fn.run(
+ *args,
+ num_warps=config.num_warps,
+ num_stages=config.num_stages,
+ **kwargs,
+ **config.kwargs)
+
+ def prune_configs(self, kwargs):
+ """Prune configs."""
+ pruned_configs = self.configs
+ if self.early_config_prune:
+ pruned_configs = self.early_config_prune(self.configs, self.nargs)
+ if self.perf_model:
+ top_k = self.configs_top_k
+ if isinstance(top_k, float) and top_k <= 1.0:
+ top_k = int(len(self.configs) * top_k)
+ if len(pruned_configs) > top_k:
+ est_timing = {
+ config: self.perf_model(
+ **self.nargs,
+ **kwargs,
+ **config.kwargs,
+ num_stages=config.num_stages,
+ num_warps=config.num_warps)
+ for config in pruned_configs
+ }
+ pruned_configs = sorted(
+ est_timing.keys(), key=lambda x: est_timing[x])[:top_k]
+ return pruned_configs
+
+ def warmup(self, *args, **kwargs):
+ """Warm up."""
+ self.nargs = dict(zip(self.arg_names, args))
+ for config in self.prune_configs(kwargs):
+ self.fn.warmup(
+ *args,
+ num_warps=config.num_warps,
+ num_stages=config.num_stages,
+ **kwargs,
+ **config.kwargs,
+ )
+ self.nargs = None
+
+
+def autotune(configs,
+ key,
+ prune_configs_by=None,
+ reset_to_zero=None,
+ nearest_power_of_two=False):
+ """Decorator for auto-tuning a :code:`triton.jit`'d function.
+
+ .. highlight:: python
+ .. code-block:: python
+ @triton.autotune(configs=[
+ triton.Config(meta={'BLOCK_SIZE': 128}, num_warps=4),
+ triton.Config(meta={'BLOCK_SIZE': 1024}, num_warps=8),
+ ],
+ key=['x_size'] # the two above configs will be evaluated
+ # anytime the value of x_size changes
+ )
+ @triton.jit
+ def kernel(x_ptr, x_size, **META):
+ BLOCK_SIZE = META['BLOCK_SIZE']
+ :note: When all the configurations are evaluated, the kernel will run
+ multiple time.This means that whatever value the kernel updates will
+ be updated multiple times.To avoid this undesired behavior, you can
+ use the `reset_to_zero` argument, which reset the value of the
+ provided tensor to `zero` before running any configuration.
+ :param configs: a list of :code:`triton.Config` objects
+ :type configs: list[triton.Config]
+ :param key: a list of argument names whose change in value will trigger
+ the evaluation of all provided configs.
+ :type key: list[str]
+ :param prune_configs_by: a dict of functions that are used to prune
+ configs, fields:
+ 'perf_model': performance model used to predicate running time with
+ different configs, returns running time
+ 'top_k': number of configs to bench
+ 'early_config_prune'(optional): a function used to do early prune
+ (eg, num_stages). It take configs:List[Config] as its input, and
+ returns pruned configs.
+ :param reset_to_zero: a list of argument names whose value will be reset
+ to zero before evaluating any configs.
+ :type reset_to_zero: list[str]
+ """
+
+ def decorator(fn):
+ return Autotuner(fn, fn.arg_names, configs, key, reset_to_zero,
+ prune_configs_by, nearest_power_of_two)
+
+ return decorator
+
+
+def matmul248_kernel_config_pruner(configs, nargs):
+ """The main purpose of this function is to shrink BLOCK_SIZE_* when the
+ corresponding dimension is smaller."""
+ m = max(2**int(math.ceil(math.log2(nargs['M']))), 16)
+ n = max(2**int(math.ceil(math.log2(nargs['N']))), 16)
+ k = max(2**int(math.ceil(math.log2(nargs['K']))), 16)
+
+ used = set()
+ for config in configs:
+ block_size_m = min(m, config.kwargs['BLOCK_SIZE_M'])
+ block_size_n = min(n, config.kwargs['BLOCK_SIZE_N'])
+ block_size_k = min(k, config.kwargs['BLOCK_SIZE_K'])
+ group_size_m = config.kwargs['GROUP_SIZE_M']
+
+ if (block_size_m, block_size_n, block_size_k, group_size_m,
+ config.num_stages, config.num_warps) in used:
+ continue
+
+ used.add((block_size_m, block_size_n, block_size_k, group_size_m,
+ config.num_stages, config.num_warps))
+ yield triton.Config(
+ {
+ 'BLOCK_SIZE_M': block_size_m,
+ 'BLOCK_SIZE_N': block_size_n,
+ 'BLOCK_SIZE_K': block_size_k,
+ 'GROUP_SIZE_M': group_size_m
+ },
+ num_stages=config.num_stages,
+ num_warps=config.num_warps)
diff --git a/mmrazor/implementations/quantization/gptq/gptq.py b/mmrazor/implementations/quantization/gptq/gptq.py
new file mode 100644
index 000000000..84cfd3a4b
--- /dev/null
+++ b/mmrazor/implementations/quantization/gptq/gptq.py
@@ -0,0 +1,318 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import sys
+
+if sys.version_info < (3, 8):
+ from typing_extensions import Protocol
+else:
+ from typing import Protocol
+
+import numpy as np
+import torch
+import torch.distributed as dist
+
+from mmrazor.implementations.pruning.sparse_gpt.utils import torch_setting
+
+
+class ModuleProtocol(Protocol):
+ """Custom module protocol for algorithm mixin."""
+ weight: torch.Tensor
+
+ def forward(self, x):
+ """The abstract method."""
+ pass
+
+ def register_forward_hook(self, hook):
+ """The abstract method."""
+ pass
+
+ def register_backward_hook(self, hook):
+ """The abstract method."""
+ pass
+
+ def register_forward_pre_hook(self, hook):
+ """The abstract method."""
+ pass
+
+ def register_buffer(self, name, tensor):
+ """The abstract method."""
+ pass
+
+
+class GPTQMixIn(ModuleProtocol):
+ """The core algorithm implementation for GPTQ."""
+
+ def _gptq_mix_in_init(self):
+ """Init mixin."""
+ self.gptq_handles = []
+ self.rows = self.weight_matrix.shape[0]
+ self.columns = self.weight_matrix.shape[1]
+
+ self._hessian: torch.Tensor = None
+ self.hessian_batch = 0
+
+ # weight and input adaptive
+
+ @property
+ def weight_matrix(self):
+ """Return weight with shape (out in)"""
+ return self.weight.flatten(1) # out in
+
+ @weight_matrix.setter
+ def weight_matrix(self, value: torch.Tensor):
+ """Set weight."""
+ with torch.no_grad():
+ value = value.reshape(self.weight.shape).to(self.weight.device).to(
+ self.weight.dtype)
+ self.weight.data.copy_(value)
+
+ def format_input(self, input: torch.Tensor):
+ """Return input with shape (B N C)"""
+ if len(input.shape) == 2: # N C
+ input = input.unsqueeze(0) # 1 N C
+ return input
+
+ # compute hessian
+
+ @property
+ def hessian(self):
+ """hessian always return float."""
+ if dist.is_initialized():
+ if dist.get_rank() == 0:
+ assert self._hessian is not None, 'hessian is not initialized.'
+ hessian = self._hessian.to(self.weight_matrix.device)
+ else:
+ hessian = torch.zeros(
+ self.columns,
+ self.columns,
+ device=self.weight_matrix.device)
+ dist.broadcast(hessian, 0)
+ return hessian
+ else:
+ return self._hessian
+
+ @hessian.setter
+ def hessian(self, value: torch.Tensor):
+ """Set hessian."""
+ with torch.no_grad():
+ if dist.is_initialized():
+ if dist.get_rank() == 0:
+ assert self._hessian is not None, 'hessian is not initialized.' # noqa
+ self._hessian.data.copy_(
+ value.data.to(self._hessian.device))
+ else:
+ self._hessian = None
+ else:
+ self._hessian.data.copy_(value.data.to(self._hessian.device))
+
+ @torch.no_grad()
+ def update_hessian(self, input: torch.Tensor):
+ """Update hessian."""
+ input = self.format_input(input).float()
+ H_save = self.hessian
+ H_save = H_save.to(input.device)
+
+ assert len(input.shape) == 3
+ B = input.shape[0] # B N C
+ input = input.transpose(0, -1).flatten(1) # C D
+
+ H = input @ input.T * 2 # C C
+
+ if dist.is_initialized():
+ dist.all_reduce(H)
+ B *= dist.get_world_size()
+ H_save = (H_save * self.hessian_batch + H) / (self.hessian_batch + B)
+ self.hessian = H_save
+ self.hessian_batch = self.hessian_batch + B
+
+ def register_hessian_hook(self):
+ """Register updating hessian hook."""
+
+ @torch.no_grad()
+ def forward_pre_hook(module: Protocol, input: tuple):
+ assert len(input) == 1
+ self.update_hessian(input[0])
+
+ handle = self.register_forward_pre_hook(forward_pre_hook)
+ self.gptq_handles.append(handle)
+
+ def remove_hessian_hook(self):
+ """Remove updating hessian hook."""
+ for h in self.gptq_handles:
+ h.remove()
+
+ def init_hessian(self, device=None):
+ """Init hessian."""
+ if dist.is_initialized():
+ if dist.get_rank() == 0:
+ self._hessian = torch.zeros([self.columns, self.columns],
+ device=device,
+ dtype=torch.float)
+ else:
+ self._hessian = None
+ else:
+ self._hessian = torch.zeros([self.columns, self.columns],
+ device=device,
+ dtype=torch.float)
+
+ def pack(self, scales, zeros, g_idx=None):
+ """Pack and update qparams with groupsize_idx."""
+ self.g_idx = g_idx.clone() if g_idx is not None else self.g_idx
+
+ scales = scales.t().contiguous()
+ zeros = zeros.t().contiguous()
+ scale_zeros = zeros * scales
+ self.scales = scales.clone().half()
+ if self.bias is not None:
+ self.bias.half()
+
+ intweight = []
+ for idx in range(self.in_features):
+ intweight.append(
+ torch.round(
+ (self.weight.data[:, idx] + scale_zeros[self.g_idx[idx]]) /
+ self.scales[self.g_idx[idx]]).to(torch.int)[:, None])
+ intweight = torch.cat(intweight, dim=1)
+ intweight = intweight.t().contiguous()
+ intweight = intweight.cpu().numpy().astype(np.uint32)
+ qweight = np.zeros(
+ (intweight.shape[0] // 32 * self.bits, intweight.shape[1]),
+ dtype=np.uint32)
+ i = 0
+ row = 0
+ while row < qweight.shape[0]:
+ if self.bits in [2, 4, 8]:
+ for j in range(i, i + (32 // self.bits)):
+ qweight[row] |= intweight[j] << (self.bits * (j - i))
+ i += 32 // self.bits
+ row += 1
+ else:
+ raise NotImplementedError('Only 2,4,8 bits are supported.')
+
+ qweight = qweight.astype(np.int32)
+ self.qweight = torch.from_numpy(qweight).to(self.weight.device)
+
+ zeros -= 1
+ zeros = zeros.cpu().numpy().astype(np.uint32)
+ qzeros = np.zeros((zeros.shape[0], zeros.shape[1] // 32 * self.bits),
+ dtype=np.uint32)
+ i = 0
+ col = 0
+ while col < qzeros.shape[1]:
+ if self.bits in [2, 4, 8]:
+ for j in range(i, i + (32 // self.bits)):
+ qzeros[:, col] |= zeros[:, j] << (self.bits * (j - i))
+ i += 32 // self.bits
+ col += 1
+ else:
+ raise NotImplementedError('Only 2,4,8 bits are supported.')
+
+ qzeros = qzeros.astype(np.int32)
+ self.qzeros = torch.from_numpy(qzeros).to(self.weight.device)
+
+ @torch.no_grad()
+ def quant(self,
+ quantizer,
+ blocksize=128,
+ percdamp=0.01,
+ groupsize=-1,
+ actorder=False):
+ """The implementation for GPTQ."""
+ with torch_setting(dtype=torch.float):
+ assert self.hessian is not None
+ W: torch.Tensor = self.weight_matrix.float() # out in
+
+ if not quantizer.ready():
+ quantizer.find_params(W, weight=True)
+
+ H = self.hessian.float().to(W.device)
+ dead = torch.diag(H) == 0
+ H[dead, dead] = 1
+ W[:, dead] = 0
+
+ if actorder:
+ perm = torch.argsort(torch.diag(H), descending=True)
+ W = W[:, perm]
+ H = H[perm][:, perm]
+
+ Losses = torch.zeros_like(W)
+ Q = torch.zeros_like(W)
+
+ damp = percdamp * torch.mean(torch.diag(H))
+ diag = torch.arange(self.columns, device=W.device)
+ H[diag, diag] += damp
+ H = torch.linalg.cholesky(H)
+ H = torch.cholesky_inverse(H)
+ H = torch.linalg.cholesky(H, upper=True)
+ Hinv = H
+
+ g_idx = []
+ scale = []
+ zero = []
+ now_idx = 1
+
+ for i1 in range(0, self.columns, blocksize):
+ i2 = min(i1 + blocksize, self.columns)
+ count = i2 - i1
+
+ W1 = W[:, i1:i2].clone()
+ Q1 = torch.zeros_like(W1)
+ Err1 = torch.zeros_like(W1)
+ Losses1 = torch.zeros_like(W1)
+ Hinv1 = Hinv[i1:i2, i1:i2]
+
+ for i in range(count):
+ w = W1[:, i]
+ d = Hinv1[i, i]
+
+ if groupsize != -1:
+ if (i1 + i) % groupsize == 0:
+ quantizer.find_params(
+ W[:, (i1 + i):(i1 + i + groupsize)],
+ weight=True)
+
+ if ((i1 + i) // groupsize) - now_idx == -1:
+ scale.append(quantizer.scale)
+ zero.append(quantizer.zero)
+ now_idx += 1
+
+ q = quantizer.quantize(w.unsqueeze(1)).flatten()
+ Q1[:, i] = q
+ Losses1[:, i] = (w - q)**2 / d**2
+
+ err1 = (w - q) / d
+ W1[:,
+ i:] -= err1.unsqueeze(1).matmul(Hinv1[i,
+ i:].unsqueeze(0))
+ Err1[:, i] = err1
+
+ Q[:, i1:i2] = Q1
+ Losses[:, i1:i2] = Losses1 / 2
+
+ W[:, i2:] -= Err1.matmul(Hinv[i1:i2, i2:])
+
+ torch.cuda.synchronize()
+ error = torch.sum(Losses).item()
+
+ groupsize = groupsize if groupsize != -1 else self.columns
+ g_idx = [i // groupsize for i in range(self.columns)]
+ g_idx = torch.tensor(g_idx, dtype=torch.int32, device=Q.device)
+ if actorder:
+ invperm = torch.argsort(perm)
+ Q = Q[:, invperm]
+ g_idx = g_idx[invperm]
+
+ if scale == []:
+ scale.append(quantizer.scale)
+ zero.append(quantizer.zero)
+ scale = torch.cat(scale, dim=1)
+ zero = torch.cat(zero, dim=1)
+ self.weight_matrix = Q.data.to(self.weight_matrix.dtype)
+ if self.is_custom_kernel:
+ self.pack(scale, zero, g_idx)
+ del self.weight
+ return error
+
+ def free(self):
+ """Free some cache and memory."""
+ self._hessian = None
+ torch.cuda.empty_cache()
diff --git a/mmrazor/implementations/quantization/gptq/ops.py b/mmrazor/implementations/quantization/gptq/ops.py
new file mode 100644
index 000000000..b8c139412
--- /dev/null
+++ b/mmrazor/implementations/quantization/gptq/ops.py
@@ -0,0 +1,566 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import math
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from torch import Tensor
+from torch.cuda.amp import custom_bwd, custom_fwd
+
+from mmrazor.models.architectures.dynamic_ops import (DynamicConv2d,
+ DynamicLinear)
+# from mmrazor.implementations.pruning.sparse_gpt.utils import torch_setting
+from .gptq import GPTQMixIn
+
+try:
+ import triton
+ import triton.language as tl
+
+ from . import custom_autotune
+
+ # code based https://github.com/fpgaminer/GPTQ-triton
+ @custom_autotune.autotune(
+ configs=[
+ triton.Config(
+ {
+ 'BLOCK_SIZE_M': 64,
+ 'BLOCK_SIZE_N': 256,
+ 'BLOCK_SIZE_K': 32,
+ 'GROUP_SIZE_M': 8
+ },
+ num_stages=4,
+ num_warps=4),
+ triton.Config(
+ {
+ 'BLOCK_SIZE_M': 128,
+ 'BLOCK_SIZE_N': 128,
+ 'BLOCK_SIZE_K': 32,
+ 'GROUP_SIZE_M': 8
+ },
+ num_stages=4,
+ num_warps=4),
+ triton.Config(
+ {
+ 'BLOCK_SIZE_M': 64,
+ 'BLOCK_SIZE_N': 128,
+ 'BLOCK_SIZE_K': 32,
+ 'GROUP_SIZE_M': 8
+ },
+ num_stages=4,
+ num_warps=4),
+ triton.Config(
+ {
+ 'BLOCK_SIZE_M': 128,
+ 'BLOCK_SIZE_N': 32,
+ 'BLOCK_SIZE_K': 32,
+ 'GROUP_SIZE_M': 8
+ },
+ num_stages=4,
+ num_warps=4),
+ triton.Config(
+ {
+ 'BLOCK_SIZE_M': 64,
+ 'BLOCK_SIZE_N': 64,
+ 'BLOCK_SIZE_K': 32,
+ 'GROUP_SIZE_M': 8
+ },
+ num_stages=4,
+ num_warps=4),
+ triton.Config(
+ {
+ 'BLOCK_SIZE_M': 64,
+ 'BLOCK_SIZE_N': 128,
+ 'BLOCK_SIZE_K': 32,
+ 'GROUP_SIZE_M': 8
+ },
+ num_stages=2,
+ num_warps=8),
+ triton.Config(
+ {
+ 'BLOCK_SIZE_M': 64,
+ 'BLOCK_SIZE_N': 64,
+ 'BLOCK_SIZE_K': 64,
+ 'GROUP_SIZE_M': 8
+ },
+ num_stages=3,
+ num_warps=8),
+ triton.Config(
+ {
+ 'BLOCK_SIZE_M': 32,
+ 'BLOCK_SIZE_N': 32,
+ 'BLOCK_SIZE_K': 128,
+ 'GROUP_SIZE_M': 8
+ },
+ num_stages=2,
+ num_warps=4),
+ ],
+ key=['M', 'N', 'K'],
+ nearest_power_of_two=True,
+ prune_configs_by={
+ 'early_config_prune':
+ custom_autotune.matmul248_kernel_config_pruner,
+ 'perf_model': None,
+ 'top_k': None,
+ },
+ )
+ @triton.jit
+ def matmul_248_kernel(a_ptr, b_ptr, c_ptr, scales_ptr, zeros_ptr, g_ptr, M,
+ N, K, bits, maxq, stride_am, stride_ak, stride_bk,
+ stride_bn, stride_cm, stride_cn, stride_scales,
+ stride_zeros, BLOCK_SIZE_M: tl.constexpr,
+ BLOCK_SIZE_N: tl.constexpr,
+ BLOCK_SIZE_K: tl.constexpr,
+ GROUP_SIZE_M: tl.constexpr):
+ """
+ Compute the matrix multiplication C = A x B.
+ A is of shape (M, K) float16
+ B is of shape (K//8, N) int32
+ C is of shape (M, N) float16
+ scales is of shape (G, N) float16
+ zeros is of shape (G, N) float16
+ g_ptr is of shape (K) int32
+ """
+ infearure_per_bits = 32 // bits
+
+ pid = tl.program_id(axis=0)
+ num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
+ num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
+ num_pid_k = tl.cdiv(K, BLOCK_SIZE_K)
+ num_pid_in_group = GROUP_SIZE_M * num_pid_n
+ group_id = pid // num_pid_in_group
+ first_pid_m = group_id * GROUP_SIZE_M
+ group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
+ pid_m = first_pid_m + (pid % group_size_m)
+ pid_n = (pid % num_pid_in_group) // group_size_m
+
+ offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
+ offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
+ offs_k = tl.arange(0, BLOCK_SIZE_K)
+ a_ptrs = a_ptr + (
+ offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak
+ ) # (BLOCK_SIZE_M, BLOCK_SIZE_K)
+ a_mask = (offs_am[:, None] < M)
+ # b_ptrs is set up such that it repeats elements along the K axis 8
+ # times
+ b_ptrs = b_ptr + (
+ (offs_k[:, None] // infearure_per_bits) * stride_bk +
+ offs_bn[None, :] * stride_bn) # (BLOCK_SIZE_K, BLOCK_SIZE_N)
+ g_ptrs = g_ptr + offs_k
+ # shifter is used to extract the N bits of each element in the 32-bit
+ # word from B
+ scales_ptrs = scales_ptr + offs_bn[None, :]
+ zeros_ptrs = zeros_ptr + (offs_bn[None, :] // infearure_per_bits)
+
+ shifter = (offs_k % infearure_per_bits) * bits
+ zeros_shifter = (offs_bn % infearure_per_bits) * bits
+ accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
+
+ for k in range(0, num_pid_k):
+ g_idx = tl.load(g_ptrs)
+
+ # Fetch scales and zeros; these are per-outfeature and thus reused
+ # in the inner loop
+ scales = tl.load(scales_ptrs + g_idx[:, None] *
+ stride_scales) # (BLOCK_SIZE_K, BLOCK_SIZE_N,)
+ zeros = tl.load(
+ zeros_ptrs +
+ g_idx[:, None] * stride_zeros) # (BLOCK_SIZE_K, BLOCK_SIZE_N,)
+
+ zeros = (zeros >> zeros_shifter[None, :]) & maxq
+ zeros = (zeros + 1)
+
+ a = tl.load(
+ a_ptrs, mask=a_mask, other=0.) # (BLOCK_SIZE_M, BLOCK_SIZE_K)
+ b = tl.load(b_ptrs) # (BLOCK_SIZE_K, BLOCK_SIZE_N), but repeated
+
+ # Now we need to unpack b (which is N-bit values) into 32-bit
+ # values
+ b = (b >> shifter[:, None]) & maxq # Extract the N-bit values
+ b = (b - zeros) * scales # Scale and shift
+
+ accumulator += tl.dot(a, b)
+ a_ptrs += BLOCK_SIZE_K
+ b_ptrs += (BLOCK_SIZE_K // infearure_per_bits) * stride_bk
+ g_ptrs += BLOCK_SIZE_K
+
+ c_ptrs = c_ptr + stride_cm * offs_am[:, None] + stride_cn * offs_bn[
+ None, :]
+ c_mask = (offs_am[:, None] < M) & (offs_bn[None, :] < N)
+ tl.store(c_ptrs, accumulator, mask=c_mask)
+
+ @custom_autotune.autotune(
+ configs=[
+ triton.Config(
+ {
+ 'BLOCK_SIZE_M': 64,
+ 'BLOCK_SIZE_N': 32,
+ 'BLOCK_SIZE_K': 256,
+ 'GROUP_SIZE_M': 8
+ },
+ num_stages=4,
+ num_warps=4),
+ triton.Config(
+ {
+ 'BLOCK_SIZE_M': 128,
+ 'BLOCK_SIZE_N': 32,
+ 'BLOCK_SIZE_K': 128,
+ 'GROUP_SIZE_M': 8
+ },
+ num_stages=4,
+ num_warps=4),
+ triton.Config(
+ {
+ 'BLOCK_SIZE_M': 64,
+ 'BLOCK_SIZE_N': 32,
+ 'BLOCK_SIZE_K': 128,
+ 'GROUP_SIZE_M': 8
+ },
+ num_stages=4,
+ num_warps=4),
+ triton.Config(
+ {
+ 'BLOCK_SIZE_M': 128,
+ 'BLOCK_SIZE_N': 32,
+ 'BLOCK_SIZE_K': 32,
+ 'GROUP_SIZE_M': 8
+ },
+ num_stages=4,
+ num_warps=4),
+ triton.Config(
+ {
+ 'BLOCK_SIZE_M': 64,
+ 'BLOCK_SIZE_N': 32,
+ 'BLOCK_SIZE_K': 64,
+ 'GROUP_SIZE_M': 8
+ },
+ num_stages=4,
+ num_warps=4),
+ triton.Config(
+ {
+ 'BLOCK_SIZE_M': 64,
+ 'BLOCK_SIZE_N': 32,
+ 'BLOCK_SIZE_K': 128,
+ 'GROUP_SIZE_M': 8
+ },
+ num_stages=2,
+ num_warps=8),
+ triton.Config(
+ {
+ 'BLOCK_SIZE_M': 64,
+ 'BLOCK_SIZE_N': 64,
+ 'BLOCK_SIZE_K': 64,
+ 'GROUP_SIZE_M': 8
+ },
+ num_stages=3,
+ num_warps=8),
+ triton.Config(
+ {
+ 'BLOCK_SIZE_M': 32,
+ 'BLOCK_SIZE_N': 128,
+ 'BLOCK_SIZE_K': 32,
+ 'GROUP_SIZE_M': 8
+ },
+ num_stages=2,
+ num_warps=4),
+ ],
+ key=['M', 'N', 'K'],
+ nearest_power_of_two=True)
+ @triton.jit
+ def transpose_matmul_248_kernel(
+ a_ptr, b_ptr, c_ptr, scales_ptr, zeros_ptr, g_ptr, M, N, K, bits,
+ maxq, stride_am, stride_ak, stride_bk, stride_bn, stride_cm,
+ stride_cn, stride_scales, stride_zeros, BLOCK_SIZE_M: tl.constexpr,
+ BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,
+ GROUP_SIZE_M: tl.constexpr):
+ """
+ Compute the matrix multiplication C = A x B.
+ A is of shape (M, N) float16
+ B is of shape (K//8, N) int32
+ C is of shape (M, K) float16
+ scales is of shape (G, N) float16
+ zeros is of shape (G, N) float16
+ g_ptr is of shape (K) int32
+ """
+ infearure_per_bits = 32 // bits
+
+ pid = tl.program_id(axis=0)
+ num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
+ num_pid_k = tl.cdiv(K, BLOCK_SIZE_K)
+ num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
+ num_pid_in_group = GROUP_SIZE_M * num_pid_k
+ group_id = pid // num_pid_in_group
+ first_pid_m = group_id * GROUP_SIZE_M
+ group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
+ pid_m = first_pid_m + (pid % group_size_m)
+ pid_k = (pid % num_pid_in_group) // group_size_m
+
+ offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
+ offs_bk = pid_k * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K)
+ offs_n = tl.arange(0, BLOCK_SIZE_N)
+ a_ptrs = a_ptr + (
+ offs_am[:, None] * stride_am + offs_n[None, :] * stride_ak
+ ) # (BLOCK_SIZE_M, BLOCK_SIZE_N)
+ a_mask = (offs_am[:, None] < M)
+ # b_ptrs is set up such that it repeats elements along the K axis 8
+ # times
+ b_ptrs = b_ptr + (
+ (offs_bk[:, None] // infearure_per_bits) * stride_bk +
+ offs_n[None, :] * stride_bn) # (BLOCK_SIZE_K, BLOCK_SIZE_N)
+ g_ptrs = g_ptr + offs_bk
+ g_idx = tl.load(g_ptrs)
+
+ # shifter is used to extract the N bits of each element in the 32-bit
+ # word from B
+ scales_ptrs = scales_ptr + offs_n[
+ None, :] + g_idx[:, None] * stride_scales
+ zeros_ptrs = zeros_ptr + (offs_n[None, :] // infearure_per_bits
+ ) + g_idx[:, None] * stride_zeros
+
+ shifter = (offs_bk % infearure_per_bits) * bits
+ zeros_shifter = (offs_n % infearure_per_bits) * bits
+ accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_K), dtype=tl.float32)
+
+ for n in range(0, num_pid_n):
+ # Fetch scales and zeros; these are per-outfeature and thus reused
+ # in the inner loop
+ scales = tl.load(scales_ptrs) # (BLOCK_SIZE_K, BLOCK_SIZE_N,)
+ zeros = tl.load(zeros_ptrs) # (BLOCK_SIZE_K, BLOCK_SIZE_N,)
+
+ zeros = (zeros >> zeros_shifter[None, :]) & maxq
+ zeros = (zeros + 1)
+
+ a = tl.load(
+ a_ptrs, mask=a_mask, other=0.) # (BLOCK_SIZE_M, BLOCK_SIZE_N)
+ b = tl.load(b_ptrs) # (BLOCK_SIZE_K, BLOCK_SIZE_N), but repeated
+
+ # Now we need to unpack b (which is N-bit values) into 32-bit
+ # values
+ b = (b >> shifter[:, None]) & maxq # Extract the N-bit values
+ b = (b - zeros) * scales # Scale and shift
+ b = tl.trans(b)
+
+ accumulator += tl.dot(a, b)
+ a_ptrs += BLOCK_SIZE_N
+ b_ptrs += BLOCK_SIZE_N
+ scales_ptrs += BLOCK_SIZE_N
+ zeros_ptrs += (BLOCK_SIZE_N // infearure_per_bits)
+
+ c_ptrs = c_ptr + stride_cm * offs_am[:, None] + stride_cn * offs_bk[
+ None, :]
+ c_mask = (offs_am[:, None] < M) & (offs_bk[None, :] < K)
+ tl.store(c_ptrs, accumulator, mask=c_mask)
+except: # noqa: E722
+ print('triton not installed.')
+
+
+def matmul248(input, qweight, scales, qzeros, g_idx, bits, maxq):
+ """matmul248 function with matmul_248_kernel."""
+ with torch.cuda.device(input.device):
+ output = torch.empty((input.shape[0], qweight.shape[1]),
+ device=input.device,
+ dtype=torch.float16)
+ grid = lambda META: ( # noqa: E731
+ triton.cdiv( # noqa: E731
+ input.shape[0], META['BLOCK_SIZE_M']) * triton. # noqa: E731
+ cdiv( # noqa: E731
+ qweight.shape[1], META['BLOCK_SIZE_N']), ) # noqa: E731
+ matmul_248_kernel[grid](input, qweight, output, scales, qzeros, g_idx,
+ input.shape[0], qweight.shape[1],
+ input.shape[1], bits, maxq, input.stride(0),
+ input.stride(1), qweight.stride(0),
+ qweight.stride(1), output.stride(0),
+ output.stride(1), scales.stride(0),
+ qzeros.stride(0))
+ return output
+
+
+def transpose_matmul248(input, qweight, scales, qzeros, g_idx, bits, maxq):
+ """transpose_matmul248 function with transpose_matmul_248_kernel."""
+ with torch.cuda.device(input.device):
+ output_dim = (qweight.shape[0] * 32) // bits
+ output = torch.empty((input.shape[0], output_dim),
+ device=input.device,
+ dtype=torch.float16)
+ grid = lambda META: ( # noqa: E731
+ triton.cdiv(input.shape[0], META['BLOCK_SIZE_M']) # noqa: E731
+ * triton.cdiv(output_dim, META['BLOCK_SIZE_K']), ) # noqa: E731
+ transpose_matmul_248_kernel[grid](input, qweight, output, scales,
+ qzeros, g_idx, input.shape[0],
+ qweight.shape[1], output_dim,
+ bits, maxq, input.stride(0),
+ input.stride(1), qweight.stride(0),
+ qweight.stride(1), output.stride(0),
+ output.stride(1), scales.stride(0),
+ qzeros.stride(0))
+ return output
+
+
+class QuantLinearFunction(torch.autograd.Function):
+ """Custom QuantLinearFunction."""
+
+ @staticmethod
+ @custom_fwd(cast_inputs=torch.float16)
+ def forward(ctx, input, qweight, scales, qzeros, g_idx, bits, maxq):
+ """Custom forward."""
+ output = matmul248(input, qweight, scales, qzeros, g_idx, bits, maxq)
+ ctx.save_for_backward(qweight, scales, qzeros, g_idx)
+ ctx.bits, ctx.maxq = bits, maxq
+ return output
+
+ @staticmethod
+ @custom_bwd
+ def backward(ctx, grad_output):
+ """Custom backward."""
+ qweight, scales, qzeros, g_idx = ctx.saved_tensors
+ bits, maxq = ctx.bits, ctx.maxq
+ grad_input = None
+
+ if ctx.needs_input_grad[0]:
+ grad_input = transpose_matmul248(grad_output, qweight, scales,
+ qzeros, g_idx, bits, maxq)
+ return grad_input, None, None, None, None, None, None
+
+
+class TritonGPTQLinear(nn.Module, GPTQMixIn):
+ """Custom Linear for GPTQ with custom triton kernel."""
+
+ def __init__(self, bits, groupsize, weight, in_features, out_features,
+ bias):
+ super().__init__()
+ if bits not in [2, 4, 8]:
+ raise NotImplementedError('Only 2,4,8 bits are supported.')
+ self.weight = weight
+ self.bias = bias
+
+ self.in_features = in_features
+ self.out_features = out_features
+ self.bits = bits
+ self.maxq = 2**self.bits - 1
+ self.groupsize = groupsize if groupsize != -1 else in_features
+
+ self.register_buffer(
+ 'qweight',
+ torch.zeros((in_features // 32 * self.bits, out_features),
+ dtype=torch.int32))
+ self.register_buffer(
+ 'qzeros',
+ torch.zeros((math.ceil(
+ in_features / self.groupsize), out_features // 32 * self.bits),
+ dtype=torch.int32))
+ self.register_buffer(
+ 'scales',
+ torch.zeros(
+ (math.ceil(in_features / self.groupsize), out_features),
+ dtype=torch.float16))
+ self.register_buffer(
+ 'g_idx',
+ torch.tensor([i // self.groupsize for i in range(in_features)],
+ dtype=torch.int32))
+
+ self._gptq_mix_in_init()
+
+ @property
+ def is_custom_kernel(self):
+ """Whether use custom kernel."""
+ return True
+
+ @classmethod
+ def convert_from(cls, module: nn.Linear, bits, groupsize):
+ """Convert to cls from torch's module."""
+ new_module = cls(
+ bits,
+ groupsize,
+ weight=module.weight,
+ in_features=module.in_features,
+ out_features=module.out_features,
+ bias=module.bias)
+
+ return new_module
+
+ def forward(self, x):
+ """Custom forward."""
+ if torch.all(self.qweight == 0):
+ out = F.linear(x, self.weight, self.bias)
+ else:
+ # import pdb;pdb.set_trace()
+ out_shape = x.shape[:-1] + (self.out_features, )
+ out = QuantLinearFunction.apply(
+ x.reshape(-1, x.shape[-1]), self.qweight, self.scales,
+ self.qzeros, self.g_idx, self.bits, self.maxq)
+ out = out + self.bias if self.bias is not None else out
+ out = out.reshape(out_shape)
+ # import pdb;pdb.set_trace()
+ return out
+
+
+class GPTQLinear(DynamicLinear, GPTQMixIn):
+ """Custom Linear for GPTQ without custom triton kernel."""
+
+ def __init__(self, a_fakequant=None, *args, **kwargs) -> None:
+ super().__init__(*args, **kwargs)
+ self._gptq_mix_in_init()
+ self.a_fakequant = a_fakequant
+ self.fix_qparams = False
+
+ @property
+ def is_custom_kernel(self):
+ """Whether use custom kernel."""
+ return False
+
+ @classmethod
+ def convert_from(cls,
+ module: nn.Linear,
+ a_fakequant=None) -> 'DynamicLinear':
+ """Convert to cls from torch's module."""
+ new_module = cls(
+ a_fakequant=a_fakequant,
+ in_features=module.in_features,
+ out_features=module.out_features,
+ bias=True if module.bias is not None else False)
+ new_module.load_state_dict(module.state_dict(), strict=False)
+
+ dtype = next(module.parameters()).dtype
+ new_module = new_module.to(dtype)
+
+ return new_module
+
+ def forward(self, input: Tensor) -> Tensor:
+ """Custom forward."""
+ if self.a_fakequant:
+ dtype = self.weight.dtype
+ if not self.fix_qparams:
+ self.a_fakequant.find_params(input)
+ input = self.a_fakequant.quantize(input).to(dtype)
+ return super().forward(input)
+
+
+class GPTQConv2d(DynamicConv2d, GPTQMixIn):
+ """Custom Conv2d for GPTQ without custom triton kernel."""
+
+ def __init__(self, *args, **kwargs) -> None:
+ super().__init__(*args, **kwargs)
+ self._gptq_mix_in_init()
+
+ @property
+ def is_custom_kernel(self):
+ """Whether use custom kernel."""
+ return False
+
+ @classmethod
+ def convert_from(cls, module: nn.Conv2d) -> 'DynamicConv2d':
+ """Convert to cls from torch's module."""
+ new_module = super().convert_from(module)
+ new_module.load_state_dict(module.state_dict(), strict=False)
+
+ dtype = next(module.parameters()).dtype
+ new_module = new_module.to(dtype)
+
+ return new_module
+
+ def format_input(self, input: torch.Tensor):
+ """Format input shape."""
+ # input B C H W
+ input = F.unfold(
+ input, self.kernel_size, padding=self.padding,
+ stride=self.stride) # B C D
+ return input.transpose(-1, -2)
diff --git a/mmrazor/implementations/quantization/gptq/quantizer.py b/mmrazor/implementations/quantization/gptq/quantizer.py
new file mode 100644
index 000000000..0db2fb998
--- /dev/null
+++ b/mmrazor/implementations/quantization/gptq/quantizer.py
@@ -0,0 +1,144 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import torch
+import torch.nn as nn
+
+
+class Quantizer(nn.Module):
+ """Quantizer for some basic quantization functions."""
+
+ def __init__(self, shape=1):
+ super(Quantizer, self).__init__()
+ self.register_buffer('maxq', torch.tensor(0))
+ self.register_buffer('scale', torch.zeros(shape))
+ self.register_buffer('zero', torch.zeros(shape))
+
+ def configure(self,
+ bits,
+ perchannel=False,
+ sym=True,
+ mse=False,
+ norm=2.4,
+ grid=100,
+ maxshrink=.8,
+ trits=False):
+ """Configure qconfig."""
+
+ self.maxq = torch.tensor(2**bits - 1)
+ self.perchannel = perchannel
+ self.sym = sym
+ self.mse = mse
+ self.norm = norm
+ self.grid = grid
+ self.maxshrink = maxshrink
+ if trits:
+ self.maxq = torch.tensor(-1)
+ self.scale = torch.zeros_like(self.scale)
+
+ def _quantize(self, x, scale, zero, maxq):
+ """Fakequant."""
+ if maxq < 0:
+ return (x > scale / 2).float() * scale + (x <
+ zero / 2).float() * zero
+ q = torch.clamp(torch.round(x / scale) + zero, 0, maxq)
+ return scale * (q - zero)
+
+ def find_params(self, x, weight=False):
+ """Observe the specified data and calculate the qparams."""
+ dev = x.device
+ self.maxq = self.maxq.to(dev)
+
+ shape = x.shape
+ if self.perchannel:
+ if weight:
+ x = x.flatten(1)
+ else:
+ if len(shape) == 4:
+ x = x.permute([1, 0, 2, 3])
+ x = x.flatten(1)
+ if len(shape) == 3:
+ x = x.reshape((-1, shape[-1])).t()
+ if len(shape) == 2:
+ x = x.t()
+ else:
+ x = x.flatten().unsqueeze(0)
+
+ tmp = torch.zeros(x.shape[0], device=dev)
+ xmin = torch.minimum(x.min(1)[0], tmp)
+ xmax = torch.maximum(x.max(1)[0], tmp)
+
+ if self.sym:
+ xmax = torch.maximum(torch.abs(xmin), xmax)
+ tmp = xmin < 0
+ if torch.any(tmp):
+ xmin[tmp] = -xmax[tmp]
+ tmp = (xmin == 0) & (xmax == 0)
+ xmin[tmp] = -1
+ xmax[tmp] = +1
+
+ if self.maxq < 0:
+ self.scale = xmax
+ self.zero = xmin
+ else:
+ self.scale = (xmax - xmin) / self.maxq
+ if self.sym:
+ self.zero = torch.full_like(self.scale, (self.maxq + 1) / 2)
+ else:
+ self.zero = torch.round(-xmin / self.scale)
+
+ if self.mse:
+ best = torch.full([x.shape[0]], float('inf'), device=dev)
+ for i in range(int(self.maxshrink * self.grid)):
+ p = 1 - i / self.grid
+ xmin1 = p * xmin
+ xmax1 = p * xmax
+ scale1 = (xmax1 - xmin1) / self.maxq
+ zero1 = torch.round(-xmin1 /
+ scale1) if not self.sym else self.zero
+ q = self._quantize(x, scale1.unsqueeze(1), zero1.unsqueeze(1),
+ self.maxq)
+ q -= x
+ q.abs_()
+ q.pow_(self.norm)
+ err = torch.sum(q, 1)
+ tmp = err < best
+ if torch.any(tmp):
+ best[tmp] = err[tmp]
+ self.scale[tmp] = scale1[tmp]
+ self.zero[tmp] = zero1[tmp]
+ if not self.perchannel:
+ if weight:
+ tmp = shape[0]
+ else:
+ tmp = shape[1] if len(shape) != 3 else shape[2]
+ self.scale = self.scale.repeat(tmp)
+ self.zero = self.zero.repeat(tmp)
+
+ if weight:
+ shape = [-1] + [1] * (len(shape) - 1)
+ self.scale = self.scale.reshape(shape)
+ self.zero = self.zero.reshape(shape)
+ return
+ if len(shape) == 4:
+ self.scale = self.scale.reshape((1, -1, 1, 1))
+ self.zero = self.zero.reshape((1, -1, 1, 1))
+ if len(shape) == 3:
+ self.scale = self.scale.reshape((1, 1, -1))
+ self.zero = self.zero.reshape((1, 1, -1))
+ if len(shape) == 2:
+ self.scale = self.scale.unsqueeze(0)
+ self.zero = self.zero.unsqueeze(0)
+
+ def quantize(self, x):
+ """Fakequant."""
+ if self.ready():
+ return self._quantize(x, self.scale, self.zero, self.maxq)
+
+ return x
+
+ def enabled(self):
+ """Whether is enabled."""
+ return self.maxq > 0
+
+ def ready(self):
+ """Whether is ready."""
+ return torch.all(self.scale != 0)
diff --git a/mmrazor/implementations/quantization/gptq/utils.py b/mmrazor/implementations/quantization/gptq/utils.py
new file mode 100644
index 000000000..a27b3ff8d
--- /dev/null
+++ b/mmrazor/implementations/quantization/gptq/utils.py
@@ -0,0 +1,56 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import torch
+
+
+# copy from https://github.com/openppl-public/ppq/blob/master/ppq/quantization/measure/norm.py # noqa: E501
+def torch_snr_error(y_pred: torch.Tensor,
+ y_real: torch.Tensor,
+ reduction: str = 'mean') -> torch.Tensor:
+ """Compute SNR between y_pred(tensor) and y_real(tensor)
+
+ SNR can be calculted as following equation:
+
+ SNR(pred, real) = (pred - real) ^ 2 / (real) ^ 2
+
+ if x and y are matrixs, SNR error over matrix should be the mean value of
+ SNR error over all elements.
+
+ SNR(pred, real) = mean((pred - real) ^ 2 / (real) ^ 2)
+ Args:
+ y_pred (torch.Tensor): _description_
+ y_real (torch.Tensor): _description_
+ reduction (str, optional): _description_. Defaults to 'mean'.
+ Raises:
+ ValueError: _description_
+ ValueError: _description_
+ Returns:
+ torch.Tensor: _description_
+ """
+ y_pred = y_pred.type(torch.float32)
+ y_real = y_real.type(torch.float32)
+
+ if y_pred.shape != y_real.shape:
+ raise ValueError(
+ f'Can not compute snr loss for tensors with different shape. '
+ f'({y_pred.shape} and {y_real.shape})')
+ reduction = str(reduction).lower()
+
+ if y_pred.ndim == 1:
+ y_pred = y_pred.unsqueeze(0)
+ y_real = y_real.unsqueeze(0)
+
+ y_pred = y_pred.flatten(start_dim=1)
+ y_real = y_real.flatten(start_dim=1)
+
+ noise_power = torch.pow(y_pred - y_real, 2).sum(dim=-1)
+ signal_power = torch.pow(y_real, 2).sum(dim=-1)
+ snr = (noise_power) / (signal_power + 1e-7)
+
+ if reduction == 'mean':
+ return torch.mean(snr)
+ elif reduction == 'sum':
+ return torch.sum(snr)
+ elif reduction == 'none':
+ return snr
+ else:
+ raise ValueError('Unsupported reduction method.')
diff --git a/mmrazor/structures/quantization/backend_config/common_operator_config_utils.py b/mmrazor/structures/quantization/backend_config/common_operator_config_utils.py
index 0a381d5d0..bea018975 100644
--- a/mmrazor/structures/quantization/backend_config/common_operator_config_utils.py
+++ b/mmrazor/structures/quantization/backend_config/common_operator_config_utils.py
@@ -49,7 +49,9 @@
'relu_qat', 'bn_qat', 'bn_relu_qat', 'func'
])
-if digit_version(torch.__version__) >= digit_version('1.13.0'):
+if digit_version(
+ torch.__version__) >= digit_version('1.13.0') and digit_version(
+ torch.__version__) <= digit_version('1.13.1'):
_Conv1dMetadata = _ConvMetadata(
nn.Conv1d, nn.ConvTranspose1d, nn.BatchNorm1d, nnqr.Conv1d,
nnqr.ConvTranspose1d, nni.ConvReLU1d, nni.ConvBn1d, nni.ConvBnReLU1d,
diff --git a/mmrazor/structures/quantization/backend_config/mapping.py b/mmrazor/structures/quantization/backend_config/mapping.py
index b9cc5372b..0a02ac1b7 100644
--- a/mmrazor/structures/quantization/backend_config/mapping.py
+++ b/mmrazor/structures/quantization/backend_config/mapping.py
@@ -7,7 +7,9 @@
from .openvino import get_openvino_backend_config
from .tensorrt import get_tensorrt_backend_config
-if digit_version(torch.__version__) >= digit_version('1.13.0'):
+if digit_version(
+ torch.__version__) >= digit_version('1.13.0') and digit_version(
+ torch.__version__) <= digit_version('1.13.1'):
BackendConfigs = {
'academic': get_academic_backend_config(),
'native': get_native_backend_config(),
diff --git a/mmrazor/utils/log_tools.py b/mmrazor/utils/log_tools.py
index 787dc1927..935349a03 100644
--- a/mmrazor/utils/log_tools.py
+++ b/mmrazor/utils/log_tools.py
@@ -1,6 +1,7 @@
# Copyright (c) OpenMMLab. All rights reserved.
import logging
+import torch.distributed as dist
from mmengine import MMLogger
from mmengine import print_log as engine_print_log
@@ -17,8 +18,15 @@ def get_level(level='info'):
return level
-def print_log(msg, logger='current', level='info'):
- engine_print_log(msg, logger, get_level(level))
+def print_log(msg, logger='current', level='info', only_rank0=True):
+
+ if only_rank0 and dist.is_initialized():
+ if dist.get_rank() == 0:
+ engine_print_log(msg, logger, get_level(level))
+ else:
+ pass
+ else:
+ engine_print_log(msg, logger, get_level(level))
def set_log_level(level='debug'):
diff --git a/projects/mmrazor_large/README.md b/projects/mmrazor_large/README.md
new file mode 100644
index 000000000..378b9102b
--- /dev/null
+++ b/projects/mmrazor_large/README.md
@@ -0,0 +1,42 @@
+
+
+
+
+# MMRazor for Large Models
+
+## Introduction
+
+MMRazor is dedicated to the development of general-purpose model compression tools. Now, MMRazor not only supports conventional CV model compression but also extends to support large models. This project will provide examples of MMRazor's compression for various large models, including LLaMA, stable diffusion, and more.
+
+Code structure overview about large models.
+
+```
+mmrazor
+├── implementations # core algorithm components
+ ├── pruning
+ └── quantization
+projects
+└── mmrazor_large
+ ├── algorithms # algorithms usage introduction
+ └── examples # examples for various models about algorithms
+ ├── language_models
+ │ ├── LLaMA
+ │ └── OPT
+ └── ResNet
+```
+
+## Model-Algorithm Example Matrix
+
+| | ResNet | OPT | LLama | Stable diffusion |
+| ------------------------------------ | ----------------------------------------------- | ------------------------------------------------------------ | -------------------------------------------------------------- | ---------------- |
+| [SparseGPT](algorithms/SparseGPT.md) | [:white_check_mark:](examples/ResNet/README.md) | [:white_check_mark:](examples/language_models/OPT/README.md) | [:white_check_mark:](examples/language_models/LLaMA/README.md) | |
+| [GPTQ](algorithms/GPTQ.md) | [:white_check_mark:](examples/ResNet/README.md) | [:white_check_mark:](examples/language_models/OPT/README.md) | [:white_check_mark:](examples/language_models/LLaMA/README.md) | |
+
+## PaperList
+
+We provide a paperlist for researchers in the field of model compression for large models. If you want to add your paper to this list, please submit a PR.
+
+| Paper | Title | Type | MMRazor |
+| --------- | --------------------------------------------------------------------------------------------------------------------- | ------------ | --------------------------------------------- |
+| SparseGPT | [SparseGPT: Massive Language Models Can Be Accurately Pruned in One-Shot](https://arxiv.org/abs/2301.00774) | Pruning | [:white_check_mark:](algorithms/SparseGPT.md) |
+| GPTQ | [GPTQ: Accurate Post-Training Quantization for Generative Pre-trained Transformers](https://arxiv.org/abs/2210.17323) | Quantization | [:white_check_mark:](algorithms/GPTQ.md) |
diff --git a/projects/mmrazor_large/algorithms/GPTQ.md b/projects/mmrazor_large/algorithms/GPTQ.md
new file mode 100644
index 000000000..b013a73a2
--- /dev/null
+++ b/projects/mmrazor_large/algorithms/GPTQ.md
@@ -0,0 +1,56 @@
+# GPTQ
+
+> [GPTQ: Accurate Post-Training Quantization for Generative Pre-trained Transformers](https://arxiv.org/abs/2210.17323)
+
+
+
+## Abstract
+
+Generative Pre-trained Transformer models, known as GPT or OPT, set themselves apart through breakthrough performance across complex language modelling tasks, but also by their extremely high computational and storage costs. Specifically, due to their massive size, even inference for large, highly-accurate GPT models may require multiple performant GPUs, which limits the usability of such models. While there is emerging work on relieving this pressure via model compression, the applicability and performance of existing compression techniques is limited by the scale and complexity of GPT models. In this paper, we address this challenge, and propose GPTQ, a new one-shot weight quantization method based on approximate second-order information, that is both highlyaccurate and highly-efficient. Specifically, GPTQ can quantize GPT models with 175 billion parameters in approximately four GPU hours, reducing the bitwidth down to 3 or 4 bits per weight, with negligible accuracy degradation relative to the uncompressed baseline. Our method more than doubles the compression gains relative to previously-proposed one-shot quantization methods, preserving accuracy, allowing us for the first time to execute an 175 billion-parameter model inside a single GPU for generative inference. Moreover, we also show that our method can still provide reasonable accuracy in the extreme quantization regime, in which weights are quantized to 2-bit or even ternary quantization levels. We show experimentally that these improvements can be leveraged for end-to-end inference speedups over FP16, of around 3.25x when using high-end GPUs (NVIDIA A100) and 4.5x when using more cost-effective ones (NVIDIA A6000). The implementation is available at https://github.com/IST-DASLab/gptq.
+
+## Usage
+
+GPTQ is easy to use in mmrazor. You can use it like this:
+
+```python
+from mmrazor.implementations.quantization import gptq
+
+# initial model, dataloaders
+model
+train_loader, test_loader
+
+## init gptq compressor and prepare for quantization
+compressor = gptq.GPTQCompressor()
+compressor.prepare(model)
+
+## get hessian matrix
+compressor.init_hessian()
+compressor.register_hessian_hooks()
+infer(model, test_loader, num_samples=num_samples)
+compressor.remove_hessian_hooks()
+
+## quant
+compressor.quant_with_default_qconfig()
+
+## to a normal torch model
+model = compressor.to_static_model(model)
+
+```
+
+## Full Examples
+
+- [ResNet](../examples/ResNet/README.md)
+- [LLaMA](../examples/language_models/LLaMA/README.md)
+
+## Cite
+
+```latex
+ @misc{
+ Frantar_Ashkboos_Hoefler_Alistarh_2022,
+ title={GPTQ: Accurate Post-Training Quantization for Generative Pre-trained Transformers},
+ author={Frantar, Elias and Ashkboos, Saleh and Hoefler, Torsten and Alistarh, Dan},
+ year={2022},
+ month={Oct},
+ language={en-US}
+}
+```
diff --git a/projects/mmrazor_large/algorithms/SparseGPT.md b/projects/mmrazor_large/algorithms/SparseGPT.md
new file mode 100644
index 000000000..479235baa
--- /dev/null
+++ b/projects/mmrazor_large/algorithms/SparseGPT.md
@@ -0,0 +1,55 @@
+# SparseGPT
+
+> [SparseGPT: Massive Language Models Can Be Accurately Pruned in One-Shot](https://arxiv.org/abs/2301.00774)
+
+
+
+## Abstract
+
+We show for the first time that large-scale generative pretrained transformer (GPT) family models can be pruned to at least 50% sparsity in one-shot, without any retraining, at minimal loss of accuracy. This is achieved via a new pruning method called SparseGPT, specifically designed to work efficiently and accurately on massive GPT-family models. We can execute SparseGPT on the largest available open-source models, OPT-175B and BLOOM-176B, in under 4.5 hours, and can reach 60% unstructured sparsity with negligible increase in perplexity: remarkably, more than 100 billion weights from these models can be ignored at inference time. SparseGPT generalizes to semi-structured (2:4 and 4:8) patterns, and is compatible with weight quantization approaches.
+
+## Usage
+
+SparseGPT is easy to use in mmrazor. You can use it like this:
+
+```python
+from mmrazor.implementations.pruning import sparse_gpt
+
+# initial model, dataloaders
+model
+train_loader, test_loader
+
+## init sparse gpt compressor and prepare for pruning
+compressor = sparse_gpt.SparseGptCompressor()
+compressor.prepare(model)
+
+## get hessian matrix
+compressor.init_hessian()
+compressor.register_hessian_hooks()
+infer(model, test_loader, num_samples=num_samples)
+compressor.remove_hessian_hooks()
+
+## prune
+compressor.prune_24()
+
+## to a normal torch model
+model = compressor.to_static_model(model)
+
+```
+
+## Full Examples
+
+- [ResNet](../examples/ResNet/README.md)
+- [OPT](../examples/language_models/OPT/README.md)
+- [LLaMA](../examples/language_models/LLaMA/README.md)
+
+## Cite
+
+```latex
+@article{frantar2023massive,
+ title={Massive Language Models Can Be Accurately Pruned in One-Shot},
+ author={Frantar, Elias and Alistarh, Dan},
+ journal={arXiv preprint arXiv:2301.00774},
+ year={2023}
+}
+```
diff --git a/projects/mmrazor_large/examples/ResNet/README.md b/projects/mmrazor_large/examples/ResNet/README.md
new file mode 100644
index 000000000..aa4eb374c
--- /dev/null
+++ b/projects/mmrazor_large/examples/ResNet/README.md
@@ -0,0 +1,25 @@
+# Examples for ResNet
+
+## SparseGPT
+
+For more details about SparseGPT, please refer to [SparseGPT](../../algorithms/SparseGPT.md)
+
+### Usage
+
+```shell
+python projects/mmrazor_large/examples/ResNet/resnet18_sparse_gpt.py --data {imagenet_path} --batchsize 128 --num_samples 512
+```
+
+**Note**: this imagenet folder follows torch format.
+
+## GPTQ
+
+For more details about GPTQ, please refer to [GPTQ](../../algorithms/GPTQ.md)
+
+### Usage
+
+```shell
+python projects/mmrazor_large/examples/ResNet/resnet18_gptq.py --data {imagenet_path} --batchsize 128 --num_samples 512
+```
+
+**Note**: this imagenet folder follows torch format.
diff --git a/projects/mmrazor_large/examples/ResNet/resnet18_gptq.py b/projects/mmrazor_large/examples/ResNet/resnet18_gptq.py
new file mode 100644
index 000000000..9aa6877a6
--- /dev/null
+++ b/projects/mmrazor_large/examples/ResNet/resnet18_gptq.py
@@ -0,0 +1,187 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+# model settings
+import os.path as osp
+
+import torch
+import torch.nn as nn
+import torchvision
+import torchvision.datasets as datasets
+import torchvision.transforms as transforms
+from torch.utils.data import DataLoader
+
+from mmrazor.implementations.quantization.gptq import (GPTQCompressor,
+ GPTQLinear)
+from mmrazor.utils import print_log
+
+
+def enable_observer_linear(model):
+ print_log('Enable updating qparams for GPTQLinear!')
+ for _, module in model.named_modules():
+ if isinstance(module, GPTQLinear):
+ module.fix_qparams = False
+
+
+def disable_observer_linear(model):
+ print_log('Disable updating qparams for GPTQLinear!')
+ for _, module in model.named_modules():
+ if isinstance(module, GPTQLinear):
+ module.fix_qparams = True
+
+
+def get_dataloaders(batch_size, n_workers, path=''):
+ normalize = transforms.Normalize(
+ mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
+ train_dataset = datasets.ImageFolder(
+ osp.join(path, 'train'),
+ transforms.Compose([
+ transforms.RandomResizedCrop(224),
+ transforms.RandomHorizontalFlip(),
+ transforms.ToTensor(),
+ normalize,
+ ]),
+ )
+
+ test_dataset = datasets.ImageFolder(
+ osp.join(path, 'val'),
+ transforms.Compose([
+ transforms.Resize(256),
+ transforms.CenterCrop(224),
+ transforms.ToTensor(),
+ normalize,
+ ]),
+ )
+
+ dataloader_train = torch.utils.data.DataLoader(
+ train_dataset,
+ batch_size=batch_size,
+ shuffle=True,
+ num_workers=n_workers,
+ pin_memory=True,
+ )
+ dataloader_test = torch.utils.data.DataLoader(
+ test_dataset,
+ batch_size=batch_size,
+ shuffle=False,
+ num_workers=n_workers,
+ pin_memory=True,
+ )
+ return dataloader_train, dataloader_test
+
+
+@torch.no_grad()
+def eval(model: nn.Module,
+ dataloader_test: DataLoader,
+ device=torch.device('cuda:0'),
+ is_half=True):
+
+ total = 0
+ correct = 0
+
+ model.eval()
+ with torch.no_grad():
+ for x, y in dataloader_test:
+ x: torch.Tensor # type: ignore
+ y: torch.Tensor # type: ignore
+ x = x.to(device)
+ y = y.to(device)
+ if is_half:
+ x = x.half()
+ y = y.half()
+ outputs = model(x)
+ _, predicted = outputs.max(1)
+ correct += (y == predicted).long().sum()
+ total += y.numel()
+ acc = correct / total
+ return acc
+
+
+@torch.no_grad()
+def infer(model: nn.Module,
+ dataloader: torch.utils.data.DataLoader,
+ num_samples=256,
+ device=torch.device('cuda:0'),
+ is_half=True):
+ model.eval()
+ with torch.no_grad():
+ accumulate_batch = 0
+ for x, _ in dataloader:
+ x = x.to(device)
+ if is_half:
+ x = x.half()
+ model(x)
+ B = x.shape[0]
+ accumulate_batch += B
+ if accumulate_batch > num_samples:
+ break
+
+
+if __name__ == '__main__':
+ import argparse
+ arg_parser = argparse.ArgumentParser()
+ arg_parser.add_argument(
+ '--data',
+ type=str,
+ default='data/imagenet_torch',
+ help='path to imagenet in torch folder format')
+ arg_parser.add_argument(
+ '--num_samples',
+ type=int,
+ default=512,
+ help='number of samples to estimate hessian matrix')
+ arg_parser.add_argument(
+ '--batch_size',
+ type=int,
+ default=128,
+ help='batch size for evaluation and inference')
+ arg_parser.add_argument(
+ '--fp16',
+ type=bool,
+ default=False,
+ help='whether to use fp16 for evaluation and inference')
+ args = arg_parser.parse_args()
+
+ data_path = args.data
+ num_samples = args.num_samples
+ batch_size = args.batch_size
+
+ model = torchvision.models.resnet18(pretrained=True)
+ if args.fp16:
+ model = model.half()
+ train_loader, test_loader = get_dataloaders(batch_size, 4, data_path)
+
+ compressor = GPTQCompressor()
+
+ # # use_triton_ops is True
+ # compressor.prepare(model,
+ # quant_conv=True,
+ # quant_linear=True,
+ # use_triton_ops=False,
+ # skipped_layers=['conv1'],
+ # bits=4,
+ # groupsize=128)
+
+ # # quantize activation for linear
+ # a_qconfig = dict(bits=4, perchannel=True, sym=False)
+ compressor.prepare(
+ model,
+ quant_conv=True,
+ quant_linear=True,
+ use_triton_ops=False,
+ skipped_layers=['conv1'],
+ # a_qconfig=a_qconfig
+ )
+
+ model.cuda()
+
+ enable_observer_linear(model)
+ compressor.init_hessian()
+ compressor.register_hessian_hooks()
+ infer(model, test_loader, num_samples=num_samples, is_half=args.fp16)
+ compressor.remove_hessian_hooks()
+ compressor.quant_with_default_qconfig()
+
+ print('start evaluation')
+ disable_observer_linear(model)
+ model = model.cuda()
+ acc = eval(model, test_loader, is_half=args.fp16)
+ print('accuracy:', acc.item())
diff --git a/projects/mmrazor_large/examples/ResNet/resnet18_sparse_gpt.py b/projects/mmrazor_large/examples/ResNet/resnet18_sparse_gpt.py
new file mode 100644
index 000000000..0e6658a6f
--- /dev/null
+++ b/projects/mmrazor_large/examples/ResNet/resnet18_sparse_gpt.py
@@ -0,0 +1,137 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+# model settings
+import os.path as osp
+
+import torch
+import torch.nn as nn
+import torchvision
+import torchvision.datasets as datasets
+import torchvision.transforms as transforms
+from torch.utils.data import DataLoader
+
+from mmrazor.implementations.pruning import sparse_gpt
+
+
+def get_dataloaders(batch_size, n_workers, path=''):
+ normalize = transforms.Normalize(
+ mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
+ train_dataset = datasets.ImageFolder(
+ osp.join(path, 'train'),
+ transforms.Compose([
+ transforms.RandomResizedCrop(224),
+ transforms.RandomHorizontalFlip(),
+ transforms.ToTensor(),
+ normalize,
+ ]),
+ )
+
+ test_dataset = datasets.ImageFolder(
+ osp.join(path, 'val'),
+ transforms.Compose([
+ transforms.Resize(256),
+ transforms.CenterCrop(224),
+ transforms.ToTensor(),
+ normalize,
+ ]),
+ )
+
+ dataloader_train = torch.utils.data.DataLoader(
+ train_dataset,
+ batch_size=batch_size,
+ shuffle=True,
+ num_workers=n_workers,
+ pin_memory=True,
+ )
+ dataloader_test = torch.utils.data.DataLoader(
+ test_dataset,
+ batch_size=batch_size,
+ shuffle=False,
+ num_workers=n_workers,
+ pin_memory=True,
+ )
+ return dataloader_train, dataloader_test
+
+
+@torch.no_grad()
+def eval(model: nn.Module,
+ dataloader_test: DataLoader,
+ device=torch.device('cuda:0')):
+
+ total = 0
+ correct = 0
+
+ model.eval()
+ with torch.no_grad():
+ for x, y in dataloader_test:
+ x: torch.Tensor # type: ignore
+ y: torch.Tensor # type: ignore
+ x = x.to(device)
+ outputs = model(x)
+ _, predicted = outputs.max(1)
+ y = y.to(device)
+ correct += (y == predicted).long().sum()
+ total += y.numel()
+ acc = correct / total
+ return acc
+
+
+@torch.no_grad()
+def infer(model: nn.Module,
+ dataloader: torch.utils.data.DataLoader,
+ num_samples=256,
+ device=torch.device('cuda:0')):
+ model.eval()
+ with torch.no_grad():
+ accumulate_batch = 0
+ for x, _ in dataloader:
+ x = x.to(device)
+ model(x)
+ B = x.shape[0]
+ accumulate_batch += B
+ if accumulate_batch > num_samples:
+ break
+
+
+if __name__ == '__main__':
+ import argparse
+ arg_parser = argparse.ArgumentParser()
+ arg_parser.add_argument(
+ '--data',
+ type=str,
+ default='data/imagenet_torch',
+ help='path to imagenet in torch folder format')
+ arg_parser.add_argument(
+ '--num_samples',
+ type=int,
+ default=512,
+ help='number of samples to estimate hessian matrix')
+ arg_parser.add_argument(
+ '--batch_size',
+ type=int,
+ default=128,
+ help='batch size for evaluation and inference')
+ args = arg_parser.parse_args()
+
+ data_path = args.data
+ num_samples = args.num_samples
+ batch_size = args.batch_size
+
+ model = torchvision.models.resnet18(pretrained=True)
+ train_loader, test_loader = get_dataloaders(batch_size, 4, data_path)
+
+ compressor = sparse_gpt.SparseGptCompressor()
+ compressor.prepare(model)
+
+ model.cuda()
+
+ compressor.init_hessian()
+ compressor.register_hessian_hooks()
+ infer(model, test_loader, num_samples=num_samples)
+ compressor.remove_hessian_hooks()
+ compressor.prune_24()
+ model = compressor.to_static_model(model)
+
+ print('start evaluation')
+ model = model.cuda()
+ acc = eval(model, test_loader)
+ print('accuracy:', acc.item())
diff --git a/projects/mmrazor_large/examples/language_models/LLaMA/README.md b/projects/mmrazor_large/examples/language_models/LLaMA/README.md
new file mode 100644
index 000000000..7d9862de8
--- /dev/null
+++ b/projects/mmrazor_large/examples/language_models/LLaMA/README.md
@@ -0,0 +1,55 @@
+# Examples for LLaMA
+
+## SparseGPT
+
+For more details about SparseGPT, please refer to [SparseGPT](../../../algorithms/SparseGPT.md)
+
+### Usage
+
+```shell
+# example for decapoda-research/llama-7b-hf
+python projects/mmrazor_large/examples/language_models/LLaMA/llama_sparse_gpt.py decapoda-research/llama-7b-hf c4
+
+# help
+usage: llama_sparse_gpt.py [-h] [--seed SEED] [--nsamples NSAMPLES] [--batch_size BATCH_SIZE] [--save SAVE] [-m M] model {wikitext2,ptb,c4}
+
+positional arguments:
+ model Llama model to load
+ {wikitext2,ptb,c4} Where to extract calibration data from.
+
+optional arguments:
+ -h, --help show this help message and exit
+ --seed SEED Seed for sampling the calibration data.
+ --nsamples NSAMPLES Number of calibration data samples.
+ --batch_size BATCH_SIZE
+ Batchsize for calibration and evaluation.
+ --save SAVE Path to saved model.
+ -m M Whether to enable memory efficient forward
+```
+
+## GPTQ
+
+For more details about GPTQ, please refer to [GPTQ](../../../algorithms/GPTQ.md)
+
+### Usage
+
+```shell
+# example for decapoda-research/llama-7b-hf
+python projects/mmrazor_large/examples/language_models/LLaMA/llama_gptq.py decapoda-research/llama-7b-hf c4
+
+# help
+usage: llama_gptq.py [-h] [--seed SEED] [--nsamples NSAMPLES] [--batch_size BATCH_SIZE] [--save SAVE] [-m M] model {wikitext2,ptb,c4}
+
+positional arguments:
+ model Llama model to load
+ {wikitext2,ptb,c4} Where to extract calibration data from.
+
+optional arguments:
+ -h, --help show this help message and exit
+ --seed SEED Seed for sampling the calibration data.
+ --nsamples NSAMPLES Number of calibration data samples.
+ --batch_size BATCH_SIZE
+ Batchsize for calibration and evaluation.
+ --save SAVE Path to saved model.
+ -m M Whether to enable memory efficient forward
+```
diff --git a/projects/mmrazor_large/examples/language_models/LLaMA/datautils.py b/projects/mmrazor_large/examples/language_models/LLaMA/datautils.py
new file mode 100755
index 000000000..04697d560
--- /dev/null
+++ b/projects/mmrazor_large/examples/language_models/LLaMA/datautils.py
@@ -0,0 +1,152 @@
+import numpy as np
+import torch
+from torch.utils.data import DataLoader
+from torch.utils.data import Dataset as TorchDataset
+from torch.utils.data import DistributedSampler
+
+
+def set_seed(seed):
+ np.random.seed(seed)
+ torch.random.manual_seed(seed)
+
+
+def get_wikitext2(nsamples, seed, seqlen, model):
+ from datasets import load_dataset
+ traindata = load_dataset('wikitext', 'wikitext-2-raw-v1', split='train')
+ testdata = load_dataset('wikitext', 'wikitext-2-raw-v1', split='test')
+
+ from transformers import AutoTokenizer
+ tokenizer = AutoTokenizer.from_pretrained(model, use_fast=False)
+ trainenc = tokenizer(' '.join(traindata['text']), return_tensors='pt')
+ testenc = tokenizer('\n\n'.join(testdata['text']), return_tensors='pt')
+
+ import random
+ random.seed(seed)
+ trainloader = []
+ for _ in range(nsamples):
+ i = random.randint(0, trainenc.input_ids.shape[1] - seqlen - 1)
+ j = i + seqlen
+ inp = trainenc.input_ids[:, i:j]
+ tar = inp.clone()
+ tar[:, :-1] = -100
+ trainloader.append((inp, tar))
+ return trainloader, testenc
+
+
+def get_ptb(nsamples, seed, seqlen, model):
+ from datasets import load_dataset
+ traindata = load_dataset('ptb_text_only', 'penn_treebank', split='train')
+ testdata = load_dataset('ptb_text_only', 'penn_treebank', split='test')
+
+ from transformers import AutoTokenizer
+ tokenizer = AutoTokenizer.from_pretrained(model, use_fast=False)
+ trainenc = tokenizer(' '.join(traindata['sentence']), return_tensors='pt')
+ testenc = tokenizer(' '.join(testdata['sentence']), return_tensors='pt')
+
+ import random
+ random.seed(seed)
+ trainloader = []
+ for _ in range(nsamples):
+ i = random.randint(0, trainenc.input_ids.shape[1] - seqlen - 1)
+ j = i + seqlen
+ inp = trainenc.input_ids[:, i:j]
+ tar = inp.clone()
+ tar[:, :-1] = -100
+ trainloader.append((inp, tar))
+ return trainloader, testenc
+
+
+def get_c4(nsamples, seed, seqlen, model):
+ from datasets import load_dataset
+ traindata = load_dataset(
+ 'allenai/c4',
+ 'allenai--c4',
+ data_files={'train': 'en/c4-train.00000-of-01024.json.gz'},
+ split='train')
+ valdata = load_dataset(
+ 'allenai/c4',
+ 'allenai--c4',
+ data_files={'validation': 'en/c4-validation.00000-of-00008.json.gz'},
+ split='validation')
+
+ from transformers import AutoTokenizer
+ tokenizer = AutoTokenizer.from_pretrained(model, use_fast=False)
+
+ import random
+ random.seed(seed)
+ trainloader = []
+ for _ in range(nsamples):
+ while True:
+ i = random.randint(0, len(traindata) - 1)
+ trainenc = tokenizer(traindata[i]['text'], return_tensors='pt')
+ if trainenc.input_ids.shape[1] >= seqlen:
+ break
+ i = random.randint(0, trainenc.input_ids.shape[1] - seqlen - 1)
+ j = i + seqlen
+ inp = trainenc.input_ids[:, i:j]
+ tar = inp.clone()
+ tar[:, :-1] = -100
+ trainloader.append((inp, tar))
+
+ valenc = tokenizer(' '.join(valdata[:1100]['text']), return_tensors='pt')
+ valenc = valenc.input_ids[:, :(256 * seqlen)]
+
+ class TokenizerWrapper:
+
+ def __init__(self, input_ids):
+ self.input_ids = input_ids
+
+ valenc = TokenizerWrapper(valenc)
+
+ return trainloader, valenc
+
+
+def get_loaders(name, nsamples=128, seed=0, seqlen=2048, model=''):
+ if 'wikitext2' in name:
+ return get_wikitext2(nsamples, seed, seqlen, model)
+ if 'ptb' in name:
+ return get_ptb(nsamples, seed, seqlen, model)
+ if 'c4' in name:
+ return get_c4(nsamples, seed, seqlen, model)
+
+
+def fold_tokens(tokens: torch.Tensor, batch_seq_len=2048):
+ # tokens: 1 N
+ N = tokens.shape[1]
+ num_drop = N % batch_seq_len
+ if num_drop != 0:
+ tokens = tokens[:, :-num_drop]
+ tokens = tokens.reshape([-1, batch_seq_len]) # B N
+ return tokens
+
+
+class LanguageDataset(TorchDataset):
+
+ def __init__(self, seq: torch.Tensor, seq_len: int = 2048) -> None:
+ super().__init__()
+ # seq: 1, N
+ self.seq_len = seq_len
+
+ self.seq = fold_tokens(seq) # B N
+
+ def __len__(self) -> int:
+ return self.seq.shape[0]
+
+ def __getitem__(self, index):
+ return self.seq[index]
+
+
+def build_language_loader(testloader, world_size, rank, model, batch_size=128):
+ val_dataset = LanguageDataset(testloader.input_ids, seq_len=model.seqlen)
+ distributed_sampler = DistributedSampler(
+ val_dataset, num_replicas=world_size, rank=rank, shuffle=False)
+ batch_size = min(len(val_dataset) // world_size, batch_size)
+ val_dataloader = DataLoader(
+ val_dataset,
+ batch_size=batch_size,
+ shuffle=False,
+ num_workers=0,
+ pin_memory=True,
+ drop_last=True,
+ sampler=distributed_sampler)
+ return val_dataloader
diff --git a/projects/mmrazor_large/examples/language_models/LLaMA/llama_gptq.py b/projects/mmrazor_large/examples/language_models/LLaMA/llama_gptq.py
new file mode 100644
index 000000000..0eae9b4f0
--- /dev/null
+++ b/projects/mmrazor_large/examples/language_models/LLaMA/llama_gptq.py
@@ -0,0 +1,162 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import torch
+from datautils import get_loaders
+from transformers.models.llama import LlamaForCausalLM
+from transformers.models.llama.modeling_llama import LlamaDecoderLayer
+from utils import opt_eval, opt_infer
+
+from mmrazor.implementations.pruning.sparse_gpt.utils import \
+ memory_efficient_forward
+from mmrazor.implementations.quantization.gptq import (GPTQLinear,
+ TritonGPTQLinear)
+from mmrazor.utils import print_log
+
+
+def enable_observer_linear(model):
+ print_log('Enable updating qparams for GPTQLinear!')
+ for _, module in model.named_modules():
+ if isinstance(module, GPTQLinear):
+ module.fix_qparams = False
+
+
+def disable_observer_linear(model):
+ print_log('Disable updating qparams for GPTQLinear!')
+ for _, module in model.named_modules():
+ if isinstance(module, GPTQLinear):
+ module.fix_qparams = True
+
+
+def del_redundant_attr(model):
+ print_log('Del redundant weight for GPTQLinear!')
+ for _, module in model.named_modules():
+ if isinstance(module, TritonGPTQLinear):
+ del module.weight
+
+
+def get_model(model):
+
+ def skip(*args, **kwargs):
+ pass
+
+ torch.nn.init.kaiming_uniform_ = skip
+ torch.nn.init.uniform_ = skip
+ torch.nn.init.normal_ = skip
+ model: LlamaForCausalLM = LlamaForCausalLM.from_pretrained(
+ model,
+ torch_dtype='auto',
+ )
+ model.seqlen = 2048
+ return model
+
+
+if __name__ == '__main__':
+
+ import argparse
+ parser = argparse.ArgumentParser()
+
+ parser.add_argument('model', type=str, help='Llama model to load')
+ parser.add_argument(
+ '--dataset',
+ type=str,
+ choices=['wikitext2', 'ptb', 'c4'],
+ help='Where to extract calibration data from.')
+ parser.add_argument(
+ '--seed',
+ type=int,
+ default=0,
+ help='Seed for sampling the calibration data.')
+ parser.add_argument(
+ '--nsamples',
+ type=int,
+ default=128,
+ help='Number of calibration data samples.')
+ parser.add_argument(
+ '--batch_size',
+ type=int,
+ default=16,
+ help='Batchsize for calibration and evaluation.')
+ parser.add_argument(
+ '--save', type=str, default='', help='Path to saved model.')
+ parser.add_argument(
+ '--quant_ckpt', type=str, default='', help='Quantized ckpt to load.')
+ parser.add_argument(
+ '--dev', type=str, default='cuda:0', help='Use which device.')
+ parser.add_argument(
+ '-m',
+ type=bool,
+ default=False,
+ help='Whether to enable memory efficient forward')
+
+ args = parser.parse_args()
+
+ DEV = args.dev
+
+ model = get_model(args.model)
+ model.to(DEV)
+ model.eval()
+ print_log('load model over')
+
+ from mmrazor.implementations.quantization import gptq
+ compressor = gptq.GPTQCompressor()
+ # use_triton_ops is True
+ compressor.prepare(
+ model.model.layers,
+ quant_conv=True,
+ use_triton_ops=True,
+ quant_linear=True,
+ bits=4,
+ groupsize=128)
+
+ # # quantize activation for linear
+ # # a_qconfig = dict(bits=4, perchannel=False, sym=False)
+ # compressor.prepare(
+ # model.model.layers,
+ # quant_conv=True,
+ # quant_linear=True,
+ # use_triton_ops=False,
+ # # a_qconfig=a_qconfig
+ # )
+
+ if args.quant_ckpt:
+ del_redundant_attr(model)
+ model.load_state_dict(torch.load(args.quant_ckpt))
+ else:
+ dataloader, testloader = get_loaders(
+ args.dataset,
+ seed=args.seed,
+ model=args.model,
+ seqlen=model.seqlen)
+ print_log('load data for infer over')
+
+ compressor.init_hessian()
+ enable_observer_linear(model)
+ with memory_efficient_forward(
+ model,
+ wrap_modules=[LlamaDecoderLayer],
+ enabled=args.m,
+ device=DEV):
+ compressor.register_hessian_hooks()
+ opt_infer(
+ model,
+ testloader,
+ DEV,
+ batch_size=args.batch_size,
+ num_samples=args.nsamples)
+ compressor.remove_hessian_hooks()
+ compressor.quant_with_default_qconfig(device=DEV)
+
+ disable_observer_linear(model)
+ with memory_efficient_forward(
+ model, wrap_modules=[LlamaDecoderLayer], enabled=args.m,
+ device=DEV):
+
+ # for dataset in ['wikitext2', 'ptb', 'c4']:
+ for dataset in ['wikitext2']:
+ dataloader, testloader = get_loaders(
+ dataset, seed=args.seed, model=args.model, seqlen=model.seqlen)
+ print_log(dataset)
+ opt_eval(model, testloader, DEV, batch_size=args.batch_size)
+
+ if args.save and not args.quant_ckpt:
+ print_log(f'save model in {args.save}')
+ torch.save(model.state_dict(), args.save)
diff --git a/projects/mmrazor_large/examples/language_models/LLaMA/llama_sparse_gpt.py b/projects/mmrazor_large/examples/language_models/LLaMA/llama_sparse_gpt.py
new file mode 100644
index 000000000..972feff2f
--- /dev/null
+++ b/projects/mmrazor_large/examples/language_models/LLaMA/llama_sparse_gpt.py
@@ -0,0 +1,106 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import torch
+from datautils import get_loaders
+from transformers.models.llama import LlamaForCausalLM
+from transformers.models.llama.modeling_llama import LlamaDecoderLayer
+from utils import opt_eval, opt_infer
+
+from mmrazor.implementations.pruning.sparse_gpt.utils import \
+ memory_efficient_forward
+from mmrazor.utils import print_log
+
+
+def get_model(model):
+ import torch
+
+ def skip(*args, **kwargs):
+ pass
+
+ torch.nn.init.kaiming_uniform_ = skip
+ torch.nn.init.uniform_ = skip
+ torch.nn.init.normal_ = skip
+ model: LlamaForCausalLM = LlamaForCausalLM.from_pretrained(
+ model,
+ torch_dtype='auto',
+ )
+ model.seqlen = 2048
+ return model
+
+
+if __name__ == '__main__':
+
+ import argparse
+ parser = argparse.ArgumentParser()
+
+ parser.add_argument('model', type=str, help='Llama model to load')
+ parser.add_argument(
+ 'dataset',
+ type=str,
+ choices=['wikitext2', 'ptb', 'c4'],
+ help='Where to extract calibration data from.')
+ parser.add_argument(
+ '--seed',
+ type=int,
+ default=0,
+ help='Seed for sampling the calibration data.')
+ parser.add_argument(
+ '--nsamples',
+ type=int,
+ default=128,
+ help='Number of calibration data samples.')
+ parser.add_argument(
+ '--batch_size',
+ type=int,
+ default=16,
+ help='Batchsize for calibration and evaluation.')
+ parser.add_argument(
+ '--save', type=str, default='', help='Path to saved model.')
+ parser.add_argument(
+ '-m',
+ type=bool,
+ default=False,
+ help='Whether to enable memory efficient forward')
+
+ args = parser.parse_args()
+
+ torch.set_default_dtype(torch.half)
+ DEV = torch.device('cuda:0')
+
+ model = get_model(args.model)
+ model.eval()
+ print_log('load model over')
+
+ dataloader, testloader = get_loaders(
+ args.dataset, seed=args.seed, model=args.model, seqlen=model.seqlen)
+ print_log('load data for infer over')
+
+ from mmrazor.implementations.pruning import sparse_gpt
+ compressor = sparse_gpt.SparseGptCompressor()
+ compressor.prepare(model.model.layers)
+
+ compressor.init_hessian()
+ with memory_efficient_forward(
+ model, wrap_modules=[LlamaDecoderLayer], enabled=args.m):
+ compressor.register_hessian_hooks()
+ opt_infer(
+ model,
+ testloader,
+ DEV,
+ batch_size=args.batch_size,
+ num_samples=args.nsamples)
+ compressor.remove_hessian_hooks()
+ compressor.prune_24()
+
+ model = compressor.to_static_model(model)
+ if args.save:
+ print_log(f'save model in {args.save}')
+ model.save_pretrained(args.save)
+
+ with memory_efficient_forward(
+ model, wrap_modules=[LlamaDecoderLayer], enabled=args.m):
+
+ for dataset in ['wikitext2', 'ptb', 'c4']:
+ dataloader, testloader = get_loaders(
+ dataset, seed=args.seed, model=args.model, seqlen=model.seqlen)
+ print_log(dataset)
+ opt_eval(model, testloader, DEV, batch_size=args.batch_size)
diff --git a/projects/mmrazor_large/examples/language_models/LLaMA/llama_sparse_gpt_fsdp.py b/projects/mmrazor_large/examples/language_models/LLaMA/llama_sparse_gpt_fsdp.py
new file mode 100644
index 000000000..14d40172b
--- /dev/null
+++ b/projects/mmrazor_large/examples/language_models/LLaMA/llama_sparse_gpt_fsdp.py
@@ -0,0 +1,198 @@
+import functools
+import os
+
+import torch
+import torch.distributed as dist
+import torch.multiprocessing as mp
+import torch.nn as nn
+from datautils import build_language_loader, get_loaders
+from llama_sparse_gpt import get_model
+from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
+from torch.distributed.fsdp.api import ShardingStrategy
+from torch.distributed.fsdp.fully_sharded_data_parallel import CPUOffload
+from torch.distributed.fsdp.wrap import size_based_auto_wrap_policy
+from utils import init_on_meta, opt_eval_fsdp, opt_infer_fsdp
+
+from mmrazor.implementations.pruning import sparse_gpt
+from mmrazor.utils import print_log
+
+
+def setup(rank, world_size):
+ os.environ['MASTER_ADDR'] = 'localhost'
+ os.environ['MASTER_PORT'] = '12356'
+
+ dist.init_process_group('nccl', rank=rank, world_size=world_size)
+ torch.cuda.set_device(rank)
+ print_log(f'init {rank}/{world_size}', only_rank0=False)
+
+
+def init_fn_wrapper(model: nn.Module, model_copy: nn.Module):
+
+ def find_module_in_model_copy(module: nn.Module):
+ name2module = dict(model.named_modules())
+ module2name = dict([(v, k) for k, v in name2module.items()])
+
+ name = module2name[module]
+ return dict(model_copy.named_modules())[name]
+
+ def _materialize_meta_module(module: nn.Module, ):
+
+ def meta_to_empty(p: torch.Tensor):
+ if p.device == torch.device('meta'):
+ return p.new_empty(p.shape, device='cpu')
+ else:
+ return p
+
+ module._apply(meta_to_empty)
+ if dist.get_rank() == 0:
+ assert model_copy is not None
+ module_copy = find_module_in_model_copy(module)
+
+ name2p = dict(module_copy.named_parameters(remove_duplicate=False))
+ for n, p in module.named_parameters():
+ if '_flat_param' not in n:
+ n = n.replace('_fsdp_wrapped_module.', '')
+ try:
+ p.data.copy_(name2p[n])
+ except Exception:
+ pass
+ name2p = dict(module_copy.named_buffers(remove_duplicate=False))
+ for n, p in module.named_buffers():
+ if '_flat_param' not in n:
+ n = n.replace('_fsdp_wrapped_module.', '')
+ try:
+ p.data.copy_(name2p[n])
+ except Exception:
+ pass
+
+ return _materialize_meta_module
+
+
+def main(rank, world_size=8, args=None):
+ setup(rank, world_size)
+
+ model_name = args.model
+ batch_size = args.batch_size
+
+ def build():
+ model = get_model(model_name)
+
+ # init compressor
+ compressor = sparse_gpt.SparseGptCompressor()
+ compressor.prepare(model.model.layers)
+ return model, compressor
+
+ with init_on_meta(enable=True):
+ model, compressor = build()
+
+ if rank == 0:
+ model_copy, _ = build() # init on cpu
+ else:
+ model_copy = None
+
+ # init fsdp
+ size_based_auto_wrap_policy_x = functools.partial(
+ size_based_auto_wrap_policy, min_num_params=int(1e8))
+
+ model = FSDP(
+ model,
+ auto_wrap_policy=size_based_auto_wrap_policy_x,
+ cpu_offload=CPUOffload(True),
+ sharding_strategy=ShardingStrategy.FULL_SHARD,
+ device_id=rank,
+ param_init_fn=init_fn_wrapper(model, model_copy),
+ sync_module_states=True)
+ print_log(model)
+
+ # init hessian
+
+ compressor.init_hessian(device='cuda')
+ compressor.register_hessian_hooks()
+
+ _, testloader = get_loaders(
+ args.dataset, seed=args.seed, model=model_name, seqlen=model.seqlen)
+ testloader = build_language_loader(
+ testloader, world_size, rank, model, batch_size=batch_size)
+ opt_infer_fsdp(model, testloader)
+
+ compressor.remove_hessian_hooks()
+
+ # prune
+ name2module = dict(model.named_modules())
+ module2name = {}
+ module2name = dict([(v, k) for k, v in name2module.items()])
+
+ with torch.no_grad():
+ for fsdp in FSDP.fsdp_modules(model):
+ fsdp._reset_lazy_init()
+ with FSDP.summon_full_params(fsdp, recurse=False):
+ fsdp_name = module2name[fsdp]
+ for name, op in fsdp.named_modules():
+ if name.count('_fsdp_wrapped_module') <= 1:
+ if isinstance(op, sparse_gpt.SparseGptMixIn):
+ try:
+ op.prune(0.5, prunen=2, prunem=4)
+ print_log(
+ f'prune {fsdp_name}.{name} successfully.', # noqa
+ only_rank0=True)
+ except Exception as e:
+ print_log(
+ f'prune {fsdp_name}.{name} failed, as {e}', # noqa
+ only_rank0=True)
+ fsdp._reset_lazy_init()
+
+ # save
+ if args.save:
+ print_log(f'save model in {args.save}')
+ model._reset_lazy_init()
+ with FSDP.summon_full_params(model, rank0_only=True, writeback=False):
+ if dist.get_rank() == 0:
+ model.save_pretrained(args.save)
+
+ # val
+ torch.cuda.empty_cache()
+ model._reset_lazy_init()
+ for dataset in ['wikitext2', 'ptb', 'c4']:
+ _, testloader = get_loaders(
+ dataset, seed=args.seed, model=model_name, seqlen=model.seqlen)
+ testloader = build_language_loader(
+ testloader, world_size, rank, model, batch_size=batch_size)
+ print_log(dataset)
+ opt_eval_fsdp(model, testloader, torch.device('cuda'))
+
+
+if __name__ == '__main__':
+ import argparse
+ parser = argparse.ArgumentParser()
+
+ parser.add_argument(
+ 'model', type=str, help='OPT model to load; pass `facebook/opt-X`.')
+ parser.add_argument(
+ 'dataset',
+ type=str,
+ choices=['wikitext2', 'ptb', 'c4'],
+ help='Where to extract calibration data from.')
+ parser.add_argument(
+ '--seed',
+ type=int,
+ default=0,
+ help='Seed for sampling the calibration data.')
+ parser.add_argument(
+ '--nsamples',
+ type=int,
+ default=128,
+ help='Number of calibration data samples.')
+ parser.add_argument(
+ '--batch_size',
+ type=int,
+ default=64,
+ help='Batchsize for calibration and evaluation.')
+
+ parser.add_argument(
+ '--save', type=str, default='', help='Path to saved model.')
+ parser.add_argument(
+ '--world_size', type=int, default=1, help='Number of GPUs to use.')
+ args = parser.parse_args()
+
+ WORLD_SIZE = args.world_size
+ mp.spawn(main, args=(WORLD_SIZE, args), nprocs=WORLD_SIZE, join=True)
diff --git a/projects/mmrazor_large/examples/language_models/LLaMA/utils.py b/projects/mmrazor_large/examples/language_models/LLaMA/utils.py
new file mode 100644
index 000000000..1f8ceb87d
--- /dev/null
+++ b/projects/mmrazor_large/examples/language_models/LLaMA/utils.py
@@ -0,0 +1,173 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+# Example for opt is converted from https://github.com/ist-daslab/sparsegpt
+import torch
+import torch.nn as nn
+from torch import distributed as dist
+from torch.utils.data import DataLoader
+from transformers import OPTForCausalLM
+
+from mmrazor.utils import print_log
+
+
+def fold_tokens(tokens: torch.Tensor, batch_seq_len=2048):
+ # tokens: 1 N
+ N = tokens.shape[1]
+ num_drop = N % batch_seq_len
+ if num_drop != 0:
+ tokens = tokens[:, :-num_drop]
+ tokens = tokens.reshape([-1, batch_seq_len]) # B N
+ return tokens
+
+
+@torch.no_grad()
+def opt_eval(model: OPTForCausalLM,
+ testenc,
+ dev=torch.device('cuda:0'),
+ batch_size=16):
+ print_log('Evaluating ...')
+
+ seqlen = model.seqlen
+
+ testenc: torch.Tensor = testenc.input_ids # type: ignore # 1, N
+ testenc = fold_tokens(testenc, seqlen) # B N
+
+ use_cache = model.config.use_cache
+ model.config.use_cache = False
+ nlls = []
+
+ for i, batch in enumerate(torch.split(testenc, batch_size)):
+ B = batch.shape[0]
+
+ batch = batch.to(dev)
+ out: torch.Tensor = model(batch)[0] # 1
+
+ shift_logits = out[:, :-1, :].contiguous().flatten(0, 1) # (B N) C
+ shift_labels = batch[:, 1:].flatten() # (B N)
+
+ loss_fct = nn.CrossEntropyLoss()
+ loss = loss_fct(shift_logits, shift_labels)
+ neg_log_likelihood = loss.float() * seqlen * B
+ nlls.append(neg_log_likelihood)
+
+ print_log(f'{(i+1)*batch_size} / {len(testenc)}')
+
+ ppl = torch.exp(torch.stack(nlls).sum() / (testenc.numel()))
+ print_log(f'Perplexity: {ppl.item():3f}')
+ model.config.use_cache = use_cache
+
+
+@torch.no_grad()
+def opt_infer(
+ model: OPTForCausalLM,
+ testenc,
+ dev,
+ batch_size=16,
+ num_samples=128,
+):
+ print_log('Infer ...')
+
+ seqlen = model.seqlen
+
+ testenc: torch.Tensor = testenc.input_ids # type: ignore # 1, N
+ testenc = fold_tokens(testenc, seqlen) # B N
+
+ use_cache = model.config.use_cache
+ model.config.use_cache = False
+
+ for i, batch in enumerate(torch.split(testenc, batch_size)):
+ batch = batch.to(dev)
+ _ = model(batch)[0] # 1
+ print_log(f'{(i+1)*batch_size} / {num_samples}')
+
+ if (i + 1) * batch_size >= num_samples:
+ break
+ model.config.use_cache = use_cache
+
+
+class init_on_meta:
+
+ def __init__(self, enable=True) -> None:
+ self.enable = enable
+ self.default_device = torch.ones([]).device
+
+ def __enter__(self):
+ if self.enable:
+ torch.set_default_device('meta')
+
+ def __exit__(self, exc_type, exc_value, traceback):
+ if self.enable:
+ torch.set_default_device(self.default_device)
+
+
+@torch.no_grad()
+def opt_eval_fsdp(
+ model: nn.Module,
+ dataloader: DataLoader,
+ dev=torch.device('cuda:0'),
+):
+ print_log('Evaluating ...')
+
+ use_cache = model.config.use_cache
+ model.config.use_cache = False
+ loss_sum = torch.zeros([1], device=dev)
+ total_seq_len = torch.zeros([1], device=dev, dtype=torch.long)
+
+ for i, batch in enumerate(dataloader):
+ B, seq_len = batch.shape[:2]
+
+ batch = batch.to(dev)
+ out: torch.Tensor = model(batch)[0] # 1
+
+ shift_logits = out[:, :-1, :].contiguous().flatten(0, 1) # (B N) C
+ shift_labels = batch[:, 1:].flatten() # (B N)
+
+ loss_fct = nn.CrossEntropyLoss()
+ loss = loss_fct(shift_logits, shift_labels)
+
+ neg_log_likelihood = loss.float() * seq_len * B
+ total_seq_len += seq_len * B
+ loss_sum += neg_log_likelihood
+
+ if dist.is_initialized():
+ world_size = dist.get_world_size()
+ else:
+ world_size = 1
+ infered_batch = (i + 1) * B * world_size
+
+ print_log(f'{infered_batch} / {len(dataloader.dataset)}')
+
+ if dist.is_initialized():
+ dist.all_reduce(loss_sum)
+ dist.all_reduce(total_seq_len)
+
+ ppl = torch.exp(loss_sum / total_seq_len)
+ print_log(f'Perplexity: {ppl.item():3f}')
+ model.config.use_cache = use_cache
+
+
+@torch.no_grad()
+def opt_infer_fsdp(
+ model: nn.Module,
+ dataloader: DataLoader,
+ dev=torch.device('cuda:0'),
+ num_samples=128,
+):
+ print_log('Infering ...')
+
+ model.config.use_cache = False
+
+ for i, batch in enumerate(dataloader):
+ B = batch.shape[0]
+
+ batch = batch.to(dev)
+ model(batch)[0] # 1
+
+ if dist.is_initialized():
+ world_size = dist.get_world_size()
+ else:
+ world_size = 1
+ infered_batch = (i + 1) * B * world_size
+
+ print_log(f'{infered_batch} / {len(dataloader.dataset)}')
+ if infered_batch >= num_samples:
+ break
diff --git a/projects/mmrazor_large/examples/language_models/OPT/README.md b/projects/mmrazor_large/examples/language_models/OPT/README.md
new file mode 100644
index 000000000..a5d1c8030
--- /dev/null
+++ b/projects/mmrazor_large/examples/language_models/OPT/README.md
@@ -0,0 +1,55 @@
+# Examples for OPT
+
+## SparseGPT
+
+For more details about SparseGPT, please refer to [SparseGPT](../../../algorithms/SparseGPT.md)
+
+### Usage
+
+```shell
+# example for facebook/opt-125m
+python projects/mmrazor_large/examples/language_models/OPT/opt_sparse_gpt.py facebook/opt-125m c4
+
+# help
+usage: opt_sparse_gpt.py [-h] [--seed SEED] [--nsamples NSAMPLES] [--batch_size BATCH_SIZE] [--save SAVE] [-m M] model {wikitext2,ptb,c4}
+
+positional arguments:
+ model OPT model to load; pass `facebook/opt-X`.
+ {wikitext2,ptb,c4} Where to extract calibration data from.
+
+optional arguments:
+ -h, --help show this help message and exit
+ --seed SEED Seed for sampling the calibration data.
+ --nsamples NSAMPLES Number of calibration data samples.
+ --batch_size BATCH_SIZE
+ Batchsize for calibration and evaluation.
+ --save SAVE Path to saved model.
+ -m M Whether to enable memory efficient forward
+```
+
+## GPTQ
+
+For more details about GPTQ, please refer to [GPTQ](../../../algorithms/GPTQ.md)
+
+### Usage
+
+```shell
+# example for facebook/opt-125m
+python projects/mmrazor_large/examples/language_models/OPT/opt_gptq.py facebook/opt-125m c4
+
+# help
+usage: opt_gptq.py [-h] [--seed SEED] [--nsamples NSAMPLES] [--batch_size BATCH_SIZE] [--save SAVE] [-m M] model {wikitext2,ptb,c4}
+
+positional arguments:
+ model OPT model to load; pass `facebook/opt-X`.
+ {wikitext2,ptb,c4} Where to extract calibration data from.
+
+optional arguments:
+ -h, --help show this help message and exit
+ --seed SEED Seed for sampling the calibration data.
+ --nsamples NSAMPLES Number of calibration data samples.
+ --batch_size BATCH_SIZE
+ Batchsize for calibration and evaluation.
+ --save SAVE Path to saved model.
+ -m M Whether to enable memory efficient forward
+```
diff --git a/projects/mmrazor_large/examples/language_models/OPT/datautils.py b/projects/mmrazor_large/examples/language_models/OPT/datautils.py
new file mode 100755
index 000000000..04697d560
--- /dev/null
+++ b/projects/mmrazor_large/examples/language_models/OPT/datautils.py
@@ -0,0 +1,152 @@
+import numpy as np
+import torch
+from torch.utils.data import DataLoader
+from torch.utils.data import Dataset as TorchDataset
+from torch.utils.data import DistributedSampler
+
+
+def set_seed(seed):
+ np.random.seed(seed)
+ torch.random.manual_seed(seed)
+
+
+def get_wikitext2(nsamples, seed, seqlen, model):
+ from datasets import load_dataset
+ traindata = load_dataset('wikitext', 'wikitext-2-raw-v1', split='train')
+ testdata = load_dataset('wikitext', 'wikitext-2-raw-v1', split='test')
+
+ from transformers import AutoTokenizer
+ tokenizer = AutoTokenizer.from_pretrained(model, use_fast=False)
+ trainenc = tokenizer(' '.join(traindata['text']), return_tensors='pt')
+ testenc = tokenizer('\n\n'.join(testdata['text']), return_tensors='pt')
+
+ import random
+ random.seed(seed)
+ trainloader = []
+ for _ in range(nsamples):
+ i = random.randint(0, trainenc.input_ids.shape[1] - seqlen - 1)
+ j = i + seqlen
+ inp = trainenc.input_ids[:, i:j]
+ tar = inp.clone()
+ tar[:, :-1] = -100
+ trainloader.append((inp, tar))
+ return trainloader, testenc
+
+
+def get_ptb(nsamples, seed, seqlen, model):
+ from datasets import load_dataset
+ traindata = load_dataset('ptb_text_only', 'penn_treebank', split='train')
+ testdata = load_dataset('ptb_text_only', 'penn_treebank', split='test')
+
+ from transformers import AutoTokenizer
+ tokenizer = AutoTokenizer.from_pretrained(model, use_fast=False)
+ trainenc = tokenizer(' '.join(traindata['sentence']), return_tensors='pt')
+ testenc = tokenizer(' '.join(testdata['sentence']), return_tensors='pt')
+
+ import random
+ random.seed(seed)
+ trainloader = []
+ for _ in range(nsamples):
+ i = random.randint(0, trainenc.input_ids.shape[1] - seqlen - 1)
+ j = i + seqlen
+ inp = trainenc.input_ids[:, i:j]
+ tar = inp.clone()
+ tar[:, :-1] = -100
+ trainloader.append((inp, tar))
+ return trainloader, testenc
+
+
+def get_c4(nsamples, seed, seqlen, model):
+ from datasets import load_dataset
+ traindata = load_dataset(
+ 'allenai/c4',
+ 'allenai--c4',
+ data_files={'train': 'en/c4-train.00000-of-01024.json.gz'},
+ split='train')
+ valdata = load_dataset(
+ 'allenai/c4',
+ 'allenai--c4',
+ data_files={'validation': 'en/c4-validation.00000-of-00008.json.gz'},
+ split='validation')
+
+ from transformers import AutoTokenizer
+ tokenizer = AutoTokenizer.from_pretrained(model, use_fast=False)
+
+ import random
+ random.seed(seed)
+ trainloader = []
+ for _ in range(nsamples):
+ while True:
+ i = random.randint(0, len(traindata) - 1)
+ trainenc = tokenizer(traindata[i]['text'], return_tensors='pt')
+ if trainenc.input_ids.shape[1] >= seqlen:
+ break
+ i = random.randint(0, trainenc.input_ids.shape[1] - seqlen - 1)
+ j = i + seqlen
+ inp = trainenc.input_ids[:, i:j]
+ tar = inp.clone()
+ tar[:, :-1] = -100
+ trainloader.append((inp, tar))
+
+ valenc = tokenizer(' '.join(valdata[:1100]['text']), return_tensors='pt')
+ valenc = valenc.input_ids[:, :(256 * seqlen)]
+
+ class TokenizerWrapper:
+
+ def __init__(self, input_ids):
+ self.input_ids = input_ids
+
+ valenc = TokenizerWrapper(valenc)
+
+ return trainloader, valenc
+
+
+def get_loaders(name, nsamples=128, seed=0, seqlen=2048, model=''):
+ if 'wikitext2' in name:
+ return get_wikitext2(nsamples, seed, seqlen, model)
+ if 'ptb' in name:
+ return get_ptb(nsamples, seed, seqlen, model)
+ if 'c4' in name:
+ return get_c4(nsamples, seed, seqlen, model)
+
+
+def fold_tokens(tokens: torch.Tensor, batch_seq_len=2048):
+ # tokens: 1 N
+ N = tokens.shape[1]
+ num_drop = N % batch_seq_len
+ if num_drop != 0:
+ tokens = tokens[:, :-num_drop]
+ tokens = tokens.reshape([-1, batch_seq_len]) # B N
+ return tokens
+
+
+class LanguageDataset(TorchDataset):
+
+ def __init__(self, seq: torch.Tensor, seq_len: int = 2048) -> None:
+ super().__init__()
+ # seq: 1, N
+ self.seq_len = seq_len
+
+ self.seq = fold_tokens(seq) # B N
+
+ def __len__(self) -> int:
+ return self.seq.shape[0]
+
+ def __getitem__(self, index):
+ return self.seq[index]
+
+
+def build_language_loader(testloader, world_size, rank, model, batch_size=128):
+ val_dataset = LanguageDataset(testloader.input_ids, seq_len=model.seqlen)
+ distributed_sampler = DistributedSampler(
+ val_dataset, num_replicas=world_size, rank=rank, shuffle=False)
+ batch_size = min(len(val_dataset) // world_size, batch_size)
+ val_dataloader = DataLoader(
+ val_dataset,
+ batch_size=batch_size,
+ shuffle=False,
+ num_workers=0,
+ pin_memory=True,
+ drop_last=True,
+ sampler=distributed_sampler)
+ return val_dataloader
diff --git a/projects/mmrazor_large/examples/language_models/OPT/opt_gptq.py b/projects/mmrazor_large/examples/language_models/OPT/opt_gptq.py
new file mode 100644
index 000000000..5cd48e563
--- /dev/null
+++ b/projects/mmrazor_large/examples/language_models/OPT/opt_gptq.py
@@ -0,0 +1,157 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+# Example for opt is converted from https://github.com/ist-daslab/sparsegpt
+import torch
+from datautils import get_loaders
+from transformers import OPTForCausalLM
+from transformers.models.opt.modeling_opt import OPTDecoderLayer
+from utils import opt_eval, opt_infer
+
+from mmrazor.implementations.pruning.sparse_gpt.utils import \
+ memory_efficient_forward
+from mmrazor.implementations.quantization.gptq import (GPTQLinear,
+ TritonGPTQLinear)
+from mmrazor.utils import print_log
+
+
+def enable_observer_linear(model):
+ print_log('Enable updating qparams for GPTQLinear!')
+ for _, module in model.named_modules():
+ if isinstance(module, GPTQLinear):
+ module.fix_qparams = False
+
+
+def disable_observer_linear(model):
+ print_log('Disable updating qparams for GPTQLinear!')
+ for _, module in model.named_modules():
+ if isinstance(module, GPTQLinear):
+ module.fix_qparams = True
+
+
+def del_redundant_attr(model):
+ print_log('Del redundant weight for GPTQLinear!')
+ for _, module in model.named_modules():
+ if isinstance(module, TritonGPTQLinear):
+ del module.weight
+
+
+def get_model(model):
+
+ def skip(*args, **kwargs):
+ pass
+
+ torch.nn.init.kaiming_uniform_ = skip
+ torch.nn.init.uniform_ = skip
+ torch.nn.init.normal_ = skip
+ model = OPTForCausalLM.from_pretrained(model, torch_dtype='auto')
+ model.seqlen = model.config.max_position_embeddings
+ return model
+
+
+if __name__ == '__main__':
+
+ import argparse
+ parser = argparse.ArgumentParser()
+
+ parser.add_argument('model', type=str, help='Llama model to load')
+ parser.add_argument(
+ '--dataset',
+ type=str,
+ choices=['wikitext2', 'ptb', 'c4'],
+ help='Where to extract calibration data from.')
+ parser.add_argument(
+ '--seed',
+ type=int,
+ default=0,
+ help='Seed for sampling the calibration data.')
+ parser.add_argument(
+ '--nsamples',
+ type=int,
+ default=128,
+ help='Number of calibration data samples.')
+ parser.add_argument(
+ '--batch_size',
+ type=int,
+ default=16,
+ help='Batchsize for calibration and evaluation.')
+ parser.add_argument(
+ '--save', type=str, default='', help='Path to saved model.')
+ parser.add_argument(
+ '--quant_ckpt', type=str, default='', help='Quantized ckpt to load.')
+ parser.add_argument(
+ '--dev', type=str, default='cuda:0', help='Use which device.')
+ parser.add_argument(
+ '-m',
+ type=bool,
+ default=False,
+ help='Whether to enable memory efficient forward')
+
+ args = parser.parse_args()
+
+ DEV = args.dev
+
+ model = get_model(args.model)
+ model.to(DEV)
+ model.eval()
+ print_log('load model over')
+
+ from mmrazor.implementations.quantization import gptq
+ compressor = gptq.GPTQCompressor()
+ # use_triton_ops is True
+ compressor.prepare(
+ model.model.layers,
+ quant_conv=True,
+ use_triton_ops=True,
+ quant_linear=True,
+ bits=4,
+ groupsize=128)
+
+ # # # quantize activation for linear
+ # # a_qconfig = dict(bits=4, perchannel=False, sym=False)
+ # compressor.prepare(
+ # model.model.decoder,
+ # quant_conv=True,
+ # quant_linear=True,
+ # use_triton_ops=False,
+ # # a_qconfig=a_qconfig
+ # )
+
+ if args.quant_ckpt:
+ del_redundant_attr(model)
+ model.load_state_dict(torch.load(args.quant_ckpt))
+ else:
+ dataloader, testloader = get_loaders(
+ args.dataset,
+ seed=args.seed,
+ model=args.model,
+ seqlen=model.seqlen)
+ print_log('load data for infer over')
+
+ compressor.init_hessian()
+ enable_observer_linear(model)
+ with memory_efficient_forward(
+ model, wrap_modules=[OPTDecoderLayer], enabled=args.m,
+ device=DEV):
+ compressor.register_hessian_hooks()
+ opt_infer(
+ model,
+ testloader,
+ DEV,
+ batch_size=args.batch_size,
+ num_samples=args.nsamples)
+ compressor.remove_hessian_hooks()
+ compressor.quant_with_default_qconfig(device=DEV)
+
+ disable_observer_linear(model)
+ with memory_efficient_forward(
+ model, wrap_modules=[OPTDecoderLayer], enabled=args.m, device=DEV):
+
+ # for dataset in ['wikitext2', 'ptb', 'c4']:
+ for dataset in ['wikitext2']:
+ dataloader, testloader = get_loaders(
+ dataset, seed=args.seed, model=args.model, seqlen=model.seqlen)
+ print_log(dataset)
+ opt_eval(model, testloader, DEV, batch_size=args.batch_size)
+
+ if args.save and not args.quant_ckpt:
+ print_log(f'save model in {args.save}')
+ torch.save(model.state_dict(), args.save)
diff --git a/projects/mmrazor_large/examples/language_models/OPT/opt_sparse_gpt.py b/projects/mmrazor_large/examples/language_models/OPT/opt_sparse_gpt.py
new file mode 100644
index 000000000..29d0947d3
--- /dev/null
+++ b/projects/mmrazor_large/examples/language_models/OPT/opt_sparse_gpt.py
@@ -0,0 +1,105 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+# Example for opt is converted from https://github.com/ist-daslab/sparsegpt
+import torch
+from datautils import get_loaders
+from transformers import OPTForCausalLM
+from transformers.models.opt.modeling_opt import OPTDecoderLayer
+from utils import opt_eval, opt_infer
+
+from mmrazor.implementations.pruning.sparse_gpt.utils import \
+ memory_efficient_forward
+from mmrazor.utils import print_log
+
+
+def get_model(model):
+ import torch
+
+ def skip(*args, **kwargs):
+ pass
+
+ torch.nn.init.kaiming_uniform_ = skip
+ torch.nn.init.uniform_ = skip
+ torch.nn.init.normal_ = skip
+ model = OPTForCausalLM.from_pretrained(model, torch_dtype='auto')
+ model.seqlen = model.config.max_position_embeddings
+ return model
+
+
+if __name__ == '__main__':
+
+ import argparse
+ parser = argparse.ArgumentParser()
+
+ parser.add_argument(
+ 'model', type=str, help='OPT model to load; pass `facebook/opt-X`.')
+ parser.add_argument(
+ 'dataset',
+ type=str,
+ choices=['wikitext2', 'ptb', 'c4'],
+ help='Where to extract calibration data from.')
+ parser.add_argument(
+ '--seed',
+ type=int,
+ default=0,
+ help='Seed for sampling the calibration data.')
+ parser.add_argument(
+ '--nsamples',
+ type=int,
+ default=128,
+ help='Number of calibration data samples.')
+ parser.add_argument(
+ '--batch_size',
+ type=int,
+ default=64,
+ help='Batchsize for calibration and evaluation.')
+ parser.add_argument(
+ '--save', type=str, default='', help='Path to saved model.')
+ parser.add_argument(
+ '-m',
+ type=bool,
+ default=False,
+ help='Whether to enable memory efficient forward')
+
+ args = parser.parse_args()
+
+ DEV = torch.device('cuda:0')
+
+ model = get_model(args.model)
+ model.eval()
+ print_log('load model over')
+
+ dataloader, testloader = get_loaders(
+ args.dataset, seed=args.seed, model=args.model, seqlen=model.seqlen)
+ print_log('load data for infer over')
+
+ from mmrazor.implementations.pruning import sparse_gpt
+ compressor = sparse_gpt.SparseGptCompressor()
+ compressor.prepare(model.model.decoder)
+
+ compressor.init_hessian()
+ with memory_efficient_forward(
+ model, wrap_modules=[OPTDecoderLayer], enabled=args.m):
+
+ compressor.register_hessian_hooks()
+ opt_infer(
+ model,
+ testloader,
+ DEV,
+ batch_size=args.batch_size,
+ num_samples=args.nsamples)
+ compressor.remove_hessian_hooks()
+ compressor.prune_24()
+
+ model = compressor.to_static_model(model)
+ if args.save:
+ print_log(f'save model in {args.save}')
+ model.save_pretrained(args.save)
+
+ with memory_efficient_forward(
+ model, wrap_modules=[OPTDecoderLayer], enabled=args.m):
+
+ for dataset in ['wikitext2', 'ptb', 'c4']:
+ dataloader, testloader = get_loaders(
+ dataset, seed=args.seed, model=args.model, seqlen=model.seqlen)
+ print_log(dataset)
+ opt_eval(model, testloader, DEV, batch_size=args.batch_size)
diff --git a/projects/mmrazor_large/examples/language_models/OPT/opt_sparse_gpt_fsdp.py b/projects/mmrazor_large/examples/language_models/OPT/opt_sparse_gpt_fsdp.py
new file mode 100644
index 000000000..e357be01a
--- /dev/null
+++ b/projects/mmrazor_large/examples/language_models/OPT/opt_sparse_gpt_fsdp.py
@@ -0,0 +1,198 @@
+import functools
+import os
+
+import torch
+import torch.distributed as dist
+import torch.multiprocessing as mp
+import torch.nn as nn
+from datautils import build_language_loader, get_loaders
+from opt_sparse_gpt import get_model
+from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
+from torch.distributed.fsdp.api import ShardingStrategy
+from torch.distributed.fsdp.fully_sharded_data_parallel import CPUOffload
+from torch.distributed.fsdp.wrap import size_based_auto_wrap_policy
+from utils import init_on_meta, opt_eval_fsdp, opt_infer_fsdp
+
+from mmrazor.implementations.pruning import sparse_gpt
+from mmrazor.utils import print_log
+
+
+def setup(rank, world_size):
+ os.environ['MASTER_ADDR'] = 'localhost'
+ os.environ['MASTER_PORT'] = '12356'
+
+ dist.init_process_group('nccl', rank=rank, world_size=world_size)
+ torch.cuda.set_device(rank)
+ print_log(f'init {rank}/{world_size}', only_rank0=False)
+
+
+def init_fn_wrapper(model: nn.Module, model_copy: nn.Module):
+
+ def find_module_in_model_copy(module: nn.Module):
+ name2module = dict(model.named_modules())
+ module2name = dict([(v, k) for k, v in name2module.items()])
+
+ name = module2name[module]
+ return dict(model_copy.named_modules())[name]
+
+ def _materialize_meta_module(module: nn.Module, ):
+
+ def meta_to_empty(p: torch.Tensor):
+ if p.device == torch.device('meta'):
+ return p.new_empty(p.shape, device='cpu')
+ else:
+ return p
+
+ module._apply(meta_to_empty)
+ if dist.get_rank() == 0:
+ assert model_copy is not None
+ module_copy = find_module_in_model_copy(module)
+
+ name2p = dict(module_copy.named_parameters(remove_duplicate=False))
+ for n, p in module.named_parameters():
+ if '_flat_param' not in n:
+ n = n.replace('_fsdp_wrapped_module.', '')
+ try:
+ p.data.copy_(name2p[n])
+ except Exception:
+ pass
+ name2p = dict(module_copy.named_buffers(remove_duplicate=False))
+ for n, p in module.named_buffers():
+ if '_flat_param' not in n:
+ n = n.replace('_fsdp_wrapped_module.', '')
+ try:
+ p.data.copy_(name2p[n])
+ except Exception:
+ pass
+
+ return _materialize_meta_module
+
+
+def main(rank, world_size=8, args=None):
+ setup(rank, world_size)
+
+ model_name = args.model
+ batch_size = args.batch_size
+
+ def build():
+ model = get_model(model_name)
+
+ # init mutator
+ mutator = sparse_gpt.SparseGptCompressor()
+ mutator.prepare(model.model.decoder)
+ return model, mutator
+
+ with init_on_meta(enable=True):
+ model, mutator = build()
+
+ if rank == 0:
+ model_copy, _ = build() # init on cpu
+ else:
+ model_copy = None
+
+ # init fsdp
+ size_based_auto_wrap_policy_x = functools.partial(
+ size_based_auto_wrap_policy, min_num_params=int(1e8))
+
+ model = FSDP(
+ model,
+ auto_wrap_policy=size_based_auto_wrap_policy_x,
+ cpu_offload=CPUOffload(True),
+ sharding_strategy=ShardingStrategy.FULL_SHARD,
+ device_id=rank,
+ param_init_fn=init_fn_wrapper(model, model_copy),
+ sync_module_states=True)
+ print_log(model)
+
+ # init hessian
+
+ mutator.init_hessian(device='cuda')
+ mutator.register_hessian_hooks(model)
+
+ _, testloader = get_loaders(
+ args.dataset, seed=args.seed, model=model_name, seqlen=model.seqlen)
+ testloader = build_language_loader(
+ testloader, world_size, rank, model, batch_size=batch_size)
+ opt_infer_fsdp(model, testloader)
+
+ mutator.remove_hessian_hooks()
+
+ # prune
+ name2module = dict(model.named_modules())
+ module2name = {}
+ module2name = dict([(v, k) for k, v in name2module.items()])
+
+ with torch.no_grad():
+ for fsdp in FSDP.fsdp_modules(model):
+ fsdp._reset_lazy_init()
+ with FSDP.summon_full_params(fsdp, recurse=False):
+ fsdp_name = module2name[fsdp]
+ for name, op in fsdp.named_modules():
+ if name.count('_fsdp_wrapped_module') <= 1:
+ if isinstance(op, sparse_gpt.SparseGptMixIn):
+ try:
+ op.prune(0.5, prunen=2, prunem=4)
+ print_log(
+ f'prune {fsdp_name}.{name} successfully.', # noqa
+ only_rank0=True)
+ except Exception as e:
+ print_log(
+ f'prune {fsdp_name}.{name} failed, as {e}', # noqa
+ only_rank0=True)
+ fsdp._reset_lazy_init()
+
+ # save
+ if args.save:
+ print_log(f'save model in {args.save}')
+ model._reset_lazy_init()
+ with FSDP.summon_full_params(model, rank0_only=True, writeback=False):
+ if dist.get_rank() == 0:
+ model.save_pretrained(args.save)
+
+ # val
+ torch.cuda.empty_cache()
+ model._reset_lazy_init()
+ for dataset in ['wikitext2', 'ptb', 'c4']:
+ _, testloader = get_loaders(
+ dataset, seed=args.seed, model=model_name, seqlen=model.seqlen)
+ testloader = build_language_loader(
+ testloader, world_size, rank, model, batch_size=batch_size)
+ print_log(dataset)
+ opt_eval_fsdp(model, testloader, torch.device('cuda'))
+
+
+if __name__ == '__main__':
+ import argparse
+ parser = argparse.ArgumentParser()
+
+ parser.add_argument(
+ 'model', type=str, help='OPT model to load; pass `facebook/opt-X`.')
+ parser.add_argument(
+ 'dataset',
+ type=str,
+ choices=['wikitext2', 'ptb', 'c4'],
+ help='Where to extract calibration data from.')
+ parser.add_argument(
+ '--seed',
+ type=int,
+ default=0,
+ help='Seed for sampling the calibration data.')
+ parser.add_argument(
+ '--nsamples',
+ type=int,
+ default=128,
+ help='Number of calibration data samples.')
+ parser.add_argument(
+ '--batch_size',
+ type=int,
+ default=64,
+ help='Batchsize for calibration and evaluation.')
+
+ parser.add_argument(
+ '--save', type=str, default='', help='Path to saved model.')
+ parser.add_argument(
+ '--world_size', type=int, default=1, help='Number of GPUs to use.')
+ args = parser.parse_args()
+
+ WORLD_SIZE = args.world_size
+ mp.spawn(main, args=(WORLD_SIZE, args), nprocs=WORLD_SIZE, join=True)
diff --git a/projects/mmrazor_large/examples/language_models/OPT/utils.py b/projects/mmrazor_large/examples/language_models/OPT/utils.py
new file mode 100644
index 000000000..a728a2268
--- /dev/null
+++ b/projects/mmrazor_large/examples/language_models/OPT/utils.py
@@ -0,0 +1,171 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+# Example for opt is converted from https://github.com/ist-daslab/sparsegpt
+import torch
+import torch.nn as nn
+from torch import distributed as dist
+from torch.utils.data import DataLoader
+from transformers import OPTForCausalLM
+
+from mmrazor.utils import print_log
+
+
+def fold_tokens(tokens: torch.Tensor, batch_seq_len=2048):
+ # tokens: 1 N
+ N = tokens.shape[1]
+ num_drop = N % batch_seq_len
+ if num_drop != 0:
+ tokens = tokens[:, :-num_drop]
+ tokens = tokens.reshape([-1, batch_seq_len]) # B N
+ return tokens
+
+
+@torch.no_grad()
+def opt_eval(model: OPTForCausalLM,
+ testenc,
+ dev=torch.device('cuda:0'),
+ batch_size=16):
+ print_log('Evaluating ...')
+
+ seqlen = model.seqlen
+
+ testenc: torch.Tensor = testenc.input_ids # type: ignore # 1, N
+ testenc = fold_tokens(testenc, seqlen) # B N
+
+ use_cache = model.config.use_cache
+ model.config.use_cache = False
+ nlls = []
+
+ for i, batch in enumerate(torch.split(testenc, batch_size)):
+ B = batch.shape[0]
+
+ batch = batch.to(dev)
+ out: torch.Tensor = model(batch)[0] # 1
+
+ shift_logits = out[:, :-1, :].contiguous().flatten(0, 1) # (B N) C
+ shift_labels = batch[:, 1:].flatten() # (B N)
+
+ loss_fct = nn.CrossEntropyLoss()
+ loss = loss_fct(shift_logits, shift_labels)
+ neg_log_likelihood = loss.float() * seqlen * B
+ nlls.append(neg_log_likelihood)
+
+ print_log(f'{(i+1)*batch_size} / {len(testenc)}')
+
+ ppl = torch.exp(torch.stack(nlls).sum() / (testenc.numel()))
+ print_log(f'Perplexity: {ppl.item():3f}')
+ model.config.use_cache = use_cache
+
+
+@torch.no_grad()
+def opt_infer(
+ model: OPTForCausalLM,
+ testenc,
+ dev,
+ batch_size=16,
+ num_samples=128,
+):
+ print_log('Infer ...')
+
+ seqlen = model.seqlen
+
+ testenc: torch.Tensor = testenc.input_ids # type: ignore # 1, N
+ testenc = fold_tokens(testenc, seqlen) # B N
+
+ model.config.use_cache = False
+
+ for i, batch in enumerate(torch.split(testenc, batch_size)):
+ batch = batch.to(dev)
+ _ = model(batch)[0] # 1
+ print_log(f'{(i+1)*batch_size} / {num_samples}')
+
+ if (i + 1) * batch_size >= num_samples:
+ break
+
+
+class init_on_meta:
+
+ def __init__(self, enable=True) -> None:
+ self.enable = enable
+ self.default_device = torch.ones([]).device
+
+ def __enter__(self):
+ if self.enable:
+ torch.set_default_device('meta')
+
+ def __exit__(self, exc_type, exc_value, traceback):
+ if self.enable:
+ torch.set_default_device(self.default_device)
+
+
+@torch.no_grad()
+def opt_eval_fsdp(
+ model: nn.Module,
+ dataloader: DataLoader,
+ dev=torch.device('cuda:0'),
+):
+ print_log('Evaluating ...')
+
+ use_cache = model.config.use_cache
+ model.config.use_cache = False
+ loss_sum = torch.zeros([1], device=dev)
+ total_seq_len = torch.zeros([1], device=dev, dtype=torch.long)
+
+ for i, batch in enumerate(dataloader):
+ B, seq_len = batch.shape[:2]
+
+ batch = batch.to(dev)
+ out: torch.Tensor = model(batch)[0] # 1
+
+ shift_logits = out[:, :-1, :].contiguous().flatten(0, 1) # (B N) C
+ shift_labels = batch[:, 1:].flatten() # (B N)
+
+ loss_fct = nn.CrossEntropyLoss()
+ loss = loss_fct(shift_logits, shift_labels)
+
+ neg_log_likelihood = loss.float() * seq_len * B
+ total_seq_len += seq_len * B
+ loss_sum += neg_log_likelihood
+
+ if dist.is_initialized():
+ world_size = dist.get_world_size()
+ else:
+ world_size = 1
+ infered_batch = (i + 1) * B * world_size
+
+ print_log(f'{infered_batch} / {len(dataloader.dataset)}')
+
+ if dist.is_initialized():
+ dist.all_reduce(loss_sum)
+ dist.all_reduce(total_seq_len)
+
+ ppl = torch.exp(loss_sum / total_seq_len)
+ print_log(f'Perplexity: {ppl.item():3f}')
+ model.config.use_cache = use_cache
+
+
+@torch.no_grad()
+def opt_infer_fsdp(
+ model: nn.Module,
+ dataloader: DataLoader,
+ dev=torch.device('cuda:0'),
+ num_samples=128,
+):
+ print_log('Infering ...')
+
+ model.config.use_cache = False
+
+ for i, batch in enumerate(dataloader):
+ B = batch.shape[0]
+
+ batch = batch.to(dev)
+ model(batch)[0] # 1
+
+ if dist.is_initialized():
+ world_size = dist.get_world_size()
+ else:
+ world_size = 1
+ infered_batch = (i + 1) * B * world_size
+
+ print_log(f'{infered_batch} / {len(dataloader.dataset)}')
+ if infered_batch >= num_samples:
+ break
diff --git a/requirements/tests.txt b/requirements/tests.txt
index 5980dc303..b025f5a67 100644
--- a/requirements/tests.txt
+++ b/requirements/tests.txt
@@ -7,5 +7,6 @@ nbformat
numpy < 1.24.0 # A temporary solution for tests with mmdet.
onnx
pytest
+triton==2.0.0
xdoctest >= 0.10.0
yapf
diff --git a/tests/test_impl/test_pruning/test_sparse_gpt/test_op.py b/tests/test_impl/test_pruning/test_sparse_gpt/test_op.py
new file mode 100644
index 000000000..636d64f67
--- /dev/null
+++ b/tests/test_impl/test_pruning/test_sparse_gpt/test_op.py
@@ -0,0 +1,75 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import unittest
+
+import torch
+import torch.nn as nn
+
+from mmrazor import digit_version
+from mmrazor.implementations.pruning import sparse_gpt
+
+
+class TestSparseGptOps(unittest.TestCase):
+
+ @torch.no_grad()
+ def test_op(self):
+ if digit_version(torch.__version__) < digit_version('1.12.0'):
+ self.skipTest('torch<1.12.0')
+
+ def get_loss(linear, linear1, data):
+ y = linear(data)
+ y1 = linear1(data)
+ return (y - y1).square().sum()
+
+ def infer(model, dataset):
+ for x in dataset:
+ model(x)
+
+ for device in ['cpu']:
+ device = torch.device(device)
+
+ # prepare
+
+ linear = nn.Linear(12, 20, bias=False).to(device)
+ sparse_linear = sparse_gpt.SparseGptLinear(
+ 12, 20, bias=False).to(device)
+ sparse_linear.load_state_dict(linear.state_dict(), strict=False)
+
+ random_data = torch.rand([10, 5, 12]).to(
+ device) # [loader_batch,batch,feature]
+ data_0 = random_data[0]
+
+ self.assertTrue(get_loss(linear, sparse_linear, data_0) == 0)
+
+ # prune
+
+ sparse_linear.init_hessian()
+ sparse_linear.register_hessian_hook()
+ infer(sparse_linear, random_data)
+ sparse_linear.remove_hessian_hook()
+
+ sparse_linear.prune(0.5)
+
+ # compare
+
+ print('norm:', linear(data_0).norm(2))
+ print('distance:', get_loss(linear, sparse_linear, data_0))
+
+ @torch.no_grad()
+ def test_model(self):
+ if digit_version(torch.__version__) < digit_version('1.12.0'):
+ self.skipTest('torch<1.12.0')
+ import torchvision
+ model = torchvision.models.resnet18()
+
+ mutator = sparse_gpt.SparseGptCompressor()
+ mutator.prepare(model)
+
+ x = torch.rand(10, 3, 224, 224)
+ mutator.init_hessian()
+ mutator.register_hessian_hooks()
+ model(x)
+ mutator.remove_hessian_hooks()
+ mutator.prune_24()
+
+ model = mutator.to_static_model(model)
+ assert type(model.conv1) is nn.Conv2d
diff --git a/tests/test_impl/test_quantization/test_gptq/test_op_gptq.py b/tests/test_impl/test_quantization/test_gptq/test_op_gptq.py
new file mode 100644
index 000000000..4928e0f17
--- /dev/null
+++ b/tests/test_impl/test_quantization/test_gptq/test_op_gptq.py
@@ -0,0 +1,80 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import unittest
+
+import torch
+import torch.nn as nn
+
+from mmrazor import digit_version
+from mmrazor.implementations.quantization import gptq
+
+
+class TestGPTQOps(unittest.TestCase):
+
+ @torch.no_grad()
+ def test_op(self):
+ if digit_version(torch.__version__) < digit_version(
+ '1.12.0') or not torch.cuda.is_available():
+ self.skipTest('torch<1.12.0')
+
+ def get_loss(linear, linear1, data):
+ y = linear(data)
+ y1 = linear1(data)
+ return (y - y1).square().sum()
+
+ def infer(model, dataset):
+ for x in dataset:
+ model(x)
+
+ for device in ['cpu']:
+ device = torch.device(device)
+
+ # prepare
+
+ linear = nn.Linear(12, 20, bias=False).to(device)
+ gptq_linear = gptq.GPTQLinear(
+ in_features=12, out_features=20, bias=False).to(device)
+ gptq_linear.load_state_dict(linear.state_dict(), strict=False)
+
+ random_data = torch.rand([10, 5, 12]).to(
+ device) # [loader_batch,batch,feature]
+ data_0 = random_data[0]
+
+ self.assertTrue(get_loss(linear, gptq_linear, data_0) == 0)
+
+ # quant
+
+ gptq_linear.init_hessian()
+ gptq_linear.register_hessian_hook()
+ infer(gptq_linear, random_data)
+ gptq_linear.remove_hessian_hook()
+
+ qconfig = dict(bits=4, perchannel=True, sym=False)
+ quantizer = gptq.Quantizer()
+ quantizer.configure(**qconfig)
+ gptq_linear.quant(quantizer=quantizer)
+
+ # compare
+
+ print('norm:', linear(data_0).norm(2))
+ print('distance:', get_loss(linear, gptq_linear, data_0))
+
+ @torch.no_grad()
+ def test_model(self):
+ if digit_version(torch.__version__) < digit_version(
+ '1.12.0') or not torch.cuda.is_available():
+ self.skipTest('torch<1.12.0')
+ import torchvision
+ model = torchvision.models.resnet18()
+
+ compressor = gptq.GPTQCompressor()
+ compressor.prepare(model, use_triton_ops=False)
+
+ x = torch.rand(10, 3, 224, 224)
+ compressor.init_hessian()
+ compressor.register_hessian_hooks()
+ model(x)
+ compressor.remove_hessian_hooks()
+ compressor.quant_with_default_qconfig()
+
+ model = compressor.to_static_model(model)
+ assert type(model.conv1) is nn.Conv2d