在过去几年里,Transformer 架构一直是语言建模领域无可争议的王者。从 GPT-3 到 PaLM,大规模 Transformer 模型重新定义了业界的顶尖水平。但这种强大力量是有代价的: 作为 Transformer 核心的注意力机制,其计算和内存开销随序列长度呈二次方增长。处理一个两倍长的序列需要四倍的计算和内存。这使得处理超长文档、代码库或音频文件成为一项重大挑战。

有没有另一种方式?一种架构,其扩展性几乎与序列长度呈线性关系——在处理长序列时极为高效——同时仍然能匹敌注意力机制的建模能力?

状态空间模型 (SSM) 登场了。SSM 在音频生成和时间序列分析等领域表现出色,但在复杂的语言领域一直落于 Transformer 之后。斯坦福大学的一篇新论文《饿饿河马: 迈向基于状态空间模型的语言建模》(Hungry Hungry Hippos: Towards Language Modeling with State Space Models) 深入剖析了这一性能差距,诊断了其根本原因,并提出了一种新颖架构,不仅弥补了差距,在某些情况下甚至超越 Transformer。

这篇论文做出了两大贡献:

  1. H3 (Hungry Hungry Hippo): 一个基于 SSM 的新网络层,旨在解决以往 SSM 在语言任务中的特定短板。
  2. FLASHCONV: 一种硬件感知的算法,使 SSM 的训练和运行速度显著提升,克服了长期以来阻碍其发展的效率瓶颈。

让我们深入了解,看看这些饿饿河马是如何挑战 Transformer 的统治地位的。


背景: 什么是状态空间模型?

在介绍河马之前,我们需要了解状态空间模型的基本原理。SSM 源于控制理论,是用于对随时间演变的系统进行建模的强大工具。

离散时间 SSM 的核心,是通过一个隐藏的“状态”向量 \(x_i\),将输入序列 \(u_1, u_2, \ldots\) 映射到输出序列 \(y_1, y_2, \ldots\)。状态是序列历史到当前时刻的压缩摘要。该过程由两个简单的线性方程控制:

离散时间状态空间表示通过一个状态变量,定义了从离散输入信号到离散输出信号的线性映射。

一个离散时间 SSM: 隐藏状态 \(x_i\) 使用当前输入 \(u_i\) 从 \(x_{i-1}\) 更新而来,输出 \(y_i\) 则由该状态计算得出。矩阵 ABCD 是通过学习得到的。

这种形式本质上是循环的。计算第 \(i\) 步的输出时,只需要当前输入 \(u_i\) 和上一步的状态 \(x_{i-1}\)。这使得 SSM 在**推理 **(如一次生成一个词) 时非常高效,因为它们无需在每一步都重新读取整个历史。


与卷积的联系

循环视角非常适合生成任务,但在训练中,处理整段序列时会很慢。幸运的是,有一个巧妙的数学技巧: 任何线性时不变系统 (如 SSM) 也可以表示为卷积

展开循环,输出 \(y\) 可以写成输入序列 \(u\) 与一个由学习到的矩阵确定的特殊滤波器 \(f\) 的卷积:

SSM 卷积滤波器 f 由学习到的矩阵 A、B 和 C 构建。

卷积滤波器 \(f = [CB, CAB, CA^2B, \dots]\) 编码了每个过去输入对当前输出的影响。

这意味着,我们可以通过执行 \(y = f * u\) 并行计算整个输出序列。而高效计算卷积的方法是什么?快速傅里叶变换 (FFT) 。借助 FFT,计算复杂度随序列长度 \(N\) 变化为 \(O(N \log N)\)——相对于 Transformer 的 \(O(N^2)\) 是巨大的渐进性提升。

那么,如果 SSM 如此高效,为什么它们没有在语言建模领域取而代之呢?


问题所在: 为何 SSM 在语言任务中表现欠佳

