From 99912cda4a96b6b8f34cbccdac7851797a1ae197 Mon Sep 17 00:00:00 2001 From: Mengtao Yuan Date: Wed, 22 Jan 2025 13:29:52 -0800 Subject: [PATCH] Print delegation info in export_llama in verbose (#7803) Co-authored-by: Martin Yuan --- examples/models/llama/export_llama_lib.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/examples/models/llama/export_llama_lib.py b/examples/models/llama/export_llama_lib.py index 69980990cf..6d4e1de0c0 100644 --- a/examples/models/llama/export_llama_lib.py +++ b/examples/models/llama/export_llama_lib.py @@ -21,6 +21,7 @@ import pkg_resources import torch +from executorch.devtools.backend_debug import get_delegation_info from executorch.devtools.etrecord import generate_etrecord from executorch.exir.passes.init_mutable_pass import InitializedMutableBufferPass @@ -43,6 +44,7 @@ get_vulkan_quantizer, ) from executorch.util.activation_memory_profiler import generate_memory_trace +from tabulate import tabulate from ..model_factory import EagerModelFactory from .source_transformation.apply_spin_quant_r1_r2 import ( @@ -777,6 +779,12 @@ def _export_llama(args) -> LLMEdgeManager: # noqa: C901 for partitioner in partitioners: logging.info(f"--> {partitioner.__class__.__name__}") + def print_delegation_info(graph_module: torch.fx.GraphModule): + delegation_info = get_delegation_info(graph_module) + print(delegation_info.get_summary()) + df = delegation_info.get_operator_delegation_dataframe() + print(tabulate(df, headers="keys", tablefmt="fancy_grid")) + additional_passes = [] if args.model in TORCHTUNE_DEFINED_MODELS: additional_passes = [InitializedMutableBufferPass(["kv_cache_pos"])] @@ -788,6 +796,8 @@ def _export_llama(args) -> LLMEdgeManager: # noqa: C901 # Copy the edge manager which will be serialized into etrecord. This is memory-wise expensive. edge_manager_copy = copy.deepcopy(builder_exported_to_edge.edge_manager) builder = builder_exported_to_edge.to_backend(partitioners) + if args.verbose: + print_delegation_info(builder.edge_manager.exported_program().graph_module) if args.num_sharding > 0 and args.qnn: from executorch.backends.qualcomm.utils.utils import canonicalize_program @@ -808,6 +818,8 @@ def _export_llama(args) -> LLMEdgeManager: # noqa: C901 logging.info("Generated etrecord.bin") else: builder = builder_exported_to_edge.to_backend(partitioners) + if args.verbose: + print_delegation_info(builder.edge_manager.exported_program().graph_module) if args.num_sharding > 0 and args.qnn: from executorch.backends.qualcomm.utils.utils import canonicalize_program