如果你曾经尝试过微调像 LLaMA 或 RoBERTa 这样的大语言模型 (LLM) ,你很可能撞上过“显存墙”。当你下载好模型,设置好 PyTorch 训练循环,点击运行,结果立刻弹出了令人绝望的 CUDA 显存溢出 (Out of Memory, OOM) 错误。

罪魁祸首通常是全参数微调 (Full-Parameter Fine-Tuning, FPFT) 。 虽然 FPFT 是使模型适应新任务的金标准——它允许模型调整每一个权重来学习新模式——但它的代价极其昂贵。它不仅需要存储模型权重,还需要同时存储梯度,以及至关重要的、针对每个参数的优化器状态 (如 AdamW 中的动量) 。

多年来,NLP 社区一直依赖于妥协方案。我们使用 参数高效微调 (Parameter-Efficient Fine-Tuning, PEFT) 方法,如 LoRA 或 Adapter,这些方法冻结主模型,只训练微小的附加层。虽然效率很高,但这只是近似方法。它们有时无法捕捉复杂的行为,特别是在推理或数学任务中。

但是,如果你不需要妥协呢?如果你能在单张 24GB 的消费级 GPU 上微调一个 70 亿参数 (7B) 的模型的所有参数,并达到与大规模集群相同的性能呢?

这就是 HiFT (Hierarchical Full Parameter Fine-Tuning,分层全参数微调) 的承诺,这是由中国东北大学和慕尼黑大学 (LMU Munich) 的研究人员提出的一种新策略。在这篇文章中,我们将解构 HiFT 的工作原理,它为何改变了高效训练的格局,以及它如何成功地将“大象装进冰箱”。

问题所在: “全”训练的高昂代价

要理解为什么 HiFT 是必要的,我们首先需要看看训练期间显存消耗的数学原理。当你使用像 AdamW 这样的标准优化器训练模型时,你的 GPU 显存主要被以下四个部分消耗:

  1. 模型参数: 模型本身的权重 (例如,16 位精度的 7B 模型约为 14GB) 。
  2. 梯度: 每个参数计算出的变化方向。
  3. 优化器状态: 用于稳定训练的梯度历史 (动量、方差) 。
  4. 激活值/残差状态: 前向传播过程中产生的中间数据,用于反向传播。

在标准的 FPFT 中,第 2 项和第 3 项是显存杀手。对于 AdamW,你需要为每一个参数存储两个状态变量。如果你有一个 7B 模型,仅这些状态就会消耗数十 GB,远超标准 RTX 3090 或 4090 的容量。

以往的尝试及其缺陷

研究人员以前曾试图解决这个问题:

  • PEFT (例如 LoRA, Prefix-Tuning) : 这些方法冻结模型并添加低秩矩阵。它们节省了显存,但可能导致信息丢失或性能下降,因为基础模型从未进化。
  • 零阶优化 (例如 MeZO) : 这些方法在不完全计算梯度的情况下估计梯度。虽然显存效率高,但它们出了名的不稳定,且性能通常明显不如基于梯度的方法。
  • LOMO: 这种方法融合了梯度计算和更新以节省显存,但需要两次前向传播,并且通常强制使用特定的量化,限制了其灵活性。

HiFT 的作者认为,我们不必放弃基于动量的优化器 (如 AdamW) 已被证实的稳定性,也不必放弃更新所有参数带来的性能优势。

解决方案: 分层微调 (HiFT)

HiFT 的核心洞见简单而精彩: 你不需要在毫秒不差的同一时刻更新所有参数。

HiFT 采用了一种“逐块”训练策略。它不是将整个神经网络的优化器状态加载到显存中,而是将模型划分为若干层的组 (块) 。在任何给定的训练步骤中,HiFT 选择一组层作为激活状态。它只更新该特定组的参数,而保持网络的其余部分冻结。在训练过程中,它循环遍历所有组,确保网络中的每个参数最终都得到更新。

工作原理: 架构

让我们利用论文中的示意图来可视化这个过程。

