Markov Chain Monte Carlo (MCMC) is the workhorse of modern Bayesian statistics. Whether we are modeling house prices, biological systems, or stock markets, we rely on MCMC to sample from complex posterior distributions.

In recent years, the hardware landscape has shifted dramatically. We have moved from running single chains on CPUs to running thousands of parallel chains on GPUs using libraries like JAX, PyTorch, and TensorFlow. The tool that makes this possible is automatic vectorization (e.g., JAX’s vmap). It allows us to write a function for a single chain and magically transform it to run on a batch of data simultaneously.

But there is a catch.

While vmap is powerful, it introduces a synchronization barrier. If you run 1,000 chains in parallel, and one chain takes 1,000 steps to finish while the others only need 10, the GPU waits for that single slow chain. The massive parallel potential is wasted on waiting.

In this post, we explore a new research paper, “Efficiently Vectorized MCMC on Modern Accelerators,” which proposes a clever solution: restructuring MCMC algorithms as Finite State Machines (FSMs). This approach de-synchronizes the chains, allowing for speed-ups of up to an order of magnitude.

The Problem: The Synchronization Trap

To understand the solution, we first need to understand how automatic vectorization handles control flow, specifically while loops.

Many advanced MCMC algorithms—like the No-U-Turn Sampler (NUTS), Slice Sampling, and Delayed Rejection—are iterative. They rely on while loops where the number of iterations depends on the current state of the chain. This is random and varies from sample to sample.

When you use vmap on a function with a while loop, the compiler generates code that executes the loop body for all chains in lock-step. It continues until the termination condition is met for every chain in the batch. If chain \(A\) finishes in 5 steps and chain \(B\) needs 100 steps, the hardware executes 100 steps for both. For chain \(A\), the last 95 steps are simply “masked out”—the computations happen, but the results are discarded.

The Math of Inefficiency

Let’s formalize this. Suppose we have \(m\) chains. Let \(N_{i,j}\) be the number of loop iterations required by the \(j\)-th chain to generate its \(i\)-th sample.

In the standard vectorized approach (the “lock-step” method), the time taken to generate a sample is determined by the slowest chain in the batch. The total cost \(C_0(n)\) to generate \(n\) samples looks like this:

The cost of standard vectorization is proportional to the sum of the maximums.

Here, we are summing up the maximum iterations across all chains for each sample.

Now, imagine an ideal world where chains didn’t have to wait for each other. They could run completely independently. The total time would be determined by the chain that takes the longest total time over the entire run, not step-by-step. The cost \(C_*(n)\) would look like this:

The ideal cost is proportional to the maximum of the sums.

The difference is subtle but massive. In the first equation, we pay the “max penalty” at every single iteration. In the second, we only pay it once at the end.

Visualizing the Bottleneck

To see this in practice, let’s look at Elliptical Slice Sampling. The authors profiled this algorithm on a real-world dataset.

Histograms showing slice shrinks. Left: Per sample distribution. Right: Average per chain convergence.

On the left of Figure 1, look at the distribution of “slice shrinks” (loop iterations) required per sample. The average is around 6. However, the dotted vertical lines show the average maximum across a batch of chains. If you run 1,024 chains, the batch doesn’t finish in 6 steps; it waits for the outlier that takes about 19 steps. You are doing 3x more work than necessary.

This inefficiency scales with the number of chains. As you add more parallel chains, the probability of hitting a “long tail” event (a sample that takes a long time) approaches 100%.

The Solution: MCMC as Finite State Machines

The researchers propose a fundamental restructuring of how we write MCMC code. Instead of writing a function with nested while loops, we break the algorithm down into a Finite State Machine (FSM).

What is an FSM in this context?

An FSM consists of a set of states and transitions. In the context of MCMC:

  • States are code blocks (sequences of instructions without loops).
  • Transitions determine which code block runs next based on the current variables.

This transformation flattens the control flow. Instead of a deep, nested structure, the algorithm becomes a flat graph of small steps.

The conversion of a code block with a while loop into a Finite State Machine graph.

Figure 2 illustrates this transformation for a simple algorithm with one while loop.

  • Code Block 1 becomes state \(S_1\).
  • Code Block 2 (the loop body) becomes state \(S_2\).
  • Code Block 3 (after the loop) becomes state \(S_3\).
  • The logic that decides whether to loop or exit becomes the transition function \(\delta\).

By applying this logic recursively, even complex algorithms like HMC-NUTS (which has nested loops) can be converted into FSMs.

FSM graphs for Delayed Rejection, Slice Sampling, Elliptical Slice, and HMC-NUTS.

Figure 3 shows the resulting FSM diagrams for several popular MCMC algorithms. Notice how NUTS (bottom right), usually a complex recursive algorithm, is represented as a structured flow between 5 discrete states.

The De-Synchronized Runtime

Once the algorithm is an FSM, we change how we run it. We define a single function called step. This function:

  1. Takes the current state index \(k\) and variables \(z\).
  2. Executes the code block corresponding to state \(k\).
  3. Calculates the next state using the transition logic.

