Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimize TPU Flash Attention (20x XLA compilation speed-up on 32k long context) #908

Merged
merged 1 commit into from
Jan 7, 2025

Conversation

ds-hwang
Copy link
Contributor

@ds-hwang ds-hwang commented Jan 7, 2025

Optimize TPU Flash Attention (20x XLA compilation speed-up on 32k long context) (#908)
Use splash attention lazy mask instead of jnp mask, which is O(T^2).

The host memory usage for the jnp mask is O(T^2). Currently, a jnp mask is
created and then wrapped with NumpyMask for use in Splash Attention,
resulting in O(T^2) temporal HBM usage (somehow XLA avoids allocating it tho).
This PR proposes using CausalMask or LocalMask, allowing each Splash
Attention block to lazily create and use the required mask in the pallas
kernel.

The runtime performance of Splash Attention remains nearly the same. However,
the JIT compilation time for the function using Splash Attention has improved
significantly. It appears that allocating the O(T^2) mask in HBM and then
wrapping it with NumpyMask consumes a lot of XLA compilation time.

  • Benchmark: on TPUv5p, (model_dim/heads/kv_heads/seq_len), tools/attention_benchmark.py
  1. measure time with XLA compilation

NumpyMask (ASIS)

----------------------------------------------------------------------------------------
Benchmark                                              Time             CPU   Iterations
----------------------------------------------------------------------------------------
FlashAttentionBenchmark/4096/16/2/4096        3609 ms         1645 ms            1
FlashAttentionBenchmark/4096/16/2/8192        7828 ms         5696 ms            1
FlashAttentionBenchmark/4096/16/2/32768      94368 ms        91442 ms            1

CausalMask (Proposed PR): significant XLA compilation speed-up.

----------------------------------------------------------------------------------------
Benchmark                                              Time             CPU   Iterations
----------------------------------------------------------------------------------------
FlashAttentionBenchmark/4096/16/2/4096        22.9 ms         1.60 ms          127
FlashAttentionBenchmark/4096/16/2/8192        40.8 ms         2.12 ms           88
FlashAttentionBenchmark/4096/16/2/32768       7641 ms         5458 ms            1
  1. measure time without XLA compilation (pure jit computation)

NumpyMask (ASIS)

----------------------------------------------------------------------------------------
Benchmark                                              Time             CPU   Iterations
----------------------------------------------------------------------------------------
FlashAttentionBenchmark/4096/16/2/4096        9.97 ms        0.757 ms          918
FlashAttentionBenchmark/4096/16/2/8192        19.4 ms        0.934 ms          832
FlashAttentionBenchmark/4096/16/2/32768        116 ms         1.03 ms          100

CausalMask (Proposed PR): slight step time speed-up.

----------------------------------------------------------------------------------------
Benchmark                                              Time             CPU   Iterations
----------------------------------------------------------------------------------------
FlashAttentionBenchmark/4096/16/2/4096        9.82 ms        0.690 ms          964
FlashAttentionBenchmark/4096/16/2/8192        19.2 ms        0.822 ms          837
FlashAttentionBenchmark/4096/16/2/32768        116 ms        0.997 ms          100

In addition, tpu_attention_benchmark.py is changed to use 8k seq len, not 2k
with sliding window = 1k.

NumpyMask (ASIS)

Benchmarking attention representative of 1.2b model layer on TPU v5.
ref_fwd:0.2288s, flash_fwd:0.0014s
ref_bwd:0.0218s, flash_bwd:0.0058s
Benchmarking attention representative of 12.6b model layer on TPU v5.
ref_fwd:0.5700s, flash_fwd:0.0032s
ref_bwd:0.0527s, flash_bwd:0.0149s
Benchmarking attention representative of 29.6b model layer on TPU v5.
ref_fwd:0.7958s, flash_fwd:0.0044s
ref_bwd:0.0730s, flash_bwd:0.0205s
Benchmarking attention representative of 65.2b model layer on TPU v5.
ref_fwd:1.0222s, flash_fwd:0.0055s
ref_bwd:0.0949s, flash_bwd:0.0262s
Benchmarking attention representative of 134b model layer on TPU v5.
ref_fwd:1.2486s, flash_fwd:0.0067s
ref_bwd:0.1161s, flash_bwd:0.0318s
Benchmarking attention representative of 261.7b model layer on TPU v5.
ref_fwd:1.5577s, flash_fwd:0.0072s
ref_bwd:0.1348s, flash_bwd:0.0375s

LocalMask (Proposed PR): slight fwd/bwd time speed-up.

Benchmarking attention representative of 1.2b model layer on TPU v5.
ref_fwd:0.2291s, flash_fwd:0.0014s
ref_bwd:0.0217s, flash_bwd:0.0058s
Benchmarking attention representative of 12.6b model layer on TPU v5.
ref_fwd:0.5699s, flash_fwd:0.0032s
ref_bwd:0.0524s, flash_bwd:0.0152s
Benchmarking attention representative of 29.6b model layer on TPU v5.
ref_fwd:0.7957s, flash_fwd:0.0043s
ref_bwd:0.0731s, flash_bwd:0.0204s
Benchmarking attention representative of 65.2b model layer on TPU v5.
ref_fwd:1.0225s, flash_fwd:0.0055s
ref_bwd:0.0948s, flash_bwd:0.0262s
Benchmarking attention representative of 134b model layer on TPU v5.
ref_fwd:1.2485s, flash_fwd:0.0067s
ref_bwd:0.1159s, flash_bwd:0.0313s
Benchmarking attention representative of 261.7b model layer on TPU v5.
ref_fwd:1.5577s, flash_fwd:0.0072s
ref_bwd:0.1349s, flash_bwd:0.0373s

…g context)

