Skip to content

Commit

Permalink
add repro of the mqa flash attention
Browse files Browse the repository at this point in the history
  • Loading branch information
vanbasten23 committed Oct 18, 2024
1 parent 5dbdb8d commit 8716795
Show file tree
Hide file tree
Showing 2 changed files with 334 additions and 0 deletions.
55 changes: 55 additions & 0 deletions torch_xla/experimental/flash_mqa_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
"""Demo of existing custom kernels."""

from absl.testing import absltest
from absl.testing import parameterized
import jax
from jax import random
from jax._src import test_util as jtu
import jax.numpy as jnp
# from jax._src.pallas.google.tpu_ops import flash_mqa
from torch_xla.experimental.pallas_kernels import flash_mqa



RUN_BENCHMARK = False


@jtu.with_config(jax_legacy_prng_key='allow')
class FlashMQATest(jtu.JaxTestCase):

@parameterized.product(
causal=(True,),
block_q=(128,),
block_k_major=(128,),
block_k=(128,),
)
@jtu.skip_on_flag("jax_skip_slow_tests", True)
def test_flash_attention(
self, causal, block_q, block_k_major, block_k
):
if block_k_major < block_k:
self.skipTest("Invalid block_k.")
if causal and block_q > block_k:
# TODO(sharadmv, apaszke): enable this
self.skipTest("Not yet working")
q_seq_len = kv_seq_len = 1024
n_heads = 2
batch_size = 4
head_dim = 256
dtype = jnp.bfloat16
kv_shape = (batch_size, kv_seq_len, head_dim)
q_shape = (batch_size, n_heads, q_seq_len, head_dim)
q_key, k_key, v_key = jax.random.split(jax.random.PRNGKey(0), 3)
q = random.normal(q_key, q_shape, dtype=dtype)
k = random.normal(k_key, kv_shape, dtype=dtype)
v = random.normal(v_key, kv_shape, dtype=dtype)
out = flash_mqa.flash_mqa(
q, k, v, causal=causal, block_k=block_k, block_k_major=block_k_major,
block_q=block_q
)
out_ref = flash_mqa.mqa_reference(q, k, v, causal=causal)
self.assertAllClose(out, out_ref, atol=1e-2, rtol=1e-2)


if __name__ == "__main__":
absltest.main(testLoader=jtu.JaxTestLoader())
279 changes: 279 additions & 0 deletions torch_xla/experimental/pallas_kernels/flash_mqa.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,279 @@
"""Example flash MQA TPU kernel."""
import functools

import jax
from jax import lax
from jax._src.lax.control_flow import for_loop
import jax.numpy as jnp

from jax.experimental import pallas as pl
from jax._src import test_util as jtu
from jax.experimental.pallas import tpu as pltpu


def when(condition):
return lambda f: jax.lax.cond(condition, f, lambda: None)


def flash_attention_kernel(*args, **kwargs):
nb, nh = args[0].shape[:2]
for ib in range(nb):
for ih in range(nh):
flash_attention_kernel_unbatched((ib, ih), *args, **kwargs)


def flash_attention_kernel_unbatched(
batch_idx,
q_tile_ref,
k_tile_ref,
v_tile_ref, # Input arrays
o_tile_ref, # Output arrays
m_scratch_ref,
l_scratch_ref,
acc_scratch_ref,
*,
causal: bool,
sm_scale: float,
block_k: int,
kv_seq_len: int,
):
_, _, block_q, head_dim = q_tile_ref.shape
_, block_k_major, _ = k_tile_ref.shape
local_batch_index, _ = batch_idx

q_seq_idx = pl.program_id(2)
kv_seq_idx = pl.program_id(3)
pl.debug_print('line45 kv_seq_idx={}', kv_seq_idx)

kv_major_index = kv_seq_idx * block_k_major
q_index = q_seq_idx * block_q

on_diag, below_or_on_diag = False, False
if block_q == block_k_major:
on_diag = q_seq_idx == kv_seq_idx
below_or_on_diag = q_seq_idx >= kv_seq_idx
else:
q_end = (q_seq_idx + 1) * block_q
kv_index = kv_seq_idx * block_k_major
below_or_on_diag = q_end > kv_index
diag_index = jax.lax.div(q_seq_idx * block_q, block_k_major)
on_diag = kv_seq_idx == diag_index

