想象一下学习骑自行车。现在,再想象一下学会骑车这件事导致你立刻忘记了如何走路。这种荒谬的情况对于许多人工智能模型来说却是现实。这种现象被称为灾难性遗忘 (Catastrophic Forgetting) , 它是持续学习 (Continual Learning, CL) 领域的一个主要障碍,在这个领域中,模型必须学习一系列任务而不清除其先前的知识。

当你没有太多数据可供学习时,这个问题会变得更加困难——这种场景被称为小样本持续关系抽取 (Few-shot Continual Relation Extraction, FCRE) 。 在这里,模型必须仅根据少量示例识别文本中的关系 (例如,“A 是 B 的母亲”) ,同时还要处理随时间推移出现的新关系。

通常的方法是采用一个强大的预训练模型 (如 BERT) ,切断它的“头” (用于通用语言理解的输出层) ,然后用一个新的、随机初始化的分类器取而代之。最近的一篇研究论文认为这是一个错误。通过丢弃预训练语言模型 (LM) 的头,我们丢掉了大量的通用知识。

在这篇文章中,我们将探讨一种称为互信息最大化 (Mutual Information Maximization, MIM) 的新方法。我们将看到保留“旧头”如何引导“新头”,从而显著提高标准模型和现代大型语言模型 (LLMs) 的性能和记忆力。

背景: FCRE 和丢失的头

在深入探讨解决方案之前,让我们先明确问题所在。

持续关系抽取 (CRE) 要求模型处理一系列任务。在任务 1 中,它可能学习识别“雇主/雇员”关系。在任务 2 中,它学习“出生地”关系。目标是在擅长任务 2 的同时不忘记任务 1。

加上“小样本 (Few-shot) ”这个条件,模型可能只能看到 5 或 10 个“出生地”的例子。这种稀缺性导致了两个主要问题:

  1. 灾难性遗忘: 模型为了适应新的、紧迫的信息而覆盖了旧的权重。
  2. 过拟合 (Overfitting) : 由于例子太少,模型死记硬背了特定的训练数据,而不是学习关系的一般概念。

标准方法 vs. 新思路

大多数现有的解决方案使用记忆缓冲区 (保存少量旧样本以便稍后重演) 或原型学习。然而,它们几乎普遍遵循一种特定的架构模式: 采用预训练骨干网络 (如 BERT) ,丢弃原始输出层 (LM 头) ,并从头开始训练一个新的分类头。

这篇论文背后的研究人员认为, LM 头——最初训练用于预测句子中下一个词的模型部分——包含了丰富、通用的语义知识。丢弃它是浪费。

现有 FCRE 方法与本文提出的 MIM 策略的区别。左侧显示了使用提示设计和记忆管理的标准方法。右侧显示了结合 LM 头和互信息损失的新策略。

如上图 Figure 3 所示,标准方法 (左) 仅依赖于新的表征学习损失 (\(\mathcal{L}_0\)) 。提出的框架 (右) 保持 LM 头处于活跃状态。它利用 LM 头的输出,通过互信息来“监督”或对齐主分类器。

核心方法: 互信息最大化 (MIM)

核心假设很简单: 由新分类器生成的句子表征 (容易过拟合) 应该与预训练 LM 头生成的表征 (鲁棒且通用) 共享高度的信息。

为了实现这一目标,作者引入了互信息最大化 (MIM) 策略。

目标函数

目标是最大化两个潜在表征之间的互信息 (MI) :

  1. \(g_{\phi}(\mathbf{x})\): 来自主分类头的特征表征。
  2. \(g_{\Phi}^{LM}(\mathbf{x})\): 来自预训练 LM 头的表征。

在数学上,我们要最大化:

Equation 1: 互信息公式

然而,在高维空间中计算精确的互信息是困难的。为了解决这个问题,作者使用了被称为 InfoNCE 的下界。这是自监督学习中常用的一种对比学习目标。

Equation 2: 使用 InfoNCE 的互信息下界

理解 InfoNCE

InfoNCE 损失本质上是将“正样本对”拉近,将“负样本对”推开。在这个语境下:

  • 正样本对是分类器表征与 LM 头对同一输入句子的表征。
  • 负样本对是分类器对当前句子的表征与 LM 头对批次 (batch) 中其他句子的表征。

InfoNCE 的计算公式如下:

Equation 3: InfoNCE 计算细节

这里,\(W\) 是一个可训练参数,帮助将两个不同的向量空间相互映射,\(\tau\) 是一个温度参数,用于控制概率分布的尖锐程度。

MIM 策略的最终损失函数是对训练数据的求和:

Equation 4: 互信息损失函数

总损失

这个新的 MI 损失被添加到所使用的任何基线模型的标准损失函数中 (表示为 \(\mathcal{L}_0\)) 。这使得 MIM 策略具有“即插即用”的特性——它可以被添加到几乎现有的 FCRE 方法中以进行改进。

