LaCT: 为什么对于测试时训练和长上下文AI,越大越好

处理和理解长序列信息的能力——无论是长篇文档、高分辨率图像集,还是长视频——是人工智能发展的关键前沿之一。Transformer彻底改变了神经网络处理序列数据的方式,但其核心的自注意力机制的计算复杂度随序列长度呈二次方增长,使其在处理长上下文时效率低下。这推动了一系列寻找更快且内存更高效架构的研究。

一个有前途的方向是测试时训练 (Test-Time Training, TTT),其灵感来源于循环神经网络 (RNN)。TTT模型包含一个小型自适应子网络,其参数——称为快速权重——会在推理过程中动态更新。这些快速权重充当上下文记忆,帮助模型保留序列中之前标记 (token) 的信息。不幸的是,大多数现有方法更新这些快速权重的频率过高——每几个标记就更新一次——导致GPU利用率很低,严重限制了其扩展性。硬件峰值利用率常降至5%以下,使得长上下文建模极其缓慢。

在论文《正确实现测试时训练》(Test-Time Training Done Right)中,来自麻省理工学院与Adobe研究院的研究者提出了颠覆性的理念。他们的方法称为大块测试时训练 (Large Chunk Test-Time Training, LaCT),通过使用范围从数千到数百万标记的大数据块来更新内存,而非微小批次。这个看似简单却极具变革性的改动极大提高了效率,使得更大的记忆状态成为可能,并能扩展到多模态任务,如百万标记图像合成和140亿参数视频生成。接下来,让我们探究LaCT的工作原理,以及为何这种“要么做大,要么回家”的方法革新了长上下文AI。


快速回顾: 什么是测试时训练?

传统神经网络在训练阶段学习参数 (“慢速权重”) ,在推理阶段保持这些权重不变。TTT引入了第二个带有快速适应能力参数的神经网络,称为快速权重,其在推理过程中不断变化。这个快速权重学习器相当于一个临时的记忆缓冲区,用于存储当前序列中元素之间的关系。

TTT的核心是反复执行两个操作:

  1. 更新操作:
    快速权重网络 \(f_W\) 更新其权重,使输入向量 \(k\) 与相应的向量 \(v\) 对齐:

    \[ W \leftarrow W - \eta \nabla_W \mathcal{L}(f_W(k), v) \]

    其中 \(\mathcal{L}\) 是自监督损失 (通常为均方误差或点积损失) ,\(\eta\) 是学习率。这一步将信息“写入”内存。

  2. 应用操作:
    更新后,\(f_W\) 处理一个查询向量 \(q\),以产生输出:

    \[ o = f_W(q) \]

    这对应于从内存中读取信息。

与RNN类似,该过程使模型能够在长序列上积累上下文知识。然而,每隔几个标记就更新会导致严重的效率问题,因为GPU更擅长处理大型并行工作负载,而小批量TTT无法提供这种负载。


传统TTT的瓶颈

传统TTT的硬件利用率低下源于其极小的更新批次。一个简单的快速权重更新的计算与内存访问比可表示为:

\[ r = \frac{2h^2b}{2h^2 + 4hb} = \frac{b}{1 + \frac{2b}{h}} \leq \min\left(\frac{h}{2}, b\right) \]

其中 \(h\) 是隐藏维度,\(b\) 是数据块大小。当 \(b\) 很小的时候,GPU大部分时间都花在数据传输上,而非计算。

相比之下,LaCT将 \(b\) 大幅提高至数千甚至数百万标记,使操作变为计算密集型而非内存密集型。结果是: 吞吐量提升了几个数量级,并完全可以在原生PyTorch代码中实现——无需任何自定义内核。

LaCT实现了比传统TTT方法高得多的GPU吞吐量。(a) 硬件利用率随块大小急剧增长,(b) 从而能够高效扩展到更大的内存状态,(c) 改善验证损失,以及(d) 在更短的训练时间内获得更高的峰值信噪比 (PSNR) 。

在LaCT中,更大的数据块显著提高了GPU利用率,支持更大的快速权重内存状态,并在多个基准上同时提升速度与准确性。


LaCT架构: 融合局部与全局上下文的混合设计

