引言

在过去几年里,计算机视觉领域已经被生成式 AI 完全颠覆。像 Stable Diffusion 和 DALL-E 这样的模型展示了从简单的文本提示生成逼真图像的惊人能力。它们“知道”狗长什么样,夕阳如何在水面上反射,以及宇航员骑马是什么样子。这是通过在包含数十亿图像-文本对 (如 LAION-5B) 的海量数据集上进行训练实现的。

然而,有一个领域在很大程度上错过了这场革命: 医学影像

为什么?问题是双重的。首先,医疗数据稀缺,标注昂贵,并且受严格的隐私法律保护。你不能简单地从互联网上抓取数十亿张带标签的 MRI 扫描图。其次,存在巨大的“分布偏移 (distribution shift) ”。一个在猫和汽车的互联网照片上训练的模型,并不天生理解人体解剖结构的严格和结构性约束。当你要求一个标准的扩散模型修改一张脑部 MRI 时,它可能会把头骨当作一个柔软的物体,以生物学上不可能的方式扭曲它。

在一篇题为 “Latent Drifting in Diffusion Models for Counterfactual Medical Image Synthesis” (用于反事实医学图像合成的扩散模型中的潜在漂移) 的精彩研究论文中,来自慕尼黑工业大学和斯坦福大学的研究人员提出了一个解决方案。他们介绍了 Latent Drifting (LD) , 这项技术允许通用扩散模型在小型医疗数据集上进行微调,并保持惊人的准确性。

这篇文章将详细拆解 LD 的工作原理,“漂移 (drift) ”背后的数学原理,以及它如何使我们能够生成“反事实 (counterfactual) ”医学图像——回答诸如*“如果这位病人患上阿尔茨海默病,他的大脑会是什么样子?”*这类问题。

背景: 适应性挑战

要理解 Latent Drifting,我们需要先快速回顾一下扩散模型的工作原理,以及为什么标准微调在医学领域会失败。

扩散模型入门

扩散模型基于噪声原理工作。在前向过程 (forward process) 中,模型逐渐向图像添加高斯噪声,直到它变成纯粹的静噪。这个过程在数学上被定义为一个马尔可夫链。

前向过程中联合分布的因式分解。

在这里,\(x_0\) 是清晰图像,\(x_T\) 是最终的噪声状态。模型在训练期间的任务是学习逆向过程 (reverse process) : 即获取噪声并迭代地去除它以恢复图像。

训练目标本质上是一个去噪任务。模型 (\(\hat{x}_\theta\)) 试图预测在特定时间步长 \(t\) 添加到图像 \(x\) 中的噪声 \(\epsilon\)。

训练扩散模型的去噪目标函数。

分布偏移问题

当我们采用一个在自然图像 (“源”分布) 上预训练的模型,并试图在医学图像 (“目标”分布) 上对其进行微调时,我们会碰壁。

医学图像有严格的模板。脑部 MRI 总是具有特定配置的头骨、脑室和灰质。自然图像则千差万别。当标准的微调方法 (如 Dreambooth 或 Textual Inversion) 应用于医疗数据时,它们往往会陷入困境。它们可能会捕捉到 MRI 的纹理 (灰度噪声) ,但无法捕捉到几何结构 (大脑的形状) ,从而导致产生幻觉或解剖学上不正确的输出。

研究人员意识到,与其强迫模型重新学习一切,不如在潜在空间 (latent space) 中引入一种“漂移”,以弥合自然图像分布和医学图像分布之间的鸿沟。

核心方法: Latent Drifting (LD)

这篇论文的核心创新是将适应过程视为一个涉及漂移参数的极小极大优化问题 (min-max optimization problem)

什么是 Latent Drift?

在标准扩散模型中,逆向过程 (从噪声生成图像) 被建模为从噪声状态 \(x_t\) 到较少噪声状态 \(x_{t-1}\) 的转换。这种转换通常是一个围绕学习到的均值 \(\mu_\theta\) 的正态分布。

