From c40b39a5fd12e62a664432ed8e035662c884094a Mon Sep 17 00:00:00 2001 From: Dongseong Hwang Date: Tue, 7 Jan 2025 13:52:08 -0800 Subject: [PATCH] Optimize TPU Flash Attention (20x XLA compilation speed-up on 32k long context) (#908) Use splash attention lazy mask instead of jnp mask, which is O(T^2). The host memory usage for the `jnp` mask is O(T^2). Currently, a `jnp` mask is created and then wrapped with `NumpyMask` for use in Splash Attention, resulting in O(T^2) temporal HBM usage (somehow XLA avoids allocating it tho). This PR proposes using `CausalMask` or `LocalMask`, allowing each Splash Attention block to lazily create and use the required mask in the pallas kernel. The runtime performance of Splash Attention remains nearly the same. However, the JIT compilation time for the function using Splash Attention has improved significantly. It appears that allocating the O(T^2) mask in HBM and then wrapping it with `NumpyMask` consumes a lot of XLA compilation time. * Benchmark: on TPUv5p, (model_dim/heads/kv_heads/seq_len), tools/attention_benchmark.py 1) measure time with XLA compilation NumpyMask (ASIS) ---------------------------------------------------------------------------------------- Benchmark Time CPU Iterations ---------------------------------------------------------------------------------------- FlashAttentionBenchmark/4096/16/2/4096 3609 ms 1645 ms 1 FlashAttentionBenchmark/4096/16/2/8192 7828 ms 5696 ms 1 FlashAttentionBenchmark/4096/16/2/32768 94368 ms 91442 ms 1 CausalMask (Proposed PR): significant XLA compilation speed-up. ---------------------------------------------------------------------------------------- Benchmark Time CPU Iterations ---------------------------------------------------------------------------------------- FlashAttentionBenchmark/4096/16/2/4096 22.9 ms 1.60 ms 127 FlashAttentionBenchmark/4096/16/2/8192 40.8 ms 2.12 ms 88 FlashAttentionBenchmark/4096/16/2/32768 7641 ms 5458 ms 1 2) measure time without XLA compilation (pure jit computation) NumpyMask (ASIS) ---------------------------------------------------------------------------------------- Benchmark Time CPU Iterations ---------------------------------------------------------------------------------------- FlashAttentionBenchmark/4096/16/2/4096 9.97 ms 0.757 ms 918 FlashAttentionBenchmark/4096/16/2/8192 19.4 ms 0.934 ms 832 FlashAttentionBenchmark/4096/16/2/32768 116 ms 1.03 ms 100 CausalMask (Proposed PR): slight step time speed-up. ---------------------------------------------------------------------------------------- Benchmark Time CPU Iterations ---------------------------------------------------------------------------------------- FlashAttentionBenchmark/4096/16/2/4096 9.82 ms 0.690 ms 964 FlashAttentionBenchmark/4096/16/2/8192 19.2 ms 0.822 ms 837 FlashAttentionBenchmark/4096/16/2/32768 116 ms 0.997 ms 100 In addition, tpu_attention_benchmark.py is changed to use 8k seq len, not 2k with sliding window = 1k. NumpyMask (ASIS) Benchmarking attention representative of 1.2b model layer on TPU v5. ref_fwd:0.2288s, flash_fwd:0.0014s ref_bwd:0.0218s, flash_bwd:0.0058s Benchmarking attention representative of 12.6b model layer on TPU v5. ref_fwd:0.5700s, flash_fwd:0.0032s ref_bwd:0.0527s, flash_bwd:0.0149s Benchmarking attention representative of 29.6b model layer on TPU v5. ref_fwd:0.7958s, flash_fwd:0.0044s ref_bwd:0.0730s, flash_bwd:0.0205s Benchmarking attention representative of 65.2b model layer on TPU v5. ref_fwd:1.0222s, flash_fwd:0.0055s ref_bwd:0.0949s, flash_bwd:0.0262s Benchmarking attention representative of 134b model layer on TPU v5. ref_fwd:1.2486s, flash_fwd:0.0067s ref_bwd:0.1161s, flash_bwd:0.0318s Benchmarking attention representative of 261.7b model layer on TPU v5. ref_fwd:1.5577s, flash_fwd:0.0072s ref_bwd:0.1348s, flash_bwd:0.0375s LocalMask (Proposed PR): slight fwd/bwd time speed-up. Benchmarking attention representative of 1.2b model layer on TPU v5. ref_fwd:0.2291s, flash_fwd:0.0014s ref_bwd:0.0217s, flash_bwd:0.0058s Benchmarking attention representative of 12.6b model layer on TPU v5. ref_fwd:0.5699s, flash_fwd:0.0032s ref_bwd:0.0524s, flash_bwd:0.0152s Benchmarking attention representative of 29.6b model layer on TPU v5. ref_fwd:0.7957s, flash_fwd:0.0043s ref_bwd:0.0731s, flash_bwd:0.0204s Benchmarking attention representative of 65.2b model layer on TPU v5. ref_fwd:1.0225s, flash_fwd:0.0055s ref_bwd:0.0948s, flash_bwd:0.0262s Benchmarking attention representative of 134b model layer on TPU v5. ref_fwd:1.2485s, flash_fwd:0.0067s ref_bwd:0.1159s, flash_bwd:0.0313s Benchmarking attention representative of 261.7b model layer on TPU v5. ref_fwd:1.5577s, flash_fwd:0.0072s ref_bwd:0.1349s, flash_bwd:0.0373s --- axlearn/common/attention_bias.py | 6 +- .../common/flash_attention/tpu_attention.py | 245 ++++++++++++------ .../tpu_attention_benchmark.py | 62 +++-- .../flash_attention/tpu_attention_test.py | 64 ++++- axlearn/common/flash_attention/utils.py | 2 + 5 files changed, 257 insertions(+), 122 deletions(-) diff --git a/axlearn/common/attention_bias.py b/axlearn/common/attention_bias.py index d605242bb..9830b077a 100644 --- a/axlearn/common/attention_bias.py +++ b/axlearn/common/attention_bias.py @@ -687,7 +687,11 @@ def sliding_window_causal_mask(sliding_window_size: int) -> MaskFn: def mask(query_position: Tensor, key_position: Tensor): return query_position - key_position <= sliding_window_size - return and_masks(causal_mask, mask) + fun = and_masks(causal_mask, mask) + # Flash attention needs to recognize sliding window size in _to_splash_mask(). + # pylint: disable-next=protected-access + fun._sliding_window_size = sliding_window_size + return fun def make_causal_biases(seq_len: int) -> Tensor: diff --git a/axlearn/common/flash_attention/tpu_attention.py b/axlearn/common/flash_attention/tpu_attention.py index de18a402f..4de717b4b 100644 --- a/axlearn/common/flash_attention/tpu_attention.py +++ b/axlearn/common/flash_attention/tpu_attention.py @@ -42,15 +42,16 @@ def tpu_flash_attention( - query: Tensor, # [batch_size, source_len, num_heads, head_dim] - key: Tensor, # [batch_size, target_len, num_heads, head_dim] - value: Tensor, # [batch_size, target_len, num_heads, head_dim] - bias: Tensor = None, # [batch_size, num_heads, source_len, target_len] - segment_ids: Tensor = None, # [batch_size, source_len] + query: Tensor, # [batch_size, target_len, num_heads, head_dim] + key: Tensor, # [batch_size, source_len, num_heads, head_dim] + value: Tensor, # [batch_size, source_len, num_heads, head_dim] + bias: Tensor = None, # [batch_size, num_heads, target_len, source_len] + segment_ids: Tensor = None, # [batch_size, target_len] *, mask: Optional[MaskFnAttentionBias] = None, softmax_scale: float = 1.0, block_size: int = 128, + interpret: bool = False, ): """Wraps JAX's TPU flash-attention, with reshapes and softmax-scaling outside kernel. @@ -63,20 +64,21 @@ def tpu_flash_attention( If provided, bias, segment_ids, and mask are applied on top of one another. Args: - query: The query tensor, of shape [batch_size, source_len, num_heads, head_dim]. - key: The key tensor, of shape [batch_size, target_len, num_heads, head_dim]. + query: The query tensor, of shape [batch_size, target_len, num_heads, head_dim]. + key: The key tensor, of shape [batch_size, source_len, num_heads, head_dim]. value: The value tensor, of shape [batch_size, source_len, num_heads, head_dim]. bias: The attention biases, can broadcast to shape - [batch_size, num_heads, source_len, target_len]. + [batch_size, num_heads, target_len, source_len]. segment_ids: The id of which segment each token belongs to. Attention is not computed between tokens in different segments. - Shape: [batch_size, source_len]. + Shape: [batch_size, target_len]. mask: The mask to apply. This is more compute efficient compared to setting bias = -inf. softmax_scale: A scaling factor applied to the query. block_size: The block size to use for chunking data in the kernel. + interpret: If True, interpret the kernel using the pallas interpreter. CPU needs it. Returns: - The context tensor, of shape [batch_size, source_len, num_heads, head_dim]. + The context tensor, of shape [batch_size, target_len, num_heads, head_dim]. Raises: NotImplementedError: If no implementation with support for the arguments is found. @@ -111,13 +113,20 @@ def tpu_flash_attention( f"Source seq len {key.shape[1]} must be divisible by block size {block_size}." ) - mask = as_attention_bias(mask) + mask: Union[MaskFnAttentionBias | ZeroAttentionBias] = as_attention_bias(mask) # Switch num_heads and seq_len axes. query = jnp.einsum("btnh->bnth", query) key = jnp.einsum("bsnh->bnsh", key) value = jnp.einsum("bsnh->bnsh", value) try: + check_tpu_splash_attention( + query=query, + key=key, + mask=mask, + has_segment_ids=(segment_ids is not None), + has_bias=(bias is not None), + ) block_sizes = splash_attention_kernel.BlockSizes( block_q=block_size, block_kv=block_size, @@ -131,7 +140,13 @@ def tpu_flash_attention( use_fused_bwd_kernel=True, ) context = _tpu_splash_attention( - query, key, value, bias, segment_ids=segment_ids, mask=mask, block_sizes=block_sizes + query, + key, + value, + mask=mask, + segment_ids=segment_ids, + block_sizes=block_sizes, + interpret=interpret, ) logging.info("Using SplashAttention.") except SplashAttentionUnsupportedError as e: @@ -150,7 +165,14 @@ def tpu_flash_attention( block_q_dq=block_size, ) context = _legacy_tpu_flash_attention( - query, key, value, bias, segment_ids=segment_ids, mask=mask, block_sizes=block_sizes + query, + key, + value, + bias, + segment_ids=segment_ids, + mask=mask, + block_sizes=block_sizes, + interpret=interpret, ) logging.warning( "Falling back to legacy flash attention because SplashAttention is not supported.\n" @@ -167,36 +189,39 @@ def tpu_flash_attention( static_argnames=[ "mask", # Mask objects don't actually contain jax arrays, so they are static. "block_sizes", + "interpret", ], ) def _legacy_tpu_flash_attention( - query: Tensor, # [batch_size, num_heads, source_len, head_dim] - key: Tensor, # [batch_size, num_heads, target_len, head_dim] - value: Tensor, # [batch_size, num_heads, target_len, head_dim] - bias: Tensor = None, # [batch_size, num_heads, source_len, target_len] - segment_ids: Tensor = None, # [batch_size, source_len] + query: Tensor, # [batch_size, num_heads, target_len, head_dim] + key: Tensor, # [batch_size, num_heads, source_len, head_dim] + value: Tensor, # [batch_size, num_heads, source_len, head_dim] + bias: Tensor = None, # [batch_size, num_heads, target_len, source_len] + segment_ids: Tensor = None, # [batch_size, target_len] *, mask: MaskFnAttentionBias, block_sizes: Optional[LegacyBlockSizes] = None, -) -> Tensor: # [batch_size, num_heads, source_len, head_dim]. + interpret: bool = False, +) -> Tensor: # [batch_size, num_heads, target_len, head_dim]. """Wraps JAX's legacy TPU flash-attention. If provided, bias, segment_ids, and mask are applied on top of one another. Args: - query: The query tensor, of shape [batch_size, num_heads, source_len, head_dim]. - key: The key tensor, of shape [batch_size, num_heads, target_len, head_dim]. + query: The query tensor, of shape [batch_size, num_heads, target_len, head_dim]. + key: The key tensor, of shape [batch_size, num_heads, source_len, head_dim]. value: The value tensor, of shape [batch_size, num_heads, source_len, head_dim]. - bias: The attention biases, of shape [batch_size, num_heads, source_len, target_len]. + bias: The attention biases, of shape [batch_size, num_heads, target_len, source_len]. segment_ids: The id of which segment each token belongs to. Attention is not computed between tokens in different segments. - Shape: [batch_size, source_len]. + Shape: [batch_size, target_len]. mask: The mask to apply. This is more compute efficient compared to setting bias = -inf. block_sizes: An object containing values that can be used to tune the performance such as the block size to chunk things into. + interpret: If True, interpret the kernel using the pallas interpreter. CPU needs it. Returns: - The context tensor, of shape [batch_size, num_heads, source_len, head_dim]. + The context tensor, of shape [batch_size, num_heads, target_len, head_dim]. Raises: NotImplementedError: If a custom (non-causal, non-full) mask is specified. @@ -216,6 +241,7 @@ def _legacy_tpu_flash_attention( softmax_scale=1.0, block_sizes=block_sizes, debug=False, + interpret=interpret, ) return context @@ -225,19 +251,103 @@ class SplashAttentionUnsupportedError(NotImplementedError): """An error indicating splash attention is not supported.""" +def check_tpu_splash_attention( + *, + query: Tensor, # [batch_size, num_heads, source_len, head_dim] + key: Tensor, # [batch_size, num_heads, target_len, head_dim] + mask: Union[MaskFnAttentionBias | ZeroAttentionBias], + has_segment_ids: bool = False, + has_bias: bool = False, +): + """Checks if splash attention is supported on TPU for the given arguments. + + Args: + query: The query tensor, of shape [batch_size, num_heads, target_len, head_dim]. + key: The key tensor, of shape [batch_size, num_heads, source_len, head_dim]. + mask: The mask to apply. This is more compute efficient compared to setting bias = -inf. + has_segment_ids: Whether segment_ids is None or not. + has_bias: Whether attention involves a bias. + + Raises: + SplashAttentionUnsupportedError: If splash attention is not supported for the given + arguments. + """ + target_len = query.shape[2] + source_len = key.shape[2] + head_dim = query.shape[3] + + if has_bias: + return False # SplashAttention does not support specifying a bias. + with jax.ensure_compile_time_eval(): + if jnp.any( + jnp.asarray([target_len, source_len, head_dim]) % splash_attention_kernel.NUM_LANES != 0 + ): + raise SplashAttentionUnsupportedError( + "SplashAttention requires target_len, source_len, head_dim are divisible by" + f" {splash_attention_kernel.NUM_LANES}, got {target_len, source_len, head_dim}." + ) + if has_segment_ids: + raise SplashAttentionUnsupportedError( + "The public API for SplashAttention that we " + "currently use does not support segment ids." + ) + if mask.value() is not None: + assert isinstance(mask, MaskFnAttentionBias) + if target_len != source_len: + raise SplashAttentionUnsupportedError( + "Query and key/value must have same length when mask is used." + ) + if isinstance(mask.target_positions, jax.core.Tracer): + raise SplashAttentionUnsupportedError( + "Non-static value of `target_positions` is not supported.\n" + "Are you decoding using SplashAttention? That's not supported." + ) + + +def _to_splash_mask( + mask: Union[MaskFnAttentionBias | ZeroAttentionBias], + *, + mask_shape: tuple[int, int], + q_seq_shards: int = 1, +) -> splash_attention_mask.Mask: + """Converts a mask to a splash mask.""" + if mask.value() is None: + return splash_attention_mask.FullMask(mask_shape) + assert isinstance(mask, MaskFnAttentionBias) + if isinstance(mask, CausalAttentionBias): + return splash_attention_mask.CausalMask(shape=mask_shape, shard_count=q_seq_shards) + if hasattr(mask.mask, "_sliding_window_size"): + # TODO(dhwang2): introduce SlidingWindowAttentionBias instead of "_sliding_window_size". + # This is set in sliding_window_causal_mask(). + left_size = getattr(mask.mask, "_sliding_window_size") + return splash_attention_mask.LocalMask( + shape=mask_shape, window_size=(left_size, 0), offset=0, shard_count=q_seq_shards + ) + + with jax.ensure_compile_time_eval(): + # MaskFn always supports compile time eval. + mask_array = np.asarray(mask.bool_value()) + # Squeeze first two leading dimensions. + mask_array = mask_array.reshape(mask_array.shape[-2:]) + + # NumpyMask is backed by a dense [target_len, source_len] numpy array. + # May consume a large amount of host memory for long sequences at compile time. + return splash_attention_mask.NumpyMask(array=mask_array) + + @functools.partial( jax.jit, - static_argnames=["block_sizes"], + static_argnames=["block_sizes", "interpret"], ) def _tpu_splash_attention( query: Tensor, # [batch_size, num_heads, target_len, head_dim] key: Tensor, # [batch_size, num_heads, source_len, head_dim] value: Tensor, # [batch_size, num_heads, source_len, head_dim] - bias: Optional[Tensor] = None, # [batch_size, num_heads, target_len, source_len] - segment_ids: Optional[Tensor] = None, # [batch_size, target_len] *, - mask: Union[MaskFnAttentionBias, ZeroAttentionBias], + mask: Union[MaskFnAttentionBias | ZeroAttentionBias], + segment_ids: Optional[Tensor] = None, # [batch_size, target_len] block_sizes: Optional[splash_attention_kernel.BlockSizes] = None, + interpret: bool = False, ) -> Tensor: # [batch_size, num_heads, target_len, head_dim]. """Wraps JAX's sparse TPU flash-attention. @@ -245,13 +355,12 @@ def _tpu_splash_attention( query: The query tensor, of shape [batch_size, num_heads, target_len, head_dim]. key: The key tensor, of shape [batch_size, num_heads, source_len, head_dim]. value: The value tensor, of shape [batch_size, num_heads, source_len, head_dim]. - bias: The attention biases, of shape [batch_size, num_heads, target_len, source_len]. - segment_ids: The id of which segment each token belongs to. Attention is not computed - between tokens in different segments. - Shape: [batch_size, target_len]. mask: The mask to apply. This is more compute efficient compared to setting bias = -inf. + segment_ids: The id of which segment each token belongs to. Attention is not computed + between tokens in different segments, [batch_size, target_len]. block_sizes: An object containing values that can be used to tune the performance such as the block size to chunk things into. + interpret: If True, interpret the kernel using the pallas interpreter. CPU needs it. Returns: The context tensor, of shape [batch_size, num_heads, target_len, head_dim]. @@ -266,57 +375,19 @@ def _tpu_splash_attention( TypeError: If mask is not an instance of `MaskFnAttentionBias. """ - target_len = query.shape[2] - source_len = key.shape[2] + # TODO(dhwang2): splash attention can support segment_ids. Support it when needed. + del segment_ids num_heads = query.shape[1] - head_dim = query.shape[3] - - if bias is not None: - raise SplashAttentionUnsupportedError("SplashAttention does not support specifying a bias.") - with jax.ensure_compile_time_eval(): - if jnp.any( - jnp.asarray([target_len, source_len, head_dim]) % splash_attention_kernel.NUM_LANES != 0 - ): - raise SplashAttentionUnsupportedError( - "SplashAttention requires target_len, source_len, head_dim are divisible by" - f" {splash_attention_kernel.NUM_LANES}, got {target_len, source_len, head_dim}." - ) - if segment_ids is not None: - raise SplashAttentionUnsupportedError( - "The public API for SplashAttention that we " - "currently use does not support segment ids." - ) - if target_len != source_len and mask.value() is not None: - raise SplashAttentionUnsupportedError( - "Query and key/value must have same length when mask is used." - ) - if mask.value() is not None and not isinstance(mask, MaskFnAttentionBias): - raise TypeError(type(mask)) - if mask.value() is not None and isinstance(mask.target_positions, jax.core.Tracer): - raise SplashAttentionUnsupportedError( - "Non-static value of `target_positions` is not supported.\n" - "Are you decoding using SplashAttention? That's not supported." - ) - - mask_shape = (target_len, source_len) - if mask.value() is None: - mask = splash_attention_mask.FullMask(mask_shape) - else: - with jax.ensure_compile_time_eval(): - # MaskFn always supports compile time eval. - mask_array = np.asarray(mask.bool_value()) - # Squeeze first two leading dimensions. - mask_array = mask_array.reshape(mask_array.shape[-2:]) - - # NumpyMask is backed by a dense [target_len, source_len] numpy array. - # May consume a large amount of host memory for long sequences at compile time. - mask = splash_attention_mask.NumpyMask(array=mask_array) + mask_shape = (query.shape[2], key.shape[2]) + splash_mask = _to_splash_mask(mask, mask_shape=mask_shape) kernel = splash_attention_kernel.make_splash_mha( - mask=splash_attention_mask.MultiHeadMask(masks=[mask] * num_heads), + mask=splash_attention_mask.MultiHeadMask(masks=[splash_mask] * num_heads), block_sizes=block_sizes, + # TODO(dhwang2): support "seq" and "model" shard. head_shards=1, q_seq_shards=1, + interpret=interpret, ) kernel = jax.vmap(kernel) context = kernel(q=query, k=key, v=value) @@ -335,6 +406,7 @@ def _tpu_splash_attention( "softmax_scale", "block_sizes", "debug", + "interpret", ], ) def pallas_tpu_flash_attention( @@ -348,6 +420,7 @@ def pallas_tpu_flash_attention( softmax_scale: float = 1.0, block_sizes: Optional[LegacyBlockSizes] = None, debug: bool = False, + interpret: bool = False, ): batch_size, num_heads, q_seq_len, d_model = q.shape batch_size_k, num_heads_k, kv_seq_len, d_model_k = k.shape @@ -397,11 +470,11 @@ def pallas_tpu_flash_attention( batch_size, num_heads, q_seq_len, kv_seq_len, d_model ) return _flash_attention( - q, k, v, ab, segment_ids, False, causal, softmax_scale, block_sizes, debug + q, k, v, ab, segment_ids, False, causal, softmax_scale, block_sizes, debug, interpret ) -@functools.partial(jax.custom_vjp, nondiff_argnums=range(5, 10)) +@functools.partial(jax.custom_vjp, nondiff_argnums=range(5, 11)) def _flash_attention( q, k, @@ -413,6 +486,7 @@ def _flash_attention( softmax_scale, block_sizes, debug, + interpret, ): return _flash_attention_impl( q, @@ -428,6 +502,7 @@ def _flash_attention( block_sizes.block_k_major, block_sizes.block_k, debug, + interpret, ) @@ -442,11 +517,12 @@ def _flash_attention_fwd( softmax_scale, block_sizes, debug, + interpret, ): if save_residuals: raise NotImplementedError("Higher-order AD not supported") o, l, m = _flash_attention( - q, k, v, ab, segment_ids, True, causal, softmax_scale, block_sizes, debug + q, k, v, ab, segment_ids, True, causal, softmax_scale, block_sizes, debug, interpret ) return o, (q, k, v, ab, segment_ids, o, l, m) @@ -457,6 +533,7 @@ def _flash_attention_bwd( softmax_scale: float, block_sizes: LegacyBlockSizes, debug: bool, + interpret: bool, residuals, do, ): @@ -491,6 +568,7 @@ def _flash_attention_bwd( causal=causal, mask_value=DEFAULT_MASK_VALUE, debug=debug, + interpret=interpret, ) dq, ds = _flash_attention_bwd_dq( @@ -510,6 +588,7 @@ def _flash_attention_bwd( causal=causal, mask_value=DEFAULT_MASK_VALUE, debug=debug, + interpret=interpret, ) return dq, dk, dv, ds, None @@ -531,6 +610,7 @@ def _flash_attention_impl( block_k_major, block_k, debug, + interpret, ): batch_size, num_heads, q_seq_len, head_dim = q.shape _, _, kv_seq_len, _ = k.shape @@ -693,6 +773,7 @@ def kv_segment_ids_index_map(batch_index, head_index, q_seq_index, kv_seq_index) ), out_shape=out_shape, debug=debug, + interpret=interpret, compiler_params=dict( mosaic=dict( dimension_semantics=( @@ -730,6 +811,7 @@ def _flash_attention_bwd_dkv( causal: bool = False, mask_value: float = DEFAULT_MASK_VALUE, debug: bool = False, + interpret: bool = False, ): batch_size, num_heads, q_seq_len, head_dim = q.shape _, _, kv_seq_len, _ = k.shape @@ -896,6 +978,7 @@ def dkv_index_map(batch_index, head_index, kv_seq_index, _): ), out_shape=out_shapes, debug=debug, + interpret=interpret, compiler_params=dict( mosaic=dict( dimension_semantics=( @@ -930,6 +1013,7 @@ def _flash_attention_bwd_dq( causal: bool, mask_value: float, debug: bool, + interpret: bool, ): batch_size, num_heads, q_seq_len, head_dim = q.shape _, _, kv_seq_len, _ = k.shape @@ -1087,6 +1171,7 @@ def kv_segment_ids_index_map(batch_index, head_index, q_seq_index, kv_seq_index) ), out_shape=out_shapes, debug=debug, + interpret=interpret, compiler_params=dict( mosaic=dict( dimension_semantics=( diff --git a/axlearn/common/flash_attention/tpu_attention_benchmark.py b/axlearn/common/flash_attention/tpu_attention_benchmark.py index 379048bc7..7e6400fa5 100644 --- a/axlearn/common/flash_attention/tpu_attention_benchmark.py +++ b/axlearn/common/flash_attention/tpu_attention_benchmark.py @@ -3,34 +3,28 @@ """Benchmark TPU FlashAttention kernels. Sample outputs: (v5p) +CMD: python \ +/opt/venv/lib/python3.10/site-packages/axlearn/common/flash_attention/tpu_attention_benchmark.py \ +2>&1 | grep -E "Benchmarking|ref_|HBM usage" Benchmarking attention representative of 1.2b model layer on TPU v5. -ref_fwd:0.0008s, flash_fwd:0.0007s -ref_bwd:0.0027s, flash_bwd:0.0026s - - Benchmarking attention representative of 12.6b model layer on TPU v5. -ref_fwd:0.0012s, flash_fwd:0.0010s -ref_bwd:0.0037s, flash_bwd:0.0026s - - Benchmarking attention representative of 29.6b model layer on TPU v5. -ref_fwd:0.0017s, flash_fwd:0.0013s -ref_bwd:0.0053s, flash_bwd:0.0034s - - Benchmarking attention representative of 65.2b model layer on TPU v5. -ref_fwd:0.0021s, flash_fwd:0.0015s -ref_bwd:0.0067s, flash_bwd:0.0042s - - Benchmarking attention representative of 134b model layer on TPU v5. -ref_fwd:0.0024s, flash_fwd:0.0018s -ref_bwd:0.0080s, flash_bwd:0.0050s - - Benchmarking attention representative of 261.7b model layer on TPU v5. -ref_fwd:0.0027s, flash_fwd:0.0019s -ref_bwd:0.0092s, flash_bwd:0.0056s - - Benchmarking attention representative of 539.5b model layer on TPU v5. -ref_fwd:0.0034s, flash_fwd:0.0023s -ref_bwd:0.0126s, flash_bwd:0.0070s +ref_fwd:0.2291s, flash_fwd:0.0014s +ref_bwd:0.0217s, flash_bwd:0.0058s +Benchmarking attention representative of 12.6b model layer on TPU v5. +ref_fwd:0.5699s, flash_fwd:0.0032s +ref_bwd:0.0524s, flash_bwd:0.0152s +Benchmarking attention representative of 29.6b model layer on TPU v5. +ref_fwd:0.7957s, flash_fwd:0.0043s +ref_bwd:0.0731s, flash_bwd:0.0204s +Benchmarking attention representative of 65.2b model layer on TPU v5. +ref_fwd:1.0225s, flash_fwd:0.0055s +ref_bwd:0.0948s, flash_bwd:0.0262s +Benchmarking attention representative of 134b model layer on TPU v5. +ref_fwd:1.2485s, flash_fwd:0.0067s +ref_bwd:0.1159s, flash_bwd:0.0313s +Benchmarking attention representative of 261.7b model layer on TPU v5. +ref_fwd:1.5577s, flash_fwd:0.0072s +ref_bwd:0.1349s, flash_bwd:0.0373s """ import time from typing import Callable, Optional @@ -49,8 +43,8 @@ _BENCHMARK_CONFIGS = { "1.2b": dict( - num_heads=32, - per_head_dim=64, + num_heads=16, + per_head_dim=128, ), "12.6b": dict( num_heads=40, @@ -72,10 +66,11 @@ num_heads=110, per_head_dim=128, ), - "539.5b": dict( - num_heads=140, - per_head_dim=128, - ), + # OOM in mha_reference. + # "539.5b": dict( + # num_heads=140, + # per_head_dim=128, + # ), } @@ -167,7 +162,8 @@ def _benchmark( print(f"Benchmarking attention representative of {name} model layer on {device_kind}.") _benchmark( batch_size=2, - seq_len=2048, + seq_len=1024 * 8, block_size=4 * 128, + sliding_window_size=1024, **cfg, ) diff --git a/axlearn/common/flash_attention/tpu_attention_test.py b/axlearn/common/flash_attention/tpu_attention_test.py index 14c3dea7f..f9a99c310 100644 --- a/axlearn/common/flash_attention/tpu_attention_test.py +++ b/axlearn/common/flash_attention/tpu_attention_test.py @@ -5,11 +5,12 @@ import unittest +import chex import jax import jax.numpy as jnp import numpy as np import pytest -from absl.testing import parameterized +from absl.testing import absltest, parameterized from jax.experimental import mesh_utils from jax.experimental.pallas.ops.tpu.splash_attention import splash_attention_mask from jax.experimental.shard_map import shard_map @@ -17,7 +18,9 @@ from jax.sharding import Mesh, NamedSharding, PartitionSpec from axlearn.common.attention_bias import ( + CausalAttentionBias, MaskFnAttentionBias, + ZeroAttentionBias, causal_mask, sliding_window_causal_mask, ) @@ -26,10 +29,16 @@ from axlearn.common.test_utils import TestCase, is_supported_mesh_shape from axlearn.common.utils import Tensor +# Comment out to test on CPU manually. Technically, this test runs on the CPU, albeit very slowly. if jax.default_backend() != "tpu": pytest.skip(reason="Incompatible hardware", allow_module_level=True) +def setUpModule(): + # If on CPU, emulate 4 devices. + chex.set_n_cpu_devices(4) + + def jax_fn_mask(query_position: Tensor, key_position: Tensor) -> Tensor: """A MaskFn that calls jax. @@ -73,21 +82,36 @@ def test_sliding_window_mask_equivalence(self, seq_len, sliding_window_size): for i in range(seq_len): self.assertNestedAllClose(ref_mask[i:, i:], test_mask[i:, i:]) + @parameterized.parameters( + [ZeroAttentionBias(), splash_attention_mask.FullMask((8, 8))], + [CausalAttentionBias(shape=(8, 8)), splash_attention_mask.CausalMask(shape=(8, 8))], + [ + MaskFnAttentionBias(sliding_window_causal_mask(4), shape=(8, 8)), + splash_attention_mask.LocalMask(shape=(8, 8), window_size=(4, 0), offset=0), + ], + ) + def test_to_splash_mask(self, mask, expected): + # pylint: disable-next=protected-access + splash_mask = tpu_attention._to_splash_mask(mask, mask_shape=(8, 8)) + self.assertEqual(splash_mask, expected) + @parameterized.product( batch_size=[4], - seq_len=[32768], + seq_len=[1024, 32768], + mask_fn=["zero", "causal", "sliding"], sliding_window_size=[1024], num_heads=[4], per_head_dim=[256], mesh=[(4, 1)], mesh_axis_names=[("data", "model")], ) - def test_sliding_window_mask( + def test_forward( self, batch_size, seq_len, num_heads, per_head_dim, + mask_fn, sliding_window_size, mesh, mesh_axis_names, @@ -121,12 +145,22 @@ def fn(q, k, v): ) softmax_scale = q.shape[-1] ** -0.5 - mask = MaskFnAttentionBias( - sliding_window_causal_mask(sliding_window_size), shape=(seq_len, seq_len) - ) + if mask_fn == "zero": + mask = ZeroAttentionBias() + elif mask_fn == "causal": + mask = CausalAttentionBias(shape=(seq_len, seq_len)) + elif mask_fn.startswith("sliding"): + mask = MaskFnAttentionBias( + sliding_window_causal_mask(sliding_window_size), shape=(seq_len, seq_len) + ) attn = lambda q, k, v: tpu_attention.tpu_flash_attention( - q, k, v, mask=mask, softmax_scale=softmax_scale + q, + k, + v, + mask=mask, + softmax_scale=softmax_scale, + interpret=(jax.default_backend() == "cpu"), ) partitioned_mha = shard_map( @@ -168,6 +202,9 @@ def test_forward_and_backward( attention_bias_type, with_segment_ids, ): + if jax.default_backend() == "cpu": + # TODO(dhwang2): this has been broken for a while on CPU. + pytest.skip(reason="Backward path is broken on CPU") # pylint: disable=protected-access causal = mask in [causal_mask, jax_fn_mask] @@ -224,7 +261,14 @@ def fn(q, k, v, bias, ids): ) with record_legacy_call: return tpu_attention.tpu_flash_attention( - q, k, v, bias, ids, mask=mask, softmax_scale=softmax_scale + q, + k, + v, + bias, + ids, + mask=mask, + softmax_scale=softmax_scale, + interpret=(jax.default_backend() == "cpu"), ) # Compare outputs. @@ -246,3 +290,7 @@ def fn(q, k, v, bias, ids): legacy_flash_wrapper.assert_called() else: legacy_flash_wrapper.assert_not_called() + + +if __name__ == "__main__": + absltest.main() diff --git a/axlearn/common/flash_attention/utils.py b/axlearn/common/flash_attention/utils.py index 938a47080..3859f8b6e 100644 --- a/axlearn/common/flash_attention/utils.py +++ b/axlearn/common/flash_attention/utils.py @@ -241,6 +241,8 @@ def get_segment_ids(segment_ids: SegmentIdAttentionBias) -> Optional[Tensor]: ) elif backend == "tpu": + # TODO(dhwang2): splash attention supports GQA natively, so don't repeat it. + # https://github.com/jax-ml/jax/blob/7b9914d711593dca8725d46aa1dadb2194284519/jax/experimental/pallas/ops/tpu/splash_attention/splash_attention_kernel.py#L934 key = _repeat_kv_heads(query.shape[2], key) value = _repeat_kv_heads(query.shape[2], value) # `mask` is supported.