Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[DRAFT] Use custom SDPA in Llama 3.2 MM #7471

Open
wants to merge 13 commits into
base: main
Choose a base branch
from
10 changes: 8 additions & 2 deletions examples/models/llama/export_llama_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import torch

from executorch.devtools.etrecord import generate_etrecord
from executorch.exir.passes.init_mutable_pass import InitializedMutableBufferPass

from executorch.extension.llm.export.builder import DType, LLMEdgeManager

Expand Down Expand Up @@ -760,6 +761,9 @@ def _export_llama(args) -> LLMEdgeManager: # noqa: C901
for partitioner in partitioners:
logging.info(f"--> {partitioner.__class__.__name__}")

additional_passes = []
if args.model in TORCHTUNE_DEFINED_MODELS:
additional_passes = [InitializedMutableBufferPass(["cache_pos"])]
if args.generate_etrecord:
if not builder_exported_to_edge.edge_manager:
raise ValueError("Unable to generate etrecord due to missing edge manager.")
Expand All @@ -774,7 +778,9 @@ def _export_llama(args) -> LLMEdgeManager: # noqa: C901
# pyre-fixme[16]: Module `backends` has no attribute `qualcomm`.
canonicalize_program(builder.edge_manager.exported_program())

builder = builder.to_executorch()
builder = builder.to_executorch(
passes=additional_passes,
)

# Generate ETRecord
if edge_manager_copy:
Expand All @@ -792,7 +798,7 @@ def _export_llama(args) -> LLMEdgeManager: # noqa: C901
# pyre-fixme[16]: Module `backends` has no attribute `qualcomm`.
canonicalize_program(builder.edge_manager.exported_program())

builder = builder.to_executorch()
builder = builder.to_executorch(passes=additional_passes)

if args.profile_memory:
generate_memory_trace(builder.export_program, "memory_profile.json")
Expand Down
18 changes: 15 additions & 3 deletions examples/models/llama3_2_vision/runner/native.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,15 @@
TorchTuneLlamaRunner,
)

from executorch.extension.pybindings.portable_lib import _load_for_executorch
from executorch.extension.pybindings.portable_lib import (
_load_for_executorch_from_buffer,
)

# Load custom ops and quantized ops.
from executorch.extension.pybindings import portable_lib # noqa # usort: skip

# Note: import this after portable_lib
from executorch.extension.llm.custom_ops import sdpa_with_kv_cache # noqa # usort: skip
from executorch.extension.llm.custom_ops import custom_ops # noqa # usort: skip
from executorch.kernels import quantized # noqa


Expand All @@ -43,7 +45,17 @@ def __init__(self, args):
use_kv_cache=args.kv_cache,
vocab_size=params["vocab_size"],
)
self.model = _load_for_executorch(args.pte)
# Save the loaded model bytes to prevent data from going out of
# scope after the `with` and getting cleaned up by Python's
# garbage collector.
self.model_bytes = None
with open(args.pte, "rb") as f:
self.model_bytes = f.read()
# Need to use _load_for_executorch_from_buffer instead of
# _load_for_executorch because the latter uses MmapDataLoader,
# which doesn't have load_into() implemented, which is needed
# for loading initialized mutable buffers.
self.model = _load_for_executorch_from_buffer(self.model_bytes)
self.use_kv_cache = args.kv_cache

