大型语言模型 (LLM) 更大上下文窗口的竞赛是目前 AI 领域最令人兴奋的发展之一。我们已经迅速从那些只能记住寥寥数段的模型,迈向了像 GPT-4 和 Gemini 1.5 这样能够通过单个提示处理整本小说、代码库或法律合同的系统。

然而,这种能力伴随着巨大的计算成本。瓶颈往往在于内存。具体来说,就是 键值 (KV) 缓存 (Key-Value Cache)

当 LLM 生成文本时,它必须存储每一个先前 Token 的“键 (Key) ”和“值 (Value) ”表示,以避免重复计算。随着上下文长度的增加,这个缓存的体积急剧膨胀,消耗数千兆字节的高带宽内存 (HBM) 。这限制了批处理大小 (batch size) ,减慢了生成速度,并使得部署长上下文模型对许多用户来说成本极其高昂。

为了解决这个问题,研究人员一直在寻找“压缩”这个缓存的方法——即丢弃我们不需要的信息。但是,如何在不破坏模型的情况下知道该丢弃什么呢?

在这篇文章中,我们将深入探讨一篇引人入胜的论文,题为 “A Simple and Effective \(L_{2}\) Norm-Based Strategy for KV Cache Compression” (一种基于 L2 范数的简单有效 KV 缓存压缩策略) 。作者提出了一个反直觉但非常有效的解决方案: 我们可以仅通过查看 Token 键嵌入 (Key embeddings) 的幅度 (\(L_2\) 范数) 来识别最重要的 Token。

关键结论是? Token 越“安静” (范数越低) ,它就越重要。

问题所在: KV 缓存瓶颈

在介绍解决方案之前,让我们先建立背景。在仅解码器架构的 Transformer (如 Llama、GPT 等) 中,注意力机制计算当前正在生成的 Token (Query) 与所有先前 Token (Keys 和 Values) 之间的关系。

为了生成下一个 Token \(x_{n+1}\),模型需要关注 \(x_1\) 到 \(x_n\)。为了不每次都让 \(x_1...x_n\) 通过深度神经网络重新计算,我们计算一次它们的 Key 和 Value 向量,并将它们存储在 KV 缓存中。

为什么压缩很难

压缩这个缓存的标准方法是 驱逐 (eviction) 。 我们希望保留“重要”的 Token 并驱逐“不重要”的。但是什么定义了重要性?

直觉上,如果模型对某个 Token 给予了大量的 注意力 (attention) , 那么这个 Token 就是重要的。许多现有的压缩方法 (如 H2O 或 Scissorhands) 都会查看注意力分数。如果一个 Token 通常获得的注意力分数很低,它就会被驱逐。

然而,这里有一个陷阱。现代推理引擎使用 FlashAttention , 这是一种通过避免在高带宽内存中显式生成完整注意力矩阵来加速注意力计算的算法。如果你的压缩算法要求你检查注意力分数来决定删除什么,你就实际上破坏了 FlashAttention 提供的优化。你将被迫去计算那些你本想避免计算的分数。

这篇论文的作者提出了一个关键问题: 我们能否在不计算注意力分数的情况下估算 KV 对的重要性?

惊人的发现: 范数 vs. 注意力

研究人员分析了像 Llama-2-7b 这样的模型的内部表示。他们特别关注了存储在缓存中的 键嵌入 (Key Embeddings)

他们发现 Key 向量的 \(L_2\) 范数 (欧几里得模长) 与它最终获得的注意力分数之间存在着强烈且一致的相关性。

图 2: Llama2-7b 第 9 层的 5 个注意力头。注意力分数 (上) 和 L2 范数 (下) 高度相关。

仔细观察上面的 图 2 。 这个可视化对比了 Llama-2 第 9 层中五个不同注意力头的注意力分数 (上排) 和 \(L_2\) 范数 (下排) 。

  • 上排 (注意力) : 亮点表示模型正在密切关注的 Token。
  • 下排 (\(L_2\) 范数) : 暗点表示向量模长 的 Token。

注意到这个规律了吗? 注意力最高的 Token (顶部最亮) 具有最低的 \(L_2\) 范数 (底部最暗) 。

