-
Notifications
You must be signed in to change notification settings - Fork 281
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Optimize TPU Flash Attention (20x XLA compilation speed-up on 32k long context) #908
Conversation
…g context) Use splash attention lazy mask instead of jnp mask, which is O(T^2). The host memory usage for the `jnp` mask is O(T^2). Currently, a `jnp` mask is created and then wrapped with `NumpyMask` for use in Splash Attention, resulting in O(T^2) temporal HBM usage (somehow XLA avoids allocating it tho). This PR proposes using `CausalMask` or `LocalMask`, allowing each Splash Attention block to lazily create and use the required mask in the pallas kernel. The runtime performance of Splash Attention remains nearly the same. However, the JIT compilation time for the function using Splash Attention has improved significantly. It appears that allocating the O(T^2) mask in HBM and then wrapping it with `NumpyMask` consumes a lot of XLA compilation time. * Benchmark: on TPUv5p, (model_dim/heads/kv_heads/seq_len), tools/attention_benchmark.py 1) measure time with XLA compilation NumpyMask (ASIS) ---------------------------------------------------------------------------------------- Benchmark Time CPU Iterations ---------------------------------------------------------------------------------------- FlashAttentionBenchmark/4096/16/2/4096 3609 ms 1645 ms 1 FlashAttentionBenchmark/4096/16/2/8192 7828 ms 5696 ms 1 FlashAttentionBenchmark/4096/16/2/32768 94368 ms 91442 ms 1 CausalMask (Proposed PR): significant XLA compilation speed-up. ---------------------------------------------------------------------------------------- Benchmark Time CPU Iterations ---------------------------------------------------------------------------------------- FlashAttentionBenchmark/4096/16/2/4096 22.9 ms 1.60 ms 127 FlashAttentionBenchmark/4096/16/2/8192 40.8 ms 2.12 ms 88 FlashAttentionBenchmark/4096/16/2/32768 7641 ms 5458 ms 1 2) measure time without XLA compilation (pure jit computation) NumpyMask (ASIS) ---------------------------------------------------------------------------------------- Benchmark Time CPU Iterations ---------------------------------------------------------------------------------------- FlashAttentionBenchmark/4096/16/2/4096 9.97 ms 0.757 ms 918 FlashAttentionBenchmark/4096/16/2/8192 19.4 ms 0.934 ms 832 FlashAttentionBenchmark/4096/16/2/32768 116 ms 1.03 ms 100 CausalMask (Proposed PR): slight step time speed-up. ---------------------------------------------------------------------------------------- Benchmark Time CPU Iterations ---------------------------------------------------------------------------------------- FlashAttentionBenchmark/4096/16/2/4096 9.82 ms 0.690 ms 964 FlashAttentionBenchmark/4096/16/2/8192 19.2 ms 0.822 ms 837 FlashAttentionBenchmark/4096/16/2/32768 116 ms 0.997 ms 100 In addition, tpu_attention_benchmark.py is changed to use 8k seq len, not 2k with sliding window = 1k. NumpyMask (ASIS) Benchmarking attention representative of 1.2b model layer on TPU v5. ref_fwd:0.2288s, flash_fwd:0.0014s ref_bwd:0.0218s, flash_bwd:0.0058s Benchmarking attention representative of 12.6b model layer on TPU v5. ref_fwd:0.5700s, flash_fwd:0.0032s ref_bwd:0.0527s, flash_bwd:0.0149s Benchmarking attention representative of 29.6b model layer on TPU v5. ref_fwd:0.7958s, flash_fwd:0.0044s ref_bwd:0.0730s, flash_bwd:0.0205s Benchmarking attention representative of 65.2b model layer on TPU v5. ref_fwd:1.0222s, flash_fwd:0.0055s ref_bwd:0.0949s, flash_bwd:0.0262s Benchmarking attention representative of 134b model layer on TPU v5. ref_fwd:1.2486s, flash_fwd:0.0067s ref_bwd:0.1161s, flash_bwd:0.0318s Benchmarking attention representative of 261.7b model layer on TPU v5. ref_fwd:1.5577s, flash_fwd:0.0072s ref_bwd:0.1348s, flash_bwd:0.0375s LocalMask (Proposed PR): slight fwd/bwd time speed-up. Benchmarking attention representative of 1.2b model layer on TPU v5. ref_fwd:0.2291s, flash_fwd:0.0014s ref_bwd:0.0217s, flash_bwd:0.0058s Benchmarking attention representative of 12.6b model layer on TPU v5. ref_fwd:0.5699s, flash_fwd:0.0032s ref_bwd:0.0524s, flash_bwd:0.0152s Benchmarking attention representative of 29.6b model layer on TPU v5. ref_fwd:0.7957s, flash_fwd:0.0043s ref_bwd:0.0731s, flash_bwd:0.0204s Benchmarking attention representative of 65.2b model layer on TPU v5. ref_fwd:1.0225s, flash_fwd:0.0055s ref_bwd:0.0948s, flash_bwd:0.0262s Benchmarking attention representative of 134b model layer on TPU v5. ref_fwd:1.2485s, flash_fwd:0.0067s ref_bwd:0.1159s, flash_bwd:0.0313s Benchmarking attention representative of 261.7b model layer on TPU v5. ref_fwd:1.5577s, flash_fwd:0.0072s ref_bwd:0.1349s, flash_bwd:0.0373s
@apghml could you review? From 916 |
Thank you for review! |
@ds-hwang nit: the PR description is wrong (perhaps copy-pasted from a different PR?). |
Wait, how is this PR approved and merged? I'm not sure that the discussion has been resolved on the internal PR. |
The PR description is wrong. This is not related to gans. The title is correct. |
Oh, it looks like you hadn't approved the actual PR this comes from internally either though. I mistakenly thought you had. |
Sorry for the confusion. The PR that was merged is the Flash Attention PR, not the GAN loss PR. I copied and pasted the content from the internal PR by mistake. Fortunately, the PR title and commit description that were merged are correct, and I’ve now updated the body text. |
Optimize TPU Flash Attention (20x XLA compilation speed-up on 32k long context) (#908)
Use splash attention lazy mask instead of jnp mask, which is O(T^2).
The host memory usage for the
jnp
mask is O(T^2). Currently, ajnp
mask iscreated and then wrapped with
NumpyMask
for use in Splash Attention,resulting in O(T^2) temporal HBM usage (somehow XLA avoids allocating it tho).
This PR proposes using
CausalMask
orLocalMask
, allowing each SplashAttention block to lazily create and use the required mask in the pallas
kernel.
The runtime performance of Splash Attention remains nearly the same. However,
the JIT compilation time for the function using Splash Attention has improved
significantly. It appears that allocating the O(T^2) mask in HBM and then
wrapping it with
NumpyMask
consumes a lot of XLA compilation time.NumpyMask (ASIS)
CausalMask (Proposed PR): significant XLA compilation speed-up.
NumpyMask (ASIS)
CausalMask (Proposed PR): slight step time speed-up.
In addition, tpu_attention_benchmark.py is changed to use 8k seq len, not 2k
with sliding window = 1k.
NumpyMask (ASIS)
LocalMask (Proposed PR): slight fwd/bwd time speed-up.