Transformers have revolutionized machine learning, but they have a well-known Achilles’ heel: the self-attention mechanism. While incredibly powerful, its computational and memory costs grow quadratically with the sequence length. This \(O(N^2)\) complexity has been a major barrier, making it prohibitively expensive to train models on long documents, high-resolution images, or lengthy audio clips.

For years, researchers have tried to tame this quadratic beast with approximate attention methods. These techniques trade a bit of model accuracy for better efficiency, often reducing complexity to linear or near-linear time. But here’s the catch: many of these theoretically faster methods don’t actually speed up training in practice. They reduce the number of calculations (FLOPs), but often overlook the real bottleneck on modern hardware like GPUs: memory access.

A groundbreaking paper from Stanford, “FLASHATTENTION: Fast and Memory-Efficient Exact Attention with IO-Awareness”, argues that we’ve been looking in the wrong place. The authors propose that the key isn’t just to reduce computations, but to become IO-aware—to intelligently manage how data is moved between the different levels of GPU memory.

By doing so, they created FlashAttention, an algorithm that computes exact attention but is dramatically faster and more memory-efficient than standard implementations. It doesn’t approximate; it re-engineers the process from the ground up with the hardware in mind. This single innovation leads to massive end-to-end training speedups and, more excitingly, unlocks the ability for Transformers to handle sequence lengths that were previously unimaginable, opening up entirely new capabilities.

In this deep dive, we’ll unpack the magic behind FlashAttention. We’ll explore the hardware limitations that standard attention runs into and see how FlashAttention’s clever use of tiling and recomputation sidesteps them completely.

The Real Bottleneck: A Tale of Two Memories

To understand why FlashAttention is so effective, we first need to understand the hardware it runs on. Modern GPUs have a memory hierarchy, with different levels offering trade-offs between size and speed.

GPU memory hierarchy pyramid (left), schematic of FlashAttention data movement (center), and GPT-2 attention runtime comparison (right).

Figure 1: Left: Memory hierarchy showing fast but small SRAM and slower, larger HBM and DRAM. Center: FlashAttention computation loops through \(K, V\) and \(Q\) blocks on SRAM without materializing the full \(N \times N\) matrix in HBM. Right: Comparison of PyTorch attention and FlashAttention runtime for GPT-2—FlashAttention’s fused kernel achieves a 7.6× speedup.

The two most important levels for our discussion are:

  1. High Bandwidth Memory (HBM): Main GPU memory, large (e.g., 40–80 GB on an NVIDIA A100) but relatively slow compared to the GPU’s compute throughput. Accessing HBM is a major performance bottleneck.
  2. SRAM (On-chip Static RAM): Much smaller but extremely fast memory located directly on GPU core units, available in kilobytes to a few megabytes. An order of magnitude faster than HBM.

Operations can be classified as compute-bound (limited by calculations) or memory-bound (limited by the time to move data between slow memory and compute units). As GPU compute has scaled faster than memory bandwidth, many Transformer operations—element-wise ops and reductions like softmax—are now memory-bound.

How Standard Attention Gets Stuck

A standard attention computation is:

\[ S = QK^{\mathsf{T}}, \quad P = \mathrm{softmax}(S), \quad O = PV \]

The trouble comes from the intermediate matrices \(S\) and \(P\), each of size \(N \times N\) for sequence length \(N\). The typical implementation does:

  1. Compute \(S\) in full and write to HBM.
  2. Read \(S\) back from HBM to compute softmax.
  3. Write \(P\) to HBM.
  4. Read \(P\) and \(V\) back from HBM to compute \(O\).

For \(N = 8192\), one such matrix of 32-bit floats is ~256 MB. Multiply that by multiple heads and batch size, and your HBM traffic explodes. Plus, during backpropagation, the \(P\) matrix must be stored, consuming quadratic memory.

FlashAttention’s mission: compute exact attention without ever materializing the full \(N \times N\) matrices in HBM.

The Core Method: Tiling and Recomputation

FlashAttention’s performance breakthroughs come from two key techniques: tiling and recomputation. The algorithm fuses all attention operations—matrix multiplies, masking, softmax, and dropout—into a single GPU kernel. This kernel loads input from HBM, performs all steps in fast SRAM, and writes only the final output back to HBM.

Tiling: Computing Softmax in Blocks

Softmax poses a hurdle. For any row, you need all its values to normalize. FlashAttention uses a numerically stable decomposition of softmax:

\[ m(x) = \max_i x_i, \quad \ell(x) = \sum_i e^{x_i - m(x)}, \quad \text{softmax}(x)_i = \frac{e^{x_i - m(x)}}{\ell(x)} \]

