大型语言模型 (LLM) 的爆发使强大的 AI 变得触手可及,但定制这些模型仍然是一场硬件噩梦。虽然使用像 Llama-2 或 GPT-3 这样的预训练模型相对便宜,但微调 (Fine-tuning) ——即针对医疗数据、代码生成或特定写作风格对其进行专业化训练——需要巨大的计算资源。

例如,微调一个 650 亿参数的模型可能需要高达 780 GB 的 GPU 内存。这实际上为定制最先进模型的能力设置了企业级门槛。

在本文中,我们将深入探讨 Meta AI 的一篇研究论文 TokenTune , 它提出了一个反直觉的解决方案: 为了有效地学习,模型不需要对输入中的每一个 Token 都进行反向传播。 通过随机选择一部分 Token 进行梯度计算,TokenTune 在不牺牲性能的情况下大幅削减了内存需求。

微调的瓶颈

要理解为什么需要 TokenTune,首先需要剖析训练期间 GPU 内存的去向。当你微调一个 Transformer 模型时,内存主要被三个部分消耗:

  1. 模型参数 (Model Parameters) : 神经网络本身的权重。
  2. 梯度与优化器状态 (Gradients & Optimizer States) : 更新这些权重所需的数据。
  3. 中间激活值 (Intermediate Activations) : 在前向传播过程中每一层计算出的值,这些值必须被存储 (缓存) 以用于在反向传播过程中计算梯度。

AI 社区在缩减前两部分方面已经做得非常出色。 参数高效微调 (PEFT) 方法,如 LoRA (低秩自适应) ,冻结主模型并仅训练微小的适配器层,从而减少梯度存储。 量化方法,如 QLoRA,通过使用低精度数值 (例如 4 位整数代替 16 位浮点数) 来压缩模型参数本身。

然而,机器中还潜伏着一个“幽灵”: 中间激活值 。 即使使用 LoRA 和 QLoRA,模型仍然必须缓存整个序列长度的激活值来执行反向传播。对于长序列 (现代 LLM 的标配) ,这种激活值内存成为了新的瓶颈。

如下图所示,虽然 QLoRA 等方法减少了内存,但结合 TokenTune 可以实现更大幅度的削减。

图 1: 不同微调方法的内存使用情况比较。TokenTune 结合 QLoRA 的内存使用量明显少于单独使用 QLoRA。

这张图突出了该论文的主要贡献: 简单地将 TokenTune 添加到现有方法 (如 QLoRA) 中,可以将内存使用量减少到仅为 QLoRA 单独所需内存的近三分之一。

核心概念: Token 选择

TokenTune 背后的假设植根于稀疏性 。 先前的研究表明,并非所有的神经元或 Token 对模型的学习过程都有同等的贡献。有些 Token 携带任务的“信号”,而其他 Token 只是噪音或结构填充物。

TokenTune 利用这一点引入了 Token 选择策略。

  1. 前向传播 (上下文是关键) : 模型读取整个输入序列。所有 Token 都会被处理,以便模型理解完整的上下文 (自注意力机制需要看到整个句子) 。
  2. 反向传播 (选择性学习) : 模型为随机选择的 Token 子集 (\(k\)) 计算梯度。未选中的 Token 被“冻结”,实际上充当了旁观者的角色。

因为我们不计算未选中 Token 的梯度,所以我们不需要缓存它们的中间激活值。我们只存储那少数被选中的 Token 所需的数据。

图 2: TokenTune 架构。蓝点代表活跃的梯度/缓存的激活值;灰色圆圈代表冻结的 Token。

图 2 可视化了这一过程。注意输入 \(x\) 是如何被拆分的。一部分 Token (蓝色) 流经完整的计算图,梯度被跟踪。其余的 (灰色) 流经 no_grad (无梯度) 路径。它们参与注意力机制——允许蓝色 Token “关注”它们——但它们不会触发用于反向传播的密集内存缓存。

数学公式表达

让我们从数学角度拆解它是如何工作的。我们将输入隐藏状态 \(h\) 分为两组:

  • \(\mathcal{G}\): 选中的 Token 组 (大小为 \(k\)) 。
  • \(\bar{\mathcal{G}}\): 未选中的 Token 组。

