Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[DO NOT REVIEW] Experiment the pl.debug_print in Pallas. #8284

Draft
wants to merge 1 commit into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 55 additions & 0 deletions torch_xla/experimental/flash_mqa_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
"""Demo of existing custom kernels."""

from absl.testing import absltest
from absl.testing import parameterized
import jax
from jax import random
from jax._src import test_util as jtu
import jax.numpy as jnp
# from jax._src.pallas.google.tpu_ops import flash_mqa
from torch_xla.experimental.pallas_kernels import flash_mqa



RUN_BENCHMARK = False


@jtu.with_config(jax_legacy_prng_key='allow')
class FlashMQATest(jtu.JaxTestCase):

@parameterized.product(
causal=(True,),
block_q=(128,),
block_k_major=(128,),
block_k=(128,),
)
@jtu.skip_on_flag("jax_skip_slow_tests", True)
def test_flash_attention(
self, causal, block_q, block_k_major, block_k
):
if block_k_major < block_k:
self.skipTest("Invalid block_k.")
if causal and block_q > block_k:
# TODO(sharadmv, apaszke): enable this
self.skipTest("Not yet working")
q_seq_len = kv_seq_len = 1024
n_heads = 2
batch_size = 4
head_dim = 256
dtype = jnp.bfloat16
kv_shape = (batch_size, kv_seq_len, head_dim)
q_shape = (batch_size, n_heads, q_seq_len, head_dim)
q_key, k_key, v_key = jax.random.split(jax.random.PRNGKey(0), 3)
q = random.normal(q_key, q_shape, dtype=dtype)
k = random.normal(k_key, kv_shape, dtype=dtype)
v = random.normal(v_key, kv_shape, dtype=dtype)
out = flash_mqa.flash_mqa(
q, k, v, causal=causal, block_k=block_k, block_k_major=block_k_major,
block_q=block_q
)
out_ref = flash_mqa.mqa_reference(q, k, v, causal=causal)
self.assertAllClose(out, out_ref, atol=1e-2, rtol=1e-2)


if __name__ == "__main__":
absltest.main(testLoader=jtu.JaxTestLoader())
279 changes: 279 additions & 0 deletions torch_xla/experimental/pallas_kernels/flash_mqa.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,279 @@
"""Example flash MQA TPU kernel."""
import functools

import jax
from jax import lax
from jax._src.lax.control_flow import for_loop
import jax.numpy as jnp

from jax.experimental import pallas as pl
from jax._src import test_util as jtu
from jax.experimental.pallas import tpu as pltpu


def when(condition):
return lambda f: jax.lax.cond(condition, f, lambda: None)


def flash_attention_kernel(*args, **kwargs):
nb, nh = args[0].shape[:2]
for ib in range(nb):
for ih in range(nh):
flash_attention_kernel_unbatched((ib, ih), *args, **kwargs)


def flash_attention_kernel_unbatched(
batch_idx,
q_tile_ref,
k_tile_ref,
v_tile_ref, # Input arrays
o_tile_ref, # Output arrays
m_scratch_ref,
l_scratch_ref,
acc_scratch_ref,
*,
causal: bool,
sm_scale: float,
block_k: int,
kv_seq_len: int,
):
_, _, block_q, head_dim = q_tile_ref.shape
_, block_k_major, _ = k_tile_ref.shape
local_batch_index, _ = batch_idx

q_seq_idx = pl.program_id(2)
kv_seq_idx = pl.program_id(3)
pl.debug_print('line45 kv_seq_idx={}', kv_seq_idx)

kv_major_index = kv_seq_idx * block_k_major
q_index = q_seq_idx * block_q

on_diag, below_or_on_diag = False, False
if block_q == block_k_major:
on_diag = q_seq_idx == kv_seq_idx
below_or_on_diag = q_seq_idx >= kv_seq_idx
else:
q_end = (q_seq_idx + 1) * block_q
kv_index = kv_seq_idx * block_k_major
below_or_on_diag = q_end > kv_index
diag_index = jax.lax.div(q_seq_idx * block_q, block_k_major)
on_diag = kv_seq_idx == diag_index

@when(kv_seq_idx == 0)
def start_new_sequence():
m_scratch_ref[:] = jnp.full(
m_scratch_ref.shape, -jnp.inf, dtype=jnp.float32
)
l_scratch_ref[:] = jnp.zeros(l_scratch_ref.shape, dtype=jnp.float32)
acc_scratch_ref[:, :] = jnp.zeros(o_tile_ref.shape[2:], dtype=jnp.float32)

