解构上下文学习: 隐藏在 LLM 内部的双塔机制

像 GPT-4 和 Llama 这样的大型语言模型 (LLM) 展现出一种迷人的涌现能力,即上下文学习 (In-Context Learning, ICL) 。 这种现象是指,当你在提示词中提供少量示例 (演示) 时——比如“英语: Cat,法语: Chat”——模型能够立即学会这种模式并完成一个新的例子,而无需任何参数更新或重新训练。

虽然我们每天都在使用 ICL,但其潜在机制仍然有点像一个“黑盒”。模型究竟是如何将信息从演示示例转移到最终预测的?它是真的在“学习”任务,还是仅仅依靠预先存在的知识?

在一篇引人入胜的论文《How do Large Language Models Learn In-Context? Query and Key Matrices of In-Context Heads are Two Towers for Metric Learning》 (大型语言模型如何进行上下文学习?上下文头的查询和键矩阵是度量学习的双塔) 中,来自曼彻斯特大学的研究员 Zeping Yu 和 Sophia Ananiadou 剥开了 Transformer 架构的层层外衣。他们提出了一个令人信服的假设: ICL 通过一组特定的注意力头运作,其功能类似于双塔度量学习系统 (Two-Tower metric learning system)

在这篇深度文章中,我们将梳理他们的方法论,了解他们发现的“上下文头 (In-Context Heads)”,以及这一新理解如何解释诸如多数标签偏差和近因偏差等令人费解的行为。

引言: 机制之谜

要理解模型如何进行上下文学习,我们首先需要将学习过程与模型的先验知识隔离开来。如果你要求模型将电影评论分类为“正面”或“负面”,它可能只是依赖于它知道“优秀”这个词是正面的这一事实。很难判断模型是在看你的例子,还是仅仅在使用它的训练数据。

为了解决这个问题,研究人员专注于使用语义无关标签的任务学习 (Task Learning, TL) 。 他们不再使用“正面/负面”,而是强制模型将输入映射到任意标签,如 “foo”“bar”

例如:

  • 演示 1: “经济正在繁荣” : bar
  • 演示 2: “股市正在崩盘” : bar
  • 演示 3: “团队赢得了比赛” : foo
  • 查询: “球员进了一个球” : [预测]

如果模型预测 “foo”,它就没有使用先验知识 (因为 “foo” 与体育无关) 。它必须是在观察上下文。这种设置允许作者从机械层面精确追踪模型是如何将信息从演示上下文移动到最终输出的。

寻找“上下文头”

Transformer 由许多层组成,每一层都有多个“注意力头”。一个标准的 70 亿参数模型可能有 32 层,每层 32 个头——总共超过一千个头。它们对上下文学习的贡献是平等的吗?

研究人员利用因果追踪和干预方法——本质上是关闭特定的头——来观察哪些头实际上影响了模型执行 “foo/bar” 任务的能力。

1% 的发现

结果令人震惊。他们发现 ICL 的性能并不是分布在模型的整个“大脑”中。相反,它依赖于一小部分“上下文头”——大约占所有头的 1% (在测试的模型中约为 12 个头) 。

当他们仅干预这 12 个头时:

  • 准确率从 87.6% 骤降至 24.4%
  • 其余的头对于这个特定机制来说基本上是无关紧要的。

他们进一步将这些头分为两组:

  1. Foo 头 (Fooheads): 当处于活跃状态时,专门增加标签 “foo” 概率的头。
  2. Bar 头 (Barheads): 专门增加标签 “bar” 概率的头。

这种定位使我们可以不再将模型视为一个巨大的、混乱的单体,而完全专注于这些特定头内部发生的事情。

核心假设: 双塔系统

为了理解作者的主要贡献,我们需要看看单个注意力头的解剖结构。在 Transformer 中,一个注意力头通常由四个矩阵描述: 查询 (Query, Q)、键 (Key, K)、值 (Value, V) 和输出 (Output, O)

标准的解释是, 查询请求信息, 定义可用的信息,而是被传递的实际内容。

