Skip to content

Commit

Permalink
[sharktank] Revert "[llama] Added the fused rotary embedding kernel (#…
Browse files Browse the repository at this point in the history
…719)" (#752)

This reverts commit fc9576b, which
causes segmentation fault at Llama prefill in IREE runtime.
  • Loading branch information
archana-ramalingam authored Jan 10, 2025
1 parent 4cc3f02 commit 63ff841
Show file tree
Hide file tree
Showing 9 changed files with 69 additions and 200 deletions.
9 changes: 8 additions & 1 deletion sharktank/sharktank/examples/export_paged_llm_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ def main():
hp,
tensor_parallelism_size=tensor_parallelism_size,
use_hf=False,
static_tables=False, # Rely on the compiler for hoisting tables.
kv_cache_type="direct" if args.bs == [1] else "paged",
attention_kernel=args.attention_kernel,
block_seq_stride=args.block_seq_stride,
Expand Down Expand Up @@ -218,16 +219,22 @@ def _(model, tokens, seq_lens, seq_block_ids, cs):
else:
cache_tensors = cs

sl = tokens.shape[1]
input_mask = model.input_mask(seq_lens, sl)
attention_mask = model.attention_mask(input_mask)

if llama_config.tensor_parallelism_size != 1:
shard_count = llama_config.tensor_parallelism_size

tokens = ops.replicate(tokens, count=shard_count)
attention_mask = ops.replicate(attention_mask, count=shard_count)
seq_block_ids = ops.replicate(seq_block_ids, count=shard_count)

cache_tensors = repack_cache(cs, cache_shard_dim)

logits = model.prefill(
tokens,
attention_mask=None, # We rely on causal attention
attention_mask=attention_mask,
seq_block_ids=seq_block_ids,
cache_state=cache_tensors,
)
Expand Down
1 change: 0 additions & 1 deletion sharktank/sharktank/kernels/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
from .mmt_block_scaled_offset_q4 import *
from .mmt_block_scaled_q8 import *
from .mmt_super_block_scaled_offset_q4 import *
from .rotary import *
from .batch_matmul_transpose_b import *
from .conv_2d_nchw_fchw import *
from .pooling_nchw_sum import *
Expand Down
70 changes: 0 additions & 70 deletions sharktank/sharktank/kernels/rotary.py

This file was deleted.

63 changes: 0 additions & 63 deletions sharktank/sharktank/kernels/templates/rotary_embedding.mlir

This file was deleted.

2 changes: 1 addition & 1 deletion sharktank/sharktank/layers/paged_llama_attention_block.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,7 @@ def repeat_kv(x: torch.Tensor) -> torch.Tensor:
k=keys, # [bs, ..., sl, dim]
v=values, # [bs, ..., sl, dim]
a=attention_mask, # [bs, ..., sl, sl]
is_causal=attention_mask is None, # assumes causal masking when true
is_causal=False, # assumes causal masking when true
scale=None, # defaults to 1/sqrt(dim)
)

Expand Down
88 changes: 57 additions & 31 deletions sharktank/sharktank/layers/rotary_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@

from .base import BaseLayer
from .. import ops
from .. import kernels
from ..types import SplitPrimitiveTensor, ReplicatedTensor, unbox_tensor