def body(i, refs):
kv_index = kv_major_index + i * block_k

def run_iter():
() = refs
m_i = m_scratch_ref[:]
l_i = l_scratch_ref[:]
start_k = pl.multiple_of(i * block_k, block_k)
q = q_tile_ref[batch_idx].astype(jnp.float32)
k = pl.load(
k_tile_ref, (local_batch_index, pl.dslice(start_k, block_k),
pl.dslice(None))
).astype(jnp.float32)

p_ij = pl.dot(q, k, trans_b=True) # [block_q, block_k]
if sm_scale != 1.0:
p_ij *= sm_scale

if causal:
q_span = q_index + jax.lax.broadcasted_iota(
jnp.int32, (block_q, block_k), 0
)
kv_span = kv_index + jax.lax.broadcasted_iota(
jnp.int32, (block_q, block_k), 1
)
causal_mask = jnp.where(q_span < kv_span, float("-inf"), 0.)
p_ij = p_ij + causal_mask

m_ij = jnp.max(p_ij, axis=1)[:, None] # Row max, shape [block_q, 1].
p_ij = jnp.exp(p_ij - m_ij) # Shape [block_q, block_k].

if causal and block_q > block_k:
# If we have skinny blocks, we might have rows that are entirely
# -inf. We need to mask out the nans that are created as a result
# TODO(sharadmv,apaszke): enable this nan mask here
# p_ij = jnp.where(jnp.isnan(p_ij), 0., p_ij)
raise NotImplementedError

m_i_new = jnp.maximum(m_i, m_ij) # Shape [block_q, 128].
alpha = jnp.exp(m_i - m_i_new) # Shape [block_q, 128].
beta = jnp.exp(m_ij - m_i_new) # Shape [block_q, 128].

l_ij = jnp.sum(p_ij, axis=1)[:, None] # Shape [block_q, 1].
l_i_new = alpha * l_i + beta * l_ij # Shape [block_q, 128].
p_scale = beta / l_i_new # Shape [block_q, 128].
p_scale_repeats, rem = divmod(block_k, 128)
if rem != 0:
raise NotImplementedError("block_k should be a multiple of 128")
p_ij = p_ij * pltpu.repeat(p_scale, p_scale_repeats, axis=1)
acc_scale = l_i / l_i_new * alpha # Shape [block_q, 128].

acc_scale_repeats, rem = divmod(head_dim, 128)
if rem != 0:
raise NotImplementedError("head_dim should be a multiple of 128")
acc_scratch_ref[:] *= pltpu.repeat(acc_scale, acc_scale_repeats, axis=1)

# Update m_i and l_i for the next block_k.
l_scratch_ref[:] = l_i_new
m_scratch_ref[:] = m_i_new

# Add the new block of attention weights.
v = pl.load(
v_tile_ref, (local_batch_index, pl.dslice(start_k, block_k),
pl.dslice(None))
).astype(jnp.float32)
acc_scratch_ref[:] += jnp.dot(p_ij, v)

if causal:
should_run_iter = (q_seq_idx + 1) * block_q > kv_index
when(should_run_iter)(run_iter)
else:
run_iter()

