Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Quantize compatible node + activation patterns as one block #7555

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions backends/arm/_passes/arm_pass_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,9 @@
QuantizeFullArgument,
RetraceFoldedDtypesPass,
)
from executorch.backends.arm._passes.fuse_quantized_activation_pass import (
FuseQuantizedActivationPass,
)
from executorch.backends.arm._passes.insert_table_ops import InsertTableOpsPass
from executorch.backends.arm._passes.keep_dims_false_to_squeeze_pass import (
KeepDimsFalseToSqueezePass,
Expand Down Expand Up @@ -72,6 +75,7 @@ def transform_to_backend_pipeline(
self, exported_program: ExportedProgram, compile_spec: list[CompileSpec]
):
"""Apply passes before transforming program to backend"""
self.add_pass(FuseQuantizedActivationPass())
self.add_pass(DecomposeLinearPass())
self.add_pass(RemoveGetItemPass())
self.add_pass(DecomposeLayerNormPass())
Expand Down
60 changes: 60 additions & 0 deletions backends/arm/_passes/fuse_quantized_activation_pass.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
# Copyright 2025 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import torch
from executorch.backends.arm.tosa_quant_utils import q_op
from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.pass_base import ExportPass, PassResult
from torch.fx import Node


class FuseQuantizedActivationPass(ExportPass):
def _is_fuseable_quantized_activation(self, node: Node):
"""Fuse activations that have a 0 lower bound and quantized with a qmin zero-point"""
is_fuseable = node.target == exir_ops.edge.aten.relu.default
if node.target == exir_ops.edge.aten.hardtanh.default:
min_val = node.args[1]
is_fuseable = min_val == 0

is_quantized = len(node.users) == 1 and next(iter(node.users)).target == q_op
if is_quantized:
quant_node = next(iter(node.users))
zp = quant_node.args[2]
qmin = quant_node.args[3]

return is_fuseable and is_quantized and zp == qmin

def _is_fuseable_input(self, node: Node):
return (
node.target
in (
exir_ops.edge.aten.convolution.default,
exir_ops.edge.aten.linear.default,
)
and len(node.users) == 1
)

def call(self, graph_module: torch.fx.GraphModule):
modified = False
for node in graph_module.graph.nodes:
if node.op != "call_function":
continue

if not self._is_fuseable_quantized_activation(node):
continue

input_node = node.args[0]
if not self._is_fuseable_input(input_node):
continue

node.replace_all_uses_with(input_node)
graph_module.graph.erase_node(node)
modified = True

if modified:
graph_module.recompile()
graph_module = super().call(graph_module).graph_module

return PassResult(graph_module, modified)
66 changes: 65 additions & 1 deletion backends/arm/quantizer/quantization_annotator.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,41 @@ def _annotate_output(node: Node, quant_property: _QuantProperty):
_annotate_output_qspec(node, quant_property.qspec)


def _match_pattern(
node: Node, pattern: List[List], filter_fn: Optional[Callable[[Node], bool]] = None
) -> bool:
"""
Check if there's a chain of node.ancestors? -> node -> node.descendant? that matches the
chain provided in 'pattern'. If 'filter_fn' is provided, check that all the nodes in the
chain pass the filtering.

Each 'pattern' element is composed of a list of disjunctive nodes types.
"""
assert len(pattern) == 2, "Only two-nodes patterns supported currently"

if node.target in pattern[0]:
assert len(node.users) != 0
parent = node
child = next(iter(node.users))
elif node.target in pattern[1]:
assert len(node.args) != 0
parent = node.args[0]
child = node
else:
return False

if len(parent.users) != 1:
return False

if parent.target not in pattern[0] or child.target not in pattern[1]:
return False

if filter_fn is not None:
return filter_fn(parent) and filter_fn(child)

return True


_one_to_one = [
torch.ops.aten.exp.default,
torch.ops.aten.log.default,
Expand Down Expand Up @@ -164,7 +199,36 @@ def get_quant_properties( # noqa: C901
bias_qspec = quantization_config.get_bias_qspec()

quant_properties = _OpQuantProperties()
if node.target in (

def any_or_hardtanh_min_zero(n: Node):
# Check that if the node is a hardtanh, its min_val is zero
return n.target != torch.ops.aten.hardtanh.default or n.args[1] == 0

if _match_pattern(
node,
[
[
torch.ops.aten.conv1d.default,
torch.ops.aten.conv2d.default,
torch.ops.aten.linear.default,
],
[torch.ops.aten.relu.default, torch.ops.aten.hardtanh.default],
],
any_or_hardtanh_min_zero,
):
if node.target in (
torch.ops.aten.conv1d.default,
torch.ops.aten.conv2d.default,
torch.ops.aten.linear.default,
):
quant_properties.quant_inputs = [
_QuantProperty(0, input_act_qspec),
_QuantProperty(1, weight_qspec, mark_annotated=True),
_QuantProperty(2, bias_qspec, optional=True, mark_annotated=True),
]
else:
quant_properties.quant_output = _QuantProperty(0, output_act_qspec)
elif node.target in (
torch.ops.aten.conv1d.default,
torch.ops.aten.conv2d.default,
torch.ops.aten.linear.default,
Expand Down
7 changes: 4 additions & 3 deletions backends/arm/test/ops/test_conv_combos.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,10 +137,11 @@ class ComboConvRelu6(torch.nn.Module):
]

test_data = [
(20 * torch.randn(1, 3, 256, 256),),
(5 * torch.randn(1, 3, 256, 256),),
(2 * torch.randn(1, 3, 256, 256),),
(0.5 * torch.randn(1, 3, 256, 256),),
(torch.randn(1, 3, 256, 256),),
(-5 * torch.randn(1, 3, 256, 256),),
(-0.5 * torch.randn(1, 3, 256, 256),),
(-2 * torch.randn(1, 3, 256, 256),),
]

def __init__(self):
Expand Down
Loading