训练大型语言模型 (LLM) 是一项昂贵且风险极高的工作。想象一下,你分配了数千个 GPU 和数百万美元来训练像 LLaMA 或 GPT 这样的模型,结果训练运行到一半时发散了。损失值突然飙升,这种现象被称为损失尖峰 (Loss Spike) , 数周的进度可能因此毁于一旦。

损失尖峰是深度学习中的一个根本性问题,对于 Transformer 模型尤为如此。虽然工程师们开发了各种“创可贴”式的补救措施——比如从之前的检查点重新开始训练或跳过某些数据批次——但其根本原因仍然部分不为人知。

在一篇题为 “Initialization of Large Language Models via Reparameterization to Mitigate Loss Spikes” 的精彩论文中,来自 NTT 人类信息学实验室的研究人员提出了一个新颖的解释和一个令人惊讶的优雅解决方案。他们认为问题在于参数范数的非均匀性 。 简单来说: 网络中的某些权重被初始化得比其他权重小得多。当这些微小的权重被更新时,它们会发生巨大的相对变化,从而破坏网络的稳定性。

他们的解决方案是一种称为 WeSaR (权重缩放即重参数化,Weight Scaling as Reparameterization) 的技术。在这篇文章中,我们将拆解损失尖峰背后的数学原理,理解为什么标准的初始化方法会失败,并探讨 WeSaR 如何稳定训练,从而实现更快的收敛和更好的性能。

隐藏的罪魁祸首: 更新比率 (Update Ratio)

要理解为什么 LLM 训练会变得不稳定,我们首先需要看看训练过程中参数是如何变化的。在神经网络中,我们通过反向传播计算出的微小变化 \(\Delta W\) 来更新参数矩阵 \(W\)。

研究人员提出,更新的绝对大小 \(\|\Delta W\|\) 并不是唯一重要的事情。真正重要的是更新比率 (Update Ratio) :

\[ \text{Update Ratio} = \frac{\|\Delta W\|}{\|W\|} \]

这个比率代表了更新相对于参数本身的大小。

微小权重的问题

深度神经网络需要特定的初始化策略 (如 He Initialization 或 Xavier Initialization) ,以防止梯度在通过深层网络时消失 (变为零) 或爆炸 (变为无穷大) 。

在 Transformer 模型中,为了满足这些要求,特定的层——特别是前馈网络 (Feed-Forward Networks) 中的投影层 (\(W_d\)) ——必须用非常小的值进行初始化。

问题就在这里: 如果一个参数 \(\|W\|\) 被初始化得非常小,即使是一个适度的更新 \(\|\Delta W\|\) 也会导致巨大的更新比率。

拥有 130 亿参数的 Transformer 模型在训练初期的损失情况。上图显示了损失尖峰。下图显示 W_d 的更新比率远高于 W_u。

如上方的 图 1 所示,请看底部的图表。\(W_d\) (向下投影层,红色曲线) 的更新比率在训练开始时明显高于其他参数。现在再看顶部的图表: 这种不稳定性与损失的巨大尖峰直接相关。

基线方法 (蓝线) 在这些巨大的相对变化中挣扎。这本质上就像是网络正在用“大锤”给这些微小的参数做“心脏手术”。

背景: 初始化的微妙平衡

在深入探讨解决方案之前,我们必须了解为什么这些权重最初要设置得很小。

避免梯度爆炸

在深度网络中,我们希望梯度的尺度从最后一层回传到第一层时保持恒定。如果梯度在反向传播过程中增大,训练就会发散。如果它们缩小,模型就会停止学习。

从数学上讲,对于将 \(x\) 映射到 \(y\) 的层,我们需要梯度范数的期望值相匹配:

显示网络中梯度范数保持恒定要求的公式。

Transformer 的挑战

Transformer 使用残差连接 (跳跃连接) 。一层的输出是 \(y = f(x) + x\)。当你对求和进行反向传播时,梯度会流经这两条路径。

显示通过带有层归一化的残差连接的梯度流动的公式。

