How to Train CLIP with Infinite Batch Sizes: Breaking the Memory Barrier

In the world of modern AI, specifically in Representation Learning, there is a recurring theme: bigger is usually better. This is particularly true for contrastive learning models like CLIP (Contrastive Language-Image Pre-training). The secret sauce behind these models isn’t just the architecture; it’s the data, and more importantly, how much data the model sees at once.

Research has consistently shown that larger batch sizes lead to better performance. A larger batch provides a more diverse set of “negative” samples (images that don’t match the text), forcing the model to learn much sharper, more discriminative features.

But there is a problem. A massive, hardware-imposed wall.

As you increase the batch size, the memory required to compute the loss function explodes. It doesn’t grow linearly; it grows quadratically. If you double your batch size, your memory usage might quadruple. Eventually, you hit an “Out of Memory” (OOM) error, even on state-of-the-art hardware like NVIDIA A100s or A800s.

Today, we are doing a deep dive into a CVPR paper titled “Breaking the Memory Barrier of Contrastive Loss via Tile-Based Strategy.” The researchers behind this work have developed a method called Inf-CL, which fundamentally changes how contrastive loss is calculated. They managed to turn that quadratic memory curve into a linear one, allowing them to train with batch sizes of up to 4 million on just 8 GPUs.

GPU memory usage comparison between Inf-CL and previous methods. The dashed line marks the common GPU memory limit.

As shown in Figure 1 above, while standard methods (CLIP and OpenCLIP) hit the memory ceiling rapidly, Inf-CL stays nearly flat. Let’s explore how they achieved this 78x reduction in memory costs.

Background: The Quadratic Trap

To understand the solution, we first need to understand the bottleneck. Contrastive learning works by taking a batch of images and their corresponding text captions. The goal is to maximize the similarity between the correct image-text pairs (the diagonal of a matrix) and minimize the similarity of incorrect pairs.

If you have a batch size of \(b\), you have \(b\) images and \(b\) texts. To compute the loss, the model must calculate the similarity between every image and every text in that batch. This results in a Similarity Matrix of size \(b \times b\).

The Vanilla Approach

In a standard distributed training setup (like the one used in the original CLIP paper), the process looks like this:

  1. Each GPU processes a small chunk of images and texts.
  2. All GPUs communicate to gather all features from every other GPU (using an AllGather operation).
  3. Each GPU now holds the full feature set for the entire global batch.
  4. The GPU computes the full \(b \times b\) similarity matrix.
  5. It performs a Softmax operation and calculates the Cross-Entropy loss.

The equation for this loss looks like this:

Contrastive Loss Equation.

Here, \(x_{i,j}\) is the similarity score between image \(i\) and text \(j\). The term inside the \(\log\) involves a summation over the entire batch \(b\).

The problem? To perform that summation and the logarithm (the Log-Sum-Exp or LSE operation), you need to instantiate that massive \(b \times b\) matrix in the GPU’s High Bandwidth Memory (HBM).

If \(b = 64,000\), a \(64k \times 64k\) matrix of floating-point numbers requires about 16 GB of memory just for the matrix itself. If you perform backpropagation, you need to store intermediate states, ballooning this to 66 GB. If you try to go to \(b=128k\), the requirement quadruples, instantly crashing mostly any GPU on the market.

Comparison of Vanilla Implementation vs Inf-CL.

Figure 2(a) illustrates this bottleneck. The “Gather” step collects everything, and the massive matrix consumes all available memory.

The Core Method: Inf-CL

The researchers propose Inf-CL (Infinite Contrastive Learning). The core insight is simple yet profound: We do not need to see the entire matrix at once to calculate the sum.

Mathematical operations like summation are cumulative. You can compute them in pieces. Inf-CL partitions the massive matrix calculation into small “tiles,” processes them sequentially, accumulates the results, and then discards the tile data from memory.

This approach changes the space complexity from \(O(b^2)\) (quadratic) to linear or even better, depending on the implementation.

