想象一下,你是一位天体物理学家,任务是为不同行星上的物体运动建模。你可以为每颗行星——地球、火星、木星——分别建立一个模拟器。但这太浪费了。物理定律是普适的;只有一个参数,即引力常数,会因行星而异。一个更聪明的策略是建立一个通用模型,并通过从少量样本中估计每颗行星的引力来使其适应。

这种在相关任务间复用共享结构的思想,正是**摊销学习 **(amortized learning) 的核心。摊销学习不再从零开始学习一切,而是捕捉共性模式并复用它们,从而更高效地解决新问题。这一原则推动了现代人工智能的发展,从学会优化的元学习系统到通过上下文示例解决新任务的大型语言模型 (LLM) 。

虽然这些方法理念相通,但历史上它们似乎彼此独立。2025 年的一篇论文《迭代式摊销推断: 统一上下文学习与学习型优化器》弥合了这一鸿沟,提出了统一的数学框架、清晰的分类体系以及一种可扩展的方案,克服了现有方法的关键限制。

本文将带你深入探讨这项工作——展示从 MAML 到 GPT 风格的上下文学习如何都融入一个优雅的方程中,以及一个受随机梯度下降启发的简单思想如何让摊销模型能够优雅地扩展到大规模数据集。


追求快速适应: 一个统一的视角

在传统的机器学习中,我们为每个任务训练一个独立的模型。给定任务 \( \mathcal{T} \) 的数据集 \( \mathcal{D}_{\mathcal{T}} \),训练过程会找到使损失最小化的模型参数 \( \hat{\boldsymbol{\theta}}_{\mathcal{T}} \):

\[ \hat{\boldsymbol{\theta}}_{\mathcal{T}} = \arg\min_{\boldsymbol{\theta}} \sum_{(\mathbf{x}, \mathbf{y}) \in \mathcal{D}_{\mathcal{T}}} \mathcal{L}\left(\mathbf{y}, f(\mathbf{x}, \boldsymbol{\theta})\right) \]

这种方法适用于单个任务,但无法复用知识。例如,一个被训练来识别猫和狗的分类器,在被要求识别马和斑马时,几乎学不到任何帮助性的东西。它忽略了“视觉分类”这一共享结构。

摊销学习通过在任务分布上进行训练来改进这一点,从而可以用极少的数据快速适应未见任务。作者通过一个统一的方程整合了这些系统:

\[ \min_{\gamma, \boldsymbol{\varphi}} \mathbb{E}_{\mathcal{T}} \mathbb{E}_{\mathbf{x}, \mathbf{y}, \mathcal{D}_{\mathcal{T}}} \left[ \mathcal{L} \left( \mathbf{y}, f_{\gamma}\left( \mathbf{x}, g_{\boldsymbol{\varphi}}(\mathcal{D}_{\mathcal{T}}) \right) \right) \right] \]

其中:

  1. 适应函数 \( g_{\boldsymbol{\varphi}} \): 接收任务的训练数据 \( \mathcal{D}_{\mathcal{T}}^{\text{train}} \),生成任务表示 \( \boldsymbol{\theta}_{\mathcal{T}} \) (可以是权重、潜在向量或原始样本) 。它具有可学习参数 \( \boldsymbol{\varphi} \)。

  2. 预测函数 \( f_{\gamma} \): 接收查询 \( \mathbf{x} \) 和任务表示 \( \boldsymbol{\theta}_{\mathcal{T}} \) 来生成预测。它拥有共享参数 \( \gamma \),这些参数编码了跨任务的归纳偏置。

根据 \( f_{\gamma} \) 与 \( g_{\boldsymbol{\varphi}} \) 的配置方式,这一统一框架可涵盖几乎所有主要学习范式——从普通监督学习到元学习和上下文学习。

摊销学习器的功能分解

表 1. 摊销学习器的功能分解。每种方法都可用适应函数 \( g_\varphi \) (生成 \(\theta_T\)) 与预测函数 \(f_\gamma\) 来描述。

