Skip to content

Commit

Permalink
[quant][pt2e] Add observer_or_fake_quant_ctr to QuantizationSpec (pyt…
Browse files Browse the repository at this point in the history
…orch#101920)

Summary:
This is the second refactor to align the annotation API with design,
next step is to change prepare_pt2e to consume QuantizationSpec object directly

Test Plan:
```
buck2 test mode/optcaffe2/test:quantization_pt2e -- --exact 'caffe2/test:quantization_pt2e - test_resnet18_with_quantizer_api (quantization.pt2e.test_quantize_pt2e.TestQuantizePT2EModels)'
```

Reviewed By: kimishpatel

Differential Revision: D45927416

Pull Request resolved: pytorch#101920
Approved by: https://github.com/andrewor14
  • Loading branch information
jerryzh168 authored and pytorchmergebot committed May 23, 2023
1 parent 8cab799 commit f7c736e
Show file tree
Hide file tree
Showing 4 changed files with 71 additions and 58 deletions.
63 changes: 46 additions & 17 deletions torch/ao/quantization/_pt2e/quantizer/qnnpack_quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import copy
import functools
import operator
from typing import Callable, Dict, List, Optional, Set
from typing import Callable, Dict, List, Optional, Set, Any

import torch
import torch._dynamo as torchdynamo
Expand All @@ -15,7 +15,6 @@
get_weight_obs_or_fq_ctr,
)

from torch.ao.quantization.observer import PlaceholderObserver
from torch.fx import Node

from torch.fx.passes.utils.source_matcher_utils import get_source_partitions
Expand All @@ -30,6 +29,17 @@
_annotate_input_qspec_map,
_annotate_output_qspec,
)
from torch.ao.quantization.fake_quantize import FusedMovingAvgObsFakeQuantize
from torch.ao.quantization.observer import (
HistogramObserver,
MinMaxObserver,
PerChannelMinMaxObserver,
MovingAverageMinMaxObserver,
MovingAveragePerChannelMinMaxObserver,
PlaceholderObserver,
)
from torch.ao.quantization.qconfig import _ObserverOrFakeQuantizeConstructor


