让 Transformer 飞起来: 深入探究线性注意力

自 2017 年里程碑式的论文 Attention Is All You Need 问世以来,Transformer 已经席卷了整个人工智能领域。像 BERT、GPT-3 和 DALL·E 这样的模型彻底改变了自然语言处理、计算机视觉等多个领域。它们是生成式 AI 浪潮背后的核心引擎——能够编写代码、创作艺术,并开展令人惊讶的连贯对话。

但这些强大的模型隐藏着一个代价高昂的秘密: 计算上的瓶颈。直到最近,这个瓶颈都在限制它们一次能处理的信息量。Transformer 的核心——自注意力机制——其计算和内存复杂度为 \(O(N^2)\),其中 \(N\) 是输入序列的长度。这意味着,如果你将文本长度或图像像素数加倍,成本不只是加倍,而是增加到四倍。对于超长序列——例如高分辨率图像、长篇文档或音频片段——这种二次方的增长会变得极其昂贵。

这一限制激发了关于更“高效”Transformer 的研究浪潮: 我们如何在不付出高昂二次方代价的情况下保留全局注意力的威力?

一篇发表于 2020 年的精彩论文 Transformers are RNNs: Fast Autoregressive Transformers with Linear Attention 给出了一个优雅而强大的答案。研究者提出了 线性 Transformer,它将自注意力的复杂度从 \(O(N^2)\) 降低到更易处理的 \(O(N)\)。这一转变带来了惊人的速度提升——在某些情况下达 4000 倍——并揭示了 Transformer 与循环神经网络 (RNN) 之间的深层联系,而后者传统上被视为序列建模的完全不同范式。

本文将剖析线性 Transformer 背后的关键思想。我们会探讨一个巧妙的数学技巧如何驯服二次方的怪兽,这种改变怎样催生闪电般快速的生成模型,以及为什么从某种意义上说,Transformer 可以被视为 RNN。


问题所在: 理解 \(O(N^2)\) 的瓶颈

在深入解决方案之前,让我们先回顾标准自注意力的机制,并定位瓶颈的来源。

Transformer 由多层堆叠而成,每一层包括两部分:

  1. 自注意力 —— 使序列中各元素能够交互。
  2. 前馈网络 —— 独立处理每个元素。

在自注意力步骤中,序列的每个元素都会生成三个向量:

  • 查询 (Query, Q) —— 表示“我在寻找什么”;
  • 键 (Key, K) —— 表示“我包含什么”;
  • 值 (Value, V) —— 代表实际携带的信息。

为了确定第 i 个元素对第 j 个元素的注意力程度,模型计算 \(Q_i\) 与 \(K_j\) 的点积。这些点积在整个序列上形成注意力矩阵,通过 softmax 归一化后,每个元素会聚合来自其他所有元素的信息。

标准的 softmax 自注意力方程。QKᵀ 乘法是二次复杂度的来源。

图: 在标准自注意力中,计算 \(QK^T\) 会生成一个庞大的 \(N \times N\) 矩阵——这就是时间和内存呈二次增长的根源。

问题就在这里: 序列长度为 \(N\) 时,我们必须计算并存储所有成对交互,得到一个 \(N \times N\) 矩阵。其成本——无论内存还是时间——都随 \(O(N^2)\) 扩展。对于长文本或大型图像,这个矩阵会迅速变得巨大且难以管理。


核心思想: 用核函数线性化注意力

作者提出了新的视角: 注意力实质上是一种加权平均,其权重由查询和键之间的相似度函数决定。标准注意力使用缩放点积的指数形式:

\[ \text{sim}(Q_i, K_j) = \exp(Q_i^T K_j / \sqrt{D}) \]

广义注意力方程,其中 sim() 可以是任意相似度函数。

图: 广义注意力公式,允许将相似度评分替换为任何合适的核函数。

关键洞见在于: 可以把这个指数相似度换成一个能表示为特征映射之间点积的形式:

\[ \text{sim}(Q_i, K_j) = \phi(Q_i)^T \phi(K_j) \]

其中 \(\phi(\cdot)\) 是一个特征映射函数。

将这种核表示代入注意力公式后得到:

使用特征映射 φ 的线性化注意力方程。

图: 使用核函数特征映射重新表示查询和键的线性化注意力。

突破点在于: 通过运用 矩阵乘法的结合律,我们可以重新组织计算。无需显式求每个查询与所有键的相似度,而是预先计算所有键和值的加总:

\[ V_i' = \frac{\phi(Q_i)^T \sum_j \phi(K_j)V_j^T}{\phi(Q_i)^T \sum_j \phi(K_j)} \]

应用结合律后的线性化注意力方程。求和项现在与查询 i 无关。

图: 改写公式后,将求和操作移到每个查询之外,从而实现高效计算。

耗时的部分——\(\sum_j \phi(K_j)V_j^T\) 和 \(\sum_j \phi(K_j)\)——可以一次性计算并重复使用。模型不再需要构建完整的注意力矩阵。

从数学上看,计算从 \((\phi(Q)\phi(K)^T)V\) 转化为 \(\phi(Q)(\phi(K)^T V)\)。

结合律的可视化。不再计算 N×N 矩阵,而是处理更小的 D×M 矩阵,从而避免二次开销。

图: 结合律使模型只需处理较小矩阵,将二次时间复杂度降为线性。

结果是: 注意力计算的时间与序列长度线性相关,内存需求也降至 \(O(N)\)。

当然,这取决于特征映射的合适选择。理想的指数核维度是无限的,但一个有限且实用的替代形式运作得很好。作者选用:

特征映射 φ(x) = elu(x) + 1,确保相似度为正,实践效果良好。

图: 特征映射 \(\phi(x) = \text{elu}(x) + 1\) 保证相似度为正并稳定梯度。

这一简单的激活函数确保了相似度为正、梯度稳定,使线性注意力的性能几乎与标准 softmax 注意力相当,却大幅提升了效率。


顿悟时刻: Transformer 即 RNN

这种优势在 自回归 任务中尤其明显——例如逐步生成文本或图像。在这类任务里,模型每一步的输出都会成为下一步的输入。

为了防止“作弊”,我们使用 因果掩码,保证每个元素只能关注它之前的部分:

因果掩码注意力,其中求和仅截至当前位置 i。

图: 因果掩码确保注意力遵循序列顺序 (\(j ≤ i\)) 。

采用相同的线性化技巧后,我们得到只涉及截至位置 \(i\) 的累加项的公式:

线性化版本的因果掩码注意力。

图: 线性化的因果注意力,实现高效的自回归推理。

接着定义两个累计量:

累加和项 Sᵢ 和 Zᵢ。

图: 递归更新的状态 \(S_i\) 和 \(Z_i\),汇总所有过去的信息。

这些项可以递归更新:

  • \(S_i = S_{i-1} + \phi(K_i)V_i^T\)
  • \(Z_i = Z_{i-1} + \phi(K_i)\)

这个递归揭示了深刻的事实:** Transformer 层的行为类似循环神经网络 (RNN)** 。\((S_i, Z_i)\) 成为随时间演化的隐藏状态。

每一步中,模型更新这些状态并计算输出:

表示为循环神经网络的完整 Transformer 层,其中 sᵢ 和 zᵢ 是隐藏状态。

图: 将 Transformer 层重新表达为具有演化隐藏状态的循环模型。

这种等价性使 Transformer 同时具备两种架构的优点:

  1. 并行训练 —— 在训练时可高效并行计算所有时间步的累计项。
  2. 常数时间推理 —— 在生成时,模型只需以恒定成本更新紧凑状态 \((s_i, z_i)\)。

不同于传统 Transformer (生成第 1000 个词元需要重新计算前 999 个词元的注意力) ,线性 Transformer 在序列变长时仍能保持稳定的生成速度。


实验: 检验线性注意力性能

理论优雅,实践决定成败。作者从运行时间、内存占用、收敛行为及真实任务表现多个维度评估该模型。

性能与内存效率

不同注意力机制的时间和内存需求对比。线性注意力 (黑线) 随序列长度线性扩展,而 softmax (红线) 呈二次扩展,迅速变得不可行。

图 1: softmax 的时间与 GPU 内存需求呈二次增长,而新的核函数注意力呈线性增长。

如预期,当序列长度增加时,标准 softmax 注意力的成本急剧膨胀。而线性注意力依旧高效,即便序列长度达到 65,536 个词元,也实现了与理论一致的线性扩展。

为验证稳定性,作者在一个合成的序列复制任务上测试。线性注意力平滑收敛,表现和完整 softmax 一样可靠,并优于其他高效替代算法如 Reformer。

序列复制任务上的收敛曲线。线性注意力 (黑色) 收敛效果与 softmax (红色虚线) 一致,而 LSH 注意力 (蓝色点线) 不够稳定。

图 2: 线性 Transformer 达到与 softmax 相同的最终性能,同时保持稳定收敛。


自回归图像生成 —— 提速高达 4000 倍

在线性 Transformer 应用于自回归图像生成 (逐像素预测) 时,它展现出惊艳性能。

方法Bits/dim图像/秒
Softmax0.6210.45 (1×)
LSH-10.7450.68 (1.5×)
LSH-40.6760.27 (0.6×)
线性 (本文方法)0.644142.8 (317×)

表 1: MNIST 生成结果——质量相当,速度快数百倍。

在 MNIST 上,线性 Transformer 生成的图像质量与 softmax 一样,但速度快 317 倍。在 CIFAR-10 上,由于序列更长,速度提升更超过 4400 倍

MNIST 和 CIFAR-10 图像生成结果,加速比列显示线性 Transformer 的显著优势。

图 3: 线性注意力在保持图像质量的同时实现数千倍推理加速。

生成样本表明,效率提升并未以牺牲质量为代价。

由线性 Transformer 生成的 MNIST 无条件样本与图像补全,生成的数字清晰可信。

图 4: 线性 Transformer 生成的 MNIST 样本与补全结果。输出质量与 softmax 相当,同时具备优秀的扩展性。


语音识别

为了验证更广泛的适用性,作者在 WSJ 数据集上测试了自动语音识别任务。结果再次彰显线性注意力的优势。

方法验证集 PER每轮时间 (秒)
Bi-LSTM10.941047
Softmax5.122711
LSH-49.332250
线性 (本文方法)8.08824

表 2: 语音识别结果——线性 Transformer 训练速度超过所有基线模型,同时保持竞争性准确率。

即便在非自回归的帧级任务中,线性化公式也显著缩短训练时间,而准确率仅略有下降。


结论: 快速注意力的未来

线性 Transformer 不仅仅是一种优化,它重新定义了自注意力的概念。通过用核函数特征映射取代 softmax,并利用矩阵乘法的结合律,注意力的时间与空间复杂度从二次降为线性。

核心要点

  1. 线性复杂度: 线性 Transformer 以 \(O(N)\) 的效率计算注意力,使得处理超长序列成为可能。
  2. 惊人提速: 在自回归任务中,推理每步的时间与内存成本保持恒定,实现数千倍的速度提升。
  3. 统一框架: Transformer 可表示为 RNN 的发现连接了两种模型家族,为新的混合架构提供灵感。

这项研究使 Transformer 能够应对过去被认为不可能的任务——长文本生成、高分辨率视频、整段音频等。它重新定义了我们对于注意力与循环机制的理解,开启了不仅学习更好也运行更快的架构。

让 Transformer 飞起来的 线性注意力,不仅让它更高效,也扩展了现代人工智能的可能疆界。