From 4eb7a37589fa5efafd23072041135e22808603ce Mon Sep 17 00:00:00 2001 From: Ilyas Moutawwakil <57442720+IlyasMoutawwakil@users.noreply.github.com> Date: Fri, 31 Jan 2025 13:26:27 +0100 Subject: [PATCH] Auto quantization (#313) --- optimum_benchmark/backends/pytorch/backend.py | 60 +++++-------------- optimum_benchmark/backends/pytorch/config.py | 4 -- 2 files changed, 14 insertions(+), 50 deletions(-) diff --git a/optimum_benchmark/backends/pytorch/backend.py b/optimum_benchmark/backends/pytorch/backend.py index ea8aa0a1..dd11ddfd 100644 --- a/optimum_benchmark/backends/pytorch/backend.py +++ b/optimum_benchmark/backends/pytorch/backend.py @@ -8,15 +8,12 @@ from datasets import Dataset from safetensors.torch import save_file from transformers import ( - AwqConfig, - BitsAndBytesConfig, - GPTQConfig, - TorchAoConfig, Trainer, TrainerCallback, TrainerState, TrainingArguments, ) +from transformers.quantizers import AutoQuantizationConfig from ...import_utils import is_deepspeed_available, is_torch_distributed_available, is_zentorch_available from ..base import Backend @@ -286,8 +283,6 @@ def create_no_weights_model(self) -> None: def process_quantization_config(self) -> None: if self.is_gptq_quantized: - self.logger.info("\t+ Processing GPTQ config") - try: import exllamav2_kernels # noqa: F401 except ImportError: @@ -299,12 +294,7 @@ def process_quantization_config(self) -> None: "`optimum-benchmark` repository at `https://github.com/huggingface/optimum-benchmark`." ) - self.quantization_config = GPTQConfig( - **dict(getattr(self.pretrained_config, "quantization_config", {}), **self.config.quantization_config) - ) elif self.is_awq_quantized: - self.logger.info("\t+ Processing AWQ config") - try: import exlv2_ext # noqa: F401 except ImportError: @@ -316,55 +306,30 @@ def process_quantization_config(self) -> None: "`optimum-benchmark` repository at `https://github.com/huggingface/optimum-benchmark`." ) - self.quantization_config = AwqConfig( - **dict(getattr(self.pretrained_config, "quantization_config", {}), **self.config.quantization_config) - ) - elif self.is_bnb_quantized: - self.logger.info("\t+ Processing BitsAndBytes config") - self.quantization_config = BitsAndBytesConfig( - **dict(getattr(self.pretrained_config, "quantization_config", {}), **self.config.quantization_config) - ) - elif self.is_torchao_quantized: - self.logger.info("\t+ Processing TorchAO config") - self.quantization_config = TorchAoConfig( - **dict(getattr(self.pretrained_config, "quantization_config", {}), **self.config.quantization_config) - ) - else: - raise ValueError(f"Quantization scheme {self.config.quantization_scheme} not recognized") + self.logger.info("\t+ Processing AutoQuantization config") + self.quantization_config = AutoQuantizationConfig.from_dict( + dict(getattr(self.pretrained_config, "quantization_config", {}), **self.config.quantization_config) + ) @property def is_quantized(self) -> bool: return self.config.quantization_scheme is not None or ( hasattr(self.pretrained_config, "quantization_config") - and self.pretrained_config.quantization_config.get("quant_method", None) is not None - ) - - @property - def is_bnb_quantized(self) -> bool: - return self.config.quantization_scheme == "bnb" or ( - hasattr(self.pretrained_config, "quantization_config") - and self.pretrained_config.quantization_config.get("quant_method", None) == "bnb" + and self.pretrained_config.quantization_config.get("quant_method") is not None ) @property def is_gptq_quantized(self) -> bool: return self.config.quantization_scheme == "gptq" or ( hasattr(self.pretrained_config, "quantization_config") - and self.pretrained_config.quantization_config.get("quant_method", None) == "gptq" + and self.pretrained_config.quantization_config.get("quant_method") == "gptq" ) @property def is_awq_quantized(self) -> bool: return self.config.quantization_scheme == "awq" or ( hasattr(self.pretrained_config, "quantization_config") - and self.pretrained_config.quantization_config.get("quant_method", None) == "awq" - ) - - @property - def is_torchao_quantized(self) -> bool: - return self.config.quantization_scheme == "torchao" or ( - hasattr(self.pretrained_config, "quantization_config") - and self.pretrained_config.quantization_config.get("quant_method", None) == "torchao" + and self.pretrained_config.quantization_config.get("quant_method") == "awq" ) @property @@ -376,11 +341,11 @@ def is_exllamav2(self) -> bool: ( hasattr(self.pretrained_config, "quantization_config") and hasattr(self.pretrained_config.quantization_config, "exllama_config") - and self.pretrained_config.quantization_config.exllama_config.get("version", None) == 2 + and self.pretrained_config.quantization_config.exllama_config.get("version") == 2 ) or ( "exllama_config" in self.config.quantization_config - and self.config.quantization_config["exllama_config"].get("version", None) == 2 + and self.config.quantization_config["exllama_config"].get("version") == 2 ) ) ) @@ -390,7 +355,10 @@ def automodel_kwargs(self) -> Dict[str, Any]: kwargs = {} if self.config.torch_dtype is not None: - kwargs["torch_dtype"] = getattr(torch, self.config.torch_dtype) + if hasattr(torch, self.config.torch_dtype): + kwargs["torch_dtype"] = getattr(torch, self.config.torch_dtype) + else: + kwargs["torch_dtype"] = self.config.torch_dtype if self.is_quantized: kwargs["quantization_config"] = self.quantization_config diff --git a/optimum_benchmark/backends/pytorch/config.py b/optimum_benchmark/backends/pytorch/config.py index ec48f639..61f9dfc0 100644 --- a/optimum_benchmark/backends/pytorch/config.py +++ b/optimum_benchmark/backends/pytorch/config.py @@ -5,7 +5,6 @@ from ...system_utils import is_rocm_system from ..config import BackendConfig -DEVICE_MAPS = ["auto", "sequential"] AMP_DTYPES = ["bfloat16", "float16"] TORCH_DTYPES = ["bfloat16", "float16", "float32", "auto"] @@ -60,9 +59,6 @@ def __post_init__(self): "Please remove it from the `model_kwargs` and set it in the backend config directly." ) - if self.device_map is not None and self.device_map not in DEVICE_MAPS: - raise ValueError(f"`device_map` must be one of {DEVICE_MAPS}. Got {self.device_map} instead.") - if self.torch_dtype is not None and self.torch_dtype not in TORCH_DTYPES: raise ValueError(f"`torch_dtype` must be one of {TORCH_DTYPES}. Got {self.torch_dtype} instead.")