想象一下教一个孩子认识猫。他们在这方面变得非常擅长。然后,你开始教他们认识狗。突然之间,他们似乎忘记了猫长什么样。这对人类来说听起来很荒谬,但对于人工神经网络而言,却是一个非常真实且令人沮丧的问题。这种现象被称为灾难性遗忘 (catastrophic forgetting) ,是构建能够在整个生命周期中持续学习、真正智能且可适应的人工智能系统的主要障碍。

当一个在任务 A 上训练好的神经网络接着在任务 B 上进行训练时,它往往会覆盖任务 A 学到的知识,导致性能急剧下降。研究人员已经提出了许多聪明的策略来应对这一问题,主要是试图“保护”那些对先前任务至关重要的网络参数。但是,如果我们不是去手动设计这些保护规则,而是教会神经网络如何学习而不遗忘呢?

这就是来自 SK T-Brain 的研究人员在论文 《元持续学习》 (“Meta Continual Learning”) 中提出的核心思想。他们提出了一种引人入胜的方法,利用元学习 (meta-learning) ——即“学习如何学习”——来训练一个独立的神经网络,它唯一的任务就是指导另一个网络的学习过程,确保它在掌握新任务的同时还能记住旧的任务。这不仅仅是关于学习,更是关于学习不遗忘的艺术


问题所在: 为什么神经网络会遗忘

在探讨解决方案之前,让我们先明确问题。神经网络本质上是一个由相互连接的节点组成的系统,节点之间的连接具有可调参数或“权重”。训练过程通过调整这些权重,来最小化一个损失函数——即任务的预测值与真实值之间的差异。

当你在任务 A (例如分类数字 0–4) 上训练一个网络时,权重会收敛到一个对该任务最优的配置。当引入任务 B (例如分类数字 5–9) 时,训练又开始调整这些权重,以最小化新任务的损失。如果没有额外约束,优化器可能会完全覆盖之前的配置。结果,网络在任务 B 上表现良好,却完全忘记了任务 A。

现有的解决灾难性遗忘的方法通常分为三类:

  1. 网络扩展: 为每个新任务添加新的神经元、层或模块。有效但效率低,因为网络会无限增长。
  2. 排练/回放: 存储先前任务的数据 (或使用生成模型重新生成) ,并与新任务数据混合训练。虽有效,却占用大量内存。
  3. 基于正则化的方法: 在损失函数中添加惩罚项,以减缓那些对旧任务关键的权重变化。例如 弹性权重巩固 (EWC)突触智能 (SI) 。 然而,这些方法依赖于手工制定的启发式规则来评估参数的重要性。

SK T-Brain 的研究人员认为,手工设计的正则化项虽强大,但其固有的局限性难以避免。对于一个真正通用的持续学习系统,我们需要一种能够自动学习如何保存有价值知识的机制。


核心思想: 学习如何在不遗忘的情况下学习

这项研究的核心创新是用一个学习到的正则化规则取代手动设计的规则。作者引入了第二个、更小的神经网络,称为更新步长预测器 (Update Step Predictor) 。

让主网络——即学习实际任务 (如数字分类) 的网络——表示为 \( f_{\theta} \),其参数为 \( \theta \)。更新预测器 \( h_{\phi} \) (其参数为 \( \phi \)) 并不学习分类图像,而是学习为 \( f_{\theta} \) 输出理想的参数更新 \( \Delta\theta \)。

目标是训练 \( h_{\phi} \) 使其对旧任务关键参数输出较小更新 (保护它们) ,而对可自由调整的参数输出较大更新以适应新任务。

元学习过程示意图。过去和当前任务输入到一个“更新预测器”中,预测器学习生成谨慎的更新,以防止灾难性遗忘。

图 1: 在元学习过程中,更新步长预测器学习如何调整序列任务的参数更新,从而防止灾难性遗忘。


更新预测器的工作原理

在学习新任务 \( T_j \) 时,更新预测器会接收针对每个参数的几个关键输入:

  • 当前梯度 (\( g_j \)) —— 表示当前任务损失函数的梯度,指示参数该如何变化以改进任务 \( T_j \) 的表现。
  • 前一任务的重要性 (\( g_{j-1}^{*} \)) —— 来自前一个任务的平均平方梯度,反映每个参数的敏感度 (即重要性) 。
  • 模型参数 (\( \theta, \theta^{*} \)) —— 当前和先前的参数值,为跨任务的一致性提供上下文。

基于这些输入,\( h_{\phi} \) 为每个参数计算出一个最优更新 \( \Delta\theta \)。主网络的参数更新如下:

