Large Language Models (LLMs) like GPT-4 and Claude are remarkable not just for their ability to generate text, but for their ability to follow instructions and adhere to human values—a process known as alignment. However, there is a hidden cost to this alignment. When we use Reinforcement Learning with Human Feedback (RLHF) to teach a model to be “helpful, honest, and harmless,” it often suffers from catastrophic forgetting. It might become polite, but it suddenly performs worse on translation, reading comprehension, or common sense reasoning.

This phenomenon is known as the Alignment Tax.

In this post, we will deep-dive into a fascinating paper titled “Mitigating the Alignment Tax of RLHF”. We will explore why this tax exists, why traditional methods fail to fix it, and how a surprisingly simple technique—Model Averaging—coupled with a novel approach called Heterogeneous Model Averaging (HMA), provides a state-of-the-art solution.

The Problem: Alignment vs. Capability

To understand the alignment tax, we first need to look at the standard pipeline for training modern LLMs. It typically involves three stages:

  1. Pre-training: The model consumes vast amounts of data to learn general language abilities (\(\theta_{pre}\)).
  2. Instruction Tuning (SFT): The model is fine-tuned on instruction-response pairs to learn how to follow orders (\(\theta_0\)).
  3. RLHF: The model is further tuned using human preferences to maximize a reward function, ensuring safety and helpfulness (\(\theta\)).

The conflict arises in stage 3. As the model optimizes for the human-preference reward (being nice and safe), its weights shift away from the configurations that made it good at specific NLP tasks (being smart).

Illustration of the RLHF procedure. The diagram shows the progression from Pre-trained model to Instruction Tuning, and finally RLHF. Below, a red loop indicates ‘Forgetting (Alignment Tax)’, showing improvements in Helpfulness (+56%) coming at the cost of Common Sense (-5%), Translation (-45%), and Comprehension (-15%).

As shown in Figure 1, the trade-off is stark. A 56% improvement in “Helpfulness” might cost a staggering 45% drop in translation ability. This is the Alignment-Forgetting Trade-off.

The Landscape of Mitigation Strategies

Researchers have attempted various methods to reduce this tax. Most of these techniques fall into regularization or “stay close to the original” strategies. The goal is to maximize the reward \(r^*(x, a)\) while keeping the new model \(\theta\) close to the original model \(\theta_0\).

Common approaches include:

  • Early Stopping: Simply stopping the RLHF process before the model changes too much.
  • Weight Regularization (L1/L2): Adding a penalty term to the loss function based on the distance between the new weights and the old weights.
  • Knowledge Distillation: Forcing the new model’s output distribution to match the old model’s distribution.
  • LoRA (Low-Rank Adaptation): Freezing the main model and only training small adapter layers.

However, as the researchers discovered, these methods often result in a strict compromise. You can save the model’s intelligence, but you fail to align it properly; or you align it well, and it becomes “dumber.”

Graphs comparing various methods like Regularization-L1/L2, LoRA, and Early Stopping. The x-axis represents the Alignment Reward, while the y-axes represent Reading Comprehension, Commonsense QA, and Translation performance. Most methods show a decline in task performance as reward increases.

Figure 3 illustrates this frustration. Most methods (colored lines) trend downwards: as the Reward (x-axis) goes up, the Task Performance (y-axis) goes down. However, notice the orange line labeled MA (RSF). It consistently sits higher than the others, offering a better Pareto frontier. This is Model Averaging, the foundation of this paper’s contribution.

The Surprising Power of Model Averaging

Model Averaging (MA) is conceptually simple. You take the weights of the model before RLHF (\(\theta_0\)) and the weights after RLHF (\(\theta\)), and you linearly interpolate them using a ratio \(\alpha \in [0, 1]\).

\[ \theta_{avg} = \alpha \theta + (1 - \alpha) \theta_0 \]

If \(\alpha=1\), you have the fully aligned (but forgetful) model. If \(\alpha=0\), you have the smart (but unaligned) model. Surprisingly, a mix (e.g., \(\alpha=0.5\)) retains high alignment rewards while recovering a significant portion of the lost capabilities.

Why does Model Averaging work?

The authors provide a theoretical framework to explain this, drawing from Out-Of-Distribution (OOD) generalization theory. The core insight revolves around Feature Diversity.

When a neural network learns two different tasks (e.g., Task A is Translation, Task B is Alignment), it learns specific features for each.

  1. Overlapped Features: Some features are useful for both tasks (e.g., basic grammar, word definitions).
  2. Task-Specific Features: Some features are only useful for one.

The theoretical analysis suggests that model averaging effectively increases the “feature diversity” in the layers where tasks share an overlapped feature space. By averaging, we are essentially ensembling the feature detectors of the pre-RLHF and post-RLHF models.

Crucially, the effectiveness of averaging depends on the similarity of the tasks.

  • High Similarity: If Task A and Task B are similar, their feature spaces overlap significantly. Averaging boosts performance by making the feature representation more robust.
  • Low Similarity: If the tasks are disjoint, averaging might dilute the specialized features of both, helping neither.

This theory leads to a pivotal hypothesis: Different layers of a Transformer process different levels of abstraction.

  • Low-level layers (Input) tend to process syntax and basic semantics—features shared across almost all language tasks.
  • High-level layers (Output) tend to be more task-specific (e.g., specific formatting for a safety refusal vs. translating a sentence).

Empirical Validation: Layer-Wise Analysis

