让 Transformer 飞起来: 深入探究线性注意力
自 2017 年里程碑式的论文 Attention Is All You Need 问世以来,Transformer 已经席卷了整个人工智能领域。像 BERT、GPT-3 和 DALL·E 这样的模型彻底改变了自然语言处理、计算机视觉等多个领域。它们是生成式 AI 浪潮背后的核心引擎——能够编写代码、创作艺术,并开展令人惊讶的连贯对话。
但这些强大的模型隐藏着一个代价高昂的秘密: 计算上的瓶颈。直到最近,这个瓶颈都在限制它们一次能处理的信息量。Transformer 的核心——自注意力机制——其计算和内存复杂度为 \(O(N^2)\),其中 \(N\) 是输入序列的长度。这意味着,如果你将文本长度或图像像素数加倍,成本不只是加倍,而是增加到四倍。对于超长序列——例如高分辨率图像、长篇文档或音频片段——这种二次方的增长会变得极其昂贵。
这一限制激发了关于更“高效”Transformer 的研究浪潮: 我们如何在不付出高昂二次方代价的情况下保留全局注意力的威力?
一篇发表于 2020 年的精彩论文 Transformers are RNNs: Fast Autoregressive Transformers with Linear Attention 给出了一个优雅而强大的答案。研究者提出了 线性 Transformer,它将自注意力的复杂度从 \(O(N^2)\) 降低到更易处理的 \(O(N)\)。这一转变带来了惊人的速度提升——在某些情况下达 4000 倍——并揭示了 Transformer 与循环神经网络 (RNN) 之间的深层联系,而后者传统上被视为序列建模的完全不同范式。
本文将剖析线性 Transformer 背后的关键思想。我们会探讨一个巧妙的数学技巧如何驯服二次方的怪兽,这种改变怎样催生闪电般快速的生成模型,以及为什么从某种意义上说,Transformer 可以被视为 RNN。
问题所在: 理解 \(O(N^2)\) 的瓶颈
在深入解决方案之前,让我们先回顾标准自注意力的机制,并定位瓶颈的来源。
Transformer 由多层堆叠而成,每一层包括两部分:
- 自注意力 —— 使序列中各元素能够交互。
- 前馈网络 —— 独立处理每个元素。
在自注意力步骤中,序列的每个元素都会生成三个向量:
- 查询 (Query, Q) —— 表示“我在寻找什么”;
- 键 (Key, K) —— 表示“我包含什么”;
- 值 (Value, V) —— 代表实际携带的信息。
为了确定第 i 个元素对第 j 个元素的注意力程度,模型计算 \(Q_i\) 与 \(K_j\) 的点积。这些点积在整个序列上形成注意力矩阵,通过 softmax 归一化后,每个元素会聚合来自其他所有元素的信息。
图: 在标准自注意力中,计算 \(QK^T\) 会生成一个庞大的 \(N \times N\) 矩阵——这就是时间和内存呈二次增长的根源。
问题就在这里: 序列长度为 \(N\) 时,我们必须计算并存储所有成对交互,得到一个 \(N \times N\) 矩阵。其成本——无论内存还是时间——都随 \(O(N^2)\) 扩展。对于长文本或大型图像,这个矩阵会迅速变得巨大且难以管理。
核心思想: 用核函数线性化注意力
作者提出了新的视角: 注意力实质上是一种加权平均,其权重由查询和键之间的相似度函数决定。标准注意力使用缩放点积的指数形式:
\[ \text{sim}(Q_i, K_j) = \exp(Q_i^T K_j / \sqrt{D}) \]图: 广义注意力公式,允许将相似度评分替换为任何合适的核函数。
关键洞见在于: 可以把这个指数相似度换成一个能表示为特征映射之间点积的形式:
\[ \text{sim}(Q_i, K_j) = \phi(Q_i)^T \phi(K_j) \]其中 \(\phi(\cdot)\) 是一个特征映射函数。
将这种核表示代入注意力公式后得到:
图: 使用核函数特征映射重新表示查询和键的线性化注意力。
突破点在于: 通过运用 矩阵乘法的结合律,我们可以重新组织计算。无需显式求每个查询与所有键的相似度,而是预先计算所有键和值的加总:
\[ V_i' = \frac{\phi(Q_i)^T \sum_j \phi(K_j)V_j^T}{\phi(Q_i)^T \sum_j \phi(K_j)} \]图: 改写公式后,将求和操作移到每个查询之外,从而实现高效计算。
耗时的部分——\(\sum_j \phi(K_j)V_j^T\) 和 \(\sum_j \phi(K_j)\)——可以一次性计算并重复使用。模型不再需要构建完整的注意力矩阵。
从数学上看,计算从 \((\phi(Q)\phi(K)^T)V\) 转化为 \(\phi(Q)(\phi(K)^T V)\)。
图: 结合律使模型只需处理较小矩阵,将二次时间复杂度降为线性。
结果是: 注意力计算的时间与序列长度线性相关,内存需求也降至 \(O(N)\)。
当然,这取决于特征映射的合适选择。理想的指数核维度是无限的,但一个有限且实用的替代形式运作得很好。作者选用:
图: 特征映射 \(\phi(x) = \text{elu}(x) + 1\) 保证相似度为正并稳定梯度。
这一简单的激活函数确保了相似度为正、梯度稳定,使线性注意力的性能几乎与标准 softmax 注意力相当,却大幅提升了效率。
顿悟时刻: Transformer 即 RNN
这种优势在 自回归 任务中尤其明显——例如逐步生成文本或图像。在这类任务里,模型每一步的输出都会成为下一步的输入。
为了防止“作弊”,我们使用 因果掩码,保证每个元素只能关注它之前的部分:
图: 因果掩码确保注意力遵循序列顺序 (\(j ≤ i\)) 。
采用相同的线性化技巧后,我们得到只涉及截至位置 \(i\) 的累加项的公式:
图: 线性化的因果注意力,实现高效的自回归推理。
接着定义两个累计量:
图: 递归更新的状态 \(S_i\) 和 \(Z_i\),汇总所有过去的信息。
这些项可以递归更新:
- \(S_i = S_{i-1} + \phi(K_i)V_i^T\)
- \(Z_i = Z_{i-1} + \phi(K_i)\)
这个递归揭示了深刻的事实:** Transformer 层的行为类似循环神经网络 (RNN)** 。\((S_i, Z_i)\) 成为随时间演化的隐藏状态。
每一步中,模型更新这些状态并计算输出:
图: 将 Transformer 层重新表达为具有演化隐藏状态的循环模型。
这种等价性使 Transformer 同时具备两种架构的优点:
- 并行训练 —— 在训练时可高效并行计算所有时间步的累计项。
- 常数时间推理 —— 在生成时,模型只需以恒定成本更新紧凑状态 \((s_i, z_i)\)。
不同于传统 Transformer (生成第 1000 个词元需要重新计算前 999 个词元的注意力) ,线性 Transformer 在序列变长时仍能保持稳定的生成速度。
实验: 检验线性注意力性能
理论优雅,实践决定成败。作者从运行时间、内存占用、收敛行为及真实任务表现多个维度评估该模型。
性能与内存效率
图 1: softmax 的时间与 GPU 内存需求呈二次增长,而新的核函数注意力呈线性增长。
如预期,当序列长度增加时,标准 softmax 注意力的成本急剧膨胀。而线性注意力依旧高效,即便序列长度达到 65,536 个词元,也实现了与理论一致的线性扩展。
为验证稳定性,作者在一个合成的序列复制任务上测试。线性注意力平滑收敛,表现和完整 softmax 一样可靠,并优于其他高效替代算法如 Reformer。
图 2: 线性 Transformer 达到与 softmax 相同的最终性能,同时保持稳定收敛。
自回归图像生成 —— 提速高达 4000 倍
在线性 Transformer 应用于自回归图像生成 (逐像素预测) 时,它展现出惊艳性能。
方法 | Bits/dim | 图像/秒 |
---|---|---|
Softmax | 0.621 | 0.45 (1×) |
LSH-1 | 0.745 | 0.68 (1.5×) |
LSH-4 | 0.676 | 0.27 (0.6×) |
线性 (本文方法) | 0.644 | 142.8 (317×) |
表 1: MNIST 生成结果——质量相当,速度快数百倍。
在 MNIST 上,线性 Transformer 生成的图像质量与 softmax 一样,但速度快 317 倍。在 CIFAR-10 上,由于序列更长,速度提升更超过 4400 倍。
图 3: 线性注意力在保持图像质量的同时实现数千倍推理加速。
生成样本表明,效率提升并未以牺牲质量为代价。
图 4: 线性 Transformer 生成的 MNIST 样本与补全结果。输出质量与 softmax 相当,同时具备优秀的扩展性。
语音识别
为了验证更广泛的适用性,作者在 WSJ 数据集上测试了自动语音识别任务。结果再次彰显线性注意力的优势。
方法 | 验证集 PER | 每轮时间 (秒) |
---|---|---|
Bi-LSTM | 10.94 | 1047 |
Softmax | 5.12 | 2711 |
LSH-4 | 9.33 | 2250 |
线性 (本文方法) | 8.08 | 824 |
表 2: 语音识别结果——线性 Transformer 训练速度超过所有基线模型,同时保持竞争性准确率。
即便在非自回归的帧级任务中,线性化公式也显著缩短训练时间,而准确率仅略有下降。
结论: 快速注意力的未来
线性 Transformer 不仅仅是一种优化,它重新定义了自注意力的概念。通过用核函数特征映射取代 softmax,并利用矩阵乘法的结合律,注意力的时间与空间复杂度从二次降为线性。
核心要点
- 线性复杂度: 线性 Transformer 以 \(O(N)\) 的效率计算注意力,使得处理超长序列成为可能。
- 惊人提速: 在自回归任务中,推理每步的时间与内存成本保持恒定,实现数千倍的速度提升。
- 统一框架: Transformer 可表示为 RNN 的发现连接了两种模型家族,为新的混合架构提供灵感。
这项研究使 Transformer 能够应对过去被认为不可能的任务——长文本生成、高分辨率视频、整段音频等。它重新定义了我们对于注意力与循环机制的理解,开启了不仅学习更好也运行更快的架构。
让 Transformer 飞起来的 线性注意力,不仅让它更高效,也扩展了现代人工智能的可能疆界。