1. The Mathematical Trick: Tiled LSE

To break the dependency on the full matrix, the authors rewrite the loss function. They separate the positive pairs from the normalization term (the Log-Sum-Exp part):

Decomposed Contrastive Loss Equation.

The difficult part is the second term: \(\log \sum e^{x_{i,j}}\). To compute this without storing the whole row, they use a streaming update rule.

Imagine you want to calculate the Log-Sum-Exp of a row, but you only receive the data in chunks (tiles). You can maintain a running LSE value. When a new tile arrives, you update your running value using this formula:

Stable LSE Update Rule.

Here:

  • \(l^i\) is the current accumulated LSE value.
  • \(l^{i,j}\) is the LSE value of the current small tile being processed.

By iterating through \(j=1\) to \(n_c\) (the number of column tiles), you can build the final global result step-by-step. You only ever need to store the current tile and the running accumulation vector.

Numerical Stability: Calculating exponentials (\(e^x\)) can easily lead to overflow (numbers too big for the computer to handle). Standard practice is to subtract the maximum value in the row before exponentiating. The authors incorporate this into their tiling strategy as well:

Tile-wise LSE with Max Subtraction.

This ensures that the calculation remains stable even when processed in small chunks.

2. The System Architecture: Multi-Level Tiling

The mathematical trick is great, but implementing it efficiently on a cluster of GPUs requires clever engineering. The authors introduce a Multi-Level Tiling Strategy that optimizes for both memory and speed.

Multi-level tiling strategy diagram.

As visualized in Figure 3, the strategy operates on two levels: Cross-GPU (between devices) and In-GPU (inside the chip).

Level 1: Cross-GPU Tiling (The Ring)

In the vanilla method, every GPU downloads everyone else’s data immediately. This causes a massive spike in memory usage for data storage.

Inf-CL uses a Ring Topology.

  1. Partition Rows: Each GPU is responsible for computing the loss for a specific subset of images (rows).
  2. Rotate Columns: The text features (columns) are passed from GPU to GPU in a ring.
  3. Compute & Pass: GPU 1 calculates the similarity between its images and its own text features. Then, it sends its text features to GPU 2 and receives text features from GPU 3 (in a 3-GPU setup).
  4. Accumulate: It updates the running LSE value with the new data.

This means a GPU only holds a fraction of the data at any specific millisecond. Furthermore, the communication (sending/receiving data) happens asynchronously while the GPU is busy computing the math. This “overlap” hides the communication latency, so the system doesn’t sit idle waiting for data.

Level 2: In-GPU Tiling (Fused Kernels)

Even inside a single GPU, memory is hierarchical. You have the massive but slower HBM (High Bandwidth Memory) and the tiny but ultra-fast SRAM (Static Random Access Memory, or shared memory).

Moving data back and forth between HBM and SRAM is expensive (it takes time and energy). Standard PyTorch operations would load the tile, compute matrix multiplication, write to HBM, load it back for exponential, write to HBM, load back for sum… this is inefficient.

Inf-CL uses Kernel Fusion. They wrote custom CUDA kernels that:

  1. Load a small tile of image/text features into SRAM.
  2. Perform matrix multiplication, maximum subtraction, exponentiation, and summation entirely within SRAM.
  3. Only write the single accumulated result vector back to HBM.

This drastically reduces memory I/O, making the tile-based approach as fast as the memory-hungry vanilla approach.

3. Tiled Backpropagation

We cannot forget the backward pass (gradient calculation). In standard training, you have to store the similarity matrix from the forward pass to compute gradients later. Since Inf-CL never materializes the full matrix, how do we compute gradients?

The answer is Gradient Checkpointing, adapted for tiles. During the backward pass, the system re-computes the similarity scores for the specific tile being processed, calculates the gradient contribution, accumulates it, and then discards the scores again.

The gradient equations are derived to support this accumulation:

Gradient accumulation formulas.