Equation 5: 结合原始损失和 MI 损失的总损失

通过最小化这个组合损失,模型被迫学习一个对特定任务准确,但同时也与预训练骨干网络的通用语言知识保持一致的分类器。

适配大型语言模型 (LLMs)

虽然 BERT 类模型是此任务的标准,但研究人员也想探索现代大型语言模型 (LLMs) 如 LLaMA-2Mistral 的潜力。

然而,存在技术上的不匹配。BERT 是“仅编码器 (encoder-only) ”模型 (擅长分类) ,而 LLaMA 是“仅解码器 (decoder-only) ”自回归模型 (擅长生成文本) 。

为了弥合这一差距,作者修改了输入结构。他们没有像 BERT 那样使用 [MASK] 标记,而是重新构建了提示 (prompt) 来提出问题,如下图 Figure 6 所示。

展示如何为 LLM 调整提示的图解。输入被转换为问题格式: ‘Entity 1 和 Entity 2 之间的关系是…’。

他们提取单词“is” (答案前的最后一个词元) 的嵌入作为特征表征 (\(g_{\phi}\)) 。这使得他们能够将完全相同的 MIM 策略应用于这些庞大的生成模型。

实验与结果

团队在两个主要基准上测试了他们的方法: FewRelTACRED 。 他们将 MIM 策略应用于三个最先进的基线模型: SCKD、ConPL 和 CPL。

1. 它能提高准确率吗?

结果是一致的。添加 MIM 策略 (+MI) 全面提高了性能。

Table 1: 准确率比较。带有 ‘+ MI’ 的行在所有任务中均显示出比原始对应项更高的数值。

Table 1 中,我们可以看到对于基于 BERT 的方法,+MI 变体 (突出显示的行) 始终优于原始版本。例如,在具有挑战性的 TACRED 数据集上, ConPL+MI 比标准 ConPL 实现了显著更好的跨任务稳定性。

2. 它能减少遗忘吗?

该领域的主要敌人是“准确率下降 (accuracy drop) ”——即模型刚学会任务时的表现与最后记得多少之间的差异。

Figure 1: 显示准确率下降的条形图。红色条 (Ours) 低于蓝色条 (Origin) ,表明遗忘较少。

Figure 1 生动地说明了这一点。蓝色条代表原始方法的准确率下降。红色条代表增强了 MIM 的方法。在每种情况下,红色条都更低,这意味着模型保留了更多的过去知识。这证实了 LM 头起到了锚点的作用,防止模型偏离其通用知识库太远。

3. 可视化改进

数字虽好,但在空间中观察数据往往更直观。研究人员使用 t-SNE (一种用于可视化高维数据的技术) 来绘制模型如何对不同的关系进行分组。

Figure 4: t-SNE 可视化。右侧的簇 (CPL+MI) 比左侧的簇 (CPL) 更紧密且更独特。

Figure 4 中,比较左图 (原始 CPL) 和右图 (CPL+MI) 。不同颜色的簇代表不同的关系。

  • 左侧: 簇有些分散,边界模糊。
  • 右侧: 簇更紧密,分离得更好。

这种更紧密的聚类意味着模型在不同关系之间的混淆更少,从而导致更高的准确率和更好的泛化能力。

4. LLM 的表现如何?

研究发现,像 LLaMA-2-7B 和 Mistral-7B 这样的 LLM 由于其巨大的规模和预训练深度,在 FCRE 任务中通常优于基于 BERT 的模型。然而,它们仍然容易发生遗忘。

关键在于, MIM 策略对 LLM 也有效

Table 6: LLM 的详细准确率表。底部行显示 Mistral-7B-CPL + MI 实现了最高的性能和最低的准确率下降。

查看 Table 6 中的数据 (特别是最右侧的 \(\Delta \downarrow\) 列) ,我们看到了准确率下降的情况。 Mistral-7B-CPL + MI 实现了令人难以置信的低准确率下降,约为 21.5%,而标准基于 BERT 的方法则超过 30%。这表明,结合 LLM 的海量知识与 MIM 的对齐策略是未来研究的一个强有力的方向。

结论

这项研究的关键结论是,在我们寻求针对特定任务专门化 AI 模型时,不应过快地抛弃它们的通用能力。

预训练语言模型的“头”不仅仅是预测下一个词的机制;它是通用语言理解的宝库。通过使用互信息最大化 , 我们可以强制新的、专门的分类器与这种深层知识保持同步。

这种方法提供了一个“两全其美”的解决方案:

  1. 灵活性: 模型可以从极少的例子中学习新的、特定的关系。
  2. 稳定性: 模型保留了预训练骨干网络的鲁棒、通用特征,防止了灾难性遗忘。

随着我们转向使用像 LLaMA 和 Mistral 这样更大的模型来执行专门任务,像 MIM 这样的技术将对于确保这些巨人在学习新技能的同时不忘记旧技能至关重要。