引言: 对齐复杂数据分布的挑战

想象一下,你有两组图像: 一组是模糊的照片,另一组是清晰的高分辨率照片。你将如何教一个模型将任何一张模糊照片转换成逼真的清晰版本?或者考虑将夏日风景图转换为冬日雪景图。这些都是现代机器学习中根本性挑战的例子: 找到一种有意义的方式,将一个复杂的概率分布映射到另一个。

最优传输 (Optimal Transport, OT) 为这项任务提供了严谨的数学框架。OT 旨在寻找从一个分布到另一个分布的最高效映射,以最小化给定的运输成本。尽管功能强大,但标准 OT 给出的映射是单一且确定的。对于像图像超分辨率这样的不适定问题,同一个低分辨率图像可能对应许多合理的高分辨率输出。我们需要的是随机映射——一对多的变换,既能生成多样化的结果,又保持真实感。

熵正则化最优传输 (Entropic Optimal Transport, EOT) 通过在 OT 问题中加入随机性 (熵) 来解决这一问题,其程度由正则化参数 \( \varepsilon \) 控制。较大的 \( \varepsilon \) 会增加多样性;较小的 \( \varepsilon \) 则趋近于确定性 OT。但问题在于,小 \( \varepsilon \) 恰恰是在高质量、可控生成中最有用的区间,而大多数现有的 EOT 算法在这一情况下会变得不稳定或不可行。

NeurIPS 2023 论文 “Entropic Neural Optimal Transport via Diffusion Processes” 提出了一种名为 ENOT 的方法,它是稳健的端到端 EOT 神经求解器,即使在小 \( \varepsilon \) 情况下也能保持稳定。作者通过将 EOT 与统计物理中的薛定谔桥 (Schrödinger Bridge) 问题联系起来,重构了 EOT,并设计出一种优雅的鞍点优化方案。该方法可扩展到大规模任务,并在合成数据和大规模图像任务上取得了当前最佳表现。


背景知识: OT、EOT 与薛定谔桥

最优传输 —— 高效移动概率质量

OT 将从一个分布 \( \mathbb{P}_0 \) 到另一个分布 \( \mathbb{P}_1 \) 移动“物质” (概率质量) 并最小化成本的问题形式化。使用二次代价的 Kantorovich 公式如下:

带二次代价的 Kantorovich 最优传输公式。

公式 (1): OT 目标函数是在所有具有正确边缘分布的传输方案中,最小化将质量从 \( x \) 移动到 \( y \) 的平均平方距离。

这里,\( \Pi(\mathbb{P}_0, \mathbb{P}_1) \) 是所有可能的传输方案 \( \pi(x,y) \) 的集合,这些方案的边缘分布为 \( \mathbb{P}_0 \) 和 \( \mathbb{P}_1 \)。最优方案 \( \pi^* \) 告诉我们如何将质量从每个 \( x \) 传输到每个 \( y \)。

熵正则化 OT —— 鼓励随机性

EOT 通过添加熵项来修改 OT,使传输方案更“分散”:

熵正则化最优传输的两种常见公式。

公式 (2)–(3): 用熵或 KL 散度对 OT 进行正则化的两种等价形式,由 \( \varepsilon > 0 \) 控制。

其中,\( -\varepsilon H(\pi) \) 用于鼓励随机性;较大的 \( \varepsilon \) → 映射多样性更高。小的 \( \varepsilon \) 会产生接近确定性 OT 的映射,但数值上更具挑战性。

薛定谔桥 —— 寻找最可能的路径

薛定谔桥 (Schrödinger’s Bridge, SB) 考虑的是在某个随机过程中,从 \( t=0 \) 时的 \( \mathbb{P}_0 \) 到 \( t=1 \) 时的 \( \mathbb{P}_1 \) 的整个路径。常用的“先验”过程是布朗运动 (维纳过程) :

方差为 ε 的维纳过程的随机微分方程。

公式 (4): 方差为 \( \varepsilon \) 的先验扩散过程。

SB 要问的是: 在所有与期望的起始和结束分布相匹配的随机过程中,哪一个与这个简单先验过程*最接近 *(以 KL 散度衡量) ?

薛定谔桥问题,最小化与维纳过程的 KL 散度。

公式 (5): SB 最小化候选过程与维纳先验之间的 KL 散度。

关键联系 —— SB 即 EOT

关键洞见是: SB 和 EOT 是等价的。两个过程之间的 KL 散度可以分解为:

两个随机过程之间 KL 散度的分解。

公式 (6): KL 散度可分解为起始/结束分布的散度,加上路径内部的条件散度。

对于最优 SB 过程 \( T^* \),条件项消失,仅剩起始-结束联合分布与先验联合分布之间的 KL 散度:

最小化过程间的 KL 散度等价于最小化它们起始-结束联合分布间的 KL 散度。

公式 (8): 优化过程等价于优化起始-结束联合分布——这就是 EOT 问题。

因此,求解 SB 就能得到 EOT 方案 \( \pi^* \)。在动态 SB 形式中,最优过程是一个带漂移 \( f(X_t, t) \) 的扩散过程,它最小化期望漂移能量:

动态薛定谔桥问题,最小化漂移函数的期望能量。

公式 (11): SB 的能量最小化形式。


ENOT 方法 —— 鞍点形式重构

通过拉格朗日松弛移除硬约束

直接约束 \( \pi_1^{T_f} = \mathbb{P}_1 \) 是困难的。ENOT 引入了一个类拉格朗日泛函:

