diff --git a/examples/models/llama2/TARGETS b/examples/models/llama2/TARGETS index 09ebd5aead..9da7a26d6d 100644 --- a/examples/models/llama2/TARGETS +++ b/examples/models/llama2/TARGETS @@ -18,7 +18,6 @@ runtime.python_library( ], deps = [ "//caffe2:torch", - "//executorch/examples/models/llama2/custom_ops:custom_ops_aot_py", ], ) @@ -86,6 +85,7 @@ runtime.python_library( "//executorch/backends/vulkan/partitioner:vulkan_partitioner", "//executorch/examples/models:model_base", "//executorch/examples/models:models", + "//executorch/examples/models/llama2/custom_ops:custom_ops_aot_py", "//executorch/examples/portable:utils", "//executorch/exir:lib", "//executorch/sdk/etrecord:etrecord", diff --git a/examples/models/llama2/export_llama_lib.py b/examples/models/llama2/export_llama_lib.py index 8728b3fdd2..aa195209ad 100644 --- a/examples/models/llama2/export_llama_lib.py +++ b/examples/models/llama2/export_llama_lib.py @@ -23,7 +23,11 @@ XnnpackDynamicallyQuantizedPartitioner, ) -from executorch.examples.models.llama2.llama_transformer import Transformer +from executorch.examples.models.llama2.llama_transformer import ( + KVCache, + SDPA, + Transformer, +) from executorch.exir.backend.backend_details import CompileSpec from executorch.sdk.etrecord import generate_etrecord @@ -88,6 +92,58 @@ def materialze_broadcast_of_rope_freq_cis( return module +class SDPACustom(torch.nn.Module): + def __init__( + self, + kv_cache: KVCache, + mask, + dim: int, + ): + super().__init__() + self.kv_cache = kv_cache + self.mask = mask + self.dim = dim + + def forward( + self, + input_pos: torch.Tensor, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + bsz, + seqlen, + ): + output = torch.ops.llama.sdpa_with_kv_cache( + q, + k, + v, + self.kv_cache.k_cache, + self.kv_cache.v_cache, + input_pos[-1].item(), + seqlen, + ) + return output.view(bsz, seqlen, self.dim) + + +def _replace_sdpa_with_custom_op(module: torch.nn.Module): + for name, child in module.named_children(): + if isinstance(child, SDPA): + setattr( + module, + name, + SDPACustom(child.kv_cache, child.mask, child.dim), + ) + else: + _replace_sdpa_with_custom_op(child) + + +def replace_sdpa_with_custom_op(module: torch.nn.Module) -> torch.nn.Module: + from executorch.examples.models.llama2.custom_ops import sdpa_with_kv_cache # noqa + + _replace_sdpa_with_custom_op(module) + return module + + def quantize( model: torch.nn.Module, qmode: str, @@ -493,8 +549,7 @@ def _prepare_for_llama_export(modelname: str, args) -> LlamaEdgeManager: transforms.append(materialze_broadcast_of_rope_freq_cis) if args.use_sdpa_with_kv_cache: - pass - # TODO: Next diff transforms.append() + transforms.append(replace_sdpa_with_custom_op) return ( load_llama_model( diff --git a/examples/models/llama2/llama_transformer.py b/examples/models/llama2/llama_transformer.py index d0794b8c37..c353a913bf 100644 --- a/examples/models/llama2/llama_transformer.py +++ b/examples/models/llama2/llama_transformer.py @@ -214,14 +214,12 @@ def __init__( self, kv_cache: KVCache, mask, - use_sdpa_with_kv_cache_op: bool, dim: int, n_rep: int, ): super().__init__() self.kv_cache = kv_cache self.mask = mask - self.use_sdpa_with_kv_cache_op = use_sdpa_with_kv_cache_op self.dim = dim self.n_rep = n_rep @@ -233,56 +231,6 @@ def forward( v: torch.Tensor, bsz, seqlen, - ) -> torch.Tensor: - if not self.use_sdpa_with_kv_cache_op: - return self._forward_default( - input_pos, - q, - k, - v, - bsz, - seqlen, - ) - else: - return self._forward_custom( - input_pos, - q, - k, - v, - bsz, - seqlen, - ) - - def _forward_custom( - self, - input_pos: torch.Tensor, - q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, - bsz, - seqlen, - ): - from .custom_ops import sdpa_with_kv_cache # noqa - - output = torch.ops.llama.sdpa_with_kv_cache( - q, - k, - v, - self.kv_cache.k_cache, - self.kv_cache.v_cache, - input_pos[-1].item(), - seqlen, - ) - return output.view(bsz, seqlen, self.dim) - - def _forward_default( - self, - input_pos: torch.Tensor, - q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, - bsz, - seqlen, ) -> torch.Tensor: q = q.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim) k = k.transpose(1, 2) @@ -341,7 +289,6 @@ def __init__(self, args: ModelArgs, layer_id: int): self.SDPA = SDPA( self.kv_cache, self.mask, - args.use_sdpa_with_kv_cache_op, self.dim, self.n_rep, )