第16章 Plug-and-Play (PnP)/RED 与“算法展开”:优化与深度的结合

开篇段落

在前十五章中,我们构建了一个严谨的数学世界:图像复原问题被建模为“数据保真项 + 正则项”的变分问题。我们学习了如何设计 TV、稀疏、低秩等正则项来描述“什么是好的图像”。然而,当我们面对真实世界的复杂退化(如非高斯噪声、复杂的运动模糊、JPEG 压缩伪影)时,手工设计的正则项往往显得力不从心——它们太简单了,无法捕捉自然图像高维流形上的所有细节。

与此同时,深度学习(DL)以一种“暴力美学”的方式席卷了该领域。端到端(End-to-End)的神经网络(如 UNet, ResNet)通过海量数据训练,效果往往碾压传统方法。但 DL 也有其阿喀琉斯之踵:黑盒性质导致缺乏可解释性,且泛化能力差——一个针对高斯模糊训练的网络,面对运动模糊可能完全失效,需要重新训练。

本章将介绍连接这两个世界的“中间道路”:Plug-and-Play (PnP), Regularization by Denoising (RED) 以及 算法展开 (Deep Unrolling/Unfolding)。这些方法的核心哲学是:“让上帝的归上帝,凯撒的归凯撒”——用传统的优化算法框架(如 ADMM/HQS)处理物理退化模型(已知且可靠),用深度神经网络处理图像先验(复杂且难懂)。这不仅赋予了深度学习以数学结构,也让传统优化获得了数据驱动的强大表现力。


16.1 PnP 思想:从“手工正则”到“隐式去噪”

16.1.1 为什么是去噪器?(The Denoiser Intuition)

让我们重新审视近端算子(Proximal Operator)的定义。对于正则项 $\lambda R(x)$,其 Prox 定义为: $$ \text{prox}_{\lambda R}(v) = \arg\min_x \underbrace{\frac{1}{2}|x - v|_2^2}_{\text{输入} v \text{的保真}} + \underbrace{\lambda R(x)}_{\text{先验约束}} $$ 如果你仔细观察这个公式,这严格等价于一个贝叶斯去噪问题:假设观测信号 $v = x + n$,其中 $n$ 是高斯白噪声,方差与 $\lambda$ 相关,我们需要从 $v$ 中恢复 $x$。

核心洞察: 传统方法中,我们显式定义 $R(x) = |x|_{\text{TV}}$,然后费力推导其 Prox 算子。 而在 PnP 中,我们反其道而行之:既然 Prox 等价于去噪,为什么不直接用一个现成的、性能最强的去噪器(Denoiser)来替换这一步计算?

我们不再关心 $R(x)$ 的数学表达式是什么,我们只关心去噪器 $\mathcal{D}(\cdot)$ 能不能把图像“推”回到自然图像流形上。

16.1.2 PnP 的通用工作流

