大型语言模型 (LLM) 是令人印象深刻的通才。它们在像 Common Crawl 这样的大规模语料库上训练,对万事万物都略知一二。然而,在现实世界中,“略知一二”往往是不够的。无论是需要专门分析合同的律师事务所,还是需要编码助手的软件公司,我们经常需要采用一个通用模型并教授它特定的领域知识。

这个过程被称为持续预训练 (Continual Pre-Training, CPT) 。 这听起来很直接: 拿一个预训练好的模型,在新的特定领域数据上继续训练它。但 CPT 引入了一种众所周知的张力。随着模型学习新领域 (下游性能) ,它往往会遗忘最初学到的东西 (通用性能) 。这种现象被称为灾难性遗忘 (catastrophic forgetting) , 它为研究人员创造了一个微妙的平衡难题。

直到最近,大多数研究都集中在这个过程的结果上——即结束时学到了多少以及丢失了多少。但是在训练期间发生了什么?学习动态是如何一步步演变的?

在论文*“Learning Dynamics in Continual Pre-Training for Large Language Models”*中,Wang 等人 (2025) 提出了一个全面的数学框架来回答这些问题。他们推导出了一个CPT 缩放定律 (CPT Scaling Law) , 模拟了验证损失在整个训练过程中的确切轨迹。通过解耦学习率退火和数据分布偏移的影响,这项工作使我们能够预测 LLM 在其持续训练旅程中任何一步的性能。

核心问题: CPT 拉锯战

要理解这篇论文的贡献,我们首先需要将问题可视化。当你采用一个在通用数据集 (我们称之为 \(D_{pt}\),代表预训练) 上训练的模型,并将其切换到一个新数据集 (\(D_{cpt}\),代表持续预训练) 时,损失曲线会出现分歧。

不同学习率调度 (LRS) 下的 CPT 损失曲线: 常数 (a-c) 和预热-稳定-衰减 (WSD) 。

如上图 Figure 1 所示,在 CPT 期间同时发生了两件事:

  1. 通用性能下降: 蓝色虚线代表模型预测原始数据 (\(D_{pt}\)) 的能力。注意当数据切换时 (垂直虚线处) ,损失是如何激增的。这是分布偏移对模型的冲击。
  2. 下游适应: 橙色虚线显示模型在新领域 (\(D_{cpt}\)) 上变得更好。

研究人员将 CPT 过程概念化为一种过渡——一座从“初始预训练轨迹”跨越到“新领域特定轨迹”的桥梁。这项研究的目标就是用数学描述这座桥梁。

解构迁移曲线

作者首先定义了两个理论基准,他们称之为“隐藏 PT 曲线 (Hidden PT Curves) ”。想象一下平行宇宙:

  1. \(D_{pt}\) 上的隐藏 PT 曲线: 在这个宇宙中,我们从未切换数据。我们只是使用 CPT 的学习率调度继续在原始数据集上训练。
  2. \(D_{cpt}\) 上的隐藏 PT 曲线: 在这个宇宙中,我们从一开始就完全在数据集上从头开始训练。

实际的 CPT 过程是在这两种隐藏状态之间移动的迁移曲线 (Transfer Curve) 。 它从第一条曲线剥离并收敛向第二条曲线。

分布偏移的不变性

论文先导观察中最惊人的发现之一是,“分布偏移 (Distribution Shift) ”——即模型因切换数据源而付出的代价——遵循一种可预测的模式,而无论你何时开始迁移。

不同迁移起点下 Dpt 和 Dcpt 验证集中的迁移损失曲线。

Figure 2 所示,无论你是在第 10,000、20,000 还是 30,000 步切换数据集,偏差的形状 (红色箭头) 都保持一致。这表明分布偏移是两个数据集之间距离的基本属性,在很大程度上独立于模型的当前状态。

CPT 缩放定律

这篇论文的核心是推导支配这一过程的数学定律。作者通过将训练动态分为两个不同的部分来实现这一点: 学习率 (LR) 退火分布偏移

组件 1: 学习率退火

首先,我们需要一种方法来描述当数据没有改变时模型是如何学习的。作者建立在 Tissue 等人 (2024) 之前工作的基础上,该工作将损失描述为不仅是计算量的函数,也是学习率调度的函数。

基础损失定义为:

公式 1: 带有 LR 退火的缩放定律。

以下是各项的细分:

  • \(S_1\) (前向区域) : 迄今为止使用的学习率之和。这代表了模型在优化景观中行进的总“距离”。
  • \(S_2\) (退火区域) : 一个捕捉降低学习率所带来收益的项。当 LR 下降 (退火) 时,模型会进入更尖锐的极小值,从而降低损失。
  • \(L_0, A, C\): 特定于模型和数据的常数。

当我们将其应用于没有任何分布偏移的 CPT 时,基础损失考虑了来自预训练 (\(pt\)) 和持续 (\(cpt\)) 两个阶段的累积训练:

公式 2: 结合 PT 和 CPT 阶段的基础损失。

组件 2: 分布偏移

接下来,作者对切换数据分布的“惩罚”进行建模。根据他们的先导观察,这遵循一种幂律形式,并随着 CPT 阶段的训练量 (\(S_1^{cpt}\)) 而缩放。

公式 3: 分布偏移项。

这个项 \(\Delta L(t)\) 代表实际损失与理论基准之间的差距。它从 0 开始 (迁移前) ,随着模型适应新数据而增长 (或缩小,取决于观察的领域) 。

组合定律

通过将这两个组件缝合在一起,作者提出了最终的 CPT 缩放定律 。 这个单一的方程可以预测持续预训练过程中任何步骤 \(t\) 的损失:

公式 4: 完整的 CPT 缩放定律。

这个方程之所以优雅,是因为它隔离了性能的具体驱动因素:

  1. \(S_1\) (训练量) : 已经进行了多少训练?
  2. \(S_2\) (退火) : 学习率衰减了多少?
  3. 分布偏移 (第二行) : 新数据与旧数据有多大差异?

验证定律

这个方程真的有效吗?作者将此曲线拟合到了各种学习率调度中,包括流行的余弦 (Cosine) 调度和预热-稳定-衰减 (WSD) 调度。

使用公式 4 拟合不同 LRS (WSD 和 Cosine) 下的所有 PT 和 CPT 损失曲线。

Figure 3 展示了拟合结果。线条代表使用方程预测的损失,数据点代表实际的训练运行。对于通用领域 (\(D_{pt}\)) 和下游领域 (\(D_{cpt}\)) 的验证损失,匹配都几近完美。这证实了无论使用特定的哪种调度,该定律都成立。

“滑梯”类比与损失潜力

为了建立直觉,作者为他们的发现提供了解释的几何视角。你可以将损失景观可视化为一个曲面。CPT 过程就像是从一个曲面过渡到另一个曲面的滑梯。

CPT 过程的损失曲面和两个方向视图。

Figure 4 中,请看面板 (c)。这个“退火视图”引入了一个关键概念: 损失潜力 (Loss Potential)

损失潜力本质上衡量了预训练模型的“未完成”程度。一个尚未衰减其学习率的模型具有高损失潜力——它位于损失曲线的高处,准备随着学习率的退火而大幅下降。相反,一个完全收敛的模型 (LR 已经接近零) 具有低损失潜力。

为什么“欠火候”的模型迁移效果更好

这篇论文的一个主要实践见解是损失潜力与下游性能之间的关系。

研究人员发现, 具有更高损失潜力的 PT 模型能更好地适应新领域。 如果你在预训练期间完全退火你的模型 (榨干通用领域上的每一滴性能) ,它就会变得僵化。它已经深深地陷入了一个极小值,使得它很难跨越到新领域的极小值。

图 5: 损失潜力的影响。

Figure 5 清楚地说明了这一点。

  • 面板 (b) 和 (e): “真实损失”曲线显示,具有较高损失潜力的模型 (紫色线) 在新领域 (\(D_{cpt}\)) 上实现的最终损失要低于完全退火的模型 (浅蓝色线) 。
  • 面板 (c) 和 (f): 缩放定律的预测证实了这一趋势。

发现: 如果你正在发布一个旨在供他人微调或继续训练的开源模型, 不要完全退火它。 发布一个具有高损失潜力的检查点。

CPT 中的关键因素

利用他们的缩放定律,作者分析了其他几个对持续预训练成功至关重要的超参数。

1. 峰值学习率

