以小博大: MGD³ 如何在无需微调的情况下进行数据集蒸馏

在深度学习的现代纪元,人们主要信奉“越大越好”。我们构建巨大的模型,并为其提供更庞大的数据集。然而,当涉及到计算资源和存储时,这种轨迹便触碰到了瓶颈。并非每位研究人员或学生都能使用 H100 GPU 集群。这种瓶颈催生了一个迷人的研究领域: 数据集蒸馏 (Dataset Distillation)

想象一下,将 ImageNet 这样包含超过一百万张图像的数据集,压缩到每类仅 10 或 50 张图像。目标是什么?是为了在这个微小的合成数据集上训练神经网络,并达到与在完整的一百万张图像上训练相媲美的准确率。

今天,我们将深入探讨一篇新论文: “MGD³: Mode-Guided Dataset Distillation using Diffusion Models” (MGD³: 使用扩散模型的模式引导数据集蒸馏) 。 这项研究提出了一种聪明的方法,利用预训练的扩散模型来生成这些蒸馏数据集,而且无需像以前的方法那样进行昂贵的微调。如果你对生成式 AI、数据效率感兴趣,或者只是想知道我们如何让深度学习变得更触手可及,这篇文章就是为你准备的。


问题所在: 为什么蒸馏很难

数据集蒸馏不仅仅是选择“最好”的图像;它是关于合成具有最高信息密度的新图像。

传统上,主要有两种方法:

  1. 基于优化的蒸馏 (Optimization-based Distillation) : 这些方法将合成图像的像素视为可学习的参数。它们试图最小化在真实数据上训练的模型与在合成数据上训练的模型之间的梯度或特征差距。虽然对微小的数据集有效,但这在计算上非常耗时,且难以扩展到高分辨率图像。
  2. 生成式数据集蒸馏 (Generative Dataset Distillation) : 这种较新的方法使用生成模型 (如 GAN 或扩散模型) 来合成数据。你存储的是生成模型参数中的“知识”,而不是直接存储像素。

图 1. 基于优化的数据集蒸馏 vs 生成式数据集蒸馏。

图 1 所示,优化方法 (顶部) 不断循环回溯,根据匹配损失更新合成数据集。生成方法 (底部) 学习一次分布,然后合成数据集。

扩散模型中的“多样性”陷阱

扩散模型是当前图像生成的王者。然而,当用于数据集蒸馏时,它们有一个缺陷: 模式崩塌 (Mode Collapse) 。

扩散模型的训练目标是最大化似然 (maximize likelihood) ,这意味着它们倾向于生成数据分布密集区域 (即“平均”外观) 的图像。如果你让一个标准的扩散模型生成一只“狗”,它可能会给你最常见的品种和最常见的姿势。它可能会忽略罕见的品种或不寻常的角度。

对于数据集蒸馏来说,这是一场灾难。为了训练一个鲁棒的分类器,你需要多样性 。 你需要模型看到狗的正面、侧面、近景和远景。如果你蒸馏的数据集只包含“平均”视角,学生模型将无法泛化。

以前的解决方案 (如 MinMax 扩散) 试图通过显式微调扩散模型来强制增加多样性。但微调既昂贵又缓慢,且违背了使用现成预训练模型的初衷。


MGD³ 登场: 模式引导扩散

MGD³ 的作者提出了一种无需任何微调的解决方案。他们不是重新训练模型,而是通过操纵采样过程来确保多样性。

其核心思想依赖于三个阶段:

  1. 模式发现 (Mode Discovery) : 找出数据的不同“聚类” (模式) 在哪里。
  2. 模式引导 (Mode Guidance) : 强制扩散模型生成属于这些特定聚类的图像。
  3. 停止引导 (Stop Guidance) : 知道何时停止强制,以保持图像的高质量。

让我们可视化一下标准扩散和这种新方法在轨迹上的差异。

图 2. 扩散轨迹对比。

