别再对数据取平均了: 最优传输如何彻底改变数据集蒸馏

在深度学习的当今时代,我们目睹了对数据的极度渴求。像 CLIP 或现代大型语言模型 (LLM) 这样的模型会消耗数百万,有时甚至是数十亿个数据点。虽然效果显著,但这种规模在存储和计算方面造成了巨大的瓶颈。在这些海量数据集上从头开始训练模型,正逐渐成为只有那些拥有超级计算集群的人才能享有的特权。

数据集蒸馏 (Dataset Distillation) (也称为数据集浓缩 Dataset Condensation) 应运而生。理想情况下,这一过程就像训练数据的“压缩包”。其目标是将一个海量数据集 (如 ImageNet) 压缩成一个微小的合成集,每类仅包含几张图像。如果操作得当,在这个微小合成集上训练出的模型,其表现应该与在原始数百万张图像上训练出的模型几乎一样好。

然而,现有的蒸馏方法存在一个根本缺陷: 它们倾向于将数据“平均化”,从而丢失了原始图像丰富的几何结构。

在这篇文章中,我们将深入探讨一篇 CVPR 论文,题为 “OPTICAL: Leveraging Optimal Transport for Contribution Allocation in Dataset Distillation” (OPTICAL: 利用最优传输进行数据集蒸馏中的贡献分配) 。我们将探索研究人员如何发现“同质距离 (Homogeneous Distance) ”陷阱,以及他们如何利用 最优传输 (Optimal Transport, OT) 理论来解决这个问题,从而创建一个能全面提升性能的即插即用模块。

问题所在: 同质距离的陷阱

要理解 OPTICAL 的创新之处,我们首先需要了解数据集蒸馏通常是如何工作的。大多数现代方法都属于 子集生成 (Subset Synthesis) 类别。它们不是从真实图像中进行选择,而是生成新的合成像素模式。

优化过程通常如下所示: 你拥有大型真实数据集 (\(\mathcal{T}\)) 和微小的合成数据集 (\(\mathcal{S}\)) 。你定义一个数学函数来测量它们之间的“距离”或差异,然后更新合成像素以最小化这个距离。

标准的目标函数通常如下所示:

显示标准最小化目标的方程。

这看起来很直接,但在通常计算这个距离 (\(\mathbf{D}\)) 的方式中存在一个隐患。大多数方法将每张真实图像视为对合成图像具有 同等贡献 。 它们最小化的是一种“同质距离”。

从数学上讲,梯度 (告诉合成数据如何变化的信号) 通常均匀地累加误差:

显示具有均匀求和的梯度计算方程。

注意 \(1/|\mathcal{S}|\) 和 \(1/|\mathcal{T}|\) 这两项。这些分数意味着每一张真实图像在塑造合成数据时都贡献均等。

为什么同等贡献是不好的?

真实数据集是杂乱的。在同一个类别 (例如“狗”) 中,你可能有金毛寻回犬、巴哥犬和哈士奇。有些很容易分类;有些则是离群值。如果你强迫你的合成数据同等地匹配所有这些真实图像的 平均值,你就会丢失分布中特定的几何结构。

作者在下图中精彩地展示了这一点。

分布匹配的视觉对比。灰点是真实数据,彩色点是合成数据。上排显示原始分布,下排显示蒸馏结果。

在上面的 (b)(d) 中,注意合成数据 (蓝点) 是如何紧密聚集在中心周围的。无论真实数据 (灰色) 是一个宽环还是一个密集的簇都不重要;合成结果看起来几乎完全相同,因为它只是被吸引到了平均中心。这就是 同质距离最小化 陷阱。合成数据未能捕捉到现实世界的多样性和类内方差。

解决方案: OPTICAL

研究人员提出了一个名为 OPTICAL (OPTImal transport for Contribution ALlocation,用于贡献分配的最优传输) 的新框架。

核心思想简单而强大: 并非所有的真实图像都应该对每一张合成图像有同等的贡献。 我们需要的不是统一的一对一匹配,而是一个动态系统,它能决定哪些真实图像与特定的合成图像最相关,并据此分配“贡献”。

他们将距离最小化重新表述为一个包含两个步骤的 双层优化 问题: 匹配 (Matching)逼近 (Approximating)

OPTICAL 流程图,显示从真实/合成数据通过投影、代价矩阵计算、Sinkhorn 归一化,最后到更新循环的流程。

如上图所示,该流程如下运作:

  1. 投影 (Project) 数据到特征空间。
  2. 使用最优传输 匹配 (Match) 真实数据和合成数据,创建一个“贡献矩阵”。
  3. 利用该矩阵 逼近 (Approximate/Update) 合成数据以最小化距离。

第一步: 贡献矩阵

研究人员重写了距离方程,引入了一个加权矩阵 \(\mathbf{P}\):

显示带有 P 矩阵的新距离公式的方程。

这里,\(\mathbf{P}_{ij}\) 代表第 \(i\) 张真实图像对第 \(j\) 张合成图像的贡献程度。这不再是一个统一的平均值。如果一张合成图像类似于真实图像的一个特定子集 (例如狗中的“哈士奇”子集) ,\(\mathbf{P}\) 将赋予这些配对更高的权重,而降低其他配对的权重。

第二步: 最优传输与 Sinkhorn

我们要如何找到完美的矩阵 \(\mathbf{P}\)?这正是 最优传输 (OT) 发挥作用的地方。OT 是一个数学框架,用于确定将“质量”从一个分布移动到另一个分布的最有效方式。在这种语境下,我们计算的是将合成数据的分布移动以覆盖真实数据分布的成本。

寻找 \(\mathbf{P}\) 的优化问题如下所示:

定义带有熵正则化的 P 优化问题的方程。

其中 \(\mathbf{C}\) 是 代价矩阵 (Cost Matrix) , 用于衡量配对之间的差异,\(H(\mathbf{P})\) 是一个熵正则化项。

Sinkhorn 算法 求解精确的最优传输计算成本高昂 (\(O(N^3)\)) ,这将使训练变得极其缓慢。为了解决这个问题,作者使用了 Sinkhorn 算法 , 它提供了一种快速的迭代近似解法。

他们初始化一个核矩阵 \(\mathbf{K}\),然后迭代地归一化行和列,以确保它们总和符合正确的边缘约束 (即真实图像和合成图像的总数) 。

显示 Sinkhorn 迭代更新的方程。

这个迭代过程在 GPU 上非常高效。经过固定次数的迭代 (\(T\)) 后,得到的矩阵 \(\mathbf{K}^T\) 就变成了我们的贡献矩阵 \(\mathbf{P}^\lambda\)。

第三步: 希尔伯特空间

OPTICAL 的最后一个技术创新在于用于测量距离的“尺子”。简单的欧几里得距离或余弦相似度往往会遗漏复杂数据中的高阶结构细微差别。

为了解决这个问题,作者将数据表示投影到 再生核希尔伯特空间 (RKHS) 中。他们使用高斯核来测量相似度:

显示高斯核函数的方程。

通过结合多个不同尺度 (\(\sigma_k\)) 的核,他们计算出一个 相关性矩阵 (Relevance Matrix, \(\mathbf{R}\)) :

显示计算相关性矩阵 R 的方程。

最优传输步骤中使用的代价矩阵 (\(\mathbf{C}\)) 随后简单地推导为 \(\mathbf{J} - \mathbf{R}\) (其中 \(\mathbf{J}\) 是全 1 矩阵) 。这意味着相关性 (相似度) 高的配对具有较低的传输成本。

实验结果

OPTICAL 框架被设计为 即插即用 的。它可以添加到几乎任何现有的数据集蒸馏方法中,无论是面向优化的 (如 DC 或 DREAM) 还是基于分布匹配的 (如 DM 或 M3D) 。

1. 跨数据集的性能提升

作者在从 MNIST 到 ImageNet 的各种数据集上测试了 OPTICAL。结果是一致的: 添加 OPTICAL 提高测试准确率,且提升往往非常显著。

表格显示了在 CIFAR-10 等低分辨率数据集上的性能比较。红色文字表示显著提升。

表 2 中,查看 DM (Distribution Matching) 所在的行。在 CIFAR-100 (IPC=50) 上,准确率从 43.6% 跃升至 44.5% 。 虽然这看起来可能很小,但在数据集蒸馏领域,在所有设置下严格且持续地击败基线是鲁棒性的强烈信号。在 IDM 方法中,收益更为明显,在 CIFAR-10 (IPC=1) 上看到了 4.3% 的跃升。

2. 更快更好的收敛

动态贡献分配真的有助于训练过程吗?

折线图比较了 CIFAR-10 和 CIFAR-100 训练步骤中的测试准确率。

图 3 显示了随训练步骤变化的测试准确率。红线 (M3D + OPTICAL) 始终位于绿线 (基线 M3D) 之上。

  • 起步更快: OPTICAL 在训练早期就达到了可用的准确率水平。
  • 上限更高: 它避免了基线遇到的“瓶颈平台期”,证明合成数据捕捉到了更多有用的信息,而不仅仅是停留在平均值上。

3. 跨架构泛化能力

对数据集蒸馏的一个主要批评是,合成数据往往过拟合于用于生成它的特定神经网络架构。如果你使用 ConvNet 生成数据,在训练 ResNet 时表现可能会很差。

OPTICAL 极大地缓解了这个问题。通过捕捉数据的底层几何结构,而不仅仅是匹配梯度或平均统计数据,合成数据变得更加“通用”。

表格显示跨架构性能。在 ConvNet-3 上蒸馏的数据在 ResNet 和 DenseNet 上进行评估。

表 6 展示了这种跨架构的鲁棒性。当使用简单的 ConvNet-3 生成合成数据但在 DenseNet-121 上进行评估时:

  • 基线 DM 方法达到 39.0% 的准确率。
  • DM + OPTICAL 达到 41.9% 的准确率。
  • DANCE 方法在使用 OPTICAL 后从 64.5% 大幅跃升至 66.8%

这表明 OPTICAL 生成的合成数据包含真实的、可迁移的特征,而不是特定于架构的伪影。

意义何在

“同质距离最小化”问题是数据合成中一个微妙但普遍存在的问题。通过将每个数据点视为同等贡献者,我们无意中抹平了那些使深度学习模型具有鲁棒性的纹理和不规则性。

OPTICAL 证明了我们不需要发明全新的蒸馏范式来解决这个问题。通过借用 最优传输 理论——具体来说是计算将一个分布映射到另一个分布的最有效方法的概念——我们可以让现有方法变得更聪明。

关键要点:

  1. 不再取平均: 合成数据不应只是真实数据的平均值。它需要反映几何结构。
  2. 动态分配: 使用最优传输允许系统动态决定哪些真实图像应该影响哪些合成图像。
  3. 高效率: 归功于 Sinkhorn 算法,这种复杂的数学计算可以以极小的计算开销 (每次迭代仅增加毫秒级时间) 集成到训练循环中。
  4. 通用性: 它在现有方法之上工作,同样提升了低分辨率和高分辨率数据集的结果。

随着数据集持续增长,像 OPTICAL 这样的技术对于使 AI 更可持续、更易于访问和更高效将至关重要。