Skip to content

Commit

Permalink
remove exir.capture from test_arg_validator (#2808)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #2808

title

Reviewed By: Gasoonjia

Differential Revision: D55444947

fbshipit-source-id: 83acc77b5cce792bd7f6c9a0d36f34441f8ff756
  • Loading branch information
JacobSzwejbka authored and facebook-github-bot committed Apr 2, 2024
1 parent 7938344 commit 8ab6daf
Showing 1 changed file with 9 additions and 10 deletions.
19 changes: 9 additions & 10 deletions exir/tests/test_arg_validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,11 @@
import unittest

import torch
from executorch import exir
from executorch.exir import EdgeCompileConfig
from executorch.exir import EdgeCompileConfig, to_edge
from executorch.exir.dialects._ops import ops
from executorch.exir.dialects.edge._ops import EdgeOpOverload
from executorch.exir.verification.arg_validator import EdgeOpArgValidator
from torch.export import export


class TestArgValidator(unittest.TestCase):
Expand All @@ -31,11 +31,7 @@ def forward(self, x):

m = TestModel()
inputs = (torch.randn(1, 3, 100, 100).to(dtype=torch.int),)
egm = (
exir.capture(m, inputs, exir.CaptureConfig())
.to_edge(EdgeCompileConfig(_check_ir_validity=False))
.exported_program.graph_module
)
egm = to_edge(export(m, inputs)).exported_program().graph_module
validator = EdgeOpArgValidator(egm)
validator.run(*inputs)
self.assertEqual(len(validator.violating_ops), 0)
Expand All @@ -52,9 +48,12 @@ def forward(self, x):

inputs = (torch.randn(1, 3, 100, 100).to(dtype=torch.bfloat16),)
egm = (
exir.capture(M(), inputs, exir.CaptureConfig())
.to_edge(EdgeCompileConfig(_check_ir_validity=False))
.exported_program.graph_module
to_edge(
export(M(), inputs),
compile_config=EdgeCompileConfig(_check_ir_validity=False),
)
.exported_program()
.graph_module
)
validator = EdgeOpArgValidator(egm)
validator.run(*inputs)
Expand Down

0 comments on commit 8ab6daf

Please sign in to comment.