作者分析了这些矩阵在已识别的“上下文头”中的数学行为,并提出了一个统一的假设,如下图所示:

图 1: ICL 机制假设。(a) 浅层将特征融合到标签位置和最后位置。在上下文头中,(b) 值-输出矩阵 VO 提取标签信息。(c) 查询矩阵 Q 和 (d) 键矩阵 K 计算 (e) 最后位置与每个演示之间的相似度分数,决定有多少标签信息被转移到最后一个 token。

让我们一步步分解这张图 (图 1) 。

1. 浅层: 特征融合

在我们到达上下文头 (通常位于更深层) 之前,模型的浅层执行了一个关键的预处理步骤 (图中的 (a) 部分) 。

模型聚合信息。位于标签位置 (例如提示词中的 “bar” 一词) 的向量从其对应的句子中收集语义信息。同时,位于最后位置 (模型需要做出预测的地方) 的向量收集关于当前输入文本的信息。

2. 值-输出 (VO) 矩阵: “是什么”

在更深层的上下文头中, 输出矩阵 (VO) 充当信息提取器。

作者分析了由这些矩阵生成的向量,并使用以下公式将它们投影到词汇空间中:

公式 1

这个公式本质上是在问: “如果我们把这个向量翻译回英语单词,它在说什么?”

表 1 (如下) 展示了这种投影的结果。请看 value 行。位于 “bar” 位置 (2-value, 5-value) 的向量在大喊 “BAR, Bars, Baron”。位于 “foo” 位置 (8-value, 11-value) 的向量在大喊 “foo, Foo”。

表 1: 标签位置和最后位置的 Top token。

关键洞察: 值-输出矩阵是“哑”管道。它们的工作仅仅是保存标签信息 (“foo” 或 “bar”) 。如果注意力关注到了这个位置,这个矩阵就确保 “foo” 或 “bar” 的概念被复制到最终预测中。

3. 查询-键 (QK) 矩阵: “双塔”

如果 VO 矩阵提供内容,那么查询和键矩阵则决定流向。这正是论文提出新颖的“双塔”解释的地方。

在机器学习中,“双塔”模型常用于推荐系统。一个塔处理用户 (User),另一个塔处理物品 (Item),然后计算它们之间的相似度 (点积) 来看看它们是否匹配。

作者建议上下文学习以同样的方式工作:

  • 塔 1 (查询): 代表最后位置 (我们想要分类的新输入句子) 。
  • 塔 2 (键): 代表演示中的标签位置 (包含示例句子的语义特征) 。

注意力机制计算查询 (输入) 和键 (演示) 之间的相似度。

  • 如果新输入句子在语义上与演示 #1 相似,相似度分数就会很高。
  • 模型“关注”演示 #1。
  • 来自演示 #1 的 VO 矩阵 (持有标签 “foo”) 被解锁。
  • “Foo” 流入最终预测。

这意味着 ICL 本质上是在注意力头内部执行度量学习 (Metric Learning) 。 它正在计算你的输入与提供的示例之间的距离度量。

实验证据: 转变

为了证明这一点,研究人员反转了标签。他们取了一个正确答案原本是 “foo” 的提示,并交换了标签,使答案变成了 “bar”。

如果他们的假设是正确的,“内容” (VO) 应该不会有太大变化,但“流向” (Attention) 应该会发生巨大的转变。

表 4: Llama (第一块) 和 GPT-J (第二块) 中 foo 头/bar 头 (fh/bh) 在 “foo”/“bar” 位置 (fp/bp) 的加权值-输出向量的 Logit 差值。

数据支持了这一点。如表所示,预测的转变是由 “foo” 位置注意力分数的显著减少和 “bar” 位置注意力分数的增加所驱动的。机器并没有“重读”文本;它只是重新加权了相似度,允许不同的标签塔占据主导地位。

解释 ICL 的偏差

验证一个新理论最有力的证据之一是它能够解释以前令人困惑的现象。众所周知,上下文学习深受多数标签偏差 (Majority Label Bias) (偏向出现频率高的标签) 和近因偏差 (Recency Bias) (偏向出现在提示末尾的标签) 的困扰。

