如果你一直关注大语言模型领域,就会知道其中一个重要目标是扩展上下文窗口。我们希望模型能阅读整本书籍、分析冗长的代码库或处理高分辨率图像。而主要的障碍是什么?正是 Transformer 架构核心的注意力机制。它的计算和内存成本会随着序列长度按二次方增长,使得长上下文变得异常昂贵。
2022 年,一篇突破性的论文 FlashAttention 直面了这个问题。通过巧妙地重排注意力计算,使其更好地利用 GPU 的内存层级结构,它实现了线性内存使用和2–4 倍加速,并且没有任何近似。这是一次颠覆性的创新,并已被广泛采用。
但故事并未就此结束。虽然 FlashAttention 速度很快,但它仍未充分发挥现代硬件的潜力。其性能 (以每秒浮点运算次数 FLOPs/s 衡量) 徘徊在 NVIDIA A100 等 GPU 理论峰值的 25–40% 左右。相比之下,高度优化的矩阵乘法 (GEMM) 例程可以达到该峰值的 80–90%。仍有显著的性能差距有待填补。
FlashAttention-2 应运而生。这项后续工作剖析了原始算法的剩余低效之处,并引入了一系列优化。通过重新思考如何在 GPU 的计算单元之间以及单元内部划分工作,FlashAttention-2 实现了又一次约 2 倍加速,将硬件利用率推向了令人印象深刻的理论峰值的 50–73%。
在这篇文章中,我们将深入探讨论文《FLASHATTENTION-2: 通过更优并行化与工作划分实现更快的注意力机制》。我们将探讨:
- 快速回顾: 标准注意力机制为何缓慢,以及初代 FlashAttention 的工作原理
- FlashAttention-2 的三大关键创新: 减少慢速操作、更智能的并行化以及更高效的工作划分
- 令人瞩目的结果: 展示 FlashAttention-2 的速度已接近纯矩阵乘法
背景: GPU 瓶颈与 FlashAttention-1
要理解 FlashAttention-2,我们首先需要掌握标准注意力为何效率低下,以及它的前辈 FlashAttention 如何解决了问题的初步部分。
GPU 硬件基础
GPU 不是一个单一的计算引擎,而是一个具有内存层级结构的复杂系统。对我们来说,最重要的两个层级是:
- 高带宽内存 (HBM): 即大容量的显存 (例如 A100 上的 40–80 GB) ,用于存放模型和数据。与常规 RAM 相比,它的“带宽”更高,但相对片上内存来说依然较慢。
- SRAM (静态随机存取存储器) ,又称共享内存: 极其快速的片上内存,容量很小 (每个计算单元仅几 KB) ,但带宽远高于 HBM。
GPU 编程的黄金法则:** 尽量减少对 HBM 的读写**。最高效的算法会将数据从 HBM 加载到快速的 SRAM 中一次性完成尽可能多的计算,然后只将最终结果写回 HBM。每一次不必要的 HBM 往返都会造成严重的性能瓶颈。
标准注意力机制的问题
标准自注意力机制定义如下:
标准注意力机制: 给定查询 Q、键 K 和值 V,计算得分 S,应用 softmax 得到概率 P,然后用 P 对 V 加权得到输出 O。
这里,Q、K 和 V 是形状为 \( N \times d \) 的矩阵,其中 \( N \) 为序列长度,\( d \) 为头的维度。
朴素实现方式如下:
- 计算
N × N
得分矩阵 S = QKT,将 S 写入 HBM。 - 从 HBM 读取 S,逐行应用 softmax 得到 P,将 P 写入 HBM。
- 从 HBM 读取 P,与 V 相乘得到输出 O。
即使序列长度仅 8k,S 和 P 就是 8k × 8k 的矩阵——含数亿个元素。在 HBM 中来回存储和传输这些巨大的矩阵非常缓慢,并且耗费大量内存。这就是二次方瓶颈。
FlashAttention 的解法: 分块与在线 Softmax
FlashAttention 的关键洞见是通过分块 (tiling) 技术,避免将完整的 S 和 P 矩阵写入 HBM。
算法将 Q、K 和 V 划分成更小的块,一次将一个 Q 块和一个 K/V 块加载到 SRAM 中。
图 1: 初代 FlashAttention 算法将 K 和 V 的块加载到高速 SRAM 中,与一个 Q 块计算注意力并更新输出,从而避免在慢速 HBM 中生成完整的 N × N 注意力矩阵。
在 SRAM 内部,它只为该块计算注意力。但 softmax 需要整行的指数和才能正确归一化,不能只独立处理一个块。
这时在线 softmax (online softmax) 就派上用场。算法为每一行维护一个运行中的最大值和归一化因子。对每个新块,计算局部 softmax,使用更新后的统计数据重新缩放之前的结果,并将它们合并。处理完所有块后,得到的结果与标准注意力完全一致——但消除了昂贵的 HBM 往返。
在反向传播中,FlashAttention 不存储庞大的 P 矩阵,而是在需要时重新计算注意力块。这样内存复杂度从 \(O(N^2)\) 降至 \(O(N)\)。
FlashAttention-2: 进一步提升硬件效率
FlashAttention 是一次巨大飞跃,但仍未将 GPU 性能发挥到极致。作者们为 FlashAttention-2 确定了三项优化。
1. 减少非矩阵乘法浮点运算 (Non-Matmul FLOPs)
现代 GPU (如 A100) 拥有专为矩阵乘法优化的张量核心 (Tensor Cores) 。在 A100 上:
- FP16/BF16 矩阵乘法: 峰值 312 TFLOPs/s
- FP32 非矩阵乘法操作: 峰值 19.5 TFLOPs/s
非矩阵乘法 FLOP 的成本可能是矩阵乘法 FLOP 的 16 倍。
初代 FlashAttention 在其内循环中执行了一些非矩阵乘法的缩放操作,例如在处理每个块后重复缩放 O。
FlashAttention-2 修改了在线 softmax 流程:
图: FlashAttention-2 避免了在循环内部重复缩放输出,只在最终进行一次缩放。
它不再在每次迭代中缩放 O,而是维护一个未缩放的输出版本和归一化统计数据,只在最后执行一次缩放。这减少了昂贵的非矩阵乘法操作,让 GPU 能够更多地用于高吞吐量的矩阵乘法。
2. 跨 GPU 的更优并行化
一块 NVIDIA A100 有 108 个流式多处理器 (Streaming Multiprocessors,SMs) ,它们执行线程块 (thread blocks) ,成千上万的线程并行工作。为了最大化速度,所有 SM 都必须保持活跃。
初代 FlashAttention 在“批次大小 × 头数”维度上并行化,为每个头分配一个线程块。如果批次 × 头数 ≥ SM 数量,这种方式效果很好。
但在处理长序列时,批次大小通常很小 (为了适应内存) 。当批次 × 头数 < SM 数量时,许多 SM 会闲置。
FlashAttention-2 增加了沿序列长度维度 \(N\) 的并行化:
图 2: 前向传播 (左) : 每个线程块处理一个行块 (查询切片) 。后向传播 (右) : 线程块处理列块并使用原子加法累积共享行的梯度。
- 前向传播: Q 的不同行块相互独立,因此线程块可并行处理它们,即使在小批次下也能提升利用率。
- 反向传播: 由于 Q 梯度存在依赖,情况更复杂,通过重构使其可以在列块 (K 和 V 切片) 上并行处理,依赖关系通过原子加法 (atomic adds) 解决。
这确保了长序列场景下的 GPU 完全利用。
3. 线程块内更高效的工作划分
线程块内部的线程被分组为线程束 (warps) (通常 32 个线程) 。Warp 可通过共享内存快速共享数据,但共享内存访问也需要时间。
初代 FlashAttention 使用了“拆分 K”方案:
- Q 由所有 warp 共享
- K、V 在 warp 之间拆分
- 结果通过共享内存读写合并
这造成了瓶颈。
FlashAttention-2 则反其道而行之:
图 3: FlashAttention (a) 拆分 K 和 V,需要 warp 间通信。FlashAttention-2 (b) 拆分 Q,所有 warp 共享 K 和 V,几乎无需通信。
现在:
- K、V 由所有 warp 共享
- Q 在不同 warp 中拆分
- 每个 warp 独立计算各自的 QKT 切片并与共享的 V 相乘
- 在最终写出之前,几乎没有 warp 间通信
这减少了共享内存流量和同步开销。
结果: 缩小差距
这些优化的综合效果如何?基准测试展示了明显的性能提升。
图 4: A100 上的前向+后向传播综合速度。FlashAttention-2 (紫色) 始终比 FlashAttention (橙色) 快约 2 倍,并且远快于 PyTorch (蓝色) 和 xformers (绿色) 。
与 PyTorch 标准注意力相比,长序列的加速可达 10 倍。前向与后向传播的原始吞吐量如下:
图 5: 前向传播: FlashAttention-2 达到 230 TFLOPs/s——A100 理论峰值的 73%。
图 6: 后向传播: FlashAttention-2 达到理论峰值的 63%。
在注意力机制中实现超过 70% 的理论峰值——接近 GEMM 的效率——相当非凡。
端到端训练增益
最终的考验是训练真实模型。
表 1: GPT 模型在 8×A100 上的端到端训练吞吐量。FlashAttention-2 每块 GPU 达到 225 TFLOPs/s (72% 利用率) ,较 FlashAttention 提速 1.3 倍,相较无 FlashAttention 提速最高可达 2.8 倍。
对于上下文长度为 8k 的 GPT-1.3B 模型:
- FlashAttention-2: 220 TFLOPs/s
- FlashAttention: 170 TFLOPs/s
- 基线: 72 TFLOPs/s
这意味着,训练一个 16k 上下文模型所需的成本现在与之前训练一个 8k 上下文模型大致相同。
结论与未来方向
FlashAttention-2 是硬件感知算法设计的典范。通过识别从高层数学到 warp 级执行的瓶颈,作者们将精确注意力推近了硬件极限。
其意义重大:
- 长上下文训练与推理更经济、更可行
- 现有流水线获得显著加速
- 使得能处理更大量信息的模型成为可能
展望未来,作者们计划:
- 针对 NVIDIA H100 GPU 进行优化 (利用新张量内存加速器与 FP8 张量核心)
- 支持新的数据类型,如 FP8
- 与更高层的算法优化 (如局部、块稀疏注意力) 结合,支持更长上下文
有了 FlashAttention-2,拥有几乎无限上下文窗口的模型的梦想比以往更近了一步。