Transformer 席卷了全球,从 ChatGPT 到高级代码补全工具无所不在。它们最神奇的能力之一是 上下文学习 (in-context learning, ICL) ——在不更新任何权重的情况下,从输入提示中提供的示例中学习的能力。如果你给一个大型语言模型展示几个任务示例,它通常可以立刻在新示例上执行相同任务。
长期以来,这一现象的工作原理一直颇为神秘。最近的研究开始层层揭示,表明对于像线性回归这样的简单任务,Transformer 内部实际上在运行一种梯度下降形式。每个注意力层都像一个优化步骤,根据提示中的数据逐步细化其内部的“解”。
但这引出了一个令人着迷的问题: Transformer 的能力是否仅限于此?它们是否只能模仿那些简单且众所周知的算法?Google Research 的一篇新论文 《线性 Transformer 是通用的上下文学习器》 (Linear Transformers are Versatile In-Context Learners) 给出了一个响亮的 否定答案。研究者们证明,即使是简化的“线性”Transformer,也能发现并实现极为复杂的优化算法——这些算法能够根据数据噪声动态调整,并且性能优于标准方法。
这项工作表明,Transformer 不仅仅是学习者,它们很可能还是算法的发明者。让我们来探究它是如何做到的。
背景设定: 线性 Transformer 与噪声问题
为了剖析上下文学习的核心机制,研究人员聚焦于一个极简模型:** 线性 Transformer**。与标准架构不同,这些模型去掉了 MLP 层和 LayerNorm,只保留线性自注意力层。
一个线性注意力层处理一个 token 序列 \( e_1, e_2, ..., e_n \)。对于每个 token \( e_i \),它通过关注序列中所有其他 token 来计算更新:
每个头 \(k\) 根据学习到的矩阵 \(Q_k\) 和 \(P_k\) 生成更新;所有头的更新最终相加。
每个 token \( e_i \) 包含一个特征向量和一个标签,\( e_i = (x_i, y_i) \)。输入序列含有 \(n\) 个示例以及一个查询 token \( e_{n+1} = (x_t, 0) \)。Transformer 的目标是为查询 \(x_t\) 预测正确的标签 \(y_t\)。
研究人员研究了两种问题设定:
固定噪声方差 (Fixed Noise Variance):
每个序列都使用相同的噪声水平 \( \sigma^2 \) 生成。先前的研究发现,Transformer 在这种数据上会学习一种类似梯度下降的算法,称为 GD++。混合噪声方差 (Mixed Noise Variance):
每个序列的噪声水平各不相同,从一个分布 (例如 \( U(0, \sigma_{\max}) \)) 中抽取。
这种设定更具挑战性: 最优解需要 岭回归 (ridge regression),它在损失函数中加入一个依赖噪声方差的正则化项。
岭回归对较大的权重施加惩罚,以提升对数据噪声的鲁棒性。
由于每个提示的噪声水平不同,Transformer 必须从数据中推断噪声强度并进行自适应——这比简单的最小二乘法要困难得多。
核心洞见: 每一层都在更新一个内部模型
研究人员的第一个重要理论成果解释了线性 Transformer 的内部机制。每一层都维护一个隐式的 线性回归模型,并在数据流动过程中持续更新。
在第 \(l\) 层,每个 token 的输出标签可表示为:
\[ y_i^{l+1} = a^l y_i - \langle w^l, x_i \rangle, \]其中 \(w^l\) 是隐式权重向量,\(a^l\) 是学习得出的缩放系数。这些变量在各层间不断演化——模型实际上在迭代更新其内部回归参数。
每一层根据前一层的特征协方差和互相关等统计量来计算更新。
在简化假设 (注意力矩阵为对角形式) 下,\(w^l\) 及一个辅助的“动量”向量 \(u^l\) 的更新规则表现如下:
这些更新与带动量的梯度下降过程非常相似。
紧凑形式为:
\[ u^{l+1} = (1-\beta)u^l + \nabla f(w^l), \]\[ w^{l+1} = w^l - \eta u^l. \]虽然这个类比不完全准确——Transformer 的系数是矩阵而非标量——但相似度令人惊叹。每一层实际上都在对一个内部模型执行复杂的优化步骤,而不只是做数据变换。
剖析学到的算法
令人意外的是,使用对角注意力权重的模型与完全矩阵模型的性能几乎相当。这让团队能够使用四个标量参数对学习到的算法进行逐项分析:
\( \omega_{xx}, \omega_{xy}, \omega_{yx}, \omega_{yy} \)。
这些参数描述了跨层特征与标签之间的信息流动。
每个参数控制着不同的动态过程:
1. 预处理梯度下降 (\( \omega_{yx} \), \( \omega_{xx} \))
当仅激活这两个分量时,Transformer 执行 GD++——一种通过预处理增强的梯度下降形式。
GD++ 通过预处理更新 \(x\),并通过梯度步骤调整 \(y\)。
作者证明,GD++ 是一种 二阶优化方法,类似于牛顿法。它能在 \(O(\log\log (1/\epsilon))\) 步内收敛,这解释了其在简单回归任务中的高效性。
2. 自适应缩放 (\( \omega_{yy} \))
分量 \( \omega_{yy} \) 引入了 噪声感知缩放。其更新规则可简化为:
较大的标签能量 (\(\lambda^l = \sum_i (y_i^l)^2\)) 会触发更强的缩放调整。
当数据表现出高方差 (即较大的 \(\lambda^l\)) 时,负的 \( \omega_{yy} \) 会收缩输出——这与岭回归的正则化效果类似。它能自动根据噪声强度调整,实现自适应校正以稳定学习过程。
3. 自适应步长 (\( \omega_{xy} \))
最后的项 \( \omega_{xy} \) 允许动态控制步长,其作用分两层实现:
- 第一层根据残差误差调整 \(x_i\)。
- 第二层使用调整后的特征执行梯度下降。
模型通过数据残差自动调节有效步长。
步长与 \( (1 + \omega_{xy} \sum_i r_i^2) \) 成正比,其中 \( \sum_i r_i^2 \) 是噪声方差的估计。负的 \( \omega_{xy} \) 在噪声较大时会让步长减小——相当于一种基于经验学习的自适应“早停”机制。
这些组件共同形成了一个经过精细调控的优化过程,融合了高级的梯度更新、缩放和噪声敏感调整——这一切都是 Transformer 自动发现的。
实验: 验证涌现的优化能力
团队在混合噪声回归任务上训练了三种线性 Transformer 变体,以验证前述理论:
- FULL: 使用完整矩阵注意力参数。
- DIAG: 限制为对角矩阵。
- GD++: 模拟标准预处理梯度下降的简化版本。
他们将这些模型与岭回归基线进行比较,包括一个经过调优的变体 (TUNEDRR
),该变体使用理想化的噪声估计。
随着噪声范围和模型层数增加,DIAG 和 FULL 的性能超过基线,而 GD++ 无法适应可变噪声。
当层深和噪声多样性 (\(\sigma_{\max}\)) 增加时,DIAG 和 FULL 实现了近乎完美的噪声自适应,在高噪声条件下甚至超越了调优的岭回归。
为了分析模型在不同噪声水平下的表现:
GD++ 仅在一个特定噪声水平上有效;而 DIAG 和 FULL 在所有噪声水平上都表现稳健。
GD++ 的性能在某个固定水平附近停滞,说明它假设了一个平均噪声方差。相比之下,DIAG 和 FULL 学会了能在整个范围内有效应对噪声的灵活策略。
在分析学习到的权重矩阵时:
学习到的矩阵几乎完全是对角形式,验证了理论简化的有效性。
明显的对角主导性表明,对角模型分析确实准确地捕捉了 Transformer 的主要行为模式。
结论: 作为算法发现者的 Transformer
这项研究为我们揭示了 Transformer 内部的运作机制——即便是简单的线性变体,也能通过训练过程自行发现强大的优化算法。
主要洞见包括:
隐式优化:
每一层都执行类似于带动量的梯度下降迭代,但更为复杂,涉及矩阵动力学和自适应控制。噪声下的算法发现:
面对含噪数据,Transformer 能自发发明策略,如自适应缩放和可变步长控制,这对稳健优化至关重要。经验上的优越性:
学习到的算法可媲美甚至超越如调整后的岭回归等闭式解,显示出 Transformer 自动构建高效解决方案的能力。
更广泛地说,这项工作将 Transformer 定位为不仅是学习者,而是涌现的算法设计者。通过在特定任务上进行训练,我们可以利用它们发现全新的优化与学习算法类别——突破传统机器学习设计的边界。
这项研究展示了复杂推理与计算如何从简单的架构原则中涌现,并为未来通过神经网络实现自动算法发现的研究方向铺平了道路。