为了适应这一点,目标函数略有变化。对于分类任务,我们可能会在 Transformer 之上使用 MLP (多层感知机) 头。TokenTune 不聚合所有 Token,而是只聚合选中的 Token \(\mathcal{G}\):

公式 1: 仅使用选中 Token 的分类目标函数。

同样,对于语言建模 (预测下一个 Token) ,交叉熵损失仅在选中的 Token 上计算:

公式 2: 语言建模目标函数,仅对选中 Token 求和损失。

这种简单的转变决定了哪些激活值必须存储在内存中。

优化密集层 (Dense Layers)

在标准的 Transformer 密集层 (前馈网络) 中,我们要利用权重 \(W\) 和偏置 \(b\) 从输入 \(h\) 计算输出 \(a\)。微积分的链式法则告诉我们,要更新权重 \(W\),我们需要存储的输入 \(h\):

公式 3: 密集层的标准梯度计算,显示对 h 的依赖。

优化的关键就在这里。如果我们决定不需要对未选中的 Token \(\bar{\mathcal{G}}\) 进行反向传播,那么损失相对于这些激活值的梯度就变为零。

公式 4: 损失相对于输出激活值的梯度对于未选中组为零。

因为未选中组的梯度为零,涉及其特定输入激活值 \(h_{\bar{\mathcal{G}}}\) 的项从权重更新方程中消失了。

公式 5: 最终梯度计算显示我们只需要缓存 h_G。

这意味着我们只需要缓存 \(h_\mathcal{G}\) 。 序列其余部分的激活值 (\(h_{\bar{\mathcal{G}}}\)) 可以在前向传播后立即丢弃,从而节省大量 GPU 内存。

在实践中 (例如在 PyTorch 中) ,这是通过显式地拆分前向传播来实现的。选中的 Token 正常处理,而未选中的 Token 则在 torch.no_grad() 上下文块中处理:

公式 6: 拆分密集层前向传播的实现逻辑。

优化注意力层 (Attention Layers)

注意力机制更复杂,因为 Token 之间会相互作用。一个选中的 Token 可能需要“关注”一个未选中的 Token 才能理解上下文。因此,我们不能简单地删除未选中的 Token。

TokenTune 通过拆分查询 (\(Q\))、键 (\(K\)) 和值 (\(V\)) 的投影来处理这个问题。

  1. 投影: 我们为两组都计算 \(Q, K, V\)。但是,未选中组 \(\bar{\mathcal{G}}\) 的计算不进行梯度跟踪。
  2. 注意力: 选中的 Token \(\mathcal{G}\) 关注所有内容 (包括 \(\mathcal{G}\) 和 \(\bar{\mathcal{G}}\)) 。

下面的方程展示了这种拆分。请注意,注意力计算 (softmax) 使用了键的级联 (\([K_{\bar{\mathcal{G}}}, K_{\mathcal{G}}]\)) 和值的级联 (\([V_{\bar{\mathcal{G}}}, V_{\mathcal{G}}]\)) ,从而保留了完整的上下文。

公式 7: 注意力机制方程,展示选中 Token 如何关注整个序列。

关键在于,对于选中组 \(h_\mathcal{G}\),我们按如下方式计算注意力。这是需要缓存的路径:

公式 8: 选中组的注意力计算。

对于未选中组 \(h_{\bar{\mathcal{G}}}\),我们在 no_grad 块内执行计算 (由下图中的括号符号表示) 。计算这些值是为了传递给下一层,但它们的中间状态不会被缓存。

公式 9: 包裹在 no_grad 中的未选中组的注意力计算。

最后,前馈步骤遵循前面讨论的密集层相同的模式:

公式 10: 选中组的前馈方程。

公式 11: 未选中组的前馈方程。

实验结果

理论听起来不错,但丢弃一半 (或更多) Token 的梯度会破坏模型性能吗?研究人员在中型模型 (BERT) 和大型模型 (Llama 2) 上对此进行了测试。

1. 中型模型 (BERT)

研究人员在 GLUE 基准测试 (一套标准的自然语言理解任务) 上测试了 TokenTune。他们将 TokenTune 与全量微调以及其他高效方法 (如 Adapters、BitFit 和 LoRA) 进行了比较。

