TinyFusion: 如何在不失魔力的情况下压缩扩散 Transformer

如果你最近关注生成式 AI 领域,你会知道 扩散 Transformer (Diffusion Transformers, DiTs) 是目前的重量级选手。从 OpenAI 的 Sora 到 Stable Diffusion 3,用 Transformer 架构取代传统的 U-Net 骨干网络,解锁了图像和视频生成方面惊人的能力。

但这里有个问题: 这些模型体量巨大。它们伴随着过多的参数量,使得在现实世界的应用中运行缓慢且昂贵。如果你想在移动设备或标准消费级 GPU 上部署高质量的图像生成器,通常会束手无策。

标准的解决方案是 剪枝 (Pruning)——移除神经网络的部分内容以使其变小。但如何决定切掉哪些部分呢?传统观点建议根据误差指标移除“最不重要”的层。然而,一篇名为 “TinyFusion: Diffusion Transformers Learned Shallow” 的新论文指出,这种传统观点对于扩散模型是错误的。

在这篇文章中,我们将深入探讨 TinyFusion。我们将探索为什么传统的剪枝方法在 DiT 上会失败,作者如何引入一种能够预测未来性能的“可学习”剪枝方法,以及他们如何实现了 2 倍的加速而几乎没有损失图像质量。


问题所在: 为什么“聪明”的剪枝通常会失败

要让模型更快,我们通常有两种选择:

  1. 宽度剪枝 (Width Pruning): 让层变窄 (减少通道/神经元) 。
  2. 深度剪枝 (Depth Pruning): 从网络中移除整个层。

对于 Transformer 而言, 深度剪枝通常是速度方面的更好选择。因为 GPU 是大规模并行处理器,它们可以轻松处理宽层。然而,Transformer 中的层必须按顺序处理。如果你有 28 层,GPU 必须先完成第 1 层才能开始第 2 层。因此,从理论上讲,将层数减半可以使推理速度加倍。

作者在下图中清楚地展示了这一优势。虽然宽度剪枝 (蓝线) 难以获得速度提升,但深度剪枝 (红色虚线) 提供了几乎线性的加速比。

深度剪枝与宽度剪枝的加速比对比。

校准损失悖论

所以,深度剪枝是正确的方向。挑战在于决定删除哪些层。

大多数现有的方法使用一种称为 校准损失 (Calibration Loss) 的指标。逻辑很简单: 移除一层,检查误差 (损失) 增加了多少,并保留那些如果被移除会导致最大误差的层。你希望剪枝后的模型在当下看起来尽可能接近原始模型。

TinyFusion 背后的研究人员发现了一个悖论: 一个具有低初始误差的剪枝模型并不一定能在微调过程中学得很好。

他们进行了一项实验,随机以 100,000 种不同的方式对 DiT 模型进行剪枝,然后对它们进行微调。他们发现,初始误差最低 (Min. Loss) 的模型在微调后的表现实际上比初始误差较高的模型更差

显示校准损失分布的直方图。

如上所示,通过标准灵敏度分析 (即最小化损失) 找到的模型,最终性能平平。而 TinyFusion 运作的“可学习 (Learnable)”区域——虽然初始损失较高,但恢复效果却好得多。

结论: 我们不应该寻找剪枝后懂得最多的模型。我们应该寻找在微调期间学得最快的模型。这种属性称为 恢复能力 (Recoverability)


解决方案: TinyFusion

TinyFusion 是一个新的框架,它不将剪枝视为一次性的计算,而是一个 可学习的过程 。 该方法不使用启发式评分,而是训练剪枝选择本身。

核心思想是同时优化两件事:

  1. 掩码 (The Mask): 保留哪些层,丢弃哪些层。
  2. 权重 (The Weights): 模拟如果丢弃这些层,模型将如何适应。

1. 可微采样 (Differentiable Sampling)

研究人员将层选择视为一个概率分布。对于每一块层,都有一个概率被分配给不同的剪枝配置 (掩码) 。

问题在于“选择一层”是一个离散的决定 (要么保留,要么不保留) ,这阻断了反向传播所需的梯度流。为了解决这个问题,他们使用了 Gumbel-Softmax 采样 。 这允许网络在训练期间“软性”地采样掩码,使选择过程变得可微。

\[ y = { \mathrm { o n e - h o t } } \left( { \frac { \exp ( ( g _ { i } + \log p _ { i } ) / \tau ) } { \sum _ { j } \exp ( ( g _ { j } + \log p _ { j } ) / \tau ) } } \right) . \]

这个方程本质上是向概率 (\(p\)) 添加噪声 (\(g\)),并使用温度参数 (\(\tau\)) 逐渐从软概率过渡到硬决策 (0 或 1) 。

2. 使用 LoRA 进行恢复能力估计

这是 TinyFusion 的天才之处。要知道一个剪枝后的模型是否“可恢复”,通常需要微调数小时。你不可能在剪枝算法的每一步训练中都这样做。

为了解决这个问题,作者引入了一种轻量级的、协同优化的权重更新。他们不更新整个庞大的模型来测试恢复能力,而是使用 LoRA (低秩适应)

带有可微剪枝掩码和 LoRA 的前向传播。

如上图 3 所示,系统将掩码 (\(m_i\)) 应用于层。同时,它将 LoRA 更新 (橙色块 \(B\) 和 \(A\)) 应用于权重。

优化目标变为:

\[ \underset { \{ p ( \mathfrak { m } _ { k } ) \} } { \operatorname* { m i n } } \underset { \Delta \Phi } { \underbrace { \operatorname* { m i n } } } \mathbb { E } _ { \boldsymbol { x } , \{ \mathfrak { m } _ { k } \sim p ( \mathfrak { m } _ { k } ) \} } \big [ \mathcal { L } ( \boldsymbol { x } , \Phi + \Delta \Phi , \{ \mathfrak { m } _ { k } \} \big ] , \]

通俗地说: “找到掩码的概率分布 (\(p\)),使得如果我们稍微更新模型 (\(\Delta \Phi\)),损失最小化。”

这本质上是在剪枝循环中模拟了“未来的微调”。模型学会了偏好那些对权重更新反应良好的掩码。

3. 工作流程

整个过程分为两个阶段:

  1. 搜索 (训练) : 模型使用可学习的掩码和 LoRA 适配器运行。随着时间的推移,概率分布发生变化。好的配置获得更高的概率;坏的被丢弃。
  2. 微调: 一旦确定了最佳层,掩码就被固定,LoRA 被丢弃,随后对生成的较小模型进行适当的微调。

可学习剪枝工作流程图。

在上面的可视化 (图 2) 中,你可以看到从“混合采样 (Mixed Sampling)” (探索不同的层组合) 到“置信采样 (Confident Sampling)” (确定最佳的浅层架构) 的转变。


加速恢复: 掩码知识蒸馏

一旦 TinyFusion 确定了要保留的最佳层,模型就被剪枝了。现在,它需要重新训练以恢复其原始质量。这是标准程序,通常通过 知识蒸馏 (Knowledge Distillation, KD) 完成——即小的“学生”模型试图模仿大的“教师”模型。

然而,作者遇到了一个扩散 Transformer 特有的问题: 巨量激活 (Massive Activations)

在大型 Transformer 中,某些神经元有时会以巨大的数值被激活 (离群值) 。虽然教师模型可以很好地处理这些情况,但强迫一个较小的、剪枝后的学生模型去模仿这些精确的巨大数值,可能会破坏训练稳定性并导致损失爆炸。

DiTs 中巨量激活的可视化。

在上图 8 中,你可以看到激活值的尖峰。为了解决这个问题,作者提出了 掩码表征 KD (Masked Representation KD)

他们简单地应用了一个阈值。如果教师或学生中的激活值过大 (离群值) ,则在损失计算中将其掩盖 (忽略) 。

掩码知识蒸馏示意图。

这确保了学生专注于学习数据的核心结构,而不是追逐数值异常,从而实现显著更快且更稳定的收敛。


实验与结果

效果如何?结果令人印象深刻。

研究人员在 DiT-XL/2 (一个在 ImageNet 上训练的标准扩散 Transformer) 上测试了 TinyFusion。他们的目标是将 28 层的模型压缩到 14 层 (50% 剪枝) 。

定量结果

显示 TinyFusion 与基线方法性能对比的表格。

查看 表 1 :

  • 原始 DiT-XL/2: FID 为 2.27 (越低越好) 。
  • 现有方法 (ShortGPT, Flux-Lite): 当剪枝到 14 层时,它们的 FID 分数飙升至 20 以上。它们基本上把模型搞坏了。
  • TinyFusion (TinyDiT-D14): 取得了 2.86 的 FID。

这是一个巨大的提升。TinyDiT 模型的运行速度为 13.54 次迭代/秒 (几乎是原始速度 6.91 it/s 的 2 倍) ,同时保持了与原始模型肉眼难以区分的图像质量。

此外,TinyFusion 仅用了 原始预训练成本的 7% 就实现了这一目标。

学习过程的可视化

观察模型如何“决定”保留哪些层非常有趣。下图跟踪了训练迭代中的剪枝决策。

随时间变化的剪枝决策可视化。

  • 底层 (索引 0-3) : 模型很快决定这些是必不可少的 (实线) 。
  • 中间层: 有一段探索期 (模糊部分) ,模型在犹豫。
  • 收敛: 到第 10,000 步时,模型做出了硬性决定,有效地锁定了最终架构。

定性结果

数字固然重要,但对于图像生成,我们需要看图。以下是由剪枝后的 TinyDiT-D14 生成的样本。

TinyDiT-D14 生成的图像。

这些图像清晰、连贯,与更大模型生成的图像无法区分。

该方法也具有良好的泛化能力。作者将其应用于其他架构,如 SiT (Scalable Interpolant Transformers)MAR (Masked Autoregressive models) , 取得了类似的成功。

TinySiT 和 TinyMAR 生成的图像。


结论与关键要点

TinyFusion 代表了我们在压缩生成模型思维方式上的重大转变。

  1. 不要相信即时损失: 对于扩散模型,一个看起来“坏掉”的剪枝模型实际上可能是微调的最佳候选者。
  2. 让剪枝变得可学习: 通过将层选择视为一个可微采样问题,我们可以利用梯度下降来找到最佳架构。
  3. 模拟未来: 在搜索阶段使用 LoRA 允许剪枝算法“窥视”未来,看看某种配置的恢复能力如何。

随着我们迈向在笔记本电脑和手机上运行复杂的 AI 模型,像 TinyFusion 这样的技术将变得至关重要。它们让我们能够去除这些庞大神经网络的冗余 (脂肪) ,只留下创造魔力所需的精髓 (肌肉) 。