如果你曾经训练过机器学习模型,那么标准的流程几乎已经形成了肌肉记忆: 设置数据加载器,定义随机梯度下降 (SGD) 优化器,并编写一个循环在数据集上迭代多个 epoch。直觉告诉我们: 模型看数据的次数越多,学习效果就应该越好。
但如果第二次看数据实际上会破坏模型呢?
在随机凸优化 (Stochastic Convex Optimization, SCO) 的基础理论中,我们知道第一遍遍历数据 (One-Pass) 有着某种“魔力”。众所周知,单 epoch 的 SGD 可以达到最优的误差率。然而,一篇引人入胜的新研究论文 “Rapid Overfitting of Multi-Pass SGD in Stochastic Convex Optimization” 揭示了一个令人震惊的现象: 在一般的凸优化设置中,仅仅对数据多进行一次遍历就可能导致灾难性的过拟合。
本文将探讨为什么会发生这种情况、其背后的数学原理,以及它揭示了学习算法的何种本质。

如上图 Figure 1 所示,使用标准步长 (蓝线) ,总体损失 (Population Loss) 在第一个 epoch 期间完美下降。但第二个 epoch 一开始,损失就急剧飙升,毁掉了之前的进展。让我们深入探讨原因。
设定: 随机凸优化
为了理解这个问题,我们需要首先定义目标。在随机凸优化 (SCO) 中,我们试图最小化总体损失 (或称风险) ,记为 \(F(w)\)。这是在真实数据分布 \(\mathcal{Z}\) 上的期望损失。我们希望模型能在从未见过的数据上表现良好。

然而,我们无法获得无限的分布 \(\mathcal{Z}\)。我们只有一个大小为 \(n\) 的训练集 \(S\)。通常,我们通过最小化经验损失来工作,即训练样本上的平均损失:

第一遍遍历的魔力
经典理论告诉我们, 单轮 SGD (One-Pass SGD) 是极小极大最优的 (minimax optimal) 。如果你运行 SGD 恰好一个 epoch (处理 \(n\) 个样本各一次) ,并配合 \(\eta \approx 1/\sqrt{n}\) 的步长衰减,你将获得 \(O(1/\sqrt{n})\) 的总体超额风险 (excess risk) 。这是理论上可能的最佳速率。
标准的单轮 SGD 算法如下所示:

因为第一遍中的每个样本都是“新鲜的” (独立同分布) ,梯度更新提供了真实总体梯度的无偏估计。这使我们能够利用被称为“在线转批处理 (online-to-batch) ”转换的强大统计保证。
问题: 多轮 SGD
在实践中,我们很少在一个 epoch 后停止。我们运行 多轮 SGD (Multi-Pass SGD) , 打乱数据并一遍又一遍地遍历它。

通常的假设是,更多的训练会进一步最小化经验风险 \(F_S(w)\),这有望转化为更低的总体风险 \(F(w)\)。虽然这对于平滑函数通常成立,但这篇论文的作者研究了一般凸情况 (即利普希茨连续但不一定平滑的函数) 。
他们提出了一个简单但尚未被回答的问题: 如果我们仅仅多训练几轮,总体风险会如何恶化?
结果: 相变
研究人员证明,在第一和第二个 epoch 之间存在一个剧烈的相变 (phase transition) 。
如果你将 SGD 的步长调整为对第一遍遍历最优 (\(\eta \approx 1/\sqrt{n}\)) ,那么同样的步长在第二遍遍历中就会变成“毒药”。
下界
论文为多轮 SGD 的总体误差建立了一个紧确的下界。如果你运行算法总共 \(T\) 步 (其中 \(T > n\)) ,超额总体损失的下界为:

让我们分解这个方程的含义,特别是 \(\eta \sqrt{T}\) 这一项。
- Epoch 1 (\(T=n\)): 如果我们设置 \(\eta = 1/\sqrt{n}\),误差大致为 \(1/\sqrt{n} \cdot \sqrt{n} = O(1)\)。等等,实际上,为了获得最优速率,我们需要误差很小。标准分析通过平衡各项来获得 \(O(1/\sqrt{n})\)。
- Epoch 2 (\(T=2n\)): 危险就在这里。如果我们保持步长 \(\eta = 1/\sqrt{n}\) 不变 (这很常见) ,并且再运行 \(n\) 步,下界意味着误差可能会显著增长。
具体来说,作者表明,使用典型的步长 \(\eta = \Theta(1/\sqrt{n})\),仅仅多进行一次遍历后,总体损失就可能变成常数级 (\(\Omega(1)\)) 。这实际上意味着模型“遗忘”了在第一遍遍历中学到的所有有价值的东西。
这个结果适用于各种采样方案:
- 单次洗牌 (Single-shuffle) : 排列一次,然后重复。
- 多次洗牌 (Multi-shuffle) : 每个 epoch 都重新洗牌。
- 任意排列。
放回采样的 SGD
论文还分析了放回采样 SGD (With-Replacement SGD) , 即每一步都从数据集中随机均匀采样数据点 (这意味着在看到其他点之前可能会再次看到同一个点) 。