表 1: BERT-large 的 GLUE 基准测试结果。TokenTune 的表现与全量微调相当。

表 1 的结果令人信服。TokenTune 取得了 82.1 的平均分,与全量微调 (82.8) 和 LoRA (81.9) 几乎相同。这证实了即使只对一部分 Token 进行反向传播,模型也能学习到鲁棒的表示。

我们需要多少 Token? 最有趣的发现之一是实际上需要的 Token 是多么。研究人员改变了训练位置的数量 (\(k\)) ,并测量了在 MRPC 和 STS-B 任务上的性能。

图 3: 左: 内存扩展与批次大小的关系。右: 性能与训练 Token 数量的关系。

如图 3 (右) 所示,性能提升非常快。仅训练 32 个位置 (在更长的序列中) ,模型就能达到接近最优的性能。

图 3 (左) 展示了内存的节省。TokenTune (蓝色实线) 的增长比全量微调 (紫色虚线) 平缓得多。当与 LoRA 结合使用时 (青色线) ,内存占用只是基线的一小部分。

2. 大型语言模型 (Llama 2)

对于 LLM 来说,赌注更高。研究人员使用指令微调 (教导模型遵循命令) 对 Llama2-7B 进行了微调。他们在 MMLU、HellaSwag 和 TruthfulQA 等困难的推理基准上评估了模型。

表 2: Llama2-7B 的少样本评估。TokenTune 结合 LoRA/QLoRA 保持了高准确率。

表 2 显示, Llama 7B w/ TokenTune (平均 61.23) 实际上优于基础 Llama 7B 模型 (60.73) ,并且与 LoRA (62.20) 不相上下。

作者还探讨了“选择比例” (Selection Ratio) ——即被选中进行反向传播的 Token 百分比。

表 3: 选择比例对 TokenTune 和 TokenTune+LoRA 的性能和内存的影响。

表 3 续: TokenTune+QLoRA 的选择比例影响。

观察上面的表格,我们发现一个有趣的趋势: 更高的选择比例并不严格等同于更好的性能。 在许多情况下,仅选择 20% 到 30% 的 Token 就能获得最佳结果。这表明 TokenTune 可能还充当了正则化器,通过 (随机 Token 选择) 引入噪声来防止模型对微调数据过拟合。

3. 终极内存解锁

这篇论文最重要的结果是内存使用分析。通过将 TokenTune 与 QLoRA (量化 LoRA) 相结合,内存需求急剧下降。

图 4: Llama2-7B 的内存使用比较。TokenTune + QLoRA 是最高效的方法。

图 4 可视化了这种扩展情况。紫色虚线 (全量微调) 高居 ~90GB。红色柱状图 (TokenTune + QLoRA) 则大幅降低。

下表 4 详细列出的确切数字令人震惊。对于 25% 的选择比例,TokenTune + QLoRA 仅需要 17.2 GiB 的内存,而全量微调则需要 91.4 GiB 。 这使得微调 70 亿参数模型完全在消费级 GPU (如 NVIDIA RTX 3090 或 4090) 的范围内。

表 4: 详细的 GPU 内存使用数据。

结论

TokenTune 提出了一个简单而有力的观察: 我们通过为输入序列中的每一个 Token 计算梯度来浪费内存。通过将建立上下文的前向传播与更新权重的反向传播解耦,TokenTune 提供了一个新的优化维度。

主要收获如下:

  1. 上下文 \(\neq\) 学习: 模型需要看到所有 Token 才能理解句子,但它只需要在其中少数几个上进行纠正就能学会任务。
  2. 可组合性: TokenTune 不是 LoRA 或量化的替代品;它是一个效能倍增器。它解决了其他方法所忽略的一个内存消耗大户 (激活值) 。
  3. 可访问性: 通过将内存需求减少高达 79% (当与 QLoRA 结合时) ,TokenTune 让我们离一个高性能 LLM 微调可以在个人硬件上进行,而不需要庞大服务器农场的世界更近了一步。

对于研究 Transformer 的学生和研究人员来说,TokenTune 强调了质疑基本假设的重要性——比如反向传播需要完整计算图这一想法。有时候,少即是多。