在过去十年中,深度学习的发展始终围绕一个主题——规模: 更大的模型、更长的上下文、更宽的特征空间。规模化带来了令人惊叹的能力,但也遇到了一个现实瓶颈——计算成本。许多当前最先进的模型,其核心操作在时间和内存上的开销均呈二次方增长: 序列长度为 \(N\) 的注意力机制,以及特征维度为 \(d\) 的密集型 MLP。序列长度或模型宽度翻倍,计算量可能增加四倍。随着我们迈向更长的上下文和更宽的网络,二次方的增长成本很快会变得难以承受。

MONARCH MIXER (M2) 是斯坦福大学和纽约州立大学研究人员提出的一种新架构,它提出了一个简单而大胆的问题: 我们能否设计一个单一的、硬件友好的原语,同时在序列维度和特征维度上混合信息,并在这两个维度上都实现亚二次方扩展?答案是:** 可以**。它的构建模块是 Monarch 矩阵,一种结构化、对 GEMM (通用矩阵乘法) 友好的对象,能够推广像 FFT 这样的变换,并可用块矩阵乘法高效实现。将基于 Monarch 的混合器堆叠构建模型 (即 M2) 后,我们得到了具备亚二次方复杂度、在现代 GPU 上高效、且在语言和视觉任务中与 Transformer 竞争力不相上下的架构。

本文将解析 M2 背后的直觉、数学原理与系统设计思路,重点关注三个方面:

  • 什么是 Monarch 矩阵,以及它们为何既有表达力又能在硬件上高效运行。
  • M2 层如何以亚二次方成本混合序列和特征信息。
  • 如何通过关键理论技巧,使 M2 在保持亚二次方扩展的同时实现因果性 (自回归) 。

同时,我们也会看到实证结果: M2 在多项任务中与 Transformer 持平或超越,同时参数更少,或在长上下文中显著提升吞吐量。

Monarch 矩阵和 Monarch Mixer (M2) 的高层概述。Monarch 矩阵是一种亚二次方的、硬件高效的原语,用于沿序列和特征维度混合信息。

图 1: Monarch 矩阵由块对角因子与置换交错相乘构成。M2 先沿序列轴进行 Monarch 混合,再沿特征轴混合,仅使用矩阵乘法、重塑、转置和逐点操作 (这些都是对 GEMM 友好的原语) 。

为什么关注结构化矩阵?
一个神经网络原语在加速器上的性能取决于两个关键且相互独立的因素: 算法的渐进复杂度,以及它与硬件高效原语 (如 GEMM、张量核心) 的适配程度。注意力和密集 MLP 功能强大,但在它们混合的维度上具有二次方复杂度。基于 FFT 的长卷积在理论复杂度上很诱人 (\(O(N \log N)\)) ,但在实践中常受限于内存带宽,难以充分利用 GPU 的计算能力。Monarch 矩阵则试图在二者间找到平衡: 既有亚二次方的复杂度,又能用高 FLOPs GEMM 高效实现。

基本原语在 GPU 上的表现回顾:

  • 计算密集型操作: 算术运算占主导,例如密集 GEMM,可在现代张量核心上实现极高的 FLOP 利用率。
  • 内存密集型操作: 数据搬运 (读/写) 占主导,当计算强度低时性能受限。在实践中,注意力和 FFT 往往属于此类。

M2 设计目标: 亚二次方复杂度 + 高计算强度 (便于高效使用 GEMM) 。

常见混合器层的 FLOPs 成本和硬件利用率对比。M2 卷积提供了介于两者之间的方案,兼具亚二次方扩展和高利用率。

图 2: 在大输入场景下,不同混合器层的 FLOPs 成本与利用率对比。MLP 属于计算密集型,高效但按维度呈二次方扩展;FFT 卷积拥有良好的复杂度,但硬件利用率低;M2 则兼具亚二次方扩展和更高的 GEMM 利用率。

什么是 Monarch 矩阵?

高层定义: 一个 p 阶 Monarch 矩阵 \(\mathbf{M} \in \mathbb{R}^{N\times N}\) 可以写为置换矩阵和块对角矩阵的乘积:

\[ \mathbf{M} = \left(\prod_{i=1}^{p} \mathbf{P}_i \mathbf{B}_i\right)\mathbf{P}_0. \]

每个 \(\mathbf{B}_i\) 是块对角矩阵: 将输入拆分成多个独立的小块,并对每个小块施加一个小的密集矩阵。置换矩阵 \(\mathbf{P}_i\) 在各次块乘法之间重新排列元素。若块大小选得适当 (例如 \(b = \sqrt[p]{N}\),此时块数也为 \(b\)) ,计算 \(\mathbf{M}\mathbf{x}\) 的总成本将小于二次方;当 \(p=2\) 且 \(b=\sqrt{N}\) 时,复杂度为 \(O(N^{3/2})\)。更大的 p 可使复杂度逼近 \(O(N \log N)\),同时保持 GEMM 友好性。

