多年来,Transformer 一直是序列建模的主导架构——从语言、代码到长文档。其 softmax 自注意力机制赋予了它无与伦比的灵活性,可以访问任意过去的 token,但这种灵活性也伴随着代价: 在推理过程中,内存和计算量随序列长度线性增长。这使得处理超长上下文问题的成本极高,并促使我们重新审视一个更早的思路: 采用循环式层,维持一个固定大小的状态,并在新 token 到达时更新它。

最近出现的一批模型 (Mamba、xLSTM、DeltaNet 等) 表明,精心设计的线性化注意力或“快速权重”RNN 可以在保持每个 token 的内存和计算恒定的同时,实现高质量的语言建模性能。这些模型有一个共同而有用的统一视角: 它们可以被理解为在执行测试时训练——当新的键值对到达时更新一个小型内部线性模型 (快速权重) 。

MesaNet 将这一思想进一步推进。它的 Mesa 层不是在每个 token 的损失上做小步梯度下降,而是在每个时间步上最优地求解一个 (正则化的) 最小二乘问题,生成从键到值的线性映射,使整个历史的累积平方误差最小化。直接计算该最优解是不切实际的,但 MesaNet 论文引入了一种数值稳定、分块并行的公式,利用了共轭梯度 (CG) 和硬件友好的核函数。其结果是一个循环层架构,具有以下特点:

  • 在每一步中进行明确的上下文优化 (最优的测试时训练) 。
  • 可分块并行化,以便在 GPU/TPU 上高效训练。
  • 通过提前停止迭代求解器,动态分配测试时计算资源。
  • 在合成任务上匹敌或超越其他线性 RNN,并在十亿参数规模上成为有竞争力的语言模型——同时揭示了一个重要的权衡: 可以在推理时投入额外的计算以提升预测。

在本文中,我将解析核心概念,展示 Mesa 层的高效实现方式,并梳理相关实验及其对 RNN 与 Transformer 权衡的启示。

图: 总体架构 (通道混合 + Mesa 层) 。
MesaNet 架构: 堆叠的残差块,包含通道混合的 MLP 和序列混合的 Mesa 层。Mesa 层生成键、查询和值,并在每个时间步计算最优的快速权重读出。

图 1: MesaNet 架构。每个残差块包含一个通道混合块 (SwiGLU MLP) 和一个序列混合块 (Mesa 层) 。Mesa 层计算键、查询、值以及输入/遗忘门,然后应用最优的更新和读出规则。

为什么要重新审视 RNN 风格的设计?因为在上下文长度特别大时,它们恒定的内存与每 token 恒定的计算量极具吸引力。许多现代线性化注意力 RNN 的核心技巧在于,将循环状态视为一个将键映射到值的线性模型 Φ,并在新的 (k, v) 对到达时在线更新 Φ。不同的更新规则对应不同的局部学习规律:

  • 类 Hebb 更新 (快速权重,GLA) 添加外积 v k^T (可选地进行缩放) 。
  • Delta 规则更新 (DeltaNet) 类似于针对当前 token 的平方误差执行单步梯度下降。
  • Mesa: 每个时间步最优地求解累积正则化最小二乘问题。

Mesa 的直觉非常简单但极其强大: 如果层的状态 Φ 旨在建模关系 v ≈ Φ k,为何不令 Φ 成为在所有观察到的 (k, v) 对下的最佳线性映射?这正是 Mesa 的目标。

Mesa 层的优化目标

在时间步 t,Mesa 层定义了累积目标 (正则化的最小二乘) :

\[ \hat\Phi_t^{\text{mesa}} \;=\; \arg\min_{\Phi}\; \mathcal{L}_t(\Phi) \quad\text{with}\quad \mathcal{L}_t(\Phi) \;=\; \frac{1}{2}\sum_{t'=1}^t \zeta_{t t'}\|v_{t'}-\Phi k_{t'}\|^2 \;+\; \frac{1}{2}\mathrm{Tr}(\Phi^\top\Lambda\Phi). \]