@when(kv_seq_idx == 0)
def start_new_sequence():
m_scratch_ref[:] = jnp.full(
m_scratch_ref.shape, -jnp.inf, dtype=jnp.float32
)
l_scratch_ref[:] = jnp.zeros(l_scratch_ref.shape, dtype=jnp.float32)
acc_scratch_ref[:, :] = jnp.zeros(o_tile_ref.shape[2:], dtype=jnp.float32)

def body(i, refs):
kv_index = kv_major_index + i * block_k

def run_iter():
() = refs
m_i = m_scratch_ref[:]
l_i = l_scratch_ref[:]
start_k = pl.multiple_of(i * block_k, block_k)
q = q_tile_ref[batch_idx].astype(jnp.float32)
k = pl.load(
k_tile_ref, (local_batch_index, pl.dslice(start_k, block_k),
pl.dslice(None))
).astype(jnp.float32)

p_ij = pl.dot(q, k, trans_b=True) # [block_q, block_k]
if sm_scale != 1.0:
p_ij *= sm_scale

if causal:
q_span = q_index + jax.lax.broadcasted_iota(
jnp.int32, (block_q, block_k), 0
)
kv_span = kv_index + jax.lax.broadcasted_iota(
jnp.int32, (block_q, block_k), 1
)
causal_mask = jnp.where(q_span < kv_span, float("-inf"), 0.)
p_ij = p_ij + causal_mask

m_ij = jnp.max(p_ij, axis=1)[:, None] # Row max, shape [block_q, 1].
p_ij = jnp.exp(p_ij - m_ij) # Shape [block_q, block_k].

if causal and block_q > block_k:
# If we have skinny blocks, we might have rows that are entirely
# -inf. We need to mask out the nans that are created as a result
# TODO(sharadmv,apaszke): enable this nan mask here
# p_ij = jnp.where(jnp.isnan(p_ij), 0., p_ij)
raise NotImplementedError

m_i_new = jnp.maximum(m_i, m_ij) # Shape [block_q, 128].
alpha = jnp.exp(m_i - m_i_new) # Shape [block_q, 128].
beta = jnp.exp(m_ij - m_i_new) # Shape [block_q, 128].

l_ij = jnp.sum(p_ij, axis=1)[:, None] # Shape [block_q, 1].
l_i_new = alpha * l_i + beta * l_ij # Shape [block_q, 128].
p_scale = beta / l_i_new # Shape [block_q, 128].
p_scale_repeats, rem = divmod(block_k, 128)
if rem != 0:
raise NotImplementedError("block_k should be a multiple of 128")
p_ij = p_ij * pltpu.repeat(p_scale, p_scale_repeats, axis=1)
acc_scale = l_i / l_i_new * alpha # Shape [block_q, 128].

acc_scale_repeats, rem = divmod(head_dim, 128)
if rem != 0:
raise NotImplementedError("head_dim should be a multiple of 128")
acc_scratch_ref[:] *= pltpu.repeat(acc_scale, acc_scale_repeats, axis=1)

# Update m_i and l_i for the next block_k.
l_scratch_ref[:] = l_i_new
m_scratch_ref[:] = m_i_new

# Add the new block of attention weights.
v = pl.load(
v_tile_ref, (local_batch_index, pl.dslice(start_k, block_k),
pl.dslice(None))
).astype(jnp.float32)
acc_scratch_ref[:] += jnp.dot(p_ij, v)

if causal:
should_run_iter = (q_seq_idx + 1) * block_q > kv_index
when(should_run_iter)(run_iter)
else:
run_iter()

