Skip to content

Commit

Permalink
[quant][graphmode][fx] Add graph mode quantization on fx (pytorch#43175)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#43175

This PR added graph mode quantization on fx: pytorch#42741
Currently it matches eager mode quantization for torchvision with static/dynamic/qat
ddp/synbn test is still wip

Test Plan:
python test/test_quantization.py TestQuantizeFx

Imported from OSS

Reviewed By: vkuzo

Differential Revision: D23178602

fbshipit-source-id: 8e7e0322846fbda2cfa79ad188abd7235326f879
  • Loading branch information
jerryzh168 authored and facebook-github-bot committed Aug 20, 2020
1 parent c89d2c6 commit dae2973
Show file tree
Hide file tree
Showing 9 changed files with 1,158 additions and 0 deletions.
6 changes: 6 additions & 0 deletions mypy.ini
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,12 @@ ignore_errors = True
[mypy-torch.quantization._numeric_suite]
ignore_errors = True

[mypy-torch.quantization._quantize_fx]
ignore_errors = True

[mypy-torch.quantization.fx.*]
ignore_errors = True

[mypy-torch.quasirandom]
ignore_errors = True

Expand Down
118 changes: 118 additions & 0 deletions test/quantization/test_quantize_fx.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
import torch
import torch.nn.functional as F

# symbolic trace
from torch.fx import symbolic_trace

# graph mode quantization based on fx
from torch.quantization._quantize_fx import (
Quantizer,
fuse,
)

# eager mode quantization
from torch.quantization import default_qconfig

# test utils
from torch.testing._internal.common_quantization import (
QuantizationTestCase,
skipIfNoFBGEMM
)

class TestQuantizeFx(QuantizationTestCase):
@skipIfNoFBGEMM
def test_functional(self):
""" Test quantizing functional conv and linear
"""
class Conv(torch.nn.Module):
def __init__(self):
super().__init__()
self.stride = (1, 1)
self.padding = (0, 0)
self.dilation = (1, 1)
self.groups = 1

def forward(self, x, weight):
return F.conv2d(x, weight, None, self.stride, self.padding, self.dilation, self.groups)

conv_input = torch.rand(1, 3, 224, 224)
conv_weight = torch.rand(3, 3, 3, 3)

class Linear(torch.nn.Module):
def __init__(self):
super().__init__()

def forward(self, x, weight):
return F.linear(x, weight)

linear_input = torch.rand(8, 5)
linear_weight = torch.rand(10, 5)

class LinearModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear = torch.nn.Linear(5, 10)

def forward(self, x):
return self.linear(x)

linear_module_input = torch.rand(8, 5)

tests = [
(False, Conv, (conv_input, conv_weight), ('call_function', torch.ops.quantized.conv2d)),
(True, Linear, (linear_input, linear_weight), ('call_function', torch.ops.quantized.linear_dynamic)),
(False, Linear, (linear_input, linear_weight), ('call_function', torch.ops.quantized.linear)),
(True, LinearModule, (linear_module_input,), ('call_module', torch.nn.quantized.dynamic.Linear)),
(False, LinearModule, (linear_module_input,), ('call_module', torch.nn.quantized.Linear)),
]

for is_dynamic, M, inputs, quantized_node in tests:
m = M().eval()
qconfig = default_qconfig

graph = symbolic_trace(m)
script = torch.jit.script(graph)

a = m(*inputs)
b = graph(*inputs)
c = script(*inputs)
assert (a - b).abs().max() == 0
assert (a - c).abs().max() == 0
assert torch.allclose(a, b)
assert torch.allclose(a, c)


graph = fuse(graph)

quantizer = Quantizer()
qconfig_dict = {'': qconfig}
if is_dynamic:
prepared = quantizer.prepare_dynamic(graph, qconfig_dict)
else:
prepared = quantizer.prepare(graph, qconfig_dict)

prepared(*inputs)

qgraph = quantizer.convert(prepared)
qgraph_debug = quantizer.convert(prepared, debug=True)
qgraph.eval()
qgraph_debug.eval()
qgraph_script = torch.jit.script(qgraph)

d = qgraph(*inputs)
d_debug = qgraph_debug(*inputs)
e = qgraph_script(*inputs)
e_debug = qgraph_debug(*inputs)

found = False
modules = dict(qgraph.root.named_modules())
for node in qgraph.graph.nodes:
if node.op == 'call_function':
found = found or node.op == quantized_node[0] and node.target == quantized_node[1]
elif node.op == 'call_module':
found = found or node.op == quantized_node[0] and type(modules[node.target]) == quantized_node[1]
assert found, 'Expected to find quantized node:' + str(quantized_op)
# assert (a-d).abs().max() < 2
assert torch.allclose(d, e)
assert (d - d_debug).abs().max() == 0
assert (e - e_debug).abs().max() == 0
3 changes: 3 additions & 0 deletions test/test_quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,9 @@
from quantization.test_quantize_jit import TestQuantizeDynamicJitPasses # noqa: F401
from quantization.test_quantize_jit import TestQuantizeDynamicJitOps # noqaa: F401

# 3. GraphModule based graph mode quantization
from quantization.test_quantize_fx import TestQuantizeFx # noqa: F401

# Tooling: numric_suite
from quantization.test_numeric_suite import TestEagerModeNumericSuite # noqa: F401

Expand Down
2 changes: 2 additions & 0 deletions torch/quantization/_quantize_fx.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from .fx import Quantizer # noqa: F401
from .fx import fuse # noqa: F401
3 changes: 3 additions & 0 deletions torch/quantization/fx/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from __future__ import absolute_import, division, print_function, unicode_literals
from .quantize import Quantizer
from .fuse import fuse
151 changes: 151 additions & 0 deletions torch/quantization/fx/fuse.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
import torch
from torch.quantization.fuse_modules import (
fuse_conv_bn,
fuse_conv_bn_relu,
)

from torch.fx import (
GraphModule,
)

from torch.fx.graph import (
Graph,
map_arg,
)

from .pattern_utils import (
matches,
register_fusion_pattern,
get_fusion_patterns,
)

from .utils import _parent_name

import copy

# Fusion Patterns
@register_fusion_pattern((torch.nn.BatchNorm2d, torch.nn.Conv2d))
@register_fusion_pattern((torch.nn.ReLU, torch.nn.Conv2d))
@register_fusion_pattern((torch.nn.functional.relu, torch.nn.Conv2d))
@register_fusion_pattern((torch.nn.ReLU, (torch.nn.BatchNorm2d, torch.nn.Conv2d)))
@register_fusion_pattern((torch.nn.functional.relu, (torch.nn.BatchNorm2d, torch.nn.Conv2d)))
class ConvBNReLUFusion():
def __init__(self, quantizer, node):
super().__init__()
self.relu_node = None
self.bn_node = None
if (node.op == 'call_function' and node.target is torch.nn.functional.relu) or \
(node.op == 'call_module' and type(quantizer.modules[node.target]) == torch.nn.ReLU):
self.relu_node = node
node = node.args[0]
assert node.op == 'call_module'
if isinstance(quantizer.modules[node.target], torch.nn.BatchNorm2d):
self.bn_node = node
self.bn = quantizer.modules[self.bn_node.target]
node = node.args[0]
assert node.op == 'call_module' and type(quantizer.modules[node.target]) == torch.nn.modules.Conv2d
self.conv_node = node
self.conv = quantizer.modules[self.conv_node.target]

def fuse(self, quantizer, load_arg):
conv_parent_name, conv_name = _parent_name(self.conv_node.target)
if self.relu_node is not None:
# since relu can be used multiple times, we'll need to create a relu module for each match
if self.relu_node.op == 'call_module':
relu = torch.nn.ReLU(quantizer.modules[self.relu_node.target].inplace)
else:
# TODO: get inplace argument from functional
relu = torch.nn.ReLU()
relu.training = self.conv.training
if self.bn_node is not None:
setattr(quantizer.modules[conv_parent_name], conv_name, fuse_conv_bn_relu(self.conv, self.bn, relu))
else:
# conv_relu
setattr(quantizer.modules[conv_parent_name], conv_name, torch.nn.intrinsic.ConvReLU2d(self.conv, relu))
else:
assert self.bn_node is not None
setattr(quantizer.modules[conv_parent_name], conv_name, fuse_conv_bn(self.conv, self.bn))

# TODO: do we need to make sure bn is only used once?
if self.bn_node is not None:
parent_name, name = _parent_name(self.bn_node.target)
setattr(quantizer.modules[parent_name], name, torch.nn.Identity())
return quantizer.fused_graph.node_copy(self.conv_node, load_arg)

@register_fusion_pattern((torch.nn.functional.relu, torch.nn.Linear))
@register_fusion_pattern((torch.nn.ReLU, torch.nn.Linear))
class LinearReLUFusion():
def __init__(self, quantizer, node):
super().__init__()
self.relu_node = node
node = node.args[0]
assert node.op == 'call_module'
assert isinstance(quantizer.modules[node.target], torch.nn.modules.Linear)
self.linear_node = node
self.linear = quantizer.modules[self.linear_node.target]

def fuse(self, quantizer, load_arg):
# since relu can be used multiple times, we'll need to create a relu module for each match
if self.relu_node.op == 'call_module':
relu = torch.nn.ReLU(quantizer.modules[self.relu_node.target].inplace)
else:
# TODO: get inplace argument from functional
relu = torch.nn.ReLU()
relu.training = self.linear.training
# linear_relu
linear_parent_name, linear_name = _parent_name(self.linear_node.target)
setattr(quantizer.modules[linear_parent_name], linear_name, torch.nn.intrinsic.LinearReLU(self.linear, relu))
return quantizer.fused_graph.node_copy(self.linear_node, load_arg)

class Fuser:
def fuse_conv_bn(self, model, inplace=False):
input_root = model.root
if not inplace:
input_root = copy.deepcopy(input_root)
input_graph = model.graph
self.modules = dict(input_root.named_modules())

fusion_patterns = get_fusion_patterns()
# find conv-bn pairs
conv_bn_pairs = self._find_matches(input_root, input_graph, fusion_patterns)
self.fused_graph = Graph()
env = {}

def load_arg(a):
return map_arg(a, lambda node: env[node.name])

for node in input_graph.nodes:
root_node, obj = conv_bn_pairs.get(node.name, (None, None))
if root_node is node:
env[node.name] = obj.fuse(self, load_arg)
elif root_node is None:
env[node.name] = self.fused_graph.node_copy(node, load_arg)
# node matched in patterns and is not root is removed here

self.fused_graph.output(load_arg(input_graph.result))
return GraphModule(input_root, self.fused_graph)

def _find_matches(self, root, graph, patterns):
modules = dict(root.named_modules())
match_map = {} # node name -> (root_node, match_value?)

def apply_match(pattern, node, match):
if isinstance(pattern, tuple):
s, *args = pattern
apply_match(s, node, match)
for subpattern, arg in zip(args, node.args):
apply_match(subpattern, arg, match)
else:
match_map[node.name] = match

for node in reversed(graph.nodes):
if node.name not in match_map:
for pattern, value in patterns.items():
if matches(modules, node, pattern):
apply_match(pattern, node, (node, value(self, node)))

return match_map

def fuse(graph_module, inplace=False):
fuser = Fuser()
return fuser.fuse_conv_bn(graph_module, inplace)
86 changes: 86 additions & 0 deletions torch/quantization/fx/pattern_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@

import torch
import sys
from collections import OrderedDict

# pattern for conv bn fusion
FUSION_PATTERNS = OrderedDict()
def register_fusion_pattern(pattern):
def insert(fn):
FUSION_PATTERNS[pattern] = fn
return fn
return insert

def get_fusion_patterns():
return FUSION_PATTERNS

# pattern for both static quantization and qat
QUANTIZATION_PATTERNS = OrderedDict()
def register_quant_pattern(pattern):
def insert(fn):
QUANTIZATION_PATTERNS[pattern] = fn
return fn
return insert

def get_quant_patterns():
return QUANTIZATION_PATTERNS

# pattern for dynamic quantization
DYNAMIC_QUANTIZATION_PATTERNS = OrderedDict()
def register_dynamic_pattern(pattern):
def insert(fn):
DYNAMIC_QUANTIZATION_PATTERNS[pattern] = fn
return fn
return insert

def get_dynamic_quant_patterns():
return DYNAMIC_QUANTIZATION_PATTERNS

# Example use of register pattern function:
# @register_fusion_pattern(torch.nn.ReLU, (torch.nn.BatchNorm2d, torch.nn.Conv2d)))
# class ConvBNReLUFusion():
# def __init__(...):
# ...
#
# Note: The order of patterns is important! match function will take whatever is matched first, so we'll
# need to put the fusion patterns before single patterns. For example, add_relu should be registered come before relu.
# decorators are applied in the reverse order we see. Also when we match the nodes in the graph with these patterns,
# we'll start from the last node of the graph and traverse back.


def matches(modules, node, pattern, max_uses=sys.maxsize):
""" Matches a node in fx against a pattern
"""
if isinstance(pattern, tuple):
self_match, *arg_matches = pattern
if self_match is getattr:
assert len(pattern) == 2, 'Expecting getattr pattern to have two elements'
arg_matches = []
else:
self_match = pattern
arg_matches = []

if node.uses > max_uses:
return False

if isinstance(self_match, type) and issubclass(self_match, torch.nn.Module):
if node.op != 'call_module':
return False
if not type(modules[node.target]) == self_match:
return False
elif callable(self_match):
if node.op != 'call_function' or node.target is not self_match:
return False
elif node.target is getattr:
if node.args[1] != pattern[1]:
return False
elif node.target != self_match:
return False

if not arg_matches:
return True

if len(arg_matches) != len(node.args):
return False

return all(matches(modules, node, arg_match, max_uses=1) for node, arg_match in zip(node.args, arg_matches))
Loading

0 comments on commit dae2973

Please sign in to comment.