Introduction

We are currently in the golden age of Large Language Models (LLMs). From GPT-4 to Llama 3, these models act as reasoning engines capable of astonishingly human-like behavior. However, there is a persistent bottleneck that every developer, student, and researcher faces: latency.

The core issue lies in auto-regressive decoding. To generate a sentence, an LLM must predict one token, append it to the sequence, feed it back into itself, and predict the next. For a 100-token response, the model must run its entire massive architecture 100 times sequentially. This process under-utilizes modern GPUs, which thrive on parallel processing, and makes real-time applications expensive and sluggish.

A popular solution to this is Speculative Decoding. This technique uses a smaller, faster “draft model” to guess the next few tokens, which the large “target model” then verifies in parallel. It’s like having an intern draft an email while the manager quickly approves or corrects it.

But here is the catch: Speculative decoding only works if the intern (the draft model) is good at the specific task. If you ask a translation-specialized draft model to help with a math problem, it will guess wrong, the target model will reject everything, and you end up slower than if you hadn’t used the draft model at all.

In the paper “Context-Aware Assistant Selection for Improved Inference Acceleration with Large Language Models,” researchers Jerry Huang, Prasanna Parthasarathi, Mehdi Rezagholizadeh, and Sarath Chandar propose a clever solution. Instead of relying on a single draft model, what if we had a pool of specialists? And what if we could train a “manager” to look at the user’s query and instantly pick the best assistant for the job?

This blog post breaks down their method, which frames LLM acceleration as a contextual bandits problem, demonstrating how we can dynamically select draft models to maximize speed without sacrificing quality.


Background: The Limitations of Static Drafting

To understand the innovation here, we first need to solidify our understanding of the status quo.

Speculative Decoding Recap

In standard speculative decoding, you have two models:

  1. Target Model (\(M_e\)): The big, smart, slow model (e.g., Llama-70B).
  2. Draft Model (\(M_d\)): A small, fast, less accurate model (e.g., Llama-7B).

The draft model rapidly generates a short sequence of tokens (say, 5 tokens). The target model processes all 5 tokens in a single forward pass to verify them. If they match what the target model would have generated, we keep them. If not, we discard the wrong ones and resume standard generation.

The Alignment Problem

The efficiency of this process depends entirely on the Acceptance Rate—the percentage of drafted tokens the target model agrees with.

  • High Acceptance: The target model skips many generation steps. Huge speedup.
  • Low Acceptance: The target model does extra work verifying bad guesses. Zero speedup or even slowdowns.

The problem is that a single small draft model cannot be an expert at everything. A small model might be great at summarizing news (high acceptance) but terrible at Python coding (low acceptance). In a production environment with diverse user queries, a single static draft model acts as a bottleneck for out-of-domain tasks.


Core Method: Context-Aware Assistant Selection

The authors propose a system where the inference pipeline has access to multiple draft candidates. These candidates could be different models entirely, or the same model architecture fine-tuned on different domains (e.g., one for coding, one for chat, one for math).

The core challenge is decision making: Given a specific input query (Context), which draft model (Arm) should we pull to get the best speedup (Reward)?

The researchers model this as a Contextual Bandits problem. Here is the high-level workflow:

Figure 1: Overview of our methodology. We first train a policy using offline data collected from greedily decoded output from each model, which are scored to produce reward samples. At test time,the policy takes in a query q’ to select a draft candidate model, which is then used for assisted generation with the target model.

As shown in Figure 1 above, the process is split into two phases: Offline Training and Online Inference.

1. Offline Data Collection & Scoring

Training a policy online (learning while serving users) is risky and slow because you’d have to suffer through bad decisions to learn from them. Instead, the authors use an offline approach.

They take a dataset of queries and run them through all available draft models and the target model independently. They then calculate a “score” representing how helpful a draft model would have been.

The Alignment Score

The most direct measure of a draft model’s utility is how similar its output (\(o_{i}^{j}\)) is to the target model’s output (\(o_{i}^{e}\)). The authors use metrics like ROUGE-L to calculate a similarity score \(f\):

Formula for similarity score

Incorporating Cost

However, alignment isn’t the only factor. A draft model might be highly accurate but too large (slow) to provide a net speedup. Conversely, a tiny model might be very fast but slightly less accurate.

To account for this, the researchers introduce a cost-aware reward function. They define a cost \(c_{i}^{j}\) based on the number of parameters in the draft model relative to the largest candidate. The final score combines alignment and cost using a trade-off parameter \(\alpha\):

Formula for weighted score

Here, \(\alpha\) allows us to tune the system. An \(\alpha\) of 1 cares only about accuracy; an \(\alpha\) near 0 prioritizes the smallest/fastest model. The cost function \(c_i\) is derived from the parameter counts (\(p\)) of the models:

Cost function equation

