第16章 Plug-and-Play (PnP)/RED 与“算法展开”:优化与深度的结合
开篇段落
在前十五章中,我们构建了一个严谨的数学世界:图像复原问题被建模为“数据保真项 + 正则项”的变分问题。我们学习了如何设计 TV、稀疏、低秩等正则项来描述“什么是好的图像”。然而,当我们面对真实世界的复杂退化(如非高斯噪声、复杂的运动模糊、JPEG 压缩伪影)时,手工设计的正则项往往显得力不从心——它们太简单了,无法捕捉自然图像高维流形上的所有细节。
与此同时,深度学习(DL)以一种“暴力美学”的方式席卷了该领域。端到端(End-to-End)的神经网络(如 UNet, ResNet)通过海量数据训练,效果往往碾压传统方法。但 DL 也有其阿喀琉斯之踵:黑盒性质导致缺乏可解释性,且泛化能力差——一个针对高斯模糊训练的网络,面对运动模糊可能完全失效,需要重新训练。
本章将介绍连接这两个世界的“中间道路”:Plug-and-Play (PnP), Regularization by Denoising (RED) 以及 算法展开 (Deep Unrolling/Unfolding)。这些方法的核心哲学是:“让上帝的归上帝,凯撒的归凯撒”——用传统的优化算法框架(如 ADMM/HQS)处理物理退化模型(已知且可靠),用深度神经网络处理图像先验(复杂且难懂)。这不仅赋予了深度学习以数学结构,也让传统优化获得了数据驱动的强大表现力。
16.1 PnP 思想:从“手工正则”到“隐式去噪”
16.1.1 为什么是去噪器?(The Denoiser Intuition)
让我们重新审视近端算子(Proximal Operator)的定义。对于正则项 $\lambda R(x)$,其 Prox 定义为: $$ \text{prox}_{\lambda R}(v) = \arg\min_x \underbrace{\frac{1}{2}|x - v|_2^2}_{\text{输入} v \text{的保真}} + \underbrace{\lambda R(x)}_{\text{先验约束}} $$ 如果你仔细观察这个公式,这严格等价于一个贝叶斯去噪问题:假设观测信号 $v = x + n$,其中 $n$ 是高斯白噪声,方差与 $\lambda$ 相关,我们需要从 $v$ 中恢复 $x$。
核心洞察: 传统方法中,我们显式定义 $R(x) = |x|_{\text{TV}}$,然后费力推导其 Prox 算子。 而在 PnP 中,我们反其道而行之:既然 Prox 等价于去噪,为什么不直接用一个现成的、性能最强的去噪器(Denoiser)来替换这一步计算?
我们不再关心 $R(x)$ 的数学表达式是什么,我们只关心去噪器 $\mathcal{D}(\cdot)$ 能不能把图像“推”回到自然图像流形上。
16.1.2 PnP 的通用工作流
PnP 使得我们可以解决任意线性逆问题 $y = Ax + n$,只要我们有求解器(如 ADMM 或 HQS)和一个去噪器(如 DnCNN, DRUNet)。
+---------------------------+
| 物理模型 (退化算子 A) |
| 负责:数据一致性 |
+------------+--------------+
|
(把当前估计推向 y 的观测流形)
|
+------- 迭代循环 (Loop) --------+
| |
v v
[数据项子问题] [先验项子问题]
Inversion Step Denoising Step
x = (A'A + ρI)^-1 (...) z = Denoiser(x)
| |
^ ^
| |
+-----------(变量耦合)-----------+
16.2 两种主流 PnP 框架:HQS 与 ADMM
虽然 ADMM 很强,但在 PnP 领域,半二次分裂(Half-Quadratic Splitting, HQS) 因其简洁性同样非常流行。
16.2.1 PnP-HQS:简单即是美
求解 $\min_x \frac{1}{2}|Ax - y|^2 + \lambda R(x)$。 引入辅助变量 $z$,构建目标函数: $$ \min_{x,z} \frac{1}{2}|Ax - y|^2 + \lambda R(z) + \frac{\mu}{2}|x - z|^2 $$ 当 $\mu \to \infty$ 时,$x$ 逼近 $z$,问题回归原问题。HQS 采用交替最小化:
-
数据项更新 ($x$-step):这是一个最小二乘问题。 $$ x_{k+1} = \arg\min_x |Ax - y|^2 + \mu |x - z_k|^2 $$ 解:$x_{k+1} = (A^T A + \mu I)^{-1} (A^T y + \mu z_k)$ 注:对于卷积或 FFT 可对角化的 $A$,这一步有闭式解且极快。
-
先验项更新 ($z$-step):这是一个去噪问题。 $$ z_{k+1} = \arg\min_z \frac{\mu}{2}|x_{k+1} - z|^2 + \lambda R(z) $$ 等价于: $$ z_{k+1} = \arg\min_z \frac{1}{2}|z - x_{k+1}|^2 + \frac{\lambda}{\mu} R(z) $$ PnP 替换:直接调用去噪器,输入噪声水平 $\sigma_k = \sqrt{\lambda / \mu}$。 $$ z_{k+1} = \mathcal{D}_{\sigma_k}(x_{k+1}) $$
16.2.2 PnP-ADMM:更严格的约束
相比 HQS,ADMM 引入了拉格朗日乘子 $u$,对约束的处理更稳健,尤其是在 $\mu$ 不需要趋于无穷大的情况下也能收敛。
- 区别:ADMM 的去噪步输入是 $(x_{k+1} + u_k)$,即包含累积误差的修正。
- 适用场景:当退化算子 $A$ 很难求逆(比如非线性算子),或者需要处理硬约束(如 $x \in [0,1]$)时,ADMM 及其线性化变体更有优势。
16.2.3 关键技巧:参数 $\mu$ 与 $\sigma$ 的联动 (Rule-of-Thumb)
PnP 成功的秘诀在于退火策略。
- 在迭代初期,图像误差很大,我们希望去噪器“下手重一点”以消除大幅伪影,同时 $\mu$ 较小(允许 $x$ 和 $z$ 不太一致)。
- 随着迭代进行,图像变清晰,我们减小去噪强度($\sigma \downarrow$),增大 $\mu$(强制一致)。
- 经验公式:设定一组下降的噪声水平 $\sigma \in [50, 30, 10, 5]$,根据 $\sigma_k = \sqrt{\lambda/\mu_k}$ 反推对应的惩罚参数 $\mu_k$。
16.3 RED: Regularization by Denoising
PnP 虽然好用,但它是“算法层面的替换”,我们失去了原本优化的目标函数 $J(x)$。RED 试图找回这个目标函数。
16.3.1 显式正则项的构造
RED 提出了一个惊人的定义: $$ R_{RED}(x) = \frac{1}{2} x^T (x - \mathcal{D}(x)) $$ 直觉解释:
- $\mathcal{D}(x)$ 是对 $x$ 去噪后的结果。
- $(x - \mathcal{D}(x))$ 是图像中的“噪声残差”。
- $x^T (\text{noise})$:我们希望图像信号与噪声残差正交(或者说内积最小)。对于干净图像,$x \approx \mathcal{D}(x)$,则 $R_{RED} \approx 0$。
16.3.2 梯度即残差
RED 理论证明,在去噪器满足局部均匀性和雅可比对称性(虽然现实网络不完全满足,但近似成立)的条件下: $$ \nabla R_{RED}(x) = x - \mathcal{D}(x) $$ 这一结论极大地简化了优化。我们不需要对复杂的 CNN 进行反向传播(Backprop)来求梯度,残差本身就是梯度方向。
16.3.3 RED 算法实现
最简单的 RED 求解是最速下降法: $$ x_{k+1} = x_k - \eta \underbrace{(A^T(Ax_k - y))}_{\text{物理梯度}} - \eta \lambda \underbrace{(x_k - \mathcal{D}(x_k))}_{\text{先验梯度}} $$ 这赋予了去噪器一个非常清晰的物理意义:在每一步迭代中,把图像往“去噪后的方向”拉一把,同时保持在观测数据的一致性范围内。
16.4 算法展开 (Algorithm Unrolling / Deep Unfolding)
PnP 和 RED 是测试时(Test-time)的方法:去噪器是预先训练好的(比如用 BSD500 数据集训练的高斯去噪器),在复原过程中参数固定。 算法展开则是训练时(Train-time)的方法:我们将整个优化迭代过程看作一个网络,端到端训练。
16.4.1 从 Loop 到 Layer (网络架构设计)
假设我们用梯度下降法解 Lasso 问题(ISTA 算法),迭代 $K$ 次。 $$ x_{k+1} = \text{soft}(x_k - \rho A^T(Ax_k - y), \theta) $$ 我们将这 $K$ 步迭代展开为 $K$ 层神经网络:
-
数据一致性模块 (DC Module): 对应 $r_k = x_k - \rho A^T(Ax_k - y)$。这一步包含物理模型 $A$。在网络中,这通常实现为一个固定的计算层(不含可学习参数),或者如果 $A$ 未知,可以用卷积层模拟。
-
近端/去噪模块 (Prox Module): 对应 $\text{soft}(\cdot)$。在深度展开中,我们不再使用简单的软阈值,而是用一个小的 CNN(如 ResBlock)来替代,记为 $\mathcal{Net}_{\theta_k}(\cdot)$。 $$ x_{k+1} = \mathcal{Net}_{\theta_k}(r_k) $$
-
可学习参数:
- 每一层的网络权重 $\theta_k$(甚至可以层间共享权重)。
- 步长 $\rho_k$ 也可以设为可学习的参数。
16.4.2 典型案例:ADMM-Net 与 USRNet
- ADMM-Net (NIPS 2016):将 ADMM 的 $x, z, u$ 更新展开。它是最早将压缩感知 MRI 重建做成深度网络的经典工作。它证明了只需 10-15 个阶段(Stage)就能达到传统算法几百次迭代的精度。
- USRNet (CVPR 2020):基于 HQS 的展开网络,用于超分辨率。它显式地将退化核 $k$ 和噪声水平 $\sigma$ 作为额外的输入传入网络,实现了单个网络处理多种退化(Non-blind SR)。
16.4.3 为什么 Unrolling 是目前的 SOTA?
- 可解释性:相比于直接用 UNet 从 $y$ 映射到 $x$,Unrolling 网络的每一层都有明确的物理含义(去伪影 vs 找回细节)。
- 轻量化:因为物理模型 $A$ 已经处理了数据匹配部分,神经网络只需要学习“残差”或“先验”。这使得所需的参数量远小于纯黑盒网络。
- 小样本泛化:由于强制了数据一致性(Data Consistency),Unrolling 网络在训练数据较少时也不容易过拟合。
16.5 核心模块详解:数据一致性层 (Data Consistency Layer)
在实现 Unrolling 或 PnP 时,很多初学者会忽略或写错数据一致性层。这是物理模型发挥作用的地方。
对于二次保真项 $\frac{1}{2}|Ax-y|^2$,其 Prox 算子(即 DC 层)通常有以下几种实现方式:
-
梯度下降式 (Gradient Block): 最简单,不需要求逆。
x_new = x - step_size * A_adjoint(A(x) - y)优点:通用,适用于任何 $A$。 缺点:收敛慢,需要更多层。 -
闭式解式 (Closed-form Block): 适用于去模糊、超分(循环边界条件)。利用 FFT。 $$ x_{new} = \mathcal{F}^{-1} \left( \frac{\mathcal{F}(z) + \rho \mathcal{F}(A^T y)}{\rho + |\mathcal{F}(A)|^2} \right) $$ 优点:一步到位,精度高。 缺点:只适用于 $A$ 是卷积的情况。
-
共轭梯度式 (CG Block): 当 $A$ 很大且不可对角化(如大型 MRI 非均匀采样),网络层内部可以内嵌 3-5 步共轭梯度下降来近似求解逆问题。
16.6 收敛性与谱归一化 (Spectral Normalization)
把深度网络插入优化算法,收敛性是数学家最担心的问题。如果去噪器 $\mathcal{D}$ 是一个“扩张”算子(即输出的差异比输入的差异还大),迭代可能会发散。
Lipschitz 约束: 为了保证不动点迭代收敛,我们需要去噪器是非扩张的 (Non-expansive),即 Lipschitz 常数 $L \le 1$。 $$ |\mathcal{D}(x_1) - \mathcal{D}(x_2)| \le |x_1 - x_2| $$
工程对策:
- RealSN (Real Spectral Normalization):在训练去噪网络时,对每一层卷积核 $W$ 进行谱归一化(除以其最大奇异值),强制每一层的 $L \le 1$。
- 残差学习技巧:训练 $\mathcal{D}(x) = x + \mathcal{R}(x)$,其中 $\mathcal{R}$ 是残差网络。
- 实用主义:在实际 PnP 应用中,即使 $L > 1$,只要 $\rho$ 足够大(步长足够小),算法往往也能收敛。但在高风险领域(医疗),必须使用 RealSN 保证稳定性。
16.7 本章小结
- PnP (Plug-and-Play) 是一种模块化思想,将优化算法中的“先验子问题”替换为“深度去噪器”。它不需要训练,利用现有去噪器即可解决去模糊、超分等多种问题。
- RED 通过正则化残差能量 $x^T(x-\mathcal{D}(x))$,提供了一种显式的变分解释,且其梯度计算极其简单。
- 算法展开 (Unrolling) 将迭代步骤映射为网络层,通过端到端训练学习参数。它是目前图像复原领域性能与可解释性平衡最好的方法。
- 数据一致性 (DC) 是连接物理与AI的纽带。无论用什么网络,永远不要丢掉 $A^T(Ax-y)$ 这一项。
16.8 练习题
基础题
- PnP 参数调试:在使用 PnP-ADMM 进行去模糊时,如果你发现恢复的图像非常锐利但充满了类似噪声的杂点,你应该增大还是减小惩罚参数 $\rho$?(提示:$\rho$ 越大,去噪强度相对越小)。
- RED 梯度:证明若 $\mathcal{D}(x)$ 是线性算子 $\mathcal{D}(x) = Wx$ 且 $W$ 对称,则 $R(x) = \frac{1}{2}x^T(x - Wx)$ 的梯度确实是 $x - \mathcal{D}(x)$。
- DC 层实现:在 PyTorch 中,如果 $A$ 是一个 $3\times3$ 的均值模糊核,请写出计算数据项梯度 $A^T(Ax - y)$ 的伪代码(使用
F.conv2d)。
挑战题
- 架构设计:针对 压缩感知 (CS) 问题,设计一个展开网络。
- 输入:测量值 $y$ 和测量矩阵 $\Phi$。
- 结构:使用 ISTA 展开。
- 难点:$\Phi$ 通常是大矩阵,不能写成卷积。如何在网络层中高效实现 $\Phi x$ 和 $\Phi^T r$?(提示:全连接层 vs 自定义矩阵乘法层)。
- 开放思考:PnP 方法通常假设噪声是高斯的(因为 Prox 等价于高斯去噪)。如果实际观测噪声是 泊松噪声 (Poisson Noise),你应该如何修改 PnP 框架?
- 提示 A:修改数据项子问题。
- 提示 B:使用 Variance Stabilizing Transform (VST) 配合高斯去噪器。
点击查看练习题提示
- 答案:应该减小 $\rho$。$\rho$ 越小,等效噪声标准差 $\sigma = \sqrt{\lambda/\rho}$ 越大,去噪力度越强。或者直接增大输入的 $\sigma$ 参数。
- 提示:$\nabla (\frac{1}{2}x^T x - \frac{1}{2}x^T W x) = x - \frac{1}{2}(W + W^T)x$。如果 $W$ 对称,则为 $x - Wx$。
- 提示:需要注意 padding 和
groups参数。$A^T$ 对应的是翻转核的卷积(即conv_transpose2d或者手动翻转核后conv2d)。
# Pseudo-code
def data_grad(x, y, kernel):
# Ax
Ax = F.conv2d(x, kernel, padding=1)
# Residual
res = Ax - y
# A^T * res
# For symmetric kernel, A^T = A. Otherwise flip kernel.
grad = F.conv2d(res, kernel.flip(2,3), padding=1)
return grad
- 提示:如果 $\Phi$ 是随机高斯矩阵,可以用全连接层(Linear)。如果 $\Phi$ 是部分傅里叶采样,可以用 FFT 实现。关键是这一层不仅要有 weights,还要能接受外部输入的 mask。
- 提示:方案一:保留高斯去噪器,但在数据一致性步骤使用泊松似然项(KL 散度)进行更新(Prox_KL)。方案二:Anscombe 变换将泊松转为高斯,用 PnP 复原后再逆变换。
16.9 常见陷阱与错误 (Gotchas)
1. 维度灾难与边界效应
- 问题:在实现数据一致性层(特别是 FFT 版)时,忘记处理图像边界。
- 后果:复原图像边缘出现严重的振铃(Ringing)或卷绕伪影。
- 对策:使用
pad_circular(如果用 FFT)或pad_reflection。确保 $A$ 和 $A^T$ 的边界处理是伴随一致的。
2. PnP 的“伪收敛”
- 现象:PnP 迭代曲线显示残差在下降,但图像质量(PSNR)在中间某一步达到峰值后开始下降。
- 原因:过度迭代。去噪器虽然去掉了噪声,但也逐渐抹平了纹理。PnP 的收敛点(Fixed Point)不一定是视觉质量最好的点。
- 对策:实施 Early Stopping(早停),或者精心调节 $\sigma$ 的衰减策略(Exponential Decay)。
3. Unrolling 的训练显存爆炸
- 问题:展开 20 层 ADMM,每层包含一个 UNet,反向传播时需要存储所有层的梯度,显存瞬间爆炸。
- 对策:
- 使用 Gradient Checkpointing(以时间换空间)。
- 减少每层子网络的深度(例如每层只用 3 个 ResBlock,而不是完整的 UNet)。
- 权重共享:所有迭代阶段共用同一个 CNN 的权重(性能略降,但参数量大减)。
4. 忽略了输入数据的归一化
- 问题:直接把 [0, 255] 的图像塞进 PnP。
- 后果:如果去噪器是针对 [0, 1] 训练的,输出全是白屏或黑屏。
- 对策:严格检查预训练模型的输入规范。$\sigma$ 值也要相应缩放(如 $\sigma=25$ 对应 [0, 255],在 [0, 1] 下应为 $25/255$)。