Transformer 在许多序列建模任务中占据主导地位,但其核心的自注意力机制在计算上与上下文长度成二次方关系。这种设计选择使得处理超长上下文在计算和内存方面代价高昂。与此同时,以 S4 和 Mamba 为代表的结构化状态空间模型 (SSM) 在序列长度上实现了线性扩展,并在自回归生成中保持恒定的状态维度。两种模型体系在发展过程中几乎沿着完全独立的路径成熟: 数学理论不同,优化方法不同,工程权衡也不同。
论文《Transformers 是 SSMs》在这两种范式之间架起了桥梁,并基于这种联系设计出一种更快、更加硬件友好的 SSM 变体: Mamba-2。核心洞见是,在以下三者之间存在原则性的对应关系: (i) SSM、(ii) 一类称为结构化掩码注意力 (SMA) 的广义注意力机制家族、(iii) 一类被称为半可分矩阵的经典结构化矩阵。这种“结构化状态空间对偶性” (SSD) 框架产生了既能兼具 SSM 递推的渐近效率,又能享受密集矩阵乘法硬件友好性的算法。
本文将逐步介绍主要思想,展示这种对偶性的来源,并解释驱动 Mamba-2 的 SSD 算法。在此过程中,我会引用论文中的图示使概念更直观。
图 1: 高层路线图。结构化矩阵在状态空间模型 (SSM) 与注意力机制之间架起桥梁,形成结构化状态空间对偶性 (SSD) 框架,催生 Mamba-2 架构。
本文的整体安排:
- 以能够显现联系的方式重温 SSM 与注意力机制。
- 介绍半可分矩阵,并展示 SSM 如何映射到结构化矩阵。
- 将线性注意力推广为结构化掩码注意力 (SMA) 。
- 阐述对偶性: 特定 SSM ↔ 特定 SMA。
- 解释混合型 SSD 算法 (分块分解) ,既在线性序列长度上工作,又对矩阵乘法友好。
- 介绍 Mamba-2 的架构设计、系统优化以及实验证据。
1 — 技术速览
1.1 什么是结构化 SSM?
结构化状态空间模型 (SSM) 是一种线性递推,更新内部状态并输出结果。在离散时间下,一个选择性 (时变) SSM可以写成:
\[ \begin{aligned} h_t &= A_t h_{t-1} + B_t x_t, \\ y_t &= C_t^\top h_t, \end{aligned} \]其中 \(x_t, y_t \in \mathbb{R}^P\) (P 为每个头或通道的维度) ,\(h_t \in \mathbb{R}^N\) 是内部状态 (状态扩展因子) 。SSM 称为“结构化”,是因为它的转移矩阵 \(A_t\) (或参数化形式) 受到特定约束,使得整体计算效率很高。
两种常见计算模式:
- 递推 (线性) 模式: 逐步更新 \(h_t\),复杂度为 \(O(TN)\) (每个头的序列长度为 \(T\)) 。
- 卷积 (并行) 模式: 当 \(A_t\) 是时不变的 (LTI) 时,SSM 等价于卷积,可通过 FFT 等方法并行计算。
选择性 SSM (时变 \(A_t\)) 在语言建模上更具表现力,但因高效实现常依赖专用核函数,硬件优化更困难。
1.2 什么是注意力 / 线性注意力?
标准 softmax 自注意力计算公式:
\[ Y = \operatorname{softmax}(Q K^\top)\,V, \]其中 \(Q,K,V \in \mathbb{R}^{T\times N}\),而 \(T\times T\) 的 Gram 矩阵带来二次方的计算成本。线性注意力移除 softmax 或用核特征映射 \(\psi(\cdot)\) 代替,从而:
\[ QK^\top \approx \psi(Q)\psi(K)^\top, \]然后利用结合律重排矩阵乘法:
\[ (QK^\top)V = Q(K^\top V) \quad\longrightarrow\quad Q \underbrace{(K^\top V)}_{\text{仅计算一次}}. \]再在 Gram 矩阵上加上因果掩码后,这种重新排序可形成递推 (累加和或其他结构化乘法) ,实现自回归生成的线性复杂度。
2 — 将 SSM 看作矩阵变换 (半可分视角)
除了递推,我们还可将整体的 \(x \mapsto y\) 转换看成一个在时间轴上的 \(T\times T\) 矩阵 \(M\):
\[ y = Mx, \qquad M_{j,i} = C_j^\top A_j A_{j-1}\cdots A_{i+1} B_i \quad\text{for } j\ge i. \]一个 SSM 定义了一个结构化序列到序列矩阵 \(M\)。关键性质是: 当内部维度为 \(N\) 时,\(M\) 主对角线及以下的任意子矩阵秩至多为 \(N\),这正是 (下三角) 半可分矩阵的定义特征。
直观理解:
- SSM 递推可用 \(O(TN)\) 线性复杂度计算 \(Mv\)。
- 显式构造 \(M\) 并计算 \(Mx\) 是 \(T^2\) 复杂度,但用的是密集矩阵乘法 (硬件友好) 。
这种双模式特性 (线性递推 vs 显式矩阵乘法) 就是 SSD 利用的计算对偶性。
图 2: 矩阵变换视角的状态空间模型。序列混合器 \(M \in \mathbb{R}^{T\times T}\) 是半可分的: 主对角线及以下的子矩阵秩最多为 \(N\),可通过递推或结构化矩阵算法计算 \(Mx\)。
特殊情况: 标量 SSM (\(N=1\)) ,即 \(A_t\) 简化为标量 \(a_t\)。此时 \(M_{j,i} = a_j\cdots a_{i+1}\) (\(j \ge i\)) ,与 \(M\) 相乘对应于简单递推:
\[ y_t = a_t y_{t-1} + x_t, \]即累乘-累加 (cumprodsum) 运算。1-SS (标量半可分) 乘法已有多种高效原语,并可推广到 \(N>1\) 半可分情形。
3 — 结构化掩码注意力 (SMA) : 推广线性注意力
将掩码注意力改写为单一张量收缩,可澄清多种算法顺序。掩码核注意力:
\[ Y = (L \circ (QK^\top))\,V \]是四个张量 \((Q,K,V,L)\) 在适当轴上的收缩。两种自然收缩顺序:
- 二次方 (注意力) 模式: 先算 \(G=QK^\top\),再 \(M=L\circ G\),最后 \(Y=MV\)。
- 线性 (扫描) 模式: 先算 \(Z=K^\top V\),再 \(H=LZ\),最后 \(Y=QH\)。瓶颈在 \(H=LZ\)。
线性注意力对应 \(L\) 为下三角全 1 因果掩码的情况,此时与掩码相乘就是累加。论文的洞见在于: 可用任何支持亚二次矩阵-向量乘法的结构化矩阵替换 \(L\),得到结构化掩码注意力 (SMA) :
定义 (非正式) : SMA 计算同样的四路收缩 \(Y = \operatorname{contract}(Q,K,V,L)\),但允许 \(L\) 为任意支持快速 matvec 的结构化掩码。两种算法模式 (二次方与线性) 是不同收缩顺序;当 \(L\) 结构化时,线性模式高效。
图 3: SMA 中,对于任意结构化掩码 \(L\) 构建 \(M = QK^\top \circ L\)。不同 \(L\) 会产生不同行为 (因果、衰减、Toeplitz、1-SS、傅里叶等) 。1-SS 掩码 (紫色行) 是 SSD 的核心。
\(L\) 的重要示例:
- 因果全一 \(L\): 标准线性注意力。
- 指数衰减 / Toeplitz 掩码: 类似 RetNet 或相对位置衰减。
- 1-半可分掩码: 源自标量 SSM (标量 \(a_t\) 的乘积) 。
4 — 对偶性: 当 SSM 成为注意力 (反之亦然)
标量单位阵 SSM (\(A_t = a_t I\)) 生成的半可分矩阵 \(M\) 可分解为:
\[ M = L \circ (C B^\top), \quad L_{j,i} = a_j a_{j-1}\cdots a_{i+1}. \]这正是二次方掩码核注意力形式,取 \(Q\leftarrow C\)、\(K\leftarrow B\)、掩码 \(L\)。因此,该 SSM 的二次方形式就是带 1-半可分 \(L\) 的 SMA。
反之,若 SMA 的结构化因果掩码 \(L\) 支持有界阶高效自回归更新,则 \(L\) 必是半可分的。因此,高效的自回归掩码注意力 ↔ 半可分 SMA,其中许多是 SSM 或变体。
简言之: 存在一个大类模型,既有线性时间 SSM 递推,又有二次时间注意力形式——这类模型构成状态空间对偶 (SSD) 类。
推论 (非正式) : 1-SS SMA 是对角 SSM (标量单位阵转移) 的特例。两种算法形式对偶: SSM 递推是线性算法;构造 \(M\) 做 \(Mx\) 是二次注意力算法。
这种等价是论文的核心,解释了为何能在注意力与 SSM 阵营间互译算法与优化。
5 — SSD: 实用且硬件高效的算法
两种算法形式互补:
- SSM 递推在 \(T\) 上渐近线性,但多为标量/逐元素操作,不总能充分利用 GPU 张量核心。
- 注意力 (二次) 形式硬件友好: 大规模批矩阵乘法,适合 GPU/TPU,但复杂度 \(T^2\)。
SSD 通过对半可分矩阵 \(M\) 做分块分解结合二者:
- 将序列分成 \(B = T/Q\) 个长度 \(Q\) 的块,等价将 \(M\) 刻成 \(B\times B\) 块下三角矩阵。
- 对角块 (块内交互) 是小 \(Q\times Q\) 半可分矩阵,用二次 SMA 模式批矩阵乘法计算 (硬件利用高) 。
- 非对角块 (跨块交互) 因半可分而低秩,分解为左/中/右因子:
- 用 matmul 计算每块的右因子 (紧凑块状态) ,
- 对这些块状态运行较短序列的标量 1-SS 扫描 (线性) ,
- 用 matmul 将状态转回输出 (左因子) 。
- 将块内与跨块输出相加,得最终结果。
该混合策略:
- 保持 \(T\) 上的线性复杂度,
- 大部分计算用大矩阵乘法 (张量核心友好) ,
- 将扫描长度由 \(T\) 缩短至 \(T/Q\),明显减少标量操作。
图 4: SSD 分块分解算法。对角块用二次注意力模式;非对角块分解为低秩项,中间部分需要短 (跨块) 递推。
论文给出简洁 PyTorch 版本 SSD (重排+批 einsum) :
|
|
复杂度总结 (适当选择 \(N\)、\(P\)、\(Q\)) :
- 训练 FLOPs: \(O(TN^2)\) (主要在块内 matmul) 。
- 推理 FLOPs: 每步 \(O(N^2)\) (取决于状态大小) 。
- 内存: \(O(TN)\)。
- 重计算多为批 matmul——理想的张量核心负载。
基准测试表明: SSD 比优化的 Mamba 融合扫描快 2–8× (大 \(N\) 时) ,并在长度约 \(2\text{k}\) 起超过 FlashAttention-2。相比扫描实现,SSD 对增加 \(N\) 不敏感。
6 — Mamba-2: 架构与设计
SSD 是快速核心序列混合器。Mamba-2 架构用 SSD 为核心层,并借鉴 Transformer 设计,使块更易于张量并行 (TP) 且稳健。
主要块级变化 (相较原 Mamba) :
- 并行参数投影: 块输入端并行生成 \(A,B,C,X\) (类 Q,K,V) ,而非从 \(X\) 顺序算 SSM 参数,减少同步并支持标准 TP 分片。
- 额外归一化: 输出前加归一化层 (LayerNorm/GroupNorm/RMSNorm) ,提升稳定性 (受 NormFormer 启发) 。
- 多头模式: 作者形式化了多头 SSM 类比:
- 多头 SSM (MHS) : 独立头 (如 MHA) 。
- 多收缩 SSM (MCS) : 类似多查询注意力 (共享 B) 。
- 多输入 SSM (MIS) : 类似多值注意力 (Mamba 用此模式) ,消融显示 MIS 在语言任务最佳。
- 分组头变体 (GVA/GIS) : 平衡速度、并行与准确。
图 5: Mamba-2 块设计。在块开始处并行计算 SSM 参数的投影,增加归一化改善训练稳定性与 TP 兼容性。
系统优势:
- 张量并行: 从原始输入 \(u\) 计算 \(A,B,C,X\) 并分片,Mamba-2 仅需块末一次 all-reduce (如 Transformer) 。用 GroupNorm 便于在每个 TP 分片内局部归一化。
- 序列 (上下文) 并行: SSD 天然支持跨设备序列分块,每台设备输出本地结果并传递紧凑递推状态到下一设备 (线性带宽) ,不同于注意力需全互通信。
- 变长批处理: 可通过在序列边界设 \(A_t=0\) 控制块互不交互,从而高效处理变长样本,避免重填充。
7 — 实验亮点
7.1 合成关联记忆 (MQAR)
多查询关联回忆任务 (MQAR) 测试模型在长上下文中记忆并检索多组 key/value 的能力,是 RNN-类递推的挑战。Mamba-2 因 SSD 支持更大状态 \(N\) (如 64、256) ,明显优于原 Mamba,并在多场景下匹敌注意力。增大状态可稳定提升召回力。
图 6: MQAR 性能。SSD 层加 Mamba-2 块允许更大状态空间与更强召回。
7.2 语言建模与缩放定律
Mamba-2 在 The Pile 数据集训练,规模从约 1.25 亿到 27 亿参数。在 Chinchilla 策略下,论文报告 Mamba-2 在一定计算范围与原 Mamba、强 Transformer++ 持平或更优。多情况下 Mamba-2 在相同计算预算下困惑度更低。
图 7: 缩放行为。在相同计算预算下,Mamba-2 达到更低困惑度。
下游零样本任务 (LAMBADA、HellaSwag、PIQA、ARC、WinoGrande、OpenBookQA) 显示,同等参数量下,Mamba-2 常匹敌或超越两倍规模的 Pythia 模型。
7.3 速度基准
SSD 是系统亮点: 在 A100 GPU 上对 SSD、Mamba 融合扫描、FlashAttention-2 基准测评:
- 大 \(N\) 时,SSD 比 Mamba 扫描快 2–8×。
- 序列长约 \(2\text{k}\) 起,SSD 快于 FlashAttention-2。
- 相比扫描随 \(N\) 线性增时变慢,SSD 运行对 \(N\) 更平稳。
图 8: 效率。SSD (紫) 相比融合扫描 (红) 显著加速,随长度增长在性能上可与 FlashAttention-2 (蓝) 匹敌。
8 — 消融与实践
消融实验验证:
- 块设计: ABCX 并行投影略减参数,同时改善 TP 兼容性与困惑度。
- 归一化: 最终归一化 (如 NormFormer) 稳定大模型训练,略提质量。
- 头模式: 固定参数/状态预算下,多输入 SSM (MIS) 效果优于多查询、多键。
- 核特征映射 (Swish、ReLU、随机特征、cosFormer 风格) : 无明显优于 Swish/SI-LU 默认,部分近似可能引数值不稳,需配合归一化。
作者还表明,混合模型 (SSD 为主 + 少量注意力层,约10%) 常优于纯 SSD 或纯注意力,因为注意力提供快速检索,SSD 提供压缩状态记忆。
9 — 意义与展望
SSD 价值:
- 概念统一: SSM 与大量掩码注意力是结构化矩阵的两面,可移植系统、并行、头模式、核技巧。
- 实用算法: SSD 将短块 matmul 与短扫描结合,提升硬件效率,保留亚二次扩展性,接近 Transformer 友好性。
- 架构灵活: 从 SMA 视角可构造多种掩码 \(L\),在效率与归纳偏置间权衡。
- 系统友好: SSD 天生支持 Transformer 训练中的张量/序列并行,并支持变长序列与高效自回归解码。
开放问题:
- 能否进一步靠近完整 softmax 注意力?SSD 未直接模拟 softmax 非线性,能否找到既近似 softmax 标准化又高效自回归的结构化矩阵?
- 是否有更丰富的半可分/秩结构分解,得到更优硬件权衡?
- 注意力的可解释性、机理洞见能否迁移到 SSD/SSM?
- SSD 算法与分块分解能否扩展到其他递推与结构化矩阵族?
10 — 总结
- 论文经半可分矩阵,将选择性 SSM 与线性注意力推广 (SMA) 严谨关联。
- 这种结构化状态空间对偶性 (SSD) 提供了概念词典与实用工具: 新算法、更高硬件利用率、强大的 SSM 架构 Mamba-2。
- SSD 显示递推的渐近效率与矩阵乘法的硬件效率可兼得。
若你关注长上下文效率、SSM 实际训练,或要构建压缩状态与快速检索平衡的混合模型,SSD 视角 (及 Mamba-2) 值得理解与实验。论文附有可运行代码、模型和丰富的消融,是在自家技术栈试用 SSD 的好资源。
(欲深入,可参阅论文中的正式证明: 连接半可分性、SSS 表示、线性注意力张量收缩推导,以及实验所用 SSD 全 PyTorch 实现。)