Figure 1: Schematic diagram of our HiFT. group represents the grouping operation of the layers. bottom2up, top2down and random are training strategies. Gray indicates that the corresponding parameters are in the frozen state,and brown indicates that the corresponding parameters are in the activated state. k is the number of groups, n is the number of layers of the given model,and BP denotes parameter update through back propagation.

如图 1 所示,模型被切分为 \(k\) 个组。在一个训练步骤中,算法选择一个特定的组 (棕色高亮显示) 为“激活”状态。

  1. 前向传播: 数据流经整个模型。
  2. 反向传播 (BP) : 计算梯度。然而,优化器只“看到”激活组的参数。
  3. 更新: 更新激活组的权重。该组的优化器状态被临时加载到 GPU,使用后即卸载 (如果不需要立即再次使用则丢弃) 。
  4. 切换: 在下一步中,选择不同的组。

这种方法显著降低了“峰值”显存需求。你在任何时候只需要存储 \(\frac{1}{k}\) 模型的梯度和优化器状态。

算法

研究人员将此形式化为一个特定的训练循环。

Algorithm 1: HiFT Training Algorithm

该算法管理着一个层队列。它支持不同的更新策略:

  • 自底向上 (bottom2up) : 从嵌入层开始向上更新到头部。
  • 自顶向下 (top2down) : 从头部向下更新到嵌入层。
  • 随机 (Random) : 打乱更新顺序。

至关重要的是,该算法采用了延迟学习率更新 (Delayed Learning Rate Update) 。 由于不同层是在不同时间更新的,如果在每一步后都改变全局学习率,可能会导致不稳定 (在同一个 epoch 内,有些层可能以高学习率更新,有些则以低学习率更新) 。HiFT 仅在所有层都更新过一次后才更新学习率调度。这确保了整个模型深度的更新幅度一致。

效率的数学原理

为什么这能节省这么多显存?让我们看看作者提供的公式。

首先,考虑标准全参数微调 (FPFT) 的显存成本 (\(\zeta\)) 。我们将模型权重的显存记为 \(\zeta_1\)。

  • AdamW 的优化器状态 (\(\zeta_2\)) 通常是模型权重的 \(2 \times\)。
  • 梯度 (\(\zeta_3\)) 是模型权重的 \(1 \times\)。

Equation 1: Memory cost of standard FPFT

因此,标准训练所需的显存大约是模型权重本身的 4 倍 (加上激活开销) 。这就是为什么一个 14GB 的模型需要 60GB+ 显存才能训练的原因。

现在,看看 HiFT。因为我们将模型分成了 \(k\) 个组,我们只存储激活组的优化器状态和梯度。

Equation 2: Memory cost of HiFT

随着 \(k\) (组数) 的增加,相对于总模型大小,状态和梯度所需的显存趋近于零。显存节省是巨大的:

Equation 3: Memory savings delta

对于一个大型模型,如果你将层分为 \(k=32\) 组,你本质上几乎从 GPU 显存中移除了优化器状态的全部负担。

理论严谨性

你可能会担心异步更新层会破坏模型的稳定性或导致无法收敛。作者通过提供泛化界解决了这个问题,证明了 HiFT 的测试损失与最佳参数之间的差距是有界的。

Equation 4: Generalization bound of HiFT

虽然数学推导很复杂,但结论令人安心: HiFT 在理论上是合理的。在标准假设下,它保证收敛,这意味着你并没有为了节省显存而牺牲数学上的有效性。

实验结果

理论听起来很棒,但实际效果如何?研究人员在广泛的模型 (RoBERTa, GPT-2, LLaMA, OPT) 和任务 (NLU, 指令微调, 数学) 上测试了 HiFT。

1. 它学得和 FPFT 一样好吗?

任何近似方法的主要担忧都是性能损失。作者在 GLUE 和 SuperGLUE 基准测试中对比了 HiFT 与标准 FPFT 及各种 PEFT 方法。

Figure 5: RoBERTa results on different fine-tuning strategies.