其中:

  • \(k_{t'}, v_{t'}\) 是由 token 嵌入生成的历史键和值;
  • \(\zeta_{t t'}\) 是一个可选的时间相关权重因子 (如来自遗忘门 \(\gamma\)) ;
  • \(\Lambda\) 是正定的正则化项 (实际中为可学习的对角矩阵) ;
  • \(\hat\Phi_t^{\text{mesa}}\) 是时间 t 的最优快速权重矩阵。

Mesa 对查询 \(q_t\) 的输出即为最优线性读出:

\[ \Delta e_t^{\text{mesa}} \;=\; \hat\Phi_t^{\text{mesa}} q_t. \]

由于 \(\mathcal{L}_t\) 对 \(\Phi\) 是二次函数,因此最优解 \(\hat\Phi_t\) 可以通过两个充分统计量表示为封闭形式:

\[ G_t = \sum_{t'=1}^t \zeta_{t t'} v_{t'} k_{t'}^\top, \qquad H_t = \sum_{t'=1}^t \zeta_{t t'} k_{t'} k_{t'}^\top, \]

所以形式上有:

\[ \hat\Phi_t^{\text{mesa}} = G_t (H_t + \Lambda)^{-1}. \]

直接计算并求逆 \(H_t + \Lambda\) 在每步的代价很高,也存在数值难题。MesaNet 的创新使这一过程变得实用。

数值稳定、分块并行的 Mesa 层

论文避免了显式的密集矩阵求逆,而是采用共轭梯度 (CG) 法求解线性系统。主要前向计算变为:

\[ \Delta e_t^{\text{mesa}} = G_t\,\mathrm{linsolve}(H_t + \Lambda,\, q_t) \;=\; G_t\,x_t^*,\quad\text{where }(H_t+\Lambda)x_t^*=q_t. \]

两点关键观察使高效实现成为可能:

  1. \(G_t\) 和 \(H_t\) 都遵循带门控的简单线性递归:

    \[ G_t = \gamma_t G_{t-1} + \beta_t v_t k_t^\top, \qquad H_t = \gamma_t H_{t-1} + \beta_t k_t k_t^\top. \]


    它们是每步秩一更新,可在不存储完整历史的情况下累积。

  2. CG 的主要操作是与 \(H_t\) 进行矩阵-向量乘法,即计算 \((\sum_i \zeta_{t i} k_i k_i^\top)p\)。该求和与门控线性注意力 (GLA) 具有相同的代数结构: 加权外积之和作用于向量。因此,可以重用加速线性注意力的分块并行矩阵原语,在多个时间步并行执行计算。

实际训练以分块形式运行: 将训练序列划分为大小为 C 的块,预先计算块级累加器,并在每块内进行并行矩阵乘法,用于 GLA 风格的前向传播及 CG 所需的矩阵-向量乘积。由于每次 CG 的矩阵-向量乘法都具备 GLA 结构,因此迭代同样可以分块并行执行。

总结如下:

  • 训练可跨时间分块并行,充分利用 GPU/TPU 矩阵乘单元。
  • 推理时可选择固定的 CG 步数 k (确定性成本) 或基于容忍度 ε (动态成本,取决于 \(H_t+\Lambda\) 的条件数与数据) 。
  • CG 的使用使该层即使在复杂的遗忘动态下也保持数值稳定,但迭代求解需额外计算量。

图: 训练/推理时间与吞吐量比较。
分块并行 Mesa (Mesa-CG) 与基线模型的训练时间、推理时间和 token 吞吐量比较。尽管有额外 CG 迭代,Mesa-CG 凭借分块矩阵核仍保持高竞争力。

图 2: 训练与推理时间及 token 吞吐量。左: TPUv5 上不同序列长度与 CG 步数的每层训练耗时。中: 每层推理耗时。右: 在 H100 GPU 上 1B 模型的 token 吞吐量。Mesa 使用 15–30 CG 步的分块并行化在吞吐量上仍具优势。

设计选择与稳定性改进

若要使 Mesa 层稳定,有几项设计要点至关重要:

  • 对键和查询进行归一化 (RMSNorm + SiLU + L2 归一化) ,以限制数值幅度并稳定条件数;
  • 通过 softplus 参数化 \(\Lambda\),并设置正的下界 (如 0.25) ,防止条件数失控;
  • 限制遗忘门 \(\gamma_t\) (略低于 1) ,并与输入门协同调节,以避免长串重复 token 的“尖叫”序列问题;
  • 用良好的初始猜测初始化 CG (使用对角预处理或热启动) ,并限制最大迭代步 K。论文采用固定 30 步训练获得稳健表现,并在推理时展示基于容忍度 ε 的动态停止。

概念吸引力

将序列层解释为局部学习器,能清晰归纳设计逻辑:

  • GLA 与 Hebb 快速权重对应简单的在线累加策略;
  • DeltaNet 式更新对应每步平方误差上的单步梯度下降 (一阶更新) ;
  • Mesa 对应完整的正则化最小二乘目标的求解——二阶、局部最优学习器。

从这一角度看,Mesa 是平方误差意义下最优的线性关联记忆: 它存储键和值的映射,并即时检索最优线性变换来回答查询,受限于键维度容量。

实验: 合成基准与语言建模

论文在广泛实验中评估 MesaNet: 包括合成算法任务、上下文学习基准、真实数据困惑度,以及下游推理/回忆任务。关键结论如下:

合成任务 (MAD & RegBench)

  • 在 MAD (token 操作与记忆任务) 中,MesaNet 达到最高平均准确率,与 Transformer 相当,并超过其他线性 RNN,表现出强大的上下文关联能力。
  • 在 RegBench (上下文语法推断任务) 上,MesaNet 超出其他线性架构,缩小与 Transformer 的差距,展现良好的上下文泛化。

图: MAD 与 RegBench 示例。
MesaNet 在 MAD 基准测试中取得最高平均性能,在 RegBench 上表现出色,缩小了与 Transformer 的差距。

图 3: MAD 基准测试 — MesaNet 在线性循环模型中取得最高平均准确率,与 Transformer 持平。
图 4: RegBench — MesaNet 优于其他线性架构,并接近 Transformer 的上下文语法推断水平。

大规模语言建模 (SlimPajama,约 10 亿参数)

  • MesaNet (及混合的 Hawk–Mesa 变体) 在约 9.4 亿参数规模上训练,并在 SlimPajama 及下游分集上评估。
  • 在标准验证集平均困惑度 (PPL) 方面,MesaNet 与多种线性 RNN 基线相当或略高,在总 PPL 上与 Transformer 持平。
  • 但更深入分析发现一致模式: MesaNet 和其他 RNN 在早期 token (如前 64 个) 上预测更好,而在后期 token 上落后。Transformer 的全局上下文访问在长序列时带来优势。

Token 位置分析: 模型差异来源

仅看平均 PPL 会掩盖这种特征。作者计算了位置依赖的负对数似然 (NLL_k) 并与 Transformer 比较:

\[ \Delta\mathrm{NLL}_k^{\text{model}} = \mathrm{NLL}_k^{\text{model}} - \mathrm{NLL}_k^{\text{MHA}}. \]
  • 对小 k (早期 token) ,Mesa 和多数 RNN 的 \(\Delta\mathrm{NLL}_k\) 为负,预测更优。
  • 对大 k,\(\Delta\mathrm{NLL}_k\) 转正,Transformer 在后期预测上超过 RNN。
  • Mesa 和 Hawk–Mesa 的早期优势持续更长 (可达 512 token 以上) ,但总体趋势一致: Transformer 在超长预测上占优。

这反映归纳偏置分野: RNN 风格线性快速权重模型擅长局部自适应与快速记忆,而当全局长程回忆成为关键时,Transformer 更具优势。

图: 与 Transformer 的按位置 NLL 差异。
按 token 位置的 NLL 差异: 循环模型更擅长预测早期 token 分布,而 Transformer 在序列后段占优。Mesa 和 Hawk–Mesa 在 RNN 中保持最长期优势。

图 5: 1B 模型相对于 Transformer 的按位置 NLL 差异。负值表示该位置上性能优于 Transformer。大多数 RNN 在早期表现更好;Transformer 在后期超越。MesaNet 延伸这一优势更远。

超长上下文外推与滑动窗口基线

  • 测试外推至超长上下文 (高达 32k token) ,部分 RNN 崩溃,而 MesaNet 保持稳定,优于多种 RNN 基线。
  • 但值得注意,一个窗口大小为 1024 的滑动窗口注意力 Transformer (SWA-1024) 在长序列评估下仍出乎意料地强,提示两点: 任务中局部上下文有时足够;困惑度指标未必能衡量真正的长程能力。

下游任务: 全局 vs 局部需求

作者根据随着注意力窗口增大性能提升的幅度,将基准任务分为“全局型”和“局部型”:

  • 在局部推理任务中 (短窗足够) ,MesaNet 与多数 RNN 表现与 Transformer 相当。
  • 在全局推理与上下文回忆任务中 (长窗有显著益处) ,所有 RNN 包括 MesaNet 都明显落后。MesaNet 虽是最强 RNN,但未能弥合差距。

图: 下游任务汇总。
推理、上下文回忆和少样本任务的分组结果。MesaNet 在循环模型中表现最佳,但在全局/回忆任务上仍逊于 Transformer。

图 6: 下游任务聚合得分 (400M 与 1B 模型) 。MesaNet 领先循环模型,但在需要长程上下文的任务上不及 Transformer。

少样本学习与字符扰动任务

  • MesaNet 在 token 操作类任务 (如单词扰乱) 上少样本表现很强,且在若干少样本任务上具竞争力,偶尔超过 Transformer。
  • 对于翻译等依赖全局对齐与复杂长程映射的任务,Transformer 仍更擅长。

动态测试时计算: 节省计算而不损失精度

Mesa 层的一个突出优势是可动态分配计算量。由于 CG 迭代特性,可:

  • 运行固定步数 k (确定成本) ,或
  • 使用停止容忍度 ε,让每个头/时间步在残差足够小时独立结束。

实验显示:

  • 全部头统一减少 CG 步数会提升 NLL,尤其在后期 token;
  • 使用动态停止 (ε) 可保持与固定 30 步相同精度,同时显著减少平均步数——如 ε = 1e−4 时,平均步从 30 降至约 9。

图: 动态 CG 停止与性能权衡。
推理时 CG 步数影响: 固定 vs 动态停止。使用容忍度 (ε = 1e-4) 可在保持性能的同时显著降低平均步数。

图 7: 测试时计算分配。固定 k 步与动态 ε 停止对比。当 ε = 1e−4 时,MesaNet 在 NLL 相当的情况下显著减少平均迭代次。

解读、局限与实践启示

MesaNet 展示了测试时训练视角的强大潜能: 设计在推理时显式求解优化问题的层极具效果。但实证结果具有复杂性:

优势

  • Mesa 在每步执行原则性的最优上下文回归,合成任务表现出强大上下文能力;
  • 分块并行化让训练吞吐量具竞争力,尽管推理更昂贵;
  • 动态停止在推理时实现性能与计算的柔性权衡。

局限与开放问题

  • 测试时计算: Mesa 层每个 token 需额外计算 (约与 CG 迭代数成正比) ,恒定内存优势仍在,但若要严格收敛需更多算力;
  • 全局长程能力: 该系列所有线性 RNN (含 Mesa) 在需全局访问历史的任务上仍不及 Transformer;
  • 工程复杂度: 分块并行 CG、稳定性措施 (λ 下界、门控限制、归一化) 增添工程负担。

未来方向

  • CG 热启动: 跨邻近时间步复用上次解减少迭代。
  • 混合运行策略: 训练用分块并行 CG,推理时在遗忘较小情形使用 Sherman–Morrison 更新。
  • 学习求解器: 以可学习神经解算器或参数化预处理器替代普通 CG。
  • 架构协同设计: 优化 RNN 式序列混合的主干结构 (如调整 MLP 或位置设计) 以挖掘更高潜力。

最终思考

MesaNet 是清晰的概念与工程突破: 它将数值稳定、分块并行的测试时优化层应用于大规模语言模型的训练与推理。Mesa 层具体实现了局部最优测试时训练的思想: 在每个 token 上寻找最能解释当前上下文的线性模型。

实验揭示更深层规律: RNN 式快速权重与线性注意力模型仍然在局部快速适应与早期预测方面表现突出,而 Transformer 在后期预测和全局回忆上持续占优——这恰是恒定内存 RNN 本应擅长的领域。这种张力不是 Mesa 的缺陷,而是一种诊断。MesaNet 指明了未来研究方向: 热启动或学习型求解器、循环与注意力混合架构、以及能更好保留全局信息的结构优化。

MesaNet 提醒我们: 序列建模的空间依旧广阔——优化问题、迭代推理与硬件感知并行化都是可重塑性能与效率权衡的关键要素。若你正构建超长上下文模型或寻求测试时自适应计算能力,MesaNet 是一个值得深入研究与借鉴的模型家族。

致谢与参考文献在此略去;论文为 “MesaNet: Sequence Modeling by Locally Optimal Test-Time Training” (von Oswald 等人,2024) ,其中包含完整实验细节、证明与开源实现说明。