Robots That Remember: Improving Diffusion Policies with Past-Token Prediction

In the world of robotics, memory is everything. Imagine trying to make a cup of coffee without remembering if you’ve already added the sugar, or trying to unlock a door without recalling which key you just tried. For humans, these temporal dependencies—the relationship between what happened five seconds ago and what should happen next—are intuitive. For robots, they are a massive computational and algorithmic headache.

In recent years, Diffusion Policies have emerged as a dominant approach for robotic imitation learning. They generate smooth, complex behaviors by learning from human demonstrations. However, a strange paradox exists: while we want robots to “remember” more history to make better decisions, feeding longer sequences of data into these models often makes them perform worse and training them becomes prohibitively expensive.

This blog post breaks down a fascinating new paper, “Learning Long-Context Diffusion Policies via Past-Token Prediction.” The researchers propose a clever regularization technique and a training pipeline that not only allows robots to effectively utilize long histories but also speeds up training by over 10x.

Overview of the proposed framework showing performance gains and training speedups.

As shown in Figure 1, the proposed method (green line) achieves significantly higher performance in less training time compared to naive approaches. Let’s dive into why this problem exists and how they solved it.


The Background: The Paradox of Robot Memory

The Challenge of Imitation Learning

Imitation learning (IL) is the process of training a robot to mimic expert behavior. A standard IL policy takes an observation (\(o_t\))—usually an image from a camera—and predicts an action (\(a_t\)).

However, many tasks are non-Markovian, meaning the current image doesn’t tell the whole story.

  • Occlusion: If a robot arm blocks the camera view of an object, the robot needs to remember where the object was.
  • Multi-stage tasks: If a robot is packing a box, it needs to know which items have already been packed to decide what to grab next.

To solve this, researchers typically feed a history of observations (\(o_{t-k} \dots o_t\)) into the model. Theoretically, more history should equal better context and better performance. In practice, it often leads to failure.

The “Copycat” vs. “Amnesia” Problem

Historically, older regression-based policies suffered from the Copycat Problem. When given a history of actions, the model would learn a spurious correlation: “The best prediction for the next action is simply the previous action.” The robot would just repeat its last move, ignoring the visual changes in the environment.

However, the authors of this paper discovered that modern Diffusion Policies suffer from the exact opposite problem. Instead of over-relying on history, they often ignore it entirely.

Comparison of regression-based and diffusion-based policies regarding temporal action dependency.

Figure 2 illustrates this phenomenon. The Y-axis represents the “Action Predictability Ratio”—essentially, how much the model relies on past actions compared to a human expert.

  • Regression (White bars): These models rely too much on the past (ratios > 1), leading to copycat behavior.
  • Diffusion (Green hatched bars): These models rely too little on the past (ratios < 1).

Despite being fed a long history of images, the diffusion model fails to capture the essential dependencies between past and future actions. It essentially suffers from amnesia, treating the input history as noise rather than a structured sequence of events.


The Core Method: Past-Token Prediction (PTP)

To fix this “amnesia,” the researchers introduce Past-Token Prediction (PTP).

The idea is simple yet profound. Standard policies are trained to predict the future trajectory of actions (\(a_t, a_{t+1}, \dots\)). PTP modifies the loss function so that the policy must also predict the past actions (\(a_{t-k}, \dots, a_{t-1}\)) based on the current context.

Why does this work?

If a neural network is forced to accurately reconstruct what it just did based on its internal state, it is mathematically forced to encode and retain information about the history. It cannot ignore the past observations if its loss function depends on predicting the past actions correctly.

Illustration of past-token prediction showing bidirectional arrows through the policy.

As visualized in Figure 3, the policy head (denoted by \(\pi\)) now has bidirectional outputs. It looks at the history of observations and generates a sequence that spans from the past into the future.

Formally, this changes the prediction target. Instead of just predicting \(\mathbf{a}_{t:t+l}\) (future), the model predicts:

\[ \hat { \mathbf { a } } _ { t - k : t + l } = \pi _ { \theta } \bigl ( \mathbf { o } _ { t - k : t } \bigr ) . \]

By explicitly regularizing the model to retain past information, PTP ensures that the “style” and “strategy” of the past actions influence the future actions, effectively bridging the gap between historical context and future planning.


Solving the Compute Bottleneck: Multi-Stage Training

Training a model on long sequences of high-dimensional images (e.g., 16 frames of video per step) is computationally expensive. It eats up GPU memory (VRAM) and slows down training because the visual encoder (usually a ResNet or Vision Transformer) has to process every frame in the history for every training step.

The authors observed that the benefits of PTP primarily come from the policy head (the sequence modeling part), not the visual encoder (the image processing part).

Based on this, they propose a Multi-Stage Training Strategy that drastically reduces overhead.

Overview of multistage training with embedding caching.

Figure 4 breaks down this efficient pipeline:

  1. Stage 1: Encoder Training (Short Context) First, they train a standard policy with a short context (e.g., 2 frames). This is fast and cheap. The goal here is simply to learn a good visual encoder that understands the scene.

  2. Stage 2: Feature Caching Once the encoder is trained, they freeze it. They run the entire dataset of demonstration videos through this frozen encoder. Instead of storing images, they store the compact embeddings (feature vectors) to a disk cache.

  3. Stage 3: Policy Training (Long Context) Finally, they train the PTP policy head using the cached embeddings. Since the visual encoder is no longer running, the memory footprint drops massively. The model can now ingest long histories (e.g., 16 steps) of pre-computed embeddings instantly.

