If you have ever trained a machine learning model, the standard procedure is almost muscle memory: set up your data loader, define a Stochastic Gradient Descent (SGD) optimizer, and write a loop that iterates over your dataset for multiple epochs. The intuition is simple: the more the model sees the data, the better it should learn.

But what if seeing the data a second time actually breaks the model?

In the fundamental theory of Stochastic Convex Optimization (SCO), there is a known “magic” to the first pass over the data. We know that a single epoch of SGD achieves the optimal error rate. However, a fascinating new research paper, “Rapid Overfitting of Multi-Pass SGD in Stochastic Convex Optimization,” reveals a startling phenomenon: in general convex settings, performing just one additional pass over the data can lead to catastrophic overfitting.

This article explores why this happens, the mathematics behind it, and what it tells us about the nature of learning algorithms.

Figure 1. An illustration of the minmax rates for the population loss of multi-pass SGD established in Theorems 3.1 and 3.3, through K = 5 epochs and for different stepsizes eta.

As shown in Figure 1 above, using a standard step size (blue line), the population loss drops beautifully during the first epoch. But the moment the second epoch begins, the loss shoots up, ruining the progress. Let’s dig into why.

The Setup: Stochastic Convex Optimization

To understand the problem, we first need to define our goal. In Stochastic Convex Optimization (SCO), we are trying to minimize the Population Loss (or Risk), denoted as \(F(w)\). This is the expected loss over the true data distribution \(\mathcal{Z}\). We want a model that performs well on data it hasn’t seen yet.

F(w) equation

However, we don’t have access to the infinite distribution \(\mathcal{Z}\). We only have a training set \(S\) of size \(n\). Usually, we minimize the Empirical Loss, which is the average loss over our specific training samples:

F_S(w) equation

The Magic of the First Pass

Classic theory tells us that One-Pass SGD is minimax optimal. If you run SGD for exactly one epoch (processing each of the \(n\) samples once) with a step size decay of \(\eta \approx 1/\sqrt{n}\), you achieve a population excess risk of \(O(1/\sqrt{n})\). This is the best theoretical rate possible.

The algorithm for standard One-Pass SGD looks like this:

One-Pass SGD Algorithm

Because every sample in the first pass is “fresh” (independent and identically distributed), the gradient updates provide an unbiased estimate of the true population gradient. This allows us to use powerful statistical guarantees known as “online-to-batch” conversions.

The Problem: Multi-Pass SGD

In practice, we rarely stop after one epoch. We run Multi-Pass SGD, shuffling the data and going over it again and again.

Multi-Pass SGD Algorithm

The assumption is usually that more training minimizes the empirical risk \(F_S(w)\) further, which hopefully translates to lower population risk \(F(w)\). While this often holds for smooth functions, the authors of this paper investigate the general convex case (Lipschitz functions that are not necessarily smooth).

They ask a simple but unanswered question: How does the population risk deteriorate if we continue training for just a few more passes?

The Result: A Phase Transition

The researchers prove that there is a sharp phase transition between the first and second epochs.

If you tune your SGD step size to be optimal for the first pass (\(\eta \approx 1/\sqrt{n}\)), that same step size becomes toxic in the second pass.

The Lower Bound

The paper establishes a tight lower bound for the population error of multi-pass SGD. If you run the algorithm for \(T\) total steps (where \(T > n\)), the excess population loss is bounded by:

Lower Bound Equation

Let’s break down the implications of this equation, specifically the term \(\eta \sqrt{T}\).

  1. Epoch 1 (\(T=n\)): If we set \(\eta = 1/\sqrt{n}\), the error is roughly \(1/\sqrt{n} \cdot \sqrt{n} = O(1)\). Wait, actually, to get the optimal rate, we usually need the error to be small. The standard analysis balances terms to get \(O(1/\sqrt{n})\).
  2. Epoch 2 (\(T=2n\)): Here is the danger. If we keep the step size \(\eta = 1/\sqrt{n}\) constant (which is common), and we run for another \(n\) steps, the lower bound implies the error can grow significantly.

Specifically, the authors show that with the canonical step size \(\eta = \Theta(1/\sqrt{n})\), the population loss can become constant (\(\Omega(1)\)) after just one additional pass. This effectively means the model has unlearned everything valuable it acquired during the first pass.

This result applies to various sampling schemes:

  • Single-shuffle: Permute once, repeat.
  • Multi-shuffle: Reshuffle every epoch.
  • Arbitrary permutations.

With-Replacement SGD

