Skip to content

Commit

Permalink
Print delegation info in export_llama in verbose (#7803)
Browse files Browse the repository at this point in the history
Co-authored-by: Martin Yuan <[email protected]>
  • Loading branch information
iseeyuan and Martin Yuan authored Jan 22, 2025
1 parent 74aace6 commit 99912cd
Showing 1 changed file with 12 additions and 0 deletions.
12 changes: 12 additions & 0 deletions examples/models/llama/export_llama_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

import pkg_resources
import torch
from executorch.devtools.backend_debug import get_delegation_info

from executorch.devtools.etrecord import generate_etrecord
from executorch.exir.passes.init_mutable_pass import InitializedMutableBufferPass
Expand All @@ -43,6 +44,7 @@
get_vulkan_quantizer,
)
from executorch.util.activation_memory_profiler import generate_memory_trace
from tabulate import tabulate

from ..model_factory import EagerModelFactory
from .source_transformation.apply_spin_quant_r1_r2 import (
Expand Down Expand Up @@ -777,6 +779,12 @@ def _export_llama(args) -> LLMEdgeManager: # noqa: C901
for partitioner in partitioners:
logging.info(f"--> {partitioner.__class__.__name__}")

def print_delegation_info(graph_module: torch.fx.GraphModule):
delegation_info = get_delegation_info(graph_module)
print(delegation_info.get_summary())
df = delegation_info.get_operator_delegation_dataframe()
print(tabulate(df, headers="keys", tablefmt="fancy_grid"))

additional_passes = []
if args.model in TORCHTUNE_DEFINED_MODELS:
additional_passes = [InitializedMutableBufferPass(["kv_cache_pos"])]
Expand All @@ -788,6 +796,8 @@ def _export_llama(args) -> LLMEdgeManager: # noqa: C901
# Copy the edge manager which will be serialized into etrecord. This is memory-wise expensive.
edge_manager_copy = copy.deepcopy(builder_exported_to_edge.edge_manager)
builder = builder_exported_to_edge.to_backend(partitioners)
if args.verbose:
print_delegation_info(builder.edge_manager.exported_program().graph_module)
if args.num_sharding > 0 and args.qnn:
from executorch.backends.qualcomm.utils.utils import canonicalize_program

Expand All @@ -808,6 +818,8 @@ def _export_llama(args) -> LLMEdgeManager: # noqa: C901
logging.info("Generated etrecord.bin")
else:
builder = builder_exported_to_edge.to_backend(partitioners)
if args.verbose:
print_delegation_info(builder.edge_manager.exported_program().graph_module)
if args.num_sharding > 0 and args.qnn:
from executorch.backends.qualcomm.utils.utils import canonicalize_program

Expand Down

0 comments on commit 99912cd

Please sign in to comment.