From d8deda2e4373ad5c44bda168b333da6232533a9a Mon Sep 17 00:00:00 2001 From: Jack Zhang Date: Mon, 6 Jan 2025 10:15:56 -0800 Subject: [PATCH] Export llama uses to_edge_lower_and_transform --- examples/models/llama/export_llama_lib.py | 19 +++-- extension/llm/export/builder.py | 95 +++++++++++++---------- 2 files changed, 67 insertions(+), 47 deletions(-) diff --git a/examples/models/llama/export_llama_lib.py b/examples/models/llama/export_llama_lib.py index 4e004e773f..0552f8f6b7 100644 --- a/examples/models/llama/export_llama_lib.py +++ b/examples/models/llama/export_llama_lib.py @@ -659,11 +659,12 @@ def _export_llama(args) -> LLMEdgeManager: # noqa: C901 if args.export_only: exit() - builder_exported_to_edge = builder_exported.pt2e_quantize( - quantizers - ).export_to_edge() + # builder_exported_to_edge = builder_exported.pt2e_quantize( + # quantizers + # ).export_to_edge() - modelname = builder_exported_to_edge.modelname + # modelname = builder_exported_to_edge.modelname + modelname = builder_exported.modelname # to_backend partitioners = [] @@ -768,6 +769,7 @@ def _export_llama(args) -> LLMEdgeManager: # noqa: C901 for partitioner in partitioners: logging.info(f"--> {partitioner.__class__.__name__}") + breakpoint() if args.generate_etrecord: if not builder_exported_to_edge.edge_manager: raise ValueError("Unable to generate etrecord due to missing edge manager.") @@ -793,14 +795,19 @@ def _export_llama(args) -> LLMEdgeManager: # noqa: C901 ) logging.info("Generated etrecord.bin") else: - builder = builder_exported_to_edge.to_backend(partitioners) + builder_lowered = builder_exported.pt2e_quantize( + quantizers + ).to_edge_transform_and_lower( + partitioners + ) + # builder = builder_exported_to_edge.to_backend(partitioners) if args.num_sharding > 0 and args.qnn: from executorch.backends.qualcomm.utils.utils import canonicalize_program # pyre-fixme[16]: Module `backends` has no attribute `qualcomm`. canonicalize_program(builder.edge_manager.exported_program()) - builder = builder.to_executorch() + builder = builder_lowered.to_executorch() if args.profile_memory: generate_memory_trace(builder.export_program, "memory_profile.json") diff --git a/extension/llm/export/builder.py b/extension/llm/export/builder.py index 7cab3c77b8..cbbab8d02c 100644 --- a/extension/llm/export/builder.py +++ b/extension/llm/export/builder.py @@ -21,7 +21,7 @@ DuplicateDynamicQuantChainPass, ) from executorch.backends.xnnpack._passes.convert_to_linear import ConvertToLinearPass -from executorch.exir import EdgeProgramManager +from executorch.exir import EdgeProgramManager, to_edge_transform_and_lower from executorch.exir.backend.partitioner import Partitioner from executorch.exir.backend.utils import format_delegated_graph @@ -216,6 +216,7 @@ def export(self) -> "LLMEdgeManager": ) # pyre-fixme[8]: Attribute has type `Optional[GraphModule]`; used as # `Module`. + self.pre_autograd_exported_program = exported_module self.pre_autograd_graph_module = exported_module.module() if hasattr(self.args, "export_only") and self.args.export_only: torch.export.save(exported_module, self.args.output_name) @@ -305,51 +306,51 @@ def pt2e_quantize(self, quantizers: Optional[List[Quantizer]]) -> "LLMEdgeManage ), "export_to_edge is already called, please call pt2e_quantize before export_to_edge" logging.info(f"Using pt2e {quantizers} to quantizing the model...") + if not quantizers: + logging.info("No quantizer provided, passing...") + return self + # 1. torch.nn.attention.sdpa_kernel([SDPBackend.MATH]) is for bypassing the dynamo error when tracing # 2. torch.no_grad() is for getting rid of the dropout (not sure why training ops will show up) - if quantizers: - with torch.nn.attention.sdpa_kernel([SDPBackend.MATH]), torch.no_grad(): - if self.verbose: - logging.info(f"Applied quantizers: {quantizers}") - composed_quantizer = ComposableQuantizer(quantizers) - assert ( - self.pre_autograd_graph_module is not None - ), "Please run export() first" - m = prepare_pt2e(self.pre_autograd_graph_module, composed_quantizer) + with torch.nn.attention.sdpa_kernel([SDPBackend.MATH]), torch.no_grad(): + if self.verbose: + logging.info(f"Applied quantizers: {quantizers}") + composed_quantizer = ComposableQuantizer(quantizers) + assert ( + self.pre_autograd_graph_module is not None + ), "Please run export() first" + m = prepare_pt2e(self.pre_autograd_graph_module, composed_quantizer) + logging.info( + f"Calibrating with tasks: {self.calibration_tasks}, limit: {self.calibration_limit}, calibration_data: {self.calibration_data}, tokenizer_path: {self.tokenizer_path}, seq_length: {self.calibration_seq_length}" + ) + # Calibrate + if ( + self.calibration_tasks is not None + and self.calibration_limit is not None + and self.calibration_seq_length is not None + and self.calibration_data is not None + and self.tokenizer_path is not None + ): logging.info( f"Calibrating with tasks: {self.calibration_tasks}, limit: {self.calibration_limit}, calibration_data: {self.calibration_data}, tokenizer_path: {self.tokenizer_path}, seq_length: {self.calibration_seq_length}" ) - # Calibrate - if ( - self.calibration_tasks is not None - and self.calibration_limit is not None - and self.calibration_seq_length is not None - and self.calibration_data is not None - and self.tokenizer_path is not None - ): - logging.info( - f"Calibrating with tasks: {self.calibration_tasks}, limit: {self.calibration_limit}, calibration_data: {self.calibration_data}, tokenizer_path: {self.tokenizer_path}, seq_length: {self.calibration_seq_length}" - ) - self.pt2e_calibrate( - prepared_module=m, - calibration_tasks=self.calibration_tasks, - calibration_limit=self.calibration_limit, - calibration_seq_length=self.calibration_seq_length, - calibration_data=self.calibration_data, - tokenizer_path=self.tokenizer_path, - ) - else: - logging.info( - "No calibration provided, using dummy input to calibrate..." - ) - m(*self.example_inputs) - m = convert_pt2e(m) - DuplicateDynamicQuantChainPass()(m) - self.pre_autograd_graph_module = m - return self - else: - logging.info("No quantizer provided, passing...") - return self + self.pt2e_calibrate( + prepared_module=m, + calibration_tasks=self.calibration_tasks, + calibration_limit=self.calibration_limit, + calibration_seq_length=self.calibration_seq_length, + calibration_data=self.calibration_data, + tokenizer_path=self.tokenizer_path, + ) + else: + logging.info( + "No calibration provided, using dummy input to calibrate..." + ) + m(*self.example_inputs, **self.example_kwarg_inputs) + m = convert_pt2e(m) + DuplicateDynamicQuantChainPass()(m) + self.pre_autograd_graph_module = m + return self def export_to_edge(self) -> "LLMEdgeManager": """ @@ -415,6 +416,18 @@ def to_backend(self, partitioners: Optional[List[Partitioner]]) -> "LLMEdgeManag return self + def to_edge_transform_and_lower(self, partitioners: Optional[List[Partitioner]]) -> "LLMEdgeManager": + if partitioners is None: + logging.info("No partitioner provided, skipping backend lowering...") + breakpoint() + edge_config = self._get_edge_config() + self.edge_manager = to_edge_transform_and_lower( + self.pre_autograd_exported_program, + partitioner=partitioners, + compile_config=edge_config, + ) + return self + def to_executorch(self) -> "LLMEdgeManager": """ Lower the model to executorch and get an ExecutorchProgram.