如果你曾经尝试过训练深度神经网络——特别是大型语言模型 (LLM) ——你很可能经历过不稳定性带来的噩梦。你盯着损失曲线,看着它平滑下降,然后突然间,毫无征兆地,它飙升了。损失发散,梯度爆炸,几天的计算时间瞬间化为乌有。

当我们为了获得“缩放定律”的收益而推动模型向更深层发展时,这种不稳定性成为了主要的瓶颈。模型越深,就越难保持从输入到输出的信号清晰。

在这篇文章中,我们将深入探讨一篇引人入胜的论文,题为 “Stable Language Model Pre-training by Reducing Embedding Variability” (通过降低嵌入变异性实现稳定的语言模型预训练) 。 作者提出了两个主要贡献,改变了我们对训练稳定性的思考方式:

  1. TEV (词元嵌入变异性) : 一种新的、计算成本低廉的指标,用于监控稳定性,取代了昂贵的梯度方差计算。
  2. MLRA (多头低秩注意力) : 一种架构上的改变,从数学上保证了前向传播中的低方差,从而从根本上防止梯度爆炸。

让我们来拆解其中的数学原理、方法和结果。


深度 Transformer 中的稳定性问题

要理解解决方案,我们首先需要了解故障的机制。现代 LLM (如 GPT-2、Llama-2 等) 通常使用 前置层归一化 (Pre-Layer Normalization, Pre-LN) 架构。虽然与后置层归一化 (Post-LN) 相比,Pre-LN 有助于优化,但它引入了一个特定的副作用: 浅层中的梯度爆炸

在残差网络 (如 Transformer) 中,信号通过一系列层传递。如果我们将词元嵌入层表示为 \(\mathbf{E}\),它将我们的词汇表映射到一个向量空间。

嵌入矩阵 E 将词汇表映射为向量表示。

这里,\(|V|\) 是词汇表大小,\(\mathbf{e}_i\) 是特定词元的向量。

单个词元的向量表示。

在反向传播 (Backward Pass) 期间,梯度从最终损失流回这个嵌入层。由于残差连接的存在,梯度会累积。第一层 (\(\nabla X_0\)) 的梯度是所有后续层梯度的乘积。

显示梯度如何通过链式法则在各层中累积的公式。

在深度模型 (例如 48、96 或 100+ 层) 中,这个乘积项呈指数级增长。这种现象被称为 梯度爆炸 。 当梯度爆炸时,模型权重的更新变得巨大且不稳定,导致那些破坏训练过程的损失飙升。

监控的代价

传统上,研究人员通过监控 梯度方差 来检测这种不稳定性。如果梯度的方差很高,说明训练是不稳定的。然而,在每一步计算梯度方差的成本高得令人望而却步——它需要对梯度矩阵进行 \(O(nd)\) 次运算,这会显著拖慢训练速度。

我们需要一个更好、更便宜的“速度计”。


第一部分: 词元嵌入变异性 (TEV)

研究人员提出,与其关注 (计算昂贵的) 梯度,不如关注嵌入层本身的权重 (这是我们现成的) 。

具体来说,他们观察到 嘈杂的梯度词元嵌入的变异性 之间存在强烈的联系。如果训练不稳定,更新嵌入层的梯度将是不稳定的。这导致不同词元的嵌入向量在幅度和分布上剧烈波动。

定义 TEV

作者引入了 词元嵌入变异性 (Token Embedding Variability, TEV) 。 首先,让我们看看单个词元嵌入向量 \(\mathbf{e}_i\) 的标准差。

单个词元的词元嵌入变异性 (TEV) 公式。

这里,\(\bar{e}_i\) 是该词元向量中元素的均值。

为了获得系统层面的视图,我们计算这个指标在整个词汇表 \(|V|\) 上的均值 (\(\mu_{\text{TEV}}\)) 和标准差 (\(\sigma_{\text{TEV}}\)) 。

整个词汇表上 TEV 的均值和标准差公式。

为什么这行得通?

在稳定的训练运行中,词元嵌入不应表现得像离群值。论文通过查看预训练的开源模型 (OPT, Pythia, Llama-2, GPT-2) 中嵌入的均值,提供了一个经验性的“健全性检查”。

小提琴图显示更大、更稳定的模型具有更低的 TEV。

如上方的 图 1 所示,有一个明显的趋势: 更大、性能更好的模型表现出更低的 TEV。 更大参数量 (如 Llama-2 70B 或 GPT-2 XL) 中“更瘦”且更低的分布表明,随着模型变得更加稳定和强大,其嵌入权重的变异性会降低。

这证实了 \(\mu_{\text{TEV}}\) 是稳定性的有效代理指标。如果这个数字飙升,你的梯度很可能正在爆炸。


第二部分: 解决方案——多头低秩注意力 (MLRA)

用 TEV 诊断问题很有用,但解决问题更好。作者对多头注意力机制提出了一种架构上的改变,称为 多头低秩注意力 (Multi-head Low-Rank Attention, MLRA)

概念

在标准 Transformer 中,注意力机制使用权重矩阵 \(W_Q, W_K, W_V\) 投影输入 \(X\)。这些通常是满秩方阵。

MLRA 提议将这些投影矩阵 分解 为两个更小的低秩矩阵。与其学习一个大矩阵 \(W\),我们学习两个矩阵 \(W^U\) (上投影) 和 \(W^D\) (下投影) ,使得:

\[W \approx W^U W^D\]

