引言
我们正处于大语言模型 (LLMs) 的黄金时代。从 GPT-4 到 Llama 3,这些模型充当着推理引擎,能够表现出令人惊叹的类人行为。然而,每一位开发者、学生和研究人员都面临着一个持续存在的瓶颈: 延迟 。
核心问题在于自回归解码。为了生成一个句子,LLM 必须预测一个 token,将其附加到序列中,反馈给自己,然后再预测下一个。对于一个 100 token 的回复,模型必须按顺序运行其整个庞大的架构 100 次。这个过程未能充分利用现代 GPU 的并行处理优势,使得实时应用既昂贵又迟缓。
针对这一问题,一个流行的解决方案是投机解码 (Speculative Decoding) 。 这种技术使用一个更小、更快的“草稿模型”来猜测接下来的几个 token,然后由大的“目标模型”并行验证这些猜测。这就像让实习生起草邮件,而经理快速批准或修改它。
但问题在于: 投机解码只有在实习生 (草稿模型) 擅长特定任务时才有效。如果你让一个专门从事翻译的草稿模型去协助解决数学问题,它会猜错,目标模型会拒绝所有建议,结果你比根本不使用草稿模型还要慢。
在论文 《Context-Aware Assistant Selection for Improved Inference Acceleration with Large Language Models》 中,研究人员 Jerry Huang、Prasanna Parthasarathi、Mehdi Rezagholizadeh 和 Sarath Chandar 提出了一个巧妙的解决方案。如果我们不依赖单个草稿模型,而是拥有一个专家库呢?如果我们能训练一个“经理”来观察用户的查询并立即挑选出最适合这项工作的助手呢?
这篇博客文章将剖析他们的方法,该方法将 LLM 加速框架化为一个上下文多臂老虎机 (Contextual Bandits) 问题 , 展示了我们如何动态选择草稿模型,以在不牺牲质量的情况下最大化速度。
背景: 静态草稿生成的局限性
为了理解这里的创新点,我们首先需要巩固对现状的理解。
投机解码回顾
在标准的投机解码中,你有两个模型:
- 目标模型 (\(M_e\)) : 大、聪明但慢的模型 (例如 Llama-70B) 。
- 草稿模型 (\(M_d\)) : 小、快但准确率较低的模型 (例如 Llama-7B) 。
草稿模型快速生成一小段 token 序列 (比如 5 个 token) 。目标模型在一个前向传播中处理所有 5 个 token 以验证它们。如果它们与目标模型将要生成的内容匹配,我们就保留它们。如果不匹配,我们就丢弃错误的并恢复标准生成。
对齐问题
这个过程的效率完全取决于接受率 (Acceptance Rate) ——即目标模型同意草稿 token 的百分比。
- 高接受率: 目标模型跳过了许多生成步骤。速度大幅提升。
- 低接受率: 目标模型做了额外的工作来验证错误的猜测。零加速甚至变慢。
问题在于,单个小型草稿模型不可能精通所有事情。一个小模型可能非常擅长总结新闻 (高接受率) ,但在 Python 编程方面却很糟糕 (低接受率) 。在拥有多样化用户查询的生产环境中,单个静态草稿模型会成为域外任务的瓶颈。
核心方法: 上下文感知助手选择
作者提出了一个系统,其中推理流水线可以访问多个草稿候选者 。 这些候选者可以是完全不同的模型,也可以是在不同领域微调过的相同架构模型 (例如,一个用于编程,一个用于聊天,一个用于数学) 。
核心挑战是决策制定 : 给定一个特定的输入查询 (上下文) ,我们应该拉动哪个草稿模型 (臂) 来获得最佳的加速 (奖励) ?
研究人员将其建模为上下文多臂老虎机问题。以下是高层级的工作流程:

如上图 1 所示,该过程分为两个阶段: 离线训练和在线推理 。
1. 离线数据收集与评分
在线训练策略 (在为用户服务的同时进行学习) 既有风险又缓慢,因为你必须通过糟糕的决策来吸取教训。相反,作者采用了一种离线方法。
他们获取一个查询数据集,并通过所有可用的草稿模型和目标模型独立运行这些查询。然后,他们计算一个“分数”,代表草稿模型的帮助有多大。
对齐分数
衡量草稿模型效用的最直接指标是其输出 (\(o_{i}^{j}\)) 与目标模型输出 (\(o_{i}^{e}\)) 的相似程度。作者使用 ROUGE-L 等指标来计算相似度分数 \(f\):

纳入成本
然而,对齐并不是唯一的因素。一个草稿模型可能非常准确,但太大 (太慢) 而无法提供净加速。相反,一个极小的模型可能非常快,但准确率稍低。
为了考虑到这一点,研究人员引入了一个成本感知奖励函数。他们根据草稿模型相对于最大候选者的参数数量定义了一个成本 \(c_{i}^{j}\)。最终分数使用权衡参数 \(\alpha\) 结合了对齐度和成本:

在这里,\(\alpha\) 允许我们调整系统。\(\alpha\) 为 1 时只关心准确性;\(\alpha\) 接近 0 时优先考虑最小/最快的模型。成本函数 \(c_i\) 派生自模型的参数量 (\(p\)) :

这种设置创建了一个带标签的数据集,其中每个查询都关联了使用不同草稿模型的“奖励” (分数) 。
2. 策略训练
有了数据集,目标就是训练一个策略网络 (\(\pi\)) 。 这是一个轻量级的多层感知机 (MLP) 。
- 输入: 用户查询的嵌入 (通常是来自目标模型的句子嵌入) 。
- 输出: 可用草稿模型 (动作) 的概率分布。
目标函数 \(J^{\pi}\) 旨在最大化期望奖励:

由于动作空间 (选择哪个模型) 是离散的,因此对动作的积分是一个求和:

为了优化这一点,作者使用了 REINFORCE 算法 (一种标准的策略梯度方法) 。梯度更新如下所示:

本质上,如果某个草稿模型对特定类型的查询产生了高回报,网络就会更新其权重,以增加在未来对类似查询选择该模型的概率。
实验与结果
作者使用 T5 和 Flan-T5 模型在翻译 (IWSLT) 、摘要 (XSUM) 和数学 (GSM8K) 等各种任务上验证了这一方法。
1. 领域专业化
最有说服力的测试是策略能否区分领域专家。他们设置了一个场景,包含两个大小相同的草稿模型:
- T5-Small: 标准预训练模型。
- T5-Small-XSum: 专门针对摘要进行微调的模型。
他们在翻译和摘要任务上测试了这些模型。

表 1 (上图) 揭示了关键见解:
- T5-Small 模型加速了翻译 (1.10x) ,但减慢了摘要 (0.97x) 。
- T5-Small-XSum 模型在摘要方面表现出色 (1.21x) ,但在翻译方面很糟糕 (0.83x,变慢了) 。
- 策略 (\(\pi_\theta\)) : 它在两个任务上都实现了加速 (1.09x 和 1.17x) 。它成功地识别了上下文并将查询路由给正确的专家,避免了使用错误草稿模型的陷阱。
2. 平衡速度与准确性
通过调整奖励函数中的 \(\alpha\) 参数,可以将策略调整为更喜欢“更安全” (更准确) 的模型或“更冒险” (更快/更小) 的模型。

图 2 直观地展示了这种权衡。随着 \(\alpha\) 增加 (优先考虑对齐) ,策略转向偏好更大、更准确的草稿模型。随着 \(\alpha\) 减小 (优先考虑低成本) ,它倾向于选择最小的有效草稿模型。这证明了该系统可以灵活适应部署环境的特定延迟/计算约束。
3. 数据效率
训练辅助策略的一个主要担忧是数据需求。我们需要数百万个示例来训练这个路由器吗?

图 3 显示了解码速度随训练样本数量 (对数刻度) 的变化。值得注意的是,该策略在仅有 1,000 到 10,000 个示例时就变得有效。这意味着为一组特定模型创建一个自定义选择器在计算上是廉价的,并且不需要海量数据集。
4. “什么都不做”选项
有时,没有一个草稿模型是足够好的。例如,复杂的数学推理通常需要目标模型的全部能力;小型草稿模型只会产生幻觉,导致拒绝并减慢速度。
作者在策略的选项中添加了一个“自回归”动作——实际上是让经理决定“我自己来做”。

在表 4 中,在 GSM8K 数学数据集上进行测试,标准起草将推理速度降低到约 0.75x,因为草稿模型失败了。然而, 策略 (\(\pi_\theta\)) 达到了约 0.95x 的速度。它学会了识别数学查询并避免使用草稿模型,默认接近标准的自回归速度。虽然它在这里没有获得加速,但它防止了盲目投机解码会遭受的灾难性减速。
5. 自草稿生成 (提前退出)
该方法也适用于自投机解码 (Self-Speculative Decoding) , 即草稿“模型”只是目标模型的早期层 (“提前退出”) 。

表 6 表明,即使在选择不同的“层”作为起草者时,该策略也有助于保持性能,尤其是在可用自回归选项的情况下。
结论与启示
这项研究将投机解码从一种静态的、“碰运气式”的优化转变为一个动态的、智能的系统。通过将助手选择框架化为上下文多臂老虎机问题,作者提供了一个框架,该框架:
- 消除了由不匹配的草稿模型导致减速的风险 。
- 实现了模块化 , 允许系统组合多个专门的小模型,而不是依赖一个通用的起草者。
- 只需极小的开销 , 策略轻量级,易于训练且运行速度快。
随着 LLM 持续增长,“一个模型统治所有”的方法在推理计算上正变得难以为继。我们很可能会走向模型生态系统——由成群的专业化、轻量级助手支持的大型推理引擎。本文为使这种生态系统高效运行的路由逻辑提供了蓝图。
对于学生和工程师来说,结论很明确: 优化不仅仅是让单个模型更快;更是让模型之间的决策过程更聪明。
](https://deep-paper.org/en/paper/2408.08470/images/cover.png)