Transformer 几乎无处不在——从 ChatGPT 到代码补全助手,各类工具背后皆有它的身影——但它有一个众所周知的短板: 自注意力机制。随着输入序列变长,注意力计算量会呈二次方增长。序列长度加倍意味着计算量翻四倍。这一计算瓶颈使得在超长文档、高分辨率图像或庞大代码库上训练变得既困难又昂贵。

研究人员早就怀疑,这其中很多计算是浪费的。实际上,一个 token 只需关注少量与其相关的其他 token。这一洞察推动了稀疏注意力的研究——即跳过不必要计算的方法。部分方法采用固定模式,另一些则尝试动态的、数据驱动的策略。

理论上,动态稀疏方法颇具吸引力。但在实践中,它们的速度往往落后于直接计算完整密集注意力矩阵——尤其是在 FlashAttention 出现之后。FlashAttention 是一种 I/O 感知的 GPU 优化注意力实现,巧妙减少了缓慢的内存操作,使原生注意力快如闪电。然而,它的优势依赖于标准因果注意力的密集、规则结构。动态稀疏模式的无规律性打破了 FlashAttention 的假设,使速度优势不复存在。

这正是论文《Faster Causal Attention Over Large Sequences Through Sparse Flash Attention》切入的地方。研究团队扩展了 FlashAttention,使其能够兼顾动态稀疏的不规则性,同时保持高速度。他们提出的方法——稀疏因果 Flash Attention (SCFA)——将稀疏注意力的理论优势引入高性能 GPU 内核的实践中。使用 SCFA,他们在 8k 和 16k token 序列上的语言模型训练速度分别提升了 2.0 倍3.3 倍,且模型质量无损。

接下来看看他们是如何做到的。

瓶颈与以往尝试

二次方问题

自注意力要计算序列中每一对 token 的得分。序列长度为 \(T\) 时,会得到一个 \(T \times T\) 的得分矩阵——计算复杂度随 \(T^2\) 增长。

短序列时尚可应付,但长度达 16,000 个 token 时,注意力矩阵将拥有超过 2.5 亿个元素。随着长度增加,注意力计算逐渐成为模型运行的主导部分。

一张堆叠面积图,显示对于长序列,注意力 (橙色) 的计算成本呈二次方增长,并主导了前馈层 (蓝色) 的成本。

图 9: 对于长序列,自注意力的二次方计算成本占据主导地位。

FlashAttention 的兴起

注意力的瓶颈不仅在于浮点运算量 (FLOPs) ——还受限于内存访问。GPU 配备速度极快但容量较小的 SRAM,而存储完整注意力矩阵需要容量大但速度慢的高带宽内存 (HBM) 。朴素实现会将大量时间浪费在两类内存间的数据传输上。

FlashAttention (Dao 等,2022) 通过将注意力计算拆分成适配 SRAM 的小块 (tile) ,大幅减少了慢速内存访问。这种重排保留了精确的注意力计算数学,但带来了超过 5 倍的加速。

自回归模型的因果注意力——即一个 token 只能看到其之前的 token——可以利用下三角掩码高效实现。FlashAttention 充分利用这种可预测的块结构来保持高速。

动态稀疏性为何会破坏它

动态稀疏性通过以下方式实时移除不必要的 token 交互:

  1. 查询/键丢弃 (QK-Sparse) : 动态剪掉不重要的查询和键。
  2. 哈希稀疏 (Hash-Sparse) : 利用局部敏感哈希 (LSH) 等技术,将相似 token 分组到“桶”中,仅在桶内计算注意力 (如 Reformer 模型所用) 。

当根据哈希重排或丢弃 token 时,原本整齐的因果顺序会被打乱。即使张量被压缩,每个 token 仍必须遵守其原始位置以保证因果性。由此产生的不规则掩码会破坏 FlashAttention 的高效性。

SCFA 的方法是: 在适应这种不规则性的同时,保持 FlashAttention 的内存高效性。

SCFA 解决方案

稀疏因果 Flash Attention 继承了 FlashAttention 的分块理念,并增加了对原始 token 索引的识别能力。它针对 QK-sparse 和 Hash-sparse 两种注意力模式分别提供专门的内核,使 SCFA 能够在不规则布局中剪掉整个分块并精确施加因果掩码。

1. QK-Sparse: 高效的 Token 丢弃

丢弃部分 token 会得到更小、更紧凑的 query/key/value 张量。

图示 QK-sparse 注意力如何丢弃某些键和查询 (红色标记) 以创建一个更小的注意力问题,而 Hash-sparse 注意力则按哈希码 (颜色) 对键和查询进行分组,以创建块稀疏注意力。

图 1: QK-sparse (上) 与 Hash-sparse (下) 注意力对注意力矩阵的改造方式。