Expand All @@ -26,6 +25,7 @@ def __init__(
rope_freq_base: Optional[float],
device: Optional[torch.device] = None,
use_hf: bool = False,
static_tables: bool = False,
use_table: bool = True,
tensor_parallelism_size: int = 1,
):
Expand All @@ -34,44 +34,60 @@ def __init__(
self.rope_dimension_count = rope_dimension_count
self.max_seqlen = max_seqlen
self.use_hf = use_hf
self.static_tables = static_tables
self.use_table = use_table

self.rope_freq_base = rope_freq_base if rope_freq_base is not None else 10000.0
self.tensor_parallelism_size = tensor_parallelism_size
if static_tables:
ops.module_register_buffer(
self, "static_rotary_embed_table", self._create_rotary_embed_table()
)
else:
self.static_rotary_embed_table = None

@property
def rotary_embed_table(self):
return self._create_rotary_embed_table()
if self.use_table:
if self.static_tables:
return self.static_rotary_embed_table
return self._create_rotary_embed_table()

return None

def forward(
self,
*,
xt: Union[torch.Tensor, SplitPrimitiveTensor],
start_index: int,
):
table = self.rotary_embed_table
if not isinstance(xt, SplitPrimitiveTensor):
if isinstance(xt, SplitPrimitiveTensor):
rotary_shards = [None] * xt.shard_count
if self.rotary_embed_table is not None:
assert (
isinstance(self.rotary_embed_table, ReplicatedTensor)
and xt.shard_count == self.rotary_embed_table.shard_count
)
rotary_shards = [
unbox_tensor(shard) for shard in self.rotary_embed_table.shards
]

xt_shards = [
self.forward_unsharded(
xt=unbox_tensor(xt_shard),
start_index=start_index,
rotary_embed_table=rotary_shard,
)
for xt_shard, rotary_shard in zip(xt.shards, rotary_shards)
]
xt = SplitPrimitiveTensor(ts=xt_shards, shard_dim=xt.shard_dim)
return xt
else:
return self.forward_unsharded(
xt=xt,
start_index=start_index,
rotary_embed_table=table,
)

assert (
isinstance(table, ReplicatedTensor) and xt.shard_count == table.shard_count
)
rotary_shards = [unbox_tensor(shard) for shard in table.shards]

xt_shards = [
self.forward_unsharded(
xt=unbox_tensor(xt_shard),
start_index=start_index,
rotary_embed_table=rotary_shard,
rotary_embed_table=self.rotary_embed_table,
)
for xt_shard, rotary_shard in zip(xt.shards, rotary_shards)
]
xt = SplitPrimitiveTensor(ts=xt_shards, shard_dim=xt.shard_dim)
return xt

def _create_interleaved_tensor(_, dim):
"""Creates a tensor which indexes an tensor such that
Expand Down Expand Up @@ -127,17 +143,18 @@ def forward_unsharded(
# Offset the table based on starting position.
if self.use_table:
freqs_cis = rotary_embed_table[start_index : start_index + sl, :]
freqs_cis = freqs_cis[0:sl, :]
freqs_cis = freqs_cis[None, 0:sl, None, :]
else:
freqs_cis = torch.arange(sl, device=xt.device) + start_index
freqs_cis = self._compute_rotary_embed_table(freqs_cis)
freqs_cis = self._compute_rotary_embed_table(freqs_cis)[None, :, None, :]

assert (
freqs_cis.shape[0] >= sl
freqs_cis.shape[1] >= sl
), f"Sequence length longer than embedding table ({sl} vs {freqs_cis.shape[0]})"

freqs_cis = ops.repeat(freqs_cis[None, :, :], (xt_.shape[0], 1, 1))
xt_out = kernels.apply_rotary_embedding(xt_.to(freqs_cis.dtype), freqs_cis)
xt_ = ops.view_as_complex(xt_)
xt_ = xt_ * freqs_cis
xt_out = ops.view_as_real(xt_)

if self.use_hf:
xt_out = xt_out[..., self._create_ordering_tensor(xt_out.shape[-1])]
Expand All @@ -164,7 +181,7 @@ def compute_batch_mask(
self.trace_tensor("rope.positions_seq", positions_seq)

if self.use_table:
freqs_cis = self.rotary_embed_table[positions_seq.flatten()]
freqs_cis = self.rotary_embed_table[positions_seq]
else:
shape = positions_seq.shape
if isinstance(positions_seq, ReplicatedTensor):
Expand All @@ -175,8 +192,11 @@ def compute_batch_mask(
freqs_cis = ReplicatedTensor(ts=ts)
else:
freqs_cis = self._compute_rotary_embed_table(positions_seq.flatten())
freqs_cis = freqs_cis.unflatten(0, shape)

return freqs_cis.unsqueeze(1)
# Unsqueeze a unit dim for attention heads.
broadcast_freqs_cis = freqs_cis.unsqueeze(2)
return broadcast_freqs_cis

def apply_batched_mask(
self,
Expand Down Expand Up @@ -212,7 +232,9 @@ def apply_batched_mask_unsharded(self, *, xt: torch.Tensor, mask: torch.Tensor):
if self.use_hf:
xt = xt[..., self._create_interleaved_tensor(xt.shape[-1])]

xt_out = kernels.apply_rotary_embedding(xt.to(mask.dtype), mask)
xt_ = ops.view_as_complex(xt)
xt_ = xt_ * mask
xt_out = ops.view_as_real(xt_)

if self.use_hf:
xt_out = xt_out[..., self._create_ordering_tensor(xt_out.shape[-1])]
Expand All @@ -222,10 +244,14 @@ def apply_batched_mask_unsharded(self, *, xt: torch.Tensor, mask: torch.Tensor):
def _compute_rotary_embed_table(self, t):
dim = self.rope_dimension_count
freqs = 1.0 / (
self.rope_freq_base ** ((torch.arange(0, dim) // 2).float() / dim * 2.0)
self.rope_freq_base ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)
)
freqs = torch.outer(t, freqs).float()
return freqs

cos = torch.cos(freqs)
sin = torch.sin(freqs)
complex = torch.complex(cos, sin)
return complex

def _create_rotary_embed_table(self):
t = torch.arange(self.max_seqlen, device=self.device)
Expand Down
4 changes: 3 additions & 1 deletion sharktank/sharktank/models/llama/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ def __init__(self, theta: Theta, config: LlamaModelConfig):
super().__init__(
theta,
context_length=config.hp.context_length,
static_tables=config.static_tables,
device=config.device,
activation_dtype=config.activation_dtype,
attention_dtype=config.attention_dtype,
Expand All @@ -91,6 +92,7 @@ def __init__(self, theta: Theta, config: LlamaModelConfig):
max_seqlen=hp.context_length,
device=self.device,
use_hf=self.use_hf,
static_tables=config.static_tables,
tensor_parallelism_size=config.tensor_parallelism_size,
),
)
Expand Down Expand Up @@ -124,7 +126,7 @@ def prefill(
tokens: Union[torch.Tensor, ReplicatedTensor],
*,
# [1, 1, batch_seq_len, batch_seq_len]
attention_mask: Optional[Union[torch.Tensor, ReplicatedTensor]],
attention_mask: Union[torch.Tensor, ReplicatedTensor],
# [bs, batch_seq_len // block_seq_stride]
seq_block_ids: Union[torch.Tensor, ReplicatedTensor],
cache_state: list[Union[torch.Tensor, SplitPrimitiveTensor]],
Expand Down
1 change: 0 additions & 1 deletion sharktank/tests/evaluate/perplexity_iree_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@ def setUp(self):
with open(self.baseline_perplexity_scores, "r") as f:
self.baseline_perplexity = json.load(f)

@pytest.mark.xfail(reason="Runtime segfault", run=False)
def test_llama3_8B_f16_decomposed(self):

# Llama 3.1 8B decomposed
Expand Down
Loading

0 comments on commit 63ff841

Please sign in to comment.