def forward(
Expand Down
10 changes: 8 additions & 2 deletions exir/emit/_emitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -1575,7 +1575,8 @@ def _find_fqn_for_placeholder(
warnings.warn(
"Mutation on a buffer in the model is detected. ExecuTorch assumes "
"buffers that are mutated in the graph have a meaningless initial state, "
"only the shape and dtype will be serialized.",
"only the shape and dtype will be serialized, unless a pass which marks "
"spec.const=True such as InitializedMutableBufferPass is run.",
UserWarning,
stacklevel=1,
)
Expand All @@ -1602,6 +1603,7 @@ def placeholder(
"""
spec = self.node.meta["spec"]
constant_tag = self.node.meta.get("constant_tag", None)
initialize_buffer = self.node.meta.get("et_init_buffer", None)
is_user_input = True

if isinstance(target, str) and isinstance(spec, TensorSpec):
Expand Down Expand Up @@ -1655,7 +1657,11 @@ def placeholder(
spec.storage = real_tensor.untyped_storage()

# User inputs and mutable buffers are not constants, other buffers or parameters are.
spec.const = not (is_user_input or is_mutable_buffer)
if initialize_buffer:
assert is_mutable_buffer
spec.const = True
else:
spec.const = not (is_user_input or is_mutable_buffer)

evalue = (
self._tensor_spec_to_evalue(spec, constant_tag)
Expand Down
31 changes: 31 additions & 0 deletions exir/passes/init_mutable_pass.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.


from typing import List

from executorch.exir.pass_base import ExportPass


class InitializedMutableBufferPass(ExportPass):
"""
If the buffer has the name "cache_pos", such as in an kv_cache
module with `self.register_buffer("cache_pos", torch.arange(10))`,
mark it with a custom tag which later is used by the emitter to
flag spec.const to True, which provides the mutable buffer with
an initialized state.
"""

def __init__(self, patterns: List[str]) -> None:
super().__init__()
self.patterns = patterns

def placeholder(self, name: str, arg, meta):
for pattern in self.patterns:
if pattern in name:
meta["et_init_buffer"] = True

return super().placeholder(name, arg, meta)
23 changes: 15 additions & 8 deletions extension/llm/export/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from executorch.exir.backend.utils import format_delegated_graph
from executorch.exir.capture._config import EdgeCompileConfig, ExecutorchBackendConfig

from executorch.exir.pass_manager import PassType
from executorch.exir.passes import MemoryPlanningPass
from executorch.exir.passes.quant_fusion_pass import QuantFusionPass
from executorch.exir.passes.sym_shape_eval_pass import ConstraintBasedSymShapeEvalPass
Expand Down Expand Up @@ -395,21 +396,27 @@ def to_backend(self, partitioners: Optional[List[Partitioner]]) -> "LLMEdgeManag

return self

def to_executorch(self) -> "LLMEdgeManager":
def to_executorch(
self, passes: Optional[List[PassType]] = None
) -> "LLMEdgeManager":
"""
Lower the model to executorch and get an ExecutorchProgram.
"""
assert self.edge_manager, "Need to run export_to_edge() first"
to_executorch_passes = [
# If there are Linear operations left in the graph, let's execute
# them with the optimized op_linear rather than materializing a
# transpose followed by a regular op_mm.
ConvertToLinearPass(),
QuantFusionPass(),
]
if passes:
to_executorch_passes.extend(passes)

self.export_program = self.edge_manager.to_executorch(
ExecutorchBackendConfig(
extract_delegate_segments=True,
passes=[
# If there are Linear operations left in the graph, let's execute
# them with the optimized op_linear rather than materializing a
# transpose followed by a regular op_mm.
ConvertToLinearPass(),
QuantFusionPass(),
],
passes=to_executorch_passes,
memory_planning_pass=MemoryPlanningPass(alloc_graph_input=False),
sym_shape_eval_pass=ConstraintBasedSymShapeEvalPass(),
)
Expand Down
22 changes: 21 additions & 1 deletion extension/llm/modules/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import torch
import torchtune.modules.attention as TorchTuneAttention
from executorch.extension.llm.modules.kv_cache import KVCache as InferenceKVCache
from executorch.extension.llm.custom_ops import custom_ops
from torch import nn
from torchtune.modules.attention_utils import _MaskType, _sdpa_or_flex_attention
from torchtune.modules.kv_cache import KVCache
Expand Down Expand Up @@ -146,6 +147,7 @@ def __init__(
# Use flex attention if supported and we are sample packing
self._attention_call = _sdpa_or_flex_attention()
self._sdpa = SDPA(
max_seq_len=self.max_seq_len,
num_kv_heads=self.num_kv_heads,
num_heads=self.num_heads,
head_dim=self.head_dim,
Expand Down Expand Up @@ -310,7 +312,7 @@ def false_fn(y):
self.kv_cache.v_cache.copy_(v)
self.kv_cache.cache_pos.copy_(cache_pos)

output = self._sdpa(q, k, v, b, s_x, mask=mask)
output = self._sdpa(q, k, v, b, s_x, mask=mask, input_pos=input_pos)
return self.output_proj(output)


Expand All @@ -322,6 +324,7 @@ class SDPA(nn.Module):

def __init__(
self,
max_seq_len: int,
num_kv_heads: int,
num_heads: int,
head_dim: int,
Expand All @@ -331,6 +334,7 @@ def __init__(
kv_cache,
) -> None:
super().__init__()
self.max_seq_len = max_seq_len
self.num_kv_heads = num_kv_heads
self.num_heads = num_heads
self.head_dim = head_dim
Expand All @@ -348,7 +352,23 @@ def forward(
bsz: int,
seq_len: int,
mask: Optional[_MaskType] = None,
# Below args are only used for ET custom sdpa op.
input_pos: Optional[torch.Tensor] = None,
) -> torch.Tensor:
start_pos = input_pos[0][-1].item() - seq_len + 1
torch._check_is_size(start_pos)
torch._check(start_pos <= self.max_seq_len)
output = torch.ops.llama.custom_sdpa(
q,
k,
v,
start_pos,
None, # Attention mask
0, # dropout probability. Ignored by the code
True, # is_causal TODO: flip to false if kv cache is enabled???
)
return output.view(bsz, seq_len, -1)

# View + expand + reshape bring num_kv_heads to num_heads for k and v
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The below is skipped, just to make this diff more clear

# to match q.

Expand Down
Loading
Loading