为 Transformer 瘦身: 通过边剪枝揭示隐藏电路

像 GPT-4 和 Llama 这样的大型语言模型 (LLM) 功能强大,但同样神秘莫测。我们可以用它们写文章、生成代码、解决谜题,却很少知道它们是如何得出结论的。这种“黑箱”特性使得构建更安全、更可靠的 AI 系统变得艰难。

机制可解释性 (Mechanistic Interpretability) 旨在打开这个黑箱。它不再将模型视为不透明的整体,而是研究内部组件——如 Transformer 中的注意力头和 MLP——如何协同完成计算。该领域的一个核心概念是电路 (circuits) : 模型中捕捉特定行为的稀疏、集中的子图。

想象一下,你能单独分离出汽车中用来控制雨刷开关的微小电路。如果我们能够在 Transformer 内部找到负责语言任务 (如代词解析) 的对应子图,我们就可以对它进行研究、调试,甚至改进。

然而,寻找这些电路并不容易。早期的方法依赖人工检查,无法扩展。自动化工具如 ACDC 和 EAP 推动了该领域的发展,但要么对大型模型而言太慢,要么近似程度太高而不够稳定。

普林斯顿大学研究人员最近发表的论文 《Finding Transformer Circuits with Edge Pruning》 提出了一个更优雅且可扩展的解决方案。他们将电路发现问题重新定义为一个基于模型剪枝的优化问题。由此得到的算法称为 边剪枝 (Edge Pruning) , 能高效地找到精简、高质量的电路,并首次扩展到130 亿参数的模型


理解电路及其发现的挑战

Transformer 通过层级结构处理信息——在注意力和 MLP 模块间交替——并都通过一个残差流 (residual stream) 连接。在每一层 \(i\),模型通过如下方式更新内部状态:

\[ h_{i+1} = h_i + f_i(h_i), \]

其中 \(h_i\) 是残差流,\(f_i\) 是该层执行的操作。

我们可以将这些层及其交互看作一个计算图 。 每个组件 (注意力头或 MLP 块) 对应一个节点,节点之间的连接对应。一个电路就是这些边的一个子集——整个模型图的稀疏版本——仍能执行某种特定功能。

标准的 Transformer 将输出直接加到残差流中。边剪枝解耦了此流,以允许可学习的掩码识别关键边。

图 1: 边剪枝的核心概念。与密集的 Transformer (a) 不同,可学习的掩码沿模型边进行优化 (b),从而得到稀疏且忠实的电路 (c)。

为了确定哪些边重要,可解释性研究者采用一种称为交换消融 (interchange ablation) (又称激活补丁) 的因果技术。该过程包括:

  1. 使用一个干净输入 (如“Mary gave the ball to John.”) 运行模型并保存所有激活值。
  2. 使用一个扰动输入 (如“Amy gave the ball to David.”) 运行模型,这应产生不同输出。
  3. 将扰动输入的激活值替换到干净输入运行的特定边上。
  4. 观察模型输出变化——例如,如果预测从 John 变为 David,则该边至关重要。

最终目标是找到一个稀疏子图 \( \mathcal{C} \),使得电路输出分布 \( p_{\mathcal{C}}(y | x, \tilde{x}) \) 与完整模型的分布 \( p_{\mathcal{G}}(y | x) \) 尽量相近。优化目标是在保持稀疏性的同时最小化两者差异:

\[ \arg\min_{\mathcal{C}} \mathbb{E}_{(x, \tilde{x})}\left[D\left(p_{\mathcal{G}}(y | x) \parallel p_{\mathcal{C}}(y | x, \tilde{x})\right)\right], \quad \text{subject to } 1 - \frac{|\mathcal{C}|}{|\mathcal{G}|} \ge c. \]

以往方法的局限性

ACDC 使用贪婪搜索: 逐一通过消融测试每条边。该方法准确但代价高昂,尤其对数十亿参数的模型而言速度极慢。

EAP (Edge Attribution Patching) 则使用线性、基于梯度的近似方法来估计边的重要性。它更快,但忽略边之间的依赖,使得在复杂任务中准确度下降。

边剪枝通过将离散搜索转化为连续优化 , 解决了这些问题,结合了直接测试的准确性与基于梯度学习的高效性。


边剪枝方法

边剪枝在 Transformer 计算图的每条边上引入可训练掩码 (trainable masks) 。 它不是删除整个神经元,而是选择性剪去组件之间的连接

对于每条边 \( j \to i \),一个连续掩码 \( z_{ji} \in [0,1] \) 决定来自节点 \( j \) 的信号是否影响节点 \( i \)。在训练过程中:

  • \( z_{ji} = 1 \): 边保持激活,使用干净激活 \( y_j \)。
  • \( z_{ji} = 0 \): 边被剪除,替换为扰动激活 \( \tilde{y}_j \)。
  • 中间值将两者进行插值。
\[ y_i = f_i\!\left(z_{0i}y_0 + (1 - z_{0i})\tilde{y}_0 + \sum_{j这种结构允许边剪枝同时对所有掩码进行梯度下降,学习出哪些边是重现模型行为所必需的。

解耦的残差流

一个新挑战是: 每个组件都需要基于掩码对过去的激活进行独特组合。为提供这种灵活性,模型的残差流被解耦 (disentangled) ——不再维护单一向量,而是保存整个激活历史。当新层开始时,它动态聚合由边掩码加权的相关输入。此过程会增加内存需求,但对准确且细致的剪枝至关重要。

优化稀疏电路

为鼓励稀疏性,论文使用了 \( L_0 \) 正则化——一种惩罚非零掩码的技术。由于 \( L_0 \) 不可微,作者采用硬具体分布 (hard concrete distribution) , 使掩码值能在连续训练中向二元极端逼近:

\[ \begin{array}{c} u \sim \text{Uniform}(\epsilon, 1-\epsilon), \quad s = \sigma\!\left(\frac{1}{\beta}\log\frac{u}{1-u} + \log\alpha\right), \\ \tilde{s} = s(r - l) + l, \quad z = \min(1, \max(0,\tilde{s})). \end{array} \]

为达到目标稀疏度 \( t \),拉格朗日项用于调整训练压力:

\[ \mathcal{L}_s = \lambda_1 (t - s) + \lambda_2 (t - s)^2. \]

最终的损失函数将此正则项与忠实度损失 (即与完整模型的 KL 散度) 结合。训练后,连续掩码会被阈值化,形成二元连接——定义最终稀疏电路。


边剪枝的实验验证

作者在四个标准电路发现任务上使用 GPT-2 Small 对边剪枝进行了验证,并与 ACDC 和 EAP 方法进行比较。

衡量忠实度

忠实度反映电路行为与完整模型的吻合程度,以 KL 散度衡量——值越小越好。

ACDC、EAP 和边剪枝方法的 KL 散度与边稀疏度关系。

图 2: 在复杂任务 (如间接宾语识别 IOI 和“大于” GT) 上,边剪枝 (绿色) 显著降低 KL 散度——即比 ACDC (蓝色) 和 EAP (橙色) 忠实度更高。

在性别代词解析等简单任务上,边剪枝表现不俗;而在更复杂的 IOIGT 任务中,其性能远超其他方法,即使在高度稀疏的情况下仍具高忠实度。

电路性能

忠实电路固然重要——但它们是否仍能正确完成任务?

各方法性能比较。越高越好。

图 3: 边剪枝得到的电路不仅保持甚至提升了任务性能,同时显著提升了稀疏度。

结果显示,边剪枝明显优于其他方法。在 IOI 任务中,它以98.8% 稀疏度实现了与 ACDC 电路 (96.8% 稀疏度) 相同的性能。这意味着它利用少 2.65 倍的边达成同等保真度,从而生成更清晰、易解释的电路。

数据规模扩展

为测试可扩展性,作者将 IOI 数据集从几百条样本扩展至 100,000 条 。 运行时间与忠实度结果 (表 1) 显示,边剪枝不仅能胜任大型数据集,且扩展后表现卓越。

方法稀疏度 (%)↑KL ↓ (200)时间 (s) ↓KL ↓ (400)时间 (s) ↓KL ↓ (100K)时间 (s) ↓
ACDC96.6 ± 0.10.9218,7830.8840,759
EAP96.6 ± 0.13.47213.66433.7812,260
边剪枝96.6 ± 0.10.252,7560.222,9310.203,042

表 1: 边剪枝在 10 万样本规模下依旧快速且忠实度高,显著优于以往方法。

ACDC 随数据增加略有提升但耗时长;EAP 速度快但精度不足;边剪枝则兼具高效与质量,堪称大规模可解释性任务的理想选择。

恢复基准真值电路

作者进一步在 Tracr 框架中验证结果——该框架能将人工编写的算法编译为微型 Transformer 模型。这些“编译的 Transformer”拥有已知的基准真值电路,适宜用于验证。

边剪枝恢复的基准真值电路。

图 4: 边剪枝成功重建了两个 Tracr 编译程序——比例计数与列表反转的基准真值电路。

边剪枝完美恢复了这两个电路,证明其优化程序能稳定地发现真实底层计算机制。


扩展至 130 亿参数: CodeLlama 案例研究

绝大多数可解释性工具难以适配超过数亿参数的模型。为展示其实用扩展能力,研究人员将边剪枝应用于 CodeLlama-13B ——其规模比 GPT-2 大逾百倍

任务设定

研究考察了 CodeLlama 处理布尔表达式 (如“((not False) and True) or False is → False”) 的方式,并比较两种提示风格:

  • 指令提示 (IP) : 模型接受直接的指令。
  • 少样本提示 (FS) : 模型接受若干示例。

作者分别在两种设定下应用边剪枝,以识别负责推理的电路。

电路边数量 ↓准确率 (%) ↑完全匹配率 (%) ↑
IPFSIPFS
完整模型3,872,82082.0089.25100.00100.00
指令提示 (IP)1,04179.2574.5090.0079.00
少样本 (FS)1,46475.7587.2584.5091.25
IP ∩ FS65372.5068.2579.7572.50

表 2: 边剪枝分离出的电路稀疏度超过 99.96%,性能与完整模型几乎一致。

主要发现

  1. 极高稀疏性: 电路仅保留不到 0.04% 的边,但性能几乎不下降。
  2. 共享机制: 两种电路高度重叠——共有边达 62.7%——表明少样本和指令提示背后存在共享的推理结构。
  3. 跨场景鲁棒性: 少样本电路在指令提示下同样表现良好,暗示存在统一内部机制。

这一案例体现了突破性进展: 在万亿次计算规模上实现可扩展的可解释性,揭示大型模型中不同提示方式背后的共同神经“布线”。


意义与展望

边剪枝是机制可解释性领域的重要里程碑:

  • 准确性: 能比以往方法更精准地识别忠实电路。
  • 效率: 支持与数据线性扩展,并保持在大模型上可行。
  • 有效性: 可完美恢复已知电路,确保结果可靠。
  • 可扩展性: 能在前所未有的规模运行,揭示 130 亿参数 Transformer 内部的结构。

其主要限制在于内存开销——解耦残差流需额外 GPU 资源。此外,即使高度稀疏的大模型电路仍包含数千条边,使人工解析仍具挑战。但这是值得的权衡: 我们获得了模型内部机制前所未有的可视化洞察。

随着模型愈加庞大,诸如边剪枝之类的工具将成为研究人员的关键助手,帮助他们看见信息在神经网络中的流动与决策形成过程。能够在任意规模上识别并分析电路,让我们更接近一个既能高效使用 AI,又能深刻理解其运作的未来。