如果你曾经尝试过微调像 LLaMA 或 RoBERTa 这样的大语言模型 (LLM) ,你很可能撞上过“显存墙”。当你下载好模型,设置好 PyTorch 训练循环,点击运行,结果立刻弹出了令人绝望的 CUDA 显存溢出 (Out of Memory, OOM) 错误。
罪魁祸首通常是全参数微调 (Full-Parameter Fine-Tuning, FPFT) 。 虽然 FPFT 是使模型适应新任务的金标准——它允许模型调整每一个权重来学习新模式——但它的代价极其昂贵。它不仅需要存储模型权重,还需要同时存储梯度,以及至关重要的、针对每个参数的优化器状态 (如 AdamW 中的动量) 。
多年来,NLP 社区一直依赖于妥协方案。我们使用 参数高效微调 (Parameter-Efficient Fine-Tuning, PEFT) 方法,如 LoRA 或 Adapter,这些方法冻结主模型,只训练微小的附加层。虽然效率很高,但这只是近似方法。它们有时无法捕捉复杂的行为,特别是在推理或数学任务中。
但是,如果你不需要妥协呢?如果你能在单张 24GB 的消费级 GPU 上微调一个 70 亿参数 (7B) 的模型的所有参数,并达到与大规模集群相同的性能呢?
这就是 HiFT (Hierarchical Full Parameter Fine-Tuning,分层全参数微调) 的承诺,这是由中国东北大学和慕尼黑大学 (LMU Munich) 的研究人员提出的一种新策略。在这篇文章中,我们将解构 HiFT 的工作原理,它为何改变了高效训练的格局,以及它如何成功地将“大象装进冰箱”。
问题所在: “全”训练的高昂代价
要理解为什么 HiFT 是必要的,我们首先需要看看训练期间显存消耗的数学原理。当你使用像 AdamW 这样的标准优化器训练模型时,你的 GPU 显存主要被以下四个部分消耗:
- 模型参数: 模型本身的权重 (例如,16 位精度的 7B 模型约为 14GB) 。
- 梯度: 每个参数计算出的变化方向。
- 优化器状态: 用于稳定训练的梯度历史 (动量、方差) 。
- 激活值/残差状态: 前向传播过程中产生的中间数据,用于反向传播。
在标准的 FPFT 中,第 2 项和第 3 项是显存杀手。对于 AdamW,你需要为每一个参数存储两个状态变量。如果你有一个 7B 模型,仅这些状态就会消耗数十 GB,远超标准 RTX 3090 或 4090 的容量。
以往的尝试及其缺陷
研究人员以前曾试图解决这个问题:
- PEFT (例如 LoRA, Prefix-Tuning) : 这些方法冻结模型并添加低秩矩阵。它们节省了显存,但可能导致信息丢失或性能下降,因为基础模型从未进化。
- 零阶优化 (例如 MeZO) : 这些方法在不完全计算梯度的情况下估计梯度。虽然显存效率高,但它们出了名的不稳定,且性能通常明显不如基于梯度的方法。
- LOMO: 这种方法融合了梯度计算和更新以节省显存,但需要两次前向传播,并且通常强制使用特定的量化,限制了其灵活性。
HiFT 的作者认为,我们不必放弃基于动量的优化器 (如 AdamW) 已被证实的稳定性,也不必放弃更新所有参数带来的性能优势。
解决方案: 分层微调 (HiFT)
HiFT 的核心洞见简单而精彩: 你不需要在毫秒不差的同一时刻更新所有参数。
HiFT 采用了一种“逐块”训练策略。它不是将整个神经网络的优化器状态加载到显存中,而是将模型划分为若干层的组 (块) 。在任何给定的训练步骤中,HiFT 选择一组层作为激活状态。它只更新该特定组的参数,而保持网络的其余部分冻结。在训练过程中,它循环遍历所有组,确保网络中的每个参数最终都得到更新。
工作原理: 架构
让我们利用论文中的示意图来可视化这个过程。

如图 1 所示,模型被切分为 \(k\) 个组。在一个训练步骤中,算法选择一个特定的组 (棕色高亮显示) 为“激活”状态。
- 前向传播: 数据流经整个模型。
- 反向传播 (BP) : 计算梯度。然而,优化器只“看到”激活组的参数。
- 更新: 更新激活组的权重。该组的优化器状态被临时加载到 GPU,使用后即卸载 (如果不需要立即再次使用则丢弃) 。
- 切换: 在下一步中,选择不同的组。
这种方法显著降低了“峰值”显存需求。你在任何时候只需要存储 \(\frac{1}{k}\) 模型的梯度和优化器状态。
算法
研究人员将此形式化为一个特定的训练循环。

该算法管理着一个层队列。它支持不同的更新策略:
- 自底向上 (bottom2up) : 从嵌入层开始向上更新到头部。
- 自顶向下 (top2down) : 从头部向下更新到嵌入层。
- 随机 (Random) : 打乱更新顺序。
至关重要的是,该算法采用了延迟学习率更新 (Delayed Learning Rate Update) 。 由于不同层是在不同时间更新的,如果在每一步后都改变全局学习率,可能会导致不稳定 (在同一个 epoch 内,有些层可能以高学习率更新,有些则以低学习率更新) 。HiFT 仅在所有层都更新过一次后才更新学习率调度。这确保了整个模型深度的更新幅度一致。
效率的数学原理
为什么这能节省这么多显存?让我们看看作者提供的公式。
首先,考虑标准全参数微调 (FPFT) 的显存成本 (\(\zeta\)) 。我们将模型权重的显存记为 \(\zeta_1\)。
- AdamW 的优化器状态 (\(\zeta_2\)) 通常是模型权重的 \(2 \times\)。
- 梯度 (\(\zeta_3\)) 是模型权重的 \(1 \times\)。