图 2 中:

  • (a) 标准扩散 (DiT): 所有生成的样本 (红色 X) 都聚集在密集的橙色区域。多样性低。
  • (b) 微调后 (MinMax): 多样性较好,但需要昂贵的训练。
  • (c) MGD³ (Ours): 该方法识别出不同的目标 (绿色星号) ,并在去噪过程 (绿色线条) 中引导生成方向,然后让模型自然完成 (黑色线条) 。

让我们详细分解 MGD³ 流程的三个阶段。

图 3. MGD3 流程概览,展示了模式发现、引导和停止引导。

第一阶段: 模式发现 (Mode Discovery)

在生成多样化数据之前,我们需要知道对于特定类别来说“多样化”长什么样。研究人员使用预训练的变分自编码器 (VAE) 将原始数据集编码到潜在空间 (latent space) 。

为什么是潜在空间?因为像素空间噪声太大。潜在空间捕捉的是语义内容 (形状、物体类型) ,而不是高频细节。

一旦数据被映射到潜在空间,他们就应用 K-Means 聚类为每个类别找到 \(N\) 个质心 (模式) 。如果预算是每类 10 张图像 (IPC = 10) ,他们就找到 10 个不同的模式。每个模式 (\(m_i\)) 代表该类别的一种不同“原型”——例如,一种模式可能是“坐着的金毛寻回犬”,另一种可能是“奔跑的哈巴狗”。

第二阶段: 模式引导 (Mode Guidance)

现在我们有了目标模式,我们需要生成落在它们附近的图像。标准扩散包括从随机噪声开始并逐渐去噪。

MGD³ 干预了这个过程。在每个时间步 \(t\),模型预测去噪后的图像 \(\hat{x}_0\)。作者根据图像当前的去向与目标模式 \(m_i\) 之间的差异计算引导信号

引导信号向量计算如下:

引导信号向量公式。

该向量将生成过程指向目标模式。然后将此信号注入扩散模型的噪声预测步骤中。修改后的噪声预测如下所示:

模式引导噪声预测公式。

在这里,\(\lambda\) 是一个控制我们推动模型向模式靠拢力度的标量。通过应用这种引导,随机噪声被专门引导成为类似于第一阶段中发现的聚类质心的图像。

第三阶段: 停止引导 (Stop Guidance)

这可能是论文中最直观的贡献。

扩散生成过程通常分为三个阶段:

  1. 混沌阶段 (Chaotic): 决定全局结构的早期步骤。
  2. 语义阶段 (Semantic): 物体和形状形成的中期步骤。
  3. 精炼阶段 (Refinement): 纹理和细节被打磨的最后步骤。

作者发现,如果你一直应用模式引导直到结束,图像看起来会很奇怪。引导会迫使像素值严格遵守质心,这会破坏扩散模型擅长生成的自然纹理和高频细节。

解决方案? 停止引导。 他们只在生成的初始部分 (混沌和语义阶段) 应用引导信号。一旦总体结构被锁定 (例如,在 50 步中的第 25 步) ,他们就会关闭引导 (\(\lambda = 0\)) ,让标准扩散过程完成剩下的工作。

这确保了图像具有多样化模式的结构,同时也拥有自然图像的质量


可视化过程

为了真正理解停止引导的影响,让我们看看随时间变化的生成过程。

图 11. 不同停止引导时间下的去噪轨迹可视化。

图 11 中,X 轴代表去噪时间步长 (从 \(t=50\) 的噪声到 \(t=0\) 的图像) 。Y 轴代表引导停止的时间 (\(t_{stop}\)) 。

  • 顶行 (\(t_{stop}=50\)): 这实际上意味着无引导 。 模型生成了一只普通的狗。
  • 底行 (\(t_{stop}=0\)): 全程引导直到结束。图像通常带有伪影或看起来“过度约束”。
  • 中间行: 通过中途停止,模型引导噪声进入特定的姿势/品种 (与顶行不同) ,但随后将其精炼成清晰的图像。

作者发现 \(t_{stop}\) 在 20-30 左右 (在 50 步的过程中) 提供了完美的平衡。


实验与结果

