在过去几年里,Transformer 架构一直是语言建模领域无可争议的王者。从 GPT-3 到 PaLM,大规模 Transformer 模型重新定义了业界的顶尖水平。但这种强大力量是有代价的: 作为 Transformer 核心的注意力机制,其计算和内存开销随序列长度呈二次方增长。处理一个两倍长的序列需要四倍的计算和内存。这使得处理超长文档、代码库或音频文件成为一项重大挑战。
有没有另一种方式?一种架构,其扩展性几乎与序列长度呈线性关系——在处理长序列时极为高效——同时仍然能匹敌注意力机制的建模能力?
状态空间模型 (SSM) 登场了。SSM 在音频生成和时间序列分析等领域表现出色,但在复杂的语言领域一直落于 Transformer 之后。斯坦福大学的一篇新论文《饿饿河马: 迈向基于状态空间模型的语言建模》(Hungry Hungry Hippos: Towards Language Modeling with State Space Models) 深入剖析了这一性能差距,诊断了其根本原因,并提出了一种新颖架构,不仅弥补了差距,在某些情况下甚至超越 Transformer。
这篇论文做出了两大贡献:
- H3 (Hungry Hungry Hippo): 一个基于 SSM 的新网络层,旨在解决以往 SSM 在语言任务中的特定短板。
- FLASHCONV: 一种硬件感知的算法,使 SSM 的训练和运行速度显著提升,克服了长期以来阻碍其发展的效率瓶颈。
让我们深入了解,看看这些饿饿河马是如何挑战 Transformer 的统治地位的。
背景: 什么是状态空间模型?
在介绍河马之前,我们需要了解状态空间模型的基本原理。SSM 源于控制理论,是用于对随时间演变的系统进行建模的强大工具。
离散时间 SSM 的核心,是通过一个隐藏的“状态”向量 \(x_i\),将输入序列 \(u_1, u_2, \ldots\) 映射到输出序列 \(y_1, y_2, \ldots\)。状态是序列历史到当前时刻的压缩摘要。该过程由两个简单的线性方程控制:
一个离散时间 SSM: 隐藏状态 \(x_i\) 使用当前输入 \(u_i\) 从 \(x_{i-1}\) 更新而来,输出 \(y_i\) 则由该状态计算得出。矩阵 A、B、C 和 D 是通过学习得到的。
这种形式本质上是循环的。计算第 \(i\) 步的输出时,只需要当前输入 \(u_i\) 和上一步的状态 \(x_{i-1}\)。这使得 SSM 在**推理 **(如一次生成一个词) 时非常高效,因为它们无需在每一步都重新读取整个历史。
与卷积的联系
循环视角非常适合生成任务,但在训练中,处理整段序列时会很慢。幸运的是,有一个巧妙的数学技巧: 任何线性时不变系统 (如 SSM) 也可以表示为卷积。
展开循环,输出 \(y\) 可以写成输入序列 \(u\) 与一个由学习到的矩阵确定的特殊滤波器 \(f\) 的卷积:
卷积滤波器 \(f = [CB, CAB, CA^2B, \dots]\) 编码了每个过去输入对当前输出的影响。
这意味着,我们可以通过执行 \(y = f * u\) 并行计算整个输出序列。而高效计算卷积的方法是什么?快速傅里叶变换 (FFT) 。借助 FFT,计算复杂度随序列长度 \(N\) 变化为 \(O(N \log N)\)——相对于 Transformer 的 \(O(N^2)\) 是巨大的渐进性提升。
那么,如果 SSM 如此高效,为什么它们没有在语言建模领域取而代之呢?
问题所在: 为何 SSM 在语言任务中表现欠佳
为了找出 SSM 逊于注意力机制的原因,研究人员使用了合成语言任务——这些简单的实验任务用于测试模型的特定能力。如果模型在这些任务上失败,则可能缺乏应对真实语言复杂度所需的基础构件。
他们聚焦于两个任务:
合成语言任务针对特定记忆和比较能力。
- 归纳头 (Induction Head): 模型看到一串随机词元,后跟一个特殊词元 (如
vdash
) ,再跟更多随机词元。目标是预测在前面序列中紧跟在该特殊词元之后的词元。 - 关联回忆 (Associative Recall): 模型看到一系列键–值对 (如
a 2 c 4 b 3
) 。末尾给一个键 (如a
) ,模型必须输出对应的值 (2
) 。
这些任务对 Transformer 来说轻而易举,因为注意力机制能够直接“回看”并比较词元。那么现有的 SSM 表现如何呢?
现有的 SSM 在需要精确回忆和跨序列比较的任务上表现失败。
常见 SSM 变体,如 S4D 和门控状态空间 (GSS) ,表现惨淡。这揭示了两项缺失能力:
- 回顾过去的特定词元。
- 跨序列比较不同的词元。
注意力机制在这两方面都表现卓越: 它构建了完整的 \(N \times N\) 比较矩阵 (\(QK^T\)) 并通过值向量 (\(V\)) 直接复制信息。目标是让 SSM 获得类似能力。
解决方案之一: 饿饿河马 (H3)
H3 层是一种全新的 SSM 架构,专为解决这些短板而生。核心思想是: 将两种不同 SSM 结合乘法交互,模拟注意力机制的比较–存储–回忆过程。
在 H3 中,一个移位 SSM 和一个对角 SSM 通过乘法门协同工作,在序列中存储和检索信息。
直观解释:
- 移位 SSM: 应用于 K 投影,其中 A 是一个移位矩阵。它形成对于近期输入的短期记忆——非常适合记住当前词元之前的内容。
- 对角 SSM: 基于 S4D,可记忆整段序列的信息。
- 乘法门控 (\(\odot\)) 提供选择性存储与检索:
- 第一层门:
SSM_shift(K) ⊙ V
仅在上一个词元匹配特定键时才存储值。 - 第二层门:
Q ⊙ SSM_diag(...)
仅在当前词元匹配查询键时才回忆该值。
- 第一层门:
核心 H3 操作: 乘法交互让记忆以键匹配为条件。
如图 1 (中) 所示,这个机制精准解决了关联回忆任务: 移位 SSM + 第一门充当存储,对角 SSM + 第二门充当回忆。
真实语言基准
这种增强的表达能力对实际语言建模影响巨大。
在 OpenWebText 上,H3 的困惑度几乎与 Transformer 持平,而混合模型则超越 Transforme。
纯 H3 模型将与 Transformer 的差距缩小到仅 0.4 困惑度点,较此前的 SSM 进步巨大。一个简单的 H3–注意力混合模型,仅有 2 层注意力,其余都是 H3,性能比纯 Transformer 高出整整一个点。
解决方案之二: 面向硬件效率的 FLASHCONV
H3 拥有了建模能力——但速度呢?从渐近复杂度上讲,SSM 更快 (\(O(N \log N)\) 对比 \(O(N^2)\)) ,但在真实 GPU 上往往更慢。原因是硬件抽签: 现代 GPU 针对注意力机制的巨型矩阵乘法进行了优化,而非 FFT。
解决办法是 FLASHCONV,一种 I/O 感知的卷积算法,灵感源自 FlashAttention,包含两大核心:
- 融合块 FFTConv —— 适用不超过 8K 词元的序列。
- 核函数融合: 将 FFT、乘法与逆 FFT 合并为单个核函数,最大限度减少高速 SRAM 与较慢 HBM 之间的数据传输开销。
- 块 FFT: 将 FFT 重写为一系列小矩阵乘法,充分利用 Tensor Core。
- 状态传递 —— 适用于超过 8K 词元的序列。
- 将输入切分成符合 SRAM 容量的块。
- 将每个块的最终状态作为下一块的初始状态向前传递。
- 重复此过程直至处理完整个序列。
状态传递让 FLASHCONV 能够高效处理任意长的序列。
FLASHCONV 性能
性能提升巨大。在长程竞技场 (LRA) 基准中,采用 FLASHCONV 的 S4 比 Transformer 快 5.8 倍。
在将 H3 与优化版 FlashAttention 对比的基准测试中,FLASHCONV 保持了近线性扩展,并在长序列上超越了注意力机制。
核函数融合加速短序列,块 FFT 优化中等序列,状态传递在长序列中占据优势。
对于长度 ≥16K 的序列,带状态传递的 SSM 甚至比最快的注意力实现快几十倍。
整合: 扩展 H3
有了表现力更强的新层 (H3) 和高速算法 (FLASHCONV) ,团队将 H3–注意力混合模型扩展到 1.25 亿至 27 亿参数规模,并在大规模 Pile 数据集上完成训练。
困惑度优势
在 The Pile、OpenWebText 和 WikiText-103 数据集上,H3 混合模型稳压同规模 Transformer。
零样本/少样本 SuperGLUE
在多数 SuperGLUE 任务上,H3 混合模型表现持平或超过 Transformer。
更快的推理
由于循环特性,H3 混合模型生成文本速度比 Transformer 快 2.4 倍——这对实时系统至关重要。
结论: 王座的新挑战者
《饿饿河马》提出了有力论证: 纯注意力模型可能不再是序列建模领域的唯一统治者。通过诊断 SSM 弱点并设计 H3 予以弥补,研究人员为语言模型构建了强大的新组件。
借助 FLASHCONV,他们清除了效率障碍——让 SSM 不仅在理论上更快,在实践中也更快。
一个简单的混合模型——多为 SSM,少量注意力层——性能就能优于纯 Transformer。未来或许属于融合互补优势的架构,而非固守单一路线。
凭借更好扩展性、更快推理速度和顶级表现,这些饿饿河马已对序列建模的王冠发起严肃挑战。