训练像 Transformer 这样的大规模神经网络是现代人工智能的基石,但这也是整个过程中最困难的部分之一。这一挑战的核心在于优化器——诸如 AdamSGD 这样的算法,它们通过逐步微调模型参数来最小化损失函数。要达到顶级性能,通常需要一个耗时且资源密集的试错循环,为每个新架构或任务进行无休止的超参数调整。

但如果我们能学习优化器本身呢? 如果不再手动设计新的训练策略,而是构建一个能够学习如何训练其他神经网络的神经网络,会怎么样?

这一设想正是学习如何学习 (Learning to Learn, L2L) ——也称为*元学习 (meta-learning) *——的核心理念。DeepMind 的最新研究论文 《Mnemosyne: 用 Transformer 学习训练 Transformer》 将这一思想推向了前所未有的高度。它介绍了 Mnemosyne , 一个基于专门设计的 Transformer 架构构建的可学习优化器。与以往依赖 LSTM 等循环模型的学习型优化器不同,Mnemosyne 采用了创新的时空注意力机制,使其能够跨越空间 (参数) 和时间 (训练历史) 学习优化策略。

结果令人印象深刻: Mnemosyne 可以在无需任何手动超参数调节的情况下训练整个神经网络——包括大型 Transformer。它的性能媲美甚至超越了像 Adam 这样的顶级优化器,同时保持了类似的内存和计算效率。


从手动调优到学习型优化器

要理解 Mnemosyne 的突破,首先需要形式化定义优化器的作用。训练神经网络是一个序列决策过程。在每次迭代 \( t \) 中,优化器更新模型的参数 \( \mathbf{x}_t \) 以最小化损失 \( f(\mathbf{x}_t) \)。

对于一个可学习优化器,我们可以将该更新规则表示为:

\[ \mathbf{x}_{t+1} = g_\theta(f, \mathbf{x}_0, ..., \mathbf{x}_t) \]

这里,\( g_\theta \) 是优化器本身——一个由参数 \( \theta \) 表征的神经网络。在*元训练 (meta-training) *期间,\( g_\theta \) 学习最小化一个“元损失”,通常是被优化对象在若干步训练中的损失总和。简而言之,我们训练优化器,让它学会如何训练其他网络。

传统的学习型优化器通常采用 LSTM 作为记忆模块,能够在优化过程中记住短期模式。然而,LSTM 容易出现灾难性遗忘 , 丢失有用的长期信息。 而 Transformer 在通过注意力机制建模长程依赖方面表现卓越,但其复杂度随序列长度呈二次增长,这对跨越数千步的优化历史来说是一个严重的挑战。Mnemosyne 巧妙地解决了这一问题。


Mnemosyne 内部: 一个优化 Transformer 的 Transformer

Mnemosyne 并非单一模块,而是一个模块化架构,结合了两个关键编码器: 一个用于空间结构,另一个用于时间记忆。它支持两种运行模式——逐坐标 (coordinate-wise)逐张量 (tensor-wise) , 取决于是分别优化单个参数还是整块张量。

Mnemosyne 优化的两种模式。(a) 逐坐标模式: 每个参数 \\( r_i \\) 通过紧凑关联记忆 (CAM) 和多层感知机 (MLP) 生成其更新。(b) 逐张量模式: 一个张量 \\( S_T \\) 经分层池化编码器 (HPE) 、CAM 和空间编码器 (SPE) 管道处理,产生紧凑表示 \\( e \\),用于指导最终更新。

图 2: Mnemosyne 的两种互补应用模式。

让我们来具体分析这些组件。


拓扑编码器: 学习模型结构

与将网络中数百万个参数视为扁平列表不同,Mnemosyne 尊重张量和层的自然层次结构——权重矩阵、偏置向量和模块共同组成了一棵“参数树”。 在逐张量模式下,Mnemosyne 一次性处理整个张量,将每个参数转换为一个简单的二维表示 (通常是梯度的大小与符号) 。这一长序列的 token 很容易超过百万级元素,即使是高效的线性 Transformer 也难以直接处理。

为此,Mnemosyne 引入了分层池化编码器 (Hierarchical Pooling Encoder, HPE) , 该编码器通过多层池化将张量逐步压缩为少量高级*元 token (meta-tokens) *。

