从 ChatGPT 到 Gemini,Transformer 架构凭借注意力机制为现代人工智能提供了强大动力。注意力机制让模型能选择性地聚焦输入中的相关部分。然而,强大的能力也伴随着严重的瓶颈: 当序列长度扩展到整本书或庞大的代码库时,注意力的计算和内存需求会呈二次方级增长——输入长度翻倍,计算量翻四倍。这就是臭名昭著的二次方瓶颈

FlashAttention 这样的突破性技术,通过避免昂贵的中间内存分配,降低了标准场景下的成本。然而,当面对现代训练任务所需的复杂注意力掩码——即规定哪些令牌 (token) 可以“看到”彼此的规则——FlashAttention 就显得力不从心了。在偏好优化、微调或序列打包等场景,这些掩码至关重要。目前的方法在处理这类掩码时,往往要退回到密集且内存消耗巨大的计算方式。

百度公司的研究人员推出了 FLASHMASK,这是对 FlashAttention 的一次重大扩展,将同样的 IO 感知与高效原则推广到各种复杂掩码。通过重新设计掩码表示方式,FLASHMASK 实现了线性内存使用大幅加速,能够在高达 128K 令牌甚至更长的上下文中进行训练,且不牺牲精度。

本文将解读 FLASHMASK 论文: 它解决了什么问题、核心算法是什么,以及令人印象深刻的结果为何值得关注。


长上下文世界中的掩码难题

注意力机制的核心是计算一个得分矩阵,用于衡量每个令牌应关注其他令牌的程度:

\[ O = \text{Softmax}\left(\frac{QK^{T}}{\sqrt{d_k}}\right) V \]

这里,\(Q\) (Queries) 与 \(K\) (Keys) 交互产生得分;再根据这些得分对 \(V\) (Values) 进行加权聚合。
为了控制令牌的可见性,我们在 softmax 之前应用一个掩码 \(M\):

\[ \text{Attention}(Q, K, V, M) = \text{Softmax}\left(\frac{QK^{T}}{\sqrt{d_k}} + M\right) V \]

通过在特定位置加上 \(-\infty\),这些位置对应的 softmax 得分会变为零——相应的查询-键对将“不可见”。


多样化的注意力掩码

不同训练阶段和任务需要不同的掩码:

  • 因果掩码 (Causal Mask): 用于 GPT 等自回归模型,阻止访问未来的令牌。
  • 文档掩码 (Document Mask): 用于序列打包,限制令牌只能关注同一文档内的内容。
  • 共享问题掩码 (Shared Question Mask): 用于奖励建模 (RM) 和直接偏好优化 (DPO),所有答案可以关注共享的问题,但不能相互关注——减少冗余计算。
  • 全局 + 滑动窗口掩码 (Global + Sliding Window Mask): 将全局上下文令牌与局部窗口注意力相结合。
  • 前缀掩码 (Prefix Masks)分块掩码 (Blockwise Masks)稀疏掩码 (Sparse Masks) 等。

FLASHMASK 论文识别了十余种常见掩码类型,每种都具有结构化的稀疏性。

FLASHMASK 支持的不同注意力掩码模式、其列式稀疏表示以及其高效核函数实现的示意图。

图 1: FLASHMASK 概览。(a) 支持的常见掩码类型。(b) 列式稀疏掩码表示。(c) 高效的核函数实现。

问题在于: 直接实现这些掩码意味着需要构造一个密集的 \(N \times N\) 矩阵,内存复杂度为 \(O(N^2)\)。当 \(N = 128{,}000\) 时,即需存储 160 亿个元素——成本高得难以承受。


FlashAttention 的革命与局限

FlashAttention 避免了创建完整的 \(N \times N\) 注意力矩阵。它将计算切分为适合芯片上 SRAM 的瓦片 (tile) ,高效地读写 GPU 内存,从而将内存开销降至 \(O(N)\),并显著提升注意力计算速度。

然而,FlashAttention-2 对掩码的高效支持范围有限 (如因果掩码、滑动窗口掩码) 。遇到自定义掩码时,它会退回到密集掩码方法——重新引入二次方级的内存成本。基于编译器的解决方案如 FlexAttention 虽提升了灵活性,但性能仍有改进空间。


核心洞察: 稀疏区间胜于密集矩阵

FLASHMASK 的关键发现是: 在几乎所有实际掩码中,每列的被屏蔽行通常构成连续区间

例如因果掩码: 在第 \(j\) 列 (键令牌 \(j\)) ,所有位置在 \(j\) 之后的查询都会被阻止——这是一个连续区间。文档掩码在每列中通常屏蔽一个或两个连续区间。

因此,与其存储一个 \(N \times N\) 的布尔矩阵,FLASHMASK 只需保存这些区间的起点终点索引。

FLASHMASK 使用四个长度为 \(N\) 的向量:

  • LTS — 下三角起点 (Lower Triangular Start)
  • LTE — 下三角终点 (Lower Triangular End)
  • UTS — 上三角起点 (Upper Triangular Start)
  • UTE — 上三角终点 (Upper Triangular End)

