The explosion of Large Language Models (LLMs) has democratized access to powerful AI, but customizing these models remains a hardware nightmare. While using a pre-trained model like Llama-2 or GPT-3 is relatively cheap, fine-tuning it—specializing it for medical data, code generation, or a specific writing style—requires massive computational resources.
For instance, fine-tuning a 65-billion parameter model can require upwards of 780 GB of GPU memory. This effectively gates the ability to customize state-of-the-art models behind an enterprise-level paywall.
In this post, we dive into TokenTune, a research paper from Meta AI that proposes a counter-intuitive solution: to learn effectively, the model does not need to backpropagate through every single token in the input. By randomly selecting a subset of tokens for gradient computation, TokenTune drastically slashes memory requirements without sacrificing performance.
The Fine-Tuning Bottleneck
To understand why TokenTune is necessary, we first need to dissect where GPU memory actually goes during training. When you fine-tune a transformer, memory is consumed by three main buckets:
- Model Parameters: The weights of the neural network itself.
- Gradients & Optimizer States: The data needed to update those weights.
- Intermediate Activations: The values calculated at each layer during the forward pass, which must be stored (cached) to calculate gradients during the backward pass.
The AI community has become very good at shrinking the first two buckets. Parameter-Efficient Fine-Tuning (PEFT) methods, like LoRA (Low-Rank Adaptation), freeze the main model and only train tiny adapter layers, reducing gradient storage. Quantization methods, like QLoRA, shrink the model parameters themselves by using lower-precision numbers (e.g., 4-bit integers instead of 16-bit floats).
However, there is a “ghost” in the machine: Intermediate Activations. Even with LoRA and QLoRA, the model still has to cache the activations for the entire sequence length to perform backpropagation. For long sequences (a staple of modern LLMs), this activation memory becomes the new bottleneck.
As illustrated below, while methods like QLoRA reduce memory, combining them with TokenTune allows for significantly deeper cuts.

This graph highlights the main contribution of the paper: simply adding TokenTune to existing methods (like QLoRA) can reduce memory usage to nearly one-third of what QLoRA requires alone.
The Core Concept: Token Selection
The hypothesis behind TokenTune is rooted in sparsity. Previous research suggests that not all neurons or tokens contribute equally to a model’s learning process. Some tokens carry the “signal” for the task, while others are just noise or structural filler.
TokenTune leverages this by introducing a token selection strategy.
- Forward Pass (Context is Key): The model reads the entire input sequence. All tokens are processed so the model understands the full context (Self-Attention requires seeing the whole sentence).
- Backward Pass (Selective Learning): The model calculates gradients only for a randomly selected subset of tokens (\(k\)). The unselected tokens are “frozen” effectively acting as bystanders.
Because we don’t calculate gradients for the unselected tokens, we don’t need to cache their intermediate activations. We only store what we need for the selected few.

Figure 2 visualizes this process. Notice how the input \(x\) is split. A subset of tokens (blue) flows through the full computational graph where gradients are tracked. The rest (gray) flow through a no_grad path. They contribute to the attention mechanism—allowing the blue tokens to “attend” to them—but they do not trigger memory-intensive caching for backpropagation.
Mathematical Formulation
Let’s break down how this works mathematically. We divide the input hidden states \(h\) into two groups:
- \(\mathcal{G}\): The group of selected tokens (size \(k\)).
- \(\bar{\mathcal{G}}\): The group of unselected tokens.
The objective functions change slightly to accommodate this. For a classification task, we might use an MLP (Multi-Layer Perceptron) head on top of the transformer. Instead of aggregating all tokens, TokenTune aggregates only the selected tokens \(\mathcal{G}\):

Similarly, for Language Modeling (predicting the next token), the Cross-Entropy loss is calculated only on the selected tokens:

This simple shift in the loss function dictates which activations must be stored in memory.
Optimizing Dense Layers
In a standard Transformer Dense Layer (Feed Forward Network), we compute output \(a\) from input \(h\) using weights \(W\) and bias \(b\). The Chain Rule of calculus tells us that to update weights \(W\), we need the stored input \(h\):

Here lies the optimization. If we decide that we do not need to backpropagate through the unselected tokens \(\bar{\mathcal{G}}\), the gradient of the loss with respect to those activations becomes zero.

Because the gradient is zero for the unselected group, the terms involving their specific input activations \(h_{\bar{\mathcal{G}}}\) disappear from the weight update equation.

This implies that we only need to cache \(h_\mathcal{G}\). The activations for the rest of the sequence (\(h_{\bar{\mathcal{G}}}\)) can be discarded immediately after the forward pass, saving massive amounts of GPU RAM.
In practice (e.g., in PyTorch), this is implemented by explicitly splitting the forward pass. The selected tokens are processed normally, while the unselected tokens are processed within a torch.no_grad() context block:

Optimizing Attention Layers
The attention mechanism is more complex because tokens interact with each other. A selected token might need to “attend” to an unselected token to understand the context. Therefore, we cannot simply delete the unselected tokens.
TokenTune handles this by splitting the Query (\(Q\)), Key (\(K\)), and Value (\(V\)) projections.
- Projections: We compute \(Q, K, V\) for both groups. However, the calculation for the unselected group \(\bar{\mathcal{G}}\) is done without gradient tracking.
- Attention: The selected tokens \(\mathcal{G}\) attend to everything (both \(\mathcal{G}\) and \(\bar{\mathcal{G}}\)).
The equations below show the split. Note the attention calculation (softmax) uses the concatenation of keys (\([K_{\bar{\mathcal{G}}}, K_{\mathcal{G}}]\)) and values (\([V_{\bar{\mathcal{G}}}, V_{\mathcal{G}}]\)) so that full context is preserved.

Crucially, for the selected group \(h_\mathcal{G}\), we compute the attention as follows. This is the path that requires caching:

For the unselected group \(h_{\bar{\mathcal{G}}}\), we perform the calculation inside a no_grad block (represented by the bracket notation in the image below). These values are computed to pass to the next layer, but their intermediate states are not cached.

Finally, the Feed-Forward steps follow the same pattern as the dense layers discussed earlier:


Experimental Results
The theory is sound, but does dropping gradients for half (or more) of the tokens destroy model performance? The researchers tested this on both medium-sized models (BERT) and large models (Llama 2).
1. Medium-Size Models (BERT)
The researchers tested TokenTune on the GLUE benchmark, a standard suite of Natural Language Understanding tasks. They compared TokenTune against full fine-tuning and other efficiency methods like Adapters, BitFit, and LoRA.

The results in Table 1 are compelling. TokenTune achieves an average score of 82.1, nearly identical to Full Fine-Tuning (82.8) and LoRA (81.9). This confirms that the model can learn robust representations even when backpropagating through only a subset of tokens.
How many tokens do we need? One of the most interesting findings is how few tokens are actually required. The researchers varied the number of trained positions (\(k\)) and measured performance on the MRPC and STS-B tasks.

As shown in Figure 3 (right), performance ramps up quickly. By training on just 32 positions (out of a much larger sequence), the model reaches near-optimal performance.
Figure 3 (left) illustrates the memory savings. TokenTune (solid blue line) scales much more gently than Full Fine-Tuning (purple dotted line). When combined with LoRA (cyan line), the memory footprint is a fraction of the baseline.
2. Large Language Models (Llama 2)
The stakes are higher for LLMs. The researchers fine-tuned Llama2-7B using instruction tuning (teaching the model to follow commands). They evaluated the models on difficult reasoning benchmarks like MMLU, HellaSwag, and TruthfulQA.

Table 2 shows that Llama 7B w/ TokenTune (61.23 average) actually outperforms the base Llama 7B model (60.73) and competes closely with LoRA (62.20).
The authors also explored the “Selection Ratio”—the percentage of tokens selected for backpropagation.


Looking at the tables above, we see a fascinating trend: higher selection ratios do not strictly equal better performance. In many cases, selecting just 20% to 30% of the tokens yielded the best results. This suggests that TokenTune might also act as a regularizer, preventing the model from overfitting to the fine-tuning data by introducing noise (via random token selection).
3. The Ultimate Memory Unlock
The most significant result of this paper is the memory usage analysis. By combining TokenTune with QLoRA (Quantized LoRA), the memory requirements drop precipitously.

Figure 4 visualizes the scaling. The purple dashed line (Full Fine-Tuning) sits high at ~90GB. The red bars (TokenTune + QLoRA) are drastically lower.
The exact numbers, detailed in Table 4 below, are staggering. For a selection ratio of 25%, TokenTune + QLoRA requires only 17.2 GiB of memory, compared to 91.4 GiB for full fine-tuning. This brings fine-tuning of 7B parameter models well within the range of consumer-grade GPUs (like an NVIDIA RTX 3090 or 4090).

Conclusion
TokenTune presents a simple yet powerful observation: we are wasting memory by calculating gradients for every single token in an input sequence. By decoupling the forward pass (which builds context) from the backward pass (which updates weights), TokenTune offers a new axis for optimization.
The key takeaways are:
- Context \(\neq\) Learning: The model needs to see all tokens to understand the sentence, but it only needs to be corrected on a few of them to learn the task.
- Combinability: TokenTune is not a replacement for LoRA or quantization; it is a force multiplier. It attacks the one memory bucket (activations) that other methods ignore.
- Accessibility: By reducing memory requirements by up to 79% (when combined with QLoRA), TokenTune moves us closer to a world where high-performance LLM fine-tuning can happen on personal hardware rather than massive server farms.
For students and researchers working with Transformers, TokenTune highlights the importance of questioning fundamental assumptions—like the idea that backpropagation requires the full computational graph. Sometimes, less really is more.
](https://deep-paper.org/en/paper/2501.18824/images/cover.png)