Modern machine learning thrives on the idea of amortization—training large models once so they can be applied instantly to many new problems. Pre-trained models like GPT-4 or Stable Diffusion embody this principle: by learning general structures from vast data, they enable fast adaptation to diverse tasks. Transformer-based architectures such as Neural Processes extend this notion to probabilistic meta-learning, allowing uncertainty-aware predictions across different domains.

Yet, these methods face a major limitation: rigidity. Most models are constrained to tasks of the form “given X, predict Y.” Real-world problems are rarely so simple—sometimes we may know partial data and have beliefs about hidden parameters, and we want to predict both unseen data and those latent quantities. Traditional approaches rarely permit dynamic incorporation of such knowledge (so-called priors) during inference. For instance, in Bayesian optimization we seek the location and value of a minimum, while in scientific modeling we infer simulator parameters. Each case typically demands a bespoke, computationally expensive solution.

Researchers from the University of Helsinki and Aalto University tackle this problem head-on in their paper on the Amortized Conditioning Engine (ACE). ACE is a transformer-based architecture that unifies probabilistic conditioning and prediction into a single, flexible operation. It allows conditioning on any mix of observed data and interpretable latent variables—and even accepts probabilistic inputs at runtime to predict any other combination of data and latent variables. In essence, ACE is a universal engine for probabilistic reasoning.

Figure 1 illustrates how diverse tasks like image classification, Bayesian optimization, and simulation-based inference can all be framed as problems of probabilistic conditioning and prediction.

Figure 1: Diverse tasks—from image completion to Bayesian optimization and simulation-based inference—are interpreted as conditioning on known quantities (data or latents) to predict unknown ones.


Background: From Neural Processes to Transformers

To appreciate ACE’s novelty, we must first understand its lineage. Neural Processes (NPs) learn distributions over functions. Given a few observed input-output pairs—called a context set
\(\mathcal{D}_N = \{(\mathbf{x}_1, y_1), \dots, (\mathbf{x}_N, y_N)\}\)—and a set of target inputs \(\mathbf{x}_{1:M}^*\), NPs predict the distribution over unknown target outputs \(y_{1:M}^*\).

A core property of NPs is permutation invariance: the model’s prediction should not depend on the order of context points. Early variants like Conditional Neural Processes (CNPs) achieved this with simple averaging—compressing the context into a single embedding vector.

Transformers changed the game. Their attention mechanism naturally respects permutation structure and uncovers complex dependencies between data points. This led to advances such as Transformer Neural Processes (TNPs) and Prior-Fitted Networks (PFNs). These use self-attention to encode the context and cross-attention to query it for predictions.

These transformer-based Neural Process models are diagonal, predicting each target point independently:

\[ \pi(y_{1:M}^{\star}|\mathbf{x}_{1:M}^{\star};\mathcal{D}_N) = \prod_{m=1}^{M} p\big(y_m^{\star}|\mathbf{r}(\mathbf{x}_m^{\star},\mathbf{r}_{\mathcal{D}}(\mathcal{D}_N))\big) \]

Despite predicting independently, joint distributions can be reconstructed autoregressively, making this family—Transformer Prediction Maps (TPM-D)—flexible and powerful. ACE extends this architecture with new capabilities.


The Core Concept: Inside the Amortized Conditioning Engine

ACE generalizes the notion of data points to include both observed measurements and latent variables that characterize a task.

Suppose a problem involves \(L\) latent variables,
\(\boldsymbol{\theta} = (\theta_1, \ldots, \theta_L)\).
Each pair \((\boldsymbol{\xi}, z)\) in ACE can represent either:

  1. A data point—\((\mathbf{x}, y)\), where \(\mathbf{x}\) is an input and \(y\) a value.
  2. A latent variable—\((\ell_l, \theta_l)\), where \(\ell_l\) is a token identifying latent \(l\).

This unification means ACE can predict any combination of targets (data or latents) conditioned on any combination of knowns. Formally:

\[ \pi(z_{1:M}^{\star}|\boldsymbol{\xi}_{1:M}^{\star};\boldsymbol{\mathfrak{D}}_N) = \prod_{m=1}^{M} p\big(z_m^{\star}|\mathbf{r}(\boldsymbol{\xi}_m^{\star},\mathbf{r}_{\mathcal{D}}(\boldsymbol{\mathfrak{D}}_N))\big) \]

This formulation turns probabilistic reasoning into a single structured computation shared across all tasks.


Architecture: The Devil in the Details

