Skip to content

Commit

Permalink
Custom SDPA in attention
Browse files Browse the repository at this point in the history
  • Loading branch information
dvorjackz committed Jan 2, 2025
1 parent 8145cda commit 1e3d978
Showing 1 changed file with 42 additions and 22 deletions.
64 changes: 42 additions & 22 deletions extension/llm/modules/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import torch
import torchtune.modules.attention as TorchTuneAttention
from executorch.extension.llm.modules.kv_cache import KVCache as InferenceKVCache
from executorch.extension.llm.custom_ops import custom_ops
from torch import nn
from torchtune.modules.attention_utils import _MaskType, _sdpa_or_flex_attention
from torchtune.modules.kv_cache import KVCache
Expand Down Expand Up @@ -146,6 +147,7 @@ def __init__(
# Use flex attention if supported and we are sample packing
self._attention_call = _sdpa_or_flex_attention()
self._sdpa = SDPA(
max_seq_len=self.max_seq_len,
num_kv_heads=self.num_kv_heads,
num_heads=self.num_heads,
head_dim=self.head_dim,
Expand Down Expand Up @@ -310,7 +312,7 @@ def false_fn(y):
self.kv_cache.v_cache.copy_(v)
self.kv_cache.cache_pos.copy_(cache_pos)

output = self._sdpa(q, k, v, b, s_x, mask=mask)
output = self._sdpa(q, k, v, b, s_x, mask=mask, input_pos=input_pos)
return self.output_proj(output)


Expand All @@ -322,6 +324,7 @@ class SDPA(nn.Module):

def __init__(
self,
max_seq_len: int,
num_kv_heads: int,
num_heads: int,
head_dim: int,
Expand All @@ -331,6 +334,7 @@ def __init__(
kv_cache,
) -> None:
super().__init__()
self.max_seq_len = max_seq_len
self.num_kv_heads = num_kv_heads
self.num_heads = num_heads
self.head_dim = head_dim
Expand All @@ -348,32 +352,48 @@ def forward(
bsz: int,
seq_len: int,
mask: Optional[_MaskType] = None,
# Below args are only used for ET custom sdpa op.
input_pos: Optional[torch.Tensor] = None,
) -> torch.Tensor:
# View + expand + reshape bring num_kv_heads to num_heads for k and v
# to match q.

# [bsz, n_h, s, h_d]
q = q.transpose(1, 2)
k = k.transpose(1, 2)
v = v.transpose(1, 2)

# Expand the key and value tensors to have the same shape
# as the query tensor by copying values across the relevant dim
if self.num_heads != self.num_kv_heads:
expand_shape = (-1, -1, self.q_per_kv, -1, -1)
k = k.unsqueeze(2).expand(expand_shape).flatten(1, 2)
v = v.unsqueeze(2).expand(expand_shape).flatten(1, 2)

output = self._attention_fn(
start_pos = input_pos[0][-1].item() - seq_len + 1
torch._check_is_size(start_pos)
torch._check(start_pos <= self.max_seq_len)
output = torch.ops.llama.custom_sdpa(
q,
k,
v,
mask=mask,
dropout_p=self.attn_dropout,
is_causal=self.kv_cache is None and mask is None and self.is_causal,
start_pos,
None, # Attention mask
0, # dropout probability. Ignored by the code
True, # is_causal TODO: flip to false if kv cache is enabled???
)
# Reshape the output to be the same shape as the input
return output.transpose(1, 2).contiguous().view(bsz, seq_len, -1)
return output.view(bsz, seq_len, -1)

# # View + expand + reshape bring num_kv_heads to num_heads for k and v
# # to match q.

# # [bsz, n_h, s, h_d]
# q = q.transpose(1, 2)
# k = k.transpose(1, 2)
# v = v.transpose(1, 2)

# # Expand the key and value tensors to have the same shape
# # as the query tensor by copying values across the relevant dim
# if self.num_heads != self.num_kv_heads:
# expand_shape = (-1, -1, self.q_per_kv, -1, -1)
# k = k.unsqueeze(2).expand(expand_shape).flatten(1, 2)
# v = v.unsqueeze(2).expand(expand_shape).flatten(1, 2)

# output = self._attention_fn(
# q,
# k,
# v,
# mask=mask,
# dropout_p=self.attn_dropout,
# is_causal=self.kv_cache is None and mask is None and self.is_causal,
# )
# # Reshape the output to be the same shape as the input
# return output.transpose(1, 2).contiguous().view(bsz, seq_len, -1)


def _replace_mha_with_inference_mha(module: torch.nn.Module) -> None:
Expand Down

0 comments on commit 1e3d978

Please sign in to comment.