元学习——通常被称为学习如何学习——彻底改变了我们处理数据有限任务的方式。像 模型无关元学习 (MAML) 这样的经典方法在小样本学习场景中取得了显著成效,在这些场景下,模型仅需少量样本即可适应新任务。这些技术会学习一组初始参数,使模型能够通过极少的梯度更新在不同任务之间快速适应。

但当我们走出小样本学习的舒适区,进入多样本领域——每个任务包含成千上万甚至数百万个样本时,会发生什么?在标准深度学习中,良好的初始化仍然至关重要: 想想 ImageNet 预训练的广泛成功。元学习能否为各种大规模问题找到一个更好、更通用的起点?

不幸的是,传统的元学习方法遭遇了计算瓶颈。让 MAML 有效的机制——通过学习过程进行反向传播——当这一过程跨越数千个优化步骤时,会变得极其昂贵。单次元更新可能需要数天的计算,这使得大规模元学习几乎不可行。

2021 年,来自 KAIST 和谷歌的论文 《通过持续轨迹迁移进行大规模元学习》(Large-Scale Meta-Learning with Continual Trajectory Shifting) 直接应对了这一问题。作者提出了一种简单而强大的方法,通过实现频繁的元更新,无需等待漫长且昂贵的训练轨迹结束,从而打破了计算障碍。该方法不仅加快了收敛速度,还带来了更平滑、更具泛化能力的模型初始化。

在这篇文章中,我们将深入探讨大规模元学习的核心思想,理解 持续轨迹迁移 (Continual Trajectory Shifting, CTS) 的原理,并了解这项技术如何提升大规模任务中元学习的效率与稳定性。


元学习中的规模扩展问题

像 MAML 和 Reptile 这样的元学习算法通过两个嵌套循环运作: 内循环外循环 (元循环)

  1. 内循环: 对于每个任务,从共享初始化参数 \( \phi \) 出发,模型通过标准梯度下降进行 \( K \) 步优化,得到任务特定的参数 \( \theta_K \)。

  2. 外循环 (元循环) : 元学习器随后根据 \( \theta_K \) 在任务上的表现来更新 \( \phi \)。元梯度指向一个使模型在更少更新步骤内获得更好任务性能的初始化方向。

对于小样本学习来说,\( K \) 较小——仅需几个步骤——因此元更新可以频繁进行。

然而,在如 Aircraft 或 Stanford Dogs 分类这类的大规模任务中,收敛可能需要 \( K = 1{,}000 \) 甚至更多内循环步骤。元学习器必须等待所有任务完成后才能进行一次元更新。如果每个元批次 (meta-batch) 有 \( T = 10 \) 个任务,那么一次元更新需要超过 10,000 次梯度操作 。 这种延迟极大地减慢了训练速度。

大规模元学习的概念。(a) 传统元学习需等待漫长的内循环轨迹完成后才进行一次元更新,导致收敛缓慢。(b) 持续轨迹迁移允许频繁的元更新与内循环步骤交错进行,避免陷入局部最小值。

图 1. 在传统元学习中,学习器会等待完整的内循环学习轨迹 (左图) 。在 CTS 中,元更新更频繁地发生,使不同任务间的收敛更加平滑。


核心方法: 持续轨迹迁移

论文的核心洞见既直接又强大: 打破迫使元学习器等待的依赖。

设想我们在每个内循环步骤后都更新初始化 \( \phi \)。在第 \( k \) 步后,任务特定的参数为 \( \theta_k = U_k(\phi) \),它们源于当前初始化。此时我们计算一个元更新 \( \Delta_k \) 并调整初始化: \( \phi_{new} = \phi + \Delta_k \)。

问题是: 现有的 \( \theta_k \) 已基于旧的 \( \phi \),因此不再一致。若要保持正确性,我们需要从新的初始化 \( \phi_{new} \) 重新进行所有 \( k \) 步优化,这在计算上几乎不可能实现。

持续轨迹迁移 (CTS) 巧妙地解决了这一问题。我们不重新计算轨迹,而是将现有参数迁移与元更新相同的量:

\[ \theta_k^{new} \approx \theta_k + \Delta_k \]

这种近似让所有任务的学习器与不断变化的初始化保持同步,从而使元学习器能在每个内循环步骤后进行元更新。任务的优化轨迹会随着训练逐步被迁移,因此得名持续轨迹迁移