我们来看一些例子:

  • 监督学习: \( g_{\boldsymbol{\varphi}} \) 是标准的 SGD,\( f_{\gamma} \) 是固定架构 (如 ResNet) ,任务之间不存在摊销。
  • MAML: \( g_{\boldsymbol{\varphi}} \) 是从可学习元权重 \( \boldsymbol{\theta}_0 \in \boldsymbol{\varphi} \) 初始化的 SGD。
  • 学习型优化器: \( g_{\boldsymbol{\varphi}} \) 本身是神经网络,根据梯度提出参数更新。
  • 上下文学习 (ICL): 适应函数为恒等映射——上下文示例直接输入到 Transformer \( f_{\gamma} \) 中,所有适应过程都在前向传播中隐式完成。

这一统一视角提供了理解学习系统复用知识的共同语言,并揭示出它们的核心差异在于学习的内容: 是初始化、更新还是映射。


摊销的分类: 参数化、隐式与显式

基于上述统一观点,作者进一步给出了摊销学习的三大类型。

1. 参数化摊销 (Parametric Amortization)

此时 \( f \) 是固定的,而 \( g_{\boldsymbol{\varphi}} \) 是可学习的。系统学习如何为一个预定义模型推断其参数。

  • 示例: 学习型优化器、超网络。
  • 机制: \( g_{\boldsymbol{\varphi}} \) 将数据映射为参数,例如线性模型的权重。
  • 优点: 高效利用梯度,参数具可解释性。
  • 缺点: 因 \( f \) 固定,表达能力受限。

2. 隐式摊销 (Implicit Amortization)

相反,\( f_{\gamma} \) 是可学习的,而 \( g \) 是**固定的 **(通常为恒等函数) 。模型自身内化了适应过程。

  • 示例: 上下文学习、先验拟合网络 (Prior-Fitted Networks)。
  • 机制: 单个网络 \( f_{\gamma} \) 联合处理查询与上下文来预测输出。
  • 优点: 表达能力强,可学习复杂行为。
  • 缺点: 计算昂贵且不透明;每次查询都需重新处理整个数据集。

3. 显式摊销 (Explicit Amortization)

此时 \( f_{\gamma} \) 与 \( g_{\boldsymbol{\varphi}} \) 均可学习,是一种混合方法。

  • 示例: 条件神经过程 (CNPs)。
  • 机制: \( g_{\boldsymbol{\varphi}} \) 将任务数据集压缩成潜在嵌入,\( f_{\gamma} \) 使用该嵌入进行预测。
  • 优点: 平衡了灵活性与可解释性。
  • 缺点: 训练更困难,因为两部分会动态地相互影响。

可扩展性问题

尽管形式不同,大多数摊销学习器共享一个弱点:** 可扩展性不足**。

  • 隐式模型受限于 Transformer 的上下文长度。
  • 参数化和显式模型依赖于池化摘要或梯度,丢失了细粒度的任务信息。

当数据集规模变得庞大时,这些模型就会不堪重负。标准优化通过在随机梯度下降 (SGD) 中使用*小批量 *(mini-batches) 来解决规模问题——逐步优化参数。我们能否将同样的思想应用于摊销学习?


迭代式摊销推断: 在小批量中学习

作者提出了**迭代式摊销推断 **(Iterative Amortized Inference,IAI) ——一种可扩展的方法,其中摊销过程本身以小批量的方式迭代进行,就像 SGD 逐步优化参数一样。

对于参数化和显式模型

模型从初始状态 \( \boldsymbol{\theta}^{(0)} \) 开始,并通过学习到的更新函数 \( h_{\boldsymbol{\varphi}} \) 逐步优化:

\[ \boldsymbol{\theta}^{(0)} \xrightarrow{h_{\boldsymbol{\varphi}}(\cdot, \mathcal{B}_{\text{train}}^{(0)})} \boldsymbol{\theta}^{(1)} \xrightarrow{h_{\boldsymbol{\varphi}}(\cdot, \mathcal{B}_{\text{train}}^{(1)})} \dots \xrightarrow{h_{\boldsymbol{\varphi}}(\cdot, \mathcal{B}_{\text{train}}^{(k-1)})} \boldsymbol{\theta}^{(k)} \eqqcolon \boldsymbol{\theta}_{\mathcal{T}} \]