If you split \(x\) into blocks \(x^{(1)}, x^{(2)}\), you can merge block statistics:

\[ m(x) = \max(m(x^{(1)}), m(x^{(2)})), \quad \ell(x) = e^{m(x^{(1)}) - m(x)}\ell(x^{(1)}) + e^{m(x^{(2)}) - m(x)}\ell(x^{(2)}) \]

That means we can process a row in blocks—never needing it all in SRAM at once.

FlashAttention works like so:

  • Outer loop: Iterate over blocks of \(K\) and \(V\), loading one block into SRAM.
  • Inner loop: For the current \(K_j, V_j\) block, iterate over all \(Q_i\) blocks, loading each in turn into SRAM.
  • On-chip compute: With \(Q_i, K_j, V_j\) in SRAM:
    • Compute \(S_{ij} = Q_iK_j^{\mathsf{T}}\).
    • Compute block softmax and update running \((m_i, \ell_i)\).
    • Update the output block \(O_i\) incrementally.
  • Write-back: Save updated \(O_i\) to HBM.

By the end, \(O\) is assembled without creating full \(S\) or \(P\) in HBM.

Recomputation: Faster Backward Without Huge Storage

During training, backpropagation typically needs \(P\). FlashAttention avoids storing it by recomputing blocks during the backward pass.

With \(Q, K, V, O\) and saved \((m, \ell)\) from forward, the kernel recomputes each attention block in SRAM. This adds extra FLOPs but eliminates massive HBM reads. Because compute is cheap and memory access is expensive, the recomputation yields faster overall runtime.

Comparison of GFLOPs, HBM accesses, and runtime between standard attention and FlashAttention (left), effect of block size (middle), and speedup from block sparsity (right).

Figure 2: Left: FlashAttention uses slightly more FLOPs but drastically fewer HBM accesses—over 9× less—leading to huge speedups. Middle: Larger block sizes reduce HBM accesses until runtime becomes compute-bound. Right: Block-Sparse FlashAttention gains proportional speedups as sparsity increases.

The Theory: IO Complexity

Analyzing HBM accesses:

  • Standard attention: \(\Theta(Nd + N^2)\) HBM accesses.
  • FlashAttention: \(\Theta(N^2 d^2 / M)\) accesses, where \(M\) is SRAM size.

With typical \(d\) and \(M\), FlashAttention’s IO cost is many times smaller.

Even Faster: Block-Sparse FlashAttention

FlashAttention can also speed up approximate attention methods. In Block-Sparse FlashAttention, a sparsity mask indicates which blocks to compute. Skipping zero blocks reduces IO by factor \(s\) (fraction of non-zero blocks).

Figure 2 (right) shows Block-Sparse FlashAttention running even faster than dense FlashAttention, with higher sparsity producing greater gains.

Real-World Results

The gains are dramatic:

Faster Training

  • BERT-large: 15% faster than MLPerf 1.1 record.
  • GPT-2: Up to 3× faster than HuggingFace, 1.7× faster than Megatron-LM.
  • Long Range Arena: 2.4× speedup on long-sequence benchmark.

Runtime vs sequence length (left) showing FlashAttention and Block-Sparse outperforming baselines; memory footprint (right) showing FlashAttention’s linear and low usage.

Figure 3: Left: FlashAttention beats exact attention baselines; Block-Sparse is fastest overall. Right: Linear memory usage—up to 20× less than exact baselines.

Better Models with Longer Context

  • GPT-2 with 4K context: Still 30% faster than Megatron’s 1K context version, and 0.7 better perplexity.
  • Long document classification: Increasing sequence length to 8K or 16K yields up to 8.5-point accuracy gains in tasks like MIMIC-III (medical) and ECtHR (legal).
  • Path-X and Path-256: First Transformers to beat random accuracy on these extreme long-context vision tasks—enabled only by FlashAttention’s scalability.

Conclusion and Future Directions

FlashAttention changes how we think about performance optimization in deep learning. Key insights:

  1. Memory I/O is the bottleneck: On modern GPUs, moving data is often costlier than computing.
  2. IO-aware algorithms yield massive speedups: Minimize HBM reads/writes with techniques like tiling.
  3. Efficiency expands capability: Handling longer sequences improves model quality and enables solving previously impossible problems.

The authors note that writing custom CUDA kernels is engineering-intensive, and envision high-level tools or compilers generating IO-aware kernels automatically.

FlashAttention is more than an optimization—it’s a new primitive for the ML stack. By making long-context Transformers practical, it unlocks advancements across domains: long-form text, high-res video, genomics, and beyond. Sometimes, a leap forward comes from simply paying closer attention—to your memory.