引言: “水鸟”问题

想象一下,你正在训练一个 AI 来对鸟类进行分类。你给它喂了数千张水鸟 (如鸭子) 和陆鸟 (如麻雀) 的图片。模型在验证集上达到了 99% 的准确率。你准备好部署了。

但随后,灾难发生了。你给模型展示了一只站在草地上的鸭子,它自信地大喊: “陆鸟!”你给它展示了一只飞过湖面的麻雀,它预测: “水鸟!”

哪里出错了?

这是一个典型的伪相关 (Spurious Correlation) 案例。在训练数据中,95% 的水鸟出现在水背景下。模型作为一个“懒惰的学习者”,并没有学会识别喙或羽毛;它只是学会了 蓝色背景 = 水鸟。它依赖了一个“捷径”特征。

对于像 CLIP 这样在海量网络规模数据上训练的现代基础模型来说,这个问题普遍存在。虽然这些模型具有很高的平均准确率,但它们的最差群组准确率 (Worst-Group Accuracy, WGA) (例如,分类在陆地上的水鸟) 往往惨不忍睹。

解决这个问题通常需要做两件昂贵的事情之一:

  1. 重新训练整个模型以消除偏差 (计算成本高昂) 。
  2. 手动标记群组 (告诉模型“这是一只在陆地上的鸭子”) ,以便它可以显式地学习 (标注成本高昂) 。

在这篇文章中,我们将深入探讨一篇引人入胜的论文,题为 “Project-Probe-Aggregate: Efficient Fine-Tuning for Group Robustness” (投影-探测-聚合: 面向群组鲁棒性的高效微调) 。作者提出了一种名为 PPA 的巧妙三步法,可以在重新训练整个模型且需要昂贵的群组标签的情况下修复这些偏差。它通过微调不到 0.01% 的参数,达到了最先进的结果。

背景: 鲁棒性与基于失败的策略

在确定解决方案之前,我们必须将问题形式化。在标准机器学习中,我们通常最小化数据分布 \(\mathbb{P}\) 上的平均误差:

标准期望误差

然而,当存在伪相关时,最小化平均误差允许模型忽略少数群组 (如水面上的陆鸟) 。为了衡量真正的鲁棒性,我们查看最差群组准确率 (WGA) 。 我们将数据分为多个群组 \(g \in \mathcal{G}\) (例如,水鸟+水背景、水鸟+陆地背景、陆鸟+水背景、陆鸟+陆地背景) ,并衡量最难群组上的误差:

最差群组误差

无监督挑战

这个问题的最难版本是无监督群组鲁棒性 。 这意味着我们没有标签来告诉我们要哪张图片属于哪个群组 (\(g\)) 。我们只有类别标签 \(y\) (鸟的类型) 和图片 \(x\)。

解决这个问题的一个常见策略是基于失败的去偏方案 (Failure-Based Debiasing Scheme) 。 其逻辑简单而强大:

  1. 训练一个标准的、“懒惰的”模型 (通常称为经验风险最小化或 ERM) 。
  2. 让它过拟合伪特征。
  3. 识别模型做错的样本。这些很可能是少数群组 (“反偏差”样本) 。
  4. 训练第二个模型,让其额外关注这些困难样本。

我们要分析的这篇论文极大地改进了这个框架。作者认为,标准模型并没有“足够偏见”到可以清晰地识别少数群组,而且标准的重加权并不是最优的。

PPA 方法

作者提出了 Project-Probe-Aggregate (PPA) (投影-探测-聚合) 。这是一种参数高效的微调方法,意味着巨大的预训练骨干网络 (如 CLIP) 被冻结,我们只训练一个小的线性层。

该方法由三个不同的步骤组成。让我们逐一分解。

第 1 步: 投影 (Project,创建“极度偏见”模型)

为了找到少数群组 (例如,陆地上的鸭子) ,首先我们需要一个严重依赖背景 (水 vs. 陆地) 的模型。如果我们能建立一个看背景的模型,它肯定会把陆地上的鸭子分错。

标准训练试图同时学习物体背景。作者提出了一种方法,强迫模型忽略物体 (类别) 而只关注伪特征。