LaCT架构将两个互补组件——局部注意力与长程记忆——融合为统一的框架。每个LaCT块包含三个主要部分:

LaCT块示意图,展示并行的窗口注意力、大块TTT层和前馈层,快速权重W在不同块之间流动。

一个LaCT块。窗口注意力捕捉局部细节;大块TTT以高效方式处理长程上下文;前馈层则像Transformer一样混合通道信息。

  1. 窗口注意力 (Window Attention): 使用标准的注意力机制,但仅在有限窗口内捕捉局部依赖,例如相邻的单词或像素。
  2. 大块TTT层 (Large-Chunk TTT Layer): LaCT的核心。它将输入分割成大数据块,使用整个块的所有标记更新快速权重,然后应用更新后的权重生成输出。
  3. 前馈网络 (Feed-Forward Network): 一个标准的Transformer通道混合层。

这种混合设计结合了注意力的局部精度与大块TTT的全局记忆,实现了高准确率并具备线性时间的可扩展性。


深入理解大块TTT层

与逐标记更新不同,大块TTT层在整个数据块上计算一个单一梯度,然后执行一次权重更新:

\[ g = \nabla_W \sum_{i=1}^{b} \eta_i \, \mathcal{L}(f_W(k_i), v_i) \]

\[ W \leftarrow \text{weight-update}(W, g) \]

在此过程中,块内的所有标记共享相同的更新后快速权重,用于处理其查询。快速权重网络采用 SwiGLU-MLP 结构:

\[ f_W(x) = W_2[\mathrm{SiLU}(W_1x) \circ (W_3x)] \]

并使用负点积损失将键和值关联:

\[ \mathcal{L}(f_W(k_i), v_i) = - f_W(k_i)^{\top} v_i \]

灵活的更新–应用顺序

LaCT的强大之处在于其灵活的更新应用操作顺序,可根据任务需求模拟不同的注意力模式。

不同更新-应用调度方式的图示,展示生成的不同有效注意力掩码 (全注意力、块级因果、移位因果和跨步因果) 。

不同的更新–应用顺序可生成适用于各类数据模态的有效注意力掩码。

  • 全掩码 (Full Mask): 更新后应用——在块内支持双向注意力。
  • 块级因果 (Block-Wise Causal): 更新与应用交替进行——在块之间保持因果关系。
  • 移位块级因果 (Shifted Block-Wise Causal): 先应用再更新——非常适合语言模型,避免未来标记泄漏。
  • 跨步块级因果 (Strided Block-Wise Causal): 仅在特定上下文块上更新——理想用于新视角合成等任务。

更智能的记忆更新: Muon优化器

由于数据块很大,更新频次较低但幅度很大——这使更复杂的非线性优化器成为可能,而不会带来性能开销。其中,Muon尤为突出。它在更新前将梯度正交化,并归一化其谱范数,以提升稳定性:

\[ \text{weight-update}(W, g) = \mathrm{L2\text{-}Normalize}(W - \mathrm{Muon}(g)) \]

\[ \mathrm{Muon}(g) \approx U V^{T} \quad \text{for } g = U \Sigma V^{T} \]

这种归一化有助于稳定学习过程,降低对步长的敏感性,并增强对长程信息的保留——所有这些得益于LaCT的数据块级效率。


进一步扩展: 上下文并行

LaCT的设计天然支持上下文并行 (Context Parallelism),即将长序列分割到多个GPU上。每个设备处理一个数据块中的部分标记,并通过分布式all-reduce聚合梯度:

\[ g = \sum_{j=1}^{\text{shards}} \nabla_W \sum_{i=1}^{s} \eta_i \, \mathcal{L}_i \]

这种机制可无缝扩展至百万标记级的上下文和超大模型,包括数十亿参数的视频Transformer——无需自定义内核或特定硬件优化。


LaCT在不同模态中的应用

作者在三个不同领域测试了LaCT,以展示其高度灵活性。

LaCT在不同任务 (新视角合成、语言建模和视频扩散) 中的设置摘要,展示块大小、状态大小和并行策略。

通过调整块大小、状态规模和并行模式,LaCT能高效适应各种截然不同的数据类型。

1. 新视角合成 (图像集)

