引言: 对齐复杂数据分布的挑战
想象一下,你有两组图像: 一组是模糊的照片,另一组是清晰的高分辨率照片。你将如何教一个模型将任何一张模糊照片转换成逼真的清晰版本?或者考虑将夏日风景图转换为冬日雪景图。这些都是现代机器学习中根本性挑战的例子: 找到一种有意义的方式,将一个复杂的概率分布映射到另一个。
最优传输 (Optimal Transport, OT) 为这项任务提供了严谨的数学框架。OT 旨在寻找从一个分布到另一个分布的最高效映射,以最小化给定的运输成本。尽管功能强大,但标准 OT 给出的映射是单一且确定的。对于像图像超分辨率这样的不适定问题,同一个低分辨率图像可能对应许多合理的高分辨率输出。我们需要的是随机映射——一对多的变换,既能生成多样化的结果,又保持真实感。
熵正则化最优传输 (Entropic Optimal Transport, EOT) 通过在 OT 问题中加入随机性 (熵) 来解决这一问题,其程度由正则化参数 \( \varepsilon \) 控制。较大的 \( \varepsilon \) 会增加多样性;较小的 \( \varepsilon \) 则趋近于确定性 OT。但问题在于,小 \( \varepsilon \) 恰恰是在高质量、可控生成中最有用的区间,而大多数现有的 EOT 算法在这一情况下会变得不稳定或不可行。
NeurIPS 2023 论文 “Entropic Neural Optimal Transport via Diffusion Processes” 提出了一种名为 ENOT 的方法,它是稳健的端到端 EOT 神经求解器,即使在小 \( \varepsilon \) 情况下也能保持稳定。作者通过将 EOT 与统计物理中的薛定谔桥 (Schrödinger Bridge) 问题联系起来,重构了 EOT,并设计出一种优雅的鞍点优化方案。该方法可扩展到大规模任务,并在合成数据和大规模图像任务上取得了当前最佳表现。
背景知识: OT、EOT 与薛定谔桥
最优传输 —— 高效移动概率质量
OT 将从一个分布 \( \mathbb{P}_0 \) 到另一个分布 \( \mathbb{P}_1 \) 移动“物质” (概率质量) 并最小化成本的问题形式化。使用二次代价的 Kantorovich 公式如下:
公式 (1): OT 目标函数是在所有具有正确边缘分布的传输方案中,最小化将质量从 \( x \) 移动到 \( y \) 的平均平方距离。
这里,\( \Pi(\mathbb{P}_0, \mathbb{P}_1) \) 是所有可能的传输方案 \( \pi(x,y) \) 的集合,这些方案的边缘分布为 \( \mathbb{P}_0 \) 和 \( \mathbb{P}_1 \)。最优方案 \( \pi^* \) 告诉我们如何将质量从每个 \( x \) 传输到每个 \( y \)。
熵正则化 OT —— 鼓励随机性
EOT 通过添加熵项来修改 OT,使传输方案更“分散”:
公式 (2)–(3): 用熵或 KL 散度对 OT 进行正则化的两种等价形式,由 \( \varepsilon > 0 \) 控制。
其中,\( -\varepsilon H(\pi) \) 用于鼓励随机性;较大的 \( \varepsilon \) → 映射多样性更高。小的 \( \varepsilon \) 会产生接近确定性 OT 的映射,但数值上更具挑战性。
薛定谔桥 —— 寻找最可能的路径
薛定谔桥 (Schrödinger’s Bridge, SB) 考虑的是在某个随机过程中,从 \( t=0 \) 时的 \( \mathbb{P}_0 \) 到 \( t=1 \) 时的 \( \mathbb{P}_1 \) 的整个路径。常用的“先验”过程是布朗运动 (维纳过程) :
公式 (4): 方差为 \( \varepsilon \) 的先验扩散过程。
SB 要问的是: 在所有与期望的起始和结束分布相匹配的随机过程中,哪一个与这个简单先验过程*最接近 *(以 KL 散度衡量) ?
公式 (5): SB 最小化候选过程与维纳先验之间的 KL 散度。
关键联系 —— SB 即 EOT
关键洞见是: SB 和 EOT 是等价的。两个过程之间的 KL 散度可以分解为:
公式 (6): KL 散度可分解为起始/结束分布的散度,加上路径内部的条件散度。
对于最优 SB 过程 \( T^* \),条件项消失,仅剩起始-结束联合分布与先验联合分布之间的 KL 散度:
公式 (8): 优化过程等价于优化起始-结束联合分布——这就是 EOT 问题。
因此,求解 SB 就能得到 EOT 方案 \( \pi^* \)。在动态 SB 形式中,最优过程是一个带漂移 \( f(X_t, t) \) 的扩散过程,它最小化期望漂移能量:
公式 (11): SB 的能量最小化形式。
ENOT 方法 —— 鞍点形式重构
通过拉格朗日松弛移除硬约束
直接约束 \( \pi_1^{T_f} = \mathbb{P}_1 \) 是困难的。ENOT 引入了一个类拉格朗日泛函:
公式 (12): 目标函数中,势函数 \( \beta(y) \) 作匹配最终边缘分布的拉格朗日乘子。
具体而言:
- 漂移网络 \( f_\theta \): 用于最小化 \( \mathcal{L} \),减少漂移能量并生成在势函数评分下得分高的最终样本。
- 势函数网络 \( \beta_\phi \): 用于最大化 \( \mathcal{L} \),推动生成样本更接近目标分布。
该对抗性设置形成了一个鞍点问题:
公式 (13): 对势函数最大化、对漂移最小化可求解松弛的 SB,从而解出 EOT。
边缘约束在均衡时隐式满足,避免了高成本的强制约束。所有项均可通过采样估计,从而支持随机梯度训练。
实用算法 —— 熵正则化神经最优传输
ENOT 算法交替进行更新:
判别器更新 (\( \beta_\phi \)):
- 从 \( \mathbb{P}_0 \) 中采样 \( X_0 \),模拟过程得到 \( X_1 \)。
- 从 \( \mathbb{P}_1 \) 中采样 \( Y \)。
- 更新 \( \beta_\phi \) 以最大化 \( \frac{1}{|Y|}\sum\beta(Y) - \frac{1}{|X_1|}\sum\beta(X_1) \)。
生成器更新 (\( f_\theta \)):
- 采样新的 \( X_0 \)。
- 模拟 \( X_t \) 和漂移 \( f_\theta(X_t, t) \)。
- 最小化能量项 (漂移均方值) 与对抗项 \(-\frac{1}{|X_1|}\sum\beta(X_1)\)。
重复上述过程直至收敛。学到的 \( f_\theta \) 同时定义了 SB 和 EOT 解。
实验
2D 玩具示例 —— 从高斯分布到 8 个高斯分布
图 2: ENOT 在 \( \varepsilon=0,0.01,0.1 \) 时的映射。小的 \( \varepsilon \) → 笔直的确定性路径;大的 \( \varepsilon \) → 更多随机性。
高维高斯分布 —— 定量评估
对于高斯分布 \( \mathbb{P}_0, \mathbb{P}_1 \),EOT/SB 存在闭式解。ENOT 在匹配目标分布和恢复真实传输方案方面优于基线方法:
维度 | ENOT | LSOT | SCONES | MLE-SB | DiffSB | FB-SDE-A | FB-SDE-J |
---|---|---|---|---|---|---|---|
2 | 0.01 | 1.82 | 1.74 | 0.41 | 0.70 | 0.87 | 0.03 |
16 | 0.09 | 6.42 | 1.87 | 0.50 | 1.11 | 0.94 | 0.05 |
64 | 0.23 | 32.18 | 6.27 | 1.16 | 1.98 | 1.85 | 0.19 |
128 | 0.50 | 64.32 | 6.88 | 2.13 | 2.20 | 1.95 | 0.39 |
表 1: 匹配边缘分布的误差 (越低越好) 。
维度 | ENOT | LSOT | SCONES | MLE-SB | DiffSB | FB-SDE-A | FB-SDE-J |
---|---|---|---|---|---|---|---|
2 | 0.012 | 6.77 | 0.92 | 0.30 | 0.88 | 0.75 | 0.07 |
16 | 0.05 | 14.56 | 1.36 | 0.90 | 1.70 | 1.36 | 0.22 |
64 | 0.13 | 25.56 | 4.62 | 1.34 | 2.32 | 2.45 | 0.34 |
128 | 0.29 | 47.11 | 5.33 | 1.80 | 2.43 | 2.64 | 0.58 |
表 2: 恢复真实传输方案的误差 (越低越好) 。
性能差距随维度升高而增大。
彩色 MNIST —— 多样性控制
图 3: 即使在较小 \( \varepsilon \) 情况下,ENOT 也能从“2”生成清晰、多样的“3”。
ENOT 在 \( \varepsilon = 1.0 \) 时 FID = 6.28,优于 SCONES (14.73) 和 DiffSB (93),而后者在小 \( \varepsilon \) 时性能急剧下降。
CelebA 人脸 —— 非配对超分辨率
图 4: ENOT 在添加逼真细节的同时保持身份特征;而 SCONES 在 \( \varepsilon=100 \) 时丢失了输入结构。
ENOT 的 FID 得分: 3.78 (\( \varepsilon=0 \)),7.63 (\( \varepsilon=1 \));SCONES 为 14.8。多样性随 \( \varepsilon \) 增加,与保真度相互平衡。
图 1 展示了 ENOT 学到的渐进式去模糊过程:
图 1: 在 \( \varepsilon=0,1,10 \) 时 CelebA 去模糊任务样本的轨迹。
结论与启示
ENOT 框架的贡献:
- 新的视角: 将 EOT 视作动态薛定谔桥问题,实现高效、基于样本的鞍点优化。
- 小 \( \varepsilon \) 时的稳定性: 对于可控多样性且高保真的映射至关重要。
- 顶尖表现: 高维与复杂图像到图像转换任务中取得最佳效果。
除超分辨率与数字转换外,ENOT 的可调随机映射还可应用于风格迁移、领域自适应,以及其他需要灵活且忠实分布对齐的任务。该方法将物理启发的理论与神经优化相结合,推进了熵正则化最优传输在方法论与应用上的发展。