Introduction
We are living in the golden age of environmental data. From satellites orbiting the Earth to sensors floating in the oceans and weather stations dotting the landscape, we are collecting information about our planet at an unprecedented rate. Simultaneously, scientific computing models are generating massive datasets simulating fluid dynamics and atmospheric shifts.
For machine learning researchers, this data explosion presents a massive opportunity: building models that can accurately predict weather, simulate physics, and interpolate sparse measurements. We have seen the rise of “foundation models” for weather, such as GraphCast and Aurora, which leverage massive compute to predict global weather patterns.
However, there is a catch. Most of these state-of-the-art models suffer from a rigidity problem. They typically require data to be “structured”—specifically, lying on a fixed, regular grid. But real-world data is messy. Weather stations aren’t placed in a perfect lattice; they are clustered in cities and sparse in deserts. Observations come from different sources (modalities) at different locations and times.
To handle this, we need a framework that is flexible enough to ingest unstructured data but efficient enough to process it at scale.
Enter the Neural Process (NP). NPs are a family of meta-learning models that map arbitrary context data to probability distributions at arbitrary target locations. They are fantastic at handling irregular data and providing uncertainty estimates. However, historically, they haven’t scaled well. The most powerful variants, Transformer Neural Processes (TNPs), suffer from the quadratic computational cost of attention mechanisms (\(O(N^2)\)), making them unusable for large-scale spatio-temporal datasets with tens of thousands of points.
In this post, we will dive deep into a new architecture that solves this scalability paradox: Gridded Transformer Neural Processes (Gridded TNPs). This approach combines the flexibility of Neural Processes with the raw power and efficiency of modern Vision Transformers (ViTs) and Swin Transformers.
Background: The Challenge of Spatio-Temporal Data
Before dissecting the new architecture, let’s establish the playing field. We are dealing with spatio-temporal regression.
Imagine you have a set of temperature readings from 500 specific weather stations across the US (the context set). You want to predict the temperature at 1,000 other locations (the target set), or perhaps on a fine grid covering the whole country.
The Neural Process Framework
Neural Processes tackle this by learning a mapping from the context set directly to a predictive distribution. Generally, CNPs (Conditional Neural Processes) follow a three-step recipe, as visualized below:

- Encoder: Takes context points \((x_c, y_c)\) and maps them to a latent representation (tokens).
- Processor: Aggregates these tokens. In early versions, this was a simple sum (averaging the information). In Transformer NPs (TNPs), this involves self-attention layers where every point attends to every other point.
- Decoder: Takes the processed representation and the target location \(x_t\) to output a prediction (e.g., a mean and variance for a Gaussian distribution).
The Bottleneck
The problem lies in step 2. If you use a standard Transformer as the processor, you are calculating attention between every pair of points.
\[ \text{Cost} \propto (\text{Number of points})^2 \]If you have 10,000 observations, a standard TNP is prohibitively slow and memory-hungry.
On the other hand, models like Convolutional CNPs (ConvCNPs) try to solve this by projecting data onto a grid and using Convolutional Neural Networks (CNNs). This is efficient (\(O(N)\)), but CNNs lack the global receptive field and modeling power of Transformers. Furthermore, the projection method used by ConvCNPs (Kernel Interpolation) can be lossy and rigid.
The goal of Gridded TNPs is to get the best of both worlds: the efficiency of grid-based processing and the flexibility of attention-based encoding.
The Core Method: Gridded Transformer Neural Processes
The authors of “Gridded Transformer Neural Processes for Spatio-Temporal Data” propose a unified architecture that decomposes the problem into three distinct, optimized stages:
- Grid Encoder: Moves unstructured data onto a latent grid.
- Grid Processor: Uses efficient Transformers (ViT or Swin) on that grid.
- Grid Decoder: Moves from the latent grid back to arbitrary target locations.
Let’s look at the complete pipeline:

As shown in Figure 1, the model takes context observations (blue circles), encodes them into pseudo-tokens on a grid (red squares), processes them into deeper representations (dark red diamonds), and finally decodes them to the target locations (green crosses).
Let’s break down each component in detail.
1. The Pseudo-Token Grid Encoder
The first challenge is getting unstructured data (scattered points) onto a structured grid. The standard approach (used in ConvCNPs) is Kernel Interpolation, where you basically take a weighted average of nearby points to fill a grid cell.
Gridded TNPs introduce a smarter way: the Pseudo-Token Grid Encoder.
Instead of a fixed mathematical average, the model uses Cross-Attention. We define a set of learnable “pseudo-tokens” located at regular grid positions. These pseudo-tokens “query” the nearby real data points to gather information.
The mathematical operation for a pseudo-token \(u_m\) at grid location \(v_m\) is:

Here, the pseudo-token \(u_m\) attends only to the context points \(z_{c,n}\) that are within its local neighborhood \(\mathfrak{N}(v_m; k)\).
Why is this better?
- Learnable: The model learns how to aggregate data, rather than relying on a fixed kernel (like a Gaussian curve).
- Topography awareness: Since pseudo-tokens have their own learnable initial values (\(u^0_m\)), the model can learn fixed geographical features (like “this grid cell is usually a mountain peak”) even if no data is observed there in a specific instance.
To make this efficient, the authors use a clever “padding” trick to batch these operations, as illustrated here:

By padding neighborhoods with dummy tokens, the cross-attention can be parallelized efficiently across the GPU.
2. The Grid Processor: Unleashing Efficient Transformers
Once the data is encoded into tokens \(U\) on a regular grid, the heavy lifting begins. Because the data is now structured, we don’t need expensive full attention. We can use architectures designed for images.
The authors explore two main backbones:
- Vision Transformer (ViT): Patches the grid and applies attention.
- Swin Transformer: Computes attention within local windows that shift, allowing for interaction between neighboring windows.
The Swin Transformer proves to be particularly effective here. It scales linearly with the number of grid points (rather than quadratically) while maintaining the ability to capture complex, non-local dependencies through its hierarchical structure.
3. The Cross-Attention Grid Decoder
After processing, we have a grid of rich, contextualized tokens. But our target predictions might be anywhere—not necessarily on the grid points. We need to decode back to the continuous domain.
The authors propose a Nearest-Neighbor Cross-Attention (NN-CA) decoder.
For a target location \(x_{t,n}\), we identify the \(k\) nearest grid tokens and allow the target to attend to them:

This is a crucial design choice. A “Full Attention” decoder would allow every target point to look at every grid point, which would be computationally massive (\(O(M \cdot N_t)\)). By restricting attention to the \(k\) nearest neighbors, the complexity drops significantly.
Furthermore, this restriction acts as a beneficial inductive bias. In physics and weather, the state at a specific location is usually most influenced by its immediate surroundings. Forcing the model to look locally often improves accuracy compared to letting it look everywhere.
The authors handle the “nearest neighbor” search carefully, accounting for different grid geometries. For example, on a global weather map, the longitude wraps around (cylindrical geometry). The model knows that the far left of the map is neighbors with the far right:

Handling Multiple Modalities
Real-world data is rarely just “temperature.” It’s wind, pressure, humidity, and topography. Often, these variables are measured at different locations (e.g., pressure at one station, wind at another).
Gridded TNPs handle this elegantly. You can have separate encoders for each data source (modality) that all feed into the same latent grid. This allows the model to perform sensor fusion naturally, integrating disparate data sources into a unified state estimation.
Incorporation of Translation Equivariance
A major property of spatio-temporal data is Translation Equivariance (TE). If a weather system moves 100km east, the prediction should simply shift 100km east; the physics doesn’t change.
Standard Transformers are not naturally equivariant (they rely on absolute positional embeddings). The authors adapt the Translation Equivariant TNP (TE-TNP) framework to the gridded setting.
They replace standard attention with TE-Attention, which depends only on the relative distance between points:

However, strict equivariance can be too restrictive. Real-world data often has “symmetry breaking” features (like the shape of continents or mountain ranges) that are fixed in space.
To address this, the authors implement Approximate Equivariance. They feed the model additional “fixed” inputs (like position embeddings or terrain maps) but allow the model to ignore them via a specialized training mechanism (dropout on the symmetry-breaking features). This allows the model to be equivariant where it matters (the fluid dynamics) while respecting fixed geography.
Experiments and Results
The authors subjected Gridded TNPs to a battery of tests, ranging from synthetic Gaussian Processes to massive real-world weather datasets.
1. Synthetic Gaussian Processes (Proof of Scalability)
The first test was “Meta-Learning Gaussian Process Regression.” They generated complex 2D functions and tasked the models with interpolating them.
The results highlight the efficiency breakthrough. The plot below shows Accuracy (Log-Likelihood) vs. Speed (Forward Pass Time).

Key Takeaways from the Chart:
- Top Right is Best: We want high accuracy and low time.
- Gridded TNPs (Stars/Diamonds): They cluster at the top. The Swin-TNP (stars) provides the best trade-off.
- Baselines (Circles/Triangles): The ConvCNP (green circles) is fast but less accurate on complex data. The standard pseudo-token TNP (black triangles) is much slower for the same accuracy.
Qualitatively, the Gridded TNP recovers the ground truth much sharper than the baselines:

Notice how the Swin-TNP (b) captures the sharp, high-frequency peaks of the Ground Truth (a) much better than the smoother, blurrier predictions of the ConvCNP (c) or PT-TNP (d).
2. Real-World Weather: Combining Stations and Satellites
This is the “killer app” for this architecture. The task: predict 2-meter temperature (\(t2m\)) at roughly 10,000 weather stations across the globe.
- Context: Structured skin temperature data (from satellites/reanalysis) + sparse temperature readings from a random subset of stations.
- Target: Predict temperature at all station locations.
The geographic distribution of these stations is highly irregular:

Quantitative Results:
The table below summarizes the performance. Higher Log-Likelihood is better; lower RMSE is better.

Analysis:
- The Swin-TNP with Pseudo-Token Grid Encoder (PT-GE) achieves a Log-Likelihood of 1.819, significantly outperforming the ConvCNP (1.689) and standard TNP (1.344).
- Scale matters: Even a smaller Swin-TNP outperforms a significantly larger ConvCNP.
- Encoder Choice: The Pseudo-Token Grid Encoder (PT-GE) consistently beats the Kernel Interpolation (KI-GE) method, proving that learning how to grid the data is better than mathematically interpolating it.
Visualizing Errors:
If we map the errors, we can see the improvement. Lighter colors mean lower error.

The Swin-TNP (top map) shows visibly lower errors (more white/pale regions) across North America and Europe compared to the baselines.
3. Multi-Modal Wind Speed
In this experiment, the model had to predict wind speed components (\(u\) and \(v\)) at three different atmospheric pressure levels. This is a “multi-modal” task because the inputs for different variables might not align perfectly.
The model used the Translation Equivariant (TE) inductive bias here.

The results (Table 2) show that adding Translation Equivariance (Swin-TNP (\(T\))) boosts performance significantly over the non-equivariant version. Relaxing it to Approximate Equivariance (\(\tilde{T}\)) improves it even further, likely because it allows the model to learn local geographic quirks while maintaining general physical rules.
Figure 22 visualizes the wind vectors. The Swin-TNP captures the flow dynamics much more accurately, with fewer large-scale error artifacts than the PT-TNP.

4. Fluid Dynamics on Meshes (EAGLE Dataset)
Finally, to prove this isn’t just a weather model, they applied it to the EAGLE dataset—simulations of drones flying through 2D scenes. This data lives on irregular meshes, not grids.

The Swin-TNP successfully modeled the complex fluid interactions (velocity and pressure fields), demonstrating that the “Pseudo-Token Grid Encoder” can effectively translate irregular mesh data into a format that Transformers can understand and predict.
Conclusion: A New Standard for Spatio-Temporal Modeling?
The Gridded Transformer Neural Process represents a significant step forward in our ability to model the physical world. By acknowledging that data comes in messy, unstructured formats but computing is most efficient on structured grids, the authors have built a bridge between the two.
Key Takeaways:
- Flexibility meets Efficiency: The architecture handles arbitrary inputs (via the Grid Encoder) and arbitrary outputs (via the Grid Decoder) but does the heavy processing on a latent grid using efficient Swin Transformers.
- Learn, Don’t Interpolate: The attention-based Pseudo-Token Grid Encoder outperforms traditional kernel interpolation, allowing the model to “learn” how to structure the data.
- Local is Good: Using Nearest-Neighbor cross-attention in the decoder is not just faster; it acts as a beneficial inductive bias for physical systems.
- Scalability: This framework allows Neural Processes to scale to datasets with hundreds of thousands of points, a realm previously reserved for rigid, grid-only models.
As we move toward a future of “Digital Twins” and AI-driven weather forecasting that integrates data from smartphones, cars, and IoT sensors alongside traditional satellites, architectures like the Gridded TNP will likely form the backbone of these systems. They provide the necessary mathematical translation layer to turn the chaos of real-world observations into the structured understanding of foundation models.
](https://deep-paper.org/en/paper/5244_gridded_transformer_neura-1656/images/cover.png)