ACE builds upon TPM-D with critical upgrades for handling latent information and user-provided priors.

A conceptual diagram of the ACE architecture showing the handling of latent variables (θ) in the embedder and the flexible Gaussian Mixture/Category output head.

Figure 2: ACE architecture: embeddings handle both data and latent variables, while output heads adapt for continuous or discrete predictions.

1. Universal Embedding

All inputs—data points, latents, or priors—are first encoded into a common embedding space of dimension \(D_{\text{emb}}\):

  • Data point: \((\mathbf{x}_n, y_n)\) → \(f_{\mathbf{x}}(\mathbf{x}_n) + f_{\text{val}}(y_n) + \mathbf{e}_{\text{data}}\)
  • Latent variable: \(\theta_l\) → \(f_{\text{val}}(\theta_l) + \mathbf{e}_l\)
  • Unknown targets: value embedding replaced with learned \(\mathbf{e}_{?}\)
  • Prior information: a probability vector \(\mathbf{p}_l\) over latent values passes through \(f_{\text{prob}}(\mathbf{p}_l) + \mathbf{e}_l\)

This treatment allows any element—observation, latent, or prior—to be embedded seamlessly.

2. Attention Layers

ACE processes embeddings through stacked transformer layers with self-attention (encoding) and cross-attention (decoding). Context self-attention builds joint representations of observed samples, while target cross-attention links those to prediction queries. The design is computationally efficient with complexity \(O(N^2 + NM)\), better than naïve \(O((N+M)^2)\).

3. Output Heads

ACE outputs a predictive distribution for each target element:

  • Continuous outputs: A Gaussian Mixture Model (GMM)—multiple components learn multi-modal distributions.
  • Discrete outputs: A categorical distribution via softmax probabilities.

Learning to Reason with Priors

Perhaps ACE’s most remarkable ability is its support for runtime probabilistic priors. Users can specify beliefs about latent values (e.g., “the optimum is near 0.5”) as probability distributions \(p(\theta_l)\). These are discretized into probability histograms and embedded like other information.

Figure 2 shows prior amortization. (a) A user provides a prior over the mean (μ) and standard deviation (σ) of a Gaussian. (b) Observed data yields a likelihood. (c) Ground-truth Bayesian posterior combines prior and likelihood. (d) ACE’s inferred posterior matches the truth.

Figure 3: ACE performs prior amortization—combining user-provided priors with data to approximate the true Bayesian posterior in one forward pass.

During training, ACE encounters randomly generated priors and learns how they combine with evidence from data. It optimizes the negative log-likelihood of its predicted targets:

\[ \mathcal{L}(\mathbf{w}) = \mathbb{E}_{\mathbf{p}\sim\mathcal{P}}\Big[\mathbb{E}_{\mathfrak{D}_N,\boldsymbol{\xi}_{1:M},\mathbf{z}_{1:M}}\Big[-\sum_{m=1}^M \log q(z_m^{\star}|\mathbf{r}_{\mathbf{w}}(\boldsymbol{\xi}_m^{\star},\mathfrak{D}_N))\Big]\Big] \]

Minimizing this loss aligns the model’s outputs with Bayesian posteriors for the whole family of generative problems it learned, so it can later infer on unseen tasks without retraining.


ACE in Action: Experiments Across Domains

1. Vision — Image Completion and Classification

In computer vision, prediction can be expressed as a regression task: given partial pixel coordinates and values, predict missing pixels. The latent variables correspond to class labels (MNIST) or attributes (CelebA).

ACE can:

  • Complete images: Predict missing pixels given partial context.
  • Conditionally generate images: Produce images given features (“Bald = True”).
  • Classify: Infer latent attributes from partial data.

Figure 3 shows image completion results on the CelebA dataset. ACE (d) generates realistic completions compared to baseline TNP-D (c), while conditioning on true attributes (e) adds detail. Plot (f) shows ACE’s lower negative log-probability across contexts.

Figure 4: ACE outperforms TNP-D on image completion tasks. Conditioning on latent attributes further improves quality and likelihood scores.

ACE not only surpasses baselines in reconstruction quality but also flexibly switches between conditional generation and classification, depending on what is treated as “context” and “target.”


2. Optimization — Bayesian Search with Contextual Priors

Bayesian Optimization (BO) aims to locate the global minimum of an unknown function using few evaluations. Traditional algorithms use Gaussian Processes and sampled acquisition functions.

