Skip to content

Commit

Permalink
Quantize compatible node + activation patterns as one block
Browse files Browse the repository at this point in the history
Annotate conv1d/conv2d/linear followed by relu/relu6 patterns as one block and fuse the activation into its parent. The activation will then be implicitly done in the tosa.rescale node that will have a -128 zero-point.

Change-Id: I5bf1e2c91be21ab842012fbc20d159af7fe2222d
  • Loading branch information
Tessil committed Jan 8, 2025
1 parent 08770b7 commit 8df7f54
Show file tree
Hide file tree
Showing 4 changed files with 133 additions and 4 deletions.
4 changes: 4 additions & 0 deletions backends/arm/_passes/arm_pass_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,9 @@
QuantizeFullArgument,
RetraceFoldedDtypesPass,
)
from executorch.backends.arm._passes.fuse_quantized_activation_pass import (
FuseQuantizedActivationPass,
)
from executorch.backends.arm._passes.insert_table_ops import InsertTableOpsPass
from executorch.backends.arm._passes.keep_dims_false_to_squeeze_pass import (
KeepDimsFalseToSqueezePass,
Expand Down Expand Up @@ -72,6 +75,7 @@ def transform_to_backend_pipeline(
self, exported_program: ExportedProgram, compile_spec: list[CompileSpec]
):
"""Apply passes before transforming program to backend"""
self.add_pass(FuseQuantizedActivationPass())
self.add_pass(DecomposeLinearPass())
self.add_pass(RemoveGetItemPass())
self.add_pass(DecomposeLayerNormPass())
Expand Down
60 changes: 60 additions & 0 deletions backends/arm/_passes/fuse_quantized_activation_pass.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
# Copyright 2025 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import torch
from executorch.backends.arm.tosa_quant_utils import q_op
from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.pass_base import ExportPass, PassResult
from torch.fx import Node


class FuseQuantizedActivationPass(ExportPass):
def _is_fuseable_quantized_activation(self, node: Node):
"""Fuse activations that have a 0 lower bound and quantized with a qmin zero-point"""
is_fuseable = node.target == exir_ops.edge.aten.relu.default
if node.target == exir_ops.edge.aten.hardtanh.default:
min_val = node.args[1]
is_fuseable = min_val == 0

is_quantized = len(node.users) == 1 and next(iter(node.users)).target == q_op
if is_quantized:
quant_node = next(iter(node.users))
zp = quant_node.args[2]
qmin = quant_node.args[3]

return is_fuseable and is_quantized and zp == qmin

def _is_fuseable_input(self, node: Node):
return (
node.target
in (
exir_ops.edge.aten.convolution.default,
exir_ops.edge.aten.linear.default,
)
and len(node.users) == 1
)

def call(self, graph_module: torch.fx.GraphModule):
modified = False
for node in graph_module.graph.nodes:
if node.op != "call_function":
continue

if not self._is_fuseable_quantized_activation(node):
continue

input_node = node.args[0]
if not self._is_fuseable_input(input_node):
continue

node.replace_all_uses_with(input_node)
graph_module.graph.erase_node(node)
modified = True

if modified:
graph_module.recompile()
graph_module = super().call(graph_module).graph_module

return PassResult(graph_module, modified)
66 changes: 65 additions & 1 deletion backends/arm/quantizer/quantization_annotator.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,41 @@ def _annotate_output(node: Node, quant_property: _QuantProperty):
_annotate_output_qspec(node, quant_property.qspec)


def match_pattern(
node: Node, pattern: List[List], filter_fn: Optional[Callable[[Node], bool]] = None
) -> bool:
"""
Check if there's a chain of node.ancestors? -> node -> node.descendant? that matches the
chain provided in 'pattern'. If 'filter_fn' is provided, check that all the nodes in the
chain pass the filtering.
Each 'pattern' element is composed of a list of disjunctive nodes types.
"""
assert len(pattern) == 2, "Only two-nodes patterns supported currently"

if node.target in pattern[0]:
assert len(node.users) != 0
parent = node
child = next(iter(node.users))
elif node.target in pattern[1]:
assert len(node.args) != 0
parent = node.args[0]
child = node
else:
return False

if len(parent.users) != 1:
return False

if parent.target not in pattern[0] or child.target not in pattern[1]:
return False

if filter_fn is not None:
return filter_fn(parent) and filter_fn(child)

return True


_one_to_one = [
torch.ops.aten.exp.default,
torch.ops.aten.log.default,
Expand Down Expand Up @@ -164,7 +199,36 @@ def get_quant_properties( # noqa: C901
bias_qspec = quantization_config.get_bias_qspec()

quant_properties = _OpQuantProperties()
if node.target in (

def any_or_hardtanh_min_zero(n: Node):
# Check that if the node is a hardtanh, its min_val is zero
return n.target != torch.ops.aten.hardtanh.default or n.args[1] == 0

if match_pattern(
node,
[
[
torch.ops.aten.conv1d.default,
torch.ops.aten.conv2d.default,
torch.ops.aten.linear.default,
],
[torch.ops.aten.relu.default, torch.ops.aten.hardtanh.default],
],
any_or_hardtanh_min_zero,
):
if node.target in (
torch.ops.aten.conv1d.default,
torch.ops.aten.conv2d.default,
torch.ops.aten.linear.default,
):
quant_properties.quant_inputs = [
_QuantProperty(0, input_act_qspec),
_QuantProperty(1, weight_qspec, mark_annotated=True),
_QuantProperty(2, bias_qspec, optional=True, mark_annotated=True),
]
else:
quant_properties.quant_output = _QuantProperty(0, output_act_qspec)
elif node.target in (
torch.ops.aten.conv1d.default,
torch.ops.aten.conv2d.default,
torch.ops.aten.linear.default,
Expand Down
7 changes: 4 additions & 3 deletions backends/arm/test/ops/test_conv_combos.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,10 +137,11 @@ class ComboConvRelu6(torch.nn.Module):
]

test_data = [
(20 * torch.randn(1, 3, 256, 256),),
(5 * torch.randn(1, 3, 256, 256),),
(2 * torch.randn(1, 3, 256, 256),),
(0.5 * torch.randn(1, 3, 256, 256),),
(torch.randn(1, 3, 256, 256),),
(-5 * torch.randn(1, 3, 256, 256),),
(-0.5 * torch.randn(1, 3, 256, 256),),
(-2 * torch.randn(1, 3, 256, 256),),
]

def __init__(self):
Expand Down

0 comments on commit 8df7f54

Please sign in to comment.