利用 Wasserstein 重心和幂均值驾驭 MCTS 中的不确定性

蒙特卡洛树搜索 (MCTS) 是现代人工智能中一些最令人印象深刻的壮举背后的引擎,其中最著名的是 AlphaGo 和 AlphaZero 在围棋和国际象棋等游戏中取得的超人类表现。这些算法的工作原理是构建一个可能性的搜索树,模拟未来的结果,并回溯这些价值以便在根节点做出最佳决策。

但这其中有一个陷阱。传统的 MCTS 擅长 确定性 环境——即把棋子移动到 E4 格子就一定会在 E4 格子上。然而,现实世界是混乱的。它是随机的 (动作会有随机结果) 且部分可观测的 (我们无法看到一切) 。在这些“战争迷雾”场景中,标准的 MCTS 表现挣扎。它通常依赖简单的平均值来估计状态的价值,有效地掩盖了定义随机环境的风险和高方差结果。

在这篇文章中,我们将深入探讨一种称为 Wasserstein MCTS (W-MCTS) 的新方法。这种方法重新构想了搜索树中的节点,不再视其为单一的数值,而是 概率分布 。 通过使用来自最优传输理论 (Optimal Transport theory) 的复杂工具——特别是 \(L^1\)-Wasserstein 重心——它不仅将估计值,还将该值的 不确定性 一路传播到树的根部。

问题所在: 当平均值撒谎时

在标准的 MCTS (如流行的 UCT 算法) 中,当搜索访问一个状态时,它会运行一次模拟,获得一个奖励,并更新该状态的“平均值”。如果一个状态有一半的时间带来 +100 的奖励,另一半时间带来 -100 的奖励,那么平均值就是 0。而一个总是返回 0 的状态,其平均值也是 0。

对于标准的 MCTS 来说,这两个状态看起来完全一样。但对于处在风险环境中的智能体来说,它们有着天壤之别。一个是避风港;另一个则是拿命赌博的抛硬币。

这种无法区分“安全”与“风险”或无法处理高方差的能力,导致标准 MCTS 在随机环境中失效。它会导致不稳定的估计以及糟糕的探索-利用 (exploration-exploitation) 平衡。

解决方案: Wasserstein MCTS

研究人员提出了一种根本性的转变: 将树中的每个节点建模为高斯分布。

我们不再仅仅存储一个均值 \(\mu\),而是存储一个均值和一个标准差 \((\mu, \sigma)\)。这使得算法能够追踪它对特定状态价值的“确信”程度。

魔力在于这些分布是如何结合的。当父节点查看其子节点时,它不仅仅是对它们的数值求平均。它计算的是子节点分布的 Wasserstein 重心 (Wasserstein Barycenter)

1. 数学基础: 最优传输

要理解 W-MCTS 如何结合分布,我们需要谈谈 Wasserstein 距离。通常被称为“推土机距离 (Earth Mover’s Distance) ”,它衡量的是将一个概率分布转换为另一个概率分布的成本。

分布 \(\mu\) 和 \(\nu\) 之间的 \(L^q\)-Wasserstein 距离的一般定义为:

Wasserstein 距离公式

这里,\(\Gamma(\mu, \nu)\) 代表将质量从分布 \(\mu\) 传输到 \(\nu\) 的所有可能方式。“重心”仅仅是使得与一组其他分布的总 Wasserstein 距离最小化的那个分布。它是概率分布空间中的几何“质心”。

Wasserstein 重心公式

虽然 \(L^2\)-Wasserstein 距离在机器学习中很常见,但这篇论文选择了 \(L^1\)-Wasserstein 距离 并结合了 \(\alpha\)-散度 (\(\alpha\)-divergence)

2. 为什么选择 \(L^1\) 和 \(\alpha\)-散度?

作者选择 \(L^1\) 度量是因为它对异常值和大的偏差更具鲁棒性,这在随机强化学习中很常见。此外,他们将其与一种称为 \(\alpha\)-散度 的特定成本函数配对。

\(\alpha\)-散度是一族距离度量,允许我们调整比较两个值的方式。其定义为:

alpha-散度公式

通过调整参数 \(\alpha\),我们可以改变距离度量的行为。这种组合导致了论文中使用的特定 \(L^1\)-Wasserstein 公式:

L1 Wasserstein 定义

这看起来可能很繁重,但它引出了我们接下来要讨论的一个优美且实用的结果: 幂均值 (Power Mean)

核心贡献: 幂均值回溯

这就论文最重要的贡献在于证明了: 如果你假设节点是高斯分布,并且使用 \(\alpha\)-散度计算 \(L^1\)-Wasserstein 重心,那么这个复杂的优化问题会坍缩成一个简洁的闭式解。

父节点更新后的均值和标准差变成了其子节点的 幂均值

让我们定义一个参数 \(p = 1 - \alpha\)。价值节点 (value node) 的均值 (\(\overline{m}\)) 和标准差 (\(\overline{\delta}\)) 的更新规则为:

幂均值更新规则

解读幂均值 (\(p\))

这个结果之所以强大,是因为它将不同的回溯策略统一到了单一参数 \(p\) 中:

  • 如果 \(p = 1\): 更新变为标准的算术 平均 (Average) 。 这是风险中性的。
  • 如果 \(p \to \infty\): 更新行为类似于 最大值 (Max) 算子。这是高度乐观的 (类似于标准的贝尔曼最优方程) 。
  • 如果 \(p\) 适中: 它在平均值和最大值之间进行插值。

这使得 W-MCTS 可以作为一个“广义均值”算法运行。在“最大值”很危险 (高估幸运的转换) 的高度随机环境中,你可以将 \(p\) 设置得接近 1。在希望找到最优路径的确定性环境中,你可以将 \(p\) 设置得更高。

推广到粒子滤波

论文指出,这种数学方法不仅适用于高斯分布。如果你使用粒子 (样本) 来建模不确定性,更新规则在结构上保持一致:

粒子更新规则

W-MCTS 算法实战

既然我们有了结合分布的数学方法,那么实际的 MCTS 循环是如何工作的呢?

第一步: 分布回溯

当搜索到达叶子节点并回溯价值时,它使用上面推导出的幂均值公式更新均值 (\(V_m\)) 和标准差 (\(V_{std}\)) 估计。

均值和标准差定义

回溯算子使用访问计数 \(n(s,a)\) 对不同路径的贡献进行加权。这确保了根节点的方差准确地反映了从树深处传播上来的不确定性。

回溯算子方程

第二步: 动作选择 (探索)

标准的 MCTS 使用 UCT 公式 (置信区间上界) 来选择动作。W-MCTS 引入了两种基于统计学的方法,利用传播上来的分布来选择动作。

策略 A: 乐观选择 (W-MCTS-OS)

这与 UCT 类似,但用传播的标准差 \(\sigma\) 替换了标准的探索奖励项。这意味着探索是由子树的 实际 不确定性驱动的,而不仅仅是访问计数。

乐观选择方程

策略 B: 汤普森采样 (W-MCTS-TS)

汤普森采样 (Thompson Sampling) 是一种强大的贝叶斯方法。我们不再计算边界,而是从每个动作的分布 \(\mathcal{N}(m, \sigma^2)\) 中 采样 一个值,并选择样本值最高的动作。

汤普森采样方程

这种方法自然地平衡了探索和利用。如果一个节点方差 (不确定性) 很高,它可能会产生一个很高的样本值,促使智能体去探索它。随着方差减小,样本会更紧密地聚集在均值周围。

理论保证

这项工作的一个主要优势是其严谨的理论支持。许多对 MCTS 的修改缺乏收敛性证明。作者证明了 W-MCTS,特别是汤普森采样变体,能够收敛到最优策略。

根节点估计价值函数的误差以多项式速率下降:

收敛界限

这个 \(\mathcal{O}(n^{-1/2})\) 的速率与此类问题的已知下界相匹配,证实了传播不确定性并不是以牺牲渐近最优性为代价的。

实验结果

理论听起来很棒,但实际效果如何?作者在那些会让标准 MCTS 失效的挑战性领域测试了 W-MCTS: 高度随机的 MDP (如 FrozenLakeRiverSwim) 以及部分可观测 MDP (Pocman, Rocksample) 。

完全可观测的随机环境

在像 FrozenLake (冰面很滑,移动是随机的) 和 RiverSwim 这样的环境中,尽管存在持续的噪声,智能体仍必须规划长序列的动作。

下图比较了 W-MCTS (红线) 与标准 UCT (蓝色虚线) 以及其他基线算法如 DNG (贝叶斯 MCTS) 的表现。

随机 MDP 上的性能图表

分析:

  • W-MCTS-TS (红色) 始终比 UCT 学习得更快并获得更高的回报。
  • 鲁棒性:SixArmsTaxi 中,标准 UCT 表现非常挣扎,而 W-MCTS 保持了高性能。这证实了传播方差有助于智能体避开那些平均看起来不错但实际很危险的“陷阱”。

部分可观测环境 (POMDPs)

POMDPs 出了名的难,因为智能体不知道真实状态。它只有一个信念 (belief) 。作者在 Rocksample (采集样本的漫游者) 和 Pocman (视野受限的吃豆人) 上进行了测试。

Rocksample 上的性能图表

Rocksample (图 2) 中,W-MCTS-TS (红色) 在基线中占据主导地位,尤其是当网格尺寸和难度增加时 (从左到右) 。W-MCTS 和标准 UCT (蓝色) 之间的差距巨大,凸显了标准平均值根本无法处理部分可观测性的不确定性。

表 1 重点展示了 Pocman 中的结果。

Pocman 结果表

W-MCTS-TS 获得了 77.70 的分数,显著高于 UCT 的 28.5 。 它甚至优于 D2NG,这是一种专门针对 POMDP 的贝叶斯算法。

结论

Wasserstein MCTS 代表了规划算法向前迈出的重要一步。通过接受世界是不确定的,并利用高斯分布和最优传输显式地对这种不确定性进行建模,它弥合了“平均”情况与“最坏/最好”情况之间的鸿沟。

主要收获如下:

  1. 不要忽视方差: 在树中传播标准差可以实现更明智的探索。
  2. 幂均值很灵活: \(p\) 参数 (源自 \(\alpha\)-散度) 允许算法将其行为从风险中性调整为乐观。
  3. 最优传输是实用的工具: 使用 \(L^1\)-Wasserstein 重心提供了一种数学上合理的方法来平均分布,从而产生了简单、可实现的更新规则。

对于强化学习的学生和研究人员来说,这篇论文展示了引入其他数学领域 (如最优传输) 的概念如何解决经典 AI 算法中的根本局限性。