如果你一直关注大语言模型领域,就会知道其中一个重要目标是扩展上下文窗口。我们希望模型能阅读整本书籍、分析冗长的代码库或处理高分辨率图像。而主要的障碍是什么?正是 Transformer 架构核心的注意力机制。它的计算和内存成本会随着序列长度按二次方增长,使得长上下文变得异常昂贵。

2022 年,一篇突破性的论文 FlashAttention 直面了这个问题。通过巧妙地重排注意力计算,使其更好地利用 GPU 的内存层级结构,它实现了线性内存使用2–4 倍加速,并且没有任何近似。这是一次颠覆性的创新,并已被广泛采用。

但故事并未就此结束。虽然 FlashAttention 速度很快,但它仍未充分发挥现代硬件的潜力。其性能 (以每秒浮点运算次数 FLOPs/s 衡量) 徘徊在 NVIDIA A100 等 GPU 理论峰值的 25–40% 左右。相比之下,高度优化的矩阵乘法 (GEMM) 例程可以达到该峰值的 80–90%。仍有显著的性能差距有待填补。

FlashAttention-2 应运而生。这项后续工作剖析了原始算法的剩余低效之处,并引入了一系列优化。通过重新思考如何在 GPU 的计算单元之间以及单元内部划分工作,FlashAttention-2 实现了又一次约 2 倍加速,将硬件利用率推向了令人印象深刻的理论峰值的 50–73%

在这篇文章中,我们将深入探讨论文《FLASHATTENTION-2: 通过更优并行化与工作划分实现更快的注意力机制》。我们将探讨:

  • 快速回顾: 标准注意力机制为何缓慢,以及初代 FlashAttention 的工作原理
  • FlashAttention-2 的三大关键创新: 减少慢速操作、更智能的并行化以及更高效的工作划分
  • 令人瞩目的结果: 展示 FlashAttention-2 的速度已接近纯矩阵乘法

背景: GPU 瓶颈与 FlashAttention-1

要理解 FlashAttention-2,我们首先需要掌握标准注意力为何效率低下,以及它的前辈 FlashAttention 如何解决了问题的初步部分。

GPU 硬件基础

GPU 不是一个单一的计算引擎,而是一个具有内存层级结构的复杂系统。对我们来说,最重要的两个层级是:

  1. 高带宽内存 (HBM): 即大容量的显存 (例如 A100 上的 40–80 GB) ,用于存放模型和数据。与常规 RAM 相比,它的“带宽”更高,但相对片上内存来说依然较慢。
  2. SRAM (静态随机存取存储器) ,又称共享内存: 极其快速的片上内存,容量很小 (每个计算单元仅几 KB) ,但带宽远高于 HBM。

GPU 编程的黄金法则:** 尽量减少对 HBM 的读写**。最高效的算法会将数据从 HBM 加载到快速的 SRAM 中一次性完成尽可能多的计算,然后只将最终结果写回 HBM。每一次不必要的 HBM 往返都会造成严重的性能瓶颈。

标准注意力机制的问题

标准自注意力机制定义如下:

标准注意力机制的计算公式。

标准注意力机制: 给定查询 Q、键 K 和值 V,计算得分 S,应用 softmax 得到概率 P,然后用 PV 加权得到输出 O

这里,QKV 是形状为 \( N \times d \) 的矩阵,其中 \( N \) 为序列长度,\( d \) 为头的维度。

朴素实现方式如下:

  1. 计算 N × N 得分矩阵 S = QKT将 S 写入 HBM
  2. 从 HBM 读取 S,逐行应用 softmax 得到 P将 P 写入 HBM
  3. 从 HBM 读取 P,与 V 相乘得到输出 O

即使序列长度仅 8k,SP 就是 8k × 8k 的矩阵——含数亿个元素。在 HBM 中来回存储和传输这些巨大的矩阵非常缓慢,并且耗费大量内存。这就是二次方瓶颈

FlashAttention 的解法: 分块与在线 Softmax

FlashAttention 的关键洞见是通过分块 (tiling) 技术,避免将完整的 SP 矩阵写入 HBM。

算法将 QKV 划分成更小的块,一次将一个 Q 块和一个 K/V 块加载到 SRAM 中。

使用分块技术的 FlashAttention 前向传播示意图。

图 1: 初代 FlashAttention 算法将 K 和 V 的块加载到高速 SRAM 中,与一个 Q 块计算注意力并更新输出,从而避免在慢速 HBM 中生成完整的 N × N 注意力矩阵。

在 SRAM 内部,它只为该块计算注意力。但 softmax 需要整行的指数和才能正确归一化,不能只独立处理一个块。

这时在线 softmax (online softmax) 就派上用场。算法为每一行维护一个运行中的最大值和归一化因子。对每个新块,计算局部 softmax,使用更新后的统计数据重新缩放之前的结果,并将它们合并。处理完所有块后,得到的结果与标准注意力完全一致——但消除了昂贵的 HBM 往返。

在反向传播中,FlashAttention 不存储庞大的 P 矩阵,而是在需要时重新计算注意力块。这样内存复杂度从 \(O(N^2)\) 降至 \(O(N)\)。


FlashAttention-2: 进一步提升硬件效率

FlashAttention 是一次巨大飞跃,但仍未将 GPU 性能发挥到极致。作者们为 FlashAttention-2 确定了三项优化。

1. 减少非矩阵乘法浮点运算 (Non-Matmul FLOPs)