ACE recasts BO within its probabilistic conditioning framework: both the optimum location \(x_{\text{opt}}\) and value \(y_{\text{opt}}\) are explicit latents. The model learns closed-form predictive distributions for them, letting it skip complex approximation steps.

Figure 4 demonstrates ACE predictions in BO. (a) The model infers function values (purple) and distributions over optimal value (orange left) and location (orange bottom). (b) Conditioning on a better value refocuses predictions to promising regions.

Figure 5: ACE predicts function behavior and optimal locations directly, enabling efficient optimization.

ACE can implement acquisition functions elegantly:

  • Thompson Sampling (TS): Sample optimistic \(y_{\text{opt}}\) below current best, then propose \(\mathbf{x}_{\text{opt}}\) conditioned on this.
  • Max-Value Entropy Search (MES): Compute information gain about \(y_{\text{opt}}\) directly, using analytic distributions instead of expensive approximations.

Figure 5 shows BO results across benchmarks. ACE-MES (blue dashed) and ACE-TS (blue solid) match or exceed GP-based baselines (orange) and outperform TNP-D (green).

Figure 6: ACE’s performance rivals or surpasses Gaussian Process benchmarks for black-box optimization tasks.

Adding priors enhances ACE further. When users express beliefs about the optimum’s location, the ACEP variant uses these at inference to guide exploration.

Figure 6 shows that providing a prior over the optimum’s location boosts performance. ACEP-TS (blue dashed) converges faster than ACE-TS (solid) and rivals πBO-TS, a specialized prior-informed baseline.

Figure 7: With priors injected, ACEP accelerates convergence and performs competitively with πBO-TS, designed for prior-informed search.


3. Scientific Models — Simulation-Based Inference

Simulation-Based Inference (SBI) tackles the task of identifying model parameters that could generate observed data—essential for scientific applications that rely on simulators but lack tractable likelihoods.

ACE elegantly merges forward and inverse inference within one framework:

  • Predict \(p(\boldsymbol{\theta}|y)\) — posterior over parameters.
  • Predict \(p(y|\boldsymbol{\theta})\) — data given parameters.
  • Fill missing data — conditional data prediction.
ModelMetricsNPENRESimformerACEACEPweakACEPstrong
OUPlog-probs ↑ / RMSE ↓ / MMD ↓1.09 (0.10) / 0.48 (0.01) / -1.07 (0.13) / 0.49 (0.00) / -1.03 (0.04) / 0.50 (0.02) / 0.43 (0.02)1.03 (0.02) / 0.48 (0.00) / 0.51 (0.00)1.05 / 0.43 / 0.371.44 / 0.27 / 0.35
SIRlog-probs ↑ / RMSE ↓ / MMD ↓6.53 / 0.02 / -6.24 / 0.03 / -6.89 / 0.02 / 0.026.78 / 0.02 / 0.026.62 / 0.02 / 0.026.69 / 0.02 / 0.00
Turinlog-probs ↑ / RMSE ↓ / MMD ↓1.99 / 0.26 / -2.33 / 0.28 / -3.16 / 0.25 / 0.353.14 / 0.24 / 0.353.58 / 0.21 / 0.354.87 / 0.13 / 0.34

Table 1: ACE matches or improves upon specialist SBI models. When priors (ACEP) are available, parameter inference improves further.

ACE’s strength lies not just in accuracy but also efficiency: generating 1,000 samples takes milliseconds instead of minutes as in diffusion-based Simformer. With informative priors, ACEP delivers even stronger calibration and inference quality.


Reflections and Future Directions

The Amortized Conditioning Engine offers a unifying paradigm for tasks once thought distinct—vision, optimization, and scientific inference. By treating data points and latent variables as equivalent entities, it enables fluid conditioning and prediction across them, all via one amortized transformer model.

Takeaways:

  • Versatility: ACE handles regression, classification, optimization, and simulation-based inference seamlessly.
  • Flexibility: Users can define arbitrary conditioning/prediction combinations without changing architecture.
  • Human collaboration: Experts can supply probabilistic beliefs (priors) to guide the model’s inference interactively.

Limitations and prospects:
Like all attention-based systems, ACE scales quadratically with context size. Future work may leverage sub-quadratic attention or sparse conditioning. Expanding prior injection to many latents and discovering new interpretable latents are promising paths ahead.

Ultimately, ACE stands as a powerful demonstration of unified amortized reasoning. It points toward a future in which a single general model can flexibly serve optimization, vision, and scientific discovery—simply by conditioning on the right information.