Transformers have taken the world by storm, powering everything from ChatGPT to advanced code completion tools. One of their most magical abilities is in-context learning (ICL) — the power to learn from examples provided in their input prompt, without any weight updates. If you show a large language model a few examples of a task, it can often perform that task on a new example instantly.
For a long time, how this works has been a bit of a mystery. Recent research has started to peel back the layers, suggesting that for simple tasks like linear regression, transformers internally run a form of gradient descent. Each attention layer acts like a single optimization step, refining an internal “solution” based on the data in the prompt.
But this raises a tantalizing question: is that all transformers can do? Are they limited to mimicking simple, well-known algorithms? A new paper from Google Research, “Linear Transformers are Versatile In-Context Learners,” argues that the answer is a resounding no. The researchers demonstrate that even simplified “linear” transformers can discover and implement remarkably sophisticated optimization algorithms — ones that adapt dynamically to data noise and outperform standard methods.
This work suggests that transformers aren’t just learners; they may be algorithmic inventors. Let’s explore how.
Setting the Stage: Linear Transformers and Noisy Problems
To isolate the core mechanisms behind in-context learning, the researchers focus on a minimalist model: the linear transformer. Unlike standard architectures, these models strip away MLP layers and LayerNorm, leaving only the linear self-attention layers.
A linear attention layer processes a sequence of tokens \( e_1, e_2, ..., e_n \). For each token \( e_i \), it computes an update by attending to all other tokens in the sequence:
Each head \(k\) produces an update based on learned matrices \(Q_k\) and \(P_k\); all heads’ updates are summed.
Each token \( e_i \) combines a feature vector and a label, \( e_i = (x_i, y_i) \). The input sequence contains \(n\) examples plus one query token \( e_{n+1} = (x_t, 0) \). The transformer’s goal is to predict the correct label \(y_t\) for query \(x_t\).
The researchers explore two problem setups:
Fixed Noise Variance:
Every sequence is generated with the same noise level \( \sigma^2 \). Previous work showed that transformers trained on this type of data learn an algorithm similar to gradient descent, called GD++.Mixed Noise Variance:
Each sequence has a different noise level, drawn from a distribution like \( U(0, \sigma_{\max}) \).
This scenario is harder: the optimal solution requires ridge regression, which adds a regularization term that depends on the noise variance.
Ridge Regression adds a penalty to large weights, improving robustness to data noise.
Because the noise level changes with every prompt, the transformer must infer it from the data and adapt — a much more difficult task than simple least squares.
The Core Insight: Each Layer Updates an Internal Model
The researchers’ first major theoretical result explains what the linear transformer is doing internally. Each layer maintains an implicit linear regression model and updates it as data flows through.
At layer \(l\), each token’s output label can be written as
\[ y_i^{l+1} = a^l y_i - \langle w^l, x_i \rangle, \]where \(w^l\) is an implicit weight vector and \(a^l\) is a learned scaling coefficient. These values evolve across layers — the model effectively performs iterative updates to its internal regression parameters.
Each layer computes updates using quantities like feature covariances and cross-correlations from the previous layer.
Under simplifying assumptions (restricting to diagonal attention matrices), the updates for \(w^l\) and a helper “momentum” vector \(u^l\) behave like:
These updates echo gradient descent with momentum.
In compact form, they parallel classic momentum dynamics:
\[ u^{l+1} = (1-\beta)u^l + \nabla f(w^l), \]\[ w^{l+1} = w^l - \eta u^l. \]The analogy isn’t perfect — the transformer’s coefficients are matrices, not scalars — but the similarity is striking. Each layer acts like a sophisticated optimization step on an internal model, not merely a data transformation.
Dissecting the Learned Algorithm
Surprisingly, models with diagonal attention weights performed nearly as well as models with full matrices. This allowed the team to analyze the learned algorithm piece by piece using four scalar parameters:
\( \omega_{xx}, \omega_{xy}, \omega_{yx}, \omega_{yy} \).
These parameters describe the flow of information between features and labels across layers.
Each parameter controls specific dynamics:
1. Preconditioned Gradient Descent (\( \omega_{yx} \), \( \omega_{xx} \))
When only these two components are active, the transformer performs GD++ — a form of gradient descent enhanced with preconditioning.
GD++ updates \(x\) through preconditioning while adjusting \(y\) via gradient steps.
The authors prove that GD++ operates as a second-order optimization method, akin to Newton’s algorithm. It converges in \(O(\log\log (1/\epsilon))\) steps, explaining its efficiency on simpler regression tasks.
2. Adaptive Rescaling (\( \omega_{yy} \))
The component \( \omega_{yy} \) introduces noise-aware scaling. Its update rule simplifies to:
Large label energy (\(\lambda^l = \sum_i (y_i^l)^2\)) triggers stronger scaling adjustments.
When the data exhibits high variance (larger \(\lambda^l\)), a negative \(\omega_{yy}\) shrinks outputs — mirroring ridge regression’s regularization. It automatically adjusts based on noise intensity, implementing an adaptive correction to stabilize learning.
3. Adaptive Step Sizes (\( \omega_{xy} \))
The last term, \( \omega_{xy} \), allows dynamic step-size control. Its effect unfolds in two layers:
- A first layer alters \(x_i\) according to residual errors.
- A second layer performs gradient descent using the adjusted features.
The model automatically modulates its effective step size using data residuals.
The step size becomes proportional to \( (1 + \omega_{xy} \sum_i r_i^2) \), where \( \sum_i r_i^2 \) estimates noise variance. A negative \( \omega_{xy} \) yields smaller steps in noisy regimes — an adaptive “early stopping” mechanism learned from experience.
Together, these elements form a finely tuned optimization procedure that combines advanced gradient, scaling, and noise-sensitive adjustments — discovered automatically by the transformer.
Experiments: Testing Emergent Optimization
To evaluate these insights, the team trained three linear transformer variants on mixed-noise regression tasks:
- FULL: Full-matrix attention parameters.
- DIAG: Diagonal matrix restriction.
- GD++: Simplified variant emulating standard preconditioned gradient descent.
They compared these to Ridge Regression baselines, including a tuned version (TUNEDRR
) with idealized noise estimation.
For larger noise ranges and deeper models, DIAG and FULL outperform baselines. GD++ fails to adapt to variable noise.
With increasing layer depth and noise diversity (\(\sigma_{\max}\)), DIAG and FULL achieved near-perfect noise adaptation, even exceeding tuned Ridge Regression for high-noise settings.
To analyze performance across specific noise levels:
GD++ learns for one effective noise level only; DIAG and FULL handle all noise levels robustly.
GD++ stagnates near constant performance, revealing it assumes an average noise variance. In contrast, DIAG and FULL learn flexible strategies effective across the full range.
When examining the learned weights:
Learned matrices become almost perfectly diagonal, reinforcing theoretical simplifications.
The heavy diagonal dominance confirms that the diagonal analysis accurately captures the transformer’s essential behavior.
Conclusion: Transformers as Algorithm Discoverers
This research opens a window into transformers’ inner workings, revealing that even simple linear variants can discover powerful optimization algorithms through training alone.
Key insights include:
Implicit Optimization:
Each layer performs iterative updates akin to gradient descent with momentum — but richer, involving matrix dynamics and adaptive control.Algorithm Discovery under Noise:
Faced with noisy data, transformers spontaneously invent strategies like adaptive rescaling and variable step-size control, crucial for robust optimization.Empirical Superiority:
The learned algorithms rival or surpass closed-form solutions like tuned Ridge Regression, highlighting transformers’ ability to craft effective solutions autonomously.
In a broader sense, this work positions transformers not only as learners but as emergent algorithm designers. By training on well-chosen problems, we could harness them to discover new classes of optimization and learning algorithms — extending beyond the boundaries of conventional machine learning design.
This study shines light on how complex reasoning and computation can emerge from simple architectural principles, and it invites future research toward automatic algorithm discovery through neural networks.