Use splash attention lazy mask instead of jnp mask, which is O(T^2).

The host memory usage for the `jnp` mask is O(T^2). Currently, a `jnp` mask is
created and then wrapped with `NumpyMask` for use in Splash Attention,
resulting in O(T^2) temporal HBM usage (somehow XLA avoids allocating it tho).
This PR proposes using `CausalMask` or `LocalMask`, allowing each Splash
Attention block to lazily create and use the required mask in the pallas
kernel.

The runtime performance of Splash Attention remains nearly the same. However,
the JIT compilation time for the function using Splash Attention has improved
significantly. It appears that allocating the O(T^2) mask in HBM and then
wrapping it with `NumpyMask` consumes a lot of XLA compilation time.

* Benchmark: on TPUv5p, (model_dim/heads/kv_heads/seq_len), tools/attention_benchmark.py

1) measure time with XLA compilation

NumpyMask (ASIS)
----------------------------------------------------------------------------------------
Benchmark                                              Time             CPU   Iterations
----------------------------------------------------------------------------------------
FlashAttentionBenchmark/4096/16/2/4096        3609 ms         1645 ms            1
FlashAttentionBenchmark/4096/16/2/8192        7828 ms         5696 ms            1
FlashAttentionBenchmark/4096/16/2/32768      94368 ms        91442 ms            1

CausalMask (Proposed PR): significant XLA compilation speed-up.
----------------------------------------------------------------------------------------
Benchmark                                              Time             CPU   Iterations
----------------------------------------------------------------------------------------
FlashAttentionBenchmark/4096/16/2/4096        22.9 ms         1.60 ms          127
FlashAttentionBenchmark/4096/16/2/8192        40.8 ms         2.12 ms           88
FlashAttentionBenchmark/4096/16/2/32768       7641 ms         5458 ms            1

2) measure time without XLA compilation (pure jit computation)

NumpyMask (ASIS)
----------------------------------------------------------------------------------------
Benchmark                                              Time             CPU   Iterations
----------------------------------------------------------------------------------------
FlashAttentionBenchmark/4096/16/2/4096        9.97 ms        0.757 ms          918
FlashAttentionBenchmark/4096/16/2/8192        19.4 ms        0.934 ms          832
FlashAttentionBenchmark/4096/16/2/32768        116 ms         1.03 ms          100

CausalMask (Proposed PR): slight step time speed-up.
----------------------------------------------------------------------------------------
Benchmark                                              Time             CPU   Iterations
----------------------------------------------------------------------------------------
FlashAttentionBenchmark/4096/16/2/4096        9.82 ms        0.690 ms          964
FlashAttentionBenchmark/4096/16/2/8192        19.2 ms        0.822 ms          837
FlashAttentionBenchmark/4096/16/2/32768        116 ms        0.997 ms          100

