The Transformer architecture powers modern AI—from ChatGPT to Gemini—thanks to its attention mechanism, which allows models to focus selectively on relevant parts of the input. But with great power comes a serious bottleneck: as sequence lengths grow to entire books or massive codebases, the computational and memory demands of attention scale quadratically. Double the input length, and you quadruple the work. This is the infamous quadratic bottleneck.

Breakthroughs like FlashAttention have reduced these costs for standard use cases by avoiding expensive intermediate memory allocations. However, FlashAttention struggles when faced with the complex attention masks needed for modern training tasks—masks that dictate which tokens can “see” each other. These masks are crucial in scenarios like preference optimization, fine-tuning, or sequence packing. Current approaches often revert to dense, memory-hungry computations for such masks.

Researchers at Baidu Inc. have introduced FLASHMASK, a major extension of FlashAttention that applies the same IO-awareness and efficiency principles to a broad variety of complex masks. By rethinking mask representation, FLASHMASK achieves linear memory usage and massive speedups, enabling training on contexts up to 128K tokens—and beyond—without sacrificing precision.

In this post, we’ll unpack the FLASHMASK paper: the problem it solves, the core algorithm, and the impressive results that show why it matters.


The Problem with Masks in a Long-Context World

At its core, attention computes a matrix of scores for how much each token should attend to every other token:

\[ O = \text{Softmax}\left(\frac{QK^{T}}{\sqrt{d_k}}\right) V \]

Here, \(Q\) (Queries) and \(K\) (Keys) interact to produce scores; \(V\) (Values) are then aggregated according to these scores.
To control token visibility, we apply a mask \(M\) before softmax:

\[ \text{Attention}(Q, K, V, M) = \text{Softmax}\left(\frac{QK^{T}}{\sqrt{d_k}} + M\right) V \]

By adding \(-\infty\) to certain positions, their corresponding softmax scores become zero—those query-key pairs are “invisible.”


A Zoo of Attention Masks

Different training stages and tasks require different masks:

  • Causal Mask: Used in autoregressive models like GPT; blocks future tokens from being seen.
  • Document Mask: Used in sequence packing; restricts tokens to attend only within the same document.
  • Shared Question Mask: In Reward Modeling (RM) and Direct Preference Optimization (DPO), all answers can attend to the shared question but not to each other—reducing redundant work.
  • Global + Sliding Window Mask: Mixes global context tokens with local-window attention.
  • Prefix Masks, Blockwise Masks, Sparse Masks, and more.

The FLASHMASK paper identifies over a dozen common mask types, each with structured sparsity.

A collage of different attention mask patterns supported by FLASHMASK, its column-wise sparse representation, and an illustration of its efficient kernel implementation.

Figure 1: Overview of FLASHMASK. (a) Common mask types supported. (b) Column-wise sparse mask representation. (c) Efficient kernel implementation.

The problem: implementing these masks naively means building a dense \(N \times N\) matrix, with \(O(N^2)\) memory complexity. For \(N=128{,}000\), that’s 16 billion elements—prohibitively expensive.


FlashAttention’s Revolution—and Its Limits

FlashAttention avoids creating the full \(N \times N\) attention matrix. It splits computation into tiles that fit in on-chip SRAM, reading/writing to GPU memory efficiently. This shrinks memory overhead to \(O(N)\) and speeds up attention dramatically.

However, FlashAttention-2 only supports a small set of masks efficiently (e.g., causal, sliding window). For custom masks, it falls back to a dense mask method—reintroducing quadratic memory costs. Compiler-based solutions like FlexAttention improve flexibility but still leave performance on the table.


The Core Insight: Sparse Intervals Beat Dense Matrices

FLASHMASK’s key observation: in almost all practical masks, the masked-out rows form continuous intervals in each column.

Take a causal mask: in the \(j\)-th column (key token \(j\)), all queries after position \(j\) are blocked—one contiguous block. Document masks block one or two contiguous ranges per column.

So rather than storing an \(N \times N\) boolean grid, FLASHMASK keeps just the start and end indices of these masked intervals.

FLASHMASK uses four vectors of length \(N\):

  • LTS — Lower Triangular Start
  • LTE — Lower Triangular End
  • UTS — Upper Triangular Start
  • UTE — Upper Triangular End

For column \(j\), masked rows are:
\([LTS_j, LTE_j) \cup [UTS_j, UTE_j)\).

Example: in Figure 1(b)(6), column 5 has \([7, 10) \cup [2, 4)\) → rows 2–3 and 7–9 are masked.