作者在这里证明了类似的下界。“过拟合”效应在算法“看过”整个数据集后开始显现。基于赠券收集者问题 (Coupon Collector’s problem) ,这大约发生在 \(O(n \log n)\) 步之后。一旦数据集被完全记忆,总体损失就会以相同的速率恶化:

过拟合是如何发生的: 机制
凸优化算法怎么会如此剧烈地过拟合?这种直觉依赖于记忆化 (memorization) 。
在第一遍遍历期间,算法看到的是数据流。到 epoch 1 结束时,算法已经隐式地“记住”了数据集 \(S\)。在高维空间中 (维度 \(d \approx n\)) ,算法可以利用这种记忆。
研究人员构建了一个特定的“困难”损失函数来证明这一点。该函数旨在惩罚那些试图过于激进地最小化经验误差的算法。
该函数由两部分组成:

- \(g(w, V)\) (Feldman 函数): 这个组件制造了“虚假”极小值。它确保存在特定的向量,这些向量在训练集上看起来是很棒的解 (经验损失很低) ,但在总体上实际上是很糟糕的解 (总体损失很高) 。
- \(h(w)\) (向导): 这个组件充当向导。它引导 SGD 更新指向那些虚假极小值,但只有在算法识别出数据点的位置之后才会起作用。
陷阱
该机制分两个阶段工作:
- 观察 (Epoch 1) : 在第一遍遍历期间,算法本质上是“安全”的,因为它正在处理新数据。然而,它正在收集关于数据点位置的信息。
- 执行 (Epoch 2) : 一旦数据集被固定且已知,梯度更新 (由 \(h(w)\) 驱动) 就会将模型引向位于数据点“之间”的特定方向。
在高维构造中,算法识别出一个与所有训练点正交的“坏”向量 \(u_0\)。然后它向 \(u_0\) 的方向移动。因为 \(u_0\) 与任何训练点都不对齐,所以它不会增加经验损失 (由于 \(g\) 的结构) 。然而,向这个方向移动会显著增加真实的总体损失。
这就创造了一个场景: 经验风险 (训练误差) 保持在低位或下降,但总体风险 (测试误差) 爆炸式增长。
匹配的上界
这个结果不仅仅是一个悲观的最坏情况;它与可实现的每一个上界相匹配。利用算法稳定性 (Algorithmic Stability) 的技术,作者提供了多轮 SGD 的匹配上界:

这证实了下界 \(\Omega(\eta\sqrt{T} + 1/(\eta T))\) 确实是正确的速率。当步长衰减不够激进时,非平滑凸函数的测试误差迅速增加是 SGD 的一个固有特征。
单轮 SGD 的泛化差距
除了多轮的结果外,该论文还对单轮 SGD 提供了引人入胜的见解。
有一种经典观点认为算法之所以能泛化,是因为它们具有很小的“泛化差距” (训练误差和测试误差之间的差异) 。然而,作者表明对于单轮 SGD,情况并非如此。
他们证明,即使单轮 SGD 达到了最优的总体损失,它也可能具有巨大的泛化差距。具体来说, 经验损失可能远高于总体损失。

这表明单轮 SGD 的成功不能用一致收敛或标准的泛化差距论点来解释。它的成功是因为随机逼近的魔力,而这种魔力在我们重复使用数据的那一刻就消失了。
结论与启示
这篇论文的发现强调了机器学习训练中第一个 epoch 与随后的所有 epoch 之间的关键区别。
- “相变”: SGD 的行为在第一遍遍历后发生了根本性的转变。在第一个 epoch 期间,我们受益于统计独立性。从第二个 epoch 开始,我们就进入了“经验风险最小化” (ERM) 机制,过拟合成为主导力量。
- 步长至关重要: 描述的“快速过拟合”发生在使用对单次遍历最优的步长 (\(\eta \approx 1/\sqrt{n}\)) 时。为了避免在后续 epoch 中过拟合,必须显著衰减步长 (例如 \(\eta \approx 1/\sqrt{T}\)) 。
- 平滑 vs 非平滑: 值得注意的是,这些结果特定于一般 (非平滑) 凸设置。在平滑优化中,损失地形更加宽容,过拟合发生得慢得多。
对于学生和从业者来说,这是一个理论警告: “更多的 epoch”不是免费的午餐。如果你的问题是非平滑的 (或者由于架构原因实际上是非平滑的) ,在不调整学习率的情况下重复使用数据可能会抵消模型在第一个 epoch 中所做的所有努力。模型不再学习世界,而是开始死记硬背数据。
](https://deep-paper.org/en/paper/10797_rapid_overfitting_of_mul-1668/images/cover.png)