Transformer 架构凭借其强大的自注意力机制,彻底改变了机器学习领域。从 GPT 模型生成媲美人类的文本,到创造令人惊叹的图像,其影响力毋庸置疑。自注意力的核心在于: 在处理某个输入片段时,模型能够权衡所有其他输入片段的重要性;这赋予了它对数据全面、全局的理解能力。

但这种能力也伴随着高昂的代价: 自注意力的计算和内存成本会随着序列长度二次方增长——即 \(O(n^2)\)。这意味着序列长度翻倍,成本就会增加到原来的四倍。对于几千个词元 (token) 的序列,这尚可接受。但如果要处理一整本书、一张高分辨率图像或一首完整的交响乐呢?二次方扩展会迅速成为令人望而却步的瓶颈,使得将 Transformer 应用于真正的长序列变得极其困难。

研究人员尝试过使用一些巧妙的技巧来规避这一限制。一些方法,如局部注意力,将模型的注意力范围限制在一个固定大小的近期输入窗口内。这种方法效率高,但牺牲了连接远处重要信息的能力。另一些方法探索基于内容的稀疏模式,但它们通常需要先计算出完整且密集的注意力矩阵再进行稀疏化——这违背了提升效率的初衷。

这正是论文 Efficient Content-Based Sparse Attention with Routing Transformers 的切入点。作者提出了一种结合两者优势的方案: 兼具基于内容的注意力的建模灵活性与原生稀疏、局部方法的效率。他们的模型——路由 Transformer (Routing Transformer) ——能够在不计算庞大的完整注意力矩阵的情况下,学习动态、依赖数据的注意力模式。它将复杂度从 \(O(n^2d)\) 降至更易管理的 \(O(n^{1.5}d)\),使得处理比以往长几个数量级的序列成为可能,并在多个具有挑战性的基准测试中创造了新的 SOTA 纪录。

让我们深入看看他们是如何做到的。


快速回顾: 自注意力机制

在介绍新机制之前,我们先简单回顾一下标准自注意力的工作原理,尤其是在自回归任务中,比如逐字生成文本或逐像素生成图像。

自回归模型一次生成序列 \(x = (x_1, \dots, x_n)\) 的一个元素,建模在给定所有先前元素的条件下每个元素的概率:

自回归序列建模的公式。

在每一步中,Transformer 会将输入矩阵 \(X\) (形状为 \(n \times d\)) 输入一系列自注意力层。在每个层内,输入被投影为三组矩阵——查询 (Queries, Q)键 (Keys, K)值 (Values, V) :

生成查询、键和值的公式。

可以这样理解:

  • 查询 (Query) — 当前词元在问: “过去的序列中,谁与我相关?”
  • 键 (Key) — 过去的某个词元在说: “我代表的是这个信息。”
  • 值 (Value) — 过去的某个词元在说: “如果你觉得我相关,这是我的信息。”

为了确定相关性,模型会计算每个查询 \(Q_i\) 与每个键 \(K_j\) 的点积。结果缩放后经过 softmax 函数,得到注意力矩阵 \(A\)。在自回归任务中,会使用**因果掩码 **(下三角矩阵) 来避免看到未来的词元:

因果注意力矩阵的公式。

最后,每个位置的输出通过对值向量进行加权求和得到:

将输出计算为值向量加权和的公式。

之后是残差连接和层归一化:

残差连接和层归一化的公式。

核心问题在于,对于长度为 \(n\) 的序列,注意力矩阵 \(A\) 的大小为 \(n \times n\): 存储和计算它会遇到可怕的 \(O(n^2)\) 瓶颈。


核心思想: 基于内容的稀疏注意力

如果每个查询不是关注所有之前的键,而是只关注一个经过精心挑选的小子集呢?我们可以定义一个集合 \(S_i\),其中包含位置 \(i\) 的查询可以关注的键的索引:

稀疏注意力的公式。

局部注意力使用固定的近期窗口。步进注意力则以固定间隔采样键。下图展示了这些模式以及路由 Transformer 学到的模式:

不同注意力模式的示意图。局部注意力呈对角带状。步进注意力呈多条对角条纹状。路由注意力则是在对角线周围散布着彩色的块。

图 1: 注意力模式对比。(a) 局部注意力只关注邻近窗口。(b) 步进注意力按固定步长回顾。(c) 路由注意力动态学习基于内容的簇。

固定的稀疏模式过于死板。如果重要信息出现在很早的上下文中,局部注意力就会错过它。路由 Transformer 的目标是在不计算完整矩阵的情况下,使 \(S_i\) 依赖内容。


使用 k-Means 聚类的路由注意力

直觉是: 如果查询和键在语义上相似,它们很可能应该相互关注。为了在超长序列中高效找到这些匹配,分组是有效的方法。

路由 Transformer 会学习 \(k\) 个质心向量,代表嵌入空间中的簇。在每一步中:

  1. 将每个查询分配给距离最近的质心。
  2. 将每个键分配给距离最近的质心。

然后,注意力被限制在同一簇内的查询和键之间:

路由注意力的公式,其中一个查询只关注同一簇中的键。

这样就产生了一个如图 1(c) 所示的动态、内容感知的稀疏模式。相关的词元——无论时间距离多远——都会被聚类到一起并相互关注。