The paper also analyzes With-Replacement SGD, where data points are sampled randomly and uniformly from the dataset at every step (meaning you might see the same point twice before seeing others).

With-Replacement SGD Algorithm

The authors prove a similar lower bound here. The “overfitting” effect kicks in once the algorithm has “seen” the entire dataset. Based on the Coupon Collector’s problem, this happens after roughly \(O(n \log n)\) steps. Once the dataset is fully memorized, the population loss deteriorates at the same rate:

With-Replacement Lower Bound

How Overfitting Happens: The Mechanism

How can a convex optimization algorithm overfit so drastically? The intuition relies on memorization.

During the first pass, the algorithm sees the data stream. By the end of epoch 1, the algorithm has implicitly “memorized” the dataset \(S\). In high-dimensional spaces (where dimension \(d \approx n\)), the algorithm can exploit this memory.

The researchers construct a specific “hard” loss function to prove this. This function is designed to punish the algorithm once it tries to minimize the empirical error too aggressively.

The function is composed of two parts:

Function Decomposition

  1. \(g(w, V)\) (Feldman’s Function): This component creates “spurious” minima. It ensures that there are specific vectors that look like great solutions on the training set (empirical loss is low) but are actually terrible solutions for the population (population loss is high).
  2. \(h(w)\) (The Guide): This component acts as a guide. It directs the SGD updates toward those spurious minima, but only after the algorithm has identified where the data points lie.

The Trap

The mechanism works in two stages:

  1. Observation (Epoch 1): During the first pass, the algorithm is essentially “safe” because it is processing new data. However, it is collecting information about the location of the data points.
  2. Execution (Epoch 2): Once the dataset is fixed and known, the gradient updates (driven by \(h(w)\)) steer the model toward a specific direction that lies “between” the data points.

In the high-dimensional construction, the algorithm identifies a “bad” vector \(u_0\) that is orthogonal to all the training points. It then moves in the direction of \(u_0\). Because \(u_0\) doesn’t align with any training point, it doesn’t increase the empirical loss (thanks to the structure of \(g\)). However, moving in this direction increases the true population loss significantly.

This creates a scenario where the Empirical Risk (training error) stays low or decreases, but the Population Risk (test error) explodes.

Matching Upper Bounds

This result isn’t just a pessimistic worst-case scenario; it matches the achievable upper bounds. Using techniques from Algorithmic Stability, the authors provide a matching upper bound for multi-pass SGD:

Upper Bound Equation

This confirms that the lower bound \(\Omega(\eta\sqrt{T} + 1/(\eta T))\) is indeed the correct rate. The behavior of rapidly increasing test error is an inherent feature of SGD on non-smooth convex functions when the step size is not decayed aggressively enough.

The Generalization Gap of One-Pass SGD

In addition to the multi-pass results, the paper provides a fascinating insight into One-Pass SGD.

There is a classical view that algorithms generalize because they have a small “generalization gap” (the difference between training error and test error). However, the authors show that for One-Pass SGD, this is not the case.

They prove that even though One-Pass SGD achieves optimal population loss, it can have a massive generalization gap. Specifically, the Empirical Loss can be much higher than the population loss.

Empirical Risk Lower Bound

This suggests that the success of One-Pass SGD cannot be explained by uniform convergence or standard generalization gap arguments. It succeeds because of the stochastic approximation magic, which vanishes the moment we reuse data.

Conclusion and Takeaways

The findings in this paper highlight a critical difference between the first epoch and all subsequent epochs in machine learning training.

  1. The “Phase Transition”: There is a fundamental shift in the behavior of SGD after the first pass. During the first epoch, we benefit from statistical independence. From the second epoch onward, we are in an “Empirical Risk Minimization” (ERM) regime where overfitting becomes the dominant force.
  2. Step Size Matters: The “rapid overfitting” described occurs when using the step size optimal for one pass (\(\eta \approx 1/\sqrt{n}\)). To avoid this overfitting in later epochs, one must decay the step size significantly (e.g., \(\eta \approx 1/\sqrt{T}\)).
  3. Smooth vs. Non-Smooth: It is important to note that these results are specific to the general (non-smooth) convex setting. In smooth optimization, the landscape is more forgiving, and overfitting happens much more slowly.

For students and practitioners, this serves as a theoretical warning: “more epochs” is not free lunch. If your problem is non-smooth (or effectively non-smooth due to architecture), reusing data without adjusting your learning rate can undo all the hard work your model performed in the first epoch. The model stops learning the world and starts memorizing the data.