引言

在现代深度学习时代,我们正见证着一场“规模之战”。一方面,模型的规模呈指数级增长——拥有数十亿参数的大型语言模型 (LLM) 已成为常态。另一方面,部署这些模型所需的资源却是有限的。我们希望在手机、笔记本电脑和边缘设备上运行这些智能系统,但这些设备根本无法承受大规模密集网络的内存和计算负载。

这引发了稀疏化研究的热潮——这是一门在不破坏网络智能的前提下移除绝大部分参数的艺术。我们将这个过程称为剪枝 (Pruning)

然而,这里有一个陷阱。当你对网络进行激进剪枝时,性能几乎不可避免地会下降。主流观点认为这是一个容量问题;更小的大脑自然无法进行同样复杂的思考。但是,如果问题不仅仅在于连接的数量,而在我们找到的解的几何形状呢?

最近的研究表明,“平坦”极小值——即即使权重发生微小扰动,误差也能保持在较低水平的损失景观区域——其泛化能力远胜于“尖锐”极小值。

在这篇文章中,我们将深入探讨一篇名为 “SAFE: Finding Sparse and Flat Minima to Improve Pruning” 的论文。研究人员提出了一种新颖的优化框架,它不仅是在寻找一个稀疏网络,而是在损失景观的平坦、鲁棒区域中搜寻稀疏网络。

背景: 稀疏性 vs. 平坦性

要理解 SAFE (Sparsification via ADMM with Flatness Enforcement,通过带有平坦性强制的 ADMM 进行稀疏化) ,我们需要回顾两个通常被分开处理的基础概念: 稀疏性和锐度感知最小化。

稀疏性的挑战

从数学角度看,剪枝是一个优化问题。我们希望最小化损失函数 \(f(x)\) (如交叉熵) ,同时受限于一个约束条件,即权重向量 \(x\) 中的非零元素数量必须小于特定限制 \(d\)。

带有稀疏性约束的优化问题。

这里,\(\|x\|_0\) 是 \(L_0\)-范数,用于计算非零条目。直接求解这个问题非常困难,因为 \(L_0\)-范数是离散且不可微的。你无法为它计算梯度。标准方法通常依赖于启发式算法,比如训练一个密集模型然后切除最小的权重 (幅度剪枝) ,或者使用像 \(L_1\) 正则化 (LASSO) 这样的松弛方法。

平坦极小值的重要性

为什么有些神经网络比其他网络更能泛化到未见过的数据,即使它们的训练准确率相同?损失景观的几何形状提供了一条线索。

如果模型收敛到一个尖锐极小值 , 输入分布的微小偏移 (或噪声) 可能会将模型推向损失函数陡峭的边缘,导致高误差。相反, 平坦极小值是稳定的。小的扰动不会对损失产生太大影响。

为了显式地寻找这些平坦区域,研究人员开发了锐度感知最小化 (SAM) 。 SAM 不仅仅是最小化损失 \(f(x)\),而是最小化 \(x\) 周围一个小邻域 \(\epsilon\) 内的最大损失。

SAM 最小-最大目标函数。

这迫使优化器寻找一个不仅在单点上损失低,而且在附近所有点损失都低的区域。

SAFE 方法

SAFE 的核心创新在于将这两个截然不同的目标——稀疏性和平坦性——结合到一个单一、严谨的数学框架中。作者没有简单地拼凑一个启发式方法;他们制定了一个受约束的最小-最大优化问题。

1. 问题公式化

目标是找到既满足稀疏性约束 (\(L_0 \le d\)) 又能最小化邻域内最坏情况损失 (平坦性) 的参数 \(x\)。

联合目标: 稀疏约束的最小-最大优化。

这个方程代表了鲁棒剪枝的“圣杯”。我们想要一个既微小 (稀疏) 又非常鲁棒 (平坦) 的模型。

2. 使用增广拉格朗日法求解

由于离散的稀疏性约束,试图使用标准的梯度下降法来求解上述方程是不可能的。为了解决这个问题,作者利用了一种来自优化理论的技术,称为增广拉格朗日法 (Augmented Lagrangian) , 特别受到了 ADMM (交替方向乘子法) 框架的启发。

诀窍在于变量分裂 (Variable Splitting) 。 我们不再试图优化一个必须同时处理损失拓扑结构和硬稀疏性约束的变量 \(x\),而是引入第二个变量 \(z\)。

  • \(x\) 将专注于损失景观 (平坦性) 。
  • \(z\) 将满足稀疏性约束。
  • 我们严格强制 \(x = z\)。

