Why AI Forgets: Solving Catastrophic Forgetting in Medical Imaging
Artificial Intelligence has made massive strides in medical diagnostics, particularly in the analysis of pathology slides. However, there is a hidden problem in the deployment of these systems: they are static. In the fast-moving world of medicine, new diseases are discovered, new subtypes are classified, and scanning equipment is upgraded.
Ideally, we want an AI model that learns continuously, adapting to new data without losing its ability to recognize previous conditions. This is the realm of Continual Learning (CL). But when researchers apply standard CL techniques to pathology, they run into a wall known as catastrophic forgetting. The model learns the new task but completely forgets the old one.
In this deep dive, we will explore a research paper that uncovers why this happens specifically in medical imaging and proposes a novel solution. The authors argue that in pathology models, the “brain” of the AI doesn’t forget—its “attention” does.
The Context: Whole Slide Imaging and MIL
To understand the problem, we first need to understand the data. Pathologists analyze Whole Slide Images (WSIs). These are digital scans of tissue samples that are gigapixels in size—far too large for a standard neural network to process all at once.
To handle this, researchers use Multiple Instance Learning (MIL). Here is how the standard pipeline works:
- Bag of Patches: The massive WSI is chopped into thousands of small squares called “patches.” The whole slide is treated as a “bag.”
- Feature Extraction: A neural network extracts features from every patch.
- Attention Mechanism: This is the crucial part. An Attention Network looks at all the patches and assigns an “importance score” to each one. It essentially decides which parts of the tissue are suspicious (tumor) and which are normal.
- Aggregation & Classification: The model aggregates the weighted features into a single slide-level representation and makes a final diagnosis (e.g., Cancer vs. Normal).
Mathematically, the slide-level feature \(\mathbf{z}\) is calculated as the weighted sum of patch features \(\mathbf{h}_n\) using attention scores \(a_n\):

The attention scores \(a_n\) are derived via a specific network architecture (often using tanh and sigmoid activations):

This architecture works beautifully for static datasets. But what happens when we introduce Class-Incremental Learning (CIL)? This is a scenario where we train the model on “Task 1” (e.g., detecting Breast Cancer), and later train it on “Task 2” (e.g., detecting Lung Cancer), without having access to the original Breast Cancer data anymore.
In standard computer vision (like classifying cats vs. dogs), catastrophic forgetting usually happens in the final classification layers. The authors of this paper discovered that in MIL, the behavior is fundamentally different.
The Core Insight: It’s Not the Classifier, It’s the Attention
The researchers performed a fascinating “Decoupling Experiment” to investigate where the memory loss was occurring. They took a model trained on a sequence of tasks and swapped parts of it with the original model trained only on the first task.
They asked: If we keep the old Classifier but use the new Attention network, does accuracy drop? What if we do the reverse?
The results were striking.

As shown in Table 1 above, look at the “Attention \(\theta_t\)” column. When the model uses the Attention network from later sessions (\(t=2\) or \(t=3\)) but keeps the original classifier, the accuracy on Task 1 plummets (down to near 0%). However, if they preserve the original Attention network (\(t=1\)) and use a newer classifier, the accuracy remains incredibly high (~86-89%).
Conclusion: The model isn’t forgetting how to classify the features; it is forgetting where to look.
Visualizing the Drift
This theoretical finding is visibly obvious when looking at heatmaps of the model’s focus.

In Figure 1, panel (2) shows the attention map after learning Task 1: the model correctly focuses on the tumor (red areas). However, in panels (3) and (4), as the model learns new tasks (fine-tuning), the attention “drifts.” The red hot spots vanish from the tumor area. The model effectively becomes blind to the cancer it previously knew how to find.
Why Does Attention Drift?
The authors provide a mathematical explanation for this phenomenon based on gradient analysis. They analyzed the gradients (the signals used to update the model weights) for both the classifier and the attention network.
For the classifier weights \(\phi\), the gradient is bounded. It relies on the aggregated feature \(z_j\), which is a weighted sum that cannot grow infinitely.

However, the gradient for the attention scores \(a_i\) depends on the term \(\phi^\top \mathbf{h}_i\):

This term represents the raw logit score of a specific patch. This is unbounded. If the model sees a new task with features that trigger a very high response in the classifier, the gradients for the attention layer can explode or fluctuate wildly.
This volatility is confirmed empirically. The graph below tracks the distribution of gradient values during training.

In Figure 2 (top), you can see the attention gradients oscillating with a high range throughout the learning process. In contrast, the classifier gradients (bottom) stabilize and shrink over time. This instability makes the attention network highly susceptible to overwriting old knowledge with new information—the definition of catastrophic forgetting.
The Solution: A Two-Pronged Approach
Armed with the knowledge that the attention layer is the weak link, the researchers proposed a new framework consisting of two main components:
- Attention Knowledge Distillation (AKD) to fix the forgetting.
- Pseudo-Bag Memory Pool (PMP) to handle the massive data size.
Here is the high-level architecture of their proposed system:

1. Attention Knowledge Distillation (AKD)
Standard Continual Learning often uses “Logit Distillation,” where the new model is forced to output the same final classification scores as the old model for old data. However, since the problem here is the attention, the authors introduce Attention Knowledge Distillation.
They force the new model (Student) to mimic the attention distribution of the old model (Teacher). The objective is to minimize the Kullback-Leibler (KL) divergence between the attention weights of the previous model (\(f_{\theta_{t-1}}\)) and the current model (\(f_{\theta_t}\)).

By locking in the attention patterns, the model ensures that even as it learns new disease features, it essentially “remembers” which patches were important for the previous diseases.
The final loss function combines standard cross-entropy loss (for learning the new task) with both attention distillation and logit distillation:

2. Pseudo-Bag Memory Pool (PMP)
The second challenge is specific to pathology: storage. To use distillation, you generally need to replay some old data (Replay Buffer). But WSIs are gigabytes in size. Storing hundreds of old slides is computationally expensive and memory-prohibitive.
The authors realized they didn’t need to store the whole slide. Since MIL relies on specific instances, they could distill a “Pseudo-Bag.”
Instead of storing \(N\) patches (where \(N\) could be 10,000+), they store a small subset of \(K\) patches (e.g., just a few dozen). But which ones? If you only store the “important” (high attention) patches, you lose the context of the background, which is necessary for the model to learn what not to look at.
They proposed the MaxMinRand strategy:
- Max: Select patches with the highest attention scores (the tumor).
- Min: Select patches with the lowest attention scores (the background).
- Rand: Select random patches (to capture general distribution).
This creates a condensed representation of the slide:

This drastically reduces memory footprint while retaining the crucial information needed for the Attention Knowledge Distillation to work.
Experiments and Results
The researchers tested their method against state-of-the-art Continual Learning methods (like EWC, LwF, DER++) on two major benchmarks: a Skin Cancer dataset and a composite Camelyon-TCGA dataset (spanning breast, lung, and kidney cancers).
Quantitative Performance
The results showed a massive improvement over existing methods. We look at AACC (Average Accuracy) and BWT (Backward Transfer—a measure of how much accuracy is retained on old tasks).

In Table 6, look at the “Ours” row compared to “Rehearsal ER” (Experience Replay) or “Regularization LwF”.
- On the CLAM model with a 30 WSI buffer, the proposed method achieves 0.754 Accuracy, whereas standard ER achieves only 0.494.
- The BWT (Backward Transfer) is -0.177 for “Ours” compared to -0.565 for ER. A number closer to zero is better; it means the model lost very little accuracy on previous tasks.
Stability Over Time
We can visualize the learning trajectory as the model moves from Task 1 to Task 3.

In Figure 4, the red line (Ours) stays significantly higher than the green (ER) or purple (DER++) lines. While other methods crash below 40% accuracy after the third task (essentially becoming random guessers on old tasks), the proposed method maintains robust performance.
The Trade-off: Plasticity vs. Stability
A classic dilemma in AI is the Stability-Plasticity dilemma. If a model is too stable (remembers old stuff), it can’t learn new stuff (plasticity). Ideally, a model should be in the top-left corner of the chart below: high Backward Transfer (Stability) and low Intransience (high Plasticity).

Figure 5 shows that the proposed method (represented by the star/red dot in the top-left cluster) achieves the best compromise. It retains knowledge significantly better (high y-axis value) without sacrificing the ability to learn new tasks (left on the x-axis).
Ablation Study: Does the Sampling Strategy Matter?
Finally, the authors checked if their MaxMinRand strategy for the memory pool was actually necessary.

Table 4 confirms that MaxMinRand yields the highest accuracy (0.729 for CLAM). Interestingly, purely selecting “Max” patches performed poorly (0.595). This proves that the model needs to see “boring” background patches in the memory pool to maintain a proper understanding of the attention distribution.
Conclusion and Implications
This research highlights a critical nuance in applying AI to medical imaging: distinct architectures fail in distinct ways. In Multiple Instance Learning, the attention mechanism is the Achilles’ heel of memory.
By diagnosing the problem—unbounded gradients in the attention layer causing drift—the authors were able to engineer a surgical solution. Attention Knowledge Distillation acts as a stabilizer, pinning down the model’s focus, while the Pseudo-Bag Memory Pool makes the solution practical for the massive file sizes typical in digital pathology.
For the future of healthcare, this is a significant step. It paves the way for diagnostic systems that can evolve alongside medical science, integrating new biomarkers and disease subtypes without needing to be rebuilt from scratch every time the textbooks are updated. By teaching AI not just what to think, but where to look, we ensure it remembers the lessons of the past.
](https://deep-paper.org/en/paper/2505.10649/images/cover.png)