← 上一章 | 第2章 / 共14章 | 下一章 → |
扩散模型的成功离不开强大的神经网络架构。有趣的是,扩散模型并没有发明全新的网络结构,而是巧妙地借用了计算机视觉领域的两个里程碑式架构:U-Net和Vision Transformer (ViT)。本章将追溯这两种架构的历史发展,理解它们的设计初衷,并剖析它们为何能与扩散模型的去噪任务完美契合。这种“历史的巧合”不仅展示了深度学习领域知识迁移的魅力,也为我们设计未来更高效的生成模型提供了深刻的启示。
2015年,深度学习正在快速改变计算机视觉的格局。然而,在医学图像分析领域,研究者们面临着独特的挑战:标注数据极其稀缺(医学专家的时间宝贵),图像分辨率高,细节至关重要,且分割边界往往模糊不清。当时流行的全卷积网络(FCN)虽然在自然图像分割上取得了成功,但在医学图像上的表现并不理想。
正是在这样的背景下,来自弗莱堡大学的Olaf Ronneberger、Philipp Fischer和Thomas Brox提出了U-Net。他们的灵感来自一个朴素但深刻的观察:医学图像分割需要两种看似矛盾的能力——既要理解全局的语义信息(这是什么器官?),又要精确定位每个像素(边界在哪里?)。传统的编码器-解码器架构在解码过程中丢失了太多空间信息,而U-Net通过引入跳跃连接,优雅地解决了这个问题。
定义:历史脉络的详细时间线
- 2012-2014年:全卷积网络(FCN)的兴起,Long等人证明了CNN可以进行像素级预测,但在细节保留上存在不足。
- 2015年5月:U-Net在ISBI细胞追踪挑战赛中首次亮相,以大幅领先的成绩震撼了医学图像界。原始论文展示了仅用30张训练图像就能达到出色性能的能力。
- 2016-2017年:U-Net的变体开始涌现——3D U-Net用于体积数据、V-Net引入残差连接、Attention U-Net加入注意力机制。每个变体都针对特定应用场景进行了优化。
- 2017-2019年:U-Net架构被广泛应用于各种像素级预测任务,从卫星图像分析到自动驾驶的道路分割,成为该领域的事实标准。其PyTorch和TensorFlow实现成为GitHub上最受欢迎的开源项目之一。
- 2020年6月:Ho等人发表DDPM论文,首次将U-Net用作扩散模型的去噪网络。他们的关键洞察是:去噪本质上也是一个像素到像素的映射问题。
- 2021年:Dhariwal和Nichol在论文《Diffusion Models Beat GANs on Image Synthesis》中提出了改进的U-Net架构(ADM),加入了自注意力层和自适应归一化,将扩散模型的生成质量推向新高度。
- 2022年:Stable Diffusion的发布让U-Net架构走向大众。其高效的潜在空间U-Net设计使得高质量图像生成首次可以在消费级GPU上运行。
- 2023年至今:U-Net继续演进,如加入更多的条件机制(ControlNet)、与Transformer混合(U-ViT)、针对视频生成的时空U-Net等。
为什么一个为医学图像分割设计的架构能够如此完美地适用于扩散模型?答案隐藏在这两个看似不同的任务的数学本质中。
图像分割的数学表述:给定输入图像 $\mathbf{x} \in \mathbb{R}^{H \times W \times 3}$,预测每个像素的类别标签 $\mathbf{y} \in {0,1,…,C-1}^{H \times W}$。这是一个确定性的映射:$f_{\text{seg}}: \mathbb{R}^{H \times W \times 3} \rightarrow {0,1,…,C-1}^{H \times W}$。
扩散模型去噪的数学表述:给定带噪声的图像 $\mathbf{x}t = \sqrt{\bar{\alpha}_t}\mathbf{x}_0 + \sqrt{1-\bar{\alpha}_t}\boldsymbol{\epsilon}$,预测噪声 $\boldsymbol{\epsilon} \in \mathbb{R}^{H \times W \times 3}$。这同样是一个确定性的映射:$f{\text{denoise}}: \mathbb{R}^{H \times W \times 3} \times \mathbb{R} \rightarrow \mathbb{R}^{H \times W \times 3}$(额外的输入是时间步$t$)。
两者的共同点在于:
U-Net的成功不仅仅是技术上的胜利,更体现了深刻的设计哲学。让我们深入剖析其核心设计原则:
1. 对称性的美学与功能 U-Net的U形结构不仅在视觉上优雅,更重要的是体现了信息处理的对称性。编码器逐步压缩空间维度、提取抽象特征的过程,与解码器逐步恢复空间维度、重建具体细节的过程,形成了完美的镜像。这种对称性在扩散模型中获得了新的诠释:编码器理解”现在有什么噪声”,解码器决定”如何去除这些噪声”。
2. 跳跃连接:信息高速公路 原始的编码器-解码器架构存在一个致命弱点:信息瓶颈。当特征图被压缩到最小尺寸时(如原始尺寸的1/32),大量的空间信息已经无可挽回地丢失了。U-Net的跳跃连接就像在山谷两侧架起的桥梁,让高分辨率的信息可以直接”跳过”瓶颈,到达需要它的地方。
在扩散模型的语境下,这一点尤为关键。考虑去噪过程的两个极端情况:
3. 计算效率的权衡艺术 U-Net的金字塔结构带来了计算上的巨大优势。大部分的计算(自注意力、复杂的卷积)发生在低分辨率的特征图上,而高分辨率层只进行相对简单的操作。这种设计使得U-Net可以在有限的计算资源下处理高分辨率图像,这也是为什么Stable Diffusion能够在个人电脑上运行的关键因素之一。
U-Net的基本思想激发了无数的变体和改进。每一个成功的变体都代表了对特定问题的深刻理解:
3D U-Net (2016):将2D卷积替换为3D卷积,用于处理CT、MRI等体积数据。关键创新是各向异性的卷积核(如3×3×1),以处理医学图像中常见的各向异性分辨率。
Attention U-Net (2018):在跳跃连接中加入注意力门控(attention gates),让模型学习”哪些跳跃连接的信息是重要的”。这在医学图像中特别有用,因为病变区域往往只占整个图像的一小部分。
U-Net++ (2018):通过密集的跳跃连接创建了一个”嵌套”的U-Net结构,让解码器可以从多个尺度的编码特征中选择信息。这种设计虽然增加了计算量,但在某些任务上显著提升了性能。
TransUNet (2021):将CNN编码器的瓶颈部分替换为Transformer,结合了CNN的局部特征提取能力和Transformer的全局建模能力。这为后来的混合架构铺平了道路。
当Ho等人在2020年为DDPM选择网络架构时,他们面临着多种选择:ResNet、VGG、甚至当时新兴的Vision Transformer。为什么最终选择了U-Net?
1. 归纳偏置的匹配 扩散模型的去噪任务具有特殊的性质:输出必须与输入在空间上严格对齐。U-Net的架构天然地保证了这一点,而其他架构(如将图像展平后输入全连接网络)则会破坏这种空间结构。
2. 多时间尺度的处理能力 在扩散过程的不同阶段,去噪的重点是不同的:
U-Net的多尺度特性完美匹配了这种需求,不同的层级自然地专注于不同尺度的特征。
3. 实践中的鲁棒性 医学图像分割领域的严苛要求(小数据集、高精度需求)锻造了U-Net的鲁棒性。这种鲁棒性在扩散模型的训练中同样重要,因为去噪网络需要处理从纯噪声到清晰图像的整个谱系。
当DDPM的作者们在2020年选择U-Net作为去噪网络时,他们面临着与原始分割任务完全不同的需求。因此,一个为扩散模型“现代化”的U-Net诞生了,它融合了自2015年以来深度学习架构的诸多进展。
定义:扩散U-Net的关键改进 | 组件 | 原始U-Net (2015) | 扩散U-Net (2020+) | | :— | :— | :— | | 卷积类型 | Valid卷积 (无padding) | Same卷积 (保持尺寸) | | 归一化 | 无 (或后期加入BatchNorm) | GroupNorm (小批量稳定) | | 激活函数 | ReLU | SiLU / Swish (更平滑) | | 残差连接 | 无 | 每个块内部都有 (类ResNet) | | 注意力机制 | 无 | 多分辨率自注意力 | | 条件机制 | 无需条件 | 时间嵌入 (必需) |
让我们深入理解几个关键改进:
现代U-Net的基本构建单元不再是简单的卷积层,而是借鉴了ResNet的残差块。一个典型的块流程如下:
x
首先通过 GroupNorm
和 SiLU
激活函数。Conv2d
层。GroupNorm
和 SiLU
。Conv2d
层。x
相加(残差连接)。时间步 t
的信息至关重要。它通常通过一个小型MLP从正弦编码转换为嵌入向量,然后通过自适应归一化层(Adaptive Group Normalization, AdaGN)注入到每个残差块中。其核心思想是调制残差块的统计特性:
h_out = GroupNorm(h_in) * (1 + scale(t)) + shift(t)
其中 scale(t)
和 shift(t)
是从时间嵌入向量线性变换得到的。
为了捕获长程依赖关系,自注意力机制被引入到U-Net中。但由于其计算复杂度与像素数的平方成正比,它通常只在特征图分辨率较低的层级(如16x16或8x8)使用,以在计算效率和全局建模能力之间取得平衡。
“如何正确地降低和恢复分辨率”是U-Net设计的核心问题之一,其演进过程反映了深度学习架构设计的范式转变。这个问题看似简单,实则深刻影响着模型的表现力和生成质量。
下采样不仅仅是减少计算量的技术手段,更是一种信息抽象的过程。每次下采样,我们都在回答一个问题:如何用更少的数字表示更大的区域?
1. 最大池化时代(2012-2015)
最大池化(nn.MaxPool2d
)曾是卷积神经网络的标配。其背后的假设是:在一个局部区域内,最强的激活值代表了最重要的特征。这种假设在分类任务中很合理——我们关心的是”是否存在某个特征”,而不是”特征在哪里”。
输入: [[1, 2], MaxPool2d 输出: [4]
[3, 4]] (2x2)
然而,对于生成任务,这种”赢者通吃”的策略是灾难性的:
2. 步进卷积革命(2015-2018)
DCGAN论文提出了一个革命性的想法:让网络自己学习如何下采样。步进卷积(stride=2
的nn.Conv2d
)将下采样和特征提取合二为一:
# 传统方法:先卷积,后池化
conv = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
pool = nn.MaxPool2d(2)
output = pool(conv(input)) # 两步操作
# 现代方法:步进卷积一步到位
strided_conv = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=2, padding=1)
output = strided_conv(input) # 一步操作,可学习
这种方法的优势在于:
3. 现代最佳实践:分而治之(2018至今) 随着模型规模的增长,训练稳定性成为关键考虑。现代架构倾向于将”改变分辨率”和”提取特征”解耦:
# 第一步:在当前分辨率提取特征
conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
# 第二步:专门的下采样层
downsample = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=2, padding=1)
这种设计的智慧在于:
如果说下采样是”压缩”,那么上采样就是”解压缩”。但与信息压缩不同,神经网络的上采样需要”创造”原本不存在的细节。
1. 转置卷积的诱惑与陷阱
转置卷积(nn.ConvTranspose2d
)在数学上是步进卷积的精确逆操作。它通过在输入之间插入零值,然后进行常规卷积来实现上采样:
输入: [a, b] → 插零: [a, 0, b] → 卷积: 生成更大的输出
然而,这种方法存在一个致命问题:棋盘效应(Checkerboard Artifacts)。当kernel_size
不能被stride
整除时,输出像素接收到的”贡献”不均匀:
kernel_size=3, stride=2 的情况:
某些输出像素被1个输入像素影响
某些输出像素被2个输入像素影响
→ 产生棋盘状的明暗模式
这个问题在2016年被Odena等人系统分析后,引发了社区的广泛讨论。
2. 插值+卷积:简单但有效的解决方案 为了避免棋盘效应,现代架构采用了一个看似”倒退”但实际上更稳健的方法:
# 方法1:最近邻插值 + 卷积
upsample = nn.Upsample(scale_factor=2, mode='nearest')
conv = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
output = conv(upsample(input))
# 方法2:双线性插值 + 卷积
upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False)
conv = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
output = conv(upsample(input))
这种方法的优势:
3. 亚像素卷积:另一种优雅的方案 亚像素卷积(Pixel Shuffle)提供了另一种思路:先在低分辨率空间生成多个通道,然后重新排列成高分辨率输出:
# 输入: [B, C, H, W]
# 先扩展通道: [B, C*r², H, W]
conv = nn.Conv2d(in_channels, out_channels * scale_factor**2, kernel_size=3, padding=1)
# 然后重排: [B, C, H*r, W*r]
pixel_shuffle = nn.PixelShuffle(scale_factor)
output = pixel_shuffle(conv(input))
这种方法在超分辨率任务中特别流行,因为它允许网络在低分辨率空间进行大部分计算。
在扩散模型中,采样方式的选择有着特殊的重要性:
1. 信息保真度 扩散模型需要在多个时间步之间传递信息。任何信息损失都会在迭代过程中被放大。因此,可逆或近似可逆的采样方式(如步进卷积配合适当的上采样)特别重要。
2. 多尺度一致性 去噪过程需要在不同尺度上保持一致性。粗糙的采样方式可能导致不同分辨率层之间的特征不匹配,影响最终的生成质量。
3. 计算效率的关键 U-Net的大部分计算发生在低分辨率层。高效的采样策略可以显著减少计算量,这是Stable Diffusion能够在消费级硬件上运行的关键因素之一。
归一化技术的演进史,是深度学习社区对”如何让深层网络稳定训练”这一核心问题不断探索的历史。在扩散模型中,归一化不仅影响训练稳定性,更成为了注入条件信息的关键机制。
2015年,Ioffe和Szegedy提出BatchNorm时,他们的核心观察是:深层网络训练困难的一个重要原因是内部协变量偏移(Internal Covariate Shift)——即每层的输入分布在训练过程中不断变化,导致后续层需要不断适应新的输入分布。
归一化的基本思想很简单:
归一化输出 = γ × (输入 - 均值) / 标准差 + β
其中γ和β是可学习的缩放和偏移参数。关键在于:如何计算均值和标准差?
BatchNorm在许多任务上取得了巨大成功,但在扩散模型中却遇到了前所未有的挑战:
1. 批次依赖性带来的不一致 BatchNorm在训练时使用当前批次的统计量,在推理时使用移动平均。这导致:
# 训练时:使用批次统计
mean = x.mean(dim=[0, 2, 3]) # 跨批次维度计算
var = x.var(dim=[0, 2, 3])
x_norm = (x - mean) / sqrt(var + eps)
# 推理时:使用移动平均
x_norm = (x - running_mean) / sqrt(running_var + eps)
对于扩散模型,这种不一致是致命的:
2. 时间步混淆问题 扩散模型的一个批次中,不同样本可能处于不同的时间步:
批次 = [x_t1, x_t2, x_t3, x_t4] # t1, t2, t3, t4可能完全不同
BatchNorm会将这些处于不同噪声水平的样本混合计算统计量,这就像把苹果和橙子混在一起求平均——毫无意义。
3. 小批量训练的灾难 高分辨率的扩散模型因为内存限制,批次大小通常很小(如2或4)。在如此小的批次上估计统计量,方差极大,训练极不稳定。
2018年,何恺明等人提出的GroupNorm巧妙地解决了这些问题。其核心思想是:不跨样本计算统计量,而是在每个样本内部,将通道分组后计算。
# GroupNorm的计算方式
# 假设输入 x 的形状为 [B, C, H, W]
# 将 C 个通道分成 G 组
x = x.view(B, G, C//G, H, W)
mean = x.mean(dim=[2, 3, 4]) # 在每组内计算
var = x.var(dim=[2, 3, 4])
x = (x - mean) / sqrt(var + eps)
x = x.view(B, C, H, W)
GroupNorm的优势:
GroupNorm实际上是一个统一框架:
传统的归一化使用固定的γ和β参数。但StyleGAN的成功启发了一个革命性的想法:让这些参数根据外部条件动态变化。
AdaGN(Adaptive Group Normalization)的工作原理:
class AdaGN(nn.Module):
def __init__(self, num_features, num_groups=32, time_emb_dim=128):
super().__init__()
self.norm = nn.GroupNorm(num_groups, num_features)
# 从时间嵌入预测 scale 和 shift
self.time_mlp = nn.Sequential(
nn.SiLU(),
nn.Linear(time_emb_dim, num_features * 2)
)
def forward(self, x, time_emb):
# 计算动态的 scale 和 shift
scale_shift = self.time_mlp(time_emb)
scale, shift = scale_shift.chunk(2, dim=1)
# 应用 GroupNorm
x = self.norm(x)
# 应用动态调制
x = x * (1 + scale[:, :, None, None]) + shift[:, :, None, None]
return x
为什么AdaGN对扩散模型如此有效?
时间感知的去噪:不同时间步需要不同的去噪策略。早期(高噪声)可能需要更强的归一化来稳定训练,后期(低噪声)可能需要更弱的归一化来保留细节。
高效的条件注入:相比于将条件信息拼接到特征图(增加计算量),AdaGN通过调制现有特征实现条件控制,几乎不增加计算成本。
分层的控制粒度:每一层可以根据时间步独立调整其行为,这种细粒度的控制对于处理不同尺度的噪声至关重要。
在Transformer的发展过程中,Layer Normalization的位置引发了激烈讨论。这个讨论同样适用于U-Net:
Post-Norm(传统方式):
x → Conv → ReLU → Norm → + → 输出
↑
x (残差连接)
Pre-Norm(现代方式):
x → Norm → Conv → ReLU → + → 输出
↑
x (残差连接)
Pre-Norm的优势:
这就是为什么现代扩散模型普遍采用Pre-Norm设计。
最近,RMSNorm作为LayerNorm的简化版本引起了关注:
# LayerNorm: 减均值,除标准差
x_norm = (x - mean) / std
# RMSNorm: 只除以均方根,不减均值
x_norm = x / sqrt(mean(x²))
RMSNorm的优势:
虽然RMSNorm在扩散模型中的应用还不广泛,但它代表了一个重要趋势:不断简化和优化基础组件。
Transformer架构由Vaswani等人在2017年的论文《Attention Is All You Need》中为自然语言处理提出。2020年,Dosovitskiy等人的ViT论文证明了纯Transformer架构在图像分类上可以达到甚至超越顶尖的CNN,开启了Transformer在计算机视觉领域的革命。
ViT的核心思想极其简洁:
[CLS]
token输入到标准的Transformer编码器中。[CLS]
token进行分类。这种设计的优雅之处在于它为视觉问题引入了新的归纳偏置:世界是由可组合的“部件”构成的。
2022年,Peebles和Xie在论文《Scalable Diffusion Models with Transformers》中提出了DiT,成功将ViT架构应用于扩散模型。DiT对ViT进行了关键改造以适应去噪任务:
[CLS]
Token:生成任务需要对每个patch进行预测,因此去除了分类任务专用的[CLS]
token。t
和类别标签c
的嵌入向量被视为额外的条件tokens,通过自适应LayerNorm(AdaLN)或交叉注意力(cross-attention)注入到模型中。DiT的成功,特别是其卓越的可扩展性(scaling law),使其迅速成为SOTA文生图模型(如Sora, Stable Diffusion 3)的首选架构。
理论架构和实际部署之间往往存在巨大鸿沟。本节分享一些在实践中积累的优化技巧。
训练扩散模型时,内存的最大消耗通常来自激活值,特别是U-Net中为跳跃连接而保存的各层特征图。
梯度检查点 (Gradient Checkpointing):核心思想是“用计算换内存”。通过torch.utils.checkpoint.checkpoint
包裹模型的一部分(如一个ResBlock),在前向传播时不保存其内部的激活值,而在反向传播时重新计算它们。这可以显著降低内存占用(约30-50%),但会增加训练时间(约20-30%)。
混合精度训练 (Mixed Precision):使用torch.cuda.amp
(自动混合精度)可以利用现代GPU的Tensor Cores,将大部分计算从FP32转为FP16或BF16,内存减半,速度翻倍。关键是使用GradScaler
来防止FP16梯度下溢。
注意力优化:标准自注意力的内存和计算复杂度与序列长度的平方成正比。对于高分辨率图像,这很快会成为瓶颈。FlashAttention等库通过融合内核操作,避免将巨大的注意力矩阵写入和读出GPU内存,从而实现显著的加速和内存节省。
初始化策略:一个关键技巧是将输出层的权重和偏置初始化为零。这确保模型在训练开始时输出为零,即预测的噪声为零。这是一种“无为而治”的初始化,使得模型在学习初期不会对输入造成巨大扰动,有助于稳定训练。
数值稳定性:
torch.nn.utils.clip_grad_norm_
来防止梯度爆炸,是训练大模型的标配。本章我们追溯了扩散模型中两种主流架构的历史渊源,并深入分析了它们的设计细节和演进过程。
这两种架构能够成功应用于扩散模型并非偶然,而是经过了精心的改造和适配:
扩散模型的架构演进史启示我们:创新并不总是需要“从零开始”。善于发现和利用已有技术的潜力,通过巧妙的改造和组合,往往能产生意想不到的突破。
下一章,我们将深入DDPM的数学原理,看看这些强大的架构是如何在一个清晰的概率框架下进行训练和优化的。