大语言模型 (LLM) 的版图目前主要由 Transformer 占据。然而,任何尝试过将一整本教科书输入标准聊天机器人的人都知道,Transformer 有一个弱点: “二次复杂度瓶颈”。随着输入文本长度的增加,计算成本呈爆炸式增长。这引发了人们对 状态空间模型 (SSM) (如 Mamba) 的极大兴趣。SSM 承诺提供一种“次二次 (sub-quadratic) ”的替代方案,从理论上讲允许模型高效地处理海量序列。

但这其中存在一个陷阱。虽然 SSM 速度很快,但它们通常患有一种形式的“健忘症”。因为它们依赖固定大小的隐藏状态来压缩历史信息,所以当上下文长度超过训练时的长度时,它们往往会忘记早期的信息。

如果在不进行昂贵的重新训练的情况下解决这个问题?

LAMB (Long-context extension driven by Attention-guided token filtering in MamBa,基于 Mamba 中注意力引导的 Token 过滤的长上下文扩展) 应运而生。在最近的一篇论文中,来自佐治亚理工学院和英特尔实验室的研究人员提出了这个新颖的、无需训练的框架。LAMB 能够智能地选择将哪些 Token 保留在内存中,哪些丢弃,从而允许 SSM 在保持更高准确率的同时处理显著更长的上下文。

不同上下文长度下 RULER 上的平均性能比较。

如图 1 所示,差异是巨大的。当原版 Mamba 模型 (红线) 随着上下文长度增加准确率暴跌至接近零时,LAMB (蓝线) 保持了稳健的性能,甚至可以与先前的增强技术相媲美乃至超越。

在这篇深度文章中,我们将拆解 SSM 衰减的数学原理,探索 Mamba 内部的“隐式注意力”机制,并详细介绍 LAMB 如何通过过滤 Token 来保留长距离依赖。

背景: 为什么 SSM 会遗忘

要理解 LAMB,我们首先需要理解它所改进的模型的运作机制。Mamba 和类似的 SSM 是按顺序处理文本的。与同时查看所有 Token 的 Transformer (全局注意力) 不同,SSM 一次查看一个 Token 并更新一个运行中的“隐藏状态”。

Mamba 块的核心更新规则如下所示:

描述隐藏状态更新规则的公式。

这里,\(h_t\) 是时间 \(t\) 的隐藏状态。矩阵 \(\bar{A}\) 决定了保留多少前一状态 (\(h_{t-1}\)) ,而 \(\bar{B}\) 决定了添加多少输入 (\(x_t\)) 。

这里的关键项是 \(\bar{A}\)。在 Mamba 中,该项源自一个负矩阵,意味着 \(\bar{A}\) 通常小于 1。这起到了 衰减因子 的作用。随着每一个新 Token 的处理,之前步骤的信息都会乘以这个衰减因子。在长序列中,这种重复的乘法会导致早期信号呈指数级消失。这就是为什么标准的 Mamba 模型难以回答关于一本长书开头部分的问题——数学信号已经被侵蚀了。

以前的尝试: LongMamba 的局限性

研究人员曾试图解决这个问题。一个著名的先行者是 LongMamba , 它引入了 Token 过滤 的概念。逻辑很简单: 如果内存有限,我们就不应该保存每一个 Token。我们应该只为“重要”的 Token 更新隐藏状态。

LongMamba 基于更新项 \(\Delta_t\) 的幅度来决定重要性。如果 \(\Delta_t\) 很小,模型就假设该 Token 没有怎么改变状态,因此将其丢弃。虽然这有一定帮助,但这是一种启发式方法。它假设 更新的幅度 等同于 信息的重要性。正如我们将看到的,这种假设通常是有缺陷的。

洞察: 揭示隐式注意力

LAMB 的创造者退后一步问道: 究竟什么构成了 SSM 中重要的 Token?

为了回答这个问题,他们使用了一种从 Mamba 中提取“注意力图”的方法。尽管 Mamba 不像 Transformer 那样使用显式的注意力头,但输入 \(x\) 和输出 \(y\) 之间的关系可以在数学上展开。

任何步骤 \(t\) 的输出计算如下:

基于隐藏状态 h 的输出 y 的公式。

