在面向科学领域的深度学习世界里,结构决定一切。无论是蛋白质的折叠、RNA 链的扭曲,还是粒子系统的动力学,原子的几何排列决定了其功能。为了有效地模拟这些系统,神经网络必须理解两件事: 全局上下文 (分子的远端部分如何相互作用) 和等变性 (物理定律不会仅仅因为你旋转了分子而改变) 。

多年来,捕捉全局交互的黄金标准一直是 Transformer 及其自注意力机制。但这里有个问题。自注意力机制的复杂度呈二次方增长 (\(O(N^2)\)) 。如果你将分子的大小增加一倍,计算成本就会增加四倍。对于拥有数万个原子的庞大生物系统来说,标准的 Transformer 简直会耗尽内存。

Geometric Hyena (G-Hyena) 应运而生。

在这篇文章中,我们将探索一种新的架构,它挑战了注意力机制在几何深度学习中的统治地位。通过将长卷积模型 (特别是 Hyena 算子) 调整到 3D 空间,研究人员创建了一个比起前身更快、更精简且更有能力处理大规模数据的模型。

几何学习中的扩展性问题

要理解为什么 Geometric Hyena 是必要的,我们首先需要看看现有方法的瓶颈。

在对 3D 结构 (如蛋白质) 进行建模时,我们要处理两种类型的数据:

  1. 不变特征 (标量) : 不随旋转而改变的属性,例如原子类型 (碳、氮) 或电荷。
  2. 等变特征 (矢量/向量) : 随分子旋转而旋转的属性,例如 3D 坐标、速度或力。

一个稳健的模型必须能够同时处理这两者。标准方法通常分为两大阵营:

  • 消息传递神经网络 (MPNNs): 它们关注局部邻域 (例如,一个原子及其最近的邻居) 。它们效率很高,但难以“看清”全貌。信息从蛋白质的一端传播到另一端需要很长时间。
  • 等变 Transformer: 它们使用注意力机制让每个原子与所有其他原子进行对话。这完美地捕捉了全局上下文,但成本极其高昂。

下图生动地说明了这一瓶颈。

图1。左: GPU 前向传播运行时间对比。Geometric Hyena 呈次二次方扩展,与具有全局上下文的其他等变模型相比实现了显著加速。右: G-Hyena 的峰值 GPU 显存消耗对于长序列来说是最高效的。

随着序列长度 (原子/Token 的数量) 接近 30,000,基于标准 Transformer 的模型 (如红线和紫线所示的 VNT 或 Equiformer) 的运行时间和内存使用量呈爆炸式增长。相比之下, Geometric Hyena (青色线) 几乎保持平坦。事实上,G-Hyena 可以在单个 GPU 上处理高达 270 万个 Token 的上下文长度——比等变 Transformer 能处理的长度长 72 倍

Geometric Hyena 架构

G-Hyena 是如何实现这种效率的?秘密在于用长卷积 (Long Convolutions) 取代了昂贵的“全对全”注意力矩阵。

受大型语言模型 (LLMs) 中 Hyena 算子成功的启发,该架构以一种允许并行训练和次二次方推理的方式处理数据。但与标准文本处理不同,G-Hyena 必须遵循物理世界的 3D 对称性。

以下是 Geometric Hyena 模块的高级视图:

图2。Geometric Hyena 模块。(a) Geometric Hyena 模块包括 SE(3)-Hyena 算子和等变投影。(b) SE(3)-Hyena 算子包括查询、键、值投影,用于全局上下文聚合的几何长卷积,以及门控。

该架构总体上遵循 Transformer 熟悉的工作流程:

  1. 投影: 输入数据被投影为查询 (\(Q\))、键 (\(K\)) 和值 (\(V\))。
  2. 混合: 全局操作在序列间混合信息。
  3. 门控: 一种控制信息流动的机制。

然而,“混合”步骤 (在图 2b 中以绿色高亮显示) 正是创新发生的地方。G-Hyena 使用几何长卷积代替了注意力机制。

模型在整个过程中保持严格的等变性。在数学上,模型 \(\Psi\) 满足以下性质,确保如果你旋转输入几何 Token \(\mathbf{x}\),输出 \(\hat{\mathbf{x}}\) 也会以完全相同的方式旋转:

定义模型 Psi 关于群作用 Lg 的等变性属性的方程。

让我们拆解这个新算子的核心组件。

1. 投影层 (The Projection Layer)

在卷积之前,我们需要嵌入我们的标量 (不变) 和矢量 (等变) 输入。作者使用了一个受等变图神经网络 (EGNN) 启发的层。该层处理局部邻域,为全局卷积生成丰富的嵌入。它充当了一座桥梁,将原始原子数据转换为 Hyena 算子所需的 \(Q, K, V\) 格式。

2. 标量长卷积 (Scalar Long Convolution)

对于标量特征 (“不变”流) ,该模型利用了现代序列模型中常见的标准长卷积。它没有执行朴素卷积 (速度较慢) ,而是利用了快速傅里叶变换 (FFT)

根据卷积定理,时域/空域中的卷积等同于频域中的乘法。这使得模型能够以 \(O(N \log N)\) 而非 \(O(N^2)\) 的时间复杂度计算全局交互。

方程 2 展示了通过快速傅里叶变换计算的标量长卷积。

在这里,\(\mathbf{q}\) 和 \(\mathbf{k}\) 是查询和键序列。它们被变换到频域 (\(\mathbf{F}\)),进行逐元素相乘,然后变换回来。这非常高效。

3. 等变矢量长卷积 (Equivariant Vector Long Convolution)

这是论文的主要贡献。标准 FFT 卷积适用于标量,但是如何在保持旋转等变性的同时对 3D 矢量进行卷积呢?