逆向扩散过程中标准的可学习转移核。

研究人员建议在这个均值上增加一个学习到的标量值 \(\delta\) (delta) 。这个 \(\delta\) 代表“Latent Drift (潜在漂移) ”。它充当一种偏差,在不破坏模型预训练知识的情况下,将生成的样本向目标分布 (医学图像) “偏移”。

修改后的逆向过程如下所示:

包含潜在漂移参数 delta 的修改后的转移核。

通过将这个 \(\delta\) 注入到逆向过程中,模型可以“引导”生成过程。

可视化漂移

为了让你直观地感受这种数学上的“漂移”对图像实际做了什么,作者使用标准提示词 (埃隆·马斯克和巴拉克·奥巴马) 提供了一个可视化示例。

使用不同的潜在漂移值生成的样本,展示了视觉内容如何变化。

在上面的 Figure 2 中,观察图像随着 \(\delta\) 从 -0.1 变为 0.1 时的变化。

  • 在 \(\delta = 0\) 时,我们得到标准输出。
  • 随着 \(\delta\) 变为负数或正数,主体仍然可识别,但上下文风格发生了显著漂移。
  • 在医学背景下,优化这种“漂移”不是为了改变照片的风格,而是为了将分布从“自然图像统计数据”转移到“医学图像统计数据”。

弥合分布鸿沟

论文中最令人信服的可视化之一确切地展示了为什么这种漂移对于医学微调是必要的。

有无 Latent Drifting 的图像和潜在空间分布比较。

Figure 3 中,我们看到了标准微调 (左) 与使用 Latent Drifting 的微调 (右) 之间的比较。

  • 第 3 行 (通道分布) : 注意左侧像素值的分布波动剧烈且方差很大。在右侧,有了 LD,分布更加紧凑和受控。
  • 第 4 行 (潜在空间) : 右侧的潜在空间分布更接近标准高斯分布,这是扩散模型的理想状态。

漂移参数 \(\delta\) 本质上充当了一个超参数,在多样性 (预训练模型想要做的) 和条件化 (医疗数据所要求的) 之间进行权衡。

反事实优化

研究人员将其形式化为一个反事实生成问题。他们希望生成一个图像 \(x'\),它与原始图像 \(x\) 相似,但具有不同的标签 \(y'\) (例如,将诊断从“健康”改为“阿尔茨海默病”) 。

这被制定为一个包含两个竞争项的损失函数:

用于反事实生成的极小极大目标函数。

  1. 期望结果保真度 (Desired Outcome Fidelity): 此项确保生成的图像实际上看起来像目标类别 \(y'\) (例如,它实际上看起来像阿尔茨海默病的大脑) 。
  2. 反事实保真度 (Counterfactual Fidelity): 此项确保生成的图像 \(x'\) 尽可能接近原始图像 \(x\)。我们不想生成一个病人;我们想要的是同一个病人,但具有不同的疾病状态。

参数 \(\lambda\) 控制这种平衡。如果 \(\lambda > 0\),模型会搜索最佳的 \(\delta\) (漂移) ,以最小化生成分布与目标医学分布之间的距离。

实验与结果

研究人员在两个主要任务上测试了 Latent Drifting: 文本到图像生成 (创建合成数据) 和图像到图像处理 (编辑现有扫描图) 。他们使用了包含脑部 MRI (阿尔茨海默病 vs. 健康) 和胸部 X 光片 (肺炎、胸腔积液等) 的数据集。

1. 文本到图像生成

这里的目标很简单: 我们能否通过像*“一位患有阿尔茨海默病的 70 岁女性的脑部 MRI”*这样的文本提示生成逼真的医学图像?

他们将 Latent Drifting (LD) 与流行的微调方法进行了比较: Textual InversionDreamBoothCustom Diffusion

使用不同微调方法生成 MRI 切片的视觉比较。

Figure 5 展示了质量上的鲜明对比:

  • (a) 不使用 LD: 图像充满噪点。“头骨”边界模糊或完全破碎。大脑内部结构通常看起来像通用的纹理而不是解剖结构。
  • (b) 使用 LD: 对比度清晰。背景是完美的黑色 (理应如此) 。解剖结构 (脑室、白质) 清晰且逼真。

定量成功

视觉上的改进得到了数据的支持。他们使用 FID (Fréchet Inception Distance) 来衡量合成图像与真实数据的相似程度,这是一个标准指标,数值越低越好。

比较不同方法的 FID 和 KID 分数的表格。

Table 1 所示,在“Stable Diffusion Basic FT”中加入 LD 后,脑部 MRI 的 FID 分数从 92.13 降至 49.68 。 这是一个巨大的保真度提升。它还提高了在这个合成数据上训练的模型的分类准确率 (AUC) ,证明生成的图像包含医学上相关的特征。

2. 反事实图像处理

这可能是最令人兴奋的应用。我们能否拍摄一位健康病人的图像,并可视化如果他们患病会是什么样子?或者反过来,从图像中移除疾病?

阿尔茨海默病进展

团队使用了一种称为 Pix2Pix Zero 的方法结合 Latent Drifting 来执行这些编辑。

展示阿尔茨海默病和健康状态之间转换的反事实 MRI 切片。

Figure 7 中,我们看到了双向编辑:

  • 顶部 (健康 \(\to\) 阿尔茨海默病) : 模型成功扩大了脑室 (中心的黑暗空腔) ,这是与阿尔茨海默病相关的脑萎缩的临床标志。
  • 底部 (阿尔茨海默病 \(\to\) 健康) : 模型通过缩小脑室和恢复组织体积来“治愈”大脑。
  • 差异图 (Diff Map): 绿色和红色的叠加层清晰地显示模型仅更改了相关的解剖区域,保留了病人的其他身份特征。

衰老模拟

他们还使用 InstructPix2Pix 应用 LD 来模拟衰老。

脑部衰老示例,将 70 岁 CN 大脑转换为 77 岁 MCI 大脑。

Figure 6 展示了一个年龄递增的请求: *“Age this CN 70 years old female brain MRI into a 77 brain MRI with MCI.” (将这张 70 岁 CN 女性脑部 MRI 老化为 77 岁患有 MCI 的脑部 MRI。) * 红框突出了转换过程中发生的细微结构变化,模拟了脑组织随时间的自然退化。

胸部 X 光片

该方法不仅限于大脑。他们将其应用于胸部 X 光片,以添加或移除肺炎和心脏肥大 (心脏扩大) 等病症。

胸部 X 光片上的反事实样本,展示了疾病的添加和移除。

Figure 8 中,你可以看到模型成功地处理了肺部的特定区域。例如,在右上角,它通过增加下肺野的不透明度来添加“胸腔积液” (肺部积液) ,这通过“Diff”列中的热力图进行了可视化。

结论与启示

论文 “Latent Drifting in Diffusion Models” 为医学 AI 迈出了重要一步。通过承认医学图像存在于与自然图像不同的“世界”中,并通过 漂移 (\(\delta\)) 参数在数学上解释这种距离,作者解锁了大型基础模型在医疗保健领域的潜力。

主要收获:

  1. 无需从头训练: 我们可以利用 Stable Diffusion 见过的数十亿张图像,即使是对于小众的医学任务,只要我们正确地调整潜在空间。
  2. 几何结构很重要: 标准微调在医学中失败是因为它忽略了结构约束。LD 保留了这些约束。
  3. 可解释性 AI: 生成反事实的能力 (“给我看这位病人没有肿瘤的样子”) 是可解释性 AI 的强大工具,可以帮助临床医生理解模型正在关注哪些特征。

这种方法降低了医学图像生成的门槛。它允许研究人员创建高保真合成数据集来训练诊断工具,从而绕过隐私问题和数据稀缺性。随着扩散模型的不断发展,像 Latent Drifting 这样的技术对于确保这些强大的工具能够可靠地服务于医学界至关重要。

医学图像生成和处理任务概览。