引言
我们正处于环境数据的黄金时代。从环绕地球的卫星到漂浮在海洋中的传感器,再到点缀在陆地上的气象站,我们正以前所未有的速度收集关于地球的信息。与此同时,科学计算模型正在生成海量数据集,模拟流体动力学和大气变化。
对于机器学习研究人员而言,这种数据爆炸带来了一个巨大的机遇: 构建能够准确预测天气、模拟物理过程并对稀疏测量值进行插值的模型。我们已经看到了针对天气的“基础模型”的兴起,例如 GraphCast 和 Aurora,它们利用巨大的算力来预测全球天气模式。
然而,这其中存在一个难题。大多数这些最先进的模型都面临着僵化的问题。它们通常要求数据是“结构化的”——具体来说,就是位于固定的、规则的网格上。但现实世界的数据是杂乱的。气象站并非布置在完美的晶格中;它们聚集在城市中,而在沙漠中则很稀疏。观测数据来自不同的来源 (模态) ,位于不同的地点和时间。
为了解决这个问题,我们需要一个既足够灵活以接收非结构化数据,又足够高效以进行大规模处理的框架。
神经过程 (Neural Process, NP) 登场了。这是一类元学习模型,能够将任意上下文数据映射到任意目标位置的概率分布。它们非常擅长处理不规则数据并提供不确定性估计。然而,从历史上看,它们的扩展性并不好。最强大的变体——Transformer 神经过程 (TNPs) ,受制于注意力机制的二次计算成本 (\(O(N^2)\)) ,导致它们无法用于包含数万个点的大规模时空数据集。
在这篇文章中,我们将深入探讨一种解决这一可扩展性悖论的新架构: 网格化 Transformer 神经过程 (Gridded Transformer Neural Processes, Gridded TNPs) 。 这种方法结合了神经过程的灵活性与现代视觉 Transformer (ViT) 和 Swin Transformer 的强大算力及效率。
背景: 时空数据的挑战
在剖析新架构之前,让我们先明确一下应用场景。我们要处理的是时空回归问题。
想象一下,你有一组来自美国各地 500 个特定气象站的温度读数 (上下文集) 。你想预测其他 1,000 个位置 (目标集) 的温度,或者可能是覆盖整个国家的精细网格上的温度。
神经过程框架
神经过程通过学习从上下文集直接到预测分布的映射来解决这个问题。通常,CNP (条件神经过程) 遵循如下所示的三步法:

- 编码器 (Encoder) : 获取上下文点 \((x_c, y_c)\) 并将其映射为潜在表示 (token) 。
- 处理器 (Processor) : 聚合这些 token。在早期版本中,这只是简单的求和 (平均信息) 。在 Transformer NPs (TNPs) 中,这涉及自注意力层,其中每个点都会关注其他所有点。
- 解码器 (Decoder) : 利用处理后的表示和目标位置 \(x_t\) 输出预测 (例如,高斯分布的均值和方差) 。
瓶颈
问题出在第 2 步。如果你使用标准的 Transformer 作为处理器,你需要计算每对点之间的注意力。
\[ \text{Cost} \propto (\text{Number of points})^2 \]如果你有 10,000 个观测值,标准的 TNP 速度慢得令人望而却步,且极其消耗内存。
另一方面,像卷积 CNP (ConvCNPs) 这样的模型试图通过将数据投影到网格上并使用卷积神经网络 (CNN) 来解决这个问题。这很高效 (\(O(N)\)) ,但 CNN 缺乏 Transformer 的全局感受野和建模能力。此外,ConvCNP 使用的投影方法 (核插值) 可能是有损且僵化的。
网格化 TNP 的目标是鱼和熊掌兼得: 既有基于网格处理的效率,又有基于注意力编码的灵活性。
核心方法: 网格化 Transformer 神经过程
“Gridded Transformer Neural Processes for Spatio-Temporal Data” 的作者们提出了一种统一的架构,将问题分解为三个独特的、经过优化的阶段:
- 网格编码器 (Grid Encoder) : 将非结构化数据移动到潜在网格上。
- 网格处理器 (Grid Processor) : 在该网格上使用高效的 Transformer (ViT 或 Swin) 。
- 网格解码器 (Grid Decoder) : 从潜在网格移回任意目标位置。
让我们看看完整的流程:

如图 1 所示,模型获取上下文观测值 (蓝色圆圈) ,将其编码为网格上的伪 token (红色方块) ,将其处理为更深层的表示 (深红色菱形) ,最后将其解码到目标位置 (绿色十字) 。
让我们详细分解每个组件。
1. 伪 Token 网格编码器
第一个挑战是将非结构化数据 (散点) 放到结构化网格上。标准方法 (用于 ConvCNP) 是核插值 , 基本上是利用附近点的加权平均值来填充网格单元。
网格化 TNP 引入了一种更聪明的方法: 伪 Token 网格编码器 (Pseudo-Token Grid Encoder) 。
模型不使用固定的数学平均值,而是使用交叉注意力 。 我们定义了一组位于规则网格位置的可学习“伪 token”。这些伪 token 会“查询”附近的真实数据点以收集信息。
位于网格位置 \(v_m\) 的伪 token \(u_m\) 的数学运算如下:

在这里,伪 token \(u_m\) 仅 对其局部邻域 \(\mathfrak{N}(v_m; k)\) 内的上下文点 \(z_{c,n}\) 进行注意力计算。
为什么这样更好?
- 可学习: 模型学习如何聚合数据,而不是依赖固定的核 (如高斯曲线) 。
- 地形感知: 由于伪 token 拥有自己的可学习初始值 (\(u^0_m\)) ,模型可以学习固定的地理特征 (比如“这个网格单元通常是一个山峰”) ,即使在特定实例中没有在该处观测到数据。
为了提高效率,作者使用了一个巧妙的“填充”技巧来批量处理这些操作,如下图所示:

通过用虚拟 token 填充邻域,交叉注意力可以在 GPU 上高效并行化。
2. 网格处理器: 释放高效 Transformer 的威力
一旦数据被编码为规则网格上的 token \(U\),繁重的工作就开始了。因为数据现在是结构化的,我们不需要昂贵的全注意力计算。我们可以使用专为图像设计的架构。
作者探索了两个主要的骨干网络:
- Vision Transformer (ViT): 对网格进行分块 (patch) 并应用注意力。
- Swin Transformer (Swin): 在移动的局部窗口内计算注意力,允许相邻窗口之间进行交互。
Swin Transformer 在这里证明特别有效。它随网格点数量线性扩展 (而不是二次方) ,同时通过其层级结构保持捕捉复杂、非局部依赖关系的能力。
3. 交叉注意力网格解码器
处理之后,我们得到了一个包含丰富上下文信息的 token 网格。但我们的目标预测可能在任何地方——不一定在网格点上。我们需要解码回连续域。
作者提出了最近邻交叉注意力 (NN-CA) 解码器。
对于目标位置 \(x_{t,n}\),我们要识别 \(k\) 个最近的网格 token,并允许目标对它们进行关注:

这是一个至关重要的设计选择。“全注意力”解码器将允许每个目标点查看每个网格点,这将导致巨大的计算量 (\(O(M \cdot N_t)\)) 。通过将注意力限制在 \(k\) 个最近邻,复杂度显著下降。
此外,这种限制充当了一种有益的归纳偏置 。 在物理学和天气学中,特定位置的状态通常受其直接周围环境的影响最大。与允许模型查看所有地方相比,强迫模型关注局部通常能提高准确性。
作者仔细处理了“最近邻”搜索,考虑了不同的网格几何形状。例如,在全球天气地图上,经度是环绕的 (圆柱几何结构) 。模型知道地图的最左端与最右端是邻居:

处理多种模态
现实世界的数据很少只是“温度”。它包括风、气压、湿度和地形。通常,这些变量是在不同位置测量的 (例如,气压在一个站点,风速在另一个站点) 。
网格化 TNP 优雅地处理了这个问题。你可以为每个数据源 (模态) 设置单独的编码器,它们都馈送到同一个潜在网格中。这使得模型能够自然地执行传感器融合 , 将不同的数据源集成到一个统一的状态估计中。
引入平移等变性
时空数据的一个主要属性是平移等变性 (Translation Equivariance, TE) 。 如果一个天气系统向东移动 100 公里,预测结果也应该简单地向东移动 100 公里;物理规律不会改变。
标准的 Transformer 天生不具备等变性 (它们依赖于绝对位置嵌入) 。作者将平移等变 TNP (TE-TNP) 框架调整到了网格化设置中。
他们用 TE-Attention 替换了标准注意力,前者仅依赖于点之间的相对距离:

然而,严格的等变性可能限制性太强。现实世界的数据通常具有空间上固定的“对称破缺”特征 (如大陆的形状或山脉) 。
为了解决这个问题,作者实现了近似等变性 。 他们向模型提供额外的“固定”输入 (如位置嵌入或地形图) ,但通过专门的训练机制 (对对称破缺特征进行 dropout) 允许模型忽略它们。这使得模型在关键部分 (流体动力学) 保持等变性,同时尊重固定的地理特征。
实验与结果
作者对网格化 TNP 进行了一系列测试,范围从合成高斯过程到大规模真实世界天气数据集。
1. 合成高斯过程 (可扩展性证明)
第一个测试是“元学习高斯过程回归”。他们生成了复杂的 2D 函数,并要求模型对其进行插值。
结果凸显了效率的突破。下图展示了准确率 (对数似然) 与速度 (前向传播时间) 的关系。

