循环神经网络 (RNN) 、卷积神经网络 (CNN) 和 Transformer 已经彻底改变了我们处理文本、音频和时间序列等序列数据的方式。每种范式都很强大,但也都有其自身的局限性:

  • RNN 在推理时效率很高,但在长序列上训练缓慢,并且存在梯度消失问题。
  • CNN 可以并行训练,速度快,但它们在固定感受野之外表现不佳,且推理成本高。
  • Transformer 能够捕捉全局上下文,但其内存和计算量会随序列长度呈二次方增长。

如果我们能将这些方法的优点结合起来会怎样?想象一个具有以下特性的模型:

  • CNN 的可并行化训练速度。
  • RNN 的快速且有状态的推理能力。
  • 神经微分方程 (NDE) 的连续时间灵活性

这正是斯坦福大学和布法罗大学研究人员在 2021 年一篇论文中的宏伟目标。他们提出了线性状态空间层 (Linear State-Space Layer, LSSL) ——一个看似简单却功能强大的构建模块,结合了上述三种视角。本文将剖析 LSSL 的特别之处、工作原理,以及它为何能在长达数万步的序列上取得顶尖 (state-of-the-art) 的成果。


长序列的问题

无论是预测医疗传感器数据、分类语音,还是解读视频,序列模型都必须捕捉跨越多个时间步的依赖关系。在长序列下,主流方法各自面临不同的瓶颈:

  • RNN: 像 LSTM 和 GRU 这样的模型按步处理序列,并维护一个隐藏状态。虽然每步的内存和推理成本是常数级的,但它们训练缓慢且无法并行,长期依赖的梯度容易消失。
  • CNN: 时间卷积网络可并行应用滤波器,从而加快训练。然而固定的卷积核限制了上下文范围,并且在推理阶段需重新处理大量数据。
  • NDE: 这类模型将隐藏状态建模为连续时间函数,可处理不规则采样,并支持有理论基础的数学建模。但数值求解代价昂贵。

理想的模型应能结合并行训练、高效的循环推理以及连续时间的适应性——且不过丢长依赖的建模能力。LSSL 的目标正是如此。


背景: 状态空间与连续时间记忆

理解 LSSL 需要掌握两方面的理论:** 线性状态空间模型**以及 HiPPO 连续时间记忆。

状态空间模型: 源自控制理论的理念

状态空间模型不是直接将输入映射到输出,而是通过一个隐藏状态向量 \( x(t) \) 传递信息。其动态公式如下:

线性状态空间模型的核心方程。

连续时间状态空间表示。 \(u(t)\) 是输入,\(x(t)\) 是内部状态,\(y(t)\) 是输出。矩阵 ABCD 定义了系统动态。

这里:

  • \( u(t) \): 时间 \( t \) 的输入
  • \( x(t) \): 概括历史信息的隐藏状态
  • \( y(t) \): 输出
  • ABCD: 分别控制状态演化、输入影响、状态到输出映射以及直接输入输出路径的参数。

这些模型是连续时间的——非常适合不规则数据,但在深度学习中使用之前必须将其离散化


离散化: 从连续到离散

为了在常规硬件和离散数据上运行,我们选择步长 \( \Delta t \) 及一个近似连续动态的更新规则。论文采用了广义双线性变换 (Generalized Bilinear Transform, GBT) :

用于离散化线性常微分方程的广义双线性变换 (GBT) 更新规则。

GBT 公式。 当选择 \( \alpha = 1/2 \) 时,即为稳定离散化的经典双线性方法。

使用离散更新矩阵 \( \overline{A} \) 和 \( \overline{B} \),模型转化为线性递推式:

离散化后的状态空间模型。

离散时间递推。 这是 LSSL 的计算核心。

A 和 \( \Delta t \) 决定了模型能够记住什么以及在何种时间尺度上记住。


HiPPO: 有理论保障的长期记忆

随机生成的 A 矩阵无法解决梯度消失问题。高阶多项式投影算子 (High-order Polynomial Projection Operator, HiPPO) 框架构建状态向量 \( x(t) \),其各分量通过投影到多项式基上来近似输入历史 \( u(s) \)。

这种投影可以得到对过去的最优低维总结,并且——关键在于——系数的动态形式与线性状态空间模型一致:

\[ \dot{x}(t) = A x(t) + B u(t) \]

因此,HiPPO 提供了有理论保障且具备记忆特性的 A 矩阵。


线性状态空间层: 三种范式合一

LSSL 通过模拟离散化的状态空间模型,将输入序列 \( u \) 映射为输出 \( y \)。这一单一公式可以按三种方式计算,分别对应主流的序列模型范式。

LSSL 的三种视角。

