From f7c736e1e7d3161304a3e93c3e36c55dcb5ebd05 Mon Sep 17 00:00:00 2001 From: Jerry Zhang Date: Tue, 23 May 2023 05:48:23 +0000 Subject: [PATCH] [quant][pt2e] Add observer_or_fake_quant_ctr to QuantizationSpec (#101920) Summary: This is the second refactor to align the annotation API with design, next step is to change prepare_pt2e to consume QuantizationSpec object directly Test Plan: ``` buck2 test mode/optcaffe2/test:quantization_pt2e -- --exact 'caffe2/test:quantization_pt2e - test_resnet18_with_quantizer_api (quantization.pt2e.test_quantize_pt2e.TestQuantizePT2EModels)' ``` Reviewed By: kimishpatel Differential Revision: D45927416 Pull Request resolved: https://github.com/pytorch/pytorch/pull/101920 Approved by: https://github.com/andrewor14 --- .../_pt2e/quantizer/qnnpack_quantizer.py | 63 ++++++++++++++----- .../quantization/_pt2e/quantizer/quantizer.py | 13 ++-- .../ao/quantization/_pt2e/quantizer/utils.py | 49 +++++---------- torch/ao/quantization/qconfig.py | 4 +- 4 files changed, 71 insertions(+), 58 deletions(-) diff --git a/torch/ao/quantization/_pt2e/quantizer/qnnpack_quantizer.py b/torch/ao/quantization/_pt2e/quantizer/qnnpack_quantizer.py index 10880f20705bdb..e82427fac601af 100644 --- a/torch/ao/quantization/_pt2e/quantizer/qnnpack_quantizer.py +++ b/torch/ao/quantization/_pt2e/quantizer/qnnpack_quantizer.py @@ -3,7 +3,7 @@ import copy import functools import operator -from typing import Callable, Dict, List, Optional, Set +from typing import Callable, Dict, List, Optional, Set, Any import torch import torch._dynamo as torchdynamo @@ -15,7 +15,6 @@ get_weight_obs_or_fq_ctr, ) -from torch.ao.quantization.observer import PlaceholderObserver from torch.fx import Node from torch.fx.passes.utils.source_matcher_utils import get_source_partitions @@ -30,6 +29,17 @@ _annotate_input_qspec_map, _annotate_output_qspec, ) +from torch.ao.quantization.fake_quantize import FusedMovingAvgObsFakeQuantize +from torch.ao.quantization.observer import ( + HistogramObserver, + MinMaxObserver, + PerChannelMinMaxObserver, + MovingAverageMinMaxObserver, + MovingAveragePerChannelMinMaxObserver, + PlaceholderObserver, +) +from torch.ao.quantization.qconfig import _ObserverOrFakeQuantizeConstructor + __all__ = [ "QNNPackQuantizer", @@ -125,16 +135,32 @@ def get_symmetric_quantization_config( is_per_channel: bool = False, is_qat: bool = False, ): + act_observer_or_fake_quant_ctr: _ObserverOrFakeQuantizeConstructor = \ + FusedMovingAvgObsFakeQuantize if is_qat else HistogramObserver + act_quantization_spec = QuantizationSpec( dtype=torch.int8, quant_min=-128, quant_max=127, qscheme=torch.per_tensor_affine, is_dynamic=False, + observer_or_fake_quant_ctr=act_observer_or_fake_quant_ctr.with_args(eps=2**-12), ) qscheme = ( torch.per_channel_symmetric if is_per_channel else torch.per_tensor_symmetric ) + weight_observer_or_fake_quant_ctr: _ObserverOrFakeQuantizeConstructor = MinMaxObserver + if is_qat: + weight_observer_or_fake_quant_ctr = FusedMovingAvgObsFakeQuantize + elif is_per_channel: + weight_observer_or_fake_quant_ctr = PerChannelMinMaxObserver + + extra_args: Dict[str, Any] = {"eps": 2**-12} + if is_qat: + if qscheme == torch.per_tensor_symmetric: + extra_args["observer"] = MovingAverageMinMaxObserver + else: + extra_args["observer"] = MovingAveragePerChannelMinMaxObserver # type: ignore[dict-item] weight_quantization_spec = QuantizationSpec( dtype=torch.int8, quant_min=-127, @@ -142,8 +168,14 @@ def get_symmetric_quantization_config( qscheme=qscheme, ch_axis=0, is_dynamic=False, + observer_or_fake_quant_ctr=weight_observer_or_fake_quant_ctr.with_args(**extra_args), + ) + + bias_observer_or_fake_quant_ctr: _ObserverOrFakeQuantizeConstructor = PlaceholderObserver + bias_quantization_spec = QuantizationSpec( + dtype=torch.float, + observer_or_fake_quant_ctr=bias_observer_or_fake_quant_ctr ) - bias_quantization_spec = QuantizationSpec(dtype=torch.float) quantization_config = QuantizationConfig( act_quantization_spec, weight_quantization_spec, bias_quantization_spec, is_qat ) @@ -153,11 +185,6 @@ def get_symmetric_quantization_config( def get_supported_config_and_operators() -> List[OperatorConfig]: return get_supported_symmetric_config_and_operators() - -def _get_default_obs_or_fq_ctr(): - return PlaceholderObserver.with_args(dtype=torch.float) - - def _is_annotated(nodes: List[Node]): """ Given a list of nodes (that represents an operator pattern), @@ -221,18 +248,20 @@ def set_config_for_operator_type( def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule: """just handling global spec for now""" global_config = self.global_config - _QUANT_CONFIG_TO_ANNOTATOR[global_config](self, model) + # _QUANT_CONFIG_TO_ANNOTATOR[global_config](self, model) + # TODO: validate that global_config is supported + self.annotate_symmetric_config(model, global_config) return model - @register_annotator( - [ - get_symmetric_quantization_config(is_per_channel=False, is_qat=False), - get_symmetric_quantization_config(is_per_channel=False, is_qat=True), - get_symmetric_quantization_config(is_per_channel=True, is_qat=True), - get_symmetric_quantization_config(is_per_channel=True, is_qat=False), - ] - ) + # @register_annotator( + # [ + # get_symmetric_quantization_config(is_per_channel=False, is_qat=False), + # get_symmetric_quantization_config(is_per_channel=False, is_qat=True), + # get_symmetric_quantization_config(is_per_channel=True, is_qat=True), + # get_symmetric_quantization_config(is_per_channel=True, is_qat=False), + # ] + # ) def annotate_symmetric_config( self, model: torch.fx.GraphModule, config: QuantizationConfig ) -> torch.fx.GraphModule: diff --git a/torch/ao/quantization/_pt2e/quantizer/quantizer.py b/torch/ao/quantization/_pt2e/quantizer/quantizer.py index 2f14a882dba38d..aeb726c0eadf1b 100644 --- a/torch/ao/quantization/_pt2e/quantizer/quantizer.py +++ b/torch/ao/quantization/_pt2e/quantizer/quantizer.py @@ -3,6 +3,7 @@ from dataclasses import asdict, dataclass, field from torch.fx import Node from typing import Callable, List, NamedTuple, Optional, Dict, Any +from torch.ao.quantization.qconfig import _ObserverOrFakeQuantizeConstructor import torch @@ -28,20 +29,23 @@ torch.uint8: torch.quint8, torch.int32: torch.qint32, torch.float16: torch.float16, + torch.float32: torch.float32, } @dataclass(eq=True, frozen=True) class QuantizationSpec: dtype: torch.dtype - is_dynamic: bool = False + # observer or fake_quantize constructor such as + # MinMaxObserver, PerChannelHistogramObserver etc. + # or we can attach some custom args to them + # e.g. MinMaxObserver.with_args(eps=eps) + observer_or_fake_quant_ctr: _ObserverOrFakeQuantizeConstructor quant_min: Optional[int] = None quant_max: Optional[int] = None qscheme: Optional[torch.qscheme] = None ch_axis: Optional[int] = None - # TODO: add this in a separate diff - # Kind of observer such as MinMaxObserver, PerChannelHistogramObserver etc. - # observer_or_fake_quant_type: Union[ObserverBase, FakeQuantizeBase] + is_dynamic: bool = False def __post_init__(self): # check dtype is one of the supported types @@ -80,6 +84,7 @@ class QuantizationConfig: activation: Optional[QuantizationSpec] weight: Optional[QuantizationSpec] bias: Optional[QuantizationSpec] + # TODO: remove, since we can use observer_or_fake_quant_ctr to express this is_qat: bool = False OperatorPatternType = List[Callable] diff --git a/torch/ao/quantization/_pt2e/quantizer/utils.py b/torch/ao/quantization/_pt2e/quantizer/utils.py index 7589420d424bda..41eba43b7c4ca5 100644 --- a/torch/ao/quantization/_pt2e/quantizer/utils.py +++ b/torch/ao/quantization/_pt2e/quantizer/utils.py @@ -4,27 +4,28 @@ QuantizationConfig, QuantizationSpec, ) -from torch.ao.quantization.fake_quantize import FusedMovingAvgObsFakeQuantize from torch.ao.quantization.observer import ( - HistogramObserver, - MinMaxObserver, - MovingAverageMinMaxObserver, - MovingAveragePerChannelMinMaxObserver, - PerChannelMinMaxObserver, + _PartialWrapper, PlaceholderObserver, ) +from torch.ao.quantization.qconfig import _obs_or_fq_ctr_equals -def create_observer(observer_type, quantization_spec: QuantizationSpec, **extra_kwargs): +def create_observer(quantization_spec: QuantizationSpec, **extra_kwargs): if quantization_spec is None: return None + observer_or_fake_quant_ctr = quantization_spec.observer_or_fake_quant_ctr kwargs = get_observer_kwargs(quantization_spec) + kwargs.pop("observer_or_fake_quant_ctr") # we will remove is_dynamic from QuantizationSpec because # it seems that dynamic range quantization - if observer_type != PlaceholderObserver: + if not _obs_or_fq_ctr_equals(observer_or_fake_quant_ctr, PlaceholderObserver): kwargs.pop("is_dynamic") - if "PerChannel" not in observer_type.__name__: + obs_or_fq_class = observer_or_fake_quant_ctr + if isinstance(observer_or_fake_quant_ctr, _PartialWrapper): + obs_or_fq_class = observer_or_fake_quant_ctr.p.func # type: ignore[union-attr, assignment] + if "PerChannel" not in obs_or_fq_class.__name__: # type: ignore[operator, union-attr] kwargs.pop("ch_axis") - return observer_type.with_args(**kwargs, **extra_kwargs) + return observer_or_fake_quant_ctr.with_args(**kwargs, **extra_kwargs) def get_act_obs_or_fq_ctr(quantization_config: QuantizationConfig): @@ -42,18 +43,7 @@ def get_act_obs_or_fq_ctr(quantization_config: QuantizationConfig): raise Exception( "Unsupported quantization_spec for activation: {}".format(quantization_spec) ) - if quantization_config.is_qat: - return create_observer( - FusedMovingAvgObsFakeQuantize, - quantization_spec, - reduce_range=False, - eps=2**-12, - ) - else: # ptq - return create_observer( - HistogramObserver, quantization_spec, reduce_range=False, eps=2**-12 - ) - + return create_observer(quantization_spec) def get_weight_obs_or_fq_ctr(quantization_config: QuantizationConfig): if quantization_config is None: @@ -69,18 +59,7 @@ def get_weight_obs_or_fq_ctr(quantization_config: QuantizationConfig): raise ValueError( f"Unsupported quantization_spec {quantization_spec} for weight" ) - observer_type = MinMaxObserver - extra_args = {} - if quantization_config.is_qat: - observer_type = FusedMovingAvgObsFakeQuantize # type: ignore[assignment] - if quantization_spec.qscheme == torch.per_tensor_symmetric: - extra_args = {"observer": MovingAverageMinMaxObserver} - else: - extra_args = {"observer": MovingAveragePerChannelMinMaxObserver} # type: ignore[dict-item] - elif quantization_spec.qscheme == torch.per_channel_symmetric: - observer_type = PerChannelMinMaxObserver # type: ignore[assignment] - return create_observer(observer_type, quantization_spec, eps=2**-12, **extra_args) - + return create_observer(quantization_spec) def get_bias_obs_or_fq_ctr(quantization_config: QuantizationConfig): if quantization_config is None: @@ -92,4 +71,4 @@ def get_bias_obs_or_fq_ctr(quantization_config: QuantizationConfig): assert ( quantization_spec.dtype == torch.float ), "Only float dtype for bias is supported for bias right now" - return PlaceholderObserver.with_args(dtype=quantization_spec.dtype) + return create_observer(quantization_spec) diff --git a/torch/ao/quantization/qconfig.py b/torch/ao/quantization/qconfig.py index 80f2f6dd768d96..dd17cd2275fc90 100644 --- a/torch/ao/quantization/qconfig.py +++ b/torch/ao/quantization/qconfig.py @@ -1,5 +1,5 @@ from collections import namedtuple -from typing import Optional, Any, Union +from typing import Optional, Any, Union, Type import torch import torch.nn as nn @@ -491,7 +491,7 @@ def configure_constructor_to_put_obs_on_module_device(original_constructor): return QConfig(activation, weight) -_ObserverOrFakeQuantizeConstructor = Union[_PartialWrapper, ObserverBase, FakeQuantizeBase] +_ObserverOrFakeQuantizeConstructor = Union[_PartialWrapper, Type[ObserverBase], Type[FakeQuantizeBase]] def _obs_or_fq_ctr_equals(obs_or_fq1: _ObserverOrFakeQuantizeConstructor, obs_or_fq2: _ObserverOrFakeQuantizeConstructor): if isinstance(obs_or_fq1, _PartialWrapper) and isinstance(obs_or_fq2, _PartialWrapper):