Transformer 已经彻底改变了机器学习领域,但它有一个众所周知的致命弱点: 自注意力机制。自注意力虽然极其强大,但其计算和内存成本会随着序列长度呈二次方增长。这种 \(O(N^2)\) 的复杂度一直是主要障碍,使得在长文档、高分辨率图像或长音频片段上训练模型的成本高得令人望而却步。
多年间,研究人员尝试过用 近似注意力 方法来驯服这只二次方增长的“猛兽”。这些方法牺牲部分模型精度以换取更高的效率,通常可以将复杂度降低到线性或近线性时间。但问题是: 这些理论上更快的方法在实践中并不一定能加快训练。它们减少了计算量 (FLOPs) ,却常常忽略现代硬件 (如 GPU) 上的真正瓶颈:** 内存访问**。
斯坦福大学的一篇开创性论文《FLASHATTENTION: Fast and Memory-Efficient Exact Attention with IO-Awareness》提出,我们一直找错了方向。作者认为关键不仅是减少计算量,更要做到 IO 感知——智能管理数据在 GPU 不同层级内存之间的流动。
他们由此设计了 FlashAttention,这是一种计算精确注意力的算法,比标准实现更快、更高效,而且不做近似计算,而是从硬件出发,彻底重新设计整个过程。这一创新带来了端到端训练速度的大幅提升,更重要的是,它使 Transformer 能够处理此前无法想象的超长序列,从而开启全新的能力。
在本文的深入解析中,我们将揭开 FlashAttention 的技术细节,探讨标准注意力机制在硬件上的限制,并看看 FlashAttention 如何凭借巧妙的分块 (tiling) 与重计算 (recomputation) 完全规避这些问题。
真正的瓶颈: 两种内存的故事
要理解 FlashAttention 的高效性,我们需要先了解它运行所在的硬件。现代 GPU 有着内存层级结构,不同层级内存在容量与速度上的权衡各不相同。
图 1: 左: 内存层级结构中,SRAM 容量小但速度快;HBM 和 DRAM 容量大但速度慢。中: FlashAttention 在 SRAM 中循环处理 \(K, V\) 和 \(Q\) 的分块,无需在 HBM 中生成完整的 \(N \times N\) 矩阵。右: PyTorch 注意力与 FlashAttention 在 GPT-2 上的运行时间对比——FlashAttention 的融合核实现了 7.6 倍加速。
与本文讨论最密切相关的两个层级是:
- 高带宽内存 (High Bandwidth Memory, HBM) : GPU 的主存,容量大 (如 NVIDIA A100 可达 40–80 GB) ,但相较 GPU 的计算吞吐速度要慢得多。访问 HBM 是主要性能瓶颈。
- SRAM (片上静态随机存取存储器) : 容量小得多但速度极快的内存,直接位于 GPU 核心单元上,可用容量在数 KB 到数 MB 之间。速度比 HBM 快一个数量级。
GPU 运算可分为**计算受限 (受计算能力限制) 或内存受限 **(受数据在慢速内存与计算单元之间传输的时间限制) 。由于 GPU 计算能力增长速度超过了内存带宽,许多 Transformer 操作——如逐元素运算及 softmax 等归约操作——如今主要受制于内存访问。
标准注意力机制如何陷入困境
标准注意力计算如下:
\[ S = QK^{\mathsf{T}}, \quad P = \mathrm{softmax}(S), \quad O = PV \]问题在于中间的 \(S\) 和 \(P\) 矩阵,对长度为 \(N\) 的序列而言它们大小都是 \(N \times N\)。典型实现过程为:
- 计算完整的 \(S\) 并写入 HBM。
- 从 HBM 读回 \(S\) 以计算 softmax。
- 将 \(P\) 写入 HBM。
- 再从 HBM 读回 \(P\) 和 \(V\) 来计算 \(O\)。
当 \(N = 8192\) 时,一个 32 位浮点矩阵的大小约为 256 MB。再乘以多头与批大小,HBM 流量将急剧膨胀。此外,反向传播时还需要存储 \(P\) 矩阵,消耗二次方级别的内存。
FlashAttention 的目标是:** 不在 HBM 中生成完整的 \(N \times N\) 矩阵,直接计算精确注意力**。
核心方法: 分块与重计算
FlashAttention 的性能突破源于两项关键技术:** 分块 (tiling)** 与重计算 (recomputation) 。该算法将矩阵乘法、掩码、softmax、dropout 等所有注意力操作融合为单一 GPU 核。核函数从 HBM 读取输入,在高速 SRAM 中完成全部处理,然后仅将最终输出写回 HBM。
分块: 分块计算 Softmax
Softmax 的难点在于: 对任一行归一化时需要该行所有元素。FlashAttention 利用一种数值稳定的 softmax 分解:
\[ m(x) = \max_i x_i, \quad \ell(x) = \sum_i e^{x_i - m(x)}, \quad \text{softmax}(x)_i = \frac{e^{x_i - m(x)}}{\ell(x)} \]若将 \(x\) 分成 \(x^{(1)}, x^{(2)}\) 两块,可合并分块的统计量:
\[ m(x) = \max(m(x^{(1)}), m(x^{(2)})), \quad \ell(x) = e^{m(x^{(1)}) - m(x)}\ell(x^{(1)}) + e^{m(x^{(2)}) - m(x)}\ell(x^{(2)}) \]意味着我们可分块处理一行,而无需一次性将整行加载进 SRAM。
FlashAttention 执行过程:
- 外层循环: 依次将 \(K\) 与 \(V\) 分块加载进 SRAM。
- 内层循环: 对当前 \(K_j, V_j\) 分块,遍历所有 \(Q_i\) 分块,依次加载进 SRAM。
- 片上计算: 当 \(Q_i, K_j, V_j\) 全部在 SRAM 中:
- 计算 \(S_{ij} = Q_iK_j^{\mathsf{T}}\)。
- 分块计算 softmax 并更新运行中的 \((m_i, \ell_i)\)。
- 增量更新输出分块 \(O_i\)。
- 写回: 将更新后的 \(O_i\) 保存到 HBM。
最终在不生成完整 \(S\) 或 \(P\) 的情况下组装出 \(O\)。
重计算: 快速反向传播且无需大量存储
训练中反向传播通常需要 \(P\) 矩阵。FlashAttention 通过在反向阶段重计算每个分块来避免存储它。
借助前向传播保存的 \(Q, K, V, O\) 和 \((m, \ell)\),核函数在 SRAM 中重新计算每个注意力分块。虽然增加了一些 FLOPs,却减少了大量 HBM 读取。由于计算成本低而内存访问成本高,重计算反而提升了整体速度。
图 2: 左: FlashAttention 使用稍多的 FLOPs,但 HBM 访问减少超过 9 倍,带来巨大加速。中: 分块越大,HBM 访问越少,直到运行时间受计算能力限制。右: 块稀疏 FlashAttention 随稀疏度提升获得相应加速。
理论分析: IO 复杂度
HBM 访问次数分析:
- 标准注意力: \(\Theta(Nd + N^2)\) 次 HBM 访问。
- FlashAttention: \(\Theta(N^2 d^2 / M)\) 次访问,其中 \(M\) 为 SRAM 容量。
在常见的 \(d\) 和 \(M\) 条件下,FlashAttention 的 IO 成本比标准方法小很多倍。
更进一步: 块稀疏 FlashAttention
FlashAttention 还可加速近似注意力方法。块稀疏 FlashAttention (Block-Sparse FlashAttention) 使用稀疏掩码指示需计算的分块。跳过零块可将 IO 缩减一个系数 \(s\) (非零块比例) 。
图 2 右显示,块稀疏版本的运行速度甚至快于密集版本,稀疏度越高,速度提升越显著。
真实世界的结果
效果显著:
更快的训练速度
- BERT-large: 比 MLPerf 1.1 记录快 15%。
- GPT-2: 比 HuggingFace 提快最高 3 倍,比 Megatron-LM 提快 1.7 倍。
- Long Range Arena: 在长序列基准中提速 2.4 倍。
图 3: 左: FlashAttention 超过精确注意力基线;块稀疏版为最快。右: 线性内存使用——比精确注意力基线少高达 20 倍。
利用更长上下文获得更优模型
- 4K 上下文的 GPT-2: 仍比 Megatron 的 1K 上下文版快 30%,困惑度 (perplexity) 降低 0.7。
- 长文档分类: 在 MIMIC-III (医疗) 与 ECtHR (法律) 任务中,将序列长度增至 8K 或 16K,准确率最高提升 8.5 点。
- Path-X 与 Path-256: 首批在这些极端长上下文视觉任务上击败随机准确率的 Transformer——得益于 FlashAttention 的可扩展性。
结论与未来方向
FlashAttention 改变了我们对深度学习性能优化的认知。主要洞见:
- 内存 I/O 是瓶颈: 在现代 GPU 上,数据传输的代价往往高于计算。
- IO 感知算法可带来巨大加速: 通过如分块等技术最小化 HBM 的读写。
- 效率提升拓展能力: 处理更长序列不仅提升模型质量,还可解决此前无法处理的问题。
作者指出,编写自定义 CUDA 核函数需要大量工程投入,他们设想未来能有高层工具或编译器自动生成 IO 感知核。
FlashAttention 不只是一次优化——它是机器学习技术栈的新原语。让长上下文 Transformer 真正可用,为长文本、高分辨率视频、基因组学等领域带来突破。有时,技术的飞跃源于对内存的深思熟虑。