Skip to content

Commit

Permalink
Recurse partitioner through branches
Browse files Browse the repository at this point in the history
  • Loading branch information
dvorjackz committed Jan 6, 2025
1 parent 73591f1 commit ed5b952
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 12 deletions.
11 changes: 6 additions & 5 deletions backends/xnnpack/partition/xnnpack_partitioner.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
ConfigerationBasedPartitioner,
)
from executorch.exir.backend.partitioner import DelegationSpec
from torch.fx import GraphModule
from torch.fx.passes.infra.partitioner import Partition

logging.basicConfig(level=logging.WARNING)
Expand Down Expand Up @@ -65,25 +66,25 @@ def __init__(
self.per_op_mode = per_op_mode
super().__init__(delegation_spec, initialized_configs)

def generate_partitions(self, ep: ExportedProgram) -> List[Partition]:
def generate_partitions(self, ep: ExportedProgram, gm: Optional[GraphModule] = None) -> List[Partition]:
"""
generate_partitions is different if partitioner is set to per_op_mode
for per_op_mode we only need to generate unmerged partitions instead
of using the default generate_partitions method.
"""
if self.per_op_mode:
return self.generate_per_op_partitions(ep)
return self.generate_per_op_partitions(ep, gm)
else:
return super().generate_partitions(ep)
return super().generate_partitions(ep, gm)

def generate_per_op_partitions(self, ep: ExportedProgram) -> List[Partition]:
def generate_per_op_partitions(self, ep: ExportedProgram, gm: Optional[GraphModule] = None) -> List[Partition]:
"""
Uses configs to generate per_op_partitions. That is no partitions are
merged together. All partitions (node + deps) returned by PartitionerConfigs
are put into their own partition.
"""
partitions = []
matched_nodes = self.get_matched_nodes_from_configs(ep)
matched_nodes = self.get_matched_nodes_from_configs(ep, gm)
partition_id = itertools.count()
nodes_seen = {}
for match in matched_nodes:
Expand Down
22 changes: 15 additions & 7 deletions exir/backend/canonical_partitioners/config_partitioner.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,9 @@
Partitioner,
PartitionResult,
)
from torch.fx import GraphModule
from torch.fx.passes.infra.partitioner import Partition
from executorch.exir.graph_module import get_control_flow_submodules


def format_target_name(target_name: str) -> str:
Expand Down Expand Up @@ -160,11 +162,11 @@ def filter_fn(node: torch.fx.Node) -> bool:
return (do_not_decomp, filter_fn)

def get_matched_nodes_from_configs(
self, ep: ExportedProgram
self, ep: ExportedProgram, gm: Optional[GraphModule] = None
) -> List[List[torch.fx.Node]]:
# gather supported nodes
matched_nodes = []
gm = ep.graph_module
gm = gm or ep.graph_module
for node in gm.graph.nodes:
if node.op == "call_function":
target = format_target_name(node.target.__name__)
Expand All @@ -175,17 +177,19 @@ def get_matched_nodes_from_configs(

return matched_nodes

def generate_partitions(self, ep: ExportedProgram) -> List[Partition]:
matched_nodes = self.get_matched_nodes_from_configs(ep)
def generate_partitions(self, ep: ExportedProgram, gm: Optional[GraphModule] = None) -> List[Partition]:
gm = gm or ep.graph_module
matched_nodes = self.get_matched_nodes_from_configs(ep, gm)
# create partitions
partitions = generate_partitions_from_list_of_nodes(
ep.graph_module,
gm,
matched_nodes,
)
return partitions

def partition(self, exported_program: ExportedProgram) -> PartitionResult:
partitions = self.generate_partitions(exported_program)
def partition(self, exported_program: ExportedProgram, graph_module: Optional[GraphModule] = None) -> PartitionResult:
graph_module = graph_module or exported_program.graph_module
partitions = self.generate_partitions(exported_program, graph_module)

# tag nodes
partition_tags: Dict[str, DelegationSpec] = {}
Expand All @@ -199,6 +203,10 @@ def partition(self, exported_program: ExportedProgram) -> PartitionResult:
node.meta["delegation_tag"] = delegation_tag
partition_tags[delegation_tag] = self.delegation_spec

for _, submodule, _ in get_control_flow_submodules(graph_module):
# pyre-ignore
self.partition(exported_program, submodule)

return PartitionResult(
tagged_exported_program=exported_program, partition_tags=partition_tags
)

0 comments on commit ed5b952

Please sign in to comment.