这有点反直觉。在深度学习的许多领域,我们通常将较大的激活幅度与“更强”的信号联系起来。而在这里,情况恰恰相反。最“安静”的向量反而是充当 注意力汇聚点 (attention sinks) 的角色——即模型重点关注的锚点。

量化相关性

为了证明这不仅仅是视觉上的巧合,作者定义了一个称为 ALR (参考注意力损失,Attention Loss Reference) 的指标。

首先,他们定义了因压缩缓存而产生的 注意力损失 (Attention Loss) (\(\mathcal{L}\))。如果你丢弃了 \(m\) 个 Token,损失就是那些被丢弃的 Token 如果被保留 本应 获得的注意力分数的总和。

公式 1: 注意力损失定义

这里,\(a_{l,h,p}\) 是第 \(p\) 个 Token 的注意力分数。

接下来,他们计算了他们的方法 (丢弃高范数) 造成的损失与一个通过知晓确切注意力分数来“作弊”的“理想”预言机造成的损失之间的差值 (\(\mathcal{Y}\))。

公式 2: ALR 差值计算

最后,他们在不同的压缩量上对这个差值求和,得到模型中每个头的 ALR 分数。

公式 3: ALR 求和

低 ALR 值 意味着 \(L_2\) 范数方法非常接近理想方法。高值则意味着相关性较弱。

图 1: Llama2-7b 各层各头的 ALR 值。

图 1 展示了 Llama-2-7b 中每一层每一个头的 ALR 分数。

  • 紫色区域: 低 ALR (高相关性) 。这意味着 \(L_2\) 范数是重要性的极佳预测指标。
  • 红色/橙色区域: 高 ALR (低相关性) 。

热力图的关键要点:

  1. 大多数层 (大片的紫色) 显示出非常强的相关性。
  2. 第 0 层和第 1 层 (最底部) 以及中间的一些头 (第 10-15 层左右) 显示出较低的相关性。 这表明对于网络的大部分,我们可以安全地基于 \(L_2\) 范数进行压缩,但我们可能需要对最初的几层小心处理。

提出的方法: 保留低范数 (Keep Low Norm)

基于这一发现,该策略极其简单。它不需要训练,不需要微调,也不需要计算注意力矩阵。

算法步骤:

  1. 当 KV 缓存达到大小限制 (预算) 时,查看缓存中当前的 Key 向量。
  2. 计算每个 Key 向量的 \(L_2\) 范数。
  3. 保留 前 \(k\) 个范数 最低 的 Token。
  4. 驱逐 范数 最高 的 Token。

这种启发式方法允许压缩完全基于 Key 嵌入的静态属性进行。它与 FlashAttention 完全兼容,因为它不需要中间的注意力概率矩阵。

为什么这行得通?“汇聚点 Token”假设

为什么较小的向量会吸引更多的注意力?作者假设这与 注意力汇聚点 (Attention Sinks) 现象有关,这是以前的研究 (如 StreamingLLM) 探索过的一个概念。

模型通常会将大量的注意力倾注到特定的 Token 上 (如句首的 <s> Token 或标点符号) ,有效地将它们用作“无操作 (no-op) ”或当没有其他 Token 相关时的休息处。

作者通过分析嵌入的具体维度进行了更深入的挖掘。

图 25: 显示尖峰激活的键投影

正如 图 25 (以及论文中的图 6) 所示,“重要”的 Token (如 BOS Token) 通常具有 稀疏激活 。 它们的嵌入大部分接近于零,但在特定维度上有巨大的尖峰。

当作者尝试“置零”这些特定的尖峰维度时,注意力图发生了巨大的变化。当他们随机置零其他维度时,什么也没有发生。这表明低范数向量并不“弱”;它们是高度专业化的向量,与特定的查询方向完美对齐,尽管它们的整体模长很低,却能触发巨大的注意力分数。

实验结果

理论听起来很可靠,但在实践中效果如何?作者在语言建模 (困惑度) 和严格的长上下文任务上测试了该方法。

1. 语言建模 (困惑度)

他们在 Wikipedia 数据集上测试了 Llama-2、Llama-3 和 Gemma。他们将 KV 缓存限制在 2,000 个 Token (即使输入变得更长) ,并比较了不同的驱逐策略。

图 3: 各模型的困惑度比较

