大型语言模型 (LLM) 席卷全球,其卓越能力源于一个看似简单的原则:** 预测下一个词**。这种被称为自回归生成或输入空间重构的方法,已成为GPT、Llama和Gemma等模型的基石。
但如果LLM训练的这块基石也是一种限制呢?
在计算机视觉领域,研究人员发现,摒弃原始的像素重构,转而在更抽象的嵌入空间中进行训练,可以获得更优的效果。这里的主流范式是联合嵌入预测架构 (Joint Embedding Predictive Architecture, JEPA) ,它鼓励模型理解图像的本质,而非记忆表面的细节。
这一在视觉领域的成功引出了一个关键问题:
LLM能从它们的视觉同行那里学几招吗?
最近的一篇论文——《LLM-JEPA: 当大型语言模型遇上联合嵌入预测架构》——迈出了弥合这一差距的坚实第一步。作者提出了LLM-JEPA,一种将JEPA的预测能力整合到LLM训练中的新方法。结果是: 模型不仅保留了生成能力,还发展出更深、更稳健的表示,从而在各种任务中取得显著性能提升。
让我们深入了解其原理。
背景: 什么是JEPA?语言中的“视图”是什么?
要理解LLM-JEPA,我们首先需要清楚JEPA的原始概念。
假设你有同一只猫的两张照片——一张是正面,一张是侧面。传统的基于重构的模型可能会尝试根据第一张照片预测第二张照片的精确像素值,这是一项艰难的任务,且会浪费模型容量在诸如地毯纹理或光照这样的无关细节上。
JEPA则跳过像素预测,而是:
- 将每张图像编码为一个嵌入——即高维向量表示;
- 根据一张图像的嵌入预测另一张图像的嵌入。
这样,模型就能捕捉“猫性”的本质——形态、姿势、毛发纹理——并忽略无关噪音。
这些相关的输入被称为视图 (views) 。
在视觉领域,通过数据增强 (裁剪、旋转、重色) 很容易创建多个视图。但我们该如何为文本定义“视图”呢?
这正是 LLM-JEPA 的核心洞见:
许多自然语言任务天生就能提供同一底层概念的多个不同视图。
以软件开发者的工作流为例:
- 文本视图: 一份用自然语言写的错误报告,例如: “登录按钮在移动应用上无法使用。”
- 代码视图: 修复该错误的代码差异或补丁。
这是两个视图——对同一解决方案的不同表达。在其他领域也存在类似的组合: 自然语言 ➜ SQL查询,自然语言 ➜ 正则表达式等。
图2: 左: 以文本和代码作为同一概念两个视图的JEPA框架。右: 来自NL-RX-SYNTH (自然语言 ↔ 正则表达式) 和Spider (自然语言 ↔ SQL) 任务的示例。
通过将(文本, 代码)
视作同一底层知识的两个视图,我们可以将JEPA的理念应用到LLM上。其思路是:** 根据文本的嵌入预测代码的嵌入**。
LLM-JEPA的目标: 两种损失,一个目的
LLM-JEPA的精妙之处在于它是对标准LLM损失的增强,而不是替代。
1. 保留生成能力
LLM仍使用下一个词元预测损失进行训练,表示为:
\[ \mathcal{L}_{\text{LLM}}(\text{Text}_{1:L-1}, \text{Text}_L) = \text{XEnt}\left(\text{Classifier}\left(\text{Enc}(\text{Text}_{1:L-1})\right), \text{Text}_L\right) \]该交叉熵损失保证模型能生成连贯的文本,保持原有的生成能力。
2. 增加抽象能力
JEPA部分增加了一个嵌入预测项,总损失为:
\[ \mathcal{L}_{\text{LLM-JEPA}} = \underbrace{\sum_{\ell=2}^{L} \mathcal{L}_{\text{LLM}}(\text{Text}_{1:\ell-1}, \text{Text}_\ell)}_{\text{Generative (LLM)}} + \lambda \times \underbrace{d\left(\text{Pred}(\text{Enc}(\text{Text})), \text{Enc}(\text{Code})\right)}_{\text{Predictive (JEPA)}} \]方程2: LLM-JEPA将词元级生成与跨视图的嵌入预测相结合,并由 \(\lambda\) 平衡。
JEPA项拆解如下:
- 编码器 (
Enc
): 最后一层最后一个词元的隐藏状态表示输入的嵌入。Enc(Text)
和Enc(Code)
通过独立的两次前向传播计算。 - 预测器 (
Pred
): 一个权重共享的预测器复用LLM自身的层。添加特殊[PRED]
词元可让模型在内部预测代码嵌入。[PRED]
词元数量 \(k\) 可调。 - 度量 (
d
): 用余弦相似度 (或L2距离) 来衡量预测的嵌入与真实代码嵌入的接近程度。
因此,在训练过程中模型会:
- 通过下一个词元预测生成文本;
- 预测配对视图的嵌入。
实验与结果
作者在**四类LLM家族 (Llama3、Gemma2、OpenELM、OLMo) 和多个数据集 **(NL-RX-SYNTH、NL-RX-TURK、GSM8K、Spider) 上验证了LLM-JEPA。
JEPA损失是必需的吗?
如果下一个词元预测已能最小化嵌入预测误差,那么JEPA就是冗余的。
图4证明事实并非如此:
图4: 在基线模型中,JEPA损失 (红线) 基本不变,而LLM损失下降。LLM-JEPA (绿线) 主动降低JEPA损失,说明它提供了额外的训练信号。
更强的微调结果
图1: 左: LLM-JEPA在各数据集上的准确率提升 (如NL-RX-SYNTH上提升约15%) 。右: 在基线模型达到峰值后,LLM-JEPA依然能抵抗过拟合并持续提升。
LLM-JEPA在所有任务上都稳定优于基线模型。例如:
- NL-RX-SYNTH (Llama3): 基线约57% → LLM-JEPA约72%
- GSM8K (Llama3): 基线约32% → LLM-JEPA约36%
其正则化效果在LoRA微调中更为突出:
图5: LLM-JEPA保持了准确率的上升趋势,而基线模型则在过拟合后精度下降。
更好的预训练
即使在微调时仅使用标准LLM损失,使用LLM-JEPA进行预训练也能改进下游任务表现。
- 在NL-RX-SYNTH上从零开始预训练可提升准确率 (表1) 。
- 在释义数据集上,将不同释义视作视图,JEPA预训练提升了在Rotten Tomatoes和Yelp情感分类上的准确率 (表4) 。
洞察: 嵌入空间发生了什么变化?
作者使用t-SNE对嵌入进行了可视化:
图6: 基线模型的嵌入形成分离且无结构的簇;LLM-JEPA将文本与代码嵌入对齐为连贯且结构化的空间。
进一步分析表明,LLM-JEPA促使文本到代码嵌入之间形成近似线性映射。
这一点通过如下结果得以验证:
图7: LLM-JEPA (蓝/绿) 的奇异值比基线小数个数量级——文本与代码的映射受到了严格约束。
表10报告了LLM-JEPA嵌入的最小二乘回归误差接近零,再次印证了线性假设。
核心要点
LLM-JEPA:
- 在不同模型规模与数据集的微调任务中显著提升性能;
- 作为强力正则项,在全参数微调和参数高效微调中都能抵抗过拟合;
- 构建结构化的表示空间,将不同视图线性地对齐到低维子空间中。
局限性与未来方向
主要的权衡是训练过程中更高的计算量: 为多个视图生成嵌入目前需要约3倍的前向传播次数。作者建议使用掩码自注意力机制进行优化,以便一次性计算所有视图。
最令人期待的前沿是大规模预训练。如果JEPA式目标能持续带来这些提升,它们很可能成为训练方案的标准组件,推动LLM向更深、更抽象、更接近人类理解的方向演进。
总结:
借鉴计算机视觉的成功经验,LLM-JEPA将联合嵌入预测整合进LLM训练,带来了更强的泛化能力、更丰富的表示以及抗过拟合的特性——且并未牺牲文本生成能力。结果充分证明,嵌入空间目标可能是LLM训练的下一步进化方向。