“双塔”假设为这两者提供了清晰的解释。

1. 多数标签偏差

为什么模型偏向多数标签?

因为最终输出是注意力分数的总和 。 如果你有 10 个 “foo” 的例子和 1 个 “bar” 的例子,就有 10 个代表 “foo” 的“键塔”。即使相似度匹配一般,10 个普通分数的总和通常也会超过单个 “bar” 例子的分数。

作者通过创建不平衡数据集验证了这一点。

图 2: Llama (左) 和 GPT-J (右) 在原始数据集和不平衡数据集上,foo 头的 foo 位置和 bar 头的 bar 位置的注意力分数。

在图 2 中,我们可以看到当移除 “foo” 演示时 (不平衡数据集) ,“foo” 位置的总注意力权重显著下降 (对比蓝色和橙色条形图) 。模型并不是在心理学意义上的“偏见”;它只是一个累加相似度分数的求和机器。

2. 近因偏差

为什么模型偏向提示末尾的例子?

作者假设这是由于位置嵌入 (Positional Embeddings) 造成的。在 Transformer 中,每个词都被标记了其位置 (第 1 个、第 2 个、第 100 个) 。

  • “查询”总是在最末尾 (例如,位置 100) 。
  • 最近示例的“键”在位置 90, 80…
  • 早期示例的“键”在位置 10, 20…

因为位置数字更接近,查询 (位置 100) 与最近的键 (位置 90) 之间的数学相似度被人为地夸大了,相比之下,与遥远的键 (位置 10) 的相似度则较低。“位置”变成了双塔模型无意中匹配的一个特征。

作者通过反转演示的顺序测试了这一点 (图 3) 。

图 3: Llama (左) 和 GPT-J (右) 在原始数据集和反转数据集上,foo 头的 foo 位置和 bar 头的 bar 位置的注意力分数。

当顺序反转时,注意力权重发生了显著变化,证实了提示中的物理位置决定了注意力的强度。

下面的图 4 和图 5 进一步可视化了这一点,展示了注意力如何在不同的数据集配置 (原始、不平衡和近因/反转) 中变化。你可以看到基于提示结构的不同,注意力分布发生了明显的偏移。

图 4: Llama 在原始、不平衡和近因数据集上 “foo”/“bar” 位置的注意力分数。

图 5: GPT-J 在原始、不平衡和近因数据集上 “foo”/“bar” 位置的注意力分数。

工程解决方案: 减少偏差

有了这种机制性的理解,作者并没有止步于解释——他们提出了修复方案。

修复多数标签偏差: 由于偏差是由少数类别的注意力权重总和较低引起的,作者建议在上下文头中通过数学方法提升少数位置的注意力分数。他们引入了一个基于演示数量比率的乘数。

修复近因偏差: 由于偏差是由位置嵌入夸大了相似度引起的,他们建议专门在上下文头内的注意力计算中剥离位置信息。

这些干预的结果是积极的:

表 7: 应用我们的方法前后准确率的变化,Llama (第一块) 和 GPT-J (第二块) 。

如表 7 所示,应用修复多数偏差的方法显著减少了准确率的波动 (准确率变化下降了约 22%) 。同样,移除位置影响将近因偏差减少了约 17%。

结论

这篇论文在机械可解释性方面迈出了重要一步。它使我们不再将上下文学习视为魔法,而是将其视为在硅片上运行的结构化算法。

主要结论:

  1. 专业化: 只有极小部分的头 (约 1%) 负责上下文学习。
  2. 角色分离: 在这些头内部,值矩阵携带标签 (“foo”) ,而查询/键矩阵决定哪个标签适用。
  3. 度量学习: 查询/键的交互作为一个双塔模型发挥作用,计算当前输入与先前示例之间的相似度。
  4. 偏差是机械性的: 像近因和多数偏好这样的偏差是求和及位置编码的可预测数学产物。

通过将这些“上下文头”理解为用于度量学习的两个独特的塔,我们不仅能更好地理解我们的模型,还能主动对其进行工程设计,使其更加鲁棒、公平和准确。