If you’ve been following the world of large language models, you know that one of the biggest goals is expanding the context window. We want models that can read entire books, analyze lengthy codebases, or process high-resolution images. The main obstacle? The attention mechanism at the heart of the Transformer architecture. Its computational and memory costs grow quadratically with the sequence length, making long contexts prohibitively expensive.
A breakthrough paper in 2022, FlashAttention, tackled this problem head-on. By cleverly reordering the attention computation to be more aware of the GPU’s memory hierarchy, it achieved linear memory usage and a 2–4× speedup over standard implementations—all without any approximation. It was a game-changer and has been widely adopted.
But the story doesn’t end there. While FlashAttention was fast, it still wasn’t reaching the full potential of modern hardware. Its performance, measured in floating-point operations per second (FLOPs/s), hovered around 25–40% of the theoretical maximum for a GPU like the NVIDIA A100. For comparison, highly optimized matrix multiplication (GEMM) routines can hit 80–90% of that maximum. There was still a significant performance gap to close.
Enter FlashAttention-2. This follow-up work dissects remaining inefficiencies in the original algorithm and introduces a set of optimizations. By rethinking how work is partitioned across and within the GPU’s compute units, FlashAttention-2 delivers another ≈2× speedup, pushing hardware utilization to an impressive 50–73% of the theoretical maximum.
In this post, we’ll dive deep into the paper “FLASHATTENTION-2: Faster Attention with Better Parallelism and Work Partitioning”. We’ll explore:
- A quick refresher on why standard attention is slow and how the original FlashAttention works.
- The three key innovations in FlashAttention-2: reducing slow operations, smarter parallelization, and more efficient work distribution.
- The impressive results showing FlashAttention-2 getting tantalizingly close to the speed of pure matrix multiplication.
Background: The GPU Bottleneck and FlashAttention-1
To understand FlashAttention-2, we first need to grasp why the standard attention implementation is inefficient, and how its predecessor, FlashAttention, solved the initial part of the problem.
GPU Hardware 101
A GPU isn’t a single compute engine—it’s a complex system with a memory hierarchy. For our purposes, the two most important levels are:
- High Bandwidth Memory (HBM): The large VRAM (e.g., 40–80 GB on an A100) that holds your model and data. It’s “high bandwidth” compared to regular RAM, but relative to on-chip memory it’s slow.
- SRAM (Static RAM), aka Shared Memory: Extremely fast, on-chip memory. It’s tiny (kilobytes per compute unit) but has much higher bandwidth than HBM.
The golden rule of GPU programming: minimize reads and writes to HBM. The most efficient algorithms load data from HBM into fast SRAM once, perform as many computations as possible, then write only final results back to HBM. Each unnecessary round trip to HBM creates a major performance bottleneck.
The Problem with Standard Attention
The standard self-attention mechanism is defined by the following equations:
Standard attention: Given queries Q, keys K, and values V, compute scores S, apply softmax to get probabilities P, then weight V by P to get the output O.
Here, Q, K, and V are matrices of shape \( N \times d \), where \( N \) is the sequence length and \( d \) is the head dimension.
The naive implementation runs:
- Compute the
N × N
score matrix S = QKT. Write S to HBM. - Read S from HBM, apply the row-wise softmax to get P. Write P to HBM.
- Read P from HBM, multiply by V to get the output O.
For a sequence length of just 8k, S and P are 8k × 8k matrices—hundreds of millions of elements. Storing and moving these massive matrices back and forth from HBM is extremely slow and uses vast amounts of memory. This is the quadratic bottleneck.
FlashAttention to the Rescue: Tiling and Online Softmax
The key insight of FlashAttention was to avoid ever writing the full S and P matrices to HBM, using a technique called tiling.
The algorithm breaks Q, K, and V into smaller blocks, loading one block of Q and one block of K/V into SRAM at a time.
Figure 1: The original FlashAttention algorithm loads blocks of K and V into fast SRAM, computes attention against a block of Q, and updates the output. This avoids materializing the full N × N attention matrix in slow HBM.
Inside SRAM, it computes attention for just that block. But softmax needs the sum of exponentials over the entire row to normalize correctly. You can’t just process one block independently.
This is where online softmax comes in. The algorithm maintains a running maximum and a running normalization factor for each row. For each new block, it calculates the local softmax, rescales earlier results using the updated statistics, and combines them. By the end of all blocks, you have the exact same result as standard attention—without the expensive HBM round trips.
For the backward pass, instead of storing the huge P matrix, FlashAttention recomputes attention blocks on-the-fly. The memory complexity drops from \(O(N^2)\) to \(O(N)\).
FlashAttention-2: Pushing Hardware Efficiency Further
FlashAttention was a huge leap forward, but still left some GPU performance on the table. The authors identified three optimizations for FlashAttention-2.
1. Reducing Non-Matmul FLOPs
Modern GPUs like the A100 have Tensor Cores optimized for matrix multiplication. On an A100:
- FP16/BF16 matmuls: 312 TFLOPs/s (peak)
- FP32 non-matmul ops: 19.5 TFLOPs/s (peak)
A non-matmul FLOP can be up to 16× more expensive than a matmul FLOP.
The original FlashAttention performed some non-matmul scaling inside its inner loop—for example, repeatedly rescaling O after each block.
FlashAttention-2 changes the online softmax procedure:
Figure: FlashAttention-2 avoids repeated re-scaling of the output inside the loop, only scaling once at the end.
Instead of scaling O in every iteration, it maintains an unscaled version of the output and normalization statistics, performing a single scaling at the end. This reduces the number of expensive non-matmul operations, letting the GPU spend more time doing high-throughput matmuls.
2. Better Parallelism Across the GPU
An NVIDIA A100 has 108 Streaming Multiprocessors (SMs). These execute thread blocks, with thousands of threads working in parallel. To maximize speed, all SMs need active work.
Original FlashAttention parallelized computation over batch size × number of heads, assigning one thread block per head. This works well if batch × heads ≥ SM count.
With long sequences, batch size is often small (to fit in memory). If batch × heads < SMs, many SMs sit idle.
FlashAttention-2 adds parallelization along the sequence length dimension \(N\):
Figure 2: Forward pass (left): each thread block handles a block of rows (query slices). Backward pass (right): thread blocks handle blocks of columns and use atomic adds to accumulate gradients for shared rows.
- Forward Pass: Blocks of rows in Q are independent, so thread blocks can process them in parallel. This boosts occupancy even for small batch sizes.
- Backward Pass: More complex due to gradient dependencies for Q, but restructured to work on column blocks (K, V slices) in parallel. Dependencies are handled with atomic adds.
This ensures full GPU utilization for long-sequence regimes.
3. More Efficient Work Partitioning Within a Thread Block
Inside a thread block, threads are grouped into warps (usually 32 threads). Warps can share data via shared memory quickly, but even shared-memory access costs time.
Original FlashAttention used a “split-K” scheme:
- Q shared by all warps
- K, V split across warps
- Results combined via shared memory writes/reads
This created a bottleneck.
FlashAttention-2 flips the scheme:
Figure 3: FlashAttention (a) splits K and V, requiring inter-warp communication. FlashAttention-2 (b) splits Q, with all warps sharing K and V—minimal communication.
Now:
- K, V shared by all warps
- Q split across warps
- Each warp independently computes its slice of QKT and multiplies by shared V
- Almost no inter-warp communication until final write-out
This reduces shared-memory traffic and synchronization overhead.
The Results: Closing the Gap
What’s the total effect of these optimizations? Benchmarks show clear wins.
Figure 4: Combined forward + backward pass speed on A100. FlashAttention-2 (purple) is consistently ~2× faster than FlashAttention (orange), and much faster than PyTorch (blue) and xformers (green).
Compared to PyTorch’s standard attention, the speedup can reach 10× for long sequences. Forward and backward passes show the raw throughput:
Figure 5: Forward pass: FlashAttention-2 reaches 230 TFLOPs/s—73% of the A100’s theoretical max.
Figure 6: Backward pass: FlashAttention-2 reaches 63% of theoretical max.
Achieving >70% theoretical peak for attention—approaching GEMM efficiency—is remarkable.
End-to-End Training Gains
The ultimate test: training real models.
Table 1: End-to-end training throughput for GPT models on 8×A100. FlashAttention-2 reaches up to 225 TFLOPs/s per GPU (72% utilization), giving 1.3× speedup over FlashAttention and up to 2.8× over no FlashAttention.
For a GPT-1.3B with 8k context:
- FlashAttention-2: 220 TFLOPs/s
- FlashAttention: 170 TFLOPs/s
- Baseline: 72 TFLOPs/s
This means training a 16k context model now costs about the same as an 8k context model before.
Conclusion and Future Directions
FlashAttention-2 is a masterclass in hardware-aware algorithm design. By identifying bottlenecks from high-level math down to warp-level execution, the authors have pushed exact attention to the brink of hardware limits.
The implications are big:
- Long-context training and inference are more affordable and feasible.
- Existing pipelines get notable speed boosts.
- Enables models that can process much larger amounts of information.
Looking ahead, the authors plan to:
- Optimize for NVIDIA H100 GPUs (using new Tensor Memory Accelerator and FP8 Tensor Cores)
- Support new data types like FP8
- Combine with high-level algorithmic optimizations (local, block-sparse attention) for even longer contexts
With FlashAttention-2, the dream of models with virtually unlimited context windows is closer than ever.