图表中的关键结论:
- 右上角最好: 我们希望准确率高且时间短。
- 网格化 TNP (星形/菱形) : 它们聚集在顶部。 Swin-TNP (星形) 提供了最佳的权衡。
- 基线 (圆形/三角形) : ConvCNP (绿色圆形) 速度快,但在复杂数据上准确率较低。标准的伪 token TNP (黑色三角形) 在相同准确率下速度要慢得多。
从定性上看,网格化 TNP 比基线更清晰地恢复了真实值:

请注意, Swin-TNP (b) 比 ConvCNP (c) 或 PT-TNP (d) 平滑、模糊的预测更好地捕捉到了 Ground Truth (a) 中尖锐的高频峰值。
2. 真实世界天气: 结合气象站与卫星数据
这是该架构的“杀手级应用”。任务: 预测全球约 10,000 个气象站的 2 米温度 (\(t2m\)) 。
- 上下文: 结构化的地表温度数据 (来自卫星/再分析) + 随机子集站点的稀疏温度读数。
- 目标: 预测所有站点位置的温度。
这些站点的地理分布非常不规则:

定量结果:
下表总结了性能。对数似然越高越好;RMSE 越低越好。

分析:
- 带有伪 Token 网格编码器 (PT-GE) 的 Swin-TNP 实现了 1.819 的对数似然,显著优于 ConvCNP (1.689) 和标准 TNP (1.344) 。
- 规模很重要: 即使是较小的 Swin-TNP 也优于明显更大的 ConvCNP。
- 编码器选择: 伪 Token 网格编码器 (PT-GE) 始终击败核插值 (KI-GE) 方法,证明学习如何网格化数据比数学插值更好。
误差可视化:
如果我们绘制误差图,我们可以看到改进。颜色越浅表示误差越低。

与基线相比,Swin-TNP (顶部地图) 在北美和欧洲显示出肉眼可见的更低误差 (更多白色/浅色区域) 。
3. 多模态风速
在这个实验中,模型必须预测三个不同大气压层的风速分量 (\(u\) 和 \(v\)) 。这是一个“多模态”任务,因为不同变量的输入可能无法完美对齐。
模型在这里使用了平移等变性 (TE) 归纳偏置。

结果 (表 2) 表明,添加平移等变性 (Swin-TNP (\(T\))) 比非等变版本显著提高了性能。将其放宽为近似等变性 (\(\tilde{T}\)) 会进一步改善性能,可能是因为它允许模型在保持一般物理规则的同时学习局部地理特征。
图 22 可视化了风矢量。Swin-TNP 比 PT-TNP 更准确地捕捉了流动动力学,大尺度的误差伪影更少。

4. 网格上的流体动力学 (EAGLE 数据集)
最后,为了证明这不仅仅是一个天气模型,他们将其应用于 EAGLE 数据集——无人机飞越 2D 场景的模拟。这些数据位于不规则网格上,而不是规则网格上。

Swin-TNP 成功模拟了复杂的流体相互作用 (速度场和压力场) ,证明了“伪 Token 网格编码器”可以有效地将不规则网格数据转换为 Transformer 可以理解和预测的格式。
结论: 时空建模的新标准?
网格化 Transformer 神经过程 代表了我们模拟物理世界能力的一次重大飞跃。通过承认数据以杂乱、非结构化的格式存在,但在结构化网格上计算效率最高,作者搭建了两者之间的桥梁。
主要结论:
- 灵活性遇上效率: 该架构处理任意输入 (通过网格编码器) 和任意输出 (通过网格解码器) ,但使用高效的 Swin Transformer 在潜在网格上进行繁重的处理。
- 学习,而非插值: 基于注意力的伪 Token 网格编码器优于传统的核插值,允许模型“学习”如何结构化数据。
- 局部即正义: 在解码器中使用最近邻交叉注意力不仅更快;它还充当了物理系统有益的归纳偏置。
- 可扩展性: 该框架允许神经过程扩展到包含数十万个点的数据集,这是以前仅限于僵化的、纯网格模型才能涉足的领域。
随着我们迈向“数字孪生”和 AI 驱动的天气预报未来,整合来自智能手机、汽车和物联网传感器的数据以及传统卫星数据,像网格化 TNP 这样的架构很可能构成这些系统的骨干。它们提供了必要的数学转换层,将现实世界观测的混乱转化为基础模型的结构化理解。
](https://deep-paper.org/en/paper/5244_gridded_transformer_neura-1656/images/cover.png)