From 1e3d978cfb9185e763d7f4f6071817bdaf8e92d1 Mon Sep 17 00:00:00 2001 From: Jack Zhang Date: Thu, 2 Jan 2025 09:26:44 -0800 Subject: [PATCH] Custom SDPA in attention --- extension/llm/modules/attention.py | 64 ++++++++++++++++++++---------- 1 file changed, 42 insertions(+), 22 deletions(-) diff --git a/extension/llm/modules/attention.py b/extension/llm/modules/attention.py index 60183801b42..b5671838c2b 100644 --- a/extension/llm/modules/attention.py +++ b/extension/llm/modules/attention.py @@ -10,6 +10,7 @@ import torch import torchtune.modules.attention as TorchTuneAttention from executorch.extension.llm.modules.kv_cache import KVCache as InferenceKVCache +from executorch.extension.llm.custom_ops import custom_ops from torch import nn from torchtune.modules.attention_utils import _MaskType, _sdpa_or_flex_attention from torchtune.modules.kv_cache import KVCache @@ -146,6 +147,7 @@ def __init__( # Use flex attention if supported and we are sample packing self._attention_call = _sdpa_or_flex_attention() self._sdpa = SDPA( + max_seq_len=self.max_seq_len, num_kv_heads=self.num_kv_heads, num_heads=self.num_heads, head_dim=self.head_dim, @@ -310,7 +312,7 @@ def false_fn(y): self.kv_cache.v_cache.copy_(v) self.kv_cache.cache_pos.copy_(cache_pos) - output = self._sdpa(q, k, v, b, s_x, mask=mask) + output = self._sdpa(q, k, v, b, s_x, mask=mask, input_pos=input_pos) return self.output_proj(output) @@ -322,6 +324,7 @@ class SDPA(nn.Module): def __init__( self, + max_seq_len: int, num_kv_heads: int, num_heads: int, head_dim: int, @@ -331,6 +334,7 @@ def __init__( kv_cache, ) -> None: super().__init__() + self.max_seq_len = max_seq_len self.num_kv_heads = num_kv_heads self.num_heads = num_heads self.head_dim = head_dim @@ -348,32 +352,48 @@ def forward( bsz: int, seq_len: int, mask: Optional[_MaskType] = None, + # Below args are only used for ET custom sdpa op. + input_pos: Optional[torch.Tensor] = None, ) -> torch.Tensor: - # View + expand + reshape bring num_kv_heads to num_heads for k and v - # to match q. - - # [bsz, n_h, s, h_d] - q = q.transpose(1, 2) - k = k.transpose(1, 2) - v = v.transpose(1, 2) - - # Expand the key and value tensors to have the same shape - # as the query tensor by copying values across the relevant dim - if self.num_heads != self.num_kv_heads: - expand_shape = (-1, -1, self.q_per_kv, -1, -1) - k = k.unsqueeze(2).expand(expand_shape).flatten(1, 2) - v = v.unsqueeze(2).expand(expand_shape).flatten(1, 2) - - output = self._attention_fn( + start_pos = input_pos[0][-1].item() - seq_len + 1 + torch._check_is_size(start_pos) + torch._check(start_pos <= self.max_seq_len) + output = torch.ops.llama.custom_sdpa( q, k, v, - mask=mask, - dropout_p=self.attn_dropout, - is_causal=self.kv_cache is None and mask is None and self.is_causal, + start_pos, + None, # Attention mask + 0, # dropout probability. Ignored by the code + True, # is_causal TODO: flip to false if kv cache is enabled??? ) - # Reshape the output to be the same shape as the input - return output.transpose(1, 2).contiguous().view(bsz, seq_len, -1) + return output.view(bsz, seq_len, -1) + + # # View + expand + reshape bring num_kv_heads to num_heads for k and v + # # to match q. + + # # [bsz, n_h, s, h_d] + # q = q.transpose(1, 2) + # k = k.transpose(1, 2) + # v = v.transpose(1, 2) + + # # Expand the key and value tensors to have the same shape + # # as the query tensor by copying values across the relevant dim + # if self.num_heads != self.num_kv_heads: + # expand_shape = (-1, -1, self.q_per_kv, -1, -1) + # k = k.unsqueeze(2).expand(expand_shape).flatten(1, 2) + # v = v.unsqueeze(2).expand(expand_shape).flatten(1, 2) + + # output = self._attention_fn( + # q, + # k, + # v, + # mask=mask, + # dropout_p=self.attn_dropout, + # is_causal=self.kv_cache is None and mask is None and self.is_causal, + # ) + # # Reshape the output to be the same shape as the input + # return output.transpose(1, 2).contiguous().view(bsz, seq_len, -1) def _replace_mha_with_inference_mha(module: torch.nn.Module) -> None: