在人工智能 (AI) 领域——尤其是在自然语言处理 (NLP) 领域——过去几年的口号一直是 “越大越好”。我们见证了 GPT-3、T5 和 Megatron 等一系列巨型语言模型的诞生,每一个都在不断刷新规模和性能的极限。扩展这些模型释放了令人惊叹的能力,从撰写连贯的文章到生成代码。但这背后是高昂的代价: 天文数字般的计算成本。训练这些庞大的密集模型——即每个参数在处理每一个输入时都会被使用——需要超级计算机并消耗巨量能源。

这就引出了一个关键问题: 我们能否在不承受高昂计算费用的情况下,继续享受规模带来的好处?如果我们不通过扩展模型的每一个部分来使其变大,而是增加更多专门化的部分,并且只在特定输入时使用相关部分,会有什么效果?

这正是谷歌研究院一篇开创性论文 《Switch Transformers: 利用简单高效的稀疏性扩展至万亿参数模型》 背后的核心思想。作者们提出了一种新架构,它能扩展到超过 一万亿个参数,同时保持每个输入的计算成本不变。结果如何?一个比密集模型快 4–7 倍的模型,从根本上改变了训练大型 AI 模型的经济学。

在这篇文章中,我们将深入探讨 Switch Transformer。我们将解析稀疏性和混合专家 (Mixture of Experts, MoE) 的概念,探索作者们引入的巧妙架构与训练简化方法,并分析其惊人的成果——让万亿参数模型比以往任何时候都更高效。


背景: 密集模型的问题所在

在领略 Switch Transformer 的精妙之前,我们需要回顾一下它所改进的架构: 标准的密集 Transformer。

一个密集模型就像一个委员会,其中每个成员都必须对每一项决策进行投票。当 Transformer 处理一个句子时,每个词元 (token) 都会经过模型中每一个参数的处理。将参数数量加倍,所需计算量 (以 FLOPs——浮点运算次数衡量) 也会大致加倍。这种“暴力”扩展方式正是 T5 和 GPT-3 等模型所采用的。虽然有效,但成本极高。

另一种选择是稀疏性,或称条件计算。想象一个由专家组成的委员会: 当出现财务问题时,只有经济学家参与;当涉及法律问题时,只有律师响应。这种方式效率高得多。在稀疏神经网络中,对于任何给定输入,只有一部分参数——即专家——会被激活。总参数量可以“大得惊人”,但由于每次输入只激活少量参数,计算成本保持稳定。

这个想法并不新鲜。它源于混合专家 (MoE) 模型的原理,该模型在 20 世纪 90 年代被提出,并在 2017 年由 Noam Shazeer 等人为深度学习进行了现代化改造。在 MoE 模型中,一个路由器网络会学习将每个输入发送给少量专家子网络。然而,过去的 MoE 模型一直受到结构复杂、专家间通信开销大以及训练不稳定等问题困扰。

Switch Transformer 论文直面这些问题,简化了 MoE 概念,创造出一个稳定、高效且可无限扩展的模型。


核心方法: Switch Transformer 的工作原理

作者们的指导原则是以一种简单且计算高效的方式最大化参数数量。他们通过基于标准 Transformer,将其前馈网络 (Feed-Forward Network, FFN) 层替换为新的 Switch FFN 层来实现这一目标。

Switch Transformer 编码器模块示意图。词元 “More” 和 “Parameters” 被路由到不同的专家;输出由路由器的置信度分数进行缩放。

图 1: 一个 Switch Transformer 模块。标准密集 FFN 层被一个包含多个专家的稀疏 Switch FFN 层所取代,路由器为每个词元选择一个专家。


1. 简化路由: 从 Top-K 到 Top-1

最初的 MoE 层会将每个词元路由到排名前 K 的专家 (通常 K=2) ,并将其输出通过加权和进行组合:

\[ p_i(x) = \frac{e^{h(x)_i}}{\sum_j^N e^{h(x)_j}}, \quad y = \sum_{i \in \text{TopK}} p_i(x) \cdot \text{Expert}_i(x) \]

