From f175ba75fb7fa37680a475dc4bb37cfada4cc00f Mon Sep 17 00:00:00 2001 From: jason Date: Wed, 12 Jan 2022 13:47:37 +0900 Subject: [PATCH] rewrite Pattern_3.pattern_condition_checker to filter new test case: Conv + Mul(w/ one init) + Add(w/o init) --- .../onnx/transformer/fuse_bn_into_conv.py | 16 +++------- .../transformer/test_fuse_bn_into_conv.py | 31 ++++++++++++++++++- 2 files changed, 34 insertions(+), 13 deletions(-) diff --git a/python/furiosa-quantizer/furiosa/quantizer/frontend/onnx/transformer/fuse_bn_into_conv.py b/python/furiosa-quantizer/furiosa/quantizer/frontend/onnx/transformer/fuse_bn_into_conv.py index a0e3a9c2..3e37ee3b 100644 --- a/python/furiosa-quantizer/furiosa/quantizer/frontend/onnx/transformer/fuse_bn_into_conv.py +++ b/python/furiosa-quantizer/furiosa/quantizer/frontend/onnx/transformer/fuse_bn_into_conv.py @@ -230,18 +230,10 @@ def pattern_matching(self, base_node): def pattern_condition_checker(self, nodes_to_check): _, mul_node, add_node = nodes_to_check - if self.check_condition_1(mul_node): - return True - - if self.check_condition_1(add_node): - return True - - return False - - def check_condition_1(self, node): - if self.get_init_node_input(node): - return True - return False + # This checks if a node has a initializer, \ + # assuming a node has exactly one initializer if it has one. \ + # That is, there is no node with two initialzier. + return self.get_init_node_input(mul_node) and self.get_init_node_input(add_node) def make_new_node(self, matched_nodes): top_node, middle_node, bottom_node = matched_nodes diff --git a/python/furiosa-quantizer/tests/frontend/onnx/transformer/test_fuse_bn_into_conv.py b/python/furiosa-quantizer/tests/frontend/onnx/transformer/test_fuse_bn_into_conv.py index 0545d26d..68a7e8a1 100644 --- a/python/furiosa-quantizer/tests/frontend/onnx/transformer/test_fuse_bn_into_conv.py +++ b/python/furiosa-quantizer/tests/frontend/onnx/transformer/test_fuse_bn_into_conv.py @@ -60,6 +60,21 @@ def forward(self, x): return x +class UnitTestModel3_1(UnitTestModel): + """ + This creates Conv + Mul + Add graph for testing Pattern_3 + """ + + def __init__(self, in_channel, out_channel): + super().__init__(in_channel, out_channel) + + def forward(self, x): + x = self.conv(x) + x = torch.mul(x, torch.ones((1, x.shape[1], 1, 1))) + x = torch.add(x, x) + return x + + class MultiTestModel(UnitTestModel): def __init__(self, in_channel, out_channel): super(MultiTestModel, self).__init__(in_channel, out_channel) @@ -163,7 +178,7 @@ def test_case5(self): self.check_output_value(orig_model, trans_model, input_shapes) self.check_value_info(trans_model) - def test_case5(self): + def test_case5_1(self): """ This tests Pattern_3 """ @@ -179,3 +194,17 @@ def test_case5(self): self.check_graph_node(trans_model, op_types) self.check_output_value(orig_model, trans_model, input_shapes) self.check_value_info(trans_model) + + def test_case5_2(self): + input_shapes = [(1, 4, 4, 4)] + in_channel = 4 + out_channel = 8 + + op_types = ['Conv', 'Mul', 'Add'] + + orig_model, trans_model = self._make_test_model( + UnitTestModel3_1(in_channel, out_channel), input_shapes + ) + self.check_graph_node(trans_model, op_types) + self.check_output_value(orig_model, trans_model, input_shapes) + self.check_value_info(trans_model)