In the rapidly evolving world of Artificial Intelligence, we are witnessing a “survival of the fittest” regarding model size. Large Language Models (LLMs) like GPT-4 possess an emergent ability known as Chain-of-Thought (CoT) reasoning. Instead of just jumping to an answer, they break down complex problems into intermediate steps, much like a human showing their work on a math test.
However, running these massive models is expensive and computationally heavy. This has led to a surge in research focused on Knowledge Distillation—teaching smaller, more efficient “Student” models (SLMs) to mimic the reasoning capabilities of “Teacher” LLMs.
But there is a major catch. While student models are great at passing tests they have seen before (In-Domain), they often fail miserably when faced with new, unfamiliar problems (Out-Of-Domain). Why? Because they aren’t actually learning to reason; they are learning to take shortcuts.
In this post, we will deep dive into a fascinating paper titled “Improve Student’s Reasoning Generalizability through Cascading Decomposed CoTs Distillation.” We will explore why standard distillation fails, the phenomenon of “spurious correlations,” and how a new method called CasCoD forces models to truly think before they speak.
The Problem: The “Smart Hans” Effect in AI
To understand the core problem, we first need to look at how we currently teach small models. In Standard CoT Distillation (Std-CoT), we take a dataset of questions (\(q\)) and use a Teacher LLM to generate a chain of thought (\(CoT\)) followed by an answer (\(a\)). We then fine-tune the student model to output the entire sequence (\(q \to CoT \to a\)) in one go.
It sounds perfect on paper. The student should learn the reasoning logic, right?
Unfortunately, neural networks are lazy learners. They look for the easiest path to minimize loss. When a model is trained to output both the reasoning and the answer simultaneously, it often spots spurious correlations between the question and the answer, effectively ignoring the reasoning steps.

As shown in Figure 1 above, researchers found a startling paradox. A model trained simply to guess the answer (Answer SFT) often outperformed models trained with Chain-of-Thought distillation on Out-Of-Domain (OOD) tasks.
Look at the example in the bottom half of the image. The question asks why someone brought a “swimsuit.” The model sees the word “swimsuit” and immediately locks onto the answer option containing “swim,” ignoring the context of a “ski resort.” This is a spurious correlation. The model “presets” the answer based on keywords and then hallucinates a rationale to justify it. It’s not reasoning; it’s rationalizing a guess.
The Solution: CasCoD
The researchers propose a method called Cascading Decomposed CoTs Distillation (CasCoD). The intuition is simple yet profound: if the model is cheating by looking at the answer too early, hide the answer.
CasCoD breaks the learning process into two distinct, cascaded steps:
- Rationale Learning: Teach the model to generate only the reasoning, without the final answer.
- Answer Learning: Teach the model to derive the answer based only on the question and the reasoning it just generated.
By decoupling these steps, the model cannot shortcut from the question to the answer. It is forced to traverse the reasoning path.

As illustrated in Figure 2, the standard approach (top) pushes the whole sequence at once. The CasCoD approach (bottom) forces a structural separation. Let’s break down the mathematics and mechanics of how this works.
Deep Dive: The Methodology
1. The Flaw in Standard Distillation
In standard distillation, the loss function looks like this:

Here, the model minimizes the negative log-likelihood of the entire sequence (Rationale + Answer). Because the Answer is part of the same generation stream, the model implicitly learns to predict the answer tokens based on the question tokens, often treating the intermediate reasoning tokens as mere noise or filler.
The loss function \(\ell\) is calculated as:

2. Step One: Rationale Learning (\(q \to r\))
In the first step of CasCoD, the researchers modify the training data. They strip away the final answer. The input is the Question (\(q\)), and the target label is only the Rationale (\(r\)).
The objective is defined as:

Crucially, the answer is removed from the output. The model has no “target” to cheat toward. It must focus entirely on the logic required to analyze the question. It learns to construct a path without knowing the destination yet.
3. Step Two: Answer Learning (\(q, r \to a\))
Once the model learns to reason, it needs to learn to conclude. In the second step, the input is the Question concatenated with the Rationale (\(q \oplus r\)). The target label is the Answer (\(a\)).
The objective becomes:

Here, the model learns that the answer is a direct consequence of the reasoning, not just a statistical correlation with the question words.
4. The Cascading Combination
While these are conceptually two steps, the researchers optimize them simultaneously using a weighted loss function:

The hyperparameter \(\alpha\) (alpha) balances the two objectives. As we will see in the experiments, finding the right balance between learning to reason and learning to answer is key.
Experiments and Results
The researchers tested CasCoD using LLaMA-2-7B as the student model and ChatGPT as the teacher. They used BIG-Bench Hard (BBH) as the In-Domain (IND) dataset (where the student practiced) and tested generalization on four distinct Out-Of-Domain (OOD) benchmarks, including AGIEval and ARC (science exams).
Main Performance
The results were overwhelmingly positive.

