大型语言模型 (LLM) 正在迅速扩展能力边界,现在能够处理一百万甚至更多词元的上下文窗口。这带来了令人惊叹的应用场景——从理解整个代码仓库,到回答冗长法律文件中的细微问题,再到在庞大数据集上进行复杂推理。
然而,巨大的上下文也意味着巨大的计算成本。
想象一下,将一个 1M 词元的提示 (prompt) 输入给一款最先进的 LLM。即便是在强大的 Nvidia A100 GPU 上,你也可能需要等待 30 分钟才能看到模型生成第一个输出词元。这种初始延迟发生在预填充 (pre-filling) 阶段——即接收提示、计算每个词元之间的注意力,并为后续解码设置键值 (KV) 缓存的过程。主要瓶颈是什么?是 Transformer 的自注意力机制,其计算量随输入长度呈二次方增长。
现在,想象一下将等待时间从 30 分钟缩短至仅 3 分钟——而且无需牺牲准确率或重新训练模型。这正是微软研究人员在最新论文中通过 MInference 实现的: 一种利用注意力模式中隐藏结构的动态稀疏注意力技术,在预填充阶段获得高达 10 倍的加速。
图 1: (b) 在单张 A100 GPU 上,MInference 对 1M 词元上下文实现了高达 10 倍的加速;(a) 在 Needle In A Haystack 等检索密集型任务上的准确率与全注意力持平或更高。
本文将分解长上下文推理的挑战、MInference 的核心洞察,以及使其成为长上下文 LLM 突破性进展的惊人实验数据。
问题所在: 首个词元生成前的漫长等待
LLM 推理包括两个阶段:
- 预填充 (Pre-filling) —— 并行处理整个提示以计算 KV 缓存。
- 解码 (Decoding) —— 利用缓存的键和值,自回归地生成词元。
在短提示下,解码通常占据主要运行时间。但在百万词元级的上下文中,预填充阶段成为瓶颈。
自注意力需要计算一个 \(N \times N\) 的成对词元交互矩阵,其成本为 \(O(N^2)\)。当 \(N = 1{,}000{,}000\) 时,如果不进行优化,这是无法接受的。
图 2: (a) 延迟分解显示,注意力计算主导了预填充延迟。(b) 在 128k 上下文中,仅保留最重要的 4096 列 (约 3%) 即可保留 96.4% 的注意力分数——证实了稀疏性。(c) 将这些列索引应用于另一段上下文时,召回率下降到 83.7%,表明稀疏性是动态的。
核心洞察: 注意力既稀疏又动态
分析证实了两点:
- 稀疏性: 每个词元有意义地关注其他词元的比例很小。
- 动态性: 稀疏模式会随提示内容发生显著变化。
在 128k 上下文中,最重要的 4k 列几乎涵盖了全部注意力权重。但对不同提示沿用这些索引则效果不佳。要利用稀疏性,不能依赖固定掩码——必须实时高效预测。
三种结构化稀疏模式
MInference 团队发现,动态稀疏性并非随机,而是呈现出少数几种在注意力矩阵中反复出现的几何模式。
图 3: (a) 注意力矩阵可视化揭示了三种通用模式。(b) 各位置到最近非零邻居的距离验证了空间聚类特性。(c) 与通用 Top-K 稀疏相比,这些模式在 GPU 上每 FLOP 的召回率更高。
三种模式包括:
A 型 (A-shape)
静态结构: 对早期词元 (全局上下文) 和最近窗口 (局部上下文) 有强关注,适合捕捉基础信息和局部线索。垂直斜线型 (Vertical-Slash, VS)
动态结构: 垂直线表示序列中任意位置的特定重要词元,斜线表示周期性的相对位置。二者均随提示内容变化。块稀疏型 (Block-Sparse)
高度动态但呈聚类: 重要词元集中在连续块中。尽管位置分散,空间聚类特性使块级计算高效。
识别这些高层几何结构,将问题从“在百万词元中寻找重要词元”转化为“为当前注意力头和提示定位特定的线或块”。
MInference 的工作原理
MInference 通过三阶段流水线实现加速:
图 4: MInference 使用的三种稀疏注意力方法——按静态 (A 型) 到更动态 (VS、块稀疏型) 排列。
1. 离线阶段: 感知核函数的模式分配
每个注意力头进行一次离线分析,通过搜索三种模式及其配置,确定在单位 GPU 成本下召回率最高的选项。
关键在于感知核函数 (kernel-aware) ——测量 GPU 核函数中实际运行时的 FLOPs,而不仅是理论值。
2. 在线阶段: 动态索引构建
推理时,MInference 执行超轻量近似计算定位 VS 和块稀疏型注意力头的动态部分:
- 垂直斜线型: 仅将最后 64 个查询向量与所有键相乘,由此生成的部分注意力图可低成本找出 top-\(k\) 的垂直与斜线索引。
- 块稀疏型: 对 Q 和 K 进行 64 大小的均值池化,再计算一个小的块级注意力矩阵选出 top-\(k\) 块。
- A 型: 无需近似,窗口为固定。
近似计算开销仅占总计算量的 5–20%。
3. 自定义核函数实现稀疏注意力
稀疏掩码传递给基于 Triton 和 FlashAttention 的自定义核函数,这些核经过优化,可跳过不相关注意力区域——只计算选中的线或块。
实验结果
准确率保持——甚至提升
在多项基准上,MInference 与全注意力持平或略优。
**InfiniteBench **(平均 214k 词元) :
MInference 在所有方法中领先,即使在 StreamingLLM 变体性能崩溃的情况下,也能保持检索准确率。
表 2: MInference 准确率持平或超越 LLaMA-3 全注意力,远胜固定稀疏基线。
RULER:
MInference 扩展了 LLaMA-3-8B 和 GLM-4-9B 等模型的有效上下文窗口,在更长序列上超越全注意力。
表 3: 在 RULER 测试中,MInference 在 32k–64k 以上仍保持长上下文问答和多跳推理性能。
Needle In A Haystack:
与固定窗口稀疏不同,MInference 可在百万词元的“草堆”中检索任意位置的“针”。
图 6: 当“针”位于静态窗口外时,StreamingLLM 失败;MInference 保持全局覆盖。
消融研究: 三种模式的力量
仅用一种模式类型或静态索引会显著降低性能,尤其在动态检索任务中。
表 4: 移除模式或使用静态掩码会降低准确率,证明组合动态方法的必要性。
加速效果: 1M 词元下实现 10 倍加速
延迟收益随上下文增大而提升:
- 100k 词元: 1.8× 加速
- 500k 词元: 6.8× 加速
- 1M 词元:** 10× 加速**——从 30 分钟缩短至 3 分钟。
由于核函数基于 Triton 编写,该方法可轻松移植至 H100 或 MI300X 等 GPU。
结论: 打破长上下文 LLM 的瓶颈
自注意力的二次方开销长期困扰着长上下文 LLM。
MInference 为预填充阶段提供了优雅解决方案:
- 结构洞察: 动态稀疏性呈 A 型、垂直斜线型和块稀疏型三种模式。
- 高效设计: 离线模式分配 + 轻量在线近似。
- 显著增益: 保持准确率同时实现最高 10 倍加速。
在无需重新训练的条件下大幅削减预填充延迟,MInference 让百万词元交互不仅可行,而且快速。随着 LLM 迈向数百万词元时代,这类技术将成为释放其现实潜力的关键。