AI 为何遗忘: 解决医学影像中的灾难性遗忘问题

人工智能在医学诊断领域取得了巨大进步,特别是在病理切片的分析方面。然而,在这些系统的部署过程中存在一个隐性问题: 它们是静态的。在快速发展的医学领域,新疾病不断被发现,新的亚型被分类,扫描设备也在不断升级。

理想情况下,我们希望 AI 模型能够持续学习,在适应新数据的同时不丧失识别既往病症的能力。这就是 持续学习 (Continual Learning, CL) 的领域。但是,当研究人员将标准的持续学习技术应用于病理学时,他们遇到了一堵被称为 灾难性遗忘 (catastrophic forgetting) 的墙。模型学会了新任务,却完全忘记了旧任务。

在这篇深度文章中,我们将探讨一篇研究论文,它揭示了 为什么 这种情况在医学影像中尤为突出,并提出了一种新颖的解决方案。作者认为,在病理模型中,AI 的“大脑”没有遗忘——遗忘的是它的“注意力”。

背景: 全切片成像与多示例学习 (MIL)

要理解这个问题,我们首先需要了解数据。病理学家分析的是 全切片图像 (Whole Slide Images, WSIs) 。 这些是组织样本的数字扫描图,其大小达到吉像素级 (gigapixels) ——对于标准的神经网络来说,一次性处理这么大的图像是不可能的。

为了处理这个问题,研究人员使用 多示例学习 (Multiple Instance Learning, MIL) 。 标准流程如下:

  1. Patch 包 (Bag of Patches): 巨大的 WSI 被切割成成千上万个称为“Patch” (小块) 的小方块。整个切片被视为一个“包”。
  2. 特征提取: 神经网络从每个 Patch 中提取特征。
  3. 注意力机制: 这是关键部分。一个 注意力网络 会查看所有的 Patch 并为每一个分配一个“重要性评分”。它本质上是在决定组织的哪些部分是可疑的 (肿瘤) ,哪些是正常的。
  4. 聚合与分类: 模型将加权后的特征聚合成单一的切片级表示,并做出最终诊断 (例如: 癌症 vs. 正常) 。

从数学角度来看,切片级特征 \(\mathbf{z}\) 是使用注意力分数 \(a_n\) 对 Patch 特征 \(\mathbf{h}_n\) 进行加权求和计算得出的:

切片级特征聚合的公式。

注意力分数 \(a_n\) 是通过特定的网络架构 (通常使用 tanh 和 sigmoid 激活函数) 推导出来的:

计算注意力分数的公式。

这种架构对于静态数据集非常有效。但是当我们引入 类增量学习 (Class-Incremental Learning, CIL) 时会发生什么?这是一种场景,我们在“任务 1” (例如: 检测乳腺癌) 上训练模型,随后在“任务 2” (例如: 检测肺癌) 上训练它,而此时无法再访问原始的乳腺癌数据。

在标准的计算机视觉 (如分类猫与狗) 中,灾难性遗忘通常发生在最后的分类层。这篇论文的作者发现,在 MIL 中,情况有着根本的不同。

核心洞察: 不是分类器,而是注意力

研究人员进行了一项引人入胜的“解耦实验”,以调查记忆丢失发生在哪里。他们采用了一个在一系列任务上训练过的模型,并将其部分组件与仅在第一个任务上训练过的原始模型进行交换。

他们的问题是: 如果我们保留旧的分类器,但使用新的注意力网络,准确率会下降吗?反之亦然呢?

结果令人震惊。

表格显示交换注意力层与分类器层时的准确率下降情况。

如上方的 表 1 所示,请看 “Attention \(\theta_t\)” 列。当模型使用来自后续阶段 (\(t=2\) 或 \(t=3\)) 的注意力网络但保留原始分类器时,任务 1 的准确率直线下降 (降至接近 0%) 。然而,如果他们保留原始的注意力网络 (\(t=1\)) 并使用较新的分类器,准确率仍然保持得非常高 (~86-89%) 。