SCFA 额外输入两个向量:

  • q_idx: 查询的原始下标。
  • k_idx: 键的原始下标。

对每个分块 \(\mathcal{T}_{i,j}\) (查询块 \(Q_i\) 对键块 \(K_j\)) :

  1. 块剪枝:max(q_idx_i) < min(k_idx_j),跳过该分块——即所有查询在所有键之前。
  2. 元素掩码: 在有效分块中,利用 q_idxk_idx 屏蔽未来的 token。

图示 SCFA 如何计算不同的分块模式。左: 标准 FlashAttention。中: 用于 QK-sparse 注意力的 SCFA,具有不规则因果掩码。右: 用于 Hash-sparse 注意力的 SCFA,具有块稀疏的不规则掩码。

图 2: 标准、QK-sparse 和 Hash-sparse SCFA 的分块计算模式。

2. Hash-Sparse: 精确的桶内注意力

在基于 LSH 的注意力中,查询与键按哈希 ID 分配到不同的桶中,形成类似块对角的结构。

SCFA 会使用:

  • q_idxk_idx: 原始下标。
  • q_hashk_hash: 桶 ID。

剪枝与掩码规则为:

  1. 桶剪枝: 仅保留查询块与键块的桶 ID 匹配的分块。
  2. 因果剪枝: 跳过违反因果顺序的桶分块。
  3. 元素级掩码: 分块内部同时按照因果性与桶匹配进行掩码。

这样即可实现精确的桶内注意力,并保持 GPU 友好的高效执行。

实验与结果

Hash-Sparse 基准测试

比较对象包括:

  • 朴素哈希稀疏实现 (先算完整矩阵,再掩码) 。
  • FlashAttention。
  • SCFA Hash-sparse。

一张图表显示了不同哈希注意力实现的运行时。朴素实现 (橙色) 极其缓慢。随着序列变长和桶数增加,SCFA (红色曲线) 性能优于 FlashAttention (蓝色虚线) 。

图 3: SCFA Hash-sparse 相较基线的运行时提升。

朴素方法非常慢;SCFA 通过二次方的计算节省迅速抵消了排序的开销。序列长度与桶数量越大,性能优势越显著。

Reformer 注意力相比,SCFA 更快,并实现对桶冲突的 100% 覆盖,而 Reformer 的覆盖率会下降。

两张图比较 SCFA (Hash-sparse,红色) 与 Reformer (绿色) 和 FlashAttention (蓝色) 的表现。左: 长序列下 SCFA 更快。右: Reformer 覆盖率骤降;SCFA 保持 100%。

图 4: SCFA Hash-sparse 与 Reformer 的速度及准确性对比。

Hash-Sparse 在语言模型训练中的应用

在 OpenWebText2 数据集上,使用 SCFA Hash-sparse (H-LM) 的 1.22 亿参数 Transformer 模型,其困惑度与 FlashAttention 基线 (F-LM) 相当或更优,并能更快达到目标:

三张图显示 H-LM 与基线困惑度持平,但训练速度更快: 8k tokens 提速 2.0 倍,16k tokens 提速 3.3 倍。

图 6: H-LM 在不损失质量的情况下获得提速。

  • 8k tokens: 提速 2.0 倍
  • 16k tokens: 提速 3.3 倍

QK-Sparse 基准测试

两张图对比 QK-sparse 的运行时。朴素实现 (a) 缓慢;SCFA (b) 在各稀疏度下均优于 FlashAttention,尤其在长序列。

图 7: QK-sparse 的运行时表现。

朴素的 token 丢弃仅在极高稀疏度 (丢弃超过 70%) 时才有效;而 SCFA 即便在中等稀疏度 (20–30%) 下也有速度优势。

QK-Sparse 在语言模型训练中的应用

丢弃 30% QK 头的模型 (D-LM) 在困惑度上与基线持平,同时获得 1.9 倍提速;更高丢弃率可带来更大速度收益,但会略有质量损失。

QK-dropping 语言模型训练结果: 丢弃 30% QK 头时,困惑度持平,速度提升 1.9 倍。

图 8: 使用 QK-sparse SCFA 可实现快速且具竞争力的语言模型训练。

结论

稀疏因果 Flash Attention 是一项扎实的工程创新,融合了动态稀疏的理论效率与 FlashAttention 的实用速度。通过将 I/O 感知分块扩展到不规则掩码,SCFA 释放了 QK 丢弃、哈希稀疏以及潜在更多动态稀疏注意力策略的性能潜力。

虽然 SCFA 不改变最坏情况下的 \(O(T^2)\) 复杂度,但它大幅降低了平均情况的计算负担,使长上下文 Transformer 的效率显著提升。它是下一代模型的重要基础工具——开源可用,能够驱动更具适应性、更加高效的注意力机制。