一个孩子在书中看到一张长颈鹿的图片,基本上就能在野外、动物园或另一本书中正确地认出长颈鹿。这种仅凭一两个例子就能学习一个新概念的非凡能力,是人类轻而易举就能做到的。然而,对于我们最先进的机器学习模型来说,这仍然是一个巨大的挑战。

传统的深度学习模型——图像识别和自然语言处理领域取得突破的幕后功臣——是出了名的数据贪婪。它们通常需要成千上万,甚至数百万个标注样本才能有效地学习一个新概念。这一限制使得它们在许多数据稀缺的真实场景中显得不切实际,例如诊断罕见疾病、识别新产品缺陷或为单个用户定制个性化系统。

这就是 单样本学习 (one-shot learning) 的挑战: 让模型在只看到一个类别的单个样本后就能识别出该新类别。

2016 年,谷歌 DeepMind 的研究人员在论文 Matching Networks for One Shot Learning (匹配网络用于单样本学习) 中提出了针对这一挑战的解决方案。他们提出了一种新颖且具有影响力的方法,该方法不仅学习分类什么,更重要的是学习如何从少量样本中进行学习。

在本文中,我们将深入探索匹配网络。我们将看到它们如何巧妙地结合度量学习与记忆增强网络的思想,构建一个无需重新训练即可快速适应新类别的系统。


参数化与非参数化思维

在深入模型本身之前,先了解机器学习框架中的两大基本类别将有助于理解。

参数化模型 , 例如标准的卷积神经网络 (CNN) ,从大型训练数据集中学习一组固定的参数——权重和偏置。为了对新图像进行分类,模型会将其输入经过已学习的层;模型的知识完全由这些参数编码。缺点是什么?新增一个类别通常需要重新训练或微调整个网络——这是一个缓慢且数据密集的过程。

相比之下, 非参数化模型 不依赖于固定数量的参数。典型例子是 k‑近邻算法 (k‑NN) , 它通过在训练集中找到与新样本距离最近的 k 个样本,并根据它们的标签进行投票来完成分类。这类模型可以立即适应新样本,但性能高度依赖于良好的距离度量,而且随着数据集的增长,推理效率会显著下降。

匹配网络 旨在兼具两者的优势: 利用深度神经网络进行强大的特征提取 (参数化部分) ,并结合非参数化预测方法,将测试样本直接与一组有限的标注样本进行比较。

其核心思想简单却深刻——训练一个神经网络,不是让它成为固定的分类器,而是让它成为一个擅长比较的专家。


匹配网络的工作原理

匹配网络将单样本学习建模为一个映射问题。给定一个小规模、带标签的 支持集 \( S = \{(x_i, y_i)\}_{i=1}^k \),以及一个未标注的测试样本 \( \hat{x} \),网络会生成标签的概率分布 \( P(\hat{y}|\hat{x}, S) \)。支持集本身为分类提供了上下文。

匹配网络的整体架构。它通过编码器 g 处理支持集图像,通过编码器 f 处理测试图像,然后进行比较以生成最终输出。

图 1: 匹配网络的架构对一个小型支持集和一个测试图像进行编码,然后通过特征匹配来预测标签。

特征空间中的加权投票

匹配网络通过支持集中各个标签的加权和来预测 \( \hat{x} \) 的标签:

公式 1: 匹配网络的核心预测公式。

公式 1: 预测由对支持集标签的加权求和组成,权重由学习到的注意力得分决定。

其中:

  • \( y_i \): 第 \( i \) 个支持样本的独热 (one‑hot) 标签向量;
  • \( a(\hat{x}, x_i) \): 注意力权重 , 衡量测试样本与支持样本之间的相似度;
  • \( \hat{y} \): 支持集标签的加权平均值,有效地汇总来自最相似样本的证据。

这种方法类似于 k‑NN,但它采用了一个可学习的度量 。 注意力机制不使用固定的欧几里得距离,而是基于余弦相似度的 softmax 函数进行端到端训练:

\[ a(\hat{x}, x_i) = \frac{e^{c(f(\hat{x}), g(x_i))}}{\sum_{j=1}^k e^{c(f(\hat{x}), g(x_j))}} \]

其中 \( f \) 和 \( g \) 是将图像映射到嵌入空间的深度编码器。目标是学习一种嵌入,使简单的余弦相似度即可正确地识别匹配类别。


全上下文嵌入: 在上下文中学习

基本形式下,模型是独立地嵌入每张图像——每个 \( g(x_i) \) 或 \( f(\hat{x}) \) 仅依赖于自身特征。但上下文是关键。如果支持集中包含相似类别 (例如不同犬种) ,这些内部关系应当影响嵌入结果。

为了让嵌入依赖于整个支持集,论文提出了 全上下文嵌入 (Full Context Embeddings, FCE) :

  1. 支持集情境化: 每个 \( g(x_i) \) 都通过一个双向 LSTM 在集合上进行扫描,将 \( S \) 中所有样本的信息整合进来,使每个样本的嵌入能感知集合中其他样本。
  2. 测试样本情境化: 测试图像的嵌入 \( f(\hat{x}) \) 通过一个带有读取注意力机制的 LSTM 进行优化,该机制会在多次“扫视”中查阅支持集。

公式 2: 测试图像的全上下文嵌入使用一个带注意力的 LSTM。

公式 2: 带注意力的 LSTM 根据与整个支持集的交互来优化测试样本嵌入。

这种 FCE 机制允许模型聚焦于与当前单样本任务最相关的特征,从而在复杂的细粒度分类任务中显著提升准确率。


训练策略——“即测即训”