In addition, tpu_attention_benchmark.py is changed to use 8k seq len, not 2k
with sliding window = 1k.

NumpyMask (ASIS)
Benchmarking attention representative of 1.2b model layer on TPU v5.
ref_fwd:0.2288s, flash_fwd:0.0014s
ref_bwd:0.0218s, flash_bwd:0.0058s
Benchmarking attention representative of 12.6b model layer on TPU v5.
ref_fwd:0.5700s, flash_fwd:0.0032s
ref_bwd:0.0527s, flash_bwd:0.0149s
Benchmarking attention representative of 29.6b model layer on TPU v5.
ref_fwd:0.7958s, flash_fwd:0.0044s
ref_bwd:0.0730s, flash_bwd:0.0205s
Benchmarking attention representative of 65.2b model layer on TPU v5.
ref_fwd:1.0222s, flash_fwd:0.0055s
ref_bwd:0.0949s, flash_bwd:0.0262s
Benchmarking attention representative of 134b model layer on TPU v5.
ref_fwd:1.2486s, flash_fwd:0.0067s
ref_bwd:0.1161s, flash_bwd:0.0318s
Benchmarking attention representative of 261.7b model layer on TPU v5.
ref_fwd:1.5577s, flash_fwd:0.0072s
ref_bwd:0.1348s, flash_bwd:0.0375s

LocalMask (Proposed PR): slight fwd/bwd time speed-up.
Benchmarking attention representative of 1.2b model layer on TPU v5.
ref_fwd:0.2291s, flash_fwd:0.0014s
ref_bwd:0.0217s, flash_bwd:0.0058s
Benchmarking attention representative of 12.6b model layer on TPU v5.
ref_fwd:0.5699s, flash_fwd:0.0032s
ref_bwd:0.0524s, flash_bwd:0.0152s
Benchmarking attention representative of 29.6b model layer on TPU v5.
ref_fwd:0.7957s, flash_fwd:0.0043s
ref_bwd:0.0731s, flash_bwd:0.0204s
Benchmarking attention representative of 65.2b model layer on TPU v5.
ref_fwd:1.0225s, flash_fwd:0.0055s
ref_bwd:0.0948s, flash_bwd:0.0262s
Benchmarking attention representative of 134b model layer on TPU v5.
ref_fwd:1.2485s, flash_fwd:0.0067s
ref_bwd:0.1159s, flash_bwd:0.0313s
Benchmarking attention representative of 261.7b model layer on TPU v5.
ref_fwd:1.5577s, flash_fwd:0.0072s
ref_bwd:0.1349s, flash_bwd:0.0373s
@ds-hwang ds-hwang requested review from ruomingp, markblee and a team as code owners January 7, 2025 21:01
@ds-hwang ds-hwang requested a review from apghml January 7, 2025 21:02
@ds-hwang
Copy link
Contributor Author

ds-hwang commented Jan 7, 2025

@apghml could you review? From 916

@ds-hwang ds-hwang added this pull request to the merge queue Jan 7, 2025
@ds-hwang
Copy link
Contributor Author

ds-hwang commented Jan 7, 2025

Thank you for review!

Merged via the queue into apple:main with commit c40b39a Jan 7, 2025
6 checks passed
@ds-hwang ds-hwang deleted the flsh_op branch January 7, 2025 22:23
@hanzhi713
Copy link
Member

@ds-hwang nit: the PR description is wrong (perhaps copy-pasted from a different PR?).

@ruomingp
Copy link
Contributor

ruomingp commented Jan 8, 2025

Wait, how is this PR approved and merged? I'm not sure that the discussion has been resolved on the internal PR.

@apghml
Copy link
Contributor

apghml commented Jan 8, 2025

The PR description is wrong. This is not related to gans. The title is correct.

@apghml
Copy link
Contributor

apghml commented Jan 8, 2025

Oh, it looks like you hadn't approved the actual PR this comes from internally either though. I mistakenly thought you had.

@ds-hwang
Copy link
Contributor Author

ds-hwang commented Jan 8, 2025

Sorry for the confusion. The PR that was merged is the Flash Attention PR, not the GAN loss PR. I copied and pasted the content from the internal PR by mistake. Fortunately, the PR title and commit description that were merged are correct, and I’ve now updated the body text.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants