马尔可夫链蒙特卡洛 (MCMC) 是现代贝叶斯统计的中流砥柱。无论我们要在这个硬件格局剧变的时代建模房价、生物系统还是股票市场,我们都依赖 MCMC 来从复杂的后验分布中进行采样。

近年来,硬件领域发生了翻天覆地的变化。我们已经从在 CPU 上运行单链,转变为利用 JAX、PyTorch 和 TensorFlow 等库在 GPU 上运行数千条并行链。实现这一点的工具是自动向量化 (automatic vectorization,例如 JAX 的 vmap) 。它允许我们编写针对单链的函数,然后神奇地将其转换为能同时处理批量数据的形式。

但这其中有一个陷阱。

虽然 vmap 很强大,但它引入了同步壁垒 (synchronization barrier) 。 如果你并行运行 1000 条链,其中一条需要 1000 步才能完成,而其他的只需要 10 步,GPU 就必须等待那条缓慢的链。巨大的并行潜力被浪费在等待上。

在本文中,我们将探讨一篇新的研究论文**“Efficiently Vectorized MCMC on Modern Accelerators”** , 该论文提出了一个巧妙的解决方案: 将 MCMC 算法重构为有限状态机 (Finite State Machines, FSMs) 。 这种方法使链去同步化,从而实现了高达一个数量级的加速。

问题: 同步陷阱

要理解这个解决方案,我们首先需要理解自动向量化如何处理控制流,特别是 while 循环。

许多先进的 MCMC 算法——如 No-U-Turn Sampler (NUTS)、切片采样 (Slice Sampling) 和延迟拒绝 (Delayed Rejection) ——都是迭代式的。它们依赖 while 循环,且迭代次数取决于链的当前状态。这是随机的,并且每个样本都不尽相同。

当你对包含 while 循环的函数使用 vmap 时,编译器生成的代码会让批次中的所有链步调一致地执行循环体。它会持续运行,直到批次中每一条链都满足终止条件。如果链 \(A\) 在 5 步内完成,而链 \(B\) 需要 100 步,硬件会为两者都执行 100 步。对于链 \(A\),最后 95 步只是被“掩码 (masked out) ”了——计算照常进行,但结果被丢弃。

低效的数学原理

让我们将其形式化。假设我们有 \(m\) 条链。设 \(N_{i,j}\) 为第 \(j\) 条链生成第 \(i\) 个样本所需的循环迭代次数。

在标准的向量化方法 (“步调一致”法) 中,生成样本所需的时间由批次中最慢的那条链决定。生成 \(n\) 个样本的总成本 \(C_0(n)\) 如下所示:

标准向量化的成本与最大值的总和成正比。

这里,我们是将每个样本所有链中的最大迭代次数相加。

现在,想象一个理想世界,链之间不需要互相等待。它们可以完全独立运行。总时间将由整个运行过程中耗时最长的那条链决定,而不是每一步都受限。成本 \(C_*(n)\) 如下所示:

理想成本与总和的最大值成正比。

这种差异虽然微妙,但影响巨大。在第一个方程中,我们在每一次迭代都要付出“最大值惩罚”。在第二个方程中,我们只需要在最后付出一次。

瓶颈可视化

为了观察实际情况,让我们看看椭圆切片采样 (Elliptical Slice Sampling) 。作者在真实数据集上对该算法进行了性能分析。

直方图显示切片收缩次数。左: 每个样本的分布。右: 每条链的平均收敛情况。

图 1 的左侧,请观察每个样本所需的“切片收缩” (循环迭代) 分布。平均值约为 6。然而,虚线垂直线显示了一批链的平均最大值。如果你运行 1024 条链,批次并不会在 6 步内完成;它要等待那个耗时约 19 步的离群值。你做了 3 倍于必要的工作。

这种低效随着链的数量增加而扩大。添加的并行链越多,遇到“长尾”事件 (即生成一个样本需要很长时间) 的概率就越接近 100%。

解决方案: 将 MCMC 视为有限状态机

研究人员提议从根本上重构编写 MCMC 代码的方式。不再编写带有嵌套 while 循环的函数,而是将算法分解为有限状态机 (FSM)

这里的 FSM 是什么?

FSM 由一组状态和转换组成。在 MCMC 的语境下:

  • 状态是代码块 (没有循环的指令序列) 。
  • 转换根据当前变量决定接下来运行哪个代码块。

这种转换将控制流扁平化。算法不再是深层嵌套的结构,而是变成了一个由小步骤组成的扁平图。

将包含 while 循环的代码块转换为有限状态机图。

图 2 展示了一个包含单个 while 循环的简单算法的转换过程。

  • 代码块 1 变为状态 \(S_1\)。
  • 代码块 2 (循环体) 变为状态 \(S_2\)。
  • 代码块 3 (循环后) 变为状态 \(S_3\)。
  • 决定是循环还是退出的逻辑变为转换函数 \(\delta\)。

通过递归应用此逻辑,即使是像 HMC-NUTS (具有嵌套循环) 这样复杂的算法也可以转换为 FSM。

延迟拒绝、切片采样、椭圆切片和 HMC-NUTS 的 FSM 图。

图 3 展示了几种流行 MCMC 算法的最终 FSM 图。请注意 NUTS (右下角) ,通常是一个复杂的递归算法,现在被表示为 5 个离散状态之间的结构化流转。

去同步化的运行时

一旦算法变成了 FSM,我们就改变运行它的方式。我们定义一个名为 step 的单一函数。该函数:

  1. 接收当前状态索引 \(k\) 和变量 \(z\)。
  2. 执行对应于状态 \(k\) 的代码块。
  3. 使用转换逻辑计算下一个状态。

关键在于,当我们使用 vmap 向量化这个 step 函数时,每条链同时运行 step。但是, 链 1 可能处于“收缩 (Shrink) ”状态,而链 2 可能处于“提议 (Propose) ”状态。

它们在执行同一个程序 (step) ,但实际上处于逻辑算法的不同部分。我们将这个 step 函数包裹在一个单一的外部循环中,直到所有链都收集到了所需数量的样本。

这有效地将同步壁垒从“每次循环迭代”推迟到了“采样的最后”。

理论加速比

这种方法快多少?论文提供了一个严格的界限。

FSM 方法的相对效率取决于工作负载的分布。我们可以定义一个“理论效率界限” \(R(m)\),它代表 \(m\) 条链可能的最大加速比。

理论效率界限公式。

该比率比较了一批链的预期最大工作量与单条链的预期平均工作量。

如果所需的步骤数 (\(N\)) 服从“长尾”分布——意味着大多数样本很快,但偶尔会有一个非常慢——\(R(m)\) 就会变得非常大。

图表显示 R(m) 随链数量增加而增加,特别是对于偏斜分布。

图 4 展示了这种行为。对于高偏斜度的分布 (右侧) ,随着增加更多链 (增加 \(m\)) ,潜在的加速比急剧增加。这很直观: 链越多,其中一条链陷入长循环并拖慢标准实现的可能性就越大。FSM 实现则忽略那条慢链,让其他链继续运行。

优化 FSM

仅仅将代码转换为 FSM 是不够的。在朴素实现中,step 函数必须在所有可能的状态之间切换。在向量化环境 (SIMD) 中,这通常意味着执行 switch 语句的所有分支,并掩码掉不适用的分支。这会增加开销。

作者引入了两个关键优化来使 FSM 变得实用:

1. 步骤捆绑 (Step Bundling)

我们不是每次调用 step 只进行一次转换,而是可以“捆绑”连续的步骤。如果我们知道状态 A 通常紧接着状态 B,我们可以编写一个 bundled_step 尝试一次性执行两者。这减少了我们需要调用主循环开销的次数。

2. 成本摊销 (Cost Amortization)

有些操作,如计算对数概率密度 (\(\log p(x)\)) ,计算成本很高。如果 vmap 执行 switch 语句的所有分支,我们可能会冒着多次计算 \(\log p(x)\) 或在不需要时计算它的风险。作者设计了一个摊销步骤 (amortized step) , 用于标记何时需要昂贵的计算。然后,运行时会为整个批次执行一次昂贵的函数,仅针对那些请求它的链。

实验结果

研究人员在 JAX 中实现了该框架,并将其与标准实现 (如 BlackJAX 中的实现) 进行了测试。

椭圆切片采样

在高斯过程回归任务中,无论链的数量如何,FSM 实现都显示出平坦的单样本成本,而标准实现的成本则呈对数增长。

椭圆切片采样的结果显示运行时间 (walltime) 和每秒有效样本数 (ESS/second) 的改进。

如图 5 所示:

  • 左图: 随着链数量增加,标准实现 (蓝色) 每个样本需要更多的迭代 (由于同步) 。FSM (红色) 保持不变。
  • 中图/右图: FSM 实现了显著更高的每秒有效样本数 (ESS/S) ,接近理论极限。

延迟拒绝

对于延迟拒绝算法 (如果第一个提议被拒绝,则尝试多个提议) ,FSM 方法产生了接近一个数量级的加速。

延迟拒绝的运行时间和 ESS。

图 6 显示,“压缩版 FSM (Condensed FSM) ” (使用步骤捆绑) 紧贴理想性能曲线,而标准实现随着链数量 (\(m\)) 的增加而性能下降。

HMC-NUTS

No-U-Turn Sampler 是相关高维分布的行业标准。由于积分步数变化很大,它也是出了名地难以有效向量化。

HMC-NUTS 性能。直方图显示积分步骤的长尾分布。

图 7 (中) 展示了 NUTS 的步数分布。大多数样本需要的步数很少,但存在长尾 (极少数样本需要 >800 步) 。

  • 结果: 与标准版本 (蓝色柱) 相比,FSM 实现 (红色柱,右侧) 实现了巨大的吞吐量增益,尤其是在 100 条以上的链时。

具有挑战性的几何形状

最后,作者在“漏斗”分布和其他梯度棘手的具有挑战性的几何形状上测试了该方法。

基准分布上的加速比表格。

表 1 证实,在各种困难的数据集 (Predator Prey, Google Stock) 中,NUTS 和 TESS (传输椭圆切片采样) 的 FSM 版本始终优于最先进的基线,提供的加速比范围从 1.5 倍到 3.5 倍不等。

结论

向 GPU 和自动向量化的转变许下了贝叶斯统计大规模并行化的承诺。然而,像 vmap 这样的工具僵化的“步调一致”执行方式与迭代 MCMC 算法的可变运行时间发生了冲突。

通过将这些算法重构为有限状态机 , 这篇论文的作者弥合了这一鸿沟。他们的方法允许链在逻辑上异步操作,同时在硬件执行上保持同步。

对于学生和从业者来说,结论很清楚: 在进行大规模并行化时, 控制流是昂贵的 。 将算法扁平化为状态机是一种强大的设计模式,可以释放现代加速器的全部潜力。


本文中使用的图片源自论文 “Efficiently Vectorized MCMC on Modern Accelerators” (2025)。