图 3 中,结果非常鲜明:

  • 红线 (无压缩) : 基线。
  • 绿线 (保留高范数) : 这是与提议方法相反的做法。困惑度激增 (越高越差) 。这证明了高范数 Token 确实是 最不 重要的。
  • 蓝线 (保留低范数) : 这是提议的方法。即使在大幅压缩缓存时,它也几乎完美地紧跟基线性能。

2. 长上下文压力测试

困惑度只是一个粗略的指标。长上下文模型的真正考验是检索深埋在文本中的特定信息。

大海捞针 (Needle-in-a-haystack): 给模型一段非常长的文本 (干草堆) ,里面藏着一个具体的随机事实 (针) 。模型必须回答关于那个事实的问题。

图 4: 大海捞针和密码检索分数

图 4(a) 中,我们看到了 Llama-2-7b-80k 在大海捞针任务上的准确率:

  • 保留低范数 (蓝色) : 即使在 50% 的压缩率下也能保持近乎 100% 的准确率 。 它的表现显著优于随机驱逐 (绿色) 。
  • 保留高范数 (深蓝色) : 几乎立即失败。

密码检索 (Passkey Retrieval): 一个类似的任务,模型必须检索一个特定的数字密码。 在 图 4(b) 中,该方法即使在压缩 90% 的缓存时也能保持 100% 的准确率 。 这是在零效用损失的情况下,对特定任务实现了巨大的内存使用量减少。

我们可以在下面看到密码任务更详细的细分:

图 11: 详细的密码检索结果

图 11(b) 显示了 Llama-2-7b-longlora 的密码检索结果。准确率保持完美 (1.0 的平直线) ,直到压缩率变得极其激进 (丢弃 >80% 的 Token) 。

3. 与 FastGen 的比较

FastGen 是一种最先进的压缩方法,但它依赖于注意力分析 (计算注意力分数) 来决定保留什么。这使得它与 FlashAttention 不兼容。

作者将他们简单的 \(L_2\) 范数策略与 FastGen (配置为不使用完整注意力分数以示公平) 进行了比较。

图 5: 与 FastGen 的比较

图 5 显示,在 Llama-3-8b 上,\(L_2\) 范数策略 (浅蓝色) 产生的困惑度比 FastGen (海军蓝) 更低,同时计算更简单,更容易集成到现代推理栈中。

消融实验: 我们应该跳过某些层吗?

回想一下 图 1 中的热力图,它显示第 0 层和第 1 层的范数与注意力之间的相关性较低。作者研究了我们是否应该跳过这些特定层的压缩。

图 10: 在不同层跳过压缩

图 10 中的结果有些微妙。

  • 图表 (c) 和 (d): 跳过前两层 (0 和 1) 有助于比不加区分地压缩所有层稍微更好地保持准确性,尤其是在压缩率很高 (Max KV 1000) 时。
  • 然而,对于适度的压缩,差异可以忽略不计。
  • 结论: 一个安全的默认设置是压缩所有层,但为了最大的性能稳定性,保持前两层不压缩是一个明智的优化。

结论与启示

论文 “A Simple and Effective \(L_{2}\) Norm-Based Strategy for KV Cache Compression” 为这个通常被日益增加的复杂性所主导的领域提供了一个令人耳目一新的见解。

我们不需要训练新的模块或在推理过程中计算昂贵的注意力矩阵,而是可以依赖嵌入空间的一个基本几何属性: 幅度与冗余度相关。

核心要点:

  1. 低范数 = 高重要性: 具有小 \(L_2\) 范数的 Key 嵌入是模型最关注的对象。
  2. FlashAttention 兼容性: 因为这种方法只查看 Key 向量 (一旦计算出来就是静态的) ,它不需要注意力矩阵。这使得它能够利用 FlashAttention 的速度优势。
  3. 巨大的节省: 实验表明,我们通常可以在检索任务上减少 50% 到 90% 的 KV 缓存,而性能下降微乎其微。

这一策略使长上下文推理变得触手可及。通过显著降低显存需求,它允许在消费级硬件上运行更长的上下文,并提高生产环境中服务系统的吞吐量。事实证明,有时候管理内存的最好方法就是检查谁最“安静”。


这篇博客文章总结了 Devoto 等人的研究。如果您对更深层次的数学证明或额外的可视化感兴趣,我鼓励您阅读完整的论文。