Crucially, when we vectorize this step function using vmap, every chain runs step simultaneously. However, Chain 1 might be in the “Shrink” state while Chain 2 is in the “Propose” state.

They are executing the same program (step), but they are effectively in different parts of the logical algorithm. We wrap this step function in a single outer loop that runs until all chains have collected their desired number of samples.

This effectively pushes the synchronization barrier from “every loop iteration” to “the very end of sampling.”

Theoretical Speed-Up

How much faster is this approach? The paper provides a rigorous bound.

The relative efficiency of the FSM approach depends on the distribution of the workload. We can define a “Theoretical Efficiency Bound,” \(R(m)\), which represents the maximum speed-up possible for \(m\) chains.

Theoretical efficiency bound formula.

This ratio compares the expected maximum work of a batch against the expected average work of a single chain.

If the number of steps required (\(N\)) follows a distribution with a “long tail”—meaning most samples are fast, but occasionally one is very slow—\(R(m)\) becomes very large.

Graph showing R(m) increasing as the number of chains increases, specifically for skewed distributions.

Figure 4 demonstrates this behavior. For distributions with high skewness (Right Hand Side), the potential speed-up increases dramatically as you add more chains (increasing \(m\)). This is intuitive: the more chains you have, the more likely one of them is getting stuck in a long loop, dragging down the standard implementation. The FSM implementation ignores that slow chain and lets the others proceed.

Optimizing the FSM

Simply converting code to an FSM isn’t enough. In a naive implementation, the step function has to switch between all possible states. In a vectorized environment (SIMD), this often means executing all branches and masking out the ones that don’t apply. This adds overhead.

The authors introduce two key optimizations to make FSMs practical:

1. Step Bundling

Instead of taking just one transition per step call, we can “bundle” sequential steps. If we know that State A is often followed immediately by State B, we can write a bundled_step that tries to execute both in one go. This reduces the number of times we have to invoke the overhead of the main loop.

2. Cost Amortization

Some operations, like calculating the log-probability density (\(\log p(x)\)), are computationally expensive. If vmap executes all branches of a switch statement, we risk calculating \(\log p(x)\) multiple times or when it’s not needed. The authors design an amortized step that flags when an expensive computation is needed. The runtime then executes the expensive function once for the whole batch, only for the chains that requested it.

Experimental Results

The researchers implemented this framework in JAX and tested it against standard implementations (like those in BlackJAX).

Elliptical Slice Sampling

In a Gaussian Process regression task, the FSM implementation showed a flat cost per sample regardless of the number of chains, while the standard implementation’s cost rose logarithmically.

Results for Elliptical Slice Sampling showing walltime and ESS/second improvements.

As shown in Figure 5:

  • Left: The standard implementation (blue) requires more iterations per sample as chains increase (due to synchronization). The FSM (red) remains constant.
  • Middle/Right: The FSM achieves significantly higher Effective Samples Per Second (ESS/S), approaching the theoretical limit.

Delayed Rejection

For the Delayed Rejection algorithm, which tries multiple proposals if the first is rejected, the FSM approach yielded nearly an order of magnitude speed-up.

Walltimes and ESS for Delayed Rejection.

Figure 6 shows that the “Condensed FSM” (using step bundling) tracks the ideal performance curve, while the standard implementation suffers as the chain count (\(m\)) grows.

HMC-NUTS

The No-U-Turn Sampler is the industry standard for correlated high-dimensional distributions. It is also notoriously difficult to vectorize efficiently because the number of integration steps varies wildly.

HMC-NUTS performance. The histogram shows a long tail of integration steps.

Figure 7 (Middle) shows the distribution of steps for NUTS. Most samples need very few steps, but there is a long tail (the rare samples needing >800 steps).

  • Result: The FSM implementation (Red bars, Right) achieves massive throughput gains compared to the standard version (Blue bars), especially at 100+ chains.

Challenging Geometries

Finally, the authors tested the method on “funnel” distributions and other challenging geometries where gradients are tricky.

Table of speedups on benchmark distributions.

Table 1 confirms that across various difficult datasets (Predator Prey, Google Stock), the FSM versions of NUTS and TESS (Transport Elliptical Slice Sampling) consistently outperform state-of-the-art baselines, offering speed-ups ranging from 1.5x to 3.5x.

Conclusion

The shift to GPUs and automatic vectorization offered a promise of massive parallelism for Bayesian statistics. However, the rigid “lock-step” execution of tools like vmap clashed with the variable runtime of iterative MCMC algorithms.

By reframing these algorithms as Finite State Machines, the authors of this paper have bridged that gap. Their method allows chains to operate asynchronously in logic while remaining synchronous in hardware execution.

For students and practitioners, the takeaway is clear: when working with massive parallelization, control flow is expensive. Flattening your algorithms into state machines is a powerful design pattern that can unlock the full potential of modern accelerators.


The images used in this post are derived from the paper “Efficiently Vectorized MCMC on Modern Accelerators” (2025).