通过展开循环项,我们可以将输出表示为所有先前输入的加权和。这揭示了一个隐式的“注意力分数” \(\alpha_{i,j}\):

显示展开后注意力求和的公式。

这个公式允许我们可视化模型在生成新 Token 时实际上正在“关注”哪些过去的 Token。

分析注意力图

当研究人员可视化这些注意力图时,他们发现了构成 LAMB 基础的两个关键洞察。

1. 稀疏性是真实的: 如下图 2(a) 所示,注意力图不是均匀的。它呈现出垂直条纹模式。这意味着极小一部分 Token (特定的列) 主导了注意力。模型非常关心几个关键词,而很大程度上忽略了其余部分。

2. 注意力 > \(\Delta_t\): 研究人员针对“预言机 (Oracle) ” (理论上限) 测试了不同的过滤策略。他们发现,基于 注意力分数 (模型对它们的关注程度) 过滤 Token 远优于基于 \(\Delta_t\) (LongMamba 方法) 的过滤。

注意力图可视化和保留率比较图。

在图 2(b) 中,请看实蓝线 (Attention-Guided,注意力引导) 。它紧跟虚红线 (Oracle) 。这证明,如果我们能在内存耗尽之前准确计算出 Token 的注意力分数,就可以安全地丢弃其余部分而不影响性能。

核心方法: LAMB

挑战在于,我们需要在生成完整序列 之前 知道 Token 的重要性。我们需要一种方法来估计 Token 相对于未来生成的注意力分数。

LAMB (Long-context extension driven by Attention-guided token filtering in MamBa) 引入了一个管道来精确地做到这一点。然而,直接使用标准 Mamba 公式中的原始注意力分数会带来两个具体问题: 偏差 (Bias)噪声 (Noise)

第 1 步: 去偏注意力 (Debiased Attention)

在标准 Mamba 注意力中,衰减因子 \(\bar{A}\) 会随时间累积。这这就产生了强烈的 近因偏差 (Recency Bias) 。 最近出现的 Token 分数很高,仅仅是因为它们还没有衰减,而不一定是因为它们在语义上很重要。相反,来自提示词开头的关键信息由于经过了数千步的衰减,看起来分数很低。

为了解决这个问题,LAMB 引入了 去偏注意力 。 研究人员将导致偏差的累积衰减因子替换为一个常数因子。

去偏注意力的公式。

通过从测量中移除时间依赖的衰减,度量标准 \(\alpha^D\) 揭示了 Token 的 内在 重要性,使早期 Token 与近期 Token 处于同等地位。

第 2 步: 对比注意力 (Contrastive Attention)

第二个问题是噪声。正如可视化图中所示,注意力图可能是模糊的。为了使过滤决策稳健,我们需要清晰地区分“信号”和“噪声”。

LAMB 应用了一种 对比 机制。它从当前分数中减去最大注意力分数的一部分,并应用 ReLU (整流线性单元) 函数。

对比注意力的公式。

这里,\(\gamma\) 是一个超参数 (通常约为 0.9) 。此操作会抑制微小的波动。如果一个 Token 的分数不接近峰值分数 (在 \(\gamma\) 定义的范围内) ,它将被归零。这就像一个高通滤波器,只留下高重要性的独特“峰值”。

图 3 展示了这些转换的视觉效果。注意“对比注意力 (Contrastive Attention) ” (第三行) 比原始的噪声注意力要干净和清晰得多。

注意力类型的视觉比较和 LAMB 管道。

第 3 步: 聚合与池化

SSM 在其隐藏状态中使用多个“通道”运作。LongMamba 识别出某些通道是“全局”的 (长期记忆) ,而其他通道是“局部”的。LAMB 将过滤工作集中在这些全局通道上。

为了获得 Token \(t\) 的单一重要性分数,LAMB 将所有全局通道以及“观察窗口” (模型最近看到的几个 Token) 内的对比注意力分数相加。

原始重要性聚合的公式。

最后,还有一个巧妙的技巧。语言很少是关于单个孤立单词的;它是关于短语和局部上下文的。如果我们只挑选单个峰值,可能会丢失赋予单词意义的周围上下文。

为了解决这个问题,LAMB 对重要性分数应用 平均池化 (Mean Pooling)

平均池化的公式。

这平滑了选择过程,确保如果一个 Token 被选中,它的直接邻居也很可能被选中,从而保持局部语义的完整性。