他们利用了基础模型 (CLIP) 的文本编码器。他们获取类别名称的文本嵌入 (例如,“一张水鸟的照片”) 来创建一个矩阵 \(Z\)。这些嵌入代表了类别的“核心”特征。

为了强迫模型查看除类别概念以外的所有内容,他们在数学上将图像特征投影到这些类代理 (class proxies) 的零空间 (nullspace) 上。

投影矩阵 \(\Pi\) 计算如下:

投影矩阵方程

这里,\(Z\) 代表类别特征。乘以 \(\Pi\) 有效地移除了特征空间中对应于类别定义的那个方向。

然后,我们在这些“投影后”的特征上训练一个线性分类器 \(f_b\) (偏见模型) :

偏见分类器方程

为什么这样做有效? 通过移除描述实际物体 (鸟) 的信号,分类器别无选择,只能依赖剩下的信号来预测标签。在具有伪相关的数据集中,剩下的信号就是强烈的伪特征 (背景) 。

作者在数学上证明了 (我们稍后会提到) 这种投影放大了模型对伪相关的依赖。

为了验证这一点,请看下面识别少数群组的精确率和召回率。 PPA 方法 (绿色) 在寻找“最差群组”样本方面显著优于标准 ERM (橙色) 和其他方法。

群组识别的精确率和召回率

第 2 步: 探测 (Probe,为群组评分)

现在我们有了一个“极度偏见”的模型 \(f_b\),我们用它来创建伪群组标签

如果偏见模型预测类别正确,该图像可能遵循伪相关 (多数群组) 。如果它预测错误,该图像可能违反了相关性 (少数群组) 。

我们根据偏见模型是否出错来定义一个伪属性 \(\hat{a}\):

伪属性定义

现在,每张图像都有一个伪群组标签 \(\hat{g} = (y, \hat{a})\)。

接下来,我们训练一个探测器 (Probe) ——一个新的线性分类器 \(h_d\)——来预测这些伪群组标签。但我们不只是使用标准的交叉熵。我们需要考虑到这些群组是严重不平衡的。

作者引入了群组 Logit 调整 (Group Logit Adjustment, GLA) 。 这个损失函数根据估计的群组先验 \(\hat{\beta}\) (每个群组出现的频率) 为 Logit 添加一个边际 (margin) 。

群组 Logit 调整损失

这里,\(\tau\) 是一个超参数,控制我们对不平衡的修正程度。这一步有效地训练了一个非常擅长区分四种情况 (水上的水鸟、陆地上的水鸟等) 的模型,即使它从未见过真实的群组标签。

第 3 步: 聚合 (Aggregate,最终分类器)

我们现在有一个预测群组标签的探测器 \(h_d\)。但对于我们的最终任务,我们不想预测“陆地上的水鸟”;我们只想预测“水鸟”。

在推理过程中,计算群组概率并将它们加起来可能会带来计算开销。作者提出了一种巧妙的简化方法,称为权重空间聚合 (Weight-Space Aggregation)

由于探测器是一个线性模型 (\(W_d\)) ,特定类别 \(y\) 的权重可以简单地是属于该类别的所有群组权重的总和。

聚合方程 权重空间聚合细节

这产生了一个最终的去偏分类器 \(f_d\),它只是一个标准的线性层,与标准模型相比零额外推理成本。


理论分析: 为什么这行得通

论文为第 1 步和第 2 步中使用的直觉提供了严格的理论支持。

为什么投影会放大偏见 (命题 1)

作者分析了一个线性回归设置,其中目标 \(y\) 依赖于核心特征 \(c\) 和伪特征 \(s\)。

线性回归模型