现代 GPU (如 A100) 拥有专为矩阵乘法优化的张量核心 (Tensor Cores) 。在 A100 上:

  • FP16/BF16 矩阵乘法: 峰值 312 TFLOPs/s
  • FP32 非矩阵乘法操作: 峰值 19.5 TFLOPs/s

非矩阵乘法 FLOP 的成本可能是矩阵乘法 FLOP 的 16 倍

初代 FlashAttention 在其内循环中执行了一些非矩阵乘法的缩放操作,例如在处理每个块后重复缩放 O

FlashAttention-2 修改了在线 softmax 流程:

FlashAttention-2 更新后的在线 softmax 计算公式。

图: FlashAttention-2 避免了在循环内部重复缩放输出,只在最终进行一次缩放。

它不再在每次迭代中缩放 O,而是维护一个未缩放的输出版本和归一化统计数据,只在最后执行一次缩放。这减少了昂贵的非矩阵乘法操作,让 GPU 能够更多地用于高吞吐量的矩阵乘法。


2. 跨 GPU 的更优并行化

一块 NVIDIA A100 有 108 个流式多处理器 (Streaming Multiprocessors,SMs) ,它们执行线程块 (thread blocks) ,成千上万的线程并行工作。为了最大化速度,所有 SM 都必须保持活跃。

初代 FlashAttention 在“批次大小 × 头数”维度上并行化,为每个头分配一个线程块。如果批次 × 头数 ≥ SM 数量,这种方式效果很好。

但在处理长序列时,批次大小通常很小 (为了适应内存) 。当批次 × 头数 < SM 数量时,许多 SM 会闲置。

FlashAttention-2 增加了沿序列长度维度 \(N\) 的并行化:

前向与后向传播的并行化策略。

图 2: 前向传播 (左) : 每个线程块处理一个行块 (查询切片) 。后向传播 (右) : 线程块处理列块并使用原子加法累积共享行的梯度。

  • 前向传播: Q 的不同行块相互独立,因此线程块可并行处理它们,即使在小批次下也能提升利用率。
  • 反向传播: 由于 Q 梯度存在依赖,情况更复杂,通过重构使其可以在列块 (KV 切片) 上并行处理,依赖关系通过原子加法 (atomic adds) 解决。

这确保了长序列场景下的 GPU 完全利用。


3. 线程块内更高效的工作划分

线程块内部的线程被分组为线程束 (warps) (通常 32 个线程) 。Warp 可通过共享内存快速共享数据,但共享内存访问也需要时间。

初代 FlashAttention 使用了“拆分 K”方案:

  • Q 由所有 warp 共享
  • KV 在 warp 之间拆分
  • 结果通过共享内存读写合并

这造成了瓶颈。

FlashAttention-2 则反其道而行之:

FlashAttention 与 FlashAttention-2 的工作划分对比。

图 3: FlashAttention (a) 拆分 K 和 V,需要 warp 间通信。FlashAttention-2 (b) 拆分 Q,所有 warp 共享 K 和 V,几乎无需通信。

现在:

  • KV 由所有 warp 共享
  • Q 在不同 warp 中拆分
  • 每个 warp 独立计算各自的 QKT 切片并与共享的 V 相乘
  • 在最终写出之前,几乎没有 warp 间通信

这减少了共享内存流量和同步开销。


结果: 缩小差距

这些优化的综合效果如何?基准测试展示了明显的性能提升。

A100 GPU 上前向与后向传播综合速度。

图 4: A100 上的前向+后向传播综合速度。FlashAttention-2 (紫色) 始终比 FlashAttention (橙色) 快约 2 倍,并且远快于 PyTorch (蓝色) 和 xformers (绿色) 。

与 PyTorch 标准注意力相比,长序列的加速可达 10 倍。前向与后向传播的原始吞吐量如下:

前向传播速度 (TFLOPs/s) 。

图 5: 前向传播: FlashAttention-2 达到 230 TFLOPs/s——A100 理论峰值的 73%。

后向传播速度 (TFLOPs/s) 。

图 6: 后向传播: FlashAttention-2 达到理论峰值的 63%。

在注意力机制中实现超过 70% 的理论峰值——接近 GEMM 的效率——相当非凡。


端到端训练增益

最终的考验是训练真实模型。

GPT 模型端到端训练速度表。

表 1: GPT 模型在 8×A100 上的端到端训练吞吐量。FlashAttention-2 每块 GPU 达到 225 TFLOPs/s (72% 利用率) ,较 FlashAttention 提速 1.3 倍,相较无 FlashAttention 提速最高可达 2.8 倍。

对于上下文长度为 8k 的 GPT-1.3B 模型:

  • FlashAttention-2: 220 TFLOPs/s
  • FlashAttention: 170 TFLOPs/s
  • 基线: 72 TFLOPs/s

这意味着,训练一个 16k 上下文模型所需的成本现在与之前训练一个 8k 上下文模型大致相同。


结论与未来方向

FlashAttention-2 是硬件感知算法设计的典范。通过识别从高层数学到 warp 级执行的瓶颈,作者们将精确注意力推近了硬件极限。

其意义重大:

  • 长上下文训练与推理更经济、更可行
  • 现有流水线获得显著加速
  • 使得能处理更大量信息的模型成为可能

展望未来,作者们计划:

  • 针对 NVIDIA H100 GPU 进行优化 (利用新张量内存加速器与 FP8 张量核心)
  • 支持新的数据类型,如 FP8
  • 与更高层的算法优化 (如局部、块稀疏注意力) 结合,支持更长上下文

有了 FlashAttention-2,拥有几乎无限上下文窗口的模型的梦想比以往更近了一步。