硬件友好的原因: 尽管 \(\mathbf{M}\) 在效果上是密集 \(N\times N\) 变换,但计算中无需构造整个密集矩阵,而是执行许多小块的密集乘法 (块 GEMM) 和置换。这类操作可很好地映射到高度优化的 BLAS 库和张量核心,保持高计算强度,并在大输入规模下降低内存密集型置换的影响。

M2 层: 一个简单的两阶段混合器

M2 采用“先序列混合再特征混合”的设计。设 \(\mathbf{X} \in \mathbb{R}^{N \times d}\) 为输入 (序列长度 \(N\),模型维度 \(d\)) 。一个二阶 M2 层可描述为:

序列混合 (广义门控卷积) :

\[ \tilde{\mathbf{X}} = \mathbf{M}_2\big(\mathbf{K}_1 \odot (\mathbf{M}_1 \mathbf{X})\big), \]

维度混合 (Monarch 替代 MLP) :

\[ \mathbf{Y}^\top = \mathbf{M}_4\;\sigma\big(\mathbf{M}_3 \tilde{\mathbf{X}}^\top\big). \]

其中 \(\mathbf{M}_1, \mathbf{M}_2\) 是 \(N\times N\) 的 Monarch 矩阵,\(\mathbf{M}_3, \mathbf{M}_4\) 是 \(d\times d\) 的 Monarch 矩阵,阶数由设计决定。\(\mathbf{K}_1\) 是核张量,\(\odot\) 表示逐元素乘法,\(\sigma\) 是逐点非线性函数 (如 ReLU 或 GeLU) 。若将 \(\mathbf{M}_1\) 设为 DFT,\(\mathbf{M}_2\) 设为逆 DFT,则第一个公式就是频域滤波实现的长卷积;用可学习的 Monarch 因子替换 DFT 可推广此方法并保持 GEMM 友好性。

实现非常简洁: M2 层仅需重塑、转置、块 GEMM 和逐元素操作。由于计算主体为块 GEMM,实现可在现代 GPU 上获得高 FLOP 利用率。

性能分析——何时亚二次方胜出?

当 \(N\) 较大时,Monarch 的渐进性优势最明显。作者使用 cuBLAS 实现块 GEMM,并配合自定义置换,测量 M2 与密集矩阵乘法的 FLOP 利用率和运行时间。在大输入下,亚二次方扩展的优势压倒了置换开销,M2 比密集乘法快几个数量级。

M2 与密集 MLP的 FLOPs 成本和利用率对比。随输入 N 增大,M2 的亚二次方扩展带来了显著加速,并在大 N 时更高效。

图 3: 在 NVIDIA A100 和 RTX 4090 上,不同输入规模下 M2 与密集乘法的 FLOP 成本与实测利用率对比。小规模时置换开销明显,但大规模时亚二次方扩展优势显现,M2 速度显著领先。

要点:

  • 短序列时,经过高度优化的注意力和密集 MLP 更胜一筹,因为置换的固定成本占主导。
  • 长序列 (数千至数万长度) 下,M2 的 \(O(N^{3/2})\) (或更好,取决于 p) 复杂度领先密集操作很大幅度。在大缓存 GPU (如 RTX 4090) 上,优化的 M2 在极大输入下可实现超过 40% 的 FLOP 利用率。

使 M2 具备因果性: 多项式视角

在自回归建模中,未来信息屏蔽若处理不当,会将亚二次方操作变成二次方瓶颈。作者的解决方案是将 Monarch 乘法重新解释为多元多项式的求值与插值,从而在代数上自然处理因果性。

核心思想:

  • 二阶 Monarch 乘法 \(\mathbf{M}\mathbf{u}\) 可看作在 \(\sqrt{N}\times\sqrt{N}\) 的点网格 (如单位根) 上对二元多项式 \(u(X,Y)\) 求值,Monarch 因子定义了基多项式。
  • M2 卷积类操作 \[ \mathbf{f} = \mathbf{M}_0^{-1}\big((\mathbf{M}_1\mathbf{k}) \odot (\mathbf{M}_2\mathbf{u})\big) \] 对应于构造多项式 \(k(X,Y)\)、\(u(X,Y)\),相乘得 \(h(X,Y)=k(X,Y)u(X,Y)\) (模求值多项式) ,再插值到系数空间。
  • 模 \(X^b-1, Y^b-1\) 的乘法可能导致循环回绕,造成输出系数依赖未来输入系数。为避免此问题,作者约束基多项式的次数,使得两个基多项式的乘积仅影响索引满足 \(a \ge j+j'\) 的基元素,从而形成三角依赖模式,实现因果性。

多项式视角下通过次数约束实现因果性。

图 4: 将 Monarch 乘法视为多项式求值/插值,通过约束多项式次数,保证乘积仅影响未来系数,从而满足因果性。

最终得到一个构造性定理: 在特定的最小/最大多项式次数分配下,并将输入按一定方式嵌入更大向量空间,M2 卷积可被严格证明为因果的,且可用同样的块 GEMM 原语计算——即依然是亚二次方复杂度。

实验: M2 能否匹敌 Transformer?

作者在三个 Transformer 表现出色的任务上评估了 M2: 非因果掩码语言建模 (BERT) 、图像分类 (ViT) 、因果语言建模 (类 GPT 预训练) 。

1) 非因果语言建模 — M2-BERT