松弛后的动态薛定谔桥问题的类拉格朗日泛函。

公式 (12): 目标函数中,势函数 \( \beta(y) \) 作匹配最终边缘分布的拉格朗日乘子。

具体而言:

  • 漂移网络 \( f_\theta \): 用于最小化 \( \mathcal{L} \),减少漂移能量并生成在势函数评分下得分高的最终样本。
  • 势函数网络 \( \beta_\phi \): 用于最大化 \( \mathcal{L} \),推动生成样本更接近目标分布。

该对抗性设置形成了一个鞍点问题:

鞍点优化问题。

公式 (13): 对势函数最大化、对漂移最小化可求解松弛的 SB,从而解出 EOT。

边缘约束在均衡时隐式满足,避免了高成本的强制约束。所有项均可通过采样估计,从而支持随机梯度训练。

实用算法 —— 熵正则化神经最优传输

ENOT 算法交替进行更新:

  1. 判别器更新 (\( \beta_\phi \)):

    • 从 \( \mathbb{P}_0 \) 中采样 \( X_0 \),模拟过程得到 \( X_1 \)。
    • 从 \( \mathbb{P}_1 \) 中采样 \( Y \)。
    • 更新 \( \beta_\phi \) 以最大化 \( \frac{1}{|Y|}\sum\beta(Y) - \frac{1}{|X_1|}\sum\beta(X_1) \)。
  2. 生成器更新 (\( f_\theta \)):

    • 采样新的 \( X_0 \)。
    • 模拟 \( X_t \) 和漂移 \( f_\theta(X_t, t) \)。
    • 最小化能量项 (漂移均方值) 与对抗项 \(-\frac{1}{|X_1|}\sum\beta(X_1)\)。

重复上述过程直至收敛。学到的 \( f_\theta \) 同时定义了 SB 和 EOT 解。


实验

2D 玩具示例 —— 从高斯分布到 8 个高斯分布

将一个方形分布映射到一个由 8 个高斯分布组成的环,针对不同的 ε 值。上排显示学习到的最终分布,下排显示样本轨迹。

图 2: ENOT 在 \( \varepsilon=0,0.01,0.1 \) 时的映射。小的 \( \varepsilon \) → 笔直的确定性路径;大的 \( \varepsilon \) → 更多随机性。

高维高斯分布 —— 定量评估

对于高斯分布 \( \mathbb{P}_0, \mathbb{P}_1 \),EOT/SB 存在闭式解。ENOT 在匹配目标分布和恢复真实传输方案方面优于基线方法:

维度ENOTLSOTSCONESMLE-SBDiffSBFB-SDE-AFB-SDE-J
20.011.821.740.410.700.870.03
160.096.421.870.501.110.940.05
640.2332.186.271.161.981.850.19
1280.5064.326.882.132.201.950.39

表 1: 匹配边缘分布的误差 (越低越好) 。

维度ENOTLSOTSCONESMLE-SBDiffSBFB-SDE-AFB-SDE-J
20.0126.770.920.300.880.750.07
160.0514.561.360.901.701.360.22
640.1325.564.621.342.322.450.34
1280.2947.115.331.802.432.640.58

表 2: 恢复真实传输方案的误差 (越低越好) 。

性能差距随维度升高而增大。

彩色 MNIST —— 多样性控制

彩色 MNIST 的定性比较。ENOT (a-c) 生成高质量、多样化的样本,并显著优于 DiffSB (d-e) 和 SCONES (i-j) 的 FID 分数。

图 3: 即使在较小 \( \varepsilon \) 情况下,ENOT 也能从“2”生成清晰、多样的“3”。

ENOT 在 \( \varepsilon = 1.0 \) 时 FID = 6.28,优于 SCONES (14.73) 和 DiffSB (93),而后者在小 \( \varepsilon \) 时性能急剧下降。

CelebA 人脸 —— 非配对超分辨率

CelebA 超分辨率定性结果。(a) 降质测试输入,(b) 真实标签。SCONES (c) 在大 ε 下生成随机人脸。ENOT (f-h) 在不同 ε 下生成逼真细节并保持身份特征。

图 4: ENOT 在添加逼真细节的同时保持身份特征;而 SCONES 在 \( \varepsilon=100 \) 时丢失了输入结构。

ENOT 的 FID 得分: 3.78 (\( \varepsilon=0 \)),7.63 (\( \varepsilon=1 \));SCONES 为 14.8。多样性随 \( \varepsilon \) 增加,与保真度相互平衡。

图 1 展示了 ENOT 学到的渐进式去模糊过程:

为 CelebA 超分辨率学习到的扩散过程。每行展示从模糊输入 (左) 到清晰输出 (右) 随时间 t 的变化。更高的 ε (下方) 为过程引入更多噪声和多样性。

图 1: 在 \( \varepsilon=0,1,10 \) 时 CelebA 去模糊任务样本的轨迹。


结论与启示

ENOT 框架的贡献:

  • 新的视角: 将 EOT 视作动态薛定谔桥问题,实现高效、基于样本的鞍点优化。
  • 小 \( \varepsilon \) 时的稳定性: 对于可控多样性且高保真的映射至关重要。
  • 顶尖表现: 高维与复杂图像到图像转换任务中取得最佳效果。

除超分辨率与数字转换外,ENOT 的可调随机映射还可应用于风格迁移、领域自适应,以及其他需要灵活且忠实分布对齐的任务。该方法将物理启发的理论与神经优化相结合,推进了熵正则化最优传输在方法论与应用上的发展。