第二个主要创新在于 情景式训练 (episodic training) , 这是一种元学习的方式。

在训练过程中,网络不会单独处理样本,而是反复面对模拟的单样本任务,以匹配评估条件:

  1. 创建情景: 随机选择 N 个类别;
  2. 采样支持集: 从每个类别抽取 k 个带标签的样本组成 \( S \);
  3. 采样查询批次: 从相同类别中抽取查询样本 \( B \);
  4. 优化: 在 \( S \) 条件下预测 \( B \) 的标签,并根据准确率更新参数。

公式 3: 训练目标是在许多情景中最大化正确预测的对数概率。

公式 3: 模型被训练以在多个采样的单样本情景中最大化正确标签的期望对数概率。

这种“即测即训” (train as you test) 的训练理念避免了网络对固定类别的记忆。网络学习到的是一个可复用的过程——如何从小规模标注集构建分类器。在推理阶段,它无需微调即可处理全新的类别。


实验与结果

作者在需要从极少数据中实现泛化的视觉和语言任务上评估了匹配网络。

Omniglot——手写字符挑战

Omniglot 数据集包含来自 50 种字母表的 1,623 个字符,每个字符仅有 20 个样本。这使它成为单样本分类任务的理想试验场。

表 1: 在 Omniglot 数据集上的单样本分类准确率。

表 1: Omniglot 数据集上的单样本分类准确率 (5‑way 与 20‑way 任务) 。

在挑战性较高的 20‑way 1‑shot 任务中,匹配网络取得了 93.8% 的准确率 , 超越了卷积孪生网络 (88.0%) 。它们在 5‑shot 版本中也表现出色,显示出在极少监督下的强大适应能力。

在 Omniglot 上训练的模型甚至可以泛化到未见过的数据集——在 10‑way MNIST 单样本分类任务中取得了 72% 的准确率,超越了此前的基线模型。


ImageNet——真实世界规模测试

为了走出小型灰度字符的范畴,作者将研究扩展到更复杂的 ImageNet 数据集,该数据集包含数百个类别。他们设计了两种评估设置:

  • rand: 在 118 个随机保留的类别上进行测试;
  • dogs: 在非狗类上训练,并在 118 个狗的品种上测试 (一个困难的细粒度任务) 。

为提高实验效率,他们还构建了 miniImageNet,这是一个包含 100 个类别、共 60 000 张图像的子集,方便快速原型化。

表 2: 在 miniImageNet 数据集上的结果。

表 2: miniImageNet 上的单样本分类结果。全上下文嵌入持续提升了准确率。

miniImageNet 上,FCE 将 5‑way 1‑shot 准确率从 41.2% 提升至 44.2%。

当扩展到完整的 ImageNet 时,匹配网络再次超过了强大的基线模型:

表 3: 在完整 ImageNet 数据集上 5-way, 1-shot 任务的结果。

表 3: 完整 ImageNet 上的比较。匹配网络在未见过的随机类别上比 Inception 分类器高出近六个百分点。

在随机类别任务中,匹配网络达到了 93.2% 的准确率,而 Inception 模型为 87.6%——错误率几乎减半。

图 2: ImageNet 示例中,Inception 基线模型失败而匹配网络成功。基线模型被支持集中杂乱的图像干扰,而匹配网络正确识别出红色汽车。

图 2: 在 ImageNet 的单样本分类中,匹配网络能更可靠地从干扰和杂乱中恢复结果,相比 Inception 基线模型表现更稳健。

有趣的是,在细粒度的 “dogs” 任务上性能略有下降。作者将其归因于训练与测试分布间的不匹配——训练情景是跨多种类别随机采样的,而非针对细粒度子集。这提示我们,根据预期的部署条件定制情景可能进一步提高性能。


单样本语言建模

将该思想扩展到文本领域,论文在 Penn Treebank 数据集上引入了新颖的 句子补全 任务。目标是: 通过一个包含五个句子的支持集 (每个句子缺少不同单词) 来预测查询句子中缺失的词语。

表 4: 5‑way, 1‑shot 句子补全任务示例。模型必须根据支持集提供的上下文确定“dollar”是查询句子的正确词。

表 4: 5‑way, 1‑shot 句子预测任务示例。模型需要推断出 “dollar” 适用于查询句子。

匹配网络在 1‑shot 任务中取得了 32.4% 的准确率,显著高于 20% 的随机基线。虽然仍远不及完全监督的 LSTM 语言模型 (72.8%) ,但这验证了非参数化匹配机制可从图像推广到文本。


关键要点

匹配网络代表了迈向能够从极少数据中高效学习的模型的重要一步。它们的成功源于两个核心原则:

  1. 直接建模学习过程: 将单样本分类视为基于条件分布的推理 \( P(\hat{y}|\hat{x}, S) \),通过注意力机制在学习到的嵌入空间中计算相似度。
  2. 训练方式与测试方式保持一致: 情景式训练构建的是一种可迁移的学习能力——从小型支持集中学习的元知识——而非固定标签的记忆。

这些思想奠定了现代 元学习 (meta‑learning) 的基础,并启发了后续工作,如 *原型网络 (Prototypical Networks) * 和 *关系网络 (Relation Networks) *,进一步完善了嵌入与比较的范式。

尽管仍存在局限——计算成本随支持集规模增长,性能依赖于训练与测试分布的匹配——匹配网络证明了深度模型确实可以 学会学习

它们提醒我们,人工智能的进步不一定来自更大的网络,而可能来自更聪明的训练策略——那些像人类一样,只需少数几个样本,甚至只需一瞥就能学习的策略。