由于这种残差路径,当你深入网络 (可能有几十层) 时,梯度的方差往往会累积。为了抵消这一点,标准的初始化技术 (如 GPT-2 中使用的技术) 会将残差层的权重缩小 \(1/\sqrt{2N}\) 倍,其中 \(N\) 是层数。

对于一个 40 层的 130 亿参数模型,这个因子大约是 \(0.11\)。这迫使像 \(W_d\) 这样的矩阵被初始化为比其他矩阵小得多的范数。这满足了梯度稳定性的要求,但违反了更新稳定性的要求,导致了我们在图 1 中看到的尖峰。

解决方案: WeSaR

研究人员提出了 WeSaR 来解决这一冲突。目标是同时满足两个相互冲突的要求:

  1. 梯度要求: 有效权重必须很小,以防止梯度爆炸。
  2. 更新要求: 实际权重参数必须足够大,以便更新 (\(\Delta W\)) 不会引起巨大的相对变化。

解耦尺度与方向

WeSaR 通过重参数化 (reparameterizing) 权重来实现这一点。模型不再直接学习矩阵 \(W\),而是学习一个标量门控 \(\alpha\) 和一个矩阵 \(W\)。

前向传播如下所示:

定义重参数化的公式: W_bar 等于 (sigma/sigma) 乘以 W,即等于 alpha 乘以 W。

这里:

  • \(W_{\cdot}\) 是实际参数 (Actual Parameter)
  • \(\bar{W}_{\cdot}\) 是计算中使用的虚拟参数 (Virtual Parameter)
  • \(\alpha\) 是一个可训练的门控参数 (Gate Parameter)

它是如何工作的

1. 实际参数 (\(W\)) 的统一初始化: 在标准方法中,不同的层根据其大小或深度获得不同的标准差。在 WeSaR 中, 所有实际参数矩阵 (\(W\)) 都使用相同的标准差 \(\sigma\) 进行初始化。

显示 W 使用通用高斯分布进行初始化的公式。

至关重要的是,这个 \(\sigma\) 可以选择得很小但很统一,确保没有任何一层相对于其他层是“脆弱”的。

2. 通过门控 (\(\alpha\)) 进行调整: 稳定反向传播所需的缩放 (前面提到的 \(1/\sqrt{2N}\) 因子) 被应用于门控参数 \(\alpha\) , 而不是矩阵 \(W\)。

这意味着有效权重 \(\bar{W}\) 很小 (满足梯度约束) ,但实际权重 \(W\) 是标准的 (满足更新稳定性) 。

为什么 Adam 优化器不会破坏这一点

你可能会问: 如果我们缩放权重,梯度难道不会也随之缩放,从而抵消掉这个好处吗?这就是 Adam 优化器 发挥作用的地方。

Adam 基于梯度动量 (\(M_t\)) 与梯度方差平方根 (\(V_t\)) 的比率来更新参数。

显示 Adam 优化器更新规则的公式。

当我们重参数化 \(\bar{W} = \alpha W\) 时,关于 \(W\) 的梯度被 \(\alpha\) 缩放。然而,由于 Adam 将动量除以方差的根,这个标量缩放因子在更新步骤中很大程度上被抵消了。

显示损失与参数之间梯度关系的公式。

这一理论见解证实了 WeSaR 改变了优化景观的几何形状,从而在不破坏学习过程的情况下稳定了相对更新。

实验结果

团队在 1.3 亿到 130 亿参数的 Transformer 模型上测试了 WeSaR。他们将其与流行的 Small Initialization (小初始化) 方法 (用于 GPT-J 和其他开源模型) 以及其他重参数化技术 (如权重归一化) 进行了比较。

消除损失尖峰

首要目标是消除尖峰。观察 13B 参数模型的训练损失,差异是显而易见的。

13B 模型训练期间的损失。红线 (Proposed) 平滑且低于蓝线 (Baseline) 。

图 2 所示,基线 (Small Init) 在训练早期遭受了剧烈的尖峰。WeSaR 模型 (红色) 不仅避免了这些尖峰,而且在整个运行过程中实现了持续更低的损失值。