图 1: LSSL 可视作连续时间模型、循环模型以及卷积模型。

  1. 连续时间视角: LSSL 由常微分方程 (ODE) 定义,可通过调整 \( \Delta t \) 来适应不规则数据或时间尺度变化。在 100Hz 下训练?在 200Hz 下测试?只需将 \( \Delta t \) 减半即可。
  2. 循环视角: 由方程 4 可知,LSSL 可像高效 RNN 一样运行: 逐步处理输入、维护隐藏状态 \( x_t \),且每步所需内存是常数级。
  3. 卷积视角: 从 \( x_{-1} = 0 \) 展开,可得 \( y_k \) 是所有过去 \( u_t \) 的加权和——即一次卷积

LSSL 输出的展开递推。

输出即卷积。 这使得我们能够并行训练。

卷积核 (滤波器) 由以下公式定义:

LSSL 的卷积核。

Krylov 函数定义了卷积核。 可利用快速傅里叶变换 (FFT) 计算,以实现并行训练。


表现力

尽管递推形式是线性的,LSSL 的表现力却意外强大:

  • 泛化卷积: 任意卷积滤波器都可以通过相应的状态空间模型近似。
  • 泛化 RNN: 流行 RNN 的门控机制在数学上对应于通过后向欧拉离散化学习 \( \Delta t \)。简单 LSSL 的深度堆叠可近似非线性 ODE,将非线性从时间步转移到网络深度。

让 LSSL 发挥威力: 长程记忆与高效计算

主要挑战有两点:

  1. 记忆: 若没有理论支撑的 A,反复乘以 \( \overline{A} \) 会导致状态爆炸或消失。
    解决方案: 采用基于 HiPPO 的结构化矩阵 (准可分矩阵) 来约束 A,已经证明对此类记忆任务是最优的。
  2. 计算: 直接学习 A 和 \( \Delta t \) 速度过慢;循环视角需要矩阵求逆,卷积视角在 \( k \) 很大时需计算 \( \overline{A}^k \)。
    解决方案: 利用准可分结构,可实现近线性复杂度的高效卷积核计算。

LSSL 实战: 顶尖成果

研究人员在多个基准测试中验证了 LSSL——涵盖标准数据集到极端长序列。

图像与时间序列基准

在逐像素图像分类任务上,LSSL 超越了此前的顶尖模型,尤其在顺序化 CIFAR-10 上提升超过 10%。

表 1 结果。

表 1: LSSL 在 sMNIST、pMNIST 和 sCIFAR 上的性能。

在医疗时间序列回归任务 (序列长度 4000) 中,LSSL 将均方根误差 (RMSE) 降低了超过三分之二。

医疗保健 RMSE 结果。

表 2: BIDMC 生命体征回归。


极端长序列: 最长可达 38,000 步

  1. 顺序化 CelebA: 将 \( 178 \times 218 \) 像素的图像展平成长度为 38,800 步的序列。LSSL 的表现几乎可与专门的 ResNet-18 媲美,同时参数量减少 10 倍。

表 3 CelebA 结果。

表 3: 顺序化 CelebA 分类。

  1. 原始语音指令: 直接处理包含 16,000 个采样点的音频。LSSL 比现有模型高出 20 多个百分点,甚至在使用 MFCC 特征 (序列缩短 100 倍) 的情况下也能超越所有基线。

表 4 Speech Commands 结果。

表 4: 原始语音 vs MFCC 语音分类,以及时间尺度自适应能力。

在测试时采样率翻倍的场景下,LSSL 只需调整 \( \Delta t \) 即可平稳适配,而许多其他模型则完全失效。


混合优势

  • 快速收敛: 强归纳偏置与卷积并行性结合,使得达到顶尖性能所需的训练轮数更少、总耗时更短。

训练速度比较。

表 5: LSSL 在训练轮数与分钟数上更快达到 SOTA。

  • 可学习的记忆与时间尺度: 消融实验显示,学习 A 和 \( \Delta t \) 能显著提升表现。使用随机 A 或固定 \( \Delta t \) 会降低准确率。

时间尺度可视化。

图 2: 学习到的逆时间尺度 \(1/\Delta t\) 在 Speech Commands 数据集覆盖相关范围。


结论与启示

线性状态空间层独特地统一了序列建模的三大范式:

  • CNN 的可并行化训练
  • RNN 的高效有状态推理
  • NDE 的连续时间适应性
  • HiPPO 提供的理论长程记忆保障

其在极长序列上的出色表现,超越了人工设计的流水线 (如语音中的 MFCC) ,预示着一个值得期待的未来: 模型可以直接从原始且复杂的信号中学习,而无需领域特定的预处理。

尽管早期的快速算法存在稳定性问题且内存开销较高,但后续研究已解决了多种限制,并催生了影响深远的结构化状态空间 (S4) 模型。

根植于控制理论与连续时间数学,LSSL 为建模超长序列提供了一个有理论基础的解决方案——让我们得以触达此前无法处理的数据。