其中 \(W^U \in \mathbb{R}^{d_{\text{model}} \times r}\) 且 \(W^D \in \mathbb{R}^{r \times d_{\text{model}}}\),\(r\) 为秩 (\(r < d_{\text{model}}\)) 。

数学原理: 分解如何降低方差

这是核心创新点。为什么把一个矩阵拆成两个有助于稳定性?这归结为方差如何在初始化权重中传播。

假设我们使用 Kaiming 均匀初始化 , 这是这些模型的标准配置。在这种初始化下,权重矩阵 \(W\) 的方差是 \(\frac{1}{3d_{\text{model}}}\)。

如果我们通过标准线性层投影输入 \(X\) (假设已归一化) ,输出的方差为:

标准线性层的方差计算。

上方公式显示标准注意力的方差是 1/3

现在,看一看分解版 (MLRA) 的下方公式。当信号通过 \(W^D\) 然后通过 \(W^U\) 时,方差相乘。因为我们乘以了两个初始方差都很小的矩阵,结果方差被显著抑制了。

\[ \frac{1}{3} \times \frac{1}{3} = \frac{1}{9} \]

结果是 1/9。

仅仅通过分解矩阵,输出的初始方差就降低了 3 倍 (从 1/3 降至 1/9) 。

显示随秩 r 变化的方差缩减的通用公式。

这种方差缩减起到了信号“阻尼器”的作用。它防止了信号在深层传播时方差呈指数级增长,直接抵消了前面描述的梯度爆炸问题。

避免低秩瓶颈

你可能会问: 低秩分解不会损害模型的表达能力吗?

这是一个合理的担忧。以前低秩训练的尝试往往失败,因为模型无法学习复杂的模式 (即“低秩瓶颈”) 。

然而,MLRA 在多头结构内部 应用这种分解。每个头学习不同的子空间。作者说明,虽然单个矩阵可能是低秩的,但它们的级联 (或在残差流中的求和) 保留了模型的满秩特性。

示例显示低秩向量如何构成满秩矩阵。

如上所示,即使向量由简单的基向量组成,组合它们也可以覆盖整个空间。通过在每个头上应用分解,MLRA 在保持必要的表达能力的同时,享受到了低方差带来的稳定性优势。


实验与结果

作者通过从头开始训练 GPT-2 模型来测试这一假设,使用了三种配置:

  1. 基线 GPT-2 (Baseline GPT-2)
  2. \(\sigma\)Reparam: 一种用于稳定性的最先进方法 (Zhai et al., 2023) 。
  3. MLRA (本文提出) : 分解注意力方法。

他们在不同的深度下进行了测试: 48 层、96 层和 192 层。

1. 稳定性分析 (梯度方差)

第一个问题是: MLRA 真的能稳定梯度吗?

梯度方差比较显示 MLRA 保持较低的方差。

图 2 讲述了一个令人信服的故事。

  • 左图 (48 层) : 所有模型都相对稳定,但 MLRA (绿线) 的梯度方差最低。
  • 中图 (96 层) : 基线 GPT-2 (蓝线) 开始遭遇更高的方差峰值。MLRA 依然是最稳定的。
  • 右图 (192 层) : 这是压力测试。基线 GPT-2 实际上 训练失败 (发散) 了 5 次并被排除。然而,MLRA 训练顺畅,梯度方差极低。

2. 验证 TEV 作为代理指标

廉价的 TEV 指标实际上与昂贵的梯度方差指标相关吗?

训练过程中的 TEV 和梯度方差趋势对比。

图 3 中,我们可以看到前 10 亿个词元的 TEV (上图) 和梯度方差 (下图) 。趋势几乎完全相同。当梯度方差飙升时,TEV 也会飙升。这验证了 TEV 是监控训练稳定性的可靠、轻量级代理指标。

3. 下游性能 (困惑度)

稳定性固然好,但模型真的学得更好吗?作者测量了标准基准 (Lambada, Wikitext, PTB) 上的零样本困惑度 (Zero-Shot Perplexity) 。困惑度越低越好。

表 1: 困惑度和 TEV 结果。MLRA 优于基线。

表 1 突显了 MLRA 的优势:

  • 更低的 TEV: MLRA 始终具有最低的 \(\mu_{\text{TEV}}\),证实它是最稳定的。
  • 更好的性能: 在几乎所有数据集和深度上,MLRA 都实现了最低的困惑度。
  • 随深度缩放: 随着模型变深,差距进一步扩大。在 192 层时,MLRA 在 Wikitext-103 上的困惑度为 44.17 , 而基线 (96 层时) 为 47.75 , 并且优于竞争对手 \(\sigma\)Reparam。

结论

训练深度 Transformer 就像是一场平衡游戏。我们想要深度以获得性能,但深度会以梯度爆炸的形式招致混乱。

这篇论文为我们提供了两个强大的工具来管理这种混乱:

  1. TEV: 一个源自嵌入权重的简单指标,充当训练稳定性的“检查引擎”指示灯,计算成本几乎为零。
  2. MLRA: 一个有原则的架构改变,利用矩阵分解从数学上抑制方差初始化。

通过将前向传播的方差从 \(1/3\) 降低到 \(1/9\),MLRA 防止了梯度爆炸,使我们能够训练明显更深的模型 (在本研究中高达 192 层) ,而无需担心发散。

对于学生和从业者来说,结论很明确: 稳定性不仅仅关乎超参数调整 (学习率、批大小) 。它从根本上关乎方差如何在你的架构中传播。有时候,一个简单的分解就是保持信号清晰所需的全部。