This setup creates a labelled dataset where every query is associated with the “rewards” (scores) of using different draft models.

2. Policy Training

With the dataset ready, the goal is to train a Policy Network (\(\pi\)). This is a lightweight Multi-Layer Perceptron (MLP).

  • Input: An embedding of the user’s query (usually the sentence embedding from the target model).
  • Output: A probability distribution over the available draft models (actions).

The objective function \(J^{\pi}\) aims to maximize the expected reward:

Objective function J

Since the action space (which model to pick) is discrete, the integration over actions is a summation:

Integration over actions

To optimize this, the authors use the REINFORCE algorithm (a standard policy gradient method). The gradient update looks like this:

Gradient update equation

Essentially, if a draft model yields a high reward for a specific type of query, the network updates its weights to increase the probability of selecting that model for similar queries in the future.


Experiments & Results

The authors validated this approach using T5 and Flan-T5 models across various tasks like Translation (IWSLT), Summarization (XSUM), and Math (GSM8K).

1. Domain Specialization

The most compelling test is whether the policy can distinguish between domain experts. They set up a scenario with two draft models of the same size:

  1. T5-Small: Standard pre-trained model.
  2. T5-Small-XSum: Fine-tuned specifically for summarization.

They tested these on translation and summarization tasks.

Table 1: Quality, decoding speeds and acceptance rates.

Table 1 (above) reveals the critical insight:

  • The T5-Small model speeds up translation (1.10x) but slows down summarization (0.97x).
  • The T5-Small-XSum model is great at summarization (1.21x) but terrible at translation (0.83x slowdown).
  • The Policy (\(\pi_\theta\)): It achieves speedups on both tasks (1.09x and 1.17x). It successfully identifies the context and routes the query to the correct expert, avoiding the pitfalls of using the wrong draft model.

2. Balancing Speed vs. Accuracy

By adjusting the \(\alpha\) parameter in the reward function, the policy can be tuned to prefer “safer” (more accurate) models or “riskier” (faster/smaller) models.

Figure 2: Effect of varying the tradeoff between output alignment and draft model size.

Figure 2 visualizes this trade-off. As \(\alpha\) increases (prioritizing alignment), the policy shifts preference toward larger, more accurate draft models. As \(\alpha\) decreases (prioritizing low cost), it favors the smallest valid draft model. This proves the system is flexible to the specific latency/compute constraints of the deployment environment.

3. Data Efficiency

A major concern with training auxiliary policies is the data requirement. Do we need millions of examples to train this router?

Figure 3: Decoding speed using a dynamic policy as a function of training examples.

Figure 3 shows the decoding speed as a function of training examples (log scale). Remarkably, the policy becomes effective with as few as 1,000 to 10,000 examples. This implies that creating a custom selector for a specific set of models is computationally cheap and doesn’t require massive datasets.

4. The “Do Nothing” Option

Sometimes, no draft model is good enough. For example, complex math reasoning often requires the full capacity of the target model; a small draft model will just hallucinate, causing rejections and slowing things down.

The authors added an “Auto-Regressive” action to the policy’s choices—effectively letting the manager decide to “do it myself.”

Table 4: Decoding speeds on GSM8K.

In Table 4, tested on the GSM8K math dataset, standard drafting slows inference down to ~0.75x because the draft models fail. However, the Policy (\(\pi_\theta\)) achieves ~0.95x speed. It learns to recognize math queries and avoids using the draft models, defaulting close to standard auto-regressive speed. While it doesn’t gain speed here, it prevents the catastrophic slowdowns that blind speculative decoding would suffer.

5. Self-Drafting (Early Exits)

The method also applies to Self-Speculative Decoding, where the draft “model” is just the earlier layers of the target model (an “early exit”).

Table 6: Results in a scenario for deciding when to early exit.

Table 6 demonstrates that even when selecting between different “layers” as drafters, the policy helps maintain performance, particularly when the Auto-Regressive option is available.


Conclusion & Implications

This research moves speculative decoding from a static, “hope-it-works” optimization to a dynamic, intelligent system. By framing assistant selection as a contextual bandits problem, the authors provide a framework that:

  1. Eliminates the risk of slowdowns caused by misaligned draft models.
  2. Enables modularity, allowing systems to combine multiple specialized small models rather than relying on one generalist drafter.
  3. Requires minimal overhead, with a lightweight policy that is easy to train and fast to run.

As LLMs continue to grow, the “one model to rule them all” approach is becoming computationally unsustainable for inference. We are likely moving toward ecosystems of models—massive reasoning engines supported by swarms of specialized, lightweight assistants. This paper provides the blueprint for the routing logic that will make such an ecosystem efficient.

For students and engineers, the takeaway is clear: Optimization isn’t just about making individual models faster; it’s about making the decision-making process between models smarter.