Quantifying uncertainty is one of the biggest hurdles in building truly trustworthy AI systems. For a model to be reliable, it needs to recognize what it doesn’t know. Whether it’s a self-driving car encountering an unusual obstacle or a medical AI analyzing a rare condition, we need our models to respond with, “I’m not sure” rather than making a confident but wrong prediction.

Over the years, the machine learning community has developed several distinct approaches to this problem. On one side, we have the principled, probability-first world of Bayesian methods. These approaches, including techniques like Variational Inference (VI) and Langevin sampling, treat model parameters not as single point estimates, but as entire probability distributions—naturally capturing uncertainty.

On the other side, we have a surprisingly effective and simple heuristic: Deep Ensembles—training the same neural network multiple times with different random initializations and averaging their predictions.

Despite their practical success, deep ensembles have been a bit of a theoretical puzzle. Are they secretly Bayesian? Are they something else entirely? The debate has been lively.

A recent NeurIPS 2023 paper, “A Rigorous Link between Deep Ensembles and (Variational) Bayesian Methods”, introduces a unified theory that places deep ensembles, variational inference, and even new algorithms under a single conceptual roof. The key insight is to reframe the entire problem: instead of optimizing a model’s parameters in a complex, non-convex landscape, they lift the problem into the infinite-dimensional space of probability measures, where the optimization becomes convex and well-behaved.

Guided by the mathematical machinery of Wasserstein gradient flows (WGF), this perspective not only explains why deep ensembles work but also leads to brand-new ensemble algorithms with proven convergence properties.


From Bumpy Roads to Smooth Highways: Convexity via Probabilistic Lifting

In deep learning, the losses we minimize are often horrendously non-convex. Imagine a vast mountainous landscape, full of valleys and peaks. Standard gradient descent is like rolling a ball downhill—it will settle into the nearest valley (local minimum), which might be far from the deepest valley (global minimum). This complexity makes theoretical guarantees difficult.

The authors propose “flattening” this landscape by lifting the optimization problem from the parameter space \(\theta \in \mathbb{R}^J\) to the space of probability measures \(\mathcal{P}(\mathbb{R}^J)\) over parameters.

Figure 1: The paper’s three-step process to transform a non-convex optimization problem into a strictly convex one on the space of probability measures.

Figure 1: Step-by-step illustration of probabilistic lifting and convexification.

The process works as follows:

  1. Standard problem:

    \[ \min_{\theta \in \Theta} \ell(\theta) \]


    Find one optimal set of parameters \(\theta\).

  2. Probabilistic lifting:
    Replace the search for a single \(\theta\) with a search for a distribution \(Q\) over \(\theta\):

    \[ \min_{Q \in \mathcal{P}(\mathbb{R}^J)} \int \ell(\theta) \, dQ(\theta) \]


    This objective is linear—and therefore convex—in \(Q\).

  3. Strict convexification:
    Add a strictly convex regularizer \(D(Q, P)\) that measures divergence from a fixed reference distribution \(P\):

    \[ \min_{Q} \left[ \int \ell(\theta) \, dQ(\theta) \;+\; \lambda\,D(Q, P) \right] \]


    This yields a unique global minimizer \(Q^*\).

The generalized variational inference objective. L(Q) is the total loss for a distribution Q, consisting of the expected loss plus a regularization term that measures the divergence D between Q and a reference measure P. Q* is the unique distribution that minimizes this objective.

Figure 2: Generalized Variational Inference objective structure.

This final form—involving a loss term plus a regularizer—is the Generalized Variational Inference (GVI) objective. By appropriate choice of \(\ell\) and \(D\), it can represent many paradigms:

  • Bayesian inference: \(\ell\) is a negative log-likelihood, \(D\) is the KL divergence, \(P\) is the prior, \(Q^*\) is the posterior.
  • PAC-Bayes bounds: \(\ell\) is a predictive loss, and \(D(Q,P)\) measures the complexity of \(Q\) relative to \(P\).

The lifting turns a messy non-convex landscape into a clean, strictly convex problem in a richer space, with a single optimal solution \(Q^*\) summarizing the best local and global solutions.


Gradient Descent in Infinite Dimensions

We now have a convex \(L(Q)\)—but \(Q\) is a probability measure, an infinite-dimensional object. How do we minimize it directly?

Finite-Dimensional GVI (FD-GVI):

Restrict \(Q\) to a parameterized family, e.g., Gaussians \(Q_\nu\) with parameters \(\nu = (\mu, \sigma^2)\). Optimize \(L(Q_\nu)\) over \(\nu\) with standard gradient descent. This is classical VI.

Drawbacks:

  1. Approximation error: Simple families can’t capture complex multi-modal \(Q^*\).
  2. Lost convexity: \(L(Q_\nu)\) is often non-convex in \(\nu\), reintroducing the original problem.

