Skip to content

Commit

Permalink
rewrite Pattern_3.pattern_condition_checker to filter new test case: …
Browse files Browse the repository at this point in the history
…Conv + Mul(w/ one init) + Add(w/o init)
  • Loading branch information
deeplearningfromscratch authored and hyunsik committed Jan 13, 2022
1 parent 2a8fd20 commit f175ba7
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 13 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -230,18 +230,10 @@ def pattern_matching(self, base_node):
def pattern_condition_checker(self, nodes_to_check):
_, mul_node, add_node = nodes_to_check

if self.check_condition_1(mul_node):
return True

if self.check_condition_1(add_node):
return True

return False

def check_condition_1(self, node):
if self.get_init_node_input(node):
return True
return False
# This checks if a node has a initializer, \
# assuming a node has exactly one initializer if it has one. \
# That is, there is no node with two initialzier.
return self.get_init_node_input(mul_node) and self.get_init_node_input(add_node)

def make_new_node(self, matched_nodes):
top_node, middle_node, bottom_node = matched_nodes
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,21 @@ def forward(self, x):
return x


class UnitTestModel3_1(UnitTestModel):
"""
This creates Conv + Mul + Add graph for testing Pattern_3
"""

def __init__(self, in_channel, out_channel):
super().__init__(in_channel, out_channel)

def forward(self, x):
x = self.conv(x)
x = torch.mul(x, torch.ones((1, x.shape[1], 1, 1)))
x = torch.add(x, x)
return x


class MultiTestModel(UnitTestModel):
def __init__(self, in_channel, out_channel):
super(MultiTestModel, self).__init__(in_channel, out_channel)
Expand Down Expand Up @@ -163,7 +178,7 @@ def test_case5(self):
self.check_output_value(orig_model, trans_model, input_shapes)
self.check_value_info(trans_model)

def test_case5(self):
def test_case5_1(self):
"""
This tests Pattern_3
"""
Expand All @@ -179,3 +194,17 @@ def test_case5(self):
self.check_graph_node(trans_model, op_types)
self.check_output_value(orig_model, trans_model, input_shapes)
self.check_value_info(trans_model)

def test_case5_2(self):
input_shapes = [(1, 4, 4, 4)]
in_channel = 4
out_channel = 8

op_types = ['Conv', 'Mul', 'Add']

orig_model, trans_model = self._make_test_model(
UnitTestModel3_1(in_channel, out_channel), input_shapes
)
self.check_graph_node(trans_model, op_types)
self.check_output_value(orig_model, trans_model, input_shapes)
self.check_value_info(trans_model)

0 comments on commit f175ba7

Please sign in to comment.