Table 1 highlights several key takeaways:
- Std-CoT Struggles: Standard distillation (Std-CoT) often performed worse than simple Answer-SFT (fine-tuning without reasoning) on OOD tasks. This confirms the hypothesis that standard CoT induces overfitting to shortcuts.
- CasCoD Dominates: CasCoD (specifically with \(\alpha=0.3\)) achieved the highest performance across the board. On OOD tasks like ARC-Easy (ARC-E) and ARC-Challenge (ARC-C), it significantly outperformed other distillation methods.
- Closing the Gap: CasCoD allowed the small 7B student model to recover a significant portion of the Teacher LLM’s performance, even in zero-shot settings.
Is the “Two-Step” Process Necessary?
You might wonder: Can’t we just mask the loss in a single forward pass? Do we really need two distinct calculations? The researchers tested a “Single-Step” implementation of CasCoD against the full “Two-Step” version.

Figure 3 shows that the physical decomposition matters. The two-step process (pink bars) consistently beats the single-step implementation (blue bars). This suggests that the internal state of the model needs to be “reset” or distinct between reasoning and answering to fully break the spurious correlations.
Robustness: Model Size and Data Efficiency
Does this only work for 7B models? What if we have very little data?
1. Model Size: The researchers tested CasCoD on TinyLLaMA (1.1B), LLaMA-2 (7B), and LLaMA-2 (13B).

As shown in Figure 4, CasCoD (the red line) consistently outperforms baselines across all model sizes. Interestingly, as the model gets bigger (13B), the gap between CasCoD and standard methods widens on OOD tasks. This implies that larger models are actually more prone to learning shortcuts if not constrained by CasCoD.
2. Data Efficiency: Training data is expensive. A great method should work with less of it.

Figure 5 reveals a massive efficiency gain. CasCoD trained on just 12.5% of the data (red line, far left) often outperforms Std-CoT trained on 100% of the data (blue dotted line). This is a game-changer for resource-constrained environments.
Why Does It Work? Analysis.
Faithfulness: Walking the Walk
One of the biggest criticisms of CoT in small models is unfaithfulness—the model generates a correct reasoning path but then outputs a completely unrelated answer, or vice versa.
To measure this, the researchers used the LAS (Leakage-Adjusted Simulatability) metric. Essentially, this asks: “Does the generated rationale actually help predict the answer?”

Table 3 shows that CasCoD produces highly faithful rationales. The score of 36.2 is comparable to the Teacher LLM itself (38.7). This proves that CasCoD students aren’t just parroting text; they are relying on their generated reasoning to find the answer.
The LAS metric is calculated as:

The Balance of Power (\(\alpha\))
The hyperparameter \(\alpha\) controls how much the model focuses on the Answer (\(q, r \to a\)) versus the Rationale (\(q \to r\)).

Figure 6 provides a crucial insight for practitioners. The performance (y-axis) peaks when \(\alpha\) is small (around 0.1 to 0.3). This means the model should spend the majority of its learning capacity on generating the rationale. If you weigh the answer loss too heavily (high \(\alpha\)), the model starts shortcutting again, and performance drops.
Case Studies: Seeing is Believing
Let’s look at concrete examples where Std-CoT fails and CasCoD succeeds.
Example 1: Math Word Problems In this AGIEval example, the model must calculate the yearly growth of a boy’s height based on a formula.

In Table 16, notice the Std-CoT response. It correctly guesses “(A)” but the reasoning is nonsense: "…height of a boy… is approximately 36 inches. Therefore… 3 inches." It hallucinated numbers to force the answer to be A. CasCoD, however, correctly identifies the slope of the equation (\(3a\)) represents the yearly increase. It reasons correctly to arrive at the answer.
Example 2: Scientific Knowledge Here, the question asks about the most abundant gas in Earth’s atmosphere.

In Table 17, Std-CoT hallucinates a table where Oxygen is 20.95% and Nitrogen is 78.09%, but then concludes: “According to this table, oxygen is the most abundant.” It completely fails basic logic because it likely associated “Atmosphere” + “Life” \(\to\) “Oxygen” during training shortcuts. CasCoD correctly retrieves the knowledge that Nitrogen is 78% and identifies it as the most abundant.
Conclusion
The “black box” nature of neural networks often leads them to solve problems in unexpected ways—including cheating by finding statistical shortcuts between questions and answers. While this works for familiar data, it creates brittle models that fail in the real world.
The CasCoD method presented in this paper offers a robust solution by:
- Decomposing the thinking process from the answering process.
- Cascading the outputs so the answer is strictly dependent on the rationale.
- Restructuring the loss function to penalize shortcuts.
The results are clear: to build small models that generalize well, we must force them to slow down and think. By prioritizing the process of reasoning over the result, we end up with students that don’t just memorize the textbook, but actually understand the subject.
](https://deep-paper.org/en/paper/2405.19842/images/cover.png)