完整流程

使用 LAMB 进行推理的完整工作流程如下:

  1. 处理提示词: 模型处理输入文本。
  2. 计算指标: 对于每个 Token,它计算相对于最近“观察窗口”的去偏、对比注意力分数。
  3. 选择 Top-K: 识别出聚合重要性分数最高的 \(K\) 个 Token。
  4. 过滤与更新:
  • 对于选中的 Top-K Token,模型正常更新隐藏状态。
  • 对于未选中的 Token,模型将 \(\Delta_t\) 设为 0。这强制隐藏状态保持不变 (\(h_t = h_{t-1}\)) ,有效地“跳过”了该 Token 对记忆状态的影响,同时保留了先前重要 Token 的记忆。

这完全是 无需训练 的。你可以采用一个预训练的 Mamba 模型,在推理过程中应用此逻辑,即可立即提升其长上下文能力。

实验与结果

研究人员在两个严格的长上下文理解基准测试上评估了 LAMB: HELMETRULER 。 他们在纯 SSM (Mamba2) 和混合模型 (Zamba2) 上进行了测试。

RULER 上的性能

RULER 是一个合成基准测试,旨在测试长距离下的精确检索。表 2 中的结果令人信服。

表 2: RULER 基准测试结果。

看最右边的 Avg. (平均) 列。

  • 原版 Mamba2-1.3B 的得分微不足道,仅为 0.27% 。 它在 16k 上下文长度的任务上实际上完全失败了。
  • LongMamba 将其提高到了 10.82%
  • LAMB 跃升至 33.96%

这代表了相对于先前最先进方法大约 3 倍的提升 。 这种提升在“多查询 (multiquery) ”和“变量跟踪 (vt) ”任务中尤为强劲,这些任务需要在长时间内保持特定信息而不受干扰。

HELMET 上的性能

HELMET 涵盖了更接近现实世界应用的任务,例如长文档问答和摘要。

表 1: HELMET 基准测试结果。

趋势保持不变。在 8k、16k 和 32k 上下文长度下,LAMB 始终优于原版模型和 LongMamba。例如,在 Zamba2 的 16k 上下文任务中,LAMB 达到了 12.35 , 而 LongMamba 为 11.35 , 原版为 6.76

为什么组件很重要 (消融实验)

去偏和池化的所有复杂性真的有必要吗?消融实验表明是的。

表 3: 消融实验。

  • 无去噪,无池化: 准确率仅为 3.40%
  • 仅池化: 准确率跃升至 27.22% 。 这表明保留关键 Token 周围的局部上下文至关重要。
  • 去噪 + 池化 (完整 LAMB) : 达到 33.96% 的峰值。去偏和对比步骤增加了实现顶级性能所需的最后一层精度。

延迟成本

有人可能会担心计算这些注意力矩阵会拖慢模型速度。然而,由于 LAMB 在“预填充”阶段 (处理提示词) 运行并使用高效的运算,开销微乎其微。

表 4: 延迟开销。

如表 4 所示,对于非常长的序列 (192k Token) ,开销仅约为 5.78% 。 关键是,在生成阶段 (当聊天机器人实际编写回复时) 零开销 , 因为过滤已经发生了。

结论

LAMB 框架代表了我们对状态空间模型理解的重大成熟。它让我们从对 Token 重要性的启发式猜测,转向了基于原则的、基于注意力的指标。

通过对注意力信号进行去偏并应用对比过滤,LAMB 允许 Mamba 模型在长序列的噪声中清晰地“看”到重点。它成功地在干草堆中找到了针——即必须记住的关键 Token——并丢弃了其余部分。

其意义令人兴奋:

  1. 效率: 我们可以在以前无法满足内存要求的硬件上运行长上下文任务。
  2. 可用性: 因为它是无需训练的,LAMB 可以立即应用于现有的 Mamba 部署。
  3. 理解: 这项工作加深了我们对 SSM 如何内部表示重要性的理论理解,弥合了 Transformer 注意力的透明度与循环神经网络效率之间的鸿沟。

随着我们致力于开发能够阅读整个图书馆文本的 LLM,像 LAMB 这样的技术对于确保这些模型不仅能阅读,而且能真正记住内容将至关重要。