如何以无限 Batch Size 训练 CLIP: 突破显存瓶颈
在现代 AI 领域,特别是表征学习 (Representation Learning) 中,有一个反复出现的主题: 越大通常越好 。 对于像 CLIP (对比语言-图像预训练) 这样的对比学习模型来说尤其如此。这些模型背后的秘诀不仅仅是架构,更是数据,最重要的是模型一次能看到多少数据。
研究一致表明,更大的 batch size (批大小) 能带来更好的性能。更大的 batch size 提供了更多样化的“负”样本集 (即与文本不匹配的图像) ,迫使模型学习更清晰、更具判别性的特征。
但存在一个问题。这是一堵由硬件强加的巨大墙壁。
当你增加 batch size 时,计算损失函数所需的显存会爆炸式增长。它不是线性增长,而是二次方级增长。如果你将 batch size 翻倍,你的显存使用量可能会变成原来的四倍。最终,即使在 NVIDIA A100 或 A800 这样的顶级硬件上,你也会遇到“显存溢出” (OOM) 错误。
今天,我们将深入探讨一篇 CVPR 论文,名为 “Breaking the Memory Barrier of Contrastive Loss via Tile-Based Strategy” (通过分块策略突破对比损失的显存屏障) 。 这项背后的研究人员开发了一种名为 Inf-CL 的方法,从根本上改变了对比损失的计算方式。他们成功地将二次方显存曲线转变为线性曲线,从而实现了仅用 8 张 GPU 就能以高达 400 万 的 batch size 进行训练。

如上图 1 所示,当标准方法 (CLIP 和 OpenCLIP) 迅速触及显存天花板时,Inf-CL 几乎保持平稳。让我们来看看他们是如何实现显存成本缩减 78 倍的。
背景: 二次方陷阱
要理解解决方案,首先需要了解瓶颈所在。对比学习的工作原理是获取一批图像及其对应的文本描述。目标是最大化正确图像-文本对 (矩阵对角线) 之间的相似度,并最小化错误配对的相似度。
如果你有一个 batch size 为 \(b\) 的批次,你就拥有 \(b\) 张图像和 \(b\) 段文本。为了计算损失,模型必须计算该批次中每一张图像与每一段文本之间的相似度。这就产生了一个大小为 \(b \times b\) 的相似度矩阵。
朴素方法 (Vanilla Approach)
在标准的分布式训练设置中 (如原始 CLIP 论文中所使用的) ,流程如下:
- 每张 GPU 处理一小块图像和文本。
- 所有 GPU 通信以收集来自所有其他 GPU 的全部特征 (使用
AllGather操作) 。 - 每张 GPU 现在都持有整个全局 batch 的完整特征集。
- GPU 计算完整的 \(b \times b\) 相似度矩阵。
- 执行 Softmax 操作并计算交叉熵损失。
损失函数的公式如下:

这里,\(x_{i,j}\) 是图像 \(i\) 和文本 \(j\) 之间的相似度分数。\(\log\) 内部的项涉及对整个 batch \(b\) 的求和。
问题在哪? 为了执行该求和以及对数运算 (Log-Sum-Exp 或 LSE 操作) ,你需要在 GPU 的高带宽内存 (HBM) 中实例化那个巨大的 \(b \times b\) 矩阵。
如果 \(b = 64,000\),一个 \(64k \times 64k\) 的浮点数矩阵仅矩阵本身就需要约 16 GB 的显存。如果执行反向传播,你需要存储中间状态,这会使其膨胀到 66 GB 。 如果你试图达到 \(b=128k\),需求会翻四倍,瞬间让市面上几乎所有的 GPU 崩溃。

图 2(a) 展示了这个瓶颈。“Gather”步骤收集了所有数据,巨大的矩阵消耗了所有可用显存。
核心方法: Inf-CL
研究人员提出了 Inf-CL (无限对比学习) 。其核心洞察简单而深刻: 我们不需要一次性看到整个矩阵来计算总和。
像求和这样的数学运算是可累积的。你可以分块计算它们。Inf-CL 将巨大的矩阵计算划分为小的“分块 (tiles) ”,按顺序处理它们,累积结果,然后从内存中丢弃分块数据。
这种方法将空间复杂度从 \(O(b^2)\) (二次方) 改变为线性甚至更好,具体取决于实现方式。
1. 数学技巧: 分块 LSE
为了打破对完整矩阵的依赖,作者重写了损失函数。他们将正样本对与归一化项 (Log-Sum-Exp 部分) 分离开来:

困难的部分是第二项: \(\log \sum e^{x_{i,j}}\)。为了在不存储整行的情况下计算它,他们使用了流式更新规则。
想象一下,你想计算一行的 Log-Sum-Exp,但你只能分块接收数据。你可以维护一个运行中的 LSE 值。当新的分块到达时,你使用以下公式更新你的运行值:

这里:
- \(l^i\) 是当前累积的 LSE 值。
- \(l^{i,j}\) 是当前正在处理的小分块的 LSE 值。
通过从 \(j=1\) 到 \(n_c\) (列分块的数量) 进行迭代,你可以逐步构建最终的全局结果。你只需要存储当前的分块和运行中的累积向量。
数值稳定性: 计算指数 (\(e^x\)) 很容易导致溢出 (数字太大,计算机无法处理) 。标准做法是在求指数之前减去行中的最大值。作者也将这一点融入到了他们的分块策略中:

这确保了即使以小块进行处理,计算也能保持稳定。
2. 系统架构: 多级分块
数学技巧很棒,但在 GPU 集群上高效实现它需要巧妙的工程设计。作者引入了一种多级分块策略 , 以优化显存和速度。

如图 3 所示,该策略在两个层面上运作: 跨 GPU (设备间) 和 GPU 内 (芯片内) 。
第 1 级: 跨 GPU 分块 (环形结构)
在朴素方法中,每个 GPU 都会立即下载其他所有人的数据。这会导致数据存储的显存使用量激增。
Inf-CL 使用环形拓扑 (Ring Topology) 。
- 行分区: 每个 GPU 负责计算特定图像子集 (行) 的损失。
- 列轮转: 文本特征 (列) 在 GPU 之间形成环形传递。
- 计算与传递: GPU 1 计算其图像与其自身文本特征之间的相似度。然后,它将其文本特征发送给 GPU 2,并从 GPU 3 接收文本特征 (在 3-GPU 设置中) 。
- 累积: 它使用新数据更新运行中的 LSE 值。
这意味着在任何特定的毫秒内,GPU 仅持有一小部分数据。此外,通信 (发送/接收数据) 是在 GPU 忙于计算数学运算时异步发生的。这种“重叠”隐藏了通信延迟,因此系统不会因等待数据而闲置。
第 2 级: GPU 内分块 (融合算子)
即使在单个 GPU 内部,内存也是分层的。你有巨大但较慢的 HBM (高带宽内存) 和微小但超快的 SRAM (静态随机存取存储器,或共享内存) 。
在 HBM 和 SRAM 之间来回移动数据是昂贵的 (耗时且耗能) 。 标准的 PyTorch 操作会加载分块,计算矩阵乘法,写入 HBM,加载回来做指数运算,写入 HBM,加载回来求和……这非常低效。
Inf-CL 使用算子融合 (Kernel Fusion) 。 他们编写了自定义的 CUDA 核函数,能够:
- 将一小块图像/文本特征加载到 SRAM 中。
- 完全在 SRAM 内执行矩阵乘法、最大值减法、指数运算和求和。
- 仅将单个累积结果向量写回 HBM。
这大大减少了内存 I/O,使得分块方法与消耗大量显存的朴素方法一样快。
3. 分块反向传播
我们不能忘记反向传播 (梯度计算) 。在标准训练中,你必须存储前向传播的相似度矩阵以便稍后计算梯度。既然 Inf-CL 从不实例化完整矩阵,我们如何计算梯度?
答案是适用于分块的梯度检查点 (Gradient Checkpointing) 。 在反向传播期间,系统会为正在处理的特定分块重计算相似度分数,计算梯度贡献,累积它,然后再次丢弃分数。
梯度公式经过推导以支持这种累积:

这里,\(I'_i\) 是一个用于累积图像编码器梯度的临时变量。通过在反向传播期间动态重计算相似度 \(x_{i,j}\),显存成本保持线性,代价是计算量略有增加 (但这被优化的核函数所抵消) 。
实验与结果
那么,它有效吗?结果令人惊讶地好。
显存消耗
这篇论文的主要目标是减少显存使用。表 1 展示了对比结果。

看一看 8 张 A800 GPU 上的 128k batch size 列:
- CLIP (Vanilla): 失败 (显存溢出) 。
- OpenCLIP: 使用 62.37 GB (每张 GPU) 。
- Inf-CL: 仅使用 0.81 GB 用于损失计算。
由于损失计算的显存占用如此之小,主要的显存消耗者变成了数据本身 (存储输入图像和模型权重) 。这使得研究人员能够通过使用“数据卸载” (在不使用时将数据移动到 CPU 内存) 在单个 8 GPU 节点上将 batch size 推高至 4,096k (400 万) 。
显存复杂度从 \(O(b^2)\) 下降到 \(O(b/n^2)\) (其中 \(n\) 是 GPU 数量) ,这实际上是相对于每张 GPU 的 batch size 的线性扩展。
速度与效率
通常,节省显存是以时间为代价的。重计算梯度和分块处理听起来更慢。然而,图 4 展示了不同的情况。

- 左图: 每次迭代的时间 (秒) 。在较低 batch size 下,Inf-CL (蓝色柱) 与 OpenCLIP 和朴素 CLIP 几乎完全相同。
- 右图: 每个 epoch 的总训练时间。无论 batch size 如何扩展,它都稳定保持在约 59 小时左右。
为什么它这么快?两个原因:
- 算子融合: 自定义的 CUDA 核函数经过高度优化,减少了内存带宽瓶颈。
- 重叠: 环形通信与计算同时发生。
最大 Batch Size
我们能推多远?表 2 提供了不同方法的“极限点”。

在 32 张 GPU 上,OpenCLIP 的最大 batch size 为 352k 。 Inf-CL 利用数据卸载,可以达到 12,288k (1200 万) 。 这实际上消除了损失函数作为训练大模型限制因素的影响。
这会影响精度吗?
对于近似或分块方法,一个常见的担忧是数值精度可能会受损,导致模型变差。

表 3 比较了 ImageNet 上的零样本 (zero-shot) 精度。
- Vanilla (64k): 74.74%
- Inf-CL (64k): 74.93%
结果在统计上是等效的。这证实了 Inf-CL 在数学上是精确的;它不是近似值。
有趣的是,作者指出,简单地将 batch size 增加到 100 万并不会自动赋予“超能力”。如表所示,在 1024k batch size 时,性能饱和甚至略有下降。这表明,虽然硬件障碍已被打破,但我们现在面临着调优障碍——社区需要重新研究这种超大 batch 下的超参数。

图 5 (上图) 强化了这一点。较大的数据集 (如 LAION-400M) 比较小的数据集更能从大 batch size 中受益。随着数据集规模的持续增长,线性扩展 batch size 的能力将变得至关重要。
结论
论文 “Breaking the Memory Barrier of Contrastive Loss via Tile-Based Strategy” 为大型多模态模型带来了重大的工程突破。
通过重新思考对比损失函数的实现,作者成功地将显存使用与 batch size 解耦。
- 分解: 他们将 Log-Sum-Exp 操作分解为独立的、可累积的分块。
- 分发: 他们利用环形拓扑在 GPU 之间分发计算,而没有产生巨大的显存峰值。
- 优化: 他们使用底层算子融合来最小化内存带宽使用。
其结果是 Inf-CL , 一种将对比学习的显存成本从二次方瓶颈转变为可控的线性开销的方法。对于学生和研究人员来说,这意味着对比损失出现“显存溢出”错误的日子屈指可数了。我们现在可以专注于从海量 batch 中学习的科学,而不是纠结于如何将它们塞进芯片的后勤工作。
这项工作为下一代基础模型铺平了道路,使得数百万级的 batch size 成为新标准。
](https://deep-paper.org/en/paper/file-1943/images/cover.png)