Training large-scale neural networks like Transformers is a cornerstone of modern AI, but it’s also one of the hardest parts of the process. At the center of this challenge lies the optimizer — algorithms like Adam or SGD that fine-tune model parameters step by step to minimize a loss function. Achieving top performance typically requires a laborious, resource-heavy cycle of trial and error, with endless hyperparameter tuning for each new architecture or task.

But what if we could learn the optimizer itself? What if, instead of hand-designing new training strategies, we could create a neural network that learns how to train other neural networks?

That vision lies at the heart of Learning to Learn (L2L) — also known as meta-learning. And the new DeepMind research paper “Mnemosyne: Learning to Train Transformers with Transformers” pushes this idea further than ever before. It introduces Mnemosyne, a learnable optimizer built on a specially designed Transformer architecture. Unlike previous learned optimizers that rely on recurrent models like LSTMs, Mnemosyne uses an innovative spatio-temporal attention mechanism that allows it to learn optimization strategies across space (parameters) and time (training history).

The results are striking: Mnemosyne can train entire neural networks — including massive Transformers — without any manual hyperparameter tuning. It rivals or even surpasses top-performing optimizers like Adam, while maintaining similar memory and computational efficiency.


From Hand-Tuned to Learned Optimizers

To grasp Mnemosyne’s breakthrough, it helps to formalize what an optimizer does. Training a neural network is a sequential decision process. At each iteration \( t \), the optimizer updates the model’s parameters \( \mathbf{x}_t \) to minimize its loss \( f(\mathbf{x}_t) \).

For a learnable optimizer, we can define this update rule as:

\[ \mathbf{x}_{t+1} = g_\theta(f, \mathbf{x}_0, ..., \mathbf{x}_t) \]

Here, \( g_\theta \) is the optimizer — itself a neural network parameterized by \( \theta \). During meta-training, \( g_\theta \) learns to minimize a “meta-loss,” usually the sum of the optimizee’s losses over several training steps. Put simply, we train the optimizer to become good at training other networks.

Traditional learned optimizers have used LSTMs as memory components, capable of remembering short-term patterns during optimization. However, LSTMs can suffer from catastrophic forgetting, losing track of useful long-term information. Transformers, by contrast, thrive on modeling long-range dependencies through attention. The challenge is their quadratic complexity with respect to sequence length — a serious issue for optimization histories spanning thousands of steps. Mnemosyne solves this elegantly.


Inside Mnemosyne: A Transformer That Optimizes Transformers

Mnemosyne isn’t a single block but a modular architecture combining two key encoders: one for spatial structure and one for temporal memory. It supports two modes of operation — coordinate-wise and tensor-wise — depending on whether individual parameters or entire tensors are optimized together.

Two modes of Mnemosyne optimization. (a) Coordinate-wise: each parameter \\( r_i \\) passes through a Compact Associative Memory (CAM) and MLP to generate its update. (b) Tensor-wise: A tensor \\( S_T \\) goes through the Hierarchical Pooling Encoder (HPE), CAM, and a Spatial Encoder (SPE) pipeline, producing a compact representation \\( e \\) that informs the final update.

Figure 2: Two complementary application modes of Mnemosyne.

Let’s break down these components.


Topological Encoder: Learning the Structure of the Model

Instead of viewing a network’s millions of parameters as a flat list, Mnemosyne respects the natural hierarchy of tensors and layers — weight matrices, bias vectors, and modules form a “parameter tree.” In tensor-wise mode, Mnemosyne looks at an entire tensor at once, converting each parameter into a simple 2D representation (often just the magnitude and sign of its gradient). This long sequence of tokens could easily exceed a million elements, too large even for efficient linear Transformers.

To tackle this, Mnemosyne introduces the Hierarchical Pooling Encoder (HPE), which progressively compresses the tensor into a few high-level meta-tokens.

