现代机器学习的蓬勃发展得益于*摊销 *(amortization) 思想——一次性训练大型模型,使其能够即时应用于许多新问题。像 GPT-4 或 Stable Diffusion 这样的预训练模型正是这一原则的体现: 通过从海量数据中学习通用结构,它们能够快速适应各种任务。基于 Transformer 的架构,如神经过程 (Neural Processes) ,将这一概念扩展到概率元学习领域,实现了跨不同领域的带不确定性预测。
然而,这些方法面临一个主要限制: 僵化。大多数模型被限定在“给定 X,预测 Y”这类任务形式中。现实世界的问题远不止如此——有时我们可能只掌握部分数据,并对隐藏参数有信念,希望同时预测未观测数据及这些潜在量。传统方法很少允许在推理过程中动态地融入这类知识 (即所谓的先验) 。例如,在贝叶斯优化中,我们寻找最小值的位置和数值;在科学建模中,我们推断模拟器的参数。每种情况通常都需要一个定制的、计算代价高昂的解决方案。
来自赫尔辛基大学和阿尔托大学的研究人员在他们的论文中提出了摊销式条件化引擎 (Amortized Conditioning Engine, ACE) ,正面解决这一问题。ACE 是一种基于 Transformer 的架构,它将概率的条件化与预测统一成一个灵活的操作。它能够以任意的观测数据与可解释隐变量的组合为条件——甚至可以在运行时接受概率性输入,以预测其他任意的数据或隐变量组合。本质上,ACE 是一个通用的概率推理引擎。
图 1: 各类任务——从图像补全到贝叶斯优化及基于模拟的推理——都可理解为对已知量 (数据或隐变量) 进行条件化来预测未知量。
背景: 从神经过程到 Transformer
要理解 ACE 的创新之处,我们必须先了解它的渊源。神经过程 (Neural Processes, NPs) 学习函数分布。给定若干观测输入输出对——称为上下文集
\(\mathcal{D}_N = \{(\mathbf{x}_1, y_1), \dots, (\mathbf{x}_N, y_N)\}\)—以及一组目标输入 \(\mathbf{x}_{1:M}^*\),NPs 能预测未知目标输出 \(y_{1:M}^*\) 的分布。
NPs 的核心性质是置换不变性: 模型的预测不应依赖上下文点的排列顺序。早期变体如条件神经过程 (Conditional Neural Processes, CNPs) 通过简单平均将上下文压缩成一个嵌入向量,从而实现这一点。
Transformer 改变了局面。其注意力机制天然符合置换结构,并能捕捉数据点之间复杂的依赖关系。这推动了Transformer 神经过程 (Transformer Neural Processes, TNPs) 和先验拟合网络 (Prior-Fitted Networks, PFNs) 等发展。这类模型利用自注意力编码上下文,并通过交叉注意力进行预测查询。
这些基于 Transformer 的神经过程模型是*对角式 *(diagonal) 的,即对每个目标点独立预测:
\[ \pi(y_{1:M}^{\star}|\mathbf{x}_{1:M}^{\star};\mathcal{D}_N) = \prod_{m=1}^{M} p\big(y_m^{\star}|\mathbf{r}(\mathbf{x}_m^{\star},\mathbf{r}_{\mathcal{D}}(\mathcal{D}_N))\big) \]尽管是独立预测,联合分布仍可通过自回归重建,使得这一系列模型——Transformer 预测映射 (Transformer Prediction Maps, TPM-D) ——既灵活又强大。ACE 在此体系之上引入了新的能力。
核心概念: 摊销式条件化引擎内部机制
ACE 将数据点的概念推广,既覆盖观测值,也涵盖描述任务的隐变量。
假设一个问题涉及 \(L\) 个隐变量,
\(\boldsymbol{\theta} = (\theta_1, \ldots, \theta_L)\)。
在 ACE 中,每个对 \((\boldsymbol{\xi}, z)\) 可表示以下两类:
- 数据点: \((\mathbf{x}, y)\),其中 \(\mathbf{x}\) 是输入,\(y\) 是其对应值。
- 隐变量: \((\ell_l, \theta_l)\),其中 \(\ell_l\) 是标识隐变量 \(l\) 的标记 (token) 。
这种统一表示使 ACE 能在任何已知组合 (数据或隐变量) 条件下预测任意目标组合。形式上:
\[ \pi(z_{1:M}^{\star}|\boldsymbol{\xi}_{1:M}^{\star};\boldsymbol{\mathfrak{D}}_N) = \prod_{m=1}^{M} p\big(z_m^{\star}|\mathbf{r}(\boldsymbol{\xi}_m^{\star},\mathbf{r}_{\mathcal{D}}(\boldsymbol{\mathfrak{D}}_N))\big) \]该形式将概率推理转化为一个任务间共享的统一结构化计算。
架构: 细节决定成败
ACE 在 TPM-D 之上进行了关键升级,以支持隐信息与用户提供的先验。
图 2: ACE 架构: 嵌入层同时处理数据与隐变量,而输出头可针对连续或离散预测灵活调整。
1. 通用嵌入
所有输入——数据点、隐变量或先验——均被编码到同一维度为 \(D_{\text{emb}}\) 的嵌入空间:
- 数据点: \((\mathbf{x}_n, y_n)\) → \(f_{\mathbf{x}}(\mathbf{x}_n) + f_{\text{val}}(y_n) + \mathbf{e}_{\text{data}}\)
- 隐变量: \(\theta_l\) → \(f_{\text{val}}(\theta_l) + \mathbf{e}_l\)
- 未知目标: 用学习到的 \(\mathbf{e}_{?}\) 替换值嵌入
- 先验信息: 关于隐变量值的概率向量 \(\mathbf{p}_l\) 经 \(f_{\text{prob}}(\mathbf{p}_l) + \mathbf{e}_l\) 处理
这使得任何元素——观测、隐变量或先验——都能无缝嵌入模型空间。
2. 注意力层
ACE 通过堆叠的 Transformer 层处理嵌入,包括自注意力 (编码) 与交叉注意力 (解码) 。上下文自注意力生成联合表示,而目标交叉注意力则连接这些表示与预测查询。设计具备高效计算性能,其复杂度为 \(O(N^2 + NM)\),优于朴素方案的 \(O((N+M)^2)\)。
3. 输出头
ACE 为每个目标元素输出预测分布:
- 连续输出: 使用高斯混合模型 (Gaussian Mixture Model, GMM) 以学习多峰分布。
- 离散输出: 通过 softmax 概率生成类别分布。
学会用先验推理
ACE 最引人注目的特性之一,是支持运行时概率先验。用户可将对隐变量值的信念 (如“最优值在 0.5 附近”) 以概率分布 \(p(\theta_l)\) 的形式输入。这些分布被离散化为概率直方图,并与其他信息同样嵌入。
图 3: ACE 执行先验摊销——将用户提供的先验与数据结合,在一次前向传播中近似真实贝叶斯后验。
在训练阶段,ACE 会遇到随机生成的先验,并学习它们如何与数据证据融合。它优化预测目标的负对数似然:
\[ \mathcal{L}(\mathbf{w}) = \mathbb{E}_{\mathbf{p}\sim\mathcal{P}}\Big[\mathbb{E}_{\mathfrak{D}_N,\boldsymbol{\xi}_{1:M},\mathbf{z}_{1:M}}\Big[-\sum_{m=1}^M \log q(z_m^{\star}|\mathbf{r}_{\mathbf{w}}(\boldsymbol{\xi}_m^{\star},\mathfrak{D}_N))\Big]\Big] \]最小化该损失使模型输出与其学习到的生成问题族的贝叶斯后验保持一致,从而能够在未知任务上无需重新训练即可进行推理。
ACE 的实践: 跨领域实验
1. 视觉——图像补全与分类
在计算机视觉中,预测可视为一种回归任务: 给定部分像素坐标与值,预测缺失像素。隐变量对应类别标签 (MNIST) 或属性 (CelebA) 。
ACE 能:
- 补全图像: 在部分上下文下预测缺失像素。
- 条件生成: 根据特征 (如“秃头 = 是”) 生成图像。
- 分类: 从部分数据中推断隐属性。
图 4: 在图像补全任务中,ACE 的表现优于 TNP-D。以隐属性为条件进一步提升了质量与似然得分。
ACE 不仅在重建质量上超越基线,还能在条件生成和分类间自由切换,取决于何者被设定为“上下文”或“目标”。
2. 优化——带上下文先验的贝叶斯搜索
贝叶斯优化 (Bayesian Optimization, BO) 旨在通过较少评估定位未知函数的全局最小值。传统方法依赖高斯过程与采集函数采样。
ACE 在其概率条件化框架中重构 BO: 最优位置 \(x_{\text{opt}}\) 与最优值 \(y_{\text{opt}}\) 被显式视为隐变量。模型学习它们的闭式预测分布,从而无需复杂的近似。
图 5: ACE 可直接预测函数行为与最优位置,实现高效优化。
ACE 能以优雅方式实现采集函数:
- 汤普森采样 (Thompson Sampling, TS) : 采样一个低于当前最佳值的乐观 \(y_{\text{opt}}\),再以此条件提出 \(\mathbf{x}_{\text{opt}}\)。
- 最大值熵搜索 (Max-Value Entropy Search, MES) : 直接计算关于 \(y_{\text{opt}}\) 的信息增益,使用解析分布替代昂贵近似。
图 6: 在黑盒优化任务中,ACE 的性能可与高斯过程基线持平或更佳。
引入先验进一步增强 ACE。用户若提供最优位置的信念,ACEP 变体可在推理时利用这些先验指导探索。
图 7: 注入先验后,ACEP 加快收敛,表现与 πBO-TS 相当。
3. 科学模型——基于模拟的推理
基于模拟的推理 (Simulation-Based Inference, SBI) 致力于识别能生成观测数据的模型参数——这是依赖模拟器但无可解似然的科学领域的关键。
ACE 在单一框架内融合前向与反向推理:
- 预测 \(p(\boldsymbol{\theta}|y)\): 参数后验分布。
- 预测 \(p(y|\boldsymbol{\theta})\): 给定参数的数据分布。
- 填补缺失数据: 条件数据预测。
模型 | 指标 | NPE | NRE | Simformer | ACE | ACEP弱 | ACEP强 |
---|---|---|---|---|---|---|---|
OUP | 对数概率 ↑ / RMSE ↓ / MMD ↓ | 1.09 (0.10) / 0.48 (0.01) / - | 1.07 (0.13) / 0.49 (0.00) / - | 1.03 (0.04) / 0.50 (0.02) / 0.43 (0.02) | 1.03 (0.02) / 0.48 (0.00) / 0.51 (0.00) | 1.05 / 0.43 / 0.37 | 1.44 / 0.27 / 0.35 |
SIR | 对数概率 ↑ / RMSE ↓ / MMD ↓ | 6.53 / 0.02 / - | 6.24 / 0.03 / - | 6.89 / 0.02 / 0.02 | 6.78 / 0.02 / 0.02 | 6.62 / 0.02 / 0.02 | 6.69 / 0.02 / 0.00 |
Turin | 对数概率 ↑ / RMSE ↓ / MMD ↓ | 1.99 / 0.26 / - | 2.33 / 0.28 / - | 3.16 / 0.25 / 0.35 | 3.14 / 0.24 / 0.35 | 3.58 / 0.21 / 0.35 | 4.87 / 0.13 / 0.34 |
表 1: ACE 的表现与专业 SBI 模型相当或更佳。当提供先验 (ACEP) 时,参数推断进一步提升。
ACE 优势不仅在准确性,也在效率: 生成 1000 个样本耗时仅毫秒级,而扩散模型 Simformer 需数分钟。结合信息充分的先验,ACEP 可实现更高精度与更好校准。
反思与未来方向
摊销式条件化引擎为曾被视为迥异的任务——视觉、优化与科学推理——提供了统一范式。通过将数据点与隐变量视为等价实体,它实现了二者间的流畅条件化与预测,全由一个摊销式 Transformer 模型完成。
要点总结:
- 通用性: ACE 可无缝处理回归、分类、优化与基于模拟的推理任务。
- 灵活性: 用户可定义任意条件与预测组合,无需更改架构。
- 人机协作: 专家可提供概率先验,引导模型交互式推理。
局限与展望:
与所有基于注意力的系统一样,ACE 的计算复杂度随上下文规模呈二次方增长。未来研究可能借助亚二次注意力或稀疏条件化。扩展多隐变量先验注入与发现可解释隐变量亦是值得探索的方向。
综上,ACE 展示了统一摊销推理的强大潜力,指向一个未来: 一个通用模型只需对正确的信息进行条件化,就能灵活服务于优化、视觉与科学发现。