因此,标准训练所需的显存大约是模型权重本身的 4 倍 (加上激活开销) 。这就是为什么一个 14GB 的模型需要 60GB+ 显存才能训练的原因。
现在,看看 HiFT。因为我们将模型分成了 \(k\) 个组,我们只存储激活组的优化器状态和梯度。

随着 \(k\) (组数) 的增加,相对于总模型大小,状态和梯度所需的显存趋近于零。显存节省是巨大的:

对于一个大型模型,如果你将层分为 \(k=32\) 组,你本质上几乎从 GPU 显存中移除了优化器状态的全部负担。
理论严谨性
你可能会担心异步更新层会破坏模型的稳定性或导致无法收敛。作者通过提供泛化界解决了这个问题,证明了 HiFT 的测试损失与最佳参数之间的差距是有界的。

虽然数学推导很复杂,但结论令人安心: HiFT 在理论上是合理的。在标准假设下,它保证收敛,这意味着你并没有为了节省显存而牺牲数学上的有效性。
实验结果
理论听起来很棒,但实际效果如何?研究人员在广泛的模型 (RoBERTa, GPT-2, LLaMA, OPT) 和任务 (NLU, 指令微调, 数学) 上测试了 HiFT。
1. 它学得和 FPFT 一样好吗?
任何近似方法的主要担忧都是性能损失。作者在 GLUE 和 SuperGLUE 基准测试中对比了 HiFT 与标准 FPFT 及各种 PEFT 方法。

在图 5 中,我们看到了 RoBERTa 在各种数据集上的准确率。HiFT 变体 (橙色、黄色、粉色柱状图) 始终与标准 FPFT (蓝色柱状图) 的性能相匹配。在许多情况下,HiFT 优于 BitFit 或 Adapter 等 PEFT 方法。
此外,观察损失曲线,我们可以看到 HiFT (尽管采用了“切分”式的训练风格) 收敛得很平滑。

2. 指令微调与推理
现代 LLM 的真正考验是指令遵循和复杂推理。研究人员使用 HiFT 微调了 LLaMA-7B 和 Mistral-7B 等模型,并在 MT-Bench 上进行了测试,这是一个评估编码、推理和角色扮演能力的具有挑战性的基准。

上面的雷达图 (图 2) 很有说服力。
- HiFT (橙线) : 始终延伸至最外沿,匹配或击败 FPFT (蓝线) 。
- LoRA (绿线) : 经常落后,特别是在推理和编码等复杂类别中。
这支持了作者的观点,即全参数微调——即使是分层进行的——比低秩近似能更好地捕捉复杂的模式。
LLaMA-7B 和 13B 模型的具体结果进一步证实了这一点,特别是在需要多步推理的 GSM8K 数据集 (小学数学) 上。

在表 4 中,请看 GSM8K 这一列。HiFT 在 LLaMA-7B 上达到了 29.85 , 几乎与 FPFT 的 30.00 相同,而 LoRA 则显著下降至 22.87 。 这表明,对于需要深度改变模型行为的任务,更新所有参数至关重要。
3. 显存效率: 饼图
我们已经在数学上确定了 HiFT 能节省显存,但在实际 GPU 上它的细分情况如何?

图 6 比较了标准 FPFT (a) 和 HiFT (b)。
- FPFT (a): “优化器状态” (黄色部分) 占据了 35.3% 的显存,梯度 (橙色) 占据了另外 17.7% 。
- HiFT (b): 优化器状态缩减至仅 2.7% 。
这种大幅缩减使得系统能够支持更大的批次大小 (batch size) ,或者仅仅是在原本会崩溃的硬件上运行。
4. 更新顺序重要吗?
HiFT 的一个显著特点是可以选择哪些块以什么顺序更新。先训练底层还是先训练顶层有关系吗?

令人惊讶的是,图 4 (左) 显示策略 (自底向上、自顶向下或随机) 对最终准确率几乎没有影响。这种鲁棒性对并行化来说是个好消息: 它意味着未来的工作可能在不同的设备上同时训练不同的块,而无需严格强制特定的顺序。
结论与启示
HiFT 代表了向大语言模型研究普及化迈出的重要一步。通过将“全参数”微调的需求与“全显存”分配的需求解耦,它打破了阻碍许多学生和独立研究人员使用最先进模型的硬件壁垒。
主要收获:
- 24GB 跑 7B: HiFT 允许在单张消费级 GPU (如 RTX 3090 或 4090) 上对 LLaMA-7B 级别的模型进行全参数微调。
- 无性能妥协: 与 MeZO 或某些 PEFT 方法不同,HiFT 取得了与标准微调相当的结果,特别是在推理任务上。
- 与优化器无关: 它适用于 AdamW、SGD 或任何其他优化器,让你保留你信任的训练动态。
随着模型规模的不断增长,像 HiFT 这样优化我们如何更新权重——而不仅仅是优化哪些权重被更新——的策略将变得至关重要。对于学生和研究人员来说,这意味着对 LLM 进行深入、有意义的实验不再是那些拥有大规模 H100 集群的人的专利。你现在可以翻新整栋房子——只不过是一次翻新一个房间。
](https://deep-paper.org/en/paper/2401.15207/images/cover.png)