Skip to content

Commit

Permalink
Fix Graph builder for higher order ops.
Browse files Browse the repository at this point in the history
Differential Revision: D68231732

Pull Request resolved: #7684
  • Loading branch information
hsharma35 authored Jan 16, 2025
1 parent af7613c commit 745f17e
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 2 deletions.
19 changes: 18 additions & 1 deletion backends/cadence/aot/graph_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@
from typing import Optional, Sequence, Union

import torch
from executorch.exir.pass_base import ExportPass, NodeMetadata, ProxyValue
from executorch.exir.pass_base import ExportPass, NodeMetadata, PassResult, ProxyValue
from torch._dispatch.python import enable_python_dispatcher
from torch._subclasses import FakeTensor, FakeTensorMode
from torch.fx.node import Argument, Target
from torch.utils import _pytree as pytree
Expand Down Expand Up @@ -80,6 +81,22 @@ def call_operator(
kwargs = {}
return super().call_operator(op, args, kwargs, meta)

def call_submodule(
self, graph_module: torch.fx.GraphModule, inputs: tuple[Argument, ...]
) -> PassResult:
return ExportPass().call(graph_module)

def _fx(
self,
kind: str,
target: torch.fx.node.Target,
args: tuple[Argument, ...],
kwargs: dict[str, Argument],
meta: NodeMetadata,
) -> ProxyValue:
with self.fake_tensor_mode, enable_python_dispatcher():
return super()._fx(kind, target, args, kwargs, meta)


def single_op_builder(
placeholders: Sequence[Union[torch.Tensor, FakeTensor]],
Expand Down
33 changes: 32 additions & 1 deletion backends/cadence/aot/tests/test_graph_builder.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
# (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary.

# pyre-strict


from typing import Sequence

import executorch.backends.cadence.aot.ops_registrations # noqa
import torch
Expand All @@ -9,7 +13,7 @@
)
from executorch.backends.cadence.aot.pass_utils import count_node
from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.pass_base import ExportPass
from executorch.exir.pass_base import ExportPass, NodeMetadata
from later.unittest import TestCase


Expand Down Expand Up @@ -68,3 +72,30 @@ def test_graph_with_single_im2row(self) -> None:
# Check graph has a single im2row node.
self.assertEqual(len([gm.graph.nodes]), 1)
self.assertEqual(count_node(gm, exir_ops.edge.cadence.im2row.default), 1)


class TestHigherOrderOps(TestCase):
def _get_inner_graph(self, x_shape: Sequence[int]) -> torch.fx.GraphModule:
builder = GraphBuilder()
x = builder.placeholder("x", torch.randn(*x_shape))
add = builder.call_operator(
exir_ops.edge.aten.add.Tensor,
(x, x), # pyre-ignore
)
builder.output([x, add])
gm = builder.get_graph_module()
# Check if graph module is valid by running exportpass on it.
gm = ExportPass().call(gm).graph_module
return gm

def test_call_map(self) -> None:
builder = GraphBuilder()
x_shape = (4, 8, 8)
x = builder.placeholder("x", torch.randn(*x_shape))
map_node = builder.call_map(
self._get_inner_graph(x_shape[1:]), [x], [], NodeMetadata({})
)
builder.output([map_node])
gm = builder.get_graph_module()
# Check if graph module is valid by running exportpass on it.
ExportPass().call(gm).graph_module

0 comments on commit 745f17e

Please sign in to comment.