diff --git a/examples/models/llama/export_llama_lib.py b/examples/models/llama/export_llama_lib.py index ea4296cc52c..65bc8991a8d 100644 --- a/examples/models/llama/export_llama_lib.py +++ b/examples/models/llama/export_llama_lib.py @@ -23,6 +23,9 @@ import torch from executorch.devtools.etrecord import generate_etrecord +from executorch.exir.passes.cache_pos_init_mutable_pass import ( + CachePosToInitializedMutableBufferPass, +) from executorch.extension.llm.export.builder import DType, LLMEdgeManager @@ -760,6 +763,9 @@ def _export_llama(args) -> LLMEdgeManager: # noqa: C901 for partitioner in partitioners: logging.info(f"--> {partitioner.__class__.__name__}") + additional_passes = [] + if args.model in TORCHTUNE_DEFINED_MODELS: + additional_passes = [CachePosToInitializedMutableBufferPass()] if args.generate_etrecord: if not builder_exported_to_edge.edge_manager: raise ValueError("Unable to generate etrecord due to missing edge manager.") @@ -774,7 +780,9 @@ def _export_llama(args) -> LLMEdgeManager: # noqa: C901 # pyre-fixme[16]: Module `backends` has no attribute `qualcomm`. canonicalize_program(builder.edge_manager.exported_program()) - builder = builder.to_executorch() + builder = builder.to_executorch( + passes=additional_passes, + ) # Generate ETRecord if edge_manager_copy: @@ -792,7 +800,7 @@ def _export_llama(args) -> LLMEdgeManager: # noqa: C901 # pyre-fixme[16]: Module `backends` has no attribute `qualcomm`. canonicalize_program(builder.edge_manager.exported_program()) - builder = builder.to_executorch() + builder = builder.to_executorch(passes=additional_passes) if args.profile_memory: generate_memory_trace(builder.export_program, "memory_profile.json") diff --git a/examples/models/llama3_2_vision/runner/native.py b/examples/models/llama3_2_vision/runner/native.py index 8180f1abbf0..105ddf20545 100644 --- a/examples/models/llama3_2_vision/runner/native.py +++ b/examples/models/llama3_2_vision/runner/native.py @@ -19,7 +19,6 @@ ) from executorch.extension.pybindings.portable_lib import ( - _load_for_executorch, _load_for_executorch_from_buffer, ) @@ -50,7 +49,6 @@ def __init__(self, args): with open(args.pte, "rb") as f: self.model_bytes = f.read() self.model = _load_for_executorch_from_buffer(self.model_bytes) - # self.model = _load_for_executorch(args.pte) self.use_kv_cache = args.kv_cache def forward( diff --git a/exir/emit/_emitter.py b/exir/emit/_emitter.py index 2ee6bb60b67..119fee3cc61 100644 --- a/exir/emit/_emitter.py +++ b/exir/emit/_emitter.py @@ -1607,7 +1607,6 @@ def placeholder( if isinstance(target, str) and isinstance(spec, TensorSpec): fqn, is_mutable_buffer = self._find_fqn_for_placeholder(target, spec) - print(f"fqn: {fqn}, is_mutable_buffer: {is_mutable_buffer}") # If the placeholder has a constant_tag, it is external to the PTE file # and requires a fqn and location=TensorDataLocation.EXTERNAL diff --git a/exir/passes/init_mutable_buffer_pass.py b/exir/passes/init_mutable_buffer_pass.py deleted file mode 100644 index 688410cc2f2..00000000000 --- a/exir/passes/init_mutable_buffer_pass.py +++ /dev/null @@ -1,21 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -import torch - -from executorch.exir.pass_base import ExportPass, NodeMetadata, ProxyValue -from executorch.exir.passes.spec_prop_pass import make_spec - - -class InitMutableBufferPass(ExportPass): - def __init__(self) -> None: - super().__init__() - - def placeholder(self, name: str, arg, meta): - if "cache_pos" in name: - meta["et_init_buffer"] = True - - return super().placeholder(name, arg, meta) diff --git a/exir/program/_program.py b/exir/program/_program.py index e6247231f0a..fd1d0aca3dc 100644 --- a/exir/program/_program.py +++ b/exir/program/_program.py @@ -34,7 +34,6 @@ OpReplacePass, ) from executorch.exir.passes.external_constants_pass import external_constants_pass -from executorch.exir.passes.init_mutable_buffer_pass import InitMutableBufferPass from executorch.exir.passes.insert_write_back_for_buffers_pass import ( insert_write_back_for_buffers_pass, ) @@ -707,7 +706,6 @@ def edge_to_executorch_passes( passes: List[PassType] = [ *config.passes, SpecPropPass(), - InitMutableBufferPass(), # ExecuTorch backend ops are unable to handle unbacked symints. So after # this pass, passes cannot be Interpreter-based, because it will fail if # there exists an unbacked symint operation. diff --git a/extension/llm/export/builder.py b/extension/llm/export/builder.py index 8bb98ebeaeb..23c38bc0ce0 100644 --- a/extension/llm/export/builder.py +++ b/extension/llm/export/builder.py @@ -25,6 +25,7 @@ from executorch.exir.backend.utils import format_delegated_graph from executorch.exir.capture._config import EdgeCompileConfig, ExecutorchBackendConfig +from executorch.exir.pass_manager import PassType from executorch.exir.passes import MemoryPlanningPass from executorch.exir.passes.quant_fusion_pass import QuantFusionPass from executorch.exir.passes.sym_shape_eval_pass import ConstraintBasedSymShapeEvalPass @@ -395,26 +396,29 @@ def to_backend(self, partitioners: Optional[List[Partitioner]]) -> "LLMEdgeManag return self - def to_executorch(self) -> "LLMEdgeManager": + def to_executorch(self, passes: Optional[List[PassType]]) -> "LLMEdgeManager": """ Lower the model to executorch and get an ExecutorchProgram. """ assert self.edge_manager, "Need to run export_to_edge() first" + to_executorch_passes = [ + # If there are Linear operations left in the graph, let's execute + # them with the optimized op_linear rather than materializing a + # transpose followed by a regular op_mm. + ConvertToLinearPass(), + QuantFusionPass(), + ] + if passes: + to_executorch_passes.extend(passes) + self.export_program = self.edge_manager.to_executorch( ExecutorchBackendConfig( extract_delegate_segments=True, - passes=[ - # If there are Linear operations left in the graph, let's execute - # them with the optimized op_linear rather than materializing a - # transpose followed by a regular op_mm. - ConvertToLinearPass(), - QuantFusionPass(), - ], + passes=passes, memory_planning_pass=MemoryPlanningPass(alloc_graph_input=False), sym_shape_eval_pass=ConstraintBasedSymShapeEvalPass(), ) ) - print(self.export_program.dump_executorch_program(verbose=True)) logging.info( "Required memory for activation in bytes: {}".format( self.export_program._emitter_output.program.execution_plan[