循环神经网络 (RNN) 、卷积神经网络 (CNN) 和 Transformer 已经彻底改变了我们处理文本、音频和时间序列等序列数据的方式。每种范式都很强大,但也都有其自身的局限性:
- RNN 在推理时效率很高,但在长序列上训练缓慢,并且存在梯度消失问题。
- CNN 可以并行训练,速度快,但它们在固定感受野之外表现不佳,且推理成本高。
- Transformer 能够捕捉全局上下文,但其内存和计算量会随序列长度呈二次方增长。
如果我们能将这些方法的优点结合起来会怎样?想象一个具有以下特性的模型:
- CNN 的可并行化训练速度。
- RNN 的快速且有状态的推理能力。
- 神经微分方程 (NDE) 的连续时间灵活性。
这正是斯坦福大学和布法罗大学研究人员在 2021 年一篇论文中的宏伟目标。他们提出了线性状态空间层 (Linear State-Space Layer, LSSL) ——一个看似简单却功能强大的构建模块,结合了上述三种视角。本文将剖析 LSSL 的特别之处、工作原理,以及它为何能在长达数万步的序列上取得顶尖 (state-of-the-art) 的成果。
长序列的问题
无论是预测医疗传感器数据、分类语音,还是解读视频,序列模型都必须捕捉跨越多个时间步的依赖关系。在长序列下,主流方法各自面临不同的瓶颈:
- RNN: 像 LSTM 和 GRU 这样的模型按步处理序列,并维护一个隐藏状态。虽然每步的内存和推理成本是常数级的,但它们训练缓慢且无法并行,长期依赖的梯度容易消失。
- CNN: 时间卷积网络可并行应用滤波器,从而加快训练。然而固定的卷积核限制了上下文范围,并且在推理阶段需重新处理大量数据。
- NDE: 这类模型将隐藏状态建模为连续时间函数,可处理不规则采样,并支持有理论基础的数学建模。但数值求解代价昂贵。
理想的模型应能结合并行训练、高效的循环推理以及连续时间的适应性——且不过丢长依赖的建模能力。LSSL 的目标正是如此。
背景: 状态空间与连续时间记忆
理解 LSSL 需要掌握两方面的理论:** 线性状态空间模型**以及 HiPPO 连续时间记忆。
状态空间模型: 源自控制理论的理念
状态空间模型不是直接将输入映射到输出,而是通过一个隐藏状态向量 \( x(t) \) 传递信息。其动态公式如下:
连续时间状态空间表示。 \(u(t)\) 是输入,\(x(t)\) 是内部状态,\(y(t)\) 是输出。矩阵 A、B、C、D 定义了系统动态。
这里:
- \( u(t) \): 时间 \( t \) 的输入
- \( x(t) \): 概括历史信息的隐藏状态
- \( y(t) \): 输出
- A、B、C、D: 分别控制状态演化、输入影响、状态到输出映射以及直接输入输出路径的参数。
这些模型是连续时间的——非常适合不规则数据,但在深度学习中使用之前必须将其离散化。
离散化: 从连续到离散
为了在常规硬件和离散数据上运行,我们选择步长 \( \Delta t \) 及一个近似连续动态的更新规则。论文采用了广义双线性变换 (Generalized Bilinear Transform, GBT) :
GBT 公式。 当选择 \( \alpha = 1/2 \) 时,即为稳定离散化的经典双线性方法。
使用离散更新矩阵 \( \overline{A} \) 和 \( \overline{B} \),模型转化为线性递推式:
离散时间递推。 这是 LSSL 的计算核心。
A 和 \( \Delta t \) 决定了模型能够记住什么以及在何种时间尺度上记住。
HiPPO: 有理论保障的长期记忆
随机生成的 A 矩阵无法解决梯度消失问题。高阶多项式投影算子 (High-order Polynomial Projection Operator, HiPPO) 框架构建状态向量 \( x(t) \),其各分量通过投影到多项式基上来近似输入历史 \( u(s) \)。
这种投影可以得到对过去的最优低维总结,并且——关键在于——系数的动态形式与线性状态空间模型一致:
\[ \dot{x}(t) = A x(t) + B u(t) \]因此,HiPPO 提供了有理论保障且具备记忆特性的 A 矩阵。
线性状态空间层: 三种范式合一
LSSL 通过模拟离散化的状态空间模型,将输入序列 \( u \) 映射为输出 \( y \)。这一单一公式可以按三种方式计算,分别对应主流的序列模型范式。
图 1: LSSL 可视作连续时间模型、循环模型以及卷积模型。
- 连续时间视角: LSSL 由常微分方程 (ODE) 定义,可通过调整 \( \Delta t \) 来适应不规则数据或时间尺度变化。在 100Hz 下训练?在 200Hz 下测试?只需将 \( \Delta t \) 减半即可。
- 循环视角: 由方程 4 可知,LSSL 可像高效 RNN 一样运行: 逐步处理输入、维护隐藏状态 \( x_t \),且每步所需内存是常数级。
- 卷积视角: 从 \( x_{-1} = 0 \) 展开,可得 \( y_k \) 是所有过去 \( u_t \) 的加权和——即一次卷积。
输出即卷积。 这使得我们能够并行训练。
卷积核 (滤波器) 由以下公式定义:
Krylov 函数定义了卷积核。 可利用快速傅里叶变换 (FFT) 计算,以实现并行训练。
表现力
尽管递推形式是线性的,LSSL 的表现力却意外强大:
- 泛化卷积: 任意卷积滤波器都可以通过相应的状态空间模型近似。
- 泛化 RNN: 流行 RNN 的门控机制在数学上对应于通过后向欧拉离散化学习 \( \Delta t \)。简单 LSSL 的深度堆叠可近似非线性 ODE,将非线性从时间步转移到网络深度。
让 LSSL 发挥威力: 长程记忆与高效计算
主要挑战有两点:
- 记忆: 若没有理论支撑的 A,反复乘以 \( \overline{A} \) 会导致状态爆炸或消失。
解决方案: 采用基于 HiPPO 的结构化矩阵 (准可分矩阵) 来约束 A,已经证明对此类记忆任务是最优的。 - 计算: 直接学习 A 和 \( \Delta t \) 速度过慢;循环视角需要矩阵求逆,卷积视角在 \( k \) 很大时需计算 \( \overline{A}^k \)。
解决方案: 利用准可分结构,可实现近线性复杂度的高效卷积核计算。
LSSL 实战: 顶尖成果
研究人员在多个基准测试中验证了 LSSL——涵盖标准数据集到极端长序列。
图像与时间序列基准
在逐像素图像分类任务上,LSSL 超越了此前的顶尖模型,尤其在顺序化 CIFAR-10 上提升超过 10%。
表 1: LSSL 在 sMNIST、pMNIST 和 sCIFAR 上的性能。
在医疗时间序列回归任务 (序列长度 4000) 中,LSSL 将均方根误差 (RMSE) 降低了超过三分之二。
表 2: BIDMC 生命体征回归。
极端长序列: 最长可达 38,000 步
- 顺序化 CelebA: 将 \( 178 \times 218 \) 像素的图像展平成长度为 38,800 步的序列。LSSL 的表现几乎可与专门的 ResNet-18 媲美,同时参数量减少 10 倍。
表 3: 顺序化 CelebA 分类。
- 原始语音指令: 直接处理包含 16,000 个采样点的音频。LSSL 比现有模型高出 20 多个百分点,甚至在使用 MFCC 特征 (序列缩短 100 倍) 的情况下也能超越所有基线。
表 4: 原始语音 vs MFCC 语音分类,以及时间尺度自适应能力。
在测试时采样率翻倍的场景下,LSSL 只需调整 \( \Delta t \) 即可平稳适配,而许多其他模型则完全失效。
混合优势
- 快速收敛: 强归纳偏置与卷积并行性结合,使得达到顶尖性能所需的训练轮数更少、总耗时更短。
表 5: LSSL 在训练轮数与分钟数上更快达到 SOTA。
- 可学习的记忆与时间尺度: 消融实验显示,学习 A 和 \( \Delta t \) 能显著提升表现。使用随机 A 或固定 \( \Delta t \) 会降低准确率。
图 2: 学习到的逆时间尺度 \(1/\Delta t\) 在 Speech Commands 数据集覆盖相关范围。
结论与启示
线性状态空间层独特地统一了序列建模的三大范式:
- CNN 的可并行化训练
- RNN 的高效有状态推理
- NDE 的连续时间适应性
- HiPPO 提供的理论长程记忆保障
其在极长序列上的出色表现,超越了人工设计的流水线 (如语音中的 MFCC) ,预示着一个值得期待的未来: 模型可以直接从原始且复杂的信号中学习,而无需领域特定的预处理。
尽管早期的快速算法存在稳定性问题且内存开销较高,但后续研究已解决了多种限制,并催生了影响深远的结构化状态空间 (S4) 模型。
根植于控制理论与连续时间数学,LSSL 为建模超长序列提供了一个有理论基础的解决方案——让我们得以触达此前无法处理的数据。