分层池化编码器 (HPE) 与紧凑关联记忆 (CAM) 的示意图。HPE 通过带有 Performer 的多层池化将大型张量转换为紧凑的元 token,这些 token 输入到 CAM 记忆中,为优化器生成学习得来的读出。

图 1: Mnemosyne 的分层池化编码器 (HPE) 与紧凑关联记忆 (CAM) 构建模块。

HPE 的工作机制如下:

  1. 将张量的参数编码序列划分为易于处理的小段。
  2. 使用双向 Performer (一种具线性注意力的高效 Transformer 变体) 并行编码每个小段。
  3. 对编码后的片段进行池化,生成更短的元 token 序列。
  4. 多层级重复此过程,直到只剩下少量固定数量的 token。

其结果是一个紧凑的张量表示,在恒定内存成本下捕捉参数间的空间相关性。


时间编码器: 永不遗忘的记忆机制

Mnemosyne 的时间编码器旨在对优化过程的时间维度进行建模。它保留所有过去步骤的记忆,并学习利用这段历史来指导未来的更新。

该架构实现了一个紧凑关联记忆 (Compact Associative Memory, CAM) ——结合了类 Hopfield 能量记忆与 Transformer 式注意力。在每个训练步骤中,CAM 接收来自空间编码器的元 token,并更新一个固定大小的隐藏状态,其中包括两个关键矩阵 \( \mathbf{N}_t \) 和 \( \Psi_t \),定义为:

\[ \mathbf{N}_t = \sum_{\mu=1}^{t} \lambda_t(\mu)\,\phi(\mathbf{k}^\mu)(\mathbf{v}^\mu)^\top,\quad \Psi_t = \sum_{\mu=1}^{t} \lambda_t(\mu)\,\phi(\mathbf{k}^\mu) \]

这里,\( \phi(\cdot) \) 表示通过 Performer 框架近似注意力核的随机低秩特征映射,\( \lambda_t \) 是衰减旧记忆的折扣因子。 这些量可以高效地在线更新:

\[ \mathbf{N}_{t+1} = e^{-\tau}\mathbf{N}_t + \phi(\mathbf{k}^{t+1})(\mathbf{v}^{t+1})^\top,\quad \Psi_{t+1} = e^{-\tau}\Psi_t + \phi(\mathbf{k}^{t+1}) \]

为了生成下一个参数更新,CAM 计算:

\[ \Delta\xi = \frac{\mathbf{N}_t^\top\phi(\mathbf{q})}{\phi(\mathbf{q})^\top\Psi_t} \]

该式得到过去值向量的凸组合,权重由当前查询 \( \mathbf{q} \) 与存储键 \( \mathbf{k}^{\mu} \) 之间的相似度决定。 与标准注意力不同,CAM 无需显式存储所有历史键和值。它实现了随训练历史长度增长仍保持恒定时间与恒定内存复杂度——这正是可扩展学习型优化器的理想目标。


理论支柱: 紧凑形式中的指数级记忆

Mnemosyne 的 CAM 不仅在实验中表现优异——理论上也得到证明。作者证明这种紧凑关联记忆能够存储与其维度成倍比例的指数级模式数量 (定理 4.3) 。

换言之,尽管 CAM 仅保留了过往优化状态的浓缩表示,它仍然具有完整注意力和类 Hopfield 网络所展现的指数级记忆容量。这意味着 Mnemosyne 在保持小而固定的隐藏状态的同时,仍可“记住”丰富多样的优化轨迹,从而解释了其出色的泛化能力。


实验: Mnemosyne 的实际表现

Mnemosyne 的效能不仅止于理论,它在多个实际场景中都取得了验证——包括微调视觉模型、预训练语言模型,以及扩展到超大规模 Transformer。


热身实验: 与传统优化器对比

研究团队首先将 Mnemosyne 与常见优化器 (Adam、RMSProp、SGD) 以及基于 LSTM 的学习型优化器进行了对比。目标模型是在 MNIST 与 CIFAR 等标准数据集上训练的小型 Vision Transformer。