Infinite-Dimensional GVI (ID-GVI):

Instead of restricting \(Q\), follow gradient descent in the space of probability measures itself.

This is done via gradient flows: continuous-time limits of gradient descent trajectories. In our case, the space has its own geometry defined by the 2-Wasserstein distance—measuring the “transport cost” between distributions.

The discrete-time update in Wasserstein space is:

The iterative update step for distributions in Wasserstein space.

Figure 3: Iterative WGF update—the loss term plus a Wasserstein-distance penalty.

Taking \(\eta \to 0\) yields the Wasserstein Gradient Flow (WGF) PDE:

The Wasserstein Gradient Flow equation.

Figure 4: PDE for the evolution of the density \(q(t,\theta)\) along steepest descent in probability space.


From Abstract Flows to Concrete Particles

Solving the WGF PDE in high dimensions is infeasible. The breakthrough: for a broad class of objectives, WGF is equivalent to simulating interacting particles.

Consider the free energy functional:

The free energy functional.

Figure 5: External potential \(V\), pairwise interaction \(\kappa\), and entropy.

The WGF of this objective corresponds to a McKean–Vlasov stochastic process, approximable by \(N_E\) interacting particles evolving via:

The stochastic differential equation (SDE) governing particle evolution.

Figure 6: Drift from \(\nabla V\) and particle interactions; diffusion from Brownian motion.

Each particle \(\theta_n\) moves under:

  1. External drift: \(-\nabla V(\theta_n)\)
  2. Interaction: Repulsion or attraction via \(\kappa\)
  3. Noise: Brownian kicks scaled by \(\lambda_2\)

Different choices of \(D(Q,P)\) translate into different \(V, \lambda_1, \lambda_2\) and thus different ensemble algorithms.


A Unified Family of Ensemble Algorithms

Case 1: No Regularizer → Deep Ensembles (DE)

\(\lambda_1 = 0, \lambda_2 = 0\) → independent deterministic gradient descent:

The update rule for Deep Ensembles.

Figure 7: DE particles follow pure gradient descent from different \(Q_0\).

Theory shows DEs converge to a mixture of local minima, weighted by initial basin sizes—not the true \(Q^*\).


Case 2: KL Regularizer → Deep Langevin Ensembles (DLE)

KL regularization (\(D = \mathrm{KL}\)) → no interactions, but noise from Brownian motion:

The SDE for Deep Langevin Ensembles.

Figure 8: DLE particles follow Langevin dynamics with drift from \(\ell\) and \(\log p\), plus isotropic noise.

Result: particles sample from the unique global minimizer \(Q^*\) of the KL-regularized objective.


Case 3: MMD + KL Regularizer → Deep Repulsive Langevin Ensembles (DRLE)

MMD induces repulsion between particles; KL ensures density existence. All three forces active:

Encourages exploration of multiple modes, avoiding particle clustering. DRLE is new—and provably converges to its \(Q^*\).


Theory Meets Reality: Experiments

Illustrating the Theory

1D Toy Example:

Figure 2: A 1D toy experiment comparing DE, DLE, and DRLE.

Figure 9: DLE and DRLE match theoretical \(Q^*\). DE settles in both local minima.

2D Multi-modal Loss:

Figure 3: A 2D experiment on a multi-modal loss.

Figure 10: ID-GVI methods capture all four modes; FD-GVI with Gaussian misses multi-modal structure.


Real Data — UCI Regression

Table 1: NLL on UCI datasets for DE, DLE, DRLE.

Table 1: No single method consistently wins—performance varies by dataset.


Why DEs Remain Competitive

Loss landscapes in deep learning can have millions of minima. With small ensemble sizes (\(N_E \ll\) number of minima), each particle falls into its own well-separated local basin and rarely escapes—even with noise or repulsion.

Figure 4: Loss with thousands of minima; particles trapped.

Figure 11: DE, DLE, DRLE behave identically when modes ≫ particles.

This parallels MCMC’s multi-modality trap: samplers stuck in one mode can’t explore others effectively.


Conclusion: A New Language for Uncertainty

By lifting optimization into probability measure space and applying Wasserstein gradient flows, this paper unifies Bayesian inference and deep ensembles under a single theoretical framework.

Key takeaways:

  • Unified View: DEs, VI, Langevin, and DRLE are instances of particle-based WGF algorithms.
  • Rigorous DE Explanation: DEs = infinite-dimensional gradient descent on unregularized lifted loss.
  • Generative Theory: Framework inspired the new DRLE algorithm, with convergence guarantees.
  • Practical Insight: In mode-rich landscapes with few particles, DE ≈ DLE ≈ DRLE.

The WGF perspective is a powerful, generative lens for uncertainty quantification—offering both theoretical clarity and inspiration for novel, robust algorithms in deep learning.