Imagine you’re an astrophysicist tasked with modeling the motion of objects on different planets. You could build a separate simulator for each planet—one for Earth, one for Mars, one for Jupiter. But that would be wasteful. The laws of physics are universal; only a single parameter, the gravitational constant, varies from planet to planet. A smarter strategy is to build one general model and adapt it to each planet by estimating its gravity from a few examples.

This idea—reusing shared structure across related tasks—is the essence of amortized learning. Instead of learning everything from scratch, amortized learning captures common patterns and reuses them to solve new problems more efficiently. This principle powers modern AI advances, from meta-learning systems that learn to optimize to large language models (LLMs) that solve new problems using in-context examples.

While these methods share the same philosophy, they’ve historically appeared disconnected. A 2025 paper titled Iterative Amortized Inference: Unifying In-Context Learning and Learned Optimizers closes this gap by offering a unified mathematical framework, a clear taxonomy, and a scalable approach that overcomes a key limitation across existing methods.

This article walks through that work—showing how everything from MAML to GPT-style in-context learning fits into one elegant equation, and how a simple insight inspired by stochastic gradient descent enables amortized models to scale gracefully to large datasets.


The Quest for Rapid Adaptation: A Unified View

In conventional machine learning, we train one model per task. Given a dataset \( \mathcal{D}_{\mathcal{T}} \) for task \( \mathcal{T} \), training finds model parameters \( \hat{\boldsymbol{\theta}}_{\mathcal{T}} \) that minimize the loss:

\[ \hat{\boldsymbol{\theta}}_{\mathcal{T}} = \arg\min_{\boldsymbol{\theta}} \sum_{(\mathbf{x}, \mathbf{y}) \in \mathcal{D}_{\mathcal{T}}} \mathcal{L}\left(\mathbf{y}, f(\mathbf{x}, \boldsymbol{\theta})\right) \]

This works for a single task but fails to reuse knowledge. A classifier trained to recognize cats and dogs learns nothing helpful when asked to classify horses and zebras. It ignores the shared structure of “visual classification.”

Amortized learning does better by training across a distribution of tasks, so it can rapidly adapt to unseen tasks with minimal data. The authors unify these systems via a single equation:

\[ \min_{\gamma, \boldsymbol{\varphi}} \mathbb{E}_{\mathcal{T}} \mathbb{E}_{\mathbf{x}, \mathbf{y}, \mathcal{D}_{\mathcal{T}}} \left[ \mathcal{L} \left( \mathbf{y}, f_{\gamma}\left( \mathbf{x}, g_{\boldsymbol{\varphi}}(\mathcal{D}_{\mathcal{T}}) \right) \right) \right] \]

Where:

  1. Adaptation Function \( g_{\boldsymbol{\varphi}} \): Takes the task’s training data \( \mathcal{D}_{\mathcal{T}}^{\text{train}} \) and produces a task representation \( \boldsymbol{\theta}_{\mathcal{T}} \) (weights, latent vectors, or even raw examples). It has learnable parameters \( \boldsymbol{\varphi} \).

  2. Prediction Function \( f_{\gamma} \): Takes a query \( \mathbf{x} \) and the task representation \( \boldsymbol{\theta}_{\mathcal{T}} \) to produce predictions. It has shared parameters \( \gamma \) that encode cross-task inductive bias.

Depending on how \( f_{\gamma} \) and \( g_{\boldsymbol{\varphi}} \) are configured, this single framework recovers every major learning method—from vanilla supervised learning to meta-learning and in-context learning.

Functional decomposition of amortized learners

Table 1. Functional decomposition of amortized learners. Each method is expressed in terms of an adaptation function \( g_\varphi \) producing \(\theta_T\), consumed by a prediction function \(f_\gamma\).

Let’s look at some examples:

  • Supervised Learning: \( g_{\boldsymbol{\varphi}} \) is standard SGD; \( f_{\gamma} \) is a fixed architecture like ResNet. No amortization happens across tasks.
  • MAML: \( g_{\boldsymbol{\varphi}} \) is SGD initialized from learnable meta-weights \( \boldsymbol{\theta}_0 \in \boldsymbol{\varphi} \).
  • Learned Optimizers: \( g_{\boldsymbol{\varphi}} \) itself is a neural network that proposes parameter updates based on gradients.
  • In-Context Learning (ICL): The adaptation function is the identity—context examples are directly fed into a Transformer \( f_{\gamma} \). All adaptation happens implicitly in its forward pass.

This unified viewpoint offers a common language for understanding how learning systems reuse knowledge—and clarifies that the central differences lie in what they choose to learn: initialization, updates, or mappings.


A Taxonomy of Amortization: Parametric, Implicit, and Explicit

Expanding this unified view, the authors define three broad regimes of amortized learning.

1. Parametric Amortization

Here, \( f \) is fixed and \( g_{\boldsymbol{\varphi}} \) is learned. The system learns how to infer parameters for a pre-defined model.

  • Examples: Learned optimizers, hypernetworks.
  • Mechanism: \( g_{\boldsymbol{\varphi}} \) maps from data to parameters like linear model weights.
  • Benefits: Efficient use of gradients and interpretable parameters.
  • Trade-offs: Limited expressivity since \( f \) is fixed.

2. Implicit Amortization

In contrast, \( f_{\gamma} \) is learned and \( g \) is fixed (often identity). The model itself internalizes adaptation.

  • Examples: In-context learning, Prior-Fitted Networks.
  • Mechanism: A single network \( f_{\gamma} \) jointly processes query and context to predict outputs.
  • Benefits: Very expressive; the model can learn complex behavior directly.
  • Trade-offs: Expensive and opaque; the entire dataset must be reprocessed for every query.

3. Explicit Amortization