持续轨迹迁移示意图。通过根据元更新逐步迁移内循环轨迹,元学习器保持一致性并提高更新频率。

图 2. 持续轨迹迁移让频繁的元更新与内循环优化交错进行。

每个内循环步骤现在包括:

  1. 对所有任务执行一次内循环梯度更新;
  2. 计算元更新 \( \Delta_k \);
  3. 更新 \( \phi \leftarrow \phi + \Delta_k \);
  4. 将每个任务的参数迁移 \( \Delta_k \)。

这种简单的交错方式显著提高了更新频率,加速了收敛。


为什么这个简单的近似奏效?

作者通过一阶近似来解释 CTS 的合理性。 令 \( U_k(\phi) \) 表示从初始化 \( \phi \) 开始经过 \( k \) 步优化后得到的参数。他们证明:

\[ U_k(\phi + \Delta) \approx U_k(\phi) + \frac{\partial U_k(\phi)}{\partial \phi}\Delta + O(\beta^2) \]

其中 \( \beta \) 为元学习率,雅可比项 \( \frac{\partial U_k(\phi)}{\partial \phi} \) 表示优化后参数对初始化变化的敏感程度。在一阶元学习中,当内循环步长较小时,该雅可比项通常可近似为单位矩阵 \( I \),因此:

\[ U_k(\phi + \Delta) \approx U_k(\phi) + \Delta \]

这验证了迁移规则: 每个任务的参数可通过加上元迁移量 \( \Delta_k \) 与 \( \phi \) 同步更新。

CTS 的近似原理: 从一个迁移后的初始点出发,经 k 步优化得到的参数约等于原始参数加上该迁移量。

图 3. 一阶泰勒展开与雅可比近似共同支持 CTS 的迁移规则。

误差分析

这种近似带来一个误差,该误差与内循环学习率 \( \alpha \)、元学习率 \( \beta \) 以及轨迹长度 \( k \) 成正比:

\[ U_k(\phi + \Delta) = U_k(\phi) + \Delta + O(\beta \alpha h k + \beta^2) \]

累计误差大致呈 \( O(\beta \alpha h k^2 + \beta^2 k) \) 增长。

近似误差的实证分析。误差随着 α、β 和 k 的增加而增长,ReLU 网络的误差比 Softplus 网络更高。

图 4. 近似误差随学习率和轨迹长度变化趋势。使用更平滑的激活函数 (如 Softplus) 可显著减少误差。

虽然误差随更大的 \( k \) 增长,但作者观察到 CTS 在实践中仍然表现出色,即使在不理想的条件下也是如此。其背后的原因在于一个隐含的优势: 元层级的课程学习


意外的课程: 元层级正则化

课程学习通常由易到难地引入任务,使模型逐步建立理解。CTS 自然在元层面形成了一个课程

  1. 训练早期 (较小的 \( k \)) 更新基于较短的轨迹。此时期元学习器关注短期改进——优化空间更简单,局部最小值更少。这种“短视偏差”起到预热作用,有助于避开初始化不佳的区域。

  2. 训练后期 (较大的 \( k \)) 随着训练推进,\( k \) 增加。更长的轨迹揭示更复杂的元损失曲面,使元学习器能利用更丰富的任务反馈进一步优化初始化。

课程学习下的损失曲面简化示意: 短轨迹长度导致更平滑的最低点,并逐步过渡至复杂曲面。

图 5. 随着 \( k \) 增加,元损失曲面变得更复杂。短轨迹在训练早期帮助模型更有效地搜索初始化。

这种自然形成的课程使 CTS 即使在近似误差逐渐增大的情况下也能保持稳健。初期的准确迁移帮助模型快速进入优质区域,而后期更长的优化轨迹则进一步微调结果。


实验

论文通过合成测试与大规模图像分类任务验证了 CTS 的有效性。

1. 合成实验

作者首先构建了一个由八个异质任务组成的玩具数据分布,这些任务源自一个具有多个极值点的二维函数。他们比较了多个方法:

  • Reptile: 基线元学习方法。
  • CTS (我们的方法): 采用持续轨迹迁移。
  • CTS 精确版: 一个计算成本高的版本,重新精确计算轨迹。

合成任务设置。通过对一个基础函数进行旋转和平移以构造多样化任务。