M2-BERT 用基于 Monarch 的双向门控长卷积替换了注意力,用小块对角 Monarch 矩阵替换了 MLP 的密集矩阵。
在标准语料上训练并评估 GLUE:

  • 紧凑型 M2-BERT-base (80M 参数) 在 GLUE 平均分上与 BERT-base (110M) 持平,参数减少约 27%。
  • 参数量匹配时,M2-BERT-base 比 BERT-base 高 1.3 分。
  • 大模型下,M2-BERT-large 参数更少但质量持平或略胜 BERT-large。

M2-BERT 架构示意。

图 5: M2-BERT 在序列混合器中使用基于 Monarch 的门控长卷积替代注意力,在维度混合器中用块对角 Monarch 矩阵替代 MLP 的密集矩阵。

在长上下文下的吞吐量提升显著: A100 上,序列长度 4096 时,M2-BERT-base 对比 HuggingFace BERT 提速 9.1 倍,对比 FlashAttention 提速约 3 倍;CPU 推理中,长度超 1K token 即超过 BERT,极长上下文下提速达 6.5 倍。

GLUE 平均分对比。

图 6: M2-BERT 用更少的参数达到或超过 BERT 性能,参数匹配时可进一步超越。

吞吐量随长度变化对比。

图 7: M2-BERT 随上下文长度增长保持高吞吐,体现亚二次方扩展优势。

2) 视觉 — M2-ViT

作者将 M2 应用于 Vision Transformer: 替换 Hyena 长卷积 (FFT) 与 ViT MLP 为基于 Monarch 的实现。
结果: 在 ImageNet-1k 上,M2-ViT 精度略高于 ViT-base,但参数仅为其一半;同时优于 HyenaViT-b 及其他 Monarch 增强版本,表明方法可迁移到视觉领域。

ImageNet-1k 精度对比。

图 8: M2-ViT 在 ImageNet 上性能有竞争力甚至更好,并显著减少参数数量。

3) 因果语言建模 — M2-GPT

构建完全无注意力、无 MLP 的自回归模型 M2-GPT,使用 Monarch 矩阵的因果多项式参数化,在 The PILE 上预训练并评估困惑度:

  • 约 360M 参数下,M2-GPT 的困惑度低于 Transformer 和 Hyena 基线,最终结果接近或略优于它们。
  • 小模型下,M2-GPT 一贯优于 Hyena,并与 Transformer 持平。

困惑度对比。

图 9: M2-GPT 在困惑度上可匹敌甚至超越 Transformer 与 Hyena,证明无注意力、无 MLP 设计可实现 GPT 级预训练效果。

为什么这些结果重要

  • 亚二次方原语的实用性: M2 展示了结构化变换的渐进优势可与高硬件利用率兼得,并在大规模下转化为实际速度提升。
  • 统一原语替代双轴混合: 用同一结构化构建块替换注意力 (序列混合) 与密集 MLP (特征混合) ,简化扩展与优化。
  • 因果性无二次方瓶颈: 多项式视角提供了概念清晰的因果操作参数化方法,避免了朴素掩码的 \(O(N^2)\) 成本。
  • 跨模态通用性: 语言、视觉均适用,可结合近期无注意力模型常见的卷积与门控手段。

注意事项与未来方向

  • 实现成熟度: 现有原型利用率不错,但通过内核融合、优化置换核、融合块 GEMM,可提升短序列性能、降低开销。
  • 内存与常数因素: 亚二次方性能需精心选取块大小,硬件缓存行为会影响最佳设计。
  • 更广泛适用性: 需评估在更多 Transformer 应用 (多模态、强化学习、极大规模生成) 中的表现。
  • 因果性填充: 因果参数化有时需嵌入更大空间,p>2 时嵌入增长值得优化。

结论

MONARCH MIXER 采用系统导向的架构设计: 从一个既有表达力又易于实现为高吞吐 GEMM 的结构化线性代数原语出发,构建模型;结合块 GEMM 与多项式视角实现因果性,证明了亚二次方扩展可落地并在长上下文中带来巨大吞吐提升,同时性能与 Transformer 基线持平或更优。

如果你对注意力替代方案或为现代加速器优化的原语设计感兴趣,M2 是设计空间中的重要新节点: 正确的结构 + 正确的系统思维,可同时降低序列长度和模型维度的扩展成本。

进一步阅读与代码指引

  • 完整论文包含证明、精简 M2 实现的 PyTorch 代码、更多实验 (语音、CIFAR、roofline 分析) 等,深入探讨权衡与实现细节。欲探索或复现,可从论文链接的代码仓库入手。

致谢: 本文总结并解析了 “MONARCH MIXER: A Simple Sub-Quadratic GEMM-Based Architecture” (Fu et al., 2023) 的核心思想。完整技术细节、证明与实验设置,请参阅原论文。