Here, \(I'_i\) is a temporary variable used to accumulate gradients for the image encoder. By re-computing the similarity \(x_{i,j}\) on the fly during the backward pass, the memory cost remains linear, at the expense of a slight increase in computation (which is offset by the optimized kernels).

Experiments & Results

So, does it work? The results are startlingly positive.

Memory Consumption

The primary goal of this paper was to reduce memory usage. Table 1 presents the comparison.

Training Memory Cost Table.

Look at the 128k batch size column on 8 A800 GPUs:

  • CLIP (Vanilla): Fails (Out of Memory).
  • OpenCLIP: Uses 62.37 GB (per GPU).
  • Inf-CL: Uses 0.81 GB for the loss calculation.

Because the loss memory footprint is so small, the main memory consumer becomes the data itself (storing the input images and model weights). This allowed the researchers to push batch sizes to 4,096k (4 million) on a single node of 8 GPUs by using “data offloading” (moving data to CPU RAM when not in use).

The memory complexity drops from \(O(b^2)\) to \(O(b/n^2)\) (where \(n\) is the number of GPUs), which is effectively linear scaling with respect to batch size per GPU.

Speed and Efficiency

Usually, saving memory comes at the cost of time. Re-computing gradients and processing in tiles sounds slower. However, Figure 4 shows a different story.

Training Speed Comparison.

  • Left Chart: The time per iteration (in seconds). Inf-CL (blue bars) is virtually identical to OpenCLIP and vanilla CLIP at lower batch sizes.
  • Right Chart: Total training time per epoch. It remains stable at around 59 hours regardless of batch size scaling.

Why is it fast? Two reasons:

  1. Kernel Fusion: The custom CUDA kernels are highly optimized, reducing memory bandwidth bottlenecks.
  2. Overlapping: The ring communication happens at the same time as the computation.

Maximum Batch Size

How far can we push it? Table 2 provides the “breaking points” for different methods.

Maximum batch size table.

On 32 GPUs, OpenCLIP maxes out at a batch size of 352k. Inf-CL, utilizing data offloading, can reach 12,288k (12 million). This effectively removes the loss function as the limiting factor in training large models.

Does it affect Accuracy?

A common fear with approximation or tiling methods is that numerical precision might suffer, leading to worse models.

Performance Verification Table.

Table 3 compares the zero-shot accuracy on ImageNet.

  • Vanilla (64k): 74.74%
  • Inf-CL (64k): 74.93%

The results are statistically equivalent. This confirms that Inf-CL is mathematically exact; it is not an approximation.

Interestingly, the authors note that simply increasing the batch size to 1 million doesn’t automatically grant “superpowers.” As shown in the table, performance saturated and even dipped slightly at 1024k batch size. This suggests that while the hardware barrier is broken, we now face a tuning barrier—hyperparameters for such massive batches need to be re-investigated by the community.

Optimal Batch Size Analysis.

Figure 5 (top) reinforces this. Larger datasets (like LAION-400M) benefit more from larger batch sizes than smaller datasets. As dataset sizes continue to grow, the ability to scale batch size linearly will become increasingly vital.

Conclusion

The paper “Breaking the Memory Barrier of Contrastive Loss via Tile-Based Strategy” presents a significant engineering breakthrough for Large Multimodal Models.

By rethinking the implementation of the contrastive loss function, the authors successfully decoupled memory usage from batch size.

  1. Decomposition: They broke the Log-Sum-Exp operation into independent, accumulating tiles.
  2. Distribution: They utilized a ring topology to distribute computation across GPUs without massive memory spikes.
  3. Optimization: They used low-level kernel fusion to minimize memory bandwidth usage.

The result is Inf-CL, a method that transforms the memory cost of contrastive learning from a quadratic bottleneck into a manageable, linear expense. For students and researchers, this implies that the days of “Out Of Memory” errors on contrastive loss are numbered. We can now focus on the science of learning from massive batches, rather than the logistics of fitting them onto a chip.

This work paves the way for the next generation of foundation models, where batch sizes in the millions could become the new standard.