Skip to content

Commit

Permalink
Force callers to manually provide GPTQ args (#2795)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #2795

Default values for quant are convenient, but have caused confusion.

Since there are no active users of GPTQ, we want to pre-empt potential ambiguity by making the following fields required and manually specified:
* Group Size
* Calibration Limit
* Calibration Sequence Length

 ---

Note: Group Size's default value is untouched since it is utilized by int4 quant and has active callers

Reviewed By: mergennachin

Differential Revision: D55605859

fbshipit-source-id: 7747435b0c462d341b466659ec79eda33a9893ec
  • Loading branch information
Jack-Khuu authored and facebook-github-bot committed Apr 2, 2024
1 parent 06668cf commit 31d5c61
Showing 1 changed file with 22 additions and 9 deletions.
31 changes: 22 additions & 9 deletions examples/models/llama2/export_llama_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

from functools import partial
from pathlib import Path
from typing import List, Optional, Union
from typing import Any, List, Optional, Union

import pkg_resources
import torch
Expand Down Expand Up @@ -214,12 +214,12 @@ def quantize(
qmode: str,
activation_dtype: Optional[DType],
checkpoint_path: Optional[Path] = None,
# following arguments only available when setting int4 quantization.
group_size: int = 128,
# following arguments only used for GPTQ
# following arguments only available when setting int4 or gptq quantization.
group_size: Optional[int] = 128,
# following arguments are only used for GPTQ
calibration_tasks: Optional[list] = None,
calibration_limit: int = 100,
calibration_seq_length: int = 2048,
calibration_limit: Optional[int] = None,
calibration_seq_length: Optional[int] = None,
pad_calibration_inputs: bool = False,
percdamp: float = 0.01,
blocksize: int = 128,
Expand All @@ -245,13 +245,13 @@ def quantize(
# if checkpoint_path is None:
# checkpoint_path = Path("checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth")

if calibration_tasks is None:
calibration_tasks = ["wikitext"]

if qmode == "int8":
# Add quantization mode options here: group size, bit width, etc.
return WeightOnlyInt8QuantHandler(model).quantized_model()
elif qmode == "8da4w":
# Check for required args
if group_size is None:
raise Exception("For 8da4w quantization, group size must be specified.")
from torchao.quantization.quant_api import Int8DynActInt4WeightQuantizer

model = Int8DynActInt4WeightQuantizer(
Expand All @@ -261,6 +261,19 @@ def quantize(
print("quantized model:", model)
return model
elif qmode == "8da4w-gptq":
# Check for required args
required_args: Optional[Any] = [
group_size,
calibration_limit,
calibration_seq_length,
]
if any(arg is None for arg in required_args):
raise Exception(
"For 8da4w-gptq quantization, group size, calibration limit and calibration sequence length must be specified."
)
if calibration_tasks is None:
calibration_tasks = ["wikitext"]

from torchao.quantization.quant_api import Int8DynActInt4WeightGPTQQuantizer

if tokenizer_path is None:
Expand Down

0 comments on commit 31d5c61

Please sign in to comment.