与仅基于梯度操作的学习型优化器不同,IAI 允许基于原始数据、梯度或两者进行更新,使其更为灵活。

对于隐式模型

模型直接优化预测。从 \( \hat{\mathbf{y}}^{(0)} \) 出发,循环 Transformer \( r_{\gamma} \) 连续更新:

\[ \hat{\mathbf{y}}^{(0)} \xrightarrow{r_{\gamma}([\mathbf{x}, \hat{\mathbf{y}}^{(0)}], \mathcal{B}_{\text{train}}^{(0)})} \hat{\mathbf{y}}^{(1)} \xrightarrow{r_{\gamma}([\mathbf{x}, \hat{\mathbf{y}}^{(1)}], \mathcal{B}_{\text{train}}^{(1)})} \cdots \xrightarrow{r_{\gamma}([\mathbf{x}, \hat{\mathbf{y}}^{(k-1)}], \mathcal{B}_{\text{train}}^{(k-1)})} \hat{\mathbf{y}}^{(k)} \]

迭代式摊销推断示意图

图 1. 参数化、显式及隐式设置下的迭代式摊销推断。参数化/显式模型迭代优化共享任务状态,而隐式模型则直接逐步更新预测。

为了高效训练这些模型,作者仅优化单步改进,不对历史迭代反向传播。该贪心策略简单、稳定且易扩展。


实验: 迭代优化的回报

大量回归、分类和生成实验表明,多次迭代优化步骤会持续提升性能。

参数化摊销

参数化摊销结果

表 2. 参数化摊销结果。随着迭代步数增加,各任务误差持续下降。灰色的分布外 (OoD) 列显示模型在新数据集上的更佳迁移性能。

显式摊销

显式摊销结果

表 3. 显式摊销在增加优化步骤后稳定收益,尤其在利用梯度信号时表现更佳。

隐式摊销

隐式摊销结果

表 4. 迭代预测优化显著降低了多种分类任务的误差。

生成建模

在更复杂的生成设定中,迭代过程提升了样本质量。模型能够更好地重建字母表和高斯混合的潜在分布,经过 10 步优化后结构更加清晰。

生成样本对比

图 2. 隐式生成模型在高斯混合与字母表任务中,1 步与 10 步结果对比,显示向真实分布逐步逼近的优化效果。


关键分析洞见

  1. 梯度 vs. 数据: 仅依赖梯度信号虽高效但有限,将梯度与原始观测结合可提供更丰富的信息,从而提升泛化性能。
  2. 近期历史足够: 提供多个历史状态几乎无额外收益;仅依赖最新状态 (马尔可夫式设计) 即能取得良好效果。

过去状态的影响

图 3. 使用 3 或 5 个过去状态的模型并未优于仅使用单状态更新。

  1. 参数化模型更稳定: 固定 \( f_{\gamma} \) 的模型通常胜过显式模型,凸显了同时训练两网络的优化挑战。
  2. 效率: 迭代式摊销的复杂度随批量数 \(K\) 线性增长。处理 \(K\) 个大小为 \(B\) 的批次,其成本为 \(O(KB^2)\) 而非 \(O((KB)^2)\),使 IAI 在数据与内存利用上更高效。

结论: 迈向可扩展的自适应学习系统

迭代式摊销推断框架优雅地将元学习、上下文学习、以及学习型优化器统一在一个数学公式下:

\[ f_{\gamma}(\mathbf{x}, g_{\boldsymbol{\varphi}}(\mathcal{D}_{\mathcal{T}})) \]

通过将摊销扩展为迭代过程——借鉴了随机优化的成功经验——作者令模型在保持快速适应的同时,能扩展至大规模数据集。这一思路桥接了基于优化和基于前向传播的两种范式,揭示它们是实现同一目标的互补途径: 高效地跨任务复用归纳偏置。

展望未来,这种迭代视角为更丰富的摊销学习器铺平了道路——这些模型能持续自我完善,正如人类推理般,一次一批地学会学习