Skip to content

Commit

Permalink
Keep fx node name consistent with aot_export (pytorch#107068)
Browse files Browse the repository at this point in the history
torch.export() starts initially with node names in aot_export, if we don't make this change, any no-op transformation would break name consistency, thus breaking GraphSignature correctness.

Pull Request resolved: pytorch#107068
Approved by: https://github.com/tugsbayasgalan
  • Loading branch information
gmagogsfm authored and pytorchmergebot committed Aug 12, 2023
1 parent 8472c24 commit f26aa2d
Show file tree
Hide file tree
Showing 4 changed files with 45 additions and 6 deletions.
2 changes: 1 addition & 1 deletion test/dynamo/test_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -2361,7 +2361,7 @@ def f(x):

with self.assertRaisesRegex(
RuntimeError,
r"_local_scalar_dense_default is outside of inline constraint \[4, 7\]",
r"_local_scalar_dense is outside of inline constraint \[4, 7\]",
) as cm:
ep(torch.tensor([30]))

Expand Down
34 changes: 34 additions & 0 deletions test/export/test_pass_infra.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,40 @@ def false_fn(x, y):
mod = M()
_ = export(mod, (torch.tensor(True), x, y)).transform(_ExportPassBase())

def test_node_name_stability(self) -> None:
# Tests that graph nodes stay the same for nodes that are not touched
# during transformation
class CustomModule(torch.nn.Module):
def __init__(self):
super().__init__()

# Define a parameter
self.my_parameter = torch.nn.Parameter(torch.tensor(2.0))

# Define two buffers
self.register_buffer('my_buffer1', torch.tensor(3.0))
self.register_buffer('my_buffer2', torch.tensor(4.0))

def forward(self, x1, x2):
# Use the parameter, buffers, and both inputs in the forward method
output = (x1 + self.my_parameter) * self.my_buffer1 + x2 * self.my_buffer2

# Mutate one of the buffers (e.g., increment it by 1)
self.my_buffer2.add_(1.0)

return output

inps = (torch.rand(1), torch.rand(1))
m = CustomModule()

ep_before = export(m, inps)

# No op transformation that doesn't perform any meaningful changes to node
ep_after = ep_before.transform(_ExportPassBase())

for before_node, after_node in zip(ep_before.graph.nodes, ep_after.graph.nodes):
self.assertEqual(before_node.name, after_node.name)


if __name__ == '__main__':
run_tests()
8 changes: 4 additions & 4 deletions test/export/test_passes.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,7 +250,7 @@ def forward(self, x):
mod = M()
ep = export(mod, (x,))

with self.assertRaisesRegex(RuntimeError, r"_local_scalar_dense_default is outside of inline constraint \[2, 5\]."):
with self.assertRaisesRegex(RuntimeError, r"_local_scalar_dense is outside of inline constraint \[2, 5\]."):
ep(torch.tensor([6]))

new_inp = torch.tensor([5])
Expand Down Expand Up @@ -278,10 +278,10 @@ def forward(self, x):
self.assertEqual(num_assert, 4)
self.assertEqual(num_scalar_tensor, 4)

with self.assertRaisesRegex(RuntimeError, r"nonzero_default.shape\[0\] is outside of inline constraint \[3, 5\]."):
with self.assertRaisesRegex(RuntimeError, r"nonzero.shape\[0\] is outside of inline constraint \[3, 5\]."):
ep(torch.tensor([1, 1, 0, 0, 0]))

with self.assertRaisesRegex(RuntimeError, r"nonzero_default.shape\[0\] is outside of inline constraint \[3, 5\]."):
with self.assertRaisesRegex(RuntimeError, r"nonzero.shape\[0\] is outside of inline constraint \[3, 5\]."):
ep(torch.ones(6))

new_inp = torch.tensor([1, 1, 1, 1])
Expand Down Expand Up @@ -361,7 +361,7 @@ def f(x):

with self.assertRaisesRegex(
RuntimeError,
r"_local_scalar_dense_default is outside of inline constraint \[4, 7\]",
r"_local_scalar_dense is outside of inline constraint \[4, 7\]",
) as cm:
gm(torch.tensor([20]))

Expand Down
7 changes: 6 additions & 1 deletion torch/_export/pass_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,7 +256,12 @@ def _fx(
args_proxy, kwargs_proxy = pytree.tree_map_only(
ProxyValue, lambda x: x.proxy, (args, kwargs)
)
res_proxy = self.tracer.create_proxy(kind, target, args_proxy, kwargs_proxy)

name = None
if isinstance(target, torch._ops.OpOverload):
name = self.tracer.graph._target_to_str(target.overloadpacket.__name__)

res_proxy = self.tracer.create_proxy(kind, target, args_proxy, kwargs_proxy, name=name)
res_proxy.node.meta.update(meta.data)
self.tracer.set_metadata(res_proxy.node, res_data)
return ProxyValue(res_data, res_proxy)
Expand Down

0 comments on commit f26aa2d

Please sign in to comment.