为了找出 SSM 逊于注意力机制的原因,研究人员使用了合成语言任务——这些简单的实验任务用于测试模型的特定能力。如果模型在这些任务上失败,则可能缺乏应对真实语言复杂度所需的基础构件。

他们聚焦于两个任务:

表 1 展示了两个合成语言建模任务: 归纳头 (Induction Head) 和关联回忆 (Associative Recall) 。

合成语言任务针对特定记忆和比较能力。

  1. 归纳头 (Induction Head): 模型看到一串随机词元,后跟一个特殊词元 (如 vdash) ,再跟更多随机词元。目标是预测在前面序列中紧跟在该特殊词元之后的词元。
  2. 关联回忆 (Associative Recall): 模型看到一系列键–值对 (如 a 2 c 4 b 3) 。末尾给一个键 (如 a) ,模型必须输出对应的值 (2) 。

这些任务对 Transformer 来说轻而易举,因为注意力机制能够直接“回看”并比较词元。那么现有的 SSM 表现如何呢?

表 2 显示,现有的 SSM (如 S4D 和 GSS) 在合成任务上表现不佳,而 H3 和注意力机制则取得了近乎完美的分数。

现有的 SSM 在需要精确回忆和跨序列比较的任务上表现失败。

常见 SSM 变体,如 S4D门控状态空间 (GSS) ,表现惨淡。这揭示了两项缺失能力:

  1. 回顾过去的特定词元
  2. 跨序列比较不同的词元。

注意力机制在这两方面都表现卓越: 它构建了完整的 \(N \times N\) 比较矩阵 (\(QK^T\)) 并通过值向量 (\(V\)) 直接复制信息。目标是让 SSM 获得类似能力。


解决方案之一: 饿饿河马 (H3)

H3 层是一种全新的 SSM 架构,专为解决这些短板而生。核心思想是: 将两种不同 SSM 结合乘法交互,模拟注意力机制的比较–存储–回忆过程。

图 1 (左) 展示了 H3 层的架构,它将一个移位 SSM 和一个对角 SSM 堆叠在一起,并在 Q、K、V 投影之间进行乘法交互。

在 H3 中,一个移位 SSM 和一个对角 SSM 通过乘法门协同工作,在序列中存储和检索信息。

直观解释:

  • 移位 SSM: 应用于 K 投影,其中 A 是一个移位矩阵。它形成对于近期输入的短期记忆——非常适合记住当前词元之前的内容。
  • 对角 SSM: 基于 S4D,可记忆整段序列的信息。
  • 乘法门控 (\(\odot\)) 提供选择性存储与检索:
    • 第一层门: SSM_shift(K) ⊙ V 仅在上一个词元匹配特定键时才存储值。
    • 第二层门: Q ⊙ SSM_diag(...) 仅在当前词元匹配查询键时才回忆该值。

核心 H3 公式,展示了一个 Q 投影与一个对角 SSM 的输出相乘,该对角 SSM 的输入是移位 SSM 在 K 上的输出与 V 相乘的结果。

核心 H3 操作: 乘法交互让记忆以键匹配为条件

如图 1 (中) 所示,这个机制精准解决了关联回忆任务: 移位 SSM + 第一门充当存储,对角 SSM + 第二门充当回忆


真实语言基准

这种增强的表达能力对实际语言建模影响巨大。

表 3 展示了在 OpenWebText 上的困惑度。H3 显著优于其他 SSM,与 Transformer 的差距仅为 0.4 PPL。H3 混合模型比 Transformer 的 PPL 低 1.0。

在 OpenWebText 上,H3 的困惑度几乎与 Transformer 持平,而混合模型则超越 Transforme。

纯 H3 模型将与 Transformer 的差距缩小到仅 0.4 困惑度点,较此前的 SSM 进步巨大。一个简单的 H3–注意力混合模型,仅有 2 层注意力,其余都是 H3,性能比纯 Transformer 高出整整一个点。


解决方案之二: 面向硬件效率的 FLASHCONV