理论基础: 近似最大内积搜索 (MIPS)

在点积注意力中,得分 \(Q_i^\top K_j\) 衡量相关性。对每个查询找到最高分便是在求解最大内积搜索 (Maximum Inner Product Search, MIPS) 问题:

最大内积搜索 (MIPS) 问题的公式。

MIPS 代价昂贵。但如果向量被归一化为单位长度 (位于超球面上) ,MIPS 等价于最近邻搜索 (Nearest Neighbor Search) :

展示单位向量的欧氏距离和点积之间关系的公式。

对于单位向量,最小化欧氏距离 \(\|Q_i - K_j\|^2\) 与最大化点积等价——而 k-means 聚类是分组最近邻的天然工具。根据三角不等式:

展示向量与其质心之间三角不等式的公式。

如果 \(Q_i\) 和 \(K_j\) 都接近同一质心,它们之间也会很接近,从而保证较高的点积。因此,聚类为 MIPS 提供了一种高效且有理论依据的近似方法。


复杂性与实现细节

新的注意力复杂性如下:

  1. 聚类: 将 \(n\) 个查询和 \(n\) 个键分配给 \(k\) 个质心: \(O(nkd)\)。
  2. 簇内注意力: 每个查询关注约 \(n/k\) 个键: \(O(n \cdot (n/k) \cdot d)\)。

总计: \(O(nkd + n^2d/k)\)。选择 \(k = \sqrt{n}\) 可平衡两项,最终得到总复杂度 \(O(n^{1.5}d)\)

为保持簇大小平衡以便高效并行计算,作者为每个质心选择前 \(w\) 个最接近的查询和键,其中 \(w = n/k\),而不是分配所有最近向量。质心为可学习参数,通过指数移动平均更新:

使用指数移动平均更新簇质心的公式。


实验: 从 CIFAR-10 到 PG-19

在大多数实验中 (PG-19 除外) ,一半注意力头使用局部注意力,另一半使用路由注意力。这种混合保持了强大的局部结构,同时实现了全局、依赖内容的连接。

CIFAR-10 消融研究
在 CIFAR-10 小型图像数据集上,消融研究比较了不同数量的路由头/层及不同注意力窗口大小:

展示 CIFAR-10 消融研究结果的表格。

表 1: 消融研究表明,路由注意力优于局部注意力和全注意力。

主要发现:

  1. 局部注意力很强大: 3.009 bits/dim 的性能高于全注意力的 2.983,且速度更快。
  2. 路由注意力有帮助: 增加路由头/层能降低 bits/dim,最佳为 2.971。
  3. 内容很重要: 随机选择会损害性能 (3.076 bits/dim) 。

Wikitext-103 (语言建模)
在此长上下文基准中,路由 Transformer 取得了 15.8 的困惑度——击败 Transformer-XL 的 18.3,且层数更少:

展示 Wikitext-103 实验结果的表格。路由 Transformer 取得了新的 SOTA 成绩。

表 2: 路由 Transformer 在 Wikitext-103 上创造了新的 SOTA 纪录。


ImageNet-64 (自回归图像生成)
将图像视作 12,288 像素序列: 路由 Transformer 获得 3.43 bits/dim,超越之前最佳的 3.44:

展示 ImageNet-64 实验结果的表格。路由 Transformer 取得了新的 SOTA 成绩。

表 4: SOTA 图像生成。


PG-19 (文档级语言建模)
整本书平均 6.9 万词——经典长序列压力测试。22 层的路由 Transformer 取得 33.2 困惑度,击败更重的 36 层 Compressive Transformer:

展示 PG-19 实验结果的表格。路由 Transformer 取得了新的 SOTA 成绩。

表 5: 路由 Transformer 在超长文本建模方面表现突出。


为什么混合模式有效: 局部 + 路由

为验证混合模式的有效性,作者测量了局部注意力与路由注意力头的注意力分布之间的 Jensen-Shannon 散度 (JS Divergence) :

展示局部注意力和路由注意力头之间 JS 散度的表格。

表 6: 局部注意力与路由注意力头之间的高散度证实了它们的互补性。

结论:

  • 局部注意力头: 彼此散度低——学到相似的局部邻近模式。
  • 路由注意力头: 与局部注意力差异大——捕捉长距离、基于内容的连接。

局部注意力构建流畅的短程结构 (语法、局部衔接) 。
路由注意力确保全局一致性 (保持主题、跨数千词元的指代解析) 。


结论与启示

路由 Transformer 提供了一种强大高效的方法来克服自注意力的二次方瓶颈:

  1. 打破二次方壁垒: 将复杂度从 \(O(n^2d)\) 降至 \(O(n^{1.5}d)\),可直接在超长序列上训练。
  2. 两全其美: 结合局部注意力的效率与归纳偏置,以及基于内容的注意力的灵活性。
  3. 性能卓越: 在语言与图像生成的多个大规模基准上创造新的 SOTA 纪录。
  4. 拓展可能性: 为 Transformer 应用于长篇数据打开大门——包括文档摘要、翻译、基因组学、高分辨率视频等。

路由 Transformer 优雅地驯服了自注意力的二次方猛兽,让模型能够真正从全局视角进行理解与生成。