diff --git a/sharktank/sharktank/models/llama/llama.py b/sharktank/sharktank/models/llama/llama.py index 656b4432b..ef3c4800d 100644 --- a/sharktank/sharktank/models/llama/llama.py +++ b/sharktank/sharktank/models/llama/llama.py @@ -269,7 +269,6 @@ def decode( for block_idx, block in enumerate(self.attn_blocks): if block_idx == 0: self.trace_tensor(f"llama.attn_block.{block_idx}.input", h) - block.attn.attention_kernel = "decomposed" h = block( h, start_positions=start_positions,