结论: 模型并没有忘记如何对特征进行分类;它忘记的是 看哪里

可视化漂移

当观察模型关注点的热力图时,这一理论发现变得显而易见。

热力图显示注意力从肿瘤区域漂移开。

图 1 中,面板 (2) 显示了学习任务 1 后的注意力图: 模型正确地聚焦于肿瘤 (红色区域) 。然而,在面板 (3) 和 (4) 中,随着模型学习新任务 (微调) ,注意力发生了“漂移”。红色的热点区域从肿瘤部位消失了。模型实际上对它之前知道如何发现的癌症变得视而不见。

为什么注意力会漂移?

作者基于梯度分析为这一现象提供了数学解释。他们分析了分类器和注意力网络的梯度 (用于更新模型权重的信号) 。

对于分类器权重 \(\phi\),梯度是有界的。它依赖于聚合特征 \(z_j\),这是一个无法无限增长的加权和。

分类器梯度平方范数的公式。

然而,注意力分数 \(a_i\) 的梯度取决于项 \(\phi^\top \mathbf{h}_i\):

注意力梯度平方范数的公式。

这一项代表特定 Patch 的原始 Logit 分数。这是 无界的 。 如果模型遇到的新任务特征触发了分类器的极高响应,注意力层的梯度可能会爆炸或剧烈波动。

这种不稳定性得到了实证确认。下图追踪了训练过程中梯度值的分布。

图表显示注意力网络与分类器网络中的梯度波动情况。

图 2 (上图) 中,你可以看到注意力梯度在整个学习过程中在很大的范围内振荡。相比之下,分类器梯度 (下图) 随着时间的推移趋于稳定并收缩。这种不稳定性使得注意力网络极易用新信息覆盖旧知识——这就是灾难性遗忘的定义。

解决方案: 双管齐下的方法

基于注意力层是薄弱环节这一认识,研究人员提出了一个包含两个主要组件的新框架:

  1. 注意力知识蒸馏 (Attention Knowledge Distillation, AKD) 用于解决遗忘问题。
  2. 伪包内存池 (Pseudo-Bag Memory Pool, PMP) 用于处理巨大的数据量。

以下是他们提出的系统的高级架构:

框架架构图,显示学生和教师网络。

1. 注意力知识蒸馏 (AKD)

标准的持续学习通常使用“Logit 蒸馏”,即强制新模型对旧数据输出与旧模型相同的最终分类分数。然而,由于这里的问题在于 注意力,作者引入了 注意力知识蒸馏

他们强制新模型 (学生) 模仿旧模型 (教师) 的注意力分布。目标是最小化前一个模型 (\(f_{\theta_{t-1}}\)) 和当前模型 (\(f_{\theta_t}\)) 的注意力权重之间的 Kullback-Leibler (KL) 散度。

注意力知识蒸馏损失的公式。

通过锁定注意力模式,模型确保即使在学习新疾病特征时,它本质上也“记住”了哪些 Patch 对之前的疾病是重要的。

最终的损失函数结合了标准交叉熵损失 (用于学习新任务) 以及注意力蒸馏和 Logit 蒸馏:

总损失函数的公式。

2. 伪包内存池 (PMP)

第二个挑战是病理学特有的: 存储。要使用蒸馏,通常需要回放一些旧数据 (回放缓冲区) 。但 WSI 大小达数千兆字节。存储数百张旧切片在计算成本和内存上都是令人望而却步的。

作者意识到他们不需要存储整个切片。由于 MIL 依赖于特定的实例,他们可以提炼一个“伪包 (Pseudo-Bag)”。

他们不存储 \(N\) 个 Patch (\(N\) 可能超过 10,000) ,而是存储一小部分 \(K\) 个 Patch (例如,只有几十个) 。但选哪些呢? 如果你只存储“重要的” (高注意力) Patch,你就会失去背景上下文,而这是模型学习 看哪里所必需的。