图 6. 合成任务设置,展示了旋转和平移后的损失表面。

发现:

  • 长轨迹至关重要: * 使用较小的 \( K \) 进行元训练,即使经过大量梯度步骤也会导致次优性能,直接体现了短视偏差*。
  • *课程效应提升质量: * CTS 由简单损失开始并逐步扩展轨迹长度,从而避免陷入糟糕的局部最小值。
  • *近似效果优异: * CTS 以远低于“精确版”的成本实现了相当的表现。

不同轨迹长度下的元学习轨迹。CTS 能更有效地走向更好的极值点,而 Reptile 容易陷入局部最小。

图 7. CTS 通过逐步延展轨迹避免差的最小值,性能优于基线方法。


2. 大规模图像分类

CTS 在多个多样化的视觉数据集上进行了评估。 元训练数据集: TinyImageNet、CIFAR100、Stanford Dogs、Aircraft、CUB、Fashion-MNIST、SVHN。 元测试数据集: Stanford Cars、QuickDraw、VGG Flowers、VGG Pets、STL10。

元学习性能比较。CTS 显示出比 MAML 变体和 Reptile 更快的收敛速度。

图 8. 与一阶基线方法相比,CTS 显著加速了元收敛。

结果:

  • 更快的元收敛: CTS 比 Reptile、Leap 以及 MAML 变体更早达到更低的训练损失。
  • 更好的泛化能力: 在细粒度与异质任务上,CTS 实现了更高的测试准确率。
  • 更高效率: CTS 用更少的累积内循环梯度步数达成更优性能。

CTS 元测试准确率与内循环步数关系。黑线 (CTS) 在所有数据集上表现均优于其他方法。

图 9. CTS 在每步内循环中保持更高准确率,表明其样本效率更优。

消融实验验证了轨迹迁移的关键作用。移除定向迁移 (“无迁移”) 或随机化迁移方向将完全消除 CTS 的性能优势。

消融实验显示,只有定向的轨迹迁移才会带来显著性能提升。

图 10. 消融实验证明 CTS 的定向迁移是性能提升的核心因素。


3. 改进 ImageNet 预训练模型

最后,作者测试了 CTS 是否能改进标准的 ImageNet 微调过程——尤其是在数据有限的场景下。

他们从 ImageNet 初始化开始,在基于 WordNet 层级划分的子集上进行元训练,并在九个分类数据集上测试,每个数据集仅有 1,000 个训练样本。

方法CIFAR100CIFAR10SVHNDogsPetsFlowersFoodCUBDTD平均值
ImageNet 预训练41.9581.6060.0955.5683.4887.0136.9534.3259.3960.04
+ MTL42.7982.3359.0555.0083.2987.0436.8434.1958.8659.93
+ Reptile47.9884.5862.3956.9784.2587.2237.3535.4458.9861.68
+ CTS (我们的方法)48.3484.4262.8257.5384.6587.5437.8436.4059.5362.12

表 1. 在低数据量场景下,通过 CTS 增强的初始化优于标准 ImageNet 微调和其他基线方法。

相较于 ImageNet 微调的性能提升。CTS 在数据最稀缺时取得最大改进。

图 11. 在小数据集上,CTS 实现了最大的性能提升,起到了强有力的正则化作用以防止过拟合。

这些结果表明,通过 CTS 重构的元学习即便对先进的预训练模型也能进一步改进。CTS 学到的平滑初始化可作为隐式正则化器,使模型在细粒度且样本稀缺的数据集上具备更好的泛化性能。


结论

长期以来,将元学习扩展到小样本任务之外一直受限于长优化轨迹带来的巨大计算负担。 持续轨迹迁移 (CTS) 通过一种一阶近似实现了突破,在频繁元更新的同时迁移内循环学习轨迹。

由此,CTS 能:

  1. 显著提升效率 —— 频繁元更新加速收敛。
  2. 发现更好的初始化 —— 隐式的元层面课程帮助避开糟糕的局部最小值。
  3. 超越强基线方法 —— 比多任务学习、标准微调及以往元学习方法效果更优。

该方法扩大了元学习在真实大规模场景中的应用范围,获得了一个能够快速、可靠地适应多样化高维任务的通用初始化,使元学习更接近在主流深度学习中的广泛落地。