Illustration of the Hierarchical Pooling Encoder (HPE) and Compact Associative Memory (CAM). The HPE converts large tensors into compact meta tokens through multi-layer pooling with Performers, which feed into the CAM memory that produces a learned readout for the optimizer.

Figure 1: Mnemosyne’s Hierarchical Pooling Encoder (HPE) and Compact Associative Memory (CAM) building blocks.

How HPE works:

  1. Split the tensor’s sequence of parameter encodings into manageable chunks.
  2. Use a bi-directional Performer (an efficient Transformer variant with linear attention) to encode each chunk in parallel.
  3. Pool the encoded chunks to form shorter sequences of meta-tokens.
  4. Repeat this process hierarchically until only a small fixed set of tokens remains.

The result is a compact representation of the tensor that captures spatial correlations among parameters with constant memory cost.


Temporal Encoder: A Memory That Never Forgets

Mnemosyne’s temporal encoder is designed to model the optimization process over time. It maintains a memory of all past steps and learns how to use this history to inform future updates.

The architecture implements a Compact Associative Memory (CAM) — a combination of Hopfield-like energy memory and Transformer-style attention. At each training step, CAM receives meta-tokens (from the spatial encoder) and updates a fixed-size hidden state composed of two key matrices, \( \mathbf{N}_t \) and \( \Psi_t \), defined as:

\[ \mathbf{N}_t = \sum_{\mu=1}^{t} \lambda_t(\mu)\,\phi(\mathbf{k}^\mu)(\mathbf{v}^\mu)^\top,\quad \Psi_t = \sum_{\mu=1}^{t} \lambda_t(\mu)\,\phi(\mathbf{k}^\mu) \]

Here, \( \phi(\cdot) \) represents randomized low-rank feature maps approximating the attention kernel (via the Performer framework), and \( \lambda_t \) is a discount factor that attenuates older memories. These quantities can be updated efficiently online:

\[ \mathbf{N}_{t+1} = e^{-\tau}\mathbf{N}_t + \phi(\mathbf{k}^{t+1})(\mathbf{v}^{t+1})^\top,\quad \Psi_{t+1} = e^{-\tau}\Psi_t + \phi(\mathbf{k}^{t+1}) \]

To produce the next parameter update, the CAM computes:

\[ \Delta\xi = \frac{\mathbf{N}_t^\top\phi(\mathbf{q})}{\phi(\mathbf{q})^\top\Psi_t} \]

This gives a convex combination of past value vectors weighted by similarity between the current query \( \mathbf{q} \) and stored keys \( \mathbf{k}^{\mu} \). Unlike standard attention, CAM doesn’t need to store all past keys and values explicitly. It achieves constant-time and constant-memory complexity with respect to training history — the holy grail for learned optimizers that scale.


Theoretical Backbone: Exponential Memory in a Compact Form

Mnemosyne’s CAM doesn’t just work empirically — it’s mathematically justified. The authors prove that this compact associative memory can store and retrieve an exponential number of distinct patterns relative to its dimensionality (Theorem 4.3).

In other words, even though the CAM only retains a condensed representation of past optimization states, it still exhibits the exponential memory capacity characteristic of full attention or Hopfield-like networks. This means Mnemosyne can “remember” a rich and diverse set of optimization trajectories while maintaining small, fixed size hidden states — explaining its impressive generalization ability.


Experiments: Mnemosyne in Action

Mnemosyne’s effectiveness isn’t theoretical; it’s demonstrated on multiple frontiers — fine-tuning vision models, pre-training language models, and scaling to enormous Transformers.


Warm-Up: Comparing Against Traditional Optimizers

The team started by benchmarking Mnemosyne against popular optimizers (Adam, RMSProp, SGD) and a recurrent LSTM-based learned optimizer. The target models were small Vision Transformers trained on standard datasets like MNIST and CIFAR.

Training loss curves for multiple optimizers. Mnemosyne consistently learns faster and reaches lower final loss than Adam, RMSProp, SGD, and LSTM-based optimizers, including when embedded within the VeLO architecture.