Benefits:

  1. Compactness — \(O(N)\) storage vs \(O(N^2)\) for dense.
  2. Flexibility — Captures most real-world masks.
  3. Speed — Perfect for block/tile-based skipping in FlashAttention.

Integrating FLASHMASK into FlashAttention-2

FLASHMASK slots into FlashAttention-2 in two steps:

Algorithm 1: The forward pass of FlashAttention-2, with the additions for FLASHMASK highlighted.

Algorithm 1: Forward pass extended with FLASHMASK (blue highlights in original paper).

Step 1 — Preprocessing:
Divide LTS, LTE, UTS, and UTE into column blocks. For each block, compute min/max start/end indices. Store in 8 small summary vectors—cheap and cache-friendly.

Step 2 — Real-time Block Skipping:
When processing tile \((Q_i, K_j)\):

  • Fully Masked — Skip entirely (no K/V load, no matmul).
  • Unmasked — Run standard FlashAttention.
  • Partially Masked — Load detailed LTS/LTE/UTS/UTE for this block and mask elements selectively.

This coarse-to-fine approach cuts unnecessary work early, then applies detailed masking only when needed. Mask sparsity \(\rho\) directly reduces computation:

\[ O((1 - \rho) N^2) \quad \text{vs.} \quad O(N^2) \]

And because it’s exact, outputs match dense-mask results bit-for-bit.


Results: FLASHMASK in Action

End-to-End Training Speed

Researchers fine-tuned Llama-2 (7B, 13B, 70B) on:

  • Supervised Fine-Tuning (SFT)
  • LoRA
  • Direct Preference Optimization (DPO)
  • Reward Modeling (RM)

A grid of charts showing training throughput (Tokens/sec/GPU) vs. sequence length for Llama-2 models.

Figure 2: End-to-end throughput. FLASHMASK (green) consistently outpaces FlashAttention-2 DenseMask (orange) and Vanilla Attention (blue).

Speedups: 1.65× to 3.22× faster than FlashAttention-2 DenseMask.
Capacity: FLASHMASK handled 544K tokens (LoRA, Llama-2 7B) vs 64K limit for dense.


Convergence and Correctness

Loss curves for FLASHMASK vs dense mask match perfectly under deterministic execution; even without determinism, convergence trends are identical.

Training loss curves for SFT, LoRA, DPO, and RM — perfect overlays when deterministic mode is on.

Figure 3: Loss curves confirm FLASHMASK’s outputs are numerically identical to dense mask results.


Performance vs. Sparsity

As block sparsity increases, latency drops linearly—validating the design.

(a) Kernel latency decreases as sparsity rises. (b) FLASHMASK’s memory scales linearly with sequence length (log scale).

Figure 4: (a) Latency vs sparsity. (b) Memory usage — FLASHMASK’s \(O(N)\) mask storage enables scalability.


Kernel-Level: FLASHMASK vs FlexAttention

Bar charts: FLASHMASK (orange) vs FlexAttention (teal) at various lengths — FLASHMASK faster in all cases.

Figure 5: Kernel speed (TFLOPs/s). FLASHMASK beats FlexAttention by 12.1%–60.7%, hitting up to 62.3% of A100’s peak.

Across 12 mask types and sequences up to 128K, FLASHMASK is consistently faster.


Key Takeaways

FLASHMASK blends elegant algorithmic insight with hardware-conscious engineering:

  • Linear Memory Complexity: From \(O(N^2)\) to \(O(N)\), enabling ultra-long contexts.
  • Massive Speedups: Skip fully masked blocks, cut computation.
  • Broad Mask Support: Handle causal, bidirectional, blockwise, prefix, sparse, and combined masks efficiently.
  • Exact Results: Bitwise match with dense-mask outputs.
  • State-of-the-Art Kernel Performance: Outpaces FlexAttention across the board.

Why This Matters

As models stretch toward million-token contexts, attention efficiency is critical. FLASHMASK eliminates a key bottleneck by making complex masking as fast and lightweight as simple causal attention. This opens doors for:

  • Richer context modeling
  • Efficient multi-document training
  • Scalable preference optimization & RLHF
  • Long-sequence tasks in code, vision, and multimodal domains

The implementation is open-sourced in PaddlePaddle and integrated into PaddleNLP, ready for large-scale adoption.

In short: FLASHMASK is a milestone in efficient Transformer design—a leaner, faster attention that doesn’t compromise on mask flexibility. It’s tailor-made for the long-context future of AI.