opt_vision_tutorial

第16章 Plug-and-Play (PnP)/RED 与“算法展开”:优化与深度的结合

开篇段落

在前十五章中,我们构建了一个严谨的数学世界:图像复原问题被建模为“数据保真项 + 正则项”的变分问题。我们学习了如何设计 TV、稀疏、低秩等正则项来描述“什么是好的图像”。然而,当我们面对真实世界的复杂退化(如非高斯噪声、复杂的运动模糊、JPEG 压缩伪影)时,手工设计的正则项往往显得力不从心——它们太简单了,无法捕捉自然图像高维流形上的所有细节。

与此同时,深度学习(DL)以一种“暴力美学”的方式席卷了该领域。端到端(End-to-End)的神经网络(如 UNet, ResNet)通过海量数据训练,效果往往碾压传统方法。但 DL 也有其阿喀琉斯之踵:黑盒性质导致缺乏可解释性,且泛化能力差——一个针对高斯模糊训练的网络,面对运动模糊可能完全失效,需要重新训练。

本章将介绍连接这两个世界的“中间道路”:Plug-and-Play (PnP), Regularization by Denoising (RED) 以及 算法展开 (Deep Unrolling/Unfolding)。这些方法的核心哲学是:“让上帝的归上帝,凯撒的归凯撒”——用传统的优化算法框架(如 ADMM/HQS)处理物理退化模型(已知且可靠),用深度神经网络处理图像先验(复杂且难懂)。这不仅赋予了深度学习以数学结构,也让传统优化获得了数据驱动的强大表现力。


16.1 PnP 思想:从“手工正则”到“隐式去噪”

16.1.1 为什么是去噪器?(The Denoiser Intuition)

让我们重新审视近端算子(Proximal Operator)的定义。对于正则项 $\lambda R(x)$,其 Prox 定义为: \(\text{prox}_{\lambda R}(v) = \arg\min_x \underbrace{\frac{1}{2}\|x - v\|_2^2}_{\text{输入} v \text{的保真}} + \underbrace{\lambda R(x)}_{\text{先验约束}}\) 如果你仔细观察这个公式,这严格等价于一个贝叶斯去噪问题:假设观测信号 $v = x + n$,其中 $n$ 是高斯白噪声,方差与 $\lambda$ 相关,我们需要从 $v$ 中恢复 $x$。

核心洞察: 传统方法中,我们显式定义 $R(x) = |x|_{\text{TV}}$,然后费力推导其 Prox 算子。 而在 PnP 中,我们反其道而行之:既然 Prox 等价于去噪,为什么不直接用一个现成的、性能最强的去噪器(Denoiser)来替换这一步计算?

我们不再关心 $R(x)$ 的数学表达式是什么,我们只关心去噪器 $\mathcal{D}(\cdot)$ 能不能把图像“推”回到自然图像流形上。

16.1.2 PnP 的通用工作流