__all__ = [
"QNNPackQuantizer",
Expand Down Expand Up @@ -125,25 +135,47 @@ def get_symmetric_quantization_config(
is_per_channel: bool = False,
is_qat: bool = False,
):
act_observer_or_fake_quant_ctr: _ObserverOrFakeQuantizeConstructor = \
FusedMovingAvgObsFakeQuantize if is_qat else HistogramObserver

act_quantization_spec = QuantizationSpec(
dtype=torch.int8,
quant_min=-128,
quant_max=127,
qscheme=torch.per_tensor_affine,
is_dynamic=False,
observer_or_fake_quant_ctr=act_observer_or_fake_quant_ctr.with_args(eps=2**-12),
)
qscheme = (
torch.per_channel_symmetric if is_per_channel else torch.per_tensor_symmetric
)
weight_observer_or_fake_quant_ctr: _ObserverOrFakeQuantizeConstructor = MinMaxObserver
if is_qat:
weight_observer_or_fake_quant_ctr = FusedMovingAvgObsFakeQuantize
elif is_per_channel:
weight_observer_or_fake_quant_ctr = PerChannelMinMaxObserver

extra_args: Dict[str, Any] = {"eps": 2**-12}
if is_qat:
if qscheme == torch.per_tensor_symmetric:
extra_args["observer"] = MovingAverageMinMaxObserver
else:
extra_args["observer"] = MovingAveragePerChannelMinMaxObserver # type: ignore[dict-item]
weight_quantization_spec = QuantizationSpec(
dtype=torch.int8,
quant_min=-127,
quant_max=127,
qscheme=qscheme,
ch_axis=0,
is_dynamic=False,
observer_or_fake_quant_ctr=weight_observer_or_fake_quant_ctr.with_args(**extra_args),
)

bias_observer_or_fake_quant_ctr: _ObserverOrFakeQuantizeConstructor = PlaceholderObserver
bias_quantization_spec = QuantizationSpec(
dtype=torch.float,
observer_or_fake_quant_ctr=bias_observer_or_fake_quant_ctr
)
bias_quantization_spec = QuantizationSpec(dtype=torch.float)
quantization_config = QuantizationConfig(
act_quantization_spec, weight_quantization_spec, bias_quantization_spec, is_qat
)
Expand All @@ -153,11 +185,6 @@ def get_symmetric_quantization_config(
def get_supported_config_and_operators() -> List[OperatorConfig]:
return get_supported_symmetric_config_and_operators()


def _get_default_obs_or_fq_ctr():
return PlaceholderObserver.with_args(dtype=torch.float)


def _is_annotated(nodes: List[Node]):
"""
Given a list of nodes (that represents an operator pattern),
Expand Down Expand Up @@ -221,18 +248,20 @@ def set_config_for_operator_type(
def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule:
"""just handling global spec for now"""
global_config = self.global_config
_QUANT_CONFIG_TO_ANNOTATOR[global_config](self, model)
# _QUANT_CONFIG_TO_ANNOTATOR[global_config](self, model)
# TODO: validate that global_config is supported
self.annotate_symmetric_config(model, global_config)

return model

@register_annotator(
[
get_symmetric_quantization_config(is_per_channel=False, is_qat=False),
get_symmetric_quantization_config(is_per_channel=False, is_qat=True),
get_symmetric_quantization_config(is_per_channel=True, is_qat=True),
get_symmetric_quantization_config(is_per_channel=True, is_qat=False),
]
)
# @register_annotator(
# [
# get_symmetric_quantization_config(is_per_channel=False, is_qat=False),
# get_symmetric_quantization_config(is_per_channel=False, is_qat=True),
# get_symmetric_quantization_config(is_per_channel=True, is_qat=True),
# get_symmetric_quantization_config(is_per_channel=True, is_qat=False),
# ]
# )
def annotate_symmetric_config(
self, model: torch.fx.GraphModule, config: QuantizationConfig
) -> torch.fx.GraphModule:
Expand Down
13 changes: 9 additions & 4 deletions torch/ao/quantization/_pt2e/quantizer/quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from dataclasses import asdict, dataclass, field
from torch.fx import Node
from typing import Callable, List, NamedTuple, Optional, Dict, Any
from torch.ao.quantization.qconfig import _ObserverOrFakeQuantizeConstructor

import torch

Expand All @@ -28,20 +29,23 @@
torch.uint8: torch.quint8,
torch.int32: torch.qint32,
torch.float16: torch.float16,
torch.float32: torch.float32,
}


@dataclass(eq=True, frozen=True)
class QuantizationSpec:
dtype: torch.dtype
is_dynamic: bool = False
# observer or fake_quantize constructor such as
# MinMaxObserver, PerChannelHistogramObserver etc.
# or we can attach some custom args to them
# e.g. MinMaxObserver.with_args(eps=eps)
observer_or_fake_quant_ctr: _ObserverOrFakeQuantizeConstructor
quant_min: Optional[int] = None
quant_max: Optional[int] = None
qscheme: Optional[torch.qscheme] = None
ch_axis: Optional[int] = None
# TODO: add this in a separate diff
# Kind of observer such as MinMaxObserver, PerChannelHistogramObserver etc.
# observer_or_fake_quant_type: Union[ObserverBase, FakeQuantizeBase]
is_dynamic: bool = False

def __post_init__(self):
# check dtype is one of the supported types
Expand Down Expand Up @@ -80,6 +84,7 @@ class QuantizationConfig:
activation: Optional[QuantizationSpec]
weight: Optional[QuantizationSpec]
bias: Optional[QuantizationSpec]
# TODO: remove, since we can use observer_or_fake_quant_ctr to express this
is_qat: bool = False

OperatorPatternType = List[Callable]
Expand Down
49 changes: 14 additions & 35 deletions torch/ao/quantization/_pt2e/quantizer/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,27 +4,28 @@
QuantizationConfig,
QuantizationSpec,
)
from torch.ao.quantization.fake_quantize import FusedMovingAvgObsFakeQuantize
from torch.ao.quantization.observer import (
HistogramObserver,
MinMaxObserver,
MovingAverageMinMaxObserver,
MovingAveragePerChannelMinMaxObserver,
PerChannelMinMaxObserver,
_PartialWrapper,
PlaceholderObserver,
)
from torch.ao.quantization.qconfig import _obs_or_fq_ctr_equals

def create_observer(observer_type, quantization_spec: QuantizationSpec, **extra_kwargs):
def create_observer(quantization_spec: QuantizationSpec, **extra_kwargs):
if quantization_spec is None:
return None
observer_or_fake_quant_ctr = quantization_spec.observer_or_fake_quant_ctr
kwargs = get_observer_kwargs(quantization_spec)
kwargs.pop("observer_or_fake_quant_ctr")
# we will remove is_dynamic from QuantizationSpec because
# it seems that dynamic range quantization
if observer_type != PlaceholderObserver:
if not _obs_or_fq_ctr_equals(observer_or_fake_quant_ctr, PlaceholderObserver):
kwargs.pop("is_dynamic")
if "PerChannel" not in observer_type.__name__:
obs_or_fq_class = observer_or_fake_quant_ctr
if isinstance(observer_or_fake_quant_ctr, _PartialWrapper):
obs_or_fq_class = observer_or_fake_quant_ctr.p.func # type: ignore[union-attr, assignment]
if "PerChannel" not in obs_or_fq_class.__name__: # type: ignore[operator, union-attr]
kwargs.pop("ch_axis")
return observer_type.with_args(**kwargs, **extra_kwargs)
return observer_or_fake_quant_ctr.with_args(**kwargs, **extra_kwargs)


def get_act_obs_or_fq_ctr(quantization_config: QuantizationConfig):
Expand All @@ -42,18 +43,7 @@ def get_act_obs_or_fq_ctr(quantization_config: QuantizationConfig):
raise Exception(
"Unsupported quantization_spec for activation: {}".format(quantization_spec)
)
if quantization_config.is_qat:
return create_observer(
FusedMovingAvgObsFakeQuantize,
quantization_spec,
reduce_range=False,
eps=2**-12,
)
else: # ptq
return create_observer(
HistogramObserver, quantization_spec, reduce_range=False, eps=2**-12
)

return create_observer(quantization_spec)

def get_weight_obs_or_fq_ctr(quantization_config: QuantizationConfig):
if quantization_config is None:
Expand All @@ -69,18 +59,7 @@ def get_weight_obs_or_fq_ctr(quantization_config: QuantizationConfig):
raise ValueError(
f"Unsupported quantization_spec {quantization_spec} for weight"
)
observer_type = MinMaxObserver
extra_args = {}
if quantization_config.is_qat:
observer_type = FusedMovingAvgObsFakeQuantize # type: ignore[assignment]
if quantization_spec.qscheme == torch.per_tensor_symmetric:
extra_args = {"observer": MovingAverageMinMaxObserver}
else:
extra_args = {"observer": MovingAveragePerChannelMinMaxObserver} # type: ignore[dict-item]
elif quantization_spec.qscheme == torch.per_channel_symmetric:
observer_type = PerChannelMinMaxObserver # type: ignore[assignment]
return create_observer(observer_type, quantization_spec, eps=2**-12, **extra_args)

return create_observer(quantization_spec)

def get_bias_obs_or_fq_ctr(quantization_config: QuantizationConfig):
if quantization_config is None:
Expand All @@ -92,4 +71,4 @@ def get_bias_obs_or_fq_ctr(quantization_config: QuantizationConfig):
assert (
quantization_spec.dtype == torch.float
), "Only float dtype for bias is supported for bias right now"
return PlaceholderObserver.with_args(dtype=quantization_spec.dtype)
return create_observer(quantization_spec)
4 changes: 2 additions & 2 deletions torch/ao/quantization/qconfig.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from collections import namedtuple
from typing import Optional, Any, Union
from typing import Optional, Any, Union, Type

import torch
import torch.nn as nn
Expand Down Expand Up @@ -491,7 +491,7 @@ def configure_constructor_to_put_obs_on_module_device(original_constructor):

return QConfig(activation, weight)

_ObserverOrFakeQuantizeConstructor = Union[_PartialWrapper, ObserverBase, FakeQuantizeBase]
_ObserverOrFakeQuantizeConstructor = Union[_PartialWrapper, Type[ObserverBase], Type[FakeQuantizeBase]]

def _obs_or_fq_ctr_equals(obs_or_fq1: _ObserverOrFakeQuantizeConstructor, obs_or_fq2: _ObserverOrFakeQuantizeConstructor):
if isinstance(obs_or_fq1, _PartialWrapper) and isinstance(obs_or_fq2, _PartialWrapper):
Expand Down

0 comments on commit f7c736e

Please sign in to comment.