传统观点认为,K>1 对于确保路由器获得有意义的梯度流是必要的。

Switch Transformer 团队提出了挑战:** 如果我们只将每个词元发送给它唯一的最佳专家 (K=1) 会怎样?** 这种“Switch”路由带来了以下好处:

  • **减少了路由器计算量 **(无需组合输出) 。
  • 减小了每个专家的批处理规模,因为每个词元只会去往一个专家。
  • 降低了托管不同专家的设备之间的通信开销

尽管非常简单,这种 top-1 路由依然保留了——甚至有时提升了——模型质量。


2. 在分布式环境中实现高效路由

只将词元路由到一个专家很好,但在跨越多个 TPU/GPU 核心、每个核心都有自己的专家的真实分布式设置中,硬件限制使问题更加复杂。

像 TPU 这样的加速器需要静态大小的张量。这意味着必须预先定义每个专家可处理的词元数量,即**专家容量 **(expert capacity) 。

\[ \text{expert capacity} = \left( \frac{\text{tokens per batch}}{\text{number of experts}} \right) \times \text{capacity factor} \]

*容量因子 *(capacity factor,例如 1.25) 会增加缓冲空间。如果路由到某专家的词元数量超过了其容量 (溢出) ,多余的词元会被“丢弃”,即在该层跳过,并通过残差连接向后传递。

词元路由动态示意图。溢出的词元被丢弃。更大的容量因子可以减少丢弃,但可能浪费计算资源 (空槽) 。

图 2: 每个专家处理固定大小的批次。溢出导致词元被丢弃 (红色虚线) 。更大的容量因子可以减少丢弃,但会增加计算量。

为尽量减少丢弃现象,作者们引入了一个辅助负载均衡损失:

\[ \text{loss} = \alpha \cdot N \cdot \sum_{i=1}^N f_i \cdot P_i \]

其中:

  • \( f_i \) = 分派给专家 i 的词元比例
  • \( P_i \) = 路由器分配给专家 i 的概率比例
    \( \alpha \) 是一个很小的系数 (例如 0.01) 。

该损失使两种分布趋向均匀,确保专家被均匀使用。


3. 克服训练不稳定性

稀疏模型容易因为硬性路由决策而出现训练不稳定,尤其是在使用像 bfloat16 这样的低精度格式时。作者们采用了以下方法修复:

选择性精度训练

float32 精度下执行路由器计算 (以确保稳定性) ,然后在设备间传输前将输出转换为 bfloat16
这样可以在保留稳定性的同时,获得 bfloat16 带来的速度与低内存优势。

精度比较表: bfloat16 发散;float32 稳定但较慢;选择性精度既稳定又快速。

表 1: 选择性精度在具备 float32 稳定性的同时,兼具 bfloat16 的速度。

更小的参数初始化

Switch 模型对初始化规模很敏感。将其缩小 10 倍可以提高稳定性并改善训练早期的质量。

初始化规模表: 0.1 倍的规模比 1.0 倍的初始化产生更高质量和更低方差。

表 2: 更小的初始化规模可以稳定训练并提升质量。

专家 Dropout

为避免微调时过拟合,对非专家层使用较低的 dropout (例如 0.1) ,而在专家内部使用较高的 dropout (例如 0.4) 。

Dropout 表: 低全局 dropout + 高专家 dropout 产生最佳综合效果。

表 3: 有针对性的专家 dropout 有助于提升小数据集上的微调性能。


实验与结果: 成效显著

这个设计表现如何?一句话: 惊人


前所未有的扩展与速度

训练步数为基准的情况下,当每词元的 FLOPs 固定时,增加专家数量 (即增加参数量) 可以在相同步数后持续提升质量。

扩展性图: 更多专家 → 更低的测试损失 (左) ,更快达到更低的负对数困惑度 (右) 。

图 3: 增加专家数量能在不增加单位词元计算的情况下改善损失和样本效率。

