The Transformer architecture, with its powerful self-attention mechanism, has revolutionized machine learning. From generating human-like text with GPT models to creating stunning images, its impact is undeniable. At its heart, self-attention allows a model to weigh the importance of every single piece of input when processing any other piece. This gives it a comprehensive, global understanding of the data.
But this power comes at a steep price: the computational and memory costs of self-attention grow quadratically with sequence length — \(O(n^2)\). This means that doubling the sequence length quadruples the cost. For sequences of a few thousand tokens, this is manageable. But what about modeling an entire book, a high-resolution image, or a full-length symphony? The quadratic scaling quickly becomes a prohibitive bottleneck, making it incredibly difficult to apply Transformers to truly long sequences.
Researchers have tried to sidestep this limitation with clever tricks. Some methods, like local attention, restrict the model to look at only a small, fixed window of recent inputs. This is efficient but sacrifices the ability to connect distant, important pieces of information. Other methods have explored content-based sparsity patterns, but often they require calculating the full, dense attention matrix before sparsifying — defeating the purpose of efficiency.
This is where the paper Efficient Content-Based Sparse Attention with Routing Transformers comes in. The authors propose a solution that combines the best of both worlds: the modeling flexibility of content-based attention with the efficiency of native sparse, local methods. Their model — the Routing Transformer — learns dynamic, data-dependent attention patterns without computing the enormous full attention matrix. It reduces complexity from \(O(n^2d)\) to a much more manageable \(O(n^{1.5}d)\), enabling processing of sequences orders of magnitude longer than before and setting new state-of-the-art results on several challenging benchmarks.
Let’s dive in and see how they did it.
A Quick Refresher on Self-Attention
Before we get to the new mechanism, let’s quickly recap how standard self-attention works, especially in autoregressive tasks like generating text or images pixel-by-pixel.
An autoregressive model generates a sequence \(x = (x_1, \dots, x_n)\) one step at a time, modeling the probability of each element given all previous ones:
At each step, the Transformer processes the input matrix \(X\) (shape \(n \times d\)) through a series of self-attention layers. Inside each layer, the input is projected into three matrices — Queries (Q), Keys (K), and Values (V):
Think of it like this:
- Query — the current token asking, “Who in the past is relevant to me?”
- Key — a past token saying, “This is what I represent.”
- Value — a past token saying, “If you find me relevant, here’s my information.”
To determine relevance, the model computes dot products between every query \(Q_i\) and every key \(K_j\). The results are scaled and passed through a softmax, creating an attention matrix \(A\). In autoregressive settings, a causal mask (lower triangular) prevents looking at future tokens:
Finally, the output for each position is computed as a weighted sum over the Value vectors:
This is followed by residual connections and layer normalization:
The core problem is that the attention matrix \(A\) for length \(n\) sequences is size \(n \times n\): storing and computing it induces the dreaded \(O(n^2)\) bottleneck.
The Core Idea: Content-Based Sparse Attention
Instead of every query attending to every previous key, what if it could attend to only a small, carefully chosen subset? We can define a set \(S_i\) with the indices of keys that the query at position \(i\) can attend to:
Local attention uses a fixed recent window. Strided attention samples keys at regular intervals. The figure below visualizes these and the Routing Transformer’s learned pattern:
Figure 1: Visual comparison of attention patterns. (a) Local attention only looks at a nearby window. (b) Strided attention looks at regular steps into the past. (c) Routing attention dynamically learns content-based clusters.
Fixed sparsity patterns are rigid. If an important fact appeared far back in the context, local attention would miss it. Routing Transformer’s goal is to make \(S_i\) content-dependent without computing the full matrix.
Routing Attention with k-Means Clustering
The intuition: if a query and a key are semantically similar, they should probably attend to each other. To find these matches efficiently in very long sequences, grouping helps.
The Routing Transformer learns \(k\) centroid vectors representing clusters in the embedding space. At each step:
- Assign each query to its nearest centroid.
- Assign each key to its nearest centroid.
Attention is then restricted to queries and keys in the same cluster:
This creates a dynamic, content-aware sparsity pattern like Figure 1(c). Related tokens — regardless of temporal distance — are clustered and attend to each other.
The Theory: Approximating Maximum Inner Product Search (MIPS)
In dot-product attention, the score \(Q_i^\top K_j\) measures relevance. Finding the highest scores amounts to solving Maximum Inner Product Search for each query:
MIPS is expensive. But if vectors are normalized to unit length (on a hypersphere), MIPS becomes equivalent to Nearest Neighbor Search:
Minimizing Euclidean distance \(\|Q_i - K_j\|^2\) is the same as maximizing the dot product for unit vectors — and k-means clustering is a natural tool for grouping nearest neighbors. By triangle inequality:
If both \(Q_i\) and \(K_j\) are close to the same centroid, they’re close to each other, ensuring high dot product. Thus, clustering gives an efficient, principled approximation to MIPS.
Complexity & Implementation Details
New attention complexity:
- Clustering: Assigning \(n\) queries and \(n\) keys to \(k\) centroids: \(O(nkd)\).
- Attention within clusters: Each query attends to \(\approx n/k\) keys: \(O(n \cdot (n/k) \cdot d)\).
Total: \(O(nkd + n^2d/k)\). Choosing \(k = \sqrt{n}\) balances the terms, yielding overall \(O(n^{1.5}d)\).
To keep clusters balanced for efficient parallelism, the authors select the top-\(w\) closest queries and keys for each centroid, where \(w = n/k\), instead of assigning all nearest vectors. Centroids are learned parameters, updated by exponential moving average:
Experiments: From CIFAR-10 to PG-19
In most experiments (except PG-19), half the attention heads use local attention, half use routing. This hybrid maintains strong local structure while enabling global, content-dependent links.
CIFAR-10 Ablations
On the small CIFAR-10 image dataset, ablation studies compared different numbers of routing heads/layers and attention windows:
Table 1: Ablations show routing attention improves performance over both local and full attention.
Key findings:
- Local attention is strong: 3.009 bits/dim vs. full attention’s 2.983, with greater speed.
- Routing helps: Adding routing heads/layers lowers bits/dim, best at 2.971.
- Content matters: Random selection hurts performance (3.076 bits/dim).
Wikitext-103 (Language Modeling)
On this long-context benchmark, Routing Transformer achieved 15.8 perplexity — beating Transformer-XL’s 18.3, with fewer layers:
Table 2: Routing Transformer sets a new state-of-the-art on Wikitext-103.
ImageNet-64 (Autoregressive Image Generation)
Images as 12,288-pixel sequences: Routing Transformer got 3.43 bits/dim, outperforming prior best 3.44:
Table 4: State-of-the-art image generation.
PG-19 (Document-Level Language Modeling)
Full books averaging 69k words — a classic long-sequence stress test. The 22-layer Routing Transformer achieved 33.2 perplexity, beating the heavier 36-layer Compressive Transformer:
Table 5: Routing Transformer excels at ultra-long text modeling.
Why the Hybrid Works: Local + Routing
To understand the hybrid’s effectiveness, the authors measured the Jensen-Shannon Divergence between attention distributions of local vs. routing heads:
Table 6: High divergence between local and routing heads confirms complementarity.
Findings:
- Local Heads: Low divergence with each other — they learn similar, adjacent patterns.
- Routing Heads: Very different from local — capture long-range, content-based links.
Local builds fluent short-range structure (syntax, local cohesion).
Routing ensures global consistency (maintaining themes, resolving references over thousands of tokens).
Conclusion & Takeaways
The Routing Transformer offers a powerful, efficient way to overcome the quadratic bottleneck of self-attention:
- Breaking the Quadratic Barrier: From \(O(n^2d)\) to \(O(n^{1.5}d)\), enabling training directly on very long sequences.
- Best of Both Worlds: Combines local efficiency and inductive bias with the flexibility of content-based attention.
- Proven Performance: Sets new state-of-the-art results on multiple large-scale benchmarks in language and image generation.
- Expanding Possibilities: Opens the door to Transformer applications on long-form data — document summarization, translation, genomics, high-res video, and more.
The Routing Transformer elegantly tames the quadratic beast of self-attention, enabling models to see, understand, and generate with a truly global perspective.