任务: 从多张带位姿信息的输入图像中渲染3D场景的新视角。
LaCT设计: 将所有输入图像视为一个巨大的数据块——最多包含一百万个标记。使用所有输入更新快速权重一次,然后应用这些权重生成新视角。
结果: 渲染质量与全注意力模型相当,但速度提高十倍以上,并在密集场景数据集上优于3D高斯溅射 (3D Gaussian Splatting)。

新视角合成性能比较——LaCT实现了与全注意力相当的质量,同时延迟更低,并能扩展至百万标记,超越高斯溅射。

LaCT在质量上与全注意力模型相当,同时显著减少了填充时间,并成功扩展至百万级标记的上下文。


2. 语言建模 (文本序列)

任务: 对长上下文的序列进行自回归下一标记预测。
LaCT设计: 将文本分割为较大的固定块 (2K–4K标记) 。采用“移位块级因果”调度以避免未来标记泄漏,并结合滑动窗口注意力获取局部信息。
结果: 在7.6亿和30亿参数模型中,LaCT (尤其是结合Muon优化) 在长序序列上取得更低损失,并在“大海捞针”任务中比高效基线模型DeltaNet和Gated Linear Attention达到更高的检索准确率。

语言建模结果: LaCT在较长序列位置保持更低验证损失,并在检索准确率上优于基线模型。

随着序列长度增加,LaCT模型保持更低的损失和更高的检索准确率,超越了当前高效的亚二次方基线模型。


3. 自回归视频扩散 (图像序列)

任务: 通过逐帧块的自回归去噪生成连贯的长视频。
LaCT设计: 交替排列干净与带噪声帧块:

\[ S = [X_1^{\mathrm{noise}}, X_1, X_2^{\mathrm{noise}}, X_2, \dots, X_N^{\mathrm{noise}}] \]

仅在干净帧上更新快速权重,然后应用这些权重去噪后续带噪帧块——确保因果一致性。
结果: 微调的140亿参数视频扩散模型在验证损失上与全注意力相当,但超越基于Mamba与滑动窗口的高效模型,并能高效生成包含高达5.6万视觉标记的稳定视频。

视频生成结果: LaCT在不同窗口和长视频长度上均取得与全注意力相当的损失,同时优于高效基线。

LaCT在保持高效率的同时实现了与全注意力模型相当的生成质量,可扩展至超过5万标记的长视频片段。


LaCT为何有效: 来自消融实验的洞见

分析结果揭示了LaCT成功的关键设计要素:

  • 状态越大越好: 扩展快速权重内存 (状态尺寸) 能显著提升各类任务表现。最大配置的快速权重占模型总参数的40%。
  • Muon优化器表现出色: Muon持续优于标准梯度下降与动量优化器,带来更快收敛与更好稳定性。
  • 非线性快速权重更优: SwiGLU-MLP结构的快速权重明显优于线性映射,即使线性模型拥有更多参数。
  • 大块循环优于逐标记循环: 在图像任务中,大块循环优于像Mamba-2这样的逐标记线性循环模型;结合Muon与非线性状态后,在语言建模中也超越逐标记基线。

(a) 性能随着快速权重状态增大而提升;(b) Muon优化器持续优于梯度下降方法。

扩展快速权重状态并使用Muon更新均显著提升了准确率。

快速权重设计与循环机制比较: 非线性SwiGLU快速权重优于线性权重;结合大型非线性状态与Muon时,大块循环性能超越逐标记循环。

非线性记忆网络与大数据块循环在各模态中均领先于传统逐标记循环。


结论: 重新思考长上下文AI的效率

LaCT重新定义了长上下文建模的效率。通过从频繁、微小的权重更新转向更少但更大的数据块更新,它显著提升了并行硬件的利用率。这种效率使更大、更具表现力的记忆状态成为可能——让非线性快速权重和先进优化器如Muon充分发挥作用。

LaCT的混合架构——结合局部注意力与大块记忆——已成功应用于图像、文本与视频领域,展现了强大的通用性。其简洁、无内核的PyTorch实现让研究者能够轻松实验,进一步推动这一新兴领域的普及。

《正确实现测试时训练》不仅是性能上的提升,更是我们理解记忆与效率方式的一次范式转变。LaCT证明,在探索更智能的长上下文AI之路上,做得更大才是最好的捷径。