真实时间为基准的情况下,这些提升转化成了巨大的加速:

训练时间图: 64 专家版 Switch-Base 达到目标质量所需时间仅为密集 T5-Base 的 1/7。

图 4: Switch-Base 仅用七分之一训练时间便达到了与 T5-Base 相同的质量。

相比 **T5-Large **(单位词元 FLOPs 是其 3.5 倍) ,Switch-Base 依然快了 2.5 倍,且样本效率更高:

双面板图: Switch-Base 在训练步与时间效率上均优于 T5-Base 和 T5-Large。

图 5: 稀疏扩展优于密集扩展——即便相比更大的密集模型。


在下游任务中的强劲表现

在 SQuAD、XSum 和 SuperGLUE 等任务上的微调表明,预训练优势能很好地迁移到下游任务。

微调表: Switch 模型在多种 NLP 基准测试中优于密集基线模型。

表 4: 在推理和知识任务中,Switch 模型始终优于 T5 基线模型。


用蒸馏让巨型模型更实用

为了部署,作者们采用了蒸馏技术,将大型稀疏教师模型压缩成小型密集学生模型。
巧妙的初始化与混合损失函数,在压缩超过 95% 的同时保留了教师模型约 30% 的性能提升。

蒸馏表: 将 38 亿参数压缩至 2.23 亿,保留了 29% 性能提升。

表 5: 蒸馏后的 Switch 模型可得到紧凑且高性能的学生模型。


在多语言场景中的卓越表现

在 mC4 数据集 (包含 101 种语言) 上的训练结果显示,mSwitch-Base 在全部语言中优于 mT5-Base,其中 91% 的语言实现了 ≥4× 加速。

多语言图: Switch 在全部 101 种语言中的负对数困惑度均优于密集基线模型。

图 6: Switch 架构在多语言任务中实现了全局性提升。

速度提升直方图: 大部分语言训练加速 4–5 倍。

图 7: 多语言速度提升分布——平均加速 5 倍。


迈向万亿参数模型: 并行化艺术

如何构建一个拥有 1.6 万亿参数的模型?需要结合以下策略:

  1. 数据并行 – 复制模型,让不同核心处理不同输入批次。
  2. 模型并行 – 当模型过大无法放入单个核心时,将其拆分到多个核心上。
  3. 专家并行 – 在 MoE 架构中,将不同专家放到不同核心上。

并行策略示意图: 数据并行、模型并行、模型+数据并行、专家+数据并行、专家+模型+数据并行。

图 8: 在核心之间划分模型权重/数据的方式。Switch 模型结合这些方式进行扩展。

通过结合这三种并行方式,最大模型 (Switch-C) 达到了 1.6 万亿参数和 2048 个专家,而其训练的计算预算与 110 亿参数的 T5-XXL 相当。

超参数/预训练表: Switch-C 达到固定质量的速度比 T5-XXL 快 4 倍。

表 6: 万亿参数的 Switch-C 达到目标质量的速度比密集 T5-XXL 快 4 倍。


结论: 规模化的新范式

Switch Transformer 不只是一个大模型——它代表了我们扩展神经网络方式的转变。通过简化 MoE 路由并解决训练稳定性问题,作者们提供了一个实用、高效的蓝图,指向下一代 AI。

核心要点:

  • 稀疏性驱动高效扩展 —— 在不增加单位词元计算量的情况下扩大参数规模。
  • 大道至简 —— top-1 路由让模型更快、更简单、更优。
  • 训练稳定性可解决 —— 选择性精度、合理初始化和专家 dropout 均有效。
  • 优势普遍适用 —— 无论是预训练加速、下游任务提升还是多语言鲁棒性。

这项工作让大规模建模更易触及。虽然万亿参数模型仍属少见,但即便是拥有 2–4 个专家的小型 Switch 变体,也能超越密集基线模型。本文所述原则为任何规模的研究人员和工程师提供了建模新途径,以打造更强大、更高效的模型——一个将影响未来多年 AI 系统设计的范式转变。