在开始 CPT 时,你通常会再次“预热”学习率。它应该升到多高? 根据定律,CPT 阶段较高的峰值 LR 会加速对新领域 (\(D_{cpt}\)) 的适应,更快地降低其损失。然而,这也会导致通用领域损失 (\(D_{pt}\)) 出现更剧烈的峰值。这就是量化的经典“稳定性-可塑性困境”。

不同 CPT 步骤下的 Dcpt 预测损失与峰值 LR 的关系。

2. 重放比例 (Replay Ratio)

对抗灾难性遗忘的一个常用技术是重放 : 将一定比例的原始数据 (\(D_{pt}\)) 混合到新的训练批次中。

作者扩展了他们的缩放定律以考虑这一点。他们发现重放比例 (\(r\)) 以指数方式影响分布偏移。

重放比例的拟合方程。

这个看起来很复杂的修正允许缩放定律预测任何数据混合的性能。

不同重放比例下的损失曲线。

Figure 19 所示,即使加入少量的原始数据 (重放) 也会极大地改变曲线,抑制通用损失的激增。缩放定律准确地预测了这些轨迹,允许工程师模拟不同的比例,而无需运行昂贵的实验。

3. 临界点和转折长度

当你开始在新数据上训练时,旧数据的损失会上升。但它还会降回来吗?

作者确定了一个临界点 (Critical Point)

  • 临界前 (Pre-Critical) : 如果你停止训练得足够早,或者如果数据集足够相似,通用损失最终可能会恢复 (曲线先升后降) 。
  • 临界后 (Post-Critical) : 如果数据集差异太大或你训练得太久,你就会越过一个不归点。通用损失将稳定在一个比开始时更高的值。

Dpt 验证损失中的临界点和转折长度。

优化: 平衡权衡

缩放定律最强大的应用之一是超参数优化。你可以用数学定义你想要什么,而不是猜测。

如果我们定义目标为最小化通用损失和下游损失的加权组合:

公式 5: 最小化加权损失。

我们可以求解最优设置。

基于不同系数优化 CPT 超参数。

Figure 8 可视化了这些最优前沿:

  • 图表 (a): 显示了最优损失潜力。如果你主要关心新领域 (低 \(\lambda_1\)) ,你需要一个具有接近 100% 损失潜力的模型 (高可塑性) 。
  • 图表 (c): 显示了最优重放比例。有趣的是,最优重放比例并不总是线性的。

解决“黑盒”问题

最后,作者解决了从业者面临的一个主要障碍: 开源模型。

当你下载像 LLaMA-3 这样的模型时,你无法访问它的确切预训练数据或特定的损失轨迹。实际上你是在一个“黑盒”上开始 CPT。缩放定律还能起作用吗?

作者建议使用代理数据集 (Proxy Dataset) 。 例如,使用 RedPajama 的一部分 (Common Crawl 数据的副本) 作为未知 \(D_{pt}\) 的替身。

使用代理数据集进行拟合和预测。

Figure 18 (b) 证明,使用代理数据集允许缩放定律拟合损失曲线 (蓝线) ,其效果几乎与拥有真实数据一样好。这使得该方法对于使用各种开源基础模型的工程师来说非常实用。

此外,作者还展示了你甚至可以通过将分布外 (OOD) 数据集 (既不是原始数据集也不是目标数据集) 建模为两个已知损失的线性组合来预测其损失。

预测 OOD 损失。 公式 6: OOD 线性组合。

结论

Wang 等人的工作将持续预训练领域从炼金术推向了化学。通过建立严格的 CPT 缩放定律 , 他们提供了一种量化迁移学习动态的方法。

关键要点:

  1. CPT 是一种过渡: 它是数学上可预测的,作为两条隐藏学习曲线之间的转移。
  2. 不要让你的模型“火候太足”: 如果你计划微调或继续训练,具有高“损失潜力” (较少退火) 的模型更优越。
  3. 先预测再训练: 通过运行简短的试点运行来拟合公式 4 中的常数,研究人员可以预测大规模训练运行的性能,在投入大量计算资源之前优化峰值学习率、重放比例和训练步骤。

随着 LLM 继续在法律、医学、编码和科学领域专业化,理解它们如何学习——以及如何遗忘——的物理学比以往任何时候都更加关键。这个缩放定律为这种理解提供了蓝图。