我们可以这样重写这个问题:

带有指示函数的变量分裂公式。

在这里,\(I(z)\) 是一个指示函数,如果 \(z\) 足够稀疏则为 0,否则为无穷大。

指示函数定义。

现在,我们构建增广拉格朗日量 。 我们添加一个对偶变量 \(u\) (作用类似于拉格朗日乘子) 和一个二次惩罚项 \(\frac{\lambda}{2}\|x - z + u\|^2\)。这个惩罚项就像一根橡皮筋,将 \(x\) 和 \(z\) 拉在一起。如果它们分开,惩罚就会增加。

增广拉格朗日方程。

3. 迭代算法

这个公式的美妙之处在于,它允许我们要迭代地求解 \(x\)、\(z\) 和 \(u\)。我们在保持其他变量不变的情况下更新其中一个。这将一个巨大的、不可能的问题分解成了更小的、可解的子问题。

x, z 和 u 的迭代更新步骤。

让我们分解这三个具体步骤,因为它们是 SAFE 算法的引擎。

步骤 A: x-最小化 (学习平坦性)

在第一步中,我们希望找到最好的权重 \(x\),使其既能最小化损失,又能保持接近我们的稀疏目标 \(z\)。

x-最小化子问题。

注意这里的目标函数。它包含 SAM 目标 (最小化最大损失) 加上一个将 \(x\) 拉向 \(z\) 的二次项。这意味着 \(x\) 被允许是密集的,可以探索损失景观以寻找平坦区域,但它被“拴”在了稀疏解 \(z\) 上。

为了解决这个问题,作者使用梯度上升步骤近似内部的最大化 (寻找最坏情况的扰动 \(\epsilon\)) ,类似于标准的 SAM。

x 的梯度更新规则。

这个更新规则看起来像标准的随机梯度下降 (SGD) ,但有两个变化:

  1. 梯度是在扰动点 \(x + \epsilon\) 处计算的 (以确保平坦性) 。
  2. 有一个“衰减”项 \(\lambda(x - z + u)\) 引导权重朝向稀疏配置。

步骤 B: z-最小化 (强制稀疏性)

一旦我们有了更新后的 \(x\),我们需要更新 \(z\)。由于 \(z\) 的唯一工作是满足稀疏性约束并接近 \(x\),这一步实际上有一个闭式解 (Closed-form solution) !

在标准的 SAFE 版本中,这是一个欧几里得投影 。 我们只需查看向量 \((x + u)\) 并保留幅度最大的前 \(d\) 个元素,将其余设为零。这在数学上等同于“硬阈值”算子。

步骤 C: u-最大化 (对偶更新)

最后,我们更新对偶变量 \(u\)。这个变量累积了 \(x\) 和 \(z\) 之间的误差。如果 \(x\) 和 \(z\) 持续存在差异,\(u\) 就会增长,有效地增加下一轮的惩罚力度,迫使它们结合在一起。

4. 扩展: SAFE+ 与广义投影

标准的 SAFE 方法假设投影到稀疏约束的最佳方式是保留具有最大幅度的权重。然而,在许多情况下 (尤其是 LLM) ,幅度并不是衡量重要性的最佳代理指标。

作者提出了 SAFE+ , 引入了广义投影 (Generalized Projection) 。 他们不再使用标准的欧几里得距离,而是使用由正定矩阵 \(\mathbf{P}\) 定义的加权距离度量。

广义投影公式。

这使得 SAFE+ 能够整合先进的剪枝指标:

  • Optimal Brain Damage: 将 \(\mathbf{P}\) 设置为海森矩阵 (Hessian) 的对角线。
  • Wanda (用于 LLM) : 根据激活幅度设置 \(\mathbf{P}\)。

这种灵活性使得 SAFE+ 对于大型语言模型的训练后剪枝非常强大,因为计算梯度很昂贵,但使用激活统计数据却很便宜。

它真的有效吗?

理论听起来很可靠,但实验出真知。作者在图像分类 (CIFAR) 和大型语言模型 (LLaMA) 上进行了广泛的测试。

可视化景观

首先,让我们确认 SAFE 是否做到了它所宣称的: 找到稀疏平坦的极小值。

权重分布和损失景观比较。

图 1 (上图) 中,请看面板 (c) 和 (d)。

  • 面板 (c) 显示了 ADMM (没有平坦性强制) 找到的解。等高线很紧密;山谷很狭窄。最大海森特征值 (锐度的度量) 是 0.2
  • 面板 (d) 显示了 SAFE 找到的解。等高线间距很大,表明是一个宽阔、平坦的山谷。锐度显著降低,为 0.09