Both \( f_{\gamma} \) and \( g_{\boldsymbol{\varphi}} \) are learned—a hybrid approach.

  • Examples: Conditional Neural Processes (CNPs).
  • Mechanism: \( g_{\boldsymbol{\varphi}} \) compresses the task dataset into a latent embedding; \( f_{\gamma} \) uses this embedding to predict.
  • Benefits: Balances flexibility and interpretability.
  • Trade-offs: Harder to train since both parts influence each other dynamically.

The Scalability Problem

Despite their differences, most amortized learners share a common weakness: scalability.

  • Implicit models are limited by Transformer context length.
  • Parametric and explicit systems rely on pooled summaries or gradients, which lose fine-grained task information.

Large datasets overwhelm these models. Standard optimization solves this problem using mini-batches in stochastic gradient descent (SGD), iteratively refining parameters step by step. Can we bring the same principle to amortized learning?


Iterative Amortized Inference: Learning in Mini-Batches

The authors propose Iterative Amortized Inference (IAI)—a scalable approach where amortization itself occurs iteratively across mini-batches, much like SGD refines parameters over time.

For Parametric and Explicit Models

The model starts with an initial state \( \boldsymbol{\theta}^{(0)} \) and refines it step by step via a learned update function \( h_{\boldsymbol{\varphi}} \):

\[ \boldsymbol{\theta}^{(0)} \xrightarrow{h_{\boldsymbol{\varphi}}(\cdot, \mathcal{B}_{\text{train}}^{(0)})} \boldsymbol{\theta}^{(1)} \xrightarrow{h_{\boldsymbol{\varphi}}(\cdot, \mathcal{B}_{\text{train}}^{(1)})} \dots \xrightarrow{h_{\boldsymbol{\varphi}}(\cdot, \mathcal{B}_{\text{train}}^{(k-1)})} \boldsymbol{\theta}^{(k)} \eqqcolon \boldsymbol{\theta}_{\mathcal{T}} \]

Unlike learned optimizers that operate only on gradients, IAI allows updates based on raw data, gradients, or both—making it more flexible.

For Implicit Models

The model refines predictions directly. Starting with \( \hat{\mathbf{y}}^{(0)} \), a recurrent Transformer \( r_{\gamma} \) updates it successively:

\[ \hat{\mathbf{y}}^{(0)} \xrightarrow{r_{\gamma}([\mathbf{x}, \hat{\mathbf{y}}^{(0)}], \mathcal{B}_{\text{train}}^{(0)})} \hat{\mathbf{y}}^{(1)} \xrightarrow{r_{\gamma}([\mathbf{x}, \hat{\mathbf{y}}^{(1)}], \mathcal{B}_{\text{train}}^{(1)})} \cdots \xrightarrow{r_{\gamma}([\mathbf{x}, \hat{\mathbf{y}}^{(k-1)}], \mathcal{B}_{\text{train}}^{(k-1)})} \hat{\mathbf{y}}^{(k)} \]

Iterative Amortized Inference schematic

Figure 1. Iterative Amortized Inference for parametric, explicit, and implicit setups. Parametric/explicit models refine shared task states, while implicit models iteratively update predictions.

To train these models efficiently, the authors optimize for one-step improvements—without backpropagating through past iterations. This greedy approach is simple, stable, and scalable.


Experiments: Refinement Pays Off

Extensive experiments across regression, classification, and generative tasks show that multiple iterative refinement steps consistently improve performance.

Parametric Amortization

Parametric amortization results

Table 2. Parametric amortization results. Error decreases with more steps across tasks. Out-of-distribution (OoD) columns in gray show improved transfer to new datasets.

Explicit Amortization

Explicit amortization results

Table 3. Explicit amortization consistently benefits from additional refinement steps, especially when leveraging gradient signals.

Implicit Amortization

Implicit amortization results

Table 4. Iterative prediction refinement significantly lowers error across varied classification tasks.

Generative Modeling

In more complex generative settings, the iterative process improves sample quality. The model learns to reconstruct the underlying distribution of alphabets and Gaussian mixtures, with visibly clearer structure after 10 refinement steps.

Generative Samples Comparison

Figure 2. Implicit generative models for Gaussian mixtures and alphabets after 1 vs 10 steps show progressive refinement toward ground truth.


Key Insights from the Analysis

  1. Gradients vs. Data: Gradient-only signals are efficient but limited. Combining gradients with raw observations gives richer information and better generalization.

  2. Recent History Suffices: Providing multiple past states yields little benefit; relying on the latest one (a Markovian design) works well.

Effect of past states

Figure 3. Using 3 or 5 past states does not outperform simple one-state updates.

  1. Parametric Wins on Stability: Models with fixed \( f_{\gamma} \) often outperform explicit counterparts, highlighting optimization difficulties when co-training both networks.

  2. Efficiency: Iterative amortization scales linearly with the number of batches \(K\) instead of quadratically. Processing \(K\) batches of size \(B\) costs \(O(KB^2)\) rather than \(O((KB)^2)\), making IAI more data-efficient and memory-friendly.


Conclusion: Toward Scalable, Adaptive Learning Systems

The Iterative Amortized Inference framework elegantly unifies meta-learning, in-context learning, and learned optimizers under one mathematical umbrella:

\[ f_{\gamma}(\mathbf{x}, g_{\boldsymbol{\varphi}}(\mathcal{D}_{\mathcal{T}})) \]

By extending amortization into an iterative process—mirroring the success of stochastic optimization—the authors enable models to scale to large datasets while retaining fast adaptation. This bridges optimization-based and forward-pass paradigms, revealing them as complementary approaches to the same goal: reusing inductive bias efficiently across tasks.

Looking ahead, this iterative perspective opens the door to richer amortized learners—ones that refine themselves continuously, just as human reasoning does, learning to learn one batch at a time.