PnP 使得我们可以解决任意线性逆问题 $y = Ax + n$,只要我们有求解器(如 ADMM 或 HQS)和一个去噪器(如 DnCNN, DRUNet)。

             +---------------------------+
             |  物理模型 (退化算子 A)    |
             |  负责:数据一致性         |
             +------------+--------------+
                          |
             (把当前估计推向 y 的观测流形)
                          |
          +------- 迭代循环 (Loop) --------+
          |                                |
          v                                v
  [数据项子问题]                     [先验项子问题]
  Inversion Step                     Denoising Step
  x = (A'A + ρI)^-1 (...)            z = Denoiser(x)
          |                                |
          ^                                ^
          |                                |
          +-----------(变量耦合)-----------+

16.2 两种主流 PnP 框架:HQS 与 ADMM

虽然 ADMM 很强,但在 PnP 领域,半二次分裂(Half-Quadratic Splitting, HQS) 因其简洁性同样非常流行。

16.2.1 PnP-HQS:简单即是美

求解 $\min_x \frac{1}{2}|Ax - y|^2 + \lambda R(x)$。 引入辅助变量 $z$,构建目标函数: \(\min_{x,z} \frac{1}{2}\|Ax - y\|^2 + \lambda R(z) + \frac{\mu}{2}\|x - z\|^2\) 当 $\mu \to \infty$ 时,$x$ 逼近 $z$,问题回归原问题。HQS 采用交替最小化:

  1. 数据项更新 ($x$-step):这是一个最小二乘问题。 \(x_{k+1} = \arg\min_x \|Ax - y\|^2 + \mu \|x - z_k\|^2\) :$x_{k+1} = (A^T A + \mu I)^{-1} (A^T y + \mu z_k)$ 注:对于卷积或 FFT 可对角化的 $A$,这一步有闭式解且极快。

  2. 先验项更新 ($z$-step):这是一个去噪问题。 \(z_{k+1} = \arg\min_z \frac{\mu}{2}\|x_{k+1} - z\|^2 + \lambda R(z)\) 等价于: \(z_{k+1} = \arg\min_z \frac{1}{2}\|z - x_{k+1}\|^2 + \frac{\lambda}{\mu} R(z)\) PnP 替换:直接调用去噪器,输入噪声水平 $\sigma_k = \sqrt{\lambda / \mu}$。 \(z_{k+1} = \mathcal{D}_{\sigma_k}(x_{k+1})\)

16.2.2 PnP-ADMM:更严格的约束

相比 HQS,ADMM 引入了拉格朗日乘子 $u$,对约束的处理更稳健,尤其是在 $\mu$ 不需要趋于无穷大的情况下也能收敛。

16.2.3 关键技巧:参数 $\mu$ 与 $\sigma$ 的联动 (Rule-of-Thumb)

PnP 成功的秘诀在于退火策略


16.3 RED: Regularization by Denoising

PnP 虽然好用,但它是“算法层面的替换”,我们失去了原本优化的目标函数 $J(x)$。RED 试图找回这个目标函数。

16.3.1 显式正则项的构造

RED 提出了一个惊人的定义: \(R_{RED}(x) = \frac{1}{2} x^T (x - \mathcal{D}(x))\) 直觉解释

16.3.2 梯度即残差

RED 理论证明,在去噪器满足局部均匀性雅可比对称性(虽然现实网络不完全满足,但近似成立)的条件下: \(\nabla R_{RED}(x) = x - \mathcal{D}(x)\) 这一结论极大地简化了优化。我们不需要对复杂的 CNN 进行反向传播(Backprop)来求梯度,残差本身就是梯度方向

16.3.3 RED 算法实现

最简单的 RED 求解是最速下降法: \(x_{k+1} = x_k - \eta \underbrace{(A^T(Ax_k - y))}_{\text{物理梯度}} - \eta \lambda \underbrace{(x_k - \mathcal{D}(x_k))}_{\text{先验梯度}}\) 这赋予了去噪器一个非常清晰的物理意义:在每一步迭代中,把图像往“去噪后的方向”拉一把,同时保持在观测数据的一致性范围内。


16.4 算法展开 (Algorithm Unrolling / Deep Unfolding)

PnP 和 RED 是测试时(Test-time)的方法:去噪器是预先训练好的(比如用 BSD500 数据集训练的高斯去噪器),在复原过程中参数固定。 算法展开则是训练时(Train-time)的方法:我们将整个优化迭代过程看作一个网络,端到端训练。

16.4.1 从 Loop 到 Layer (网络架构设计)

假设我们用梯度下降法解 Lasso 问题(ISTA 算法),迭代 $K$ 次。 \(x_{k+1} = \text{soft}(x_k - \rho A^T(Ax_k - y), \theta)\) 我们将这 $K$ 步迭代展开为 $K$ 层神经网络:

  1. 数据一致性模块 (DC Module): 对应 $r_k = x_k - \rho A^T(Ax_k - y)$。这一步包含物理模型 $A$。在网络中,这通常实现为一个固定的计算层(不含可学习参数),或者如果 $A$ 未知,可以用卷积层模拟。

  2. 近端/去噪模块 (Prox Module): 对应 $\text{soft}(\cdot)$。在深度展开中,我们不再使用简单的软阈值,而是用一个小的 CNN(如 ResBlock)来替代,记为 $\mathcal{Net}_{\theta_k}(\cdot)$。 \(x_{k+1} = \mathcal{Net}_{\theta_k}(r_k)\)

  3. 可学习参数

    • 每一层的网络权重 $\theta_k$(甚至可以层间共享权重)。
    • 步长 $\rho_k$ 也可以设为可学习的参数。

16.4.2 典型案例:ADMM-Net 与 USRNet

16.4.3 为什么 Unrolling 是目前的 SOTA?

  1. 可解释性:相比于直接用 UNet 从 $y$ 映射到 $x$,Unrolling 网络的每一层都有明确的物理含义(去伪影 vs 找回细节)。
  2. 轻量化:因为物理模型 $A$ 已经处理了数据匹配部分,神经网络只需要学习“残差”或“先验”。这使得所需的参数量远小于纯黑盒网络。
  3. 小样本泛化:由于强制了数据一致性(Data Consistency),Unrolling 网络在训练数据较少时也不容易过拟合。

16.5 核心模块详解:数据一致性层 (Data Consistency Layer)

在实现 Unrolling 或 PnP 时,很多初学者会忽略或写错数据一致性层。这是物理模型发挥作用的地方。

对于二次保真项 $\frac{1}{2}|Ax-y|^2$,其 Prox 算子(即 DC 层)通常有以下几种实现方式:

  1. 梯度下降式 (Gradient Block): 最简单,不需要求逆。 x_new = x - step_size * A_adjoint(A(x) - y) 优点:通用,适用于任何 $A$。 缺点:收敛慢,需要更多层。

  2. 闭式解式 (Closed-form Block): 适用于去模糊、超分(循环边界条件)。利用 FFT。 \(x_{new} = \mathcal{F}^{-1} \left( \frac{\mathcal{F}(z) + \rho \mathcal{F}(A^T y)}{\rho + |\mathcal{F}(A)|^2} \right)\) 优点:一步到位,精度高。 缺点:只适用于 $A$ 是卷积的情况。

  3. 共轭梯度式 (CG Block): 当 $A$ 很大且不可对角化(如大型 MRI 非均匀采样),网络层内部可以内嵌 3-5 步共轭梯度下降来近似求解逆问题。


16.6 收敛性与谱归一化 (Spectral Normalization)

把深度网络插入优化算法,收敛性是数学家最担心的问题。如果去噪器 $\mathcal{D}$ 是一个“扩张”算子(即输出的差异比输入的差异还大),迭代可能会发散。

Lipschitz 约束: 为了保证不动点迭代收敛,我们需要去噪器是非扩张的 (Non-expansive),即 Lipschitz 常数 $L \le 1$。 \(\|\mathcal{D}(x_1) - \mathcal{D}(x_2)\| \le \|x_1 - x_2\|\)

工程对策

  1. RealSN (Real Spectral Normalization):在训练去噪网络时,对每一层卷积核 $W$ 进行谱归一化(除以其最大奇异值),强制每一层的 $L \le 1$。
  2. 残差学习技巧:训练 $\mathcal{D}(x) = x + \mathcal{R}(x)$,其中 $\mathcal{R}$ 是残差网络。
  3. 实用主义:在实际 PnP 应用中,即使 $L > 1$,只要 $\rho$ 足够大(步长足够小),算法往往也能收敛。但在高风险领域(医疗),必须使用 RealSN 保证稳定性。

16.7 本章小结

  1. PnP (Plug-and-Play) 是一种模块化思想,将优化算法中的“先验子问题”替换为“深度去噪器”。它不需要训练,利用现有去噪器即可解决去模糊、超分等多种问题。
  2. RED 通过正则化残差能量 $x^T(x-\mathcal{D}(x))$,提供了一种显式的变分解释,且其梯度计算极其简单。
  3. 算法展开 (Unrolling) 将迭代步骤映射为网络层,通过端到端训练学习参数。它是目前图像复原领域性能与可解释性平衡最好的方法。
  4. 数据一致性 (DC) 是连接物理与AI的纽带。无论用什么网络,永远不要丢掉 $A^T(Ax-y)$ 这一项。

16.8 练习题

基础题

  1. PnP 参数调试:在使用 PnP-ADMM 进行去模糊时,如果你发现恢复的图像非常锐利但充满了类似噪声的杂点,你应该增大还是减小惩罚参数 $\rho$?(提示:$\rho$ 越大,去噪强度相对越小)。
  2. RED 梯度:证明若 $\mathcal{D}(x)$ 是线性算子 $\mathcal{D}(x) = Wx$ 且 $W$ 对称,则 $R(x) = \frac{1}{2}x^T(x - Wx)$ 的梯度确实是 $x - \mathcal{D}(x)$。
  3. DC 层实现:在 PyTorch 中,如果 $A$ 是一个 $3\times3$ 的均值模糊核,请写出计算数据项梯度 $A^T(Ax - y)$ 的伪代码(使用 F.conv2d)。

挑战题

  1. 架构设计:针对 压缩感知 (CS) 问题,设计一个展开网络。
    • 输入:测量值 $y$ 和测量矩阵 $\Phi$。
    • 结构:使用 ISTA 展开。
    • 难点:$\Phi$ 通常是大矩阵,不能写成卷积。如何在网络层中高效实现 $\Phi x$ 和 $\Phi^T r$?(提示:全连接层 vs 自定义矩阵乘法层)。
  2. 开放思考:PnP 方法通常假设噪声是高斯的(因为 Prox 等价于高斯去噪)。如果实际观测噪声是 泊松噪声 (Poisson Noise),你应该如何修改 PnP 框架?
    • 提示 A:修改数据项子问题。
    • 提示 B:使用 Variance Stabilizing Transform (VST) 配合高斯去噪器。
点击查看练习题提示 1. **答案**:应该**减小** $\rho$。$\rho$ 越小,等效噪声标准差 $\sigma = \sqrt{\lambda/\rho}$ 越大,去噪力度越强。或者直接增大输入的 $\sigma$ 参数。 2. **提示**:$\nabla (\frac{1}{2}x^T x - \frac{1}{2}x^T W x) = x - \frac{1}{2}(W + W^T)x$。如果 $W$ 对称,则为 $x - Wx$。 3. **提示**:需要注意 padding 和 `groups` 参数。$A^T$ 对应的是翻转核的卷积(即 `conv_transpose2d` 或者手动翻转核后 `conv2d`)。 ```python # Pseudo-code def data_grad(x, y, kernel): # Ax Ax = F.conv2d(x, kernel, padding=1) # Residual res = Ax - y # A^T * res # For symmetric kernel, A^T = A. Otherwise flip kernel. grad = F.conv2d(res, kernel.flip(2,3), padding=1) return grad ``` 4. **提示**:如果 $\Phi$ 是随机高斯矩阵,可以用全连接层(Linear)。如果 $\Phi$ 是部分傅里叶采样,可以用 FFT 实现。关键是这一层不仅要有 weights,还要能接受外部输入的 mask。 5. **提示**:方案一:保留高斯去噪器,但在数据一致性步骤使用泊松似然项(KL 散度)进行更新(Prox_KL)。方案二:Anscombe 变换将泊松转为高斯,用 PnP 复原后再逆变换。

16.9 常见陷阱与错误 (Gotchas)

1. 维度灾难与边界效应

2. PnP 的“伪收敛”

3. Unrolling 的训练显存爆炸

4. 忽略了输入数据的归一化