这一视觉确认证明了该优化策略成功地导航到了参数空间中更平坦的区域。

图像分类结果

作者在 ResNet 和 VGG 架构上将 SAFE 与现有的剪枝基线进行了比较,如 GMP (渐进幅度剪枝) 、LTH (彩票假设) 和 ADMM。

CIFAR-10/100 的验证准确率图表。

图 2 中的图表讲述了一个令人信服的故事,特别是在极度稀疏的水平上。

  • 看红线 (SAFE) 。随着稀疏度的增加,它们始终高于基线。
  • 99% 稀疏度 (意味着只剩下 1% 的权重!) 时,像 MLPrune (绿色) 和 PBW 这样的标准方法呈断崖式下跌。SAFE 保持了显著更高的准确率。

这表明,当你只有极少的参数可用时,这些参数位于损失景观的稳定、平坦区域变得至关重要。

为了详细了解数值明细,我们可以查看验证准确率表:

CIFAR-10 和 CIFAR-100 的结果表。

在 ResNet-20 (CIFAR-10) 上,在 99.5% 的稀疏度下,SAFE 达到了 79.55% 的准确率,而标准的 ADMM 表现极差 (73.72%,且方差很高) ,其他方法则大幅下降。

剪枝大型语言模型 (LLM)

也许对当今 AI 领域最相关的测试是将 SAFE 应用于 Transformer。作者采用了 SAFE+ (在投影步骤中使用 Wanda 指标) 来剪枝 LLaMA-2 和 LLaMA-3 模型。

LLaMA 模型的困惑度结果。

表 1 中,困惑度 (perplexity) 越低越好。

  • SAFE+ 在不同的模型大小 (7B, 13B, 8B) 和稀疏度水平 (50%, 60%, 2:4 结构化) 上始终优于 SparseGPT 和标准 Wanda 等基线。
  • 例如,在 LLaMA-2-7B 的 50% 稀疏度下,SAFE+ 的困惑度达到了 6.56 , 击败了 SparseGPT (6.99),甚至在某些配置下击败了密集基线 (尽管由于校准数据的原因,直接与密集模型比较需要具体分析) 。

这证明了 SAFE 不仅仅是针对小型 CNN 的理论探索;它可以扩展到现代的基础模型。

对噪声的鲁棒性

平坦极小值的理论优势之一是鲁棒性。作者通过在噪声标签 (错误标记的数据) 上进行训练来测试这一点。

噪声标签训练结果。

表 2 的结果令人震惊。当 50% 的训练标签被破坏 (噪声比 50%) 时,标准 ADMM 剪枝仅产生 59-67% 的准确率 (取决于稀疏度) 。 SAFE 达到了大约 86% 的准确率。

因为 SAFE 寻找平坦区域,它自然会忽略由错误标记数据产生的“尖锐”且不稳定的极小值,转而关注代表真实信号的更广泛模式。

消融实验: 什么最重要?

作者进行了仔细的消融研究以理解超参数的影响。一个关键因素是惩罚参数 \(\lambda\) (lambda)。

惩罚参数 lambda 的影响。

图 3 所示,\(\lambda\) 控制着权衡。

  • 小 \(\lambda\): 约束较松。模型保持准确 (密集准确率高) ,但无法真正变得稀疏 (到约束的距离大) 。
  • 大 \(\lambda\): 模型被强制严格稀疏,但激进的惩罚会损害原始密集模型的性能。
  • 最佳平衡点位于中间,使用调度器 (逐渐增加 \(\lambda\)) 可以帮助模型先学习,然后再进行稀疏化。

结论

论文 “SAFE: Finding Sparse and Flat Minima to Improve Pruning” 为模型压缩问题提供了一个有原则的答案。它成功地论证了我们不应该孤立地优化稀疏性。

通过损失景观几何的视角审视剪枝,研究人员开发了一种方法:

  1. 数学严谨性: 使用增广拉格朗日法同时在数学上强制平坦性和稀疏性。
  2. 灵活性: 可以通过广义投影 (SAFE+) 进行扩展,以整合用于 LLM 的现代剪枝指标。
  3. 性能: 在高稀疏度下提供卓越的准确率,并对噪声数据具有显著的鲁棒性。

对于学生和从业者来说,SAFE 强调了一个至关重要的教训: 在深度学习中,你如何得到解 (穿过景观的路径) 往往与解本身一样重要。稀疏网络固然好,但稀疏且平坦的网络才更“安全” (SAFE) 。