稳定更新比率

WeSaR 真的解决了“更新比率”问题吗?

图 9: 13B 模型第 40 层的更新比率。与其基线相比,提出的方法显示出更低且更稳定的更新比率。

图 9 (上图) 显示了 13B 模型第 40 层的更新比率。虚线 (基线) 显示出巨大的波动。实线 (提出的 WeSaR) 则非常稳定。这种稳定性证实了统一初始化所有权重可以防止特定层对更新变得“过敏”。

探究参数范数

为了证明该机制按预期工作,研究人员绘制了训练期间权重的范数。

图 3: 最后一层参数 W_d 和 W_u 的范数。提出的方法使它们重叠且稳定。

图 3 中,请注意“Baseline” (蓝色) 线。它从接近零开始并线性增长。这种增长代表模型拼命试图增加那些初始化极小的权重的幅度。这种快速增长与不稳定性有关。

相比之下, Proposed 方法 (红色和粉色线) 保持了实际参数 \(W_d\) 和 \(W_u\) 的稳定、较高的范数。必要的缩放由 \(\alpha\) 处理,使矩阵 \(W\) 保持稳定。

性能对比

最后,这种稳定性是否转化为更好的语言建模能力?是的。

表 5: 主要结果显示 WeSaR 在 WikiText 和 LAMBADA 困惑度上优于 Small Init。

表 5 显示,在所有模型尺寸 (130M、1.3B 和 13B) 下,WeSaR 在 WikiText 和 LAMBADA 基准测试中都实现了更低 (更好) 的困惑度。

值得注意的是,WeSaR 的表现也优于其他复杂的重参数化技术。

表 6: 重参数化方法的比较。WeSaR 优于或堪比权重归一化,但速度更快。

表 6 所示,WeSaR 比 Weight Normalization (权重归一化)\(\sigma\)Reparam 更有效。此外,由于 WeSaR 仅添加单个标量乘法 (门控) ,而不是复杂的归一化操作 (如除以整个矩阵的标准差) ,因此计算成本更低。例如,权重归一化增加了 12.6% 的训练时间,而 WeSaR 增加的开销几乎可以忽略不计。

超参数鲁棒性

WeSaR 的一个隐藏好处是,它将初始化的标准差 (\(\sigma\)) 与隐藏层维度大小 (\(d\)) 解耦了。

在标准 Transformer 中,你通常使用 \(\sigma = \sqrt{1/d}\) 初始化权重。随着模型变宽 (\(d\) 变大) ,\(\sigma\) 必须变小。WeSaR 允许你设置一个固定的、较小的 \(\sigma\) (例如 \(\sigma^2 = 4e-5\)) ,而不管模型大小如何。

表 9: 鲁棒性与标准差和学习率的关系。

表 9 证明了 WeSaR 对不同的 \(\sigma\) 值具有鲁棒性,并且至关重要的是,它允许更高的学习率 (1e-3 对比标准的 5e-4) 。能够在不崩溃的情况下以更高的学习率进行训练,对于 LLM 预训练来说是“圣杯”,因为它可以显著加快收敛速度。

结论

长期以来,“损失尖峰”一直是 LLM 从业者机器中的幽灵——一个随机的、令人沮丧的事件,通常被归咎于“糟糕的数据”或“运气不好”。Nishida 等人的工作揭示了一个结构性原因: 梯度稳定性 (要求微小的权重) 和更新稳定性 (要求正常大小的权重) 之间的算术冲突。

WeSaR 提供了一个令人信服的解决方案:

  1. 重参数化: 使用门控 \(\alpha\) 来处理尺度。
  2. 统一初始化: 让实际权重 \(W\) 处于一个安全、统一的范围内。

通过这样做,WeSaR 稳定了更新比率,防止了损失尖峰,并实现了更快、更激进的训练计划。随着语言模型继续扩展到万亿参数,像 WeSaR 这样高效且稳定的初始化技术很可能会成为深度学习技术栈中的标准组件。