Skip to content

Commit

Permalink
Auto quantization (#313)
Browse files Browse the repository at this point in the history
  • Loading branch information
IlyasMoutawwakil authored Jan 31, 2025
1 parent 92cd2b2 commit 4eb7a37
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 50 deletions.
60 changes: 14 additions & 46 deletions optimum_benchmark/backends/pytorch/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,12 @@
from datasets import Dataset
from safetensors.torch import save_file
from transformers import (
AwqConfig,
BitsAndBytesConfig,
GPTQConfig,
TorchAoConfig,
Trainer,
TrainerCallback,
TrainerState,
TrainingArguments,
)
from transformers.quantizers import AutoQuantizationConfig

from ...import_utils import is_deepspeed_available, is_torch_distributed_available, is_zentorch_available
from ..base import Backend
Expand Down Expand Up @@ -286,8 +283,6 @@ def create_no_weights_model(self) -> None:

def process_quantization_config(self) -> None:
if self.is_gptq_quantized:
self.logger.info("\t+ Processing GPTQ config")

try:
import exllamav2_kernels # noqa: F401
except ImportError:
Expand All @@ -299,12 +294,7 @@ def process_quantization_config(self) -> None:
"`optimum-benchmark` repository at `https://github.com/huggingface/optimum-benchmark`."
)

self.quantization_config = GPTQConfig(
**dict(getattr(self.pretrained_config, "quantization_config", {}), **self.config.quantization_config)
)
elif self.is_awq_quantized:
self.logger.info("\t+ Processing AWQ config")

try:
import exlv2_ext # noqa: F401
except ImportError:
Expand All @@ -316,55 +306,30 @@ def process_quantization_config(self) -> None:
"`optimum-benchmark` repository at `https://github.com/huggingface/optimum-benchmark`."
)

self.quantization_config = AwqConfig(
**dict(getattr(self.pretrained_config, "quantization_config", {}), **self.config.quantization_config)
)
elif self.is_bnb_quantized:
self.logger.info("\t+ Processing BitsAndBytes config")
self.quantization_config = BitsAndBytesConfig(
**dict(getattr(self.pretrained_config, "quantization_config", {}), **self.config.quantization_config)
)
elif self.is_torchao_quantized:
self.logger.info("\t+ Processing TorchAO config")
self.quantization_config = TorchAoConfig(
**dict(getattr(self.pretrained_config, "quantization_config", {}), **self.config.quantization_config)
)
else:
raise ValueError(f"Quantization scheme {self.config.quantization_scheme} not recognized")
self.logger.info("\t+ Processing AutoQuantization config")
self.quantization_config = AutoQuantizationConfig.from_dict(
dict(getattr(self.pretrained_config, "quantization_config", {}), **self.config.quantization_config)
)

@property
def is_quantized(self) -> bool:
return self.config.quantization_scheme is not None or (
hasattr(self.pretrained_config, "quantization_config")
and self.pretrained_config.quantization_config.get("quant_method", None) is not None
)

@property
def is_bnb_quantized(self) -> bool:
return self.config.quantization_scheme == "bnb" or (
hasattr(self.pretrained_config, "quantization_config")
and self.pretrained_config.quantization_config.get("quant_method", None) == "bnb"
and self.pretrained_config.quantization_config.get("quant_method") is not None
)

@property
def is_gptq_quantized(self) -> bool:
return self.config.quantization_scheme == "gptq" or (
hasattr(self.pretrained_config, "quantization_config")
and self.pretrained_config.quantization_config.get("quant_method", None) == "gptq"
and self.pretrained_config.quantization_config.get("quant_method") == "gptq"
)

@property
def is_awq_quantized(self) -> bool:
return self.config.quantization_scheme == "awq" or (
hasattr(self.pretrained_config, "quantization_config")
and self.pretrained_config.quantization_config.get("quant_method", None) == "awq"
)

@property
def is_torchao_quantized(self) -> bool:
return self.config.quantization_scheme == "torchao" or (
hasattr(self.pretrained_config, "quantization_config")
and self.pretrained_config.quantization_config.get("quant_method", None) == "torchao"
and self.pretrained_config.quantization_config.get("quant_method") == "awq"
)

@property
Expand All @@ -376,11 +341,11 @@ def is_exllamav2(self) -> bool:
(
hasattr(self.pretrained_config, "quantization_config")
and hasattr(self.pretrained_config.quantization_config, "exllama_config")
and self.pretrained_config.quantization_config.exllama_config.get("version", None) == 2
and self.pretrained_config.quantization_config.exllama_config.get("version") == 2
)
or (
"exllama_config" in self.config.quantization_config
and self.config.quantization_config["exllama_config"].get("version", None) == 2
and self.config.quantization_config["exllama_config"].get("version") == 2
)
)
)
Expand All @@ -390,7 +355,10 @@ def automodel_kwargs(self) -> Dict[str, Any]:
kwargs = {}

if self.config.torch_dtype is not None:
kwargs["torch_dtype"] = getattr(torch, self.config.torch_dtype)
if hasattr(torch, self.config.torch_dtype):
kwargs["torch_dtype"] = getattr(torch, self.config.torch_dtype)
else:
kwargs["torch_dtype"] = self.config.torch_dtype

if self.is_quantized:
kwargs["quantization_config"] = self.quantization_config
Expand Down
4 changes: 0 additions & 4 deletions optimum_benchmark/backends/pytorch/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from ...system_utils import is_rocm_system
from ..config import BackendConfig

DEVICE_MAPS = ["auto", "sequential"]
AMP_DTYPES = ["bfloat16", "float16"]
TORCH_DTYPES = ["bfloat16", "float16", "float32", "auto"]

Expand Down Expand Up @@ -60,9 +59,6 @@ def __post_init__(self):
"Please remove it from the `model_kwargs` and set it in the backend config directly."
)

if self.device_map is not None and self.device_map not in DEVICE_MAPS:
raise ValueError(f"`device_map` must be one of {DEVICE_MAPS}. Got {self.device_map} instead.")

if self.torch_dtype is not None and self.torch_dtype not in TORCH_DTYPES:
raise ValueError(f"`torch_dtype` must be one of {TORCH_DTYPES}. Got {self.torch_dtype} instead.")

Expand Down

0 comments on commit 4eb7a37

Please sign in to comment.