研究人员引入了基于叉积 (cross product)矢量长卷积 。 与点积 (结果为标量) 不同,两个矢量的叉积会产生一个新的矢量,该矢量与输入的旋转是等变的。

该操作被定义为序列上的叉积之和:

方程 3 将矢量长卷积定义为查询矢量和键矢量之间的叉积之和。

朴素地计算这个和会让我们回到二次复杂度。为了解决这个问题,作者将叉积卷积分解为更简单的组件。叉积可以写成逐元素乘法的特定组合。因此,矢量卷积可以分解为六个标量卷积之和:

该方程使用列维-奇维塔符号将矢量叉积卷积分解为标量卷积之和。

通过将矢量交互分解为标量分量 (由 \(h\) 和 \(p\) 索引) ,模型本质上可以对矢量的分量运行高效的基于 FFT 的卷积,然后将它们重新组合。这保留了 \(O(N \log N)\) 的速度,同时正确处理了 3D 矢量几何。

4. 几何长卷积: 混合标量和矢量

在物理系统中,几何形状 (矢量) 往往决定属性 (标量) ,反之亦然。如果一个模型将它们视为两个独立的流,那将是受限的。G-Hyena 引入了“几何长卷积”,允许这些子空间进行交互。

交互过程如下图所示:

图5。几何长卷积中的标量-矢量交互。蓝线表示导致标量输出 alpha 3 的交互,红线表示导致矢量输出 r3 的交互。

模型通过以每种旋转安全的方式混合输入,计算出一个新的标量输出 (\(\alpha_3\)) 和一个新的矢量输出 (\(\mathbf{r}_3\)):

  • 标量 \(\times\) 标量
  • 矢量 \(\cdot\) 矢量 (点积 \(\rightarrow\) 标量)
  • 标量 \(\times\) 矢量
  • 矢量 \(\times\) 矢量 (叉积 \(\rightarrow\) 矢量)

在数学上,输出标量是通过结合标量卷积和矢量流的点积来计算的:

方程 17 显示几何卷积的标量输出是标量卷积和矢量分量卷积的混合。

而输出矢量则结合了标量-矢量调制以及我们之前推导的叉积卷积:

方程 19 显示几何卷积的矢量输出结合了标量-矢量交互和矢量-矢量叉积。

这种全面的混合使 G-Hyena 能够学习复杂的依赖关系,其中分子的形状会影响其化学性质,反之亦然。

实验结果

研究人员在各种任务上验证了 Geometric Hyena,范围从合成基准测试到现实世界的分子动力学。

1. 几何联想回忆 (Geometric Associative Recall)

为了证明 G-Hyena 确实能够“学习”几何序列,作者设计了一个名为几何联想回忆的新任务。

图4。解释几何联想回忆任务的图表,模型必须根据序列中先前的出现情况,检索与查询键对应的向量值。

在这个任务中,模型看到一系列矢量对 (键,值) 。在序列的末尾,它会被展示一个特定的键 (已旋转) ,并且必须预测相应的值 (也同样旋转) 。这测试了模型在长上下文中存储和检索几何信息的能力。

结果是决定性的:

图3。上: 不同序列长度下,几何联想回忆任务的检索向量与目标向量之间的 MSE。下: 不同隐藏维度和词汇量下的性能研究。

如上图所示,标准 Transformer (绿色) 和纯 Hyena (红色) 表现不佳,因为它们缺乏适当的几何归纳偏置。 G-Hyena (青色三角形) 实现了接近零的误差,与等变注意力模型的理论性能相匹配,但效率要高得多。

2. 大分子属性预测 (RNA)

真正的考验在于生物数据。作者在 Open VaccineRibonanza-2k 数据集上测试了 G-Hyena。这些数据集需要预测大分子 RNA 的稳定性和降解概况,这些分子可能包含数千个原子。

表1。两种大型 RNA 分子稳定性和降解预测任务在全原子和骨架表示下的 RMSE。

在上表中,仅依赖局部上下文的方法 (红色) 通常比具有全局上下文的方法 (青色) 表现更差。然而,请注意 Equiformer 这一项: 它在较大的任务上出现了 OOM (内存溢出)

G-Hyena 在几乎所有类别中都实现了最低的 RMSE (误差) , 胜过了局部方法和沉重的基于注意力的全局方法。它在不挤爆 GPU 的情况下有效地捕捉了决定 RNA 稳定性的长程相互作用。

3. 蛋白质分子动力学

最后,该模型在预测蛋白质动力学 (预测原子下一步将移动到哪里) 方面进行了测试。

表3。预测的和真实的蛋白质全原子及骨架 MD 轨迹的 MSE。

我们再次看到了效率差距。FastEGNN 因数值不稳定 (NAN) 而失败,Equiformer 耗尽了内存。 G-Hyena 以最低的均方误差 (骨架上 1.80,全原子上 2.49) 完成了任务,证明它足够稳健,可以用于复杂的物理模拟。

结论与启示

Geometric Hyena Network 代表了我们处理几何深度学习方式的重大转变。多年来,社区已经接受了 Transformer 的二次方成本作为获取全局上下文的代价。这篇论文证明了这种权衡不再是必须的。

通过巧妙地调整快速傅里叶变换和矢量叉积,G-Hyena 实现了:

  1. 次二次方扩展: 使得百万级 Token 的上下文长度成为可能。
  2. 严格的等变性: 遵循 3D 旋转和平移的物理规律。
  3. 丰富的交互: 深度混合不变和等变数据。

这为在原子分辨率下模拟整个基因组、巨大的蛋白质复合物和宏观材料性质打开了大门——这些任务在以前是计算上不可能完成的。随着我们迈向生物学和化学领域的更大型“基础模型”,像 Geometric Hyena 这样高效的架构很可能成为下一代科学 AI 的支柱。