那么,这种方法真的有效吗?作者在多个基准上测试了 MGD³,包括 ImageNette、ImageNet-100 和庞大的 ImageNet-1K。

性能 vs. SOTA (当前最佳)

主要指标是验证准确率: 如果我们在 MGD³ 生成的合成图像上训练一个新的 ResNet-18,它在真实测试数据上的表现如何?

图 4. 准确率对比柱状图。

图 4 所示,MGD³ (Ours) 始终优于之前的 SOTA 方法。

  • 图表 (c): 在使用文生图模型的 ImageNet-1K 上,MGD³ 显著击败了标准的 Stable Diffusion。
  • 图表 (d): 与 DiT、SRe²L 和 MinMax 等方法相比,MGD³ 保持领先,特别是随着每类图像数量 (IPC) 的增加。

为了更细致地观察数字,我们可以查看表 1

表 1. ImageNet 子集上的性能对比。

在每类 10 张图像 (IPC 10) 的 ImageNette 上,MGD³ 达到了 66.4% 的准确率,大幅跃升超过了标准的预训练 DiT (59.1%) 甚至微调后的 MinMaxDiff (62.0%)。这证实了引导机制比简单的随机采样提取了更有用的训练信号。

多样性分析

假设 MGD³ 有效是因为它创建了更多样化的数据集。为了证明这一点,作者使用 t-SNE (一种将高维数据可视化的技术) 可视化了蒸馏后的数据集。

图 5. 展示多样性覆盖的 t-SNE 图。

图 5 :

  • 橙色三角形 (DiT): 标准模型的样本聚集在一起。它不断生成同一种“类型”的盒式磁带播放器或狗。
  • 蓝色圆形 (MGD³ - Ours): 样本广泛分布在数据分布中,覆盖了不同的聚类。

这一视觉证据证实了模式发现 + 引导成功地迫使模型探索了潜在空间。

多样性 vs. 代表性

生成建模中通常存在权衡。你可以拥有高多样性 (随机噪声非常多样化!) ,但代表性低 (它看起来不像该类别) 。或者你可以拥有高代表性 (看起来完美的狗) ,但多样性低 (所有看起来都一模一样) 。

理想情况下,你希望两者兼得

图 8. 按类别划分的代表性与多样性散点图。

图 8 绘制了不同类别的这种权衡。目标是位于右上角 (高多样性,高代表性) 。

  • DiT (橙色): 通常代表性高,但多样性较低。
  • MinMax (绿色): 多样性较高,但经常失去代表性 (图像可能看起来很奇怪) 。
  • MGD³ (蓝色): 始终占据右上区域,比替代方案更好地平衡了这两个指标。

为什么这很重要

MGD³ 的意义不仅仅在于在排行榜上获得略高的准确率分数。

  1. 计算效率: 以前的 SOTA 方法 (如 MinMax 扩散) 生成 ImageNet-100 的蒸馏数据集需要 10 小时 , 因为有微调要求。MGD³ 仅需 0.42 小时 。 这是一个巨大的提速。
  2. 易用性: 因为它使用的是无需修改的预训练模型,任何人都可以运行它。你不需要知道如何从头开始训练扩散模型;你只需要知道如何从中采样。
  3. 可扩展性: 该方法可以优雅地扩展到更大的数据集 (如 ImageNet-1K) 和更大的架构 (ResNet-101) ,证明生成式数据集蒸馏是通往高效 AI 未来的可行路径。

结论

MGD³ 提出了一个令人信服的论点: 我们并不总是需要重新训练生成模型来让它们做我们想做的事。有时候,我们只需要更好地引导它们。

通过将问题分解为识别模式 (发现) 、推动生成朝向它们 (引导) 以及知道何时让模型接管 (停止引导) ,研究人员在数据集蒸馏方面取得了最先进的结果。他们成功地将海量数据集的知识压缩成微小的合成集,并捕捉到了原始数据的完整多样性谱系。

对于学生和研究人员来说,这篇论文是一个很好的例子,展示了如何利用预训练模型的潜在空间来控制生成输出——这一技术应用前景将远远超出数据集蒸馏本身。