引言: 稀疏性与保真度的两难困境

机理可解释性 (Mechanistic interpretability) 研究人员长期以来一直致力于揭示大型语言模型 (LM) ,如 Gemma 或 GPT-4,是如何组织其内部表征的。在这项探索中,一个强大的工具是 稀疏自编码器 (Sparse Autoencoder, SAE) ——一种将密集的激活向量分解为称为 特征 的更简单构建块的模型。想象一下,一个语言模型的激活代表其“思维”,由数千个数值组成。SAE 将这种复杂性简化为几个成分,例如 70% “语法”40% “计算机代码”10% “正式写作”

这种稀疏分解帮助研究人员追踪信息在模型中的流动,理解因果子电路,甚至可以引导模型行为。然而,每个 SAE 都会面临持续的冲突:

  1. 稀疏性 (Sparsity): 为了具有可解释性,每个激活只应有少数特征处于激活状态。
  2. 保真度 (Fidelity): 为了有用,重建的激活必须与原始激活高度匹配。

增强稀疏性必然会牺牲保真度,而提高保真度则通常会增加特征的使用。这种权衡形成了一个 帕累托前沿 (Pareto frontier) ——研究人员不断努力推动其向外移动。

近期,一篇题为 《Jumping Ahead: Improving Reconstruction Fidelity with JumpReLU Sparse Autoencoders》 的论文提出了一种简单却强大的方法来实现这一目标。作者们提出了 JumpReLU SAE , 在给定的稀疏度水平下实现了最先进的重建保真度。这一突破源于一种新的训练方法,可优化曾被视为 不可训练 的激活函数,从而在保持高保真度的同时直接优化稀疏性。

本文介绍 JumpReLU SAE 的工作原理,阐述这一创新的意义,并探讨其对人工智能特征级可解释性未来的启示。


背景知识: 稀疏自编码器速览

一个标准的自编码器由两部分组成——一个 编码器,将输入数据压缩成低维表示;一个 解码器,用于重建原始数据。

相比之下, 稀疏自编码器是扩展而非压缩。它学习一个 超完备字典 (overcomplete dictionary) 的特征方向,其中编码器仅激活少数条目来表示任何给定输入。

形式上,对于语言模型的激活 \( \mathbf{x} \in \mathbb{R}^n \):

  1. 编码器 (Encoder): 将激活转换为特征 \( \mathbf{f}(\mathbf{x}) = \sigma(\mathbf{W}_{\text{enc}} \mathbf{x} + \mathbf{b}_{\text{enc}}) \)
  2. 解码器 (Decoder): 重建激活 \( \hat{\mathbf{x}} = \mathbf{W}_{\text{dec}} \mathbf{f} + \mathbf{b}_{\text{dec}} \)

稀疏自编码器的编码器和解码器方程。

“编码器将语言模型的激活转换为特征值,解码器则使用学习到的字典向量重建激活。”

其中 \( \mathbf{W}_{\text{enc}} \) 与 \( \mathbf{W}_{\text{dec}} \) 是权重矩阵,\( \mathbf{b}_{\text{enc}} \)、\( \mathbf{b}_{\text{dec}} \) 是偏置项,\( \sigma \) 是一个非线性激活函数——传统上为 ReLU 。 为了训练 SAE,研究人员最小化一个包含两部分的损失:

稀疏自编码器的一般损失函数。

“SAE 的总损失在重建保真度 (L2 误差) 与稀疏性 (激活字典特征数量) 之间取得平衡。”

系数 \( \lambda \) 用以调节模型因激活过密而受罚的强度。


传统 ReLU SAE 的问题

基于 ReLU 的 SAE 通常使用 L1 惩罚项来实现稀疏性,但激活函数与惩罚项都带来了不良副作用。

请看下面的玩具示例:

一个玩具模型,说明了 ReLU 的问题以及 JumpReLU 的优势。

“误报与幅度收缩: ReLU 会保留微弱的正激活 (本应为零) ,并惩罚较大的激活,从而降低重建保真度。”

当某个特征的编码器预激活值本应为非激活状态、但略为正时,ReLU 会让它通过。降低编码器偏置可减少这些“误报” (false positives) ,但同时也会压缩真实激活值。L1 惩罚进一步促使激活变小,从而系统性地低估特征幅度并削弱重建能力。

近期的一些变体,如 Gated SAETopK SAE , 引入了额外的阈值机制以更精确地控制激活。 JumpReLU SAE 将这一思想更进一步——为每个特征直接学习阈值,并优化真实的稀疏度水平,而不仅是代理指标。


核心方法: 借助 JumpReLU 实现飞跃

JumpReLU SAE 融合了两个核心理念——一种 新的激活函数 与一种 新的训练方法: 直接的 L0 稀疏性优化。

1. JumpReLU 激活函数

每个特征不再使用 ReLU,而是采用 JumpReLU 激活函数,定义如下:

JumpReLU 激活函数的方程。

“JumpReLU 用门控恒等函数取代 ReLU: 低于阈值的输入置零,高于阈值的则保持不变。”

JumpReLU 激活函数图像。对于输入小于阈值 θ 时输出为零,大于 θ 时保持恒等。

“每个特征都有自己的阈值 θ: 低于 θ 的预激活值被设为零,从而避免误报与幅度收缩。”

形式上: \( \text{JumpReLU}_\theta(z) = z \cdot H(z - \theta) \), 其中 \( H \) 为亥维赛阶跃函数 (Heaviside step function,低于 0 时取 0,高于 0 时取 1) 。当 \( z \ge \theta \) 时,激活保持不变,否则为零。

这既保留了真实激活,又干净地去除噪声——完美解决误报和幅度收缩问题。

完整的 SAE 前向传播公式变为:

JumpReLU SAE 的前向传播方程。

“JumpReLU SAE 在标准 SAE 基础上增加了一个逐特征阈值向量 θ,用于控制激活。”

2. L0 损失与训练难题

用直接的 L0 稀疏项替代 L1 惩罚在概念上是理想的——模型应最小化激活特征的数量,而非强度。

带有 L0 稀疏惩罚项的 JumpReLU SAE 损失函数。

“JumpReLU SAE 结合了 L2 重建损失和一个精确的 L0 稀疏项。”

然而,问题显而易见: JumpReLU 与 L0 范数都在阈值处 不连续。轻微移动 \( \theta \) 很少影响输出,因此关于 \( \theta \) 的梯度为零,使得基于梯度的训练无法进行。

我们如何训练没有梯度的参数?关键洞见随之而来。

3. 解决方案: 期望损失 的梯度

虽然瞬时损失几乎处处平坦,但数据分布上的 期望损失 是可微的,其解析梯度如下:

期望损失相对于阈值 θ 的解析导数。

“期望损失的梯度取决于阈值附近激活对重建误差与稀疏惩罚的影响。”

该公式表示: 应根据阈值变动对平均重建质量和稀疏性的影响来调整阈值。如果提高阈值会移除对重建贡献显著的特征,模型就会学会降低它;反之亦然。

为从小批量数据中估计该梯度,作者引入了 直通估计器 (Straight‑Through Estimator, STE)

4. 直通估计器 (STE)

STE 通过将不可微函数的真实导数 (狄拉克 δ 脉冲) 替换为阈值附近的一小段非零窗口,来近似梯度。

JumpReLU 函数相对于其阈值的伪导数。 亥维赛阶跃函数相对于其阈值的伪导数。 伪导数的可视化解释,在阈值附近的小窗口内生成梯度信号。

“对于接近阈值的预激活,伪导数提供梯度信号,使得反向传播能够穿过跳跃点。”

通过这种方式,JumpReLU 在训练期间获得平滑的代理梯度。更令人瞩目的是,作者证明这些伪梯度精确对应于真实期望梯度的 核密度估计 (kernel density estimator, KDE) , 为该技巧提供了坚实的理论基础。

总的来说,JumpReLU 使不可微函数变得可微,从而可以直接实现 L0 稀疏训练。


实验与结果: 检验 JumpReLU

作者在 Gemma 2 9B 模型的多个位置与层 (包括残差流、注意力输出、MLP 输出) 上,将 JumpReLU 与 GatedTopK SAE 进行了对比评估。

稀疏性–保真度权衡

比较 JumpReLU、Gated 和 TopK SAE 在残差流上的 Delta LM Loss 与 L0 的散点图。

“在 Gemma 2 9B 各层中,JumpReLU 在任意稀疏度下都能实现与 TopK 与 Gated SAE 相等或更高的保真度。”

指标 Delta LM Loss 衡量当重建激活替代真实激活时语言模型预测准确性下降的程度,数值越低越好。在这些图中, 绿色的 JumpReLU 曲线 始终位于竞争者下方,表明在相同稀疏度下具备更好的重建保真度。结果在残差流、MLP 与注意力激活上均保持一致。


学习到的特征动态

高频特征 —— 某些架构会产生“始终激活”的特征,这些特征在许多 token 上都活跃,因此更难解释。

展示各 SAE 类型高频特征比例的图表。

“JumpReLU 与 TopK SAE 的高频特征略多于 Gated SAE,但这些特征占总字典不足 0.06%。”

JumpReLU 与 TopK 呈现相似的高频行为,但绝大部分特征仍保持稀疏且罕见。关键在于 JumpReLU 避免了“特征死亡”问题,无需在训练中重新采样。


可解释性研究

人工研究: 人类评估者检查特征激活与对应解释,以评估单义性 (monosemanticity) ——即每个特征是否代表单一且连贯的概念。

人类评估者对特征可解释性打分的条形图。

“人类评估者认为 JumpReLU 特征的可解释性与 Gated 或 TopK 相当。”

三种 SAE 类型评分相近,验证了 JumpReLU 的优势未以牺牲可解释性为代价。

自动研究: 使用 Gemini Flash,研究人员为每个特征生成文本解释,并测试预测激活是否与真实激活一致。

模拟与真实激活之间皮尔逊相关系数的提琴图。

“在自动评估中,JumpReLU 特征与语言模型模拟激活高度相关,表现与 Gated 和 TopK SAE 相当或更优。”

JumpReLU 在文本描述与真实激活模式之间实现了有意义的对应,验证了其语义清晰性。


结论: JumpReLU 的意义

JumpReLU 稀疏自编码器 超越了以往 SAE 设计,解决了稀疏训练中长期存在的问题。通过结合阈值线性激活与坚实的直通梯度估计,它带来了以下优势:

  • 最先进的保真度: 在各个稀疏度水平上实现比 Gated 或 TopK SAE 更精确的重建。
  • 高效率: 使用简单的逐元素操作——无需昂贵的排序或辅助损失——训练更快。
  • 可解释性: 保持与最佳替代方案相当的人工和自动评估清晰度。
  • 理论基础: 通过核密度估计,将 STE 训练与期望损失的真实梯度直接关联。

这一理念的意义远不止于 JumpReLU 本身。用期望损失的梯度训练不可微模型,可能开启新的架构——那些能直接优化诸如 L0 稀疏性或定制门控策略等不连续目标的模型。

随着社区持续探索大型模型如何表征抽象概念,像 JumpReLU 这样的技术让我们更接近实现高保真、高效率、真正可解释的特征发现——一次真正的飞跃。