在图 5 中,我们看到了 RoBERTa 在各种数据集上的准确率。HiFT 变体 (橙色、黄色、粉色柱状图) 始终与标准 FPFT (蓝色柱状图) 的性能相匹配。在许多情况下,HiFT 优于 BitFit 或 Adapter 等 PEFT 方法。

此外,观察损失曲线,我们可以看到 HiFT (尽管采用了“切分”式的训练风格) 收敛得很平滑。

Figure 3: Loss curves of OPT-13B on different datasets. The parameter m of HiFT is set to 1.

2. 指令微调与推理

现代 LLM 的真正考验是指令遵循和复杂推理。研究人员使用 HiFT 微调了 LLaMA-7B 和 Mistral-7B 等模型,并在 MT-Bench 上进行了测试,这是一个评估编码、推理和角色扮演能力的具有挑战性的基准。

Figure 2: Category-wise scores of diferent fine-tuning methods on MT-bench.

上面的雷达图 (图 2) 很有说服力。

  • HiFT (橙线) : 始终延伸至最外沿,匹配或击败 FPFT (蓝线) 。
  • LoRA (绿线) : 经常落后,特别是在推理和编码等复杂类别中。

这支持了作者的观点,即全参数微调——即使是分层进行的——比低秩近似能更好地捕捉复杂的模式。

LLaMA-7B 和 13B 模型的具体结果进一步证实了这一点,特别是在需要多步推理的 GSM8K 数据集 (小学数学) 上。

Table 4: Performance comparison of different finetuning methods for LLaMA-7B and 13B.

在表 4 中,请看 GSM8K 这一列。HiFT 在 LLaMA-7B 上达到了 29.85 , 几乎与 FPFT 的 30.00 相同,而 LoRA 则显著下降至 22.87 。 这表明,对于需要深度改变模型行为的任务,更新所有参数至关重要。

3. 显存效率: 饼图

我们已经在数学上确定了 HiFT 能节省显存,但在实际 GPU 上它的细分情况如何?

Figure 6: (a), (b), (c) and (d)represent the proportion of parameters occupied by diferent parts when fine-tuning LLaMA-2 (7B).

图 6 比较了标准 FPFT (a) 和 HiFT (b)。

  • FPFT (a): “优化器状态” (黄色部分) 占据了 35.3% 的显存,梯度 (橙色) 占据了另外 17.7%
  • HiFT (b): 优化器状态缩减至仅 2.7%

这种大幅缩减使得系统能够支持更大的批次大小 (batch size) ,或者仅仅是在原本会崩溃的硬件上运行。

4. 更新顺序重要吗?

HiFT 的一个显著特点是可以选择哪些块以什么顺序更新。先训练底层还是先训练顶层有关系吗?

Figure 4: The left shows the performance of HiFT of RoBERTa_base under B2U, T2D and RAN strategies.

令人惊讶的是,图 4 (左) 显示策略 (自底向上、自顶向下或随机) 对最终准确率几乎没有影响。这种鲁棒性对并行化来说是个好消息: 它意味着未来的工作可能在不同的设备上同时训练不同的块,而无需严格强制特定的顺序。

结论与启示

HiFT 代表了向大语言模型研究普及化迈出的重要一步。通过将“全参数”微调的需求与“全显存”分配的需求解耦,它打破了阻碍许多学生和独立研究人员使用最先进模型的硬件壁垒。

主要收获:

  1. 24GB 跑 7B: HiFT 允许在单张消费级 GPU (如 RTX 3090 或 4090) 上对 LLaMA-7B 级别的模型进行全参数微调。
  2. 无性能妥协: 与 MeZO 或某些 PEFT 方法不同,HiFT 取得了与标准微调相当的结果,特别是在推理任务上。
  3. 与优化器无关: 它适用于 AdamW、SGD 或任何其他优化器,让你保留你信任的训练动态。

随着模型规模的不断增长,像 HiFT 这样优化我们如何更新权重——而不仅仅是优化哪些权重被更新——的策略将变得至关重要。对于学生和研究人员来说,这意味着对 LLM 进行深入、有意义的实验不再是那些拥有大规模 H100 集群的人的专利。你现在可以翻新整栋房子——只不过是一次翻新一个房间。