多种优化器的训练损失曲线。Mnemosyne 的学习速度始终更快,最终损失也低于 Adam、RMSProp、SGD 及基于 LSTM 的优化器,即便嵌入 VeLO 架构中也是如此。

图 3: 在训练 ViT 和 MLP 时,Mnemosyne 优于传统和 LSTM 型优化器。

尽管元训练量极少,Mnemosyne 在所有任务中都实现了更快的收敛和更低的损失。它能够稳健地训练基于注意力机制的架构,而 LSTM 优化器在此类任务上表现欠佳,这验证了 Mnemosyne 记忆机制的鲁棒性。


规模扩展: 逐坐标 Mnemosyne

接下来是大规模实验。逐坐标版本 (每个参数拥有独立 CAM 模块) 被用于微调 Vision Transformer (ViT-H),以及对超大型 *T5XXL 模型 (110 亿以上参数) *进行软提示微调。

在 ViT 微调中,Mnemosyne 的表现与最优 Adam 变体相当——而后者已在不同学习率下进行了精细调节。

在多个数据集上微调 ViT-H 模型的准确率曲线。无需手动调参,Mnemosyne 的性能与最佳学习率的 Adam 相当或更优。

图 4: 逐坐标 Mnemosyne 在 ViT 微调任务中与顶级 Adam 变体表现相当。

对 T5XXL 模型的软提示微调进一步展示了 Mnemosyne 的超大规模适应能力。尽管可训练提示模块仅含约 1.2 万参数,优化空间仍十分复杂。Mnemosyne 在所有 Adam 变体中始终取得更低损失。

T5XXL 提示微调的损失与迭代次数关系图。Mnemosyne 的收敛更加平滑,最终损失低于所有 Adam 基线。

图 6: Mnemosyne 在 T5XXL 软提示微调任务中的表现。


逐张量 Mnemosyne: 高效内存优化

为了端到端训练大型模型,逐张量版本通过在张量级别而非参数级别操作来节省内存。作者使用该模式在掩码语言建模 (Masked Language Modeling, MLM) 任务中预训练了 BERT-base (8600 万参数)

BERT 预训练中掩码语言模型 (MLM) 损失随迭代变化曲线。Mnemosyne 的表现与最佳 Adam 变体相当,而其他设置收敛到更高损失。

图 7: 在 BERT 预训练任务中,逐张量 Mnemosyne 的性能与最佳 Adam 相当。

尽管在元训练阶段 (于小型 MNIST 分类器上进行) 未曾接触语言任务,Mnemosyne 仍成功泛化,取得了与精细调优的 Adam 相当甚至更好的结果。


超级 Mnemosyne: 两全其美的融合

最后,为结合逐坐标优化的灵活性与逐张量优化的高效性,研究团队提出了超级 Mnemosyne (Super-Mnemosyne) ——一种混合架构,对大张量采用逐坐标优化,对小张量采用逐张量优化。

在多个数据集上微调 ViT-B 模型的准确率曲线。超级 Mnemosyne 的表现始终达到或超过最高性能的 Adam 变体。

图 9: 超级 Mnemosyne 结合逐坐标与逐张量模式,实现顶级微调效果。

在各项 ViT 微调任务中,超级 Mnemosyne 无需额外超参数调节,即可稳定超越最优的手动调参 Adam 基线。


Mnemosyne 的意义

Mnemosyne 在优化研究领域开启了全新的范式:

  • 性能: 始终达到或超过当前最先进优化器 (如 Adam、RMSProp) 的水平。
  • 免调参: 开箱即用,无需繁琐的学习率与动量搜索。
  • 可扩展性: 高效的分层空间编码与紧凑的时间记忆使其可胜任数十亿参数的模型。
  • 泛化性: 在简单任务上训练即可跨架构、跨领域、跨规模泛化。

展望未来

通过连接 Transformer 与学习型优化,Mnemosyne 为未来打开了新篇章: 优化器本身将成为智能体——能够理解并适应所训练模型的动态特性。此类系统将极大减少实验时间、提升鲁棒性,乃至学习出超越梯度下降的全新训练范式。

Mnemosyne 框架将“优化”从手工工程转变为一种可学习的能力——一个真正“让 Transformer 教 Transformer 学习”的体系。