Skip to content

Commit

Permalink
Export llama uses to_edge_lower_and_transform
Browse files Browse the repository at this point in the history
  • Loading branch information
jackzhxng committed Jan 6, 2025
1 parent 2600cc8 commit d8deda2
Show file tree
Hide file tree
Showing 2 changed files with 67 additions and 47 deletions.
19 changes: 13 additions & 6 deletions examples/models/llama/export_llama_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -659,11 +659,12 @@ def _export_llama(args) -> LLMEdgeManager: # noqa: C901
if args.export_only:
exit()

builder_exported_to_edge = builder_exported.pt2e_quantize(
quantizers
).export_to_edge()
# builder_exported_to_edge = builder_exported.pt2e_quantize(
# quantizers
# ).export_to_edge()

modelname = builder_exported_to_edge.modelname
# modelname = builder_exported_to_edge.modelname
modelname = builder_exported.modelname

# to_backend
partitioners = []
Expand Down Expand Up @@ -768,6 +769,7 @@ def _export_llama(args) -> LLMEdgeManager: # noqa: C901
for partitioner in partitioners:
logging.info(f"--> {partitioner.__class__.__name__}")

breakpoint()
if args.generate_etrecord:
if not builder_exported_to_edge.edge_manager:
raise ValueError("Unable to generate etrecord due to missing edge manager.")
Expand All @@ -793,14 +795,19 @@ def _export_llama(args) -> LLMEdgeManager: # noqa: C901
)
logging.info("Generated etrecord.bin")
else:
builder = builder_exported_to_edge.to_backend(partitioners)
builder_lowered = builder_exported.pt2e_quantize(
quantizers
).to_edge_transform_and_lower(
partitioners
)
# builder = builder_exported_to_edge.to_backend(partitioners)
if args.num_sharding > 0 and args.qnn:
from executorch.backends.qualcomm.utils.utils import canonicalize_program

# pyre-fixme[16]: Module `backends` has no attribute `qualcomm`.
canonicalize_program(builder.edge_manager.exported_program())

builder = builder.to_executorch()
builder = builder_lowered.to_executorch()

if args.profile_memory:
generate_memory_trace(builder.export_program, "memory_profile.json")
Expand Down
95 changes: 54 additions & 41 deletions extension/llm/export/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
DuplicateDynamicQuantChainPass,
)
from executorch.backends.xnnpack._passes.convert_to_linear import ConvertToLinearPass
from executorch.exir import EdgeProgramManager
from executorch.exir import EdgeProgramManager, to_edge_transform_and_lower
from executorch.exir.backend.partitioner import Partitioner

from executorch.exir.backend.utils import format_delegated_graph
Expand Down Expand Up @@ -216,6 +216,7 @@ def export(self) -> "LLMEdgeManager":
)
# pyre-fixme[8]: Attribute has type `Optional[GraphModule]`; used as
# `Module`.
self.pre_autograd_exported_program = exported_module
self.pre_autograd_graph_module = exported_module.module()
if hasattr(self.args, "export_only") and self.args.export_only:
torch.export.save(exported_module, self.args.output_name)
Expand Down Expand Up @@ -305,51 +306,51 @@ def pt2e_quantize(self, quantizers: Optional[List[Quantizer]]) -> "LLMEdgeManage
), "export_to_edge is already called, please call pt2e_quantize before export_to_edge"
logging.info(f"Using pt2e {quantizers} to quantizing the model...")

if not quantizers:
logging.info("No quantizer provided, passing...")
return self

# 1. torch.nn.attention.sdpa_kernel([SDPBackend.MATH]) is for bypassing the dynamo error when tracing
# 2. torch.no_grad() is for getting rid of the dropout (not sure why training ops will show up)
if quantizers:
with torch.nn.attention.sdpa_kernel([SDPBackend.MATH]), torch.no_grad():
if self.verbose:
logging.info(f"Applied quantizers: {quantizers}")
composed_quantizer = ComposableQuantizer(quantizers)
assert (
self.pre_autograd_graph_module is not None
), "Please run export() first"
m = prepare_pt2e(self.pre_autograd_graph_module, composed_quantizer)
with torch.nn.attention.sdpa_kernel([SDPBackend.MATH]), torch.no_grad():
if self.verbose:
logging.info(f"Applied quantizers: {quantizers}")
composed_quantizer = ComposableQuantizer(quantizers)
assert (
self.pre_autograd_graph_module is not None
), "Please run export() first"
m = prepare_pt2e(self.pre_autograd_graph_module, composed_quantizer)
logging.info(
f"Calibrating with tasks: {self.calibration_tasks}, limit: {self.calibration_limit}, calibration_data: {self.calibration_data}, tokenizer_path: {self.tokenizer_path}, seq_length: {self.calibration_seq_length}"
)
# Calibrate
if (
self.calibration_tasks is not None
and self.calibration_limit is not None
and self.calibration_seq_length is not None
and self.calibration_data is not None
and self.tokenizer_path is not None
):
logging.info(
f"Calibrating with tasks: {self.calibration_tasks}, limit: {self.calibration_limit}, calibration_data: {self.calibration_data}, tokenizer_path: {self.tokenizer_path}, seq_length: {self.calibration_seq_length}"
)
# Calibrate
if (
self.calibration_tasks is not None
and self.calibration_limit is not None
and self.calibration_seq_length is not None
and self.calibration_data is not None
and self.tokenizer_path is not None
):
logging.info(
f"Calibrating with tasks: {self.calibration_tasks}, limit: {self.calibration_limit}, calibration_data: {self.calibration_data}, tokenizer_path: {self.tokenizer_path}, seq_length: {self.calibration_seq_length}"
)
self.pt2e_calibrate(
prepared_module=m,
calibration_tasks=self.calibration_tasks,
calibration_limit=self.calibration_limit,
calibration_seq_length=self.calibration_seq_length,
calibration_data=self.calibration_data,
tokenizer_path=self.tokenizer_path,
)
else:
logging.info(
"No calibration provided, using dummy input to calibrate..."
)
m(*self.example_inputs)
m = convert_pt2e(m)
DuplicateDynamicQuantChainPass()(m)
self.pre_autograd_graph_module = m
return self
else:
logging.info("No quantizer provided, passing...")
return self
self.pt2e_calibrate(
prepared_module=m,
calibration_tasks=self.calibration_tasks,
calibration_limit=self.calibration_limit,
calibration_seq_length=self.calibration_seq_length,
calibration_data=self.calibration_data,
tokenizer_path=self.tokenizer_path,
)
else:
logging.info(
"No calibration provided, using dummy input to calibrate..."
)
m(*self.example_inputs, **self.example_kwarg_inputs)
m = convert_pt2e(m)
DuplicateDynamicQuantChainPass()(m)
self.pre_autograd_graph_module = m
return self

def export_to_edge(self) -> "LLMEdgeManager":
"""
Expand Down Expand Up @@ -415,6 +416,18 @@ def to_backend(self, partitioners: Optional[List[Partitioner]]) -> "LLMEdgeManag

return self

def to_edge_transform_and_lower(self, partitioners: Optional[List[Partitioner]]) -> "LLMEdgeManager":
if partitioners is None:
logging.info("No partitioner provided, skipping backend lowering...")
breakpoint()
edge_config = self._get_edge_config()
self.edge_manager = to_edge_transform_and_lower(
self.pre_autograd_exported_program,
partitioner=partitioners,
compile_config=edge_config,
)
return self

def to_executorch(self) -> "LLMEdgeManager":
"""
Lower the model to executorch and get an ExecutorchProgram.
Expand Down

0 comments on commit d8deda2

Please sign in to comment.