他们提出了 MaxMinRand 策略:

  1. Max (最大): 选择注意力分数最高的 Patch (肿瘤) 。
  2. Min (最小): 选择注意力分数最低的 Patch (背景) 。
  3. Rand (随机): 选择随机 Patch (以捕获一般分布) 。

这创建了切片的浓缩表示:

伪包近似的公式。

这大大减少了内存占用,同时保留了注意力知识蒸馏所需的关键信息。

实验与结果

研究人员在两个主要基准上测试了他们的方法,并与最先进的持续学习方法 (如 EWC, LwF, DER++) 进行了对比: 一个是 皮肤癌 数据集,另一个是复合的 Camelyon-TCGA 数据集 (涵盖乳腺癌、肺癌和肾癌) 。

定量性能

结果显示,该方法在性能上大幅超越了现有方法。我们查看 AACC (平均准确率) 和 BWT (向后迁移——衡量在旧任务上保留了多少准确率的指标) 。

Camelyon-TCGA 数据集上的结果表。

表 6 中,请看“Ours” (我们的方法) 这一行与“Rehearsal ER” (经验回放) 或“Regularization LwF” (正则化 LwF) 的对比。

  • 在使用 30 个 WSI 缓冲区的 CLAM 模型上,所提出的方法达到了 0.754 的准确率 , 而标准 ER 仅达到 0.494
  • “Ours”的 BWT (向后迁移) 为 -0.177,而 ER 为 -0.565。数字越接近零越好;这意味着模型在以前的任务上损失的准确率非常少。

随着时间的稳定性

我们可以可视化模型从任务 1 转移到任务 3 时的学习轨迹。

准确率随时间下降的图表。

图 4 中,红线 (Ours) 明显高于绿线 (ER) 或紫线 (DER++) 。当其他方法在第三个任务后准确率跌至 40% 以下 (基本上在旧任务上变成了随机猜测) 时,所提出的方法保持了稳健的性能。

权衡: 可塑性 vs. 稳定性

AI 中的一个经典困境是稳定性-可塑性困境。如果一个模型太稳定 (记住旧东西) ,它就无法学习新东西 (可塑性) 。理想情况下,模型应该处于下图的左上角: 高向后迁移 (稳定性) 和低非瞬态性 (高可塑性) 。

BWT 与 IM 指标的散点图。

图 5 显示,所提出的方法 (由左上角簇中的星形/红点表示) 实现了最佳折衷。它在保留知识方面表现得更好 (Y 轴数值高) ,且没有牺牲学习新任务的能力 (X 轴靠左) 。

消融实验: 采样策略重要吗?

最后,作者验证了他们用于内存池的 MaxMinRand 策略是否真的是必要的。

蒸馏方法的消融实验表。

表 4 确认 MaxMinRand 产生了最高的准确率 (CLAM 上为 0.729) 。有趣的是,纯粹选择“Max” Patch 表现不佳 (0.595) 。这证明模型需要在内存池中看到“无聊”的背景 Patch,以维持对注意力分布的正确理解。

结论与启示

这项研究强调了将 AI 应用于医学影像时的一个关键细微差别: 不同的架构会以不同的方式失效。在多示例学习中,注意力机制是记忆的“阿喀琉斯之踵” (致命弱点) 。

通过诊断问题——注意力层中的无界梯度导致漂移——作者能够设计出一个精准的解决方案。 注意力知识蒸馏 起到了稳定剂的作用,固定了模型的关注点,而 伪包内存池 使得该解决方案在数字病理学典型的大文件尺寸下具有实用性。

对于医疗保健的未来而言,这是重要的一步。它为那些能够随着医学科学发展而进化的诊断系统铺平了道路,使其能够整合新的生物标志物和疾病亚型,而无需每次更新教科书时都从头开始重建。通过教导 AI 不仅要思考 什么,还要知道 看哪里,我们确保了它能记住过去的教训。