To test this, the researchers split the Transformer into three parts: Input, Middle, and Output. They tried averaging only one part at a time while keeping the others fully aligned.

Left: A diagram showing the model split into Input, Middle, and Output parts. Right: A graph showing Reading Comprehension vs. RLHF Reward. The green line (Input Part MA) shows a distinct curve compared to the Output Part MA.

The results in Figure 4 confirm the theory. Averaging the Input Part (Green line) behaves very differently from averaging the Output Part (Purple line). Specifically, averaging the lower layers (Input) preserves capabilities (“magic” improvement in both reward and task) better than higher layers, likely because the lower layers contain those shared, fundamental linguistic features that benefit both alignment and reasoning.

The Solution: Heterogeneous Model Averaging (HMA)

The observation above—that different layers contribute differently to the alignment-forgetting trade-off—implies that using a single mixing ratio \(\alpha\) for the entire model is suboptimal.

The authors propose Heterogeneous Model Averaging (HMA). Instead of a single scalar \(\alpha\), HMA assigns a unique mixing ratio \(\alpha_k\) to each block \(k\) of the model.

Diagram of Heterogeneous Model Averaging. It shows a model divided into three blocks: Input, Middle, and Output. Each block has a distinct mixing equation, such as 0.7<em>theta_0 + 0.3</em>theta for the input, and different ratios for subsequent blocks.

As shown in Figure 2, the Input part might be mixed with a ratio of 0.3, while the Output part uses 0.7. This allows the model to retain strong pre-trained features in the lower layers (where overlap is high) while allowing the higher layers to adapt more aggressively to the alignment task.

Formally, for \(K\) parts, the \(k\)-th part of the merged model is defined as:

Equation defining the merged parameter theta(K) for the k-th part as a weighted sum of the aligned model theta and the initial model theta_0 using ratio alpha_k.

Optimizing the Ratios

How do we find the perfect set of ratios \((\alpha_1, ..., \alpha_K)\)? We cannot simply run RLHF training for every possible combination—that would be computationally impossible.

Instead, the authors use a clever proxy distillation method.

  1. Take the fully aligned model \(\theta\).
  2. Generate a dataset of responses \(\mathcal{D}_\theta\) using this model. Since \(\theta\) is aligned, these responses have high rewards.
  3. Optimize the mixing ratios \((\alpha_1, ..., \alpha_K)\) to maximize the likelihood of generating these high-reward responses.

The optimization objective becomes:

The maximization objective function. It maximizes the sum of log probabilities of the data D_theta given the heterogeneous merged model theta(K), over the set of ratios Omega.

This effectively treats the mixing ratios as trainable parameters, optimizing them to capture the alignment behavior while the constraint (staying close to the base model via the averaging structure) mitigates forgetting.

Experiments and Key Results

The researchers validated HMA across two major model families (OpenLLaMA-3B and Mistral-7B) and three different RLHF algorithms:

  1. RSF (Rejection Sampling Finetuning)
  2. PPO (Proximal Policy Optimization)
  3. DPO (Direct Preference Optimization)

HMA vs. Vanilla Model Averaging

The primary metric for success is the Pareto Frontier. We want the curve to be as far to the top-right as possible (High Reward + High Task Performance).

Comparison graphs of HMA vs. MA. Top: RSF algorithm results showing HMA (Red line) consistently outperforming MA (Orange line). Bottom: DPO algorithm results showing a similar trend where HMA maintains higher reading comprehension scores for the same reward levels.

Figure 5 demonstrates the superiority of HMA.

  • Top (RSF): The red curve (HMA) is consistently above the orange curve (Standard MA). For the same level of alignment reward, HMA retains significantly higher Reading Comprehension scores.
  • Bottom (DPO): The trend holds even for Direct Preference Optimization. HMA pushes the boundary of what’s possible, allowing for “cheaper” alignment in terms of capability loss.

Generalization to Larger Models (Mistral-7B)

To ensure this wasn’t just a quirk of smaller models, they applied HMA to the Zephyr-7B-beta model (a DPO-aligned version of Mistral). They evaluated it using GPT-4 as a judge (AlpacaEval 2.0 win rates) and standard NLP benchmarks.

Table comparing Zephyr-7B-beta and Zephyr-7B-Gemma against HMA versions. HMA shows higher Win-Rates, Reading Comprehension scores, Commonsense Accuracy, and Translation BLEU scores across the board.

Table 1 is particularly telling. The HMA-enhanced Zephyr model achieves a 9.32% win rate (vs. 8.10% for the original) while simultaneously improving Reading Comprehension, Common Sense, and Translation scores. This contradicts the usual trade-off law: HMA improved both alignment (win-rate) and capabilities.

Conclusion and Implications

The “Alignment Tax” has long been viewed as an unavoidable cost of doing business with LLMs. If you want a safe model, you have to accept a slightly dumber one.

This research challenges that assumption. It reveals that:

  1. Catastrophic forgetting is non-uniform: Different layers of the transformer forget differently and contribute differently to alignment.
  2. Model Averaging is powerful: Simple interpolation beats complex regularization techniques.
  3. Heterogeneity is key: By treating layers differently via HMA, we can actively manage the trade-off.

For students and practitioners, this offers a practical, computationally efficient tool. HMA doesn’t require retraining the model from scratch; it operates on the weights after training, finding the optimal combination to recover lost knowledge. As models get larger and alignment becomes more critical, techniques like HMA will be essential to ensure our AI assistants are not just polite, but also smart.