PnP 使得我们可以解决任意线性逆问题 $y = Ax + n$,只要我们有求解器(如 ADMM 或 HQS)和一个去噪器(如 DnCNN, DRUNet)。

             +---------------------------+
             |  物理模型 (退化算子 A)    |
             |  负责:数据一致性         |
             +------------+--------------+
                          |
             (把当前估计推向 y 的观测流形)
                          |
          +------- 迭代循环 (Loop) --------+
          |                                |
          v                                v
  [数据项子问题]                     [先验项子问题]
  Inversion Step                     Denoising Step
  x = (A'A + ρI)^-1 (...)            z = Denoiser(x)
          |                                |
          ^                                ^
          |                                |
          +-----------(变量耦合)-----------+

16.2 两种主流 PnP 框架:HQS 与 ADMM

虽然 ADMM 很强,但在 PnP 领域,半二次分裂(Half-Quadratic Splitting, HQS) 因其简洁性同样非常流行。

16.2.1 PnP-HQS:简单即是美

求解 $\min_x \frac{1}{2}|Ax - y|^2 + \lambda R(x)$。 引入辅助变量 $z$,构建目标函数: $$ \min_{x,z} \frac{1}{2}|Ax - y|^2 + \lambda R(z) + \frac{\mu}{2}|x - z|^2 $$ 当 $\mu \to \infty$ 时,$x$ 逼近 $z$,问题回归原问题。HQS 采用交替最小化:

  1. 数据项更新 ($x$-step):这是一个最小二乘问题。 $$ x_{k+1} = \arg\min_x |Ax - y|^2 + \mu |x - z_k|^2 $$ :$x_{k+1} = (A^T A + \mu I)^{-1} (A^T y + \mu z_k)$ 注:对于卷积或 FFT 可对角化的 $A$,这一步有闭式解且极快。

  2. 先验项更新 ($z$-step):这是一个去噪问题。 $$ z_{k+1} = \arg\min_z \frac{\mu}{2}|x_{k+1} - z|^2 + \lambda R(z) $$ 等价于: $$ z_{k+1} = \arg\min_z \frac{1}{2}|z - x_{k+1}|^2 + \frac{\lambda}{\mu} R(z) $$ PnP 替换:直接调用去噪器,输入噪声水平 $\sigma_k = \sqrt{\lambda / \mu}$。 $$ z_{k+1} = \mathcal{D}_{\sigma_k}(x_{k+1}) $$

16.2.2 PnP-ADMM:更严格的约束

相比 HQS,ADMM 引入了拉格朗日乘子 $u$,对约束的处理更稳健,尤其是在 $\mu$ 不需要趋于无穷大的情况下也能收敛。

  • 区别:ADMM 的去噪步输入是 $(x_{k+1} + u_k)$,即包含累积误差的修正。
  • 适用场景:当退化算子 $A$ 很难求逆(比如非线性算子),或者需要处理硬约束(如 $x \in [0,1]$)时,ADMM 及其线性化变体更有优势。

16.2.3 关键技巧:参数 $\mu$ 与 $\sigma$ 的联动 (Rule-of-Thumb)

PnP 成功的秘诀在于退火策略

  • 在迭代初期,图像误差很大,我们希望去噪器“下手重一点”以消除大幅伪影,同时 $\mu$ 较小(允许 $x$ 和 $z$ 不太一致)。
  • 随着迭代进行,图像变清晰,我们减小去噪强度($\sigma \downarrow$),增大 $\mu$(强制一致)。
  • 经验公式:设定一组下降的噪声水平 $\sigma \in [50, 30, 10, 5]$,根据 $\sigma_k = \sqrt{\lambda/\mu_k}$ 反推对应的惩罚参数 $\mu_k$。

16.3 RED: Regularization by Denoising

PnP 虽然好用,但它是“算法层面的替换”,我们失去了原本优化的目标函数 $J(x)$。RED 试图找回这个目标函数。

16.3.1 显式正则项的构造

RED 提出了一个惊人的定义: $$ R_{RED}(x) = \frac{1}{2} x^T (x - \mathcal{D}(x)) $$ 直觉解释

  • $\mathcal{D}(x)$ 是对 $x$ 去噪后的结果。
  • $(x - \mathcal{D}(x))$ 是图像中的“噪声残差”。
  • $x^T (\text{noise})$:我们希望图像信号与噪声残差正交(或者说内积最小)。对于干净图像,$x \approx \mathcal{D}(x)$,则 $R_{RED} \approx 0$。

16.3.2 梯度即残差

RED 理论证明,在去噪器满足局部均匀性雅可比对称性(虽然现实网络不完全满足,但近似成立)的条件下: $$ \nabla R_{RED}(x) = x - \mathcal{D}(x) $$ 这一结论极大地简化了优化。我们不需要对复杂的 CNN 进行反向传播(Backprop)来求梯度,残差本身就是梯度方向

16.3.3 RED 算法实现

最简单的 RED 求解是最速下降法: $$ x_{k+1} = x_k - \eta \underbrace{(A^T(Ax_k - y))}_{\text{物理梯度}} - \eta \lambda \underbrace{(x_k - \mathcal{D}(x_k))}_{\text{先验梯度}} $$ 这赋予了去噪器一个非常清晰的物理意义:在每一步迭代中,把图像往“去噪后的方向”拉一把,同时保持在观测数据的一致性范围内。


16.4 算法展开 (Algorithm Unrolling / Deep Unfolding)

PnP 和 RED 是测试时(Test-time)的方法:去噪器是预先训练好的(比如用 BSD500 数据集训练的高斯去噪器),在复原过程中参数固定。 算法展开则是训练时(Train-time)的方法:我们将整个优化迭代过程看作一个网络,端到端训练。

16.4.1 从 Loop 到 Layer (网络架构设计)

假设我们用梯度下降法解 Lasso 问题(ISTA 算法),迭代 $K$ 次。 $$ x_{k+1} = \text{soft}(x_k - \rho A^T(Ax_k - y), \theta) $$ 我们将这 $K$ 步迭代展开为 $K$ 层神经网络:

  1. 数据一致性模块 (DC Module): 对应 $r_k = x_k - \rho A^T(Ax_k - y)$。这一步包含物理模型 $A$。在网络中,这通常实现为一个固定的计算层(不含可学习参数),或者如果 $A$ 未知,可以用卷积层模拟。

  2. 近端/去噪模块 (Prox Module): 对应 $\text{soft}(\cdot)$。在深度展开中,我们不再使用简单的软阈值,而是用一个小的 CNN(如 ResBlock)来替代,记为 $\mathcal{Net}_{\theta_k}(\cdot)$。 $$ x_{k+1} = \mathcal{Net}_{\theta_k}(r_k) $$

  3. 可学习参数

    • 每一层的网络权重 $\theta_k$(甚至可以层间共享权重)。
    • 步长 $\rho_k$ 也可以设为可学习的参数。

16.4.2 典型案例:ADMM-Net 与 USRNet

  • ADMM-Net (NIPS 2016):将 ADMM 的 $x, z, u$ 更新展开。它是最早将压缩感知 MRI 重建做成深度网络的经典工作。它证明了只需 10-15 个阶段(Stage)就能达到传统算法几百次迭代的精度。
  • USRNet (CVPR 2020):基于 HQS 的展开网络,用于超分辨率。它显式地将退化核 $k$ 和噪声水平 $\sigma$ 作为额外的输入传入网络,实现了单个网络处理多种退化(Non-blind SR)。

16.4.3 为什么 Unrolling 是目前的 SOTA?

  1. 可解释性:相比于直接用 UNet 从 $y$ 映射到 $x$,Unrolling 网络的每一层都有明确的物理含义(去伪影 vs 找回细节)。
  2. 轻量化:因为物理模型 $A$ 已经处理了数据匹配部分,神经网络只需要学习“残差”或“先验”。这使得所需的参数量远小于纯黑盒网络。
  3. 小样本泛化:由于强制了数据一致性(Data Consistency),Unrolling 网络在训练数据较少时也不容易过拟合。

16.5 核心模块详解:数据一致性层 (Data Consistency Layer)

在实现 Unrolling 或 PnP 时,很多初学者会忽略或写错数据一致性层。这是物理模型发挥作用的地方。

对于二次保真项 $\frac{1}{2}|Ax-y|^2$,其 Prox 算子(即 DC 层)通常有以下几种实现方式:

  1. 梯度下降式 (Gradient Block): 最简单,不需要求逆。 x_new = x - step_size * A_adjoint(A(x) - y) 优点:通用,适用于任何 $A$。 缺点:收敛慢,需要更多层。

  2. 闭式解式 (Closed-form Block): 适用于去模糊、超分(循环边界条件)。利用 FFT。 $$ x_{new} = \mathcal{F}^{-1} \left( \frac{\mathcal{F}(z) + \rho \mathcal{F}(A^T y)}{\rho + |\mathcal{F}(A)|^2} \right) $$ 优点:一步到位,精度高。 缺点:只适用于 $A$ 是卷积的情况。

  3. 共轭梯度式 (CG Block): 当 $A$ 很大且不可对角化(如大型 MRI 非均匀采样),网络层内部可以内嵌 3-5 步共轭梯度下降来近似求解逆问题。


16.6 收敛性与谱归一化 (Spectral Normalization)

把深度网络插入优化算法,收敛性是数学家最担心的问题。如果去噪器 $\mathcal{D}$ 是一个“扩张”算子(即输出的差异比输入的差异还大),迭代可能会发散。

Lipschitz 约束: 为了保证不动点迭代收敛,我们需要去噪器是非扩张的 (Non-expansive),即 Lipschitz 常数 $L \le 1$。 $$ |\mathcal{D}(x_1) - \mathcal{D}(x_2)| \le |x_1 - x_2| $$

工程对策

  1. RealSN (Real Spectral Normalization):在训练去噪网络时,对每一层卷积核 $W$ 进行谱归一化(除以其最大奇异值),强制每一层的 $L \le 1$。
  2. 残差学习技巧:训练 $\mathcal{D}(x) = x + \mathcal{R}(x)$,其中 $\mathcal{R}$ 是残差网络。
  3. 实用主义:在实际 PnP 应用中,即使 $L > 1$,只要 $\rho$ 足够大(步长足够小),算法往往也能收敛。但在高风险领域(医疗),必须使用 RealSN 保证稳定性。

16.7 本章小结

  1. PnP (Plug-and-Play) 是一种模块化思想,将优化算法中的“先验子问题”替换为“深度去噪器”。它不需要训练,利用现有去噪器即可解决去模糊、超分等多种问题。
  2. RED 通过正则化残差能量 $x^T(x-\mathcal{D}(x))$,提供了一种显式的变分解释,且其梯度计算极其简单。
  3. 算法展开 (Unrolling) 将迭代步骤映射为网络层,通过端到端训练学习参数。它是目前图像复原领域性能与可解释性平衡最好的方法。
  4. 数据一致性 (DC) 是连接物理与AI的纽带。无论用什么网络,永远不要丢掉 $A^T(Ax-y)$ 这一项。

16.8 练习题

基础题

  1. PnP 参数调试:在使用 PnP-ADMM 进行去模糊时,如果你发现恢复的图像非常锐利但充满了类似噪声的杂点,你应该增大还是减小惩罚参数 $\rho$?(提示:$\rho$ 越大,去噪强度相对越小)。
  2. RED 梯度:证明若 $\mathcal{D}(x)$ 是线性算子 $\mathcal{D}(x) = Wx$ 且 $W$ 对称,则 $R(x) = \frac{1}{2}x^T(x - Wx)$ 的梯度确实是 $x - \mathcal{D}(x)$。
  3. DC 层实现:在 PyTorch 中,如果 $A$ 是一个 $3\times3$ 的均值模糊核,请写出计算数据项梯度 $A^T(Ax - y)$ 的伪代码(使用 F.conv2d)。

挑战题

  1. 架构设计:针对 压缩感知 (CS) 问题,设计一个展开网络。
    • 输入:测量值 $y$ 和测量矩阵 $\Phi$。
    • 结构:使用 ISTA 展开。
    • 难点:$\Phi$ 通常是大矩阵,不能写成卷积。如何在网络层中高效实现 $\Phi x$ 和 $\Phi^T r$?(提示:全连接层 vs 自定义矩阵乘法层)。
  2. 开放思考:PnP 方法通常假设噪声是高斯的(因为 Prox 等价于高斯去噪)。如果实际观测噪声是 泊松噪声 (Poisson Noise),你应该如何修改 PnP 框架?
    • 提示 A:修改数据项子问题。
    • 提示 B:使用 Variance Stabilizing Transform (VST) 配合高斯去噪器。
点击查看练习题提示
  1. 答案:应该减小 $\rho$。$\rho$ 越小,等效噪声标准差 $\sigma = \sqrt{\lambda/\rho}$ 越大,去噪力度越强。或者直接增大输入的 $\sigma$ 参数。
  2. 提示:$\nabla (\frac{1}{2}x^T x - \frac{1}{2}x^T W x) = x - \frac{1}{2}(W + W^T)x$。如果 $W$ 对称,则为 $x - Wx$。
  3. 提示:需要注意 padding 和 groups 参数。$A^T$ 对应的是翻转核的卷积(即 conv_transpose2d 或者手动翻转核后 conv2d)。
# Pseudo-code
def data_grad(x, y, kernel):
    # Ax
    Ax = F.conv2d(x, kernel, padding=1)
    # Residual
    res = Ax - y
    # A^T * res
    # For symmetric kernel, A^T = A. Otherwise flip kernel.
    grad = F.conv2d(res, kernel.flip(2,3), padding=1)
    return grad
  1. 提示:如果 $\Phi$ 是随机高斯矩阵,可以用全连接层(Linear)。如果 $\Phi$ 是部分傅里叶采样,可以用 FFT 实现。关键是这一层不仅要有 weights,还要能接受外部输入的 mask。
  2. 提示:方案一:保留高斯去噪器,但在数据一致性步骤使用泊松似然项(KL 散度)进行更新(Prox_KL)。方案二:Anscombe 变换将泊松转为高斯,用 PnP 复原后再逆变换。

16.9 常见陷阱与错误 (Gotchas)

1. 维度灾难与边界效应

  • 问题:在实现数据一致性层(特别是 FFT 版)时,忘记处理图像边界。
  • 后果:复原图像边缘出现严重的振铃(Ringing)或卷绕伪影。
  • 对策:使用 pad_circular(如果用 FFT)或 pad_reflection。确保 $A$ 和 $A^T$ 的边界处理是伴随一致的。

2. PnP 的“伪收敛”

  • 现象:PnP 迭代曲线显示残差在下降,但图像质量(PSNR)在中间某一步达到峰值后开始下降。
  • 原因:过度迭代。去噪器虽然去掉了噪声,但也逐渐抹平了纹理。PnP 的收敛点(Fixed Point)不一定是视觉质量最好的点。
  • 对策:实施 Early Stopping(早停),或者精心调节 $\sigma$ 的衰减策略(Exponential Decay)。

3. Unrolling 的训练显存爆炸

  • 问题:展开 20 层 ADMM,每层包含一个 UNet,反向传播时需要存储所有层的梯度,显存瞬间爆炸。
  • 对策
    1. 使用 Gradient Checkpointing(以时间换空间)。
    2. 减少每层子网络的深度(例如每层只用 3 个 ResBlock,而不是完整的 UNet)。
    3. 权重共享:所有迭代阶段共用同一个 CNN 的权重(性能略降,但参数量大减)。

4. 忽略了输入数据的归一化

  • 问题:直接把 [0, 255] 的图像塞进 PnP。
  • 后果:如果去噪器是针对 [0, 1] 训练的,输出全是白屏或黑屏。
  • 对策:严格检查预训练模型的输入规范。$\sigma$ 值也要相应缩放(如 $\sigma=25$ 对应 [0, 255],在 [0, 1] 下应为 $25/255$)。