forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[quant][graphmode][fx] Add graph mode quantization on fx (pytorch#43175)
Summary: Pull Request resolved: pytorch#43175 This PR added graph mode quantization on fx: pytorch#42741 Currently it matches eager mode quantization for torchvision with static/dynamic/qat ddp/synbn test is still wip Test Plan: python test/test_quantization.py TestQuantizeFx Imported from OSS Reviewed By: vkuzo Differential Revision: D23178602 fbshipit-source-id: 8e7e0322846fbda2cfa79ad188abd7235326f879
- Loading branch information
1 parent
c89d2c6
commit dae2973
Showing
9 changed files
with
1,158 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,118 @@ | ||
import torch | ||
import torch.nn.functional as F | ||
|
||
# symbolic trace | ||
from torch.fx import symbolic_trace | ||
|
||
# graph mode quantization based on fx | ||
from torch.quantization._quantize_fx import ( | ||
Quantizer, | ||
fuse, | ||
) | ||
|
||
# eager mode quantization | ||
from torch.quantization import default_qconfig | ||
|
||
# test utils | ||
from torch.testing._internal.common_quantization import ( | ||
QuantizationTestCase, | ||
skipIfNoFBGEMM | ||
) | ||
|
||
class TestQuantizeFx(QuantizationTestCase): | ||
@skipIfNoFBGEMM | ||
def test_functional(self): | ||
""" Test quantizing functional conv and linear | ||
""" | ||
class Conv(torch.nn.Module): | ||
def __init__(self): | ||
super().__init__() | ||
self.stride = (1, 1) | ||
self.padding = (0, 0) | ||
self.dilation = (1, 1) | ||
self.groups = 1 | ||
|
||
def forward(self, x, weight): | ||
return F.conv2d(x, weight, None, self.stride, self.padding, self.dilation, self.groups) | ||
|
||
conv_input = torch.rand(1, 3, 224, 224) | ||
conv_weight = torch.rand(3, 3, 3, 3) | ||
|
||
class Linear(torch.nn.Module): | ||
def __init__(self): | ||
super().__init__() | ||
|
||
def forward(self, x, weight): | ||
return F.linear(x, weight) | ||
|
||
linear_input = torch.rand(8, 5) | ||
linear_weight = torch.rand(10, 5) | ||
|
||
class LinearModule(torch.nn.Module): | ||
def __init__(self): | ||
super().__init__() | ||
self.linear = torch.nn.Linear(5, 10) | ||
|
||
def forward(self, x): | ||
return self.linear(x) | ||
|
||
linear_module_input = torch.rand(8, 5) | ||
|
||
tests = [ | ||
(False, Conv, (conv_input, conv_weight), ('call_function', torch.ops.quantized.conv2d)), | ||
(True, Linear, (linear_input, linear_weight), ('call_function', torch.ops.quantized.linear_dynamic)), | ||
(False, Linear, (linear_input, linear_weight), ('call_function', torch.ops.quantized.linear)), | ||
(True, LinearModule, (linear_module_input,), ('call_module', torch.nn.quantized.dynamic.Linear)), | ||
(False, LinearModule, (linear_module_input,), ('call_module', torch.nn.quantized.Linear)), | ||
] | ||
|
||
for is_dynamic, M, inputs, quantized_node in tests: | ||
m = M().eval() | ||
qconfig = default_qconfig | ||
|
||
graph = symbolic_trace(m) | ||
script = torch.jit.script(graph) | ||
|
||
a = m(*inputs) | ||
b = graph(*inputs) | ||
c = script(*inputs) | ||
assert (a - b).abs().max() == 0 | ||
assert (a - c).abs().max() == 0 | ||
assert torch.allclose(a, b) | ||
assert torch.allclose(a, c) | ||
|
||
|
||
graph = fuse(graph) | ||
|
||
quantizer = Quantizer() | ||
qconfig_dict = {'': qconfig} | ||
if is_dynamic: | ||
prepared = quantizer.prepare_dynamic(graph, qconfig_dict) | ||
else: | ||
prepared = quantizer.prepare(graph, qconfig_dict) | ||
|
||
prepared(*inputs) | ||
|
||
qgraph = quantizer.convert(prepared) | ||
qgraph_debug = quantizer.convert(prepared, debug=True) | ||
qgraph.eval() | ||
qgraph_debug.eval() | ||
qgraph_script = torch.jit.script(qgraph) | ||
|
||
d = qgraph(*inputs) | ||
d_debug = qgraph_debug(*inputs) | ||
e = qgraph_script(*inputs) | ||
e_debug = qgraph_debug(*inputs) | ||
|
||
found = False | ||
modules = dict(qgraph.root.named_modules()) | ||
for node in qgraph.graph.nodes: | ||
if node.op == 'call_function': | ||
found = found or node.op == quantized_node[0] and node.target == quantized_node[1] | ||
elif node.op == 'call_module': | ||
found = found or node.op == quantized_node[0] and type(modules[node.target]) == quantized_node[1] | ||
assert found, 'Expected to find quantized node:' + str(quantized_op) | ||
# assert (a-d).abs().max() < 2 | ||
assert torch.allclose(d, e) | ||
assert (d - d_debug).abs().max() == 0 | ||
assert (e - e_debug).abs().max() == 0 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
from .fx import Quantizer # noqa: F401 | ||
from .fx import fuse # noqa: F401 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
from __future__ import absolute_import, division, print_function, unicode_literals | ||
from .quantize import Quantizer | ||
from .fuse import fuse |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,151 @@ | ||
import torch | ||
from torch.quantization.fuse_modules import ( | ||
fuse_conv_bn, | ||
fuse_conv_bn_relu, | ||
) | ||
|
||
from torch.fx import ( | ||
GraphModule, | ||
) | ||
|
||
from torch.fx.graph import ( | ||
Graph, | ||
map_arg, | ||
) | ||
|
||
from .pattern_utils import ( | ||
matches, | ||
register_fusion_pattern, | ||
get_fusion_patterns, | ||
) | ||
|
||
from .utils import _parent_name | ||
|
||
import copy | ||
|
||
# Fusion Patterns | ||
@register_fusion_pattern((torch.nn.BatchNorm2d, torch.nn.Conv2d)) | ||
@register_fusion_pattern((torch.nn.ReLU, torch.nn.Conv2d)) | ||
@register_fusion_pattern((torch.nn.functional.relu, torch.nn.Conv2d)) | ||
@register_fusion_pattern((torch.nn.ReLU, (torch.nn.BatchNorm2d, torch.nn.Conv2d))) | ||
@register_fusion_pattern((torch.nn.functional.relu, (torch.nn.BatchNorm2d, torch.nn.Conv2d))) | ||
class ConvBNReLUFusion(): | ||
def __init__(self, quantizer, node): | ||
super().__init__() | ||
self.relu_node = None | ||
self.bn_node = None | ||
if (node.op == 'call_function' and node.target is torch.nn.functional.relu) or \ | ||
(node.op == 'call_module' and type(quantizer.modules[node.target]) == torch.nn.ReLU): | ||
self.relu_node = node | ||
node = node.args[0] | ||
assert node.op == 'call_module' | ||
if isinstance(quantizer.modules[node.target], torch.nn.BatchNorm2d): | ||
self.bn_node = node | ||
self.bn = quantizer.modules[self.bn_node.target] | ||
node = node.args[0] | ||
assert node.op == 'call_module' and type(quantizer.modules[node.target]) == torch.nn.modules.Conv2d | ||
self.conv_node = node | ||
self.conv = quantizer.modules[self.conv_node.target] | ||
|
||
def fuse(self, quantizer, load_arg): | ||
conv_parent_name, conv_name = _parent_name(self.conv_node.target) | ||
if self.relu_node is not None: | ||
# since relu can be used multiple times, we'll need to create a relu module for each match | ||
if self.relu_node.op == 'call_module': | ||
relu = torch.nn.ReLU(quantizer.modules[self.relu_node.target].inplace) | ||
else: | ||
# TODO: get inplace argument from functional | ||
relu = torch.nn.ReLU() | ||
relu.training = self.conv.training | ||
if self.bn_node is not None: | ||
setattr(quantizer.modules[conv_parent_name], conv_name, fuse_conv_bn_relu(self.conv, self.bn, relu)) | ||
else: | ||
# conv_relu | ||
setattr(quantizer.modules[conv_parent_name], conv_name, torch.nn.intrinsic.ConvReLU2d(self.conv, relu)) | ||
else: | ||
assert self.bn_node is not None | ||
setattr(quantizer.modules[conv_parent_name], conv_name, fuse_conv_bn(self.conv, self.bn)) | ||
|
||
# TODO: do we need to make sure bn is only used once? | ||
if self.bn_node is not None: | ||
parent_name, name = _parent_name(self.bn_node.target) | ||
setattr(quantizer.modules[parent_name], name, torch.nn.Identity()) | ||
return quantizer.fused_graph.node_copy(self.conv_node, load_arg) | ||
|
||
@register_fusion_pattern((torch.nn.functional.relu, torch.nn.Linear)) | ||
@register_fusion_pattern((torch.nn.ReLU, torch.nn.Linear)) | ||
class LinearReLUFusion(): | ||
def __init__(self, quantizer, node): | ||
super().__init__() | ||
self.relu_node = node | ||
node = node.args[0] | ||
assert node.op == 'call_module' | ||
assert isinstance(quantizer.modules[node.target], torch.nn.modules.Linear) | ||
self.linear_node = node | ||
self.linear = quantizer.modules[self.linear_node.target] | ||
|
||
def fuse(self, quantizer, load_arg): | ||
# since relu can be used multiple times, we'll need to create a relu module for each match | ||
if self.relu_node.op == 'call_module': | ||
relu = torch.nn.ReLU(quantizer.modules[self.relu_node.target].inplace) | ||
else: | ||
# TODO: get inplace argument from functional | ||
relu = torch.nn.ReLU() | ||
relu.training = self.linear.training | ||
# linear_relu | ||
linear_parent_name, linear_name = _parent_name(self.linear_node.target) | ||
setattr(quantizer.modules[linear_parent_name], linear_name, torch.nn.intrinsic.LinearReLU(self.linear, relu)) | ||
return quantizer.fused_graph.node_copy(self.linear_node, load_arg) | ||
|
||
class Fuser: | ||
def fuse_conv_bn(self, model, inplace=False): | ||
input_root = model.root | ||
if not inplace: | ||
input_root = copy.deepcopy(input_root) | ||
input_graph = model.graph | ||
self.modules = dict(input_root.named_modules()) | ||
|
||
fusion_patterns = get_fusion_patterns() | ||
# find conv-bn pairs | ||
conv_bn_pairs = self._find_matches(input_root, input_graph, fusion_patterns) | ||
self.fused_graph = Graph() | ||
env = {} | ||
|
||
def load_arg(a): | ||
return map_arg(a, lambda node: env[node.name]) | ||
|
||
for node in input_graph.nodes: | ||
root_node, obj = conv_bn_pairs.get(node.name, (None, None)) | ||
if root_node is node: | ||
env[node.name] = obj.fuse(self, load_arg) | ||
elif root_node is None: | ||
env[node.name] = self.fused_graph.node_copy(node, load_arg) | ||
# node matched in patterns and is not root is removed here | ||
|
||
self.fused_graph.output(load_arg(input_graph.result)) | ||
return GraphModule(input_root, self.fused_graph) | ||
|
||
def _find_matches(self, root, graph, patterns): | ||
modules = dict(root.named_modules()) | ||
match_map = {} # node name -> (root_node, match_value?) | ||
|
||
def apply_match(pattern, node, match): | ||
if isinstance(pattern, tuple): | ||
s, *args = pattern | ||
apply_match(s, node, match) | ||
for subpattern, arg in zip(args, node.args): | ||
apply_match(subpattern, arg, match) | ||
else: | ||
match_map[node.name] = match | ||
|
||
for node in reversed(graph.nodes): | ||
if node.name not in match_map: | ||
for pattern, value in patterns.items(): | ||
if matches(modules, node, pattern): | ||
apply_match(pattern, node, (node, value(self, node))) | ||
|
||
return match_map | ||
|
||
def fuse(graph_module, inplace=False): | ||
fuser = Fuser() | ||
return fuser.fuse_conv_bn(graph_module, inplace) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,86 @@ | ||
|
||
import torch | ||
import sys | ||
from collections import OrderedDict | ||
|
||
# pattern for conv bn fusion | ||
FUSION_PATTERNS = OrderedDict() | ||
def register_fusion_pattern(pattern): | ||
def insert(fn): | ||
FUSION_PATTERNS[pattern] = fn | ||
return fn | ||
return insert | ||
|
||
def get_fusion_patterns(): | ||
return FUSION_PATTERNS | ||
|
||
# pattern for both static quantization and qat | ||
QUANTIZATION_PATTERNS = OrderedDict() | ||
def register_quant_pattern(pattern): | ||
def insert(fn): | ||
QUANTIZATION_PATTERNS[pattern] = fn | ||
return fn | ||
return insert | ||
|
||
def get_quant_patterns(): | ||
return QUANTIZATION_PATTERNS | ||
|
||
# pattern for dynamic quantization | ||
DYNAMIC_QUANTIZATION_PATTERNS = OrderedDict() | ||
def register_dynamic_pattern(pattern): | ||
def insert(fn): | ||
DYNAMIC_QUANTIZATION_PATTERNS[pattern] = fn | ||
return fn | ||
return insert | ||
|
||
def get_dynamic_quant_patterns(): | ||
return DYNAMIC_QUANTIZATION_PATTERNS | ||
|
||
# Example use of register pattern function: | ||
# @register_fusion_pattern(torch.nn.ReLU, (torch.nn.BatchNorm2d, torch.nn.Conv2d))) | ||
# class ConvBNReLUFusion(): | ||
# def __init__(...): | ||
# ... | ||
# | ||
# Note: The order of patterns is important! match function will take whatever is matched first, so we'll | ||
# need to put the fusion patterns before single patterns. For example, add_relu should be registered come before relu. | ||
# decorators are applied in the reverse order we see. Also when we match the nodes in the graph with these patterns, | ||
# we'll start from the last node of the graph and traverse back. | ||
|
||
|
||
def matches(modules, node, pattern, max_uses=sys.maxsize): | ||
""" Matches a node in fx against a pattern | ||
""" | ||
if isinstance(pattern, tuple): | ||
self_match, *arg_matches = pattern | ||
if self_match is getattr: | ||
assert len(pattern) == 2, 'Expecting getattr pattern to have two elements' | ||
arg_matches = [] | ||
else: | ||
self_match = pattern | ||
arg_matches = [] | ||
|
||
if node.uses > max_uses: | ||
return False | ||
|
||
if isinstance(self_match, type) and issubclass(self_match, torch.nn.Module): | ||
if node.op != 'call_module': | ||
return False | ||
if not type(modules[node.target]) == self_match: | ||
return False | ||
elif callable(self_match): | ||
if node.op != 'call_function' or node.target is not self_match: | ||
return False | ||
elif node.target is getattr: | ||
if node.args[1] != pattern[1]: | ||
return False | ||
elif node.target != self_match: | ||
return False | ||
|
||
if not arg_matches: | ||
return True | ||
|
||
if len(arg_matches) != len(node.args): | ||
return False | ||
|
||
return all(matches(modules, node, arg_match, max_uses=1) for node, arg_match in zip(node.args, arg_matches)) |
Oops, something went wrong.