H3 拥有了建模能力——但速度呢?从渐近复杂度上讲,SSM 更快 (\(O(N \log N)\) 对比 \(O(N^2)\)) ,但在真实 GPU 上往往更慢。原因是硬件抽签: 现代 GPU 针对注意力机制的巨型矩阵乘法进行了优化,而非 FFT。

解决办法是 FLASHCONV,一种 I/O 感知的卷积算法,灵感源自 FlashAttention,包含两大核心:

  1. 融合块 FFTConv —— 适用不超过 8K 词元的序列。
    • 核函数融合: 将 FFT、乘法与逆 FFT 合并为单个核函数,最大限度减少高速 SRAM 与较慢 HBM 之间的数据传输开销。
    • 块 FFT: 将 FFT 重写为一系列小矩阵乘法,充分利用 Tensor Core。
  2. 状态传递 —— 适用于超过 8K 词元的序列。
    • 将输入切分成符合 SRAM 容量的块。
    • 将每个块的最终状态作为下一块的初始状态向前传递。
    • 重复此过程直至处理完整个序列。

图 1 (右) 展示了 FlashConv 的状态传递算法,该算法以 8K 大小的块处理长序列,并将一个压缩状态从一个块传递到下一个块。

状态传递让 FLASHCONV 能够高效处理任意长的序列。


FLASHCONV 性能

性能提升巨大。在长程竞技场 (LRA) 基准中,采用 FLASHCONV 的 S4 比 Transformer 快 5.8 倍

表 8 显示,在 LRA 基准测试中,使用 FlashConv 的 S4 比标准 Transformer 实现了 5.8 倍的加速。

在将 H3 与优化版 FlashAttention 对比的基准测试中,FLASHCONV 保持了近线性扩展,并在长序列上超越了注意力机制。

图 2 展示了不同卷积算法的基准测试。FlashConv 几乎呈线性扩展,并且在序列长度超过 8K 时,速度显著快于 FlashAttention。

核函数融合加速短序列,块 FFT 优化中等序列,状态传递在长序列中占据优势。

对于长度 ≥16K 的序列,带状态传递的 SSM 甚至比最快的注意力实现快几十倍


整合: 扩展 H3

有了表现力更强的新层 (H3) 和高速算法 (FLASHCONV) ,团队将 H3–注意力混合模型扩展到 1.25 亿27 亿参数规模,并在大规模 Pile 数据集上完成训练。


困惑度优势

表 4 展示了不同大小的 H3 混合模型与 GPT-Neo 和 GPT-2 模型的困惑度对比。H3 混合模型始终表现更优。

在 The Pile、OpenWebText 和 WikiText-103 数据集上,H3 混合模型稳压同规模 Transformer。


零样本/少样本 SuperGLUE

表 5 展示了在 SuperGLUE 上的零样本准确率。在每个规模上,H3 混合模型在大多数任务上都优于或持平于 Transformer 基线模型。 表 6 展示了在 SuperGLUE 上的 3 样本准确率,H3 混合模型同样表现出色。

在多数 SuperGLUE 任务上,H3 混合模型表现持平或超过 Transformer。


更快的推理

表 7 展示了 13 亿参数模型的推理吞吐量。H3 混合模型生成词元的速度比 Transformer 快了 2.4 倍。

由于循环特性,H3 混合模型生成文本速度比 Transformer 快 2.4 倍——这对实时系统至关重要。


结论: 王座的新挑战者

《饿饿河马》提出了有力论证: 纯注意力模型可能不再是序列建模领域的唯一统治者。通过诊断 SSM 弱点并设计 H3 予以弥补,研究人员为语言模型构建了强大的新组件。

借助 FLASHCONV,他们清除了效率障碍——让 SSM 不仅在理论上更快,在实践中也更快。

一个简单的混合模型——多为 SSM,少量注意力层——性能就能优于纯 Transformer。未来或许属于融合互补优势的架构,而非固守单一路线。

凭借更好扩展性、更快推理速度和顶级表现,这些饿饿河马已对序列建模的王冠发起严肃挑战。