从 ChatGPT 到 Gemini,Transformer 架构凭借注意力机制为现代人工智能提供了强大动力。注意力机制让模型能选择性地聚焦输入中的相关部分。然而,强大的能力也伴随着严重的瓶颈: 当序列长度扩展到整本书或庞大的代码库时,注意力的计算和内存需求会呈二次方级增长——输入长度翻倍,计算量翻四倍。这就是臭名昭著的二次方瓶颈。
像 FlashAttention 这样的突破性技术,通过避免昂贵的中间内存分配,降低了标准场景下的成本。然而,当面对现代训练任务所需的复杂注意力掩码——即规定哪些令牌 (token) 可以“看到”彼此的规则——FlashAttention 就显得力不从心了。在偏好优化、微调或序列打包等场景,这些掩码至关重要。目前的方法在处理这类掩码时,往往要退回到密集且内存消耗巨大的计算方式。
百度公司的研究人员推出了 FLASHMASK,这是对 FlashAttention 的一次重大扩展,将同样的 IO 感知与高效原则推广到各种复杂掩码。通过重新设计掩码表示方式,FLASHMASK 实现了线性内存使用与大幅加速,能够在高达 128K 令牌甚至更长的上下文中进行训练,且不牺牲精度。
本文将解读 FLASHMASK 论文: 它解决了什么问题、核心算法是什么,以及令人印象深刻的结果为何值得关注。
长上下文世界中的掩码难题
注意力机制的核心是计算一个得分矩阵,用于衡量每个令牌应关注其他令牌的程度:
\[ O = \text{Softmax}\left(\frac{QK^{T}}{\sqrt{d_k}}\right) V \]这里,\(Q\) (Queries) 与 \(K\) (Keys) 交互产生得分;再根据这些得分对 \(V\) (Values) 进行加权聚合。
为了控制令牌的可见性,我们在 softmax 之前应用一个掩码 \(M\):
通过在特定位置加上 \(-\infty\),这些位置对应的 softmax 得分会变为零——相应的查询-键对将“不可见”。
多样化的注意力掩码
不同训练阶段和任务需要不同的掩码:
- 因果掩码 (Causal Mask): 用于 GPT 等自回归模型,阻止访问未来的令牌。
- 文档掩码 (Document Mask): 用于序列打包,限制令牌只能关注同一文档内的内容。
- 共享问题掩码 (Shared Question Mask): 用于奖励建模 (RM) 和直接偏好优化 (DPO),所有答案可以关注共享的问题,但不能相互关注——减少冗余计算。
- 全局 + 滑动窗口掩码 (Global + Sliding Window Mask): 将全局上下文令牌与局部窗口注意力相结合。
- 前缀掩码 (Prefix Masks)、分块掩码 (Blockwise Masks)、稀疏掩码 (Sparse Masks) 等。
FLASHMASK 论文识别了十余种常见掩码类型,每种都具有结构化的稀疏性。
图 1: FLASHMASK 概览。(a) 支持的常见掩码类型。(b) 列式稀疏掩码表示。(c) 高效的核函数实现。
问题在于: 直接实现这些掩码意味着需要构造一个密集的 \(N \times N\) 矩阵,内存复杂度为 \(O(N^2)\)。当 \(N = 128{,}000\) 时,即需存储 160 亿个元素——成本高得难以承受。
FlashAttention 的革命与局限
FlashAttention 避免了创建完整的 \(N \times N\) 注意力矩阵。它将计算切分为适合芯片上 SRAM 的瓦片 (tile) ,高效地读写 GPU 内存,从而将内存开销降至 \(O(N)\),并显著提升注意力计算速度。
然而,FlashAttention-2 对掩码的高效支持范围有限 (如因果掩码、滑动窗口掩码) 。遇到自定义掩码时,它会退回到密集掩码方法——重新引入二次方级的内存成本。基于编译器的解决方案如 FlexAttention 虽提升了灵活性,但性能仍有改进空间。
核心洞察: 稀疏区间胜于密集矩阵
FLASHMASK 的关键发现是: 在几乎所有实际掩码中,每列的被屏蔽行通常构成连续区间。
例如因果掩码: 在第 \(j\) 列 (键令牌 \(j\)) ,所有位置在 \(j\) 之后的查询都会被阻止——这是一个连续区间。文档掩码在每列中通常屏蔽一个或两个连续区间。
因此,与其存储一个 \(N \times N\) 的布尔矩阵,FLASHMASK 只需保存这些区间的起点和终点索引。
FLASHMASK 使用四个长度为 \(N\) 的向量:
- LTS — 下三角起点 (Lower Triangular Start)
- LTE — 下三角终点 (Lower Triangular End)
- UTS — 上三角起点 (Upper Triangular Start)
- UTE — 上三角终点 (Upper Triangular End)
对于第 \(j\) 列,被掩码的行为:
\([LTS_j, LTE_j) \cup [UTS_j, UTE_j)\)。
例: 在图 1(b)(6) 中,第 5 列的掩码区间为 \([7, 10) \cup [2, 4)\),即第 2–3 行和第 7–9 行被掩码。
优势:
- 紧凑 —— 存储复杂度 \(O(N)\),相比密集掩码的 \(O(N^2)\)。
- 灵活 —— 能表示大部分现实中的掩码。
- 高速 —— 天然适配 FlashAttention 的块/瓦片跳过机制。
将 FLASHMASK 集成到 FlashAttention-2
FLASHMASK 嵌入到 FlashAttention-2 的流程分两步:
算法 1: 使用 FLASHMASK 扩展的前向传播 (原论文中蓝色部分为新增) 。
步骤 1 — 预处理:
将 LTS、LTE、UTS、UTE 划分为列块。对每个块计算最小/最大起点/终点索引,并存储到 8 个小的摘要向量中——代价低且缓存友好。
步骤 2 — 实时块跳过:
处理瓦片 \((Q_i, K_j)\) 时:
- 全掩码 —— 完全跳过 (不加载 K/V、不做矩阵乘法) 。
- 无掩码 —— 按标准 FlashAttention 执行。
- 部分掩码 —— 加载该块的详细 LTS/LTE/UTS/UTE 信息,并选择性屏蔽元素。
这种粗到细的策略能尽早剔除无效计算,仅在需要时应用细粒度掩码。掩码稀疏度 \(\rho\) 可直接减少计算量:
\[ O((1 - \rho) N^2) \quad \text{对比} \quad O(N^2) \]且因其是精确执行,输出与密集掩码结果逐比特完全一致。
结果: FLASHMASK 实战表现
端到端训练速度
研究人员在以下任务上对 **Llama-2 **(7B、13B、70B) 进行了微调:
- 监督微调 (SFT)
- LoRA
- 直接偏好优化 (DPO)
- 奖励建模 (RM)
图 2: 端到端吞吐量。FLASHMASK (绿色) 始终优于 FlashAttention-2 DenseMask (橙色) 和 Vanilla Attention (蓝色) 。
加速效果: 比 FlashAttention-2 DenseMask 快 1.65× 至 3.22×。
处理上限: FLASHMASK 可处理 544K 令牌 (LoRA,Llama-2 7B) ,密集掩码为 64K。
收敛性与正确性
在确定性模式下,FLASHMASK 与密集掩码的损失曲线完全重合;即便关闭确定性,收敛趋势也保持一致。
图 3: 损失曲线表明 FLASHMASK 输出与密集掩码结果在数值上完全一致。
性能与稀疏度
随着块稀疏度增加,延迟线性下降——验证设计理念的正确性。
图 4: (a) 延迟与稀疏度关系。(b) 内存使用——FLASHMASK 的 \(O(N)\) 掩码存储实现了出色可扩展性。
核函数对比: FLASHMASK vs FlexAttention
图 5: 核函数速度 (TFLOPs/s)。FLASHMASK 比 FlexAttention 快 12.1%–60.7%,最高达 A100 峰值性能的 62.3%。
在 12 种掩码类型和最长 128K 序列下,FLASHMASK 始终更快。
关键要点
FLASHMASK 将巧妙的算法设计与硬件友好的工程实践相结合:
- 线性内存复杂度: 从 \(O(N^2)\) 降至 \(O(N)\),支持超长上下文。
- 显著加速: 跳过全掩码的块,减少计算量。
- 广谱掩码支持: 高效处理因果、双向、分块、前缀、稀疏及组合掩码。
- 精确结果: 与密集掩码输出逐比特匹配。
- 顶尖内核性能: 全面超越 FlexAttention。
为什么重要
随着模型迈向百万令牌上下文,注意力效率至关重要。FLASHMASK 让复杂掩码的计算如同简单因果注意力般快速轻量,从而消除核心瓶颈。这为以下场景打开了可能性:
- 更丰富的上下文建模
- 高效的多文档训练
- 可扩展的偏好优化与 RLHF
- 涉及代码、视觉和多模态的长序列任务
该实现已在 PaddlePaddle 开源,并集成到 PaddleNLP,适用于大规模应用。
简而言之: FLASHMASK 是高效 Transformer 设计的里程碑——更精简、更快速的注意力机制,掩码灵活性不减,专为长上下文未来的 AI 而生。