if causal:
@when(below_or_on_diag)
def _run_body():
for_loop.for_loop(block_k_major // block_k, body, init_state=())
else:
for_loop.for_loop(block_k_major // block_k, body, init_state=())

if causal:
@when(on_diag)
def store_output():
o_tile_ref[batch_idx] = acc_scratch_ref[:].astype(o_tile_ref.dtype)
else:
@when(kv_seq_idx == (kv_seq_len // block_k_major) - 1)
def store_output():
o_tile_ref[batch_idx] = acc_scratch_ref[:].astype(o_tile_ref.dtype)


# @functools.partial(
# jax.jit,
# static_argnames=[
# "causal", "sm_scale", "block_b", "block_q", "block_k_major", "block_k",
# "debug", "interpret"
# ],
# )
def flash_mqa(
q, # [batch_size, num_heads, seq_len, d_model]
k, # [batch_size, seq_len, d_model]
v, # [batch_size, seq_len, d_model]
*,
causal: bool = False,
sm_scale: float = 1.0,
block_b: int = 1,
block_q: int = 128,
block_k_major: int = 128,
block_k: int = 128,
debug: bool = False,
interpret: bool = False,
):
batch_size, num_heads, q_seq_len, head_dim = q.shape
_, kv_seq_len, _ = k.shape

if block_b > batch_size:
raise ValueError(f"{block_b=} should be smaller or equal to {batch_size=}")
if block_q > q_seq_len:
raise ValueError(f"{block_q=} should be smaller or equal to {q_seq_len=}")
if block_k > kv_seq_len:
raise ValueError(f"{block_k=} should be smaller or equal to {kv_seq_len=}")
if block_k_major > kv_seq_len:
raise ValueError(
f"{block_k_major=} should be smaller or equal to {kv_seq_len=}"
)
if block_k_major < block_k:
raise ValueError(f"{block_k_major=} should be smaller than {block_k=}")
grid = (
batch_size // block_b,
num_heads,
q_seq_len // block_q,
kv_seq_len // block_k_major,
)

def kv_index_map(batch_index, _, q_seq_index, kv_seq_index):
if not causal:
return (batch_index, kv_seq_index, 0)
q_end = (q_seq_index + 1) * block_q
kv_index = kv_seq_index * block_k_major
if block_q == block_k_major:
default_index = q_seq_index
else:
default_index = jax.lax.div(q_seq_index * block_q, block_k_major)
def _below_or_on_diag():
return (batch_index, kv_seq_index, 0)
def _above_diag():
return (batch_index, default_index, 0)
return lax.cond(q_end > kv_index, _below_or_on_diag, _above_diag)

def qo_index_map(batch_index, head_index, q_seq_idx, _):
return (batch_index, head_index, q_seq_idx, 0)

kernel = functools.partial(
flash_attention_kernel,
causal=causal,
sm_scale=sm_scale,
block_k=block_k,
kv_seq_len=kv_seq_len,
)
out_shape = jax.ShapeDtypeStruct(shape=q.shape, dtype=q.dtype)
m_scratch = jax.ShapeDtypeStruct((block_q, 128), dtype=jnp.float32)
l_scratch = jax.ShapeDtypeStruct((block_q, 128), dtype=jnp.float32)
acc_scratch = jax.ShapeDtypeStruct((block_q, head_dim), dtype=jnp.float32)
kernel = pl.pallas_call(
kernel,
out_shape=(out_shape, m_scratch, l_scratch, acc_scratch),
in_specs=[
pl.BlockSpec((block_b, 1, block_q, head_dim), qo_index_map),
pl.BlockSpec((block_b, block_k_major, head_dim), kv_index_map),
pl.BlockSpec((block_b, block_k_major, head_dim), kv_index_map),
],
out_specs=[
pl.BlockSpec((block_b, 1, block_q, head_dim), qo_index_map),
pl.BlockSpec(m_scratch.shape, lambda *_: (0, 0)),
pl.BlockSpec(l_scratch.shape, lambda *_: (0, 0)),
pl.BlockSpec(acc_scratch.shape, lambda *_: (0, 0)),
],
grid=grid,
debug=debug,
interpret=interpret,
)
compiled_kernel = (
jax.jit(kernel).lower(q, k, v).compile({'xla_tpu_enable_log_recorder': 'true'})
)
with jtu.capture_stderr() as get_output:
o = jax.block_until_ready(compiled_kernel(q, k, v)[0])

print('xw32 line256 out=', get_output())
return o


@functools.partial(jax.jit, static_argnames=["sm_scale", "causal"])
@jax.default_matmul_precision("bfloat16")
def mqa_reference(q, k, v, sm_scale: float = 1.0, causal: bool = False):
logits = jnp.einsum(
"bhqc,bkc->bhqk",
q.astype(jnp.float32),
k.astype(jnp.float32),
_dot_general=functools.partial(
lax.dot_general, preferred_element_type=jnp.float32,
),
precision=jax.lax.Precision.DEFAULT,
).astype(jnp.float32)
if causal:
mask = jnp.tril(jnp.ones((1, 1, q.shape[2], k.shape[1]), dtype=bool))
mask = jnp.broadcast_to(mask, logits.shape)
logits = jnp.where(mask, logits, float("-inf"))
weights = jax.nn.softmax(logits * sm_scale, axis=-1)
return jnp.einsum("bhqk,bkc->bhqc", weights, v.astype(jnp.float32)).astype(
q.dtype
)

0 comments on commit 8716795

Please sign in to comment.