\[ \begin{aligned} g_{j-1}^{*} &= \nabla_{\theta^{*}} \mathcal{L}(f_{\theta^{*}}(T_{j-1})) && \text{(来自前一任务的重要性)} \\ g_j &= \nabla_{\theta}\mathcal{L}(f_{\theta}(T_j)) && \text{(当前任务梯度)} \\ \Delta\theta &= h_{\phi}(g_{j-1}^{*}, g_j, \mathcal{I}) && \text{(预测器输出)} \\ \theta' &= \theta - \eta \Delta\theta && \text{(应用更新)} \end{aligned} \]

其中,\( \eta \) 是一个缩放超参数,\( \mathcal{I} \) 代表其他上下文输入。


元学习步骤: 训练预测器本身

我们如何训练 \( h_{\phi} \),使其能够平衡学习与记忆保持?

这就是元训练 (meta-training) 的作用所在。预测器在一个特别设计的数据集 \( \mathcal{T}_{0} \) 上进行训练,该数据集由模拟持续学习挑战的子任务组成。

通用流程 (论文中的算法 1) 如下:

  1. 从元训练集选取两个连续任务 \( \mathcal{T}_{0,j-1} \) 和 \( \mathcal{T}_{0,j} \)。
  2. 让 \( h_{\phi} \) 为主模型 \( f_{\theta} \) 提出参数更新。
  3. 在两个任务的联合数据集 (\( \mathcal{T}_{0,j-1} \cup \mathcal{T}_{0,j} \)) 上评估更新后的 \( f_{\theta'} \)。
  4. 计算该联合任务的元损失——既衡量当前任务的成功,也衡量对过去任务的保持。
  5. 反向传播更新步骤,并用 Adam 算法调整 \( \phi \)。

简写形式为:

\[ \phi \leftarrow \operatorname{Adam}\big(\nabla_{\phi}\mathcal{L}(f_{\theta}(T_{i-1}\cup T_i))\big) \]

这样,预测器便学会在连续任务间最小化遗忘。元训练结束后,我们冻结 \( \phi \)。此时,训练好的预测器 \( h_{\phi^{*}} \) 可使用算法 2 指导真实的持续学习任务。


实验: 检验元学习器的效果

为了验证这一想法,研究人员使用了MNIST 手写数字数据集的变体,这是持续学习领域的常用基准。

实验包含两种主要配置:

  • 不相交 MNIST (Disjoint MNIST): 将类别划分为任务 1 ({0–4}) 和任务 2 ({5–9}) 。
  • 像素重排 MNIST (Shuffled MNIST): 每个任务均包含所有 10 个数字,但像素被随机置换,每次置换构成一个全新的任务。

这种设置确保各任务难度相似但输入结构不同,非常适合研究灾难性遗忘现象。

分类网络与更新预测器均为小型全连接神经网络。分类网络包含两个隐藏层 (每层 800 个 ReLU 单元) ,更新预测器包含两个隐藏层 (每层 10 个神经元) 。


结果: 它真的有效吗?

在不相交 MNIST 和像素重排 MNIST 上的结果。左图显示了两任务的准确率,在学习任务 2 后任务 1 的性能仍保持较高;右图显示了三个像素重排任务中类似的保持效果。

图 2: (左) 不相交 MNIST 的测试准确率; (右) 三个像素重排 MNIST 任务的测试准确率。元学习的预测器在各任务上均保持了高性能。

方法不相交 MNIST (%)像素重排 MNIST (%)
SGD (未调优)47.7 ± 0.189.1 ± 2.3
我们的方法 (MLP)82.3 ± 0.995.5 ± 0.6

改进十分显著: 不相交 MNIST 的平均准确率从约 48% (SGD) 提升至超过 82% (元学习优化器) ,而在像素重排 MNIST 上,各任务的性能保持高水平,与传统 SGD 相似且更稳定。

与其他持续学习方法对比,结果进一步凸显了这一方法的潜力:

方法不相交 MNIST (%)像素重排 MNIST (%)
SGD (已调优)71.3 ± 1.5~95.5
EWC52.7 ± 1.4~98.2
IMM (最佳)94.1 ± 0.398.3 ± 0.1
我们的方法 (MLP)82.3 ± 0.995.5 ± 0.6

虽然并非表现最优,但该方法成功展示了: 优化器可以被训练来从数据中学习持续学习策略 , 而非依赖人工设计。


预测器学到了什么

为了理解更新预测器的演化过程,作者将其在元训练期间的输出进行了可视化。

一个 3D 图,展示预测的参数更新随训练步的演变。最初,所有更新接近零;随着训练进行,出现了三峰分布,峰值分别为负、零和正。

图 3: 预测器输出值 (按 \( \eta \) 缩放) 在元训练期间的演变。

在训练初期,预测器输出接近零的微小更新——由于知识不足而采取保守策略。随着训练的推进,出现了明显的三类聚集: 大幅正向更新、大幅负向更新以及微小 (近零) 更新。

这种三峰模式揭示了预测器已学会何时强烈调整 (针对可灵活调整的参数) ,何时冻结权重 (针对对旧任务至关重要的参数) 。换句话说,它已内化了在保持与适应之间的平衡。


结论与未来方向

元持续学习框架改变了我们看待灾难性遗忘的方式。我们不再需要设计静态规则来保存知识,而是可以让神经网络自己去学习这些规则。更新步长预测器就像一个智能优化器,引导另一个网络的参数持续演化而不抹去已学知识。

MNIST 实验显示出明显的成功迹象: 遗忘减少、旧任务保持能力增强、以及良好的适应性。作者指出,他们的实验仅限于短任务序列和相关数据集,是一个初步的概念验证。未来研究可以扩展至更广泛的任务和更长的序列;为更新预测器引入循环或记忆增强架构,也可能进一步提升其可扩展性。

通过结合元学习持续学习 , 这项研究向终身学习系统迈出了重要一步——这样的神经网络不仅能够学习,还能记忆、适应,并随经验不断进化。