Figure 3: Mnemosyne outperforms conventional and LSTM-based optimizers when training ViTs and MLPs.

Mnemosyne achieved faster convergence and lower losses across all tasks, despite minimal meta-training. It effectively trained attention-based architectures that the LSTM optimizer struggled with, a sign of its robust memory mechanism.


Scaling Up: Coordinate-Wise Mnemosyne

Next came large-scale experiments. The coordinate-wise version, where each parameter has its own CAM module, was tested on fine-tuning Vision Transformers (ViT-H) and soft prompt-tuning the massive T5XXL model (11B+ parameters).

For ViT fine-tuning, Mnemosyne matched the optimal Adam variant — even though Adam had been meticulously tuned across different learning rates.

Accuracy curves for fine-tuning a ViT-H model across datasets. Mnemosyne matches or exceeds the best-tuned Adam learning rates without manual hyperparameter tuning.

Figure 4: Coordinate-wise Mnemosyne matches top-performing Adam variants for ViT fine-tuning.

Soft prompt-tuning the T5XXL model demonstrates Mnemosyne’s ability to handle extreme scale. Despite having only 12K parameters in its trainable prompt module, the optimization landscape is complex. Mnemosyne consistently achieved lower losses compared to all Adam variants.

Prompt tuning loss versus iterations for T5XXL. Mnemosyne achieves smoother convergence and lower final loss than several Adam baselines.

Figure 6: Mnemosyne in soft prompt-tuning of T5XXL.


Tensor-Wise Mnemosyne: Memory-Efficient Optimization

To train massive models end-to-end, the tensor-wise variant compresses memory by operating on tensors instead of individual parameters. The authors used this version to pre-train BERT-base (86M parameters) on the Masked Language Modeling (MLM) task.

Masked language modeling (MLM) loss versus iterations for BERT pre-training. Mnemosyne matches the best Adam variant while other settings converge to higher losses.

Figure 7: Tensor-wise Mnemosyne matches the best Adam performance in BERT pre-training.

Despite never seeing a language task during meta-training (which was done on small MNIST classifiers), Mnemosyne successfully generalized — achieving comparable or better results than finely tuned Adam optimizers.


Super-Mnemosyne: The Best of Two Worlds

Finally, to blend the simplicity of the coordinate-wise approach with the efficiency of the tensor-wise variant, the team created Super-Mnemosyne, a hybrid that uses coordinate-wise optimization for large tensors and tensor-wise for smaller ones.

Accuracy curves for fine-tuning a ViT-B model on multiple datasets. Super-Mnemosyne consistently performs at or above the highest Adam variants.

Figure 9: Super-Mnemosyne combines coordinate- and tensor-wise modes to achieve top-tier fine-tuning performance.

Across ViT fine-tuning tasks, Super-Mnemosyne consistently outperformed the best hand-tuned Adam baseline without any extra hyperparameter adjustments.


Why Mnemosyne Matters

Mnemosyne introduces an entirely new paradigm in optimization research:

  • Performance: It consistently matches or outperforms state-of-the-art optimizers like Adam and RMSProp.
  • No Hyperparameter Tuning: Works out-of-the-box, eliminating the extensive search for learning rates and momentum values.
  • Scalability: Efficient hierarchical spatial encoding and compact temporal memory make it viable for models with billions of parameters.
  • Generalization: Trained on simple tasks, Mnemosyne generalizes across architecture types, domains, and scales.

Looking Ahead

By bridging Transformers and learned optimization, Mnemosyne opens doors to a future where optimizers themselves are intelligent agents — architectures that understand and adapt to the dynamics of each model they train. Such systems could drastically reduce experimentation time, improve robustness, and even learn entirely new paradigms of training beyond gradient descent.

The Mnemosyne framework transforms optimization from a hand-engineered craft into a learnable capability — a true “Transformer that teaches Transformers how to learn.”