Skip to content

Commit

Permalink
[ONNX] Support restricted quantized range for activation.
Browse files Browse the repository at this point in the history
PyTorch restricts activations to be in the range (0, 127).
In ONNX, the supported ranges are (0, 255) and (-128, 127),
respectfully, uint8 and int8. This PR extends support for range
(0, 127), by adding additional clipping when detected.

Pull Request resolved: pytorch#76055

Approved by: https://github.com/garymm
  • Loading branch information
BowenBao authored and pytorchmergebot committed Apr 25, 2022
1 parent cada2cd commit 8d31706
Show file tree
Hide file tree
Showing 3 changed files with 49 additions and 10 deletions.
26 changes: 26 additions & 0 deletions test/onnx/test_pytorch_onnx_onnxruntime.py
Original file line number Diff line number Diff line change
Expand Up @@ -8731,6 +8731,32 @@ def forward(self, input):
x = torch.randn(6, 4, 3, 3)
self.run_test(FakeQuantizePerChannelModel(), (x))

@skipIfUnsupportedMinOpsetVersion(13)
@disableScriptTest() # RuntimeError: Can't redefine method: forward on class: __torch__.torch.nn.modules.linear.Linear
def test_fake_quantize_activation(self):
from torch import quantization
m = torch.nn.Linear(1, 1)
m.qconfig = quantization.QConfig(
activation=quantization.default_fake_quant,
weight=quantization.default_per_channel_weight_fake_quant)
quantization.prepare_qat(m.train(), inplace=True)
m.apply(quantization.enable_observer)
m.apply(quantization.enable_fake_quant)
for module in m.modules():
if isinstance(module, quantization.FakeQuantize):
module.calculate_qparams()

m.apply(quantization.disable_observer)
m.eval()

# Fake quantize activation is a special case, as it restricts quantized range to be (0, 127),
# while standard 8bit quantization range is (-128, 127) or (0, 255).
# Set fixed weight, bias and inputs to test if ONNX handles the overflow correctly.
m.weight = torch.nn.Parameter(torch.tensor([[1.], [1.], [1.]]))
m.bias = torch.nn.Parameter(torch.tensor([0.]))
x = torch.tensor([[150.], [127.], [-5.]])
self.run_test(m, x)

def test_batchnorm_training(self):
class MyModule(torch.nn.Module):
def __init__(self):
Expand Down
6 changes: 6 additions & 0 deletions torch/onnx/symbolic_opset10.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,6 +300,12 @@ def embedding_bag(g,

@parse_args("v", "v", "v", "i", "i")
def fake_quantize_per_tensor_affine(g, inputs, scale, zero_point, quant_min=-128, quant_max=127):
# NOTE: (0, 127) is a special case. PyTorch restricts activations to be in the range (0, 127).
# https://github.com/pytorch/pytorch/blob/b34b192d6b97325c9f78e5995c48c8498ede34bd/torch/ao/quantization/observer.py#L1422
if (quant_min, quant_max) == (0, 127):
sym_help._onnx_opset_unsupported_detailed(
"fake_quantize_per_tensor_affine", 10, 13,
"Quantize range (0, 127) not supported, requires opset 13 Clip")
if (quant_min, quant_max) not in [(0, 255), (-128, 127)]:
raise RuntimeError(
"For (quant_min, quant_max), ONNX allows only (0, 255) and (-128, 127). "
Expand Down
27 changes: 17 additions & 10 deletions torch/onnx/symbolic_opset13.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from torch.onnx.symbolic_helper import parse_args, _unimplemented
from torch.onnx.symbolic_opset9 import (overload_by_arg_count, _maybe_cast_reduce_op_input,
nonzero, expand, zeros, ones, size, linear, conv2d,
relu)
relu, unused)
from torch.onnx.symbolic_opset11 import unsqueeze
from torch.onnx.utils import _add_block, _add_input_to_block, _add_output_to_block

Expand Down Expand Up @@ -132,33 +132,40 @@ def where(g, condition, self=None, other=None, _outputs=None):

@parse_args("v", "v", "v", "i", "i", "i")
def fake_quantize_per_channel_affine(g, inputs, scale, zero_point, axis, quant_min=-128, quant_max=127):
if (quant_min, quant_max) not in [(0, 255), (-128, 127)]:
# NOTE: (0, 127) is allowed as special case. PyTorch restricts activations to be in the range (0, 127).
# https://github.com/pytorch/pytorch/blob/b34b192d6b97325c9f78e5995c48c8498ede34bd/torch/ao/quantization/observer.py#L1422
if (quant_min, quant_max) not in [(0, 255), (-128, 127), (0, 127)]:
raise RuntimeError(
"For (quant_min, quant_max), ONNX allows only (0, 255) and (-128, 127). "
"For (quant_min, quant_max), ONNX allows only (0, 127), (0, 255) and (-128, 127). "
"Got ({}, {})".format(quant_min, quant_max))
# ONNX defines zero_point to be int8 or uint8
if quant_min == 0:
zero_point = g.op("Cast", zero_point, to_i=torch.onnx.TensorProtoDataType.UINT8)
else:
zero_point = g.op("Cast", zero_point, to_i=torch.onnx.TensorProtoDataType.INT8)
return g.op(
"DequantizeLinear",
g.op("QuantizeLinear", inputs, scale, zero_point, axis_i=axis),
scale, zero_point, axis_i=axis)
quantized = g.op("QuantizeLinear", inputs, scale, zero_point, axis_i=axis)
if (quant_min, quant_max) == (0, 127):
quantized = g.op("Clip", quantized, unused(g), g.op("Constant", value_t=torch.tensor(127, dtype=torch.uint8)))
return g.op("DequantizeLinear", quantized, scale, zero_point, axis_i=axis)

@parse_args("v", "v", "v", "i", "i")
def fake_quantize_per_tensor_affine(g, inputs, scale, zero_point, quant_min=-128, quant_max=127):
if (quant_min, quant_max) not in [(0, 255), (-128, 127)]:
# NOTE: (0, 127) is allowed as special case. PyTorch restricts activations to be in the range (0, 127).
# https://github.com/pytorch/pytorch/blob/b34b192d6b97325c9f78e5995c48c8498ede34bd/torch/ao/quantization/observer.py#L1422
if (quant_min, quant_max) not in [(0, 255), (-128, 127), (0, 127)]:
raise RuntimeError(
"For (quant_min, quant_max), ONNX allows only (0, 255) and (-128, 127). "
"For (quant_min, quant_max), ONNX allows only (0, 127), (0, 255) and (-128, 127). "
"Got ({}, {})".format(quant_min, quant_max))
if quant_min == 0:
zero_point = g.op("Cast", zero_point, to_i=torch.onnx.TensorProtoDataType.UINT8)
else:
zero_point = g.op("Cast", zero_point, to_i=torch.onnx.TensorProtoDataType.INT8)
if scale.type().scalarType() != "Float":
scale = g.op("Cast", scale, to_i=torch.onnx.TensorProtoDataType.FLOAT)
return g.op("DequantizeLinear", g.op("QuantizeLinear", inputs, scale, zero_point), scale, zero_point)
quantized = g.op("QuantizeLinear", inputs, scale, zero_point)
if (quant_min, quant_max) == (0, 127):
quantized = g.op("Clip", quantized, unused(g), g.op("Constant", value_t=torch.tensor(127, dtype=torch.uint8)))
return g.op("DequantizeLinear", quantized, scale, zero_point)

def _reduce_op_symbolic(onnx_op_name):
def symbolic(g, self, dim=None, keepdim=None):
Expand Down

0 comments on commit 8d31706

Please sign in to comment.