In the world of Large Language Models (LLMs), Supervised Fine-Tuning (SFT) is the standard procedure for adapting a pre-trained base model to a specific task, whether it’s mathematical reasoning, coding, or following instructions. The general consensus has long been that as long as we shuffle our training data and run enough epochs, the model will learn effectively.

But what if the order in which the model sees the data matters more than we thought? What if the samples seen at the very beginning of training are consistently learned “worse” than those seen later, creating a hidden imbalance in your model’s performance?

A recent paper titled “Mitigating Training Imbalance in LLM Fine-Tuning via Selective Parameter Merging” uncovers exactly this phenomenon. The researchers demonstrate that data ordering creates significant training imbalances. More importantly, they propose a novel solution: instead of training one model, train multiple variations with different data orders and merge them. However, they don’t just average the weights. They introduce a new technique called Parameter-Selection Merging that builds a “mosaic” of the best parameters, outperforming traditional merging methods.

In this deep dive, we will explore the problem of training imbalance, the mathematics behind parameter-selection merging, and the experimental results that validate this new approach.

The Invisible Problem: Training Imbalance

When training a neural network, we typically assume that the specific position of a data sample in the training queue becomes irrelevant after several epochs of Stochastic Gradient Descent (SGD). We shuffle the dataset, process it in batches, and assume the model averages out the learning signal.

However, the authors of this paper conducted a preliminary investigation that challenges this assumption. They tracked the relationship between the position of a training sample in the very first epoch and its final loss value after training was complete.

The results, shown in Figure 1 below, are striking.

Figure 1: Impact of training sample position at first epoch on final model losses of these samples (after 3 epochs of training). Panels (a) and (b) present the results on the GSM8k and Alpaca tasks, respectively. Panels (c) and (d) show the corresponding results from multiple experiments with different training orders.

Let’s break down what Figure 1 is telling us:

  • Panels (a) and (b): The x-axis represents the position of the sample during the first epoch (from start to finish). The y-axis represents the loss (error) of those specific samples after training is finished (3 epochs). There is a clear downward trend. Samples that appeared early in the first epoch have higher final loss compared to samples that appeared later.
  • Panels (c) and (d): This isn’t a fluke. When repeating the experiment with different random orders (represented by different colors), the pattern holds. The early samples consistently end up with higher loss.

This phenomenon indicates a training imbalance. The model is biased against the data it sees first. If your most critical instruction-tuning examples happen to fall in the first few batches, your model might significantly underperform on those specific types of queries, regardless of how many subsequent epochs you run.

The Solution: Merging Diverse Models

If the order of data creates bias, how do we fix it? We can’t simply put “everything last.” The authors propose a solution rooted in Model Merging.

The core idea is simple:

  1. Take your dataset and shuffle it into \(K\) different random orders.
  2. Fine-tune \(K\) separate models (starting from the same pre-trained base), each using one of those specific data orders.
  3. Merge these \(K\) models into a single, final model.

By training multiple times with different permutations, a sample that appeared “early” in Model A (and was thus poorly learned) might appear “late” in Model B (and be well learned). By merging them, we hope to smooth out the positional biases.

However, the way we merge these models makes all the difference.

Traditional Approach: Weighted-Average Merging

The standard way to merge models—often seen in techniques like Model Soups or Federated Learning—is Weighted-Average Merging.

Imagine you have two models. To merge them, you take the value of a specific parameter in Model 1, add it to the corresponding parameter in Model 2, and divide by two. You do this for every single parameter in the neural network (which can be billions of parameters).

Mathematically, if we have \(K\) sub-models, and \(w_i\) is the weight (importance) we assign to the \(i\)-th model, the merged parameter for any dimension \(j\) is calculated as:

()\n\\theta _ { \\mathrm { m e r g e d } , j } = \\sum _ { i = 1 } ^ { K } w _ { i } \\theta _ { \\mathrm { i , j } } , \\forall j \\in { 1 , \\dots , d }\n[

While this method works reasonably well to reduce variance, it creates a “blended” model. It pulls the parameter values toward a mean. The authors argue that this might not be optimal. If Model A found a perfect parameter value for a specific feature, and Model B didn’t learn that feature well, averaging them simply dilutes Model A’s success.

The Innovation: Parameter-Selection Merging

The researchers introduce Parameter-Selection Merging. Instead of averaging the parameters, this method acts as a selector. For every single parameter dimension in the model, it looks at the pool of \(K\) available values from the fine-tuned sub-models and picks one.

It is a stochastic process. The merged parameter is set to the value of the \(i\)-th sub-model with a probability \(p_i\):

]\n\\theta _ { \\mathrm { m e r g e d } , j } = \\theta _ { i , j } \\mathrm { w i t h } p _ { i } , \\forall j \\in { 1 , \\dots , d }\n[

Since all sub-models are trained on the same data (just different orders), they are assumed to be equally good on average. Therefore, the probability is usually set uniformly (\(p_i = 1/K\)).

Why is this better? Think of weighted averaging as mixing colors of paint; if you mix red and blue, you get purple. If red was the correct answer, you are now wrong. Parameter-selection is like a mosaic; you keep the distinct red tile or the blue tile. By selecting discrete parameters, the method preserves the specific “knowledge” encoded in that parameter by a specific model, rather than blurring it.

Figure 2 provides a visual comparison of the two workflows.

Figure 2: Illustration comparing weighted-average method and the proposed parameter-selection method. Weighted-average merging calculates the weighted sum of all sub-model parameters at each parameter dimension, whereas parameter-selection merging selects parameters from a single sub-model. In the resampling module,parameters that equal those of the base model are replaced with parameters from alternative models.

As shown in the diagram, the weighted-average method creates a blend (represented by the gradient colors). The parameter-selection method creates a composite matrix where each cell comes distinctly from one of the source models (blue, pink, or cyan).

Enhancing the Method: The Resampling Strategy

The authors identified a potential weakness in the random selection process. What if the selector picks a parameter from a model that didn’t actually learn anything for that specific feature?

In Fine-Tuning, we often look at the Task Vector (\(\tau\)). The task vector is the difference between the fine-tuned weights (\(\theta_{SFT}\)) and the original pre-trained weights (\(\theta_{pre}\)): () \tau = \theta_{SFT} - \theta_{pre} ] If \(\tau_{i,j} = 0\), it means the fine-tuning process did not change that parameter at all for that model. If our random selection picks this “unchanged” parameter, we are missing out on an opportunity to use a parameter from a different model that did learn something (where \(\tau \neq 0\)).

To fix this, the authors introduce a Resampling Strategy.

If the chosen parameter has a task vector of 0 (meaning no change occurred), the algorithm discards it and resamples from the pool of sub-models. This can be repeated \(n\) times until a “changed” parameter is found or the attempts run out.

The formula for this recursive selection is:

()\n\\theta _ { \\mathrm { m e r g e d } , j } ^ { ( n ) } = \\left{ { \\theta _ { i , j } } \\atop { \\theta _ { \\mathrm { m e r g e d } , j } ^ { ( n - 1 ) } } \\right. \\left. { \\mathrm { i f } } \\ \\tau _ { i , j } \\neq 0 { \\mathrm { o r } } \\ n = 0 , \\right.\n()

This ensures that the final merged model is densely populated with parameters that have actually been adapted to the task, maximizing the information gain from the ensemble of sub-models.

Experimental Results

The theory sounds solid, but does it translate to better performance? The researchers tested this on several mainstream LLM benchmarks using Llama-2-7b as the base model. They looked at instruction following (Alpaca), mathematical reasoning (GSM8K, MATH), and coding (HumanEval).

1. Superiority over Single SFT and Weighted Averaging

Table 1 summarizes the main results.

Table 1: Performance comparison of weighted-average and parameter-selection merging based on Llama-2-7t “weighted-avg” means weighted-average and “param-selection” means parameter-selection merging method.

Key takeaways from Table 1:

  • Merging beats Single SFT: Simply merging multiple models (even with weighted averaging) improves over a single fine-tuned model. For example, GSM8K accuracy jumps from 41.29% to 44.35% with weighted averaging.
  • Selection beats Averaging: The proposed param-selection method outperforms weighted averaging across almost all metrics.
  • Resampling is King: When the resampling strategy is added (denoted as + resample), the performance peaks. On GSM8K, the accuracy reaches 45.26%, a massive improvement over the single SFT baseline of 41.29%.

2. Does Model Size Matter?

One might wonder if this effect is exclusive to 7-billion parameter models. The authors extended their experiments to models of varying sizes, ranging from BERT-base (0.11b) to TinyLlama (1.1b) and Llama-2 (7b).

Table 2: Performance comparison between single SFT model and merged models across pre-trained models with various model sizes.

Table 2 shows a fascinating trend. While the method works for all sizes, the Average Delta (improvement) increases as the model size grows.

  • BERT-base gained +0.75 points on average.
  • Llama-2-7b gained +2.11 points on average.

This suggests that Parameter-Selection Merging is particularly scalable and beneficial for the massive Large Language Models we use today.

3. Analyzing the “Why”: Loss Visualization

To verify that the merging process actually solved the original “training imbalance” problem (where early samples had high loss), the authors plotted the loss of the merged model against the sample position.

Figure 3: Comparison of training losses across different models,with the first epoch sample position of the anchor model as the X-axis. Green lines represent final training losses of the anchor model; blue \\(\\mathbf { \\epsilon } ^ { \\star } \\mathbf { x } ^ { \\star }\\) markers indicate losses of SFT models trained with various data order; red dots show losses of the merged model.

In Figure 3, the green line represents the “Anchor Model” (a single SFT run). It shows high variance and the positional bias discussed earlier. The red dots represent the merged model. Notice how the red dots are consistently lower than the peaks of the green line and do not show a correlation with the data position. The merging process has successfully smoothed out the positional disadvantages.

Furthermore, this improvement isn’t just on the training set. Figure 4 shows the validation loss (performance on unseen data) during training steps.

Figure 4: Comparison of validation loss between single and merged SFT models at various training steps.

The merged model (red line) maintains a consistently lower loss than the single model (blue line) throughout the training process. This proves the method improves generalization, not just memorization.

4. Ablation: Is it Position or Batch Composition?

A skeptic might argue: “When you shuffle data, you aren’t just changing the position (early vs late), you are also changing which samples appear together in a batch.” Maybe the improvement comes from better batch diversity?

The authors anticipated this. They ran an ablation study where they fixed the combinations of samples within batches but changed the order in which those batches were fed to the model.

Table 4: Performance comparison of standard merged models and models with fixed intra-batch combinations.

As Table 4 shows, the performance of the “fixed intra-batch” model is nearly identical to the standard resampling model. This confirms that the position of the data (the order) is indeed the primary factor driving the training imbalance, and correcting it is the source of the performance gain.

Implementation Details

For those interested in reproducing these results, the paper provides detailed hyperparameters. For the main LLM experiments (Llama-2), they utilized standard learning rates and batch sizes, as seen below.

Table 6: Hyperparameters for training Llama-2-7b on LLM tasks.

For the smaller traditional models (BERT, TinyLlama), the settings were adjusted accordingly:

Table 5: Hyperparameters for training models on traditional tasks.

Conclusion and Future Implications

This research highlights a crucial but often overlooked aspect of training Deep Learning models: Data Order Matters. The “Training Imbalance” phenomenon shows that samples seen early in the fine-tuning process are at a disadvantage, leading to sub-optimal final models.

The proposed solution, Parameter-Selection Merging, offers a compelling alternative to standard weighted averaging. By treating the model parameters as a selection pool rather than a mathematical average, and by intelligently resampling to avoid unchanged parameters, we can construct a “super-model” that combines the best learned features from multiple training runs.

Key Takeaways:

  1. Don’t trust a single seed: A single training run is biased by its data order.
  2. Select, Don’t Average: When merging identical architectures trained on the same data, stochastic selection preserves distinct information better than averaging.
  3. Resampling Adds Value: Checking if a parameter actually changed (Task Vector \(\neq\) 0) before selecting it is a simple heuristic that significantly boosts performance.

As LLMs continue to grow in size and the cost of data acquisition rises, techniques like this—which squeeze more performance out of existing data and models without requiring new datasets—will become increasingly vital in the AI engineer’s toolkit.