This approach allows them to train long-context policies 10x faster than end-to-end approaches, without sacrificing performance.

Graph showing the effect of feature caching on training speed.

Figure 8 demonstrates this efficiency. The standard method (black dashed line) barely makes progress over time because each epoch takes so long. The caching method (black solid line) learns rapidly, achieving high success rates in a fraction of the time.


Smart Inference: Self-Verification

The benefits of Past-Token Prediction extend beyond training. During deployment (test time), PTP gives the robot a superpower: Self-Verification.

Diffusion policies are probabilistic—they sample actions from a distribution. Sometimes they sample a “bad” plan that doesn’t align with what the robot was just doing. Because the PTP model predicts both past and future actions, we can use the past prediction as a quality check.

Here is the logic:

  1. The robot knows exactly what actions it took in the previous steps (Ground Truth Past).
  2. The robot generates \(B\) different candidate plans for the future. Each plan also contains a “reconstructed past.”
  3. The robot compares the “reconstructed past” of each candidate against the “Ground Truth Past.”
  4. The candidate with the most accurate past reconstruction is likely to have the most consistent future plan.

Diagram of test-time verification showing selection of the best action sequence.

Figure 5 illustrates this selection process. The policy samples a batch of sequences (\(\mathcal{A}\)):

\[ \mathcal { A } = \{ \hat { \mathbf { a } } ^ { ( 1 ) } , \ldots , \hat { \mathbf { a } } ^ { ( B ) } \} , \quad \hat { \mathbf { a } } ^ { ( i ) } \sim \pi _ { \boldsymbol { \theta } } ( \mathbf { o } _ { t - k : t } ) , \]

It then selects the best one (\(\hat{\mathbf{a}}^*\)) by minimizing the error between predicted past actions (\(\hat{a}_\tau\)) and actual past actions (\(a_\tau\)):

\[ \hat { \mathbf { a } } ^ { * } = \arg \operatorname* { m i n } _ { \hat { \mathbf { a } } \in \mathcal { A } } \sum _ { \tau = t - k } ^ { t - 1 } \| \hat { a } _ { \tau } - a _ { \tau } \| ^ { 2 } \]

This acts as a runtime filter, discarding hallucinations or inconsistent plans before they are executed on the hardware.

Effect of PTP self-verification on success rates.

As seen in Figure 10, this inference-time verification boosts success rates by roughly 5% on challenging tasks simply by sampling more candidates (increasing the budget) and picking the best one.


Experiments and Results

Does forcing the robot to predict the past actually help it perform better in the future? The results are compelling.

Simulation Results

The researchers tested the method on six simulated tasks, including “Square,” “Tool Hang,” and complex “Long-Horizon” variants that specifically require memory.

Comparison of policies across six simulation tasks.

Figure 9 compares three approaches:

  1. Gray: No History (Context length = 1 or 2).
  2. Red: Long History without PTP (Standard Diffusion).
  3. Green: Long History with PTP (Ours).

On standard tasks (like Push-T), PTP performs on par with the best baselines. However, on Long-Horizon tasks (the right side of the chart), the difference is night and day. Standard diffusion policies (Red) fail almost completely because they cannot maintain context over time. The PTP method (Green) achieves near-perfect success rates.

Furthermore, looking at the dependency analysis:

Graph showing the effect of PTP on temporal action dependency.

Figure 6 confirms that increasing the PTP supervision (moving right on the x-axis) directly correlates with higher success rates. By forcing the model to respect temporal dependencies, the robot becomes a better imitator.

Real-World Performance

The team also validated the approach on real robots (Franka Emika Panda and ALOHA) performing tasks like counting scoops of powder or replacing objects. These tasks are impossible without memory—if you forget how many scoops you’ve done, you fail.

Comparison of different policies on real-world tasks.

Figure 11 shows the real-world results. The “No History” and “No PTP” baselines struggle significantly (15-20% success). The PTP method jumps to over 70% success, proving that the technique is robust enough for physical deployment.

The Role of Context Length

Finally, does adding more history help once PTP is enabled?

Effect of history observations on PTP-trained policies.

Figure 7 shows that as context length increases (up to 16 frames), success rates climb, particularly for the “Long-Horizon Square” task (orange line). This validates that PTP successfully unlocks the ability to use long contexts, turning what used to be a liability (more data causing confusion) into an asset.


Conclusion

The paper “Learning Long-Context Diffusion Policies via Past-Token Prediction” offers a masterclass in diagnosing and fixing a subtle failure mode in deep learning.

The key takeaways are:

  1. Diagnosis: Modern diffusion policies tend to ignore history, failing to capture temporal consistency.
  2. Solution: Training the model to predict the past forces it to encode history, improving its future predictions.
  3. Efficiency: Decoupling the visual encoder from the policy head via caching makes training 10x faster.
  4. Verification: The ability to predict the past allows the robot to self-verify its plans at runtime.

By treating the past as a supervision signal rather than just an input, this method enables robots to perform complex, multi-stage tasks that were previously out of reach for diffusion-based imitation learning. It turns out, for a robot to move forward effectively, it helps to look back.