他们在数学上证明,当你 (使用 \(\Pi\)) 剔除核心特征时,新模型中分配给伪特征的权重 (\(\gamma'\)) 会比原始权重 (\(\gamma\)) 增加。

伪特征权重增加

因为分母为正,且相关项 \(\mathbf{r}_s^\top \mathbf{r}_{y_o}\) 在伪相关数据集 (背景与标签相关) 中通常为正,所以 \(\gamma' > \gamma\)。这证明了投影模型在数学上被迫变得更加偏颇。

贝叶斯最优性 (命题 2)

作者还试图最小化平衡群组误差 (Balanced Group Error, BGE) , 该误差平等对待所有群组,无论其在训练集中的大小如何。

平衡群组误差定义

他们证明了他们的聚合策略 (将群组 Logit 求和并减去群组先验的对数) 是最小化 BGE 的贝叶斯最优分类器

贝叶斯最优分类器

这一理论结果解释了为什么第 2 步中的群组 Logit 调整至关重要——它使训练目标与平衡性能的最终目标保持一致。


实验与结果

作者在已知的存在伪相关的标准基准上评估了 PPA: Waterbirds (水鸟) 、CelebA (分类头发颜色,受性别偏差影响) 、MetaShiftLiving-17BAR

性能比较

结果令人印象深刻。如下表 2 所示,PPA (最后一行) 始终优于其他无监督方法 (如 JTT、CnC 和各种提示策略) 。

Waterbirds 和 CelebA 比较表

  • Waterbirds: 使用 CLIP ResNet-50,PPA 达到了 84.3% 的最差群组准确率。这以显著优势击败了之前的最先进水平 (CFR 的 76.9%) 。
  • CelebA: 改进更为明显,从约 77% (之前的方法) 跃升至 91.1%
  • 效率: 关键是,PPA 仅通过微调 <0.01% 的参数就实现了这一点,而许多基线方法需要更繁重的微调。

在 Living-17 和 BAR 数据集上也看到了类似的优势:

Living-17 和 BAR 比较表

消融实验: 我们真的需要投影和 Logit 调整吗?

你可能会想,第 1 步 (投影) 或第 2 步 (GLA) 是否真的是必要的。作者进行了消融实验来验证这一点。

消融实验表

  • 行 (a) vs (b): 仅添加群组 Logit 调整 (GLA) 提高了性能,但还不够。
  • 行 (b) vs (d): 添加投影步骤 (第 1 步) 提供了巨大的准确率提升 (例如,在 Waterbirds 上,从 54.4% 提升到 84.3%) 。这证实了标准模型并没有“足够偏见”来准确识别少数群组;投影是必不可少的。
  • 行 (d) vs (e): 行 (e) 使用*真实 (Ground Truth) *群组标签。PPA (行 d) 非常接近使用真实标签的性能,验证了伪标签的质量。

对超参数的敏感性

该方法在 Logit 调整损失中引入了 \(\tau\) (tau)。如果 \(\tau\) 不完美,方法会失效吗?

Tau 敏感性图

上图显示性能在 \(\tau=1.0\) 附近相对稳定,这与理论预测 (命题 2) 一致,即 \(\tau=1\) 对于最小化平衡群组误差是最优的。

可视化数据

为了使结果更具体,让我们看看“群组”实际上是什么样子的。

Waterbirds: 区别很明显——水上的水鸟 vs. 陆地上的水鸟,以及水上的陆鸟 vs. 陆地上的陆鸟。 Waterbirds 样本

CelebA: 任务是检测“金发”。伪特征是性别。少数群组是“金发男性” (在数据集中很少见) 和“黑发女性”。 CelebA 样本

结论与启示

Project-Probe-Aggregate (PPA) 论文为如何处理现代 AI 系统中的偏差提供了一堂大师课。它不是用蛮力 (更多数据、更多训练) 来对抗模型学习捷径的倾向,而是使用了一种类似柔道 (借力打力) 的方法:

  1. 顺应偏见: 使用投影使偏见变得更严重,从而更容易检测到。
  2. 数学修正: 使用 Logit 调整在理论上平衡群组。
  3. 简化: 将所有内容折叠回一个简单的线性分类器。

对于学生和从业者来说,主要的收获是:

  • 基础模型携带了网络的偏见。 你不能仅仅因为 CLIP 具有高准确率就假设它“知道”鸟是什么。
  • 无监督去偏是可能的。 我们并不总是需要昂贵的标注来修复公平性问题。
  • 线性探测器很强大。 你可以通过智能地训练最终分类层来修复模型中的深层结构问题。

随着 AI 模型不断发展,像 PPA 这样高效、有数学依据的技术对于确保这些系统在现实世界中的鲁棒性、公平性和可靠性将至关重要。