if causal:
@when(below_or_on_diag)
def _run_body():
for_loop.for_loop(block_k_major // block_k, body, init_state=())
else:
for_loop.for_loop(block_k_major // block_k, body, init_state=())

if causal:
@when(on_diag)
def store_output():
o_tile_ref[batch_idx] = acc_scratch_ref[:].astype(o_tile_ref.dtype)
else:
@when(kv_seq_idx == (kv_seq_len // block_k_major) - 1)
def store_output():
o_tile_ref[batch_idx] = acc_scratch_ref[:].astype(o_tile_ref.dtype)


# @functools.partial(
# jax.jit,
# static_argnames=[
# "causal", "sm_scale", "block_b", "block_q", "block_k_major", "block_k",
# "debug", "interpret"
# ],
# )
def flash_mqa(
q, # [batch_size, num_heads, seq_len, d_model]
k, # [batch_size, seq_len, d_model]
v, # [batch_size, seq_len, d_model]
*,
causal: bool = False,
sm_scale: float = 1.0,
block_b: int = 1,
block_q: int = 128,
block_k_major: int = 128,
block_k: int = 128,
debug: bool = False,
interpret: bool = False,
):
batch_size, num_heads, q_seq_len, head_dim = q.shape
_, kv_seq_len, _ = k.shape

if block_b > batch_size:
raise ValueError(f"{block_b=} should be smaller or equal to {batch_size=}")
if block_q > q_seq_len:
raise ValueError(f"{block_q=} should be smaller or equal to {q_seq_len=}")
if block_k > kv_seq_len:
raise ValueError(f"{block_k=} should be smaller or equal to {kv_seq_len=}")
if block_k_major > kv_seq_len:
raise ValueError(
f"{block_k_major=} should be smaller or equal to {kv_seq_len=}"
)
if block_k_major < block_k:
raise ValueError(f"{block_k_major=} should be smaller than {block_k=}")
grid = (
batch_size // block_b,
num_heads,
q_seq_len // block_q,
kv_seq_len // block_k_major,
)

def kv_index_map(batch_index, _, q_seq_index, kv_seq_index):
if not causal:
return (batch_index, kv_seq_index, 0)
q_end = (q_seq_index + 1) * block_q
kv_index = kv_seq_index * block_k_major
if block_q == block_k_major:
default_index = q_seq_index
else:
default_index = jax.lax.div(q_seq_index * block_q, block_k_major)
def _below_or_on_diag():
return (batch_index, kv_seq_index, 0)
def _above_diag():
return (batch_index, default_index, 0)
return lax.cond(q_end > kv_index, _below_or_on_diag, _above_diag)

def qo_index_map(batch_index, head_index, q_seq_idx, _):
return (batch_index, head_index, q_seq_idx, 0)

kernel = functools.partial(
flash_attention_kernel,
causal=causal,
sm_scale=sm_scale,
block_k=block_k,
kv_seq_len=kv_seq_len,
)
out_shape = jax.ShapeDtypeStruct(shape=q.shape, dtype=q.dtype)
m_scratch = jax.ShapeDtypeStruct((block_q, 128), dtype=jnp.float32)
l_scratch = jax.ShapeDtypeStruct((block_q, 128), dtype=jnp.float32)
acc_scratch = jax.ShapeDtypeStruct((block_q, head_dim), dtype=jnp.float32)
kernel = pl.pallas_call(
kernel,
out_shape=(out_shape, m_scratch, l_scratch, acc_scratch),
in_specs=[
pl.BlockSpec((block_b, 1, block_q, head_dim), qo_index_map),
pl.BlockSpec((block_b, block_k_major, head_dim), kv_index_map),
pl.BlockSpec((block_b, block_k_major, head_dim), kv_index_map),
],
out_specs=[
pl.BlockSpec((block_b, 1, block_q, head_dim), qo_index_map),
pl.BlockSpec(m_scratch.shape, lambda *_: (0, 0)),
pl.BlockSpec(l_scratch.shape, lambda *_: (0, 0)),
pl.BlockSpec(acc_scratch.shape, lambda *_: (0, 0)),
],
grid=grid,
debug=debug,
interpret=interpret,
)
compiled_kernel = (
jax.jit(kernel).lower(q, k, v).compile({'xla_tpu_enable_log_recorder': 'true'})
)
with jtu.capture_stderr() as get_output:
o = jax.block_until_ready(compiled_kernel(q, k, v)[0])

print('xw32 line256 out=', get_output())
return o


@functools.partial(jax.jit, static_argnames=["sm_scale", "causal"])
@jax.default_matmul_precision("bfloat16")
def mqa_reference(q, k, v, sm_scale: float = 1.0, causal: bool = False):
logits = jnp.einsum(
"bhqc,bkc->bhqk",
q.astype(jnp.float32),
k.astype(jnp.float32),
_dot_general=functools.partial(
lax.dot_general, preferred_element_type=jnp.float32,
),
precision=jax.lax.Precision.DEFAULT,
).astype(jnp.float32)
if causal:
mask = jnp.tril(jnp.ones((1, 1, q.shape[2], k.shape[1]), dtype=bool))
mask = jnp.broadcast_to(mask, logits.shape)
logits = jnp.where(mask, logits, float("-inf"))
weights = jax.nn.softmax(logits * sm_scale, axis=-1)
return jnp.einsum("bhqk,bkc->bhqc", weights, v.astype(jnp.float32)).astype(
q.dtype
)
Loading