对于第 \(j\) 列,被掩码的行为:
\([LTS_j, LTE_j) \cup [UTS_j, UTE_j)\)。

例: 在图 1(b)(6) 中,第 5 列的掩码区间为 \([7, 10) \cup [2, 4)\),即第 2–3 行和第 7–9 行被掩码。

优势:

  1. 紧凑 —— 存储复杂度 \(O(N)\),相比密集掩码的 \(O(N^2)\)。
  2. 灵活 —— 能表示大部分现实中的掩码。
  3. 高速 —— 天然适配 FlashAttention 的块/瓦片跳过机制。

将 FLASHMASK 集成到 FlashAttention-2

FLASHMASK 嵌入到 FlashAttention-2 的流程分两步:

算法 1: FlashAttention-2 的前向传播过程,其中高亮显示了为 FLASHMASK 添加的部分。

算法 1: 使用 FLASHMASK 扩展的前向传播 (原论文中蓝色部分为新增) 。

步骤 1 — 预处理:
将 LTS、LTE、UTS、UTE 划分为列块。对每个块计算最小/最大起点/终点索引,并存储到 8 个小的摘要向量中——代价低且缓存友好。

步骤 2 — 实时块跳过:
处理瓦片 \((Q_i, K_j)\) 时:

  • 全掩码 —— 完全跳过 (不加载 K/V、不做矩阵乘法) 。
  • 无掩码 —— 按标准 FlashAttention 执行。
  • 部分掩码 —— 加载该块的详细 LTS/LTE/UTS/UTE 信息,并选择性屏蔽元素。

这种粗到细的策略能尽早剔除无效计算,仅在需要时应用细粒度掩码。掩码稀疏度 \(\rho\) 可直接减少计算量:

\[ O((1 - \rho) N^2) \quad \text{对比} \quad O(N^2) \]

且因其是精确执行,输出与密集掩码结果逐比特完全一致


结果: FLASHMASK 实战表现

端到端训练速度

研究人员在以下任务上对 **Llama-2 **(7B、13B、70B) 进行了微调:

  • 监督微调 (SFT)
  • LoRA
  • 直接偏好优化 (DPO)
  • 奖励建模 (RM)

一组图表展示 Llama-2 模型在不同序列长度下的训练吞吐量 (Tokens/sec/GPU) 。

图 2: 端到端吞吐量。FLASHMASK (绿色) 始终优于 FlashAttention-2 DenseMask (橙色) 和 Vanilla Attention (蓝色) 。

加速效果: 比 FlashAttention-2 DenseMask 快 1.65× 至 3.22×。
处理上限: FLASHMASK 可处理 544K 令牌 (LoRA,Llama-2 7B) ,密集掩码为 64K。


收敛性与正确性

在确定性模式下,FLASHMASK 与密集掩码的损失曲线完全重合;即便关闭确定性,收敛趋势也保持一致。

SFT、LoRA、DPO、RM 的训练损失曲线——在确定性模式下曲线完全重合。

图 3: 损失曲线表明 FLASHMASK 输出与密集掩码结果在数值上完全一致。


性能与稀疏度

随着块稀疏度增加,延迟线性下降——验证设计理念的正确性。

(a) 核函数延迟随稀疏度增加而降低。(b) FLASHMASK 的内存使用随序列长度线性扩展 (对数尺度) 。

图 4: (a) 延迟与稀疏度关系。(b) 内存使用——FLASHMASK 的 \(O(N)\) 掩码存储实现了出色可扩展性。


核函数对比: FLASHMASK vs FlexAttention

柱状图: 不同序列长度下,FLASHMASK (橙色) 均比 FlexAttention (青色) 更快。

图 5: 核函数速度 (TFLOPs/s)。FLASHMASK 比 FlexAttention 快 12.1%–60.7%,最高达 A100 峰值性能的 62.3%。

在 12 种掩码类型和最长 128K 序列下,FLASHMASK 始终更快。


关键要点

FLASHMASK 将巧妙的算法设计与硬件友好的工程实践相结合:

  • 线性内存复杂度: 从 \(O(N^2)\) 降至 \(O(N)\),支持超长上下文。
  • 显著加速: 跳过全掩码的块,减少计算量。
  • 广谱掩码支持: 高效处理因果、双向、分块、前缀、稀疏及组合掩码。
  • 精确结果: 与密集掩码输出逐比特匹配。
  • 顶尖内核性能: 全面超越 FlexAttention。

为什么重要

随着模型迈向百万令牌上下文,注意力效率至关重要。FLASHMASK 让复杂掩码的计算如同简单因果注意力般快速轻量,从而消除核心瓶颈。这为以下场景打开了可能性:

  • 更丰富的上下文建模
  • 高效的多文档训练
  • 可扩展的偏好优化与 RLHF
  • 涉及代码、视觉和多模态的长序列任务

该实现已在 PaddlePaddle 开源,并集成到 PaddleNLP,适用于大规模应用。

简而言之: FLASHMASK 是高效 Transformer 设计的里程碑——更精简、更快速的注意力机制,掩码灵活性不减,专为长上下文未来的 AI 而生。