附录 C:实验与代码模板 (Experiments & Code Templates)
1. 开篇段落
在之前的章节中,我们推导了诸如 ADMM、PDHG 和 Split Bregman 等算法的数学形式。公式是优雅且紧凑的,例如 $x_{k+1} = \text{prox}_{\gamma g}(x_k - \gamma \nabla f(x_k))$。然而,在实际编程中,这一行公式背后隐藏着巨大的工程挑战:如何高效处理百万级像素的矩阵向量乘法?如何管理数十个超参数?如何保证每一次实验都是可复现的?
本章不再讨论数学证明,而是转向软件工程。我们将建立一套面向对象的、模块化的代码架构,用于解决大规模逆问题。无论你使用 Python (NumPy/PyTorch) 还是 MATLAB,本章提供的设计模式(Design Patterns)都能帮助你摆脱“脚本式编程”的混乱,构建出清晰、可扩展且易于调试的算法库。我们将重点讨论无矩阵算子设计、解耦求解器架构以及科学实验的配置管理。
2. 文字论述
2.1 核心基石:线性算子抽象 (Linear Operator Abstraction)
在图像复原与重建中,观测模型通常写作 $y = Ax + n$。虽然 $A$ 在数学上是一个矩阵,但在工程实现中,$A$ 永远不应该以显式矩阵的形式存在。 例如,对于一张 $1024 \times 1024$ 的图像,即使是简单的恒等变换,$A$ 也是一个 $10^6 \times 10^6$ 的单位阵,直接存储需要数 TB 的内存。
因此,我们采用无矩阵 (Matrix-free) 编程范式。我们将 $A$ 视为一个黑盒对象,只暴露两个接口:
- Forward (正向映射): $u \to A(u)$。对应物理成像过程(如模糊、降采样)。
- Adjoint (伴随/转置映射): $v \to A^*(v)$。对应梯度的反向传播方向(如反投影、零填充)。
关键原则:任何涉及 $A^T A$ 或 $A^{-1}$ 的操作,都必须分解为对 $A$ 和 $A^*$ 的一系列调用,或者使用专门的迭代解法(如共轭梯度法 CG)。
2.1.1 算子类的标准接口设计
一个健壮的算子库应该像乐高积木一样可组合。
ASCII 示意图:算子类的继承结构
[ BaseLinearOperator ]
+ shape: (out_dim, in_dim)
+ forward(x) -> raise NotImplemented
+ adjoint(y) -> raise NotImplemented
+ dot_test() -> Boolean (自动自检)
^
|
+-------+-------+------------------+
| | |
[ Identity ] [ Scaling ] [ Composite ] (组合算子)
( y = x ) ( y = s*x ) ( A = B * C )
( A.fwd = B.fwd(C.fwd(x)) )
( A.adj = C.adj(B.adj(y)) ) 注意顺序!
2.1.2 为什么不用自动微分 (AutoDiff)?
现代深度学习框架(PyTorch/TensorFlow)提供了 autograd,可以自动计算 $A^T v$。为什么我们还需要手动编写 adjoint?
- 内存效率:自动微分通常需要存储前向传播的中间图,内存消耗大。手动伴随通常是“原地”操作 (in-place)。
- 复数与 FFT 支持:许多经典优化(如卷积稀疏编码)在频域求解,手动处理复数共轭转置比自动微分更可控。
- 精确性:某些算子(如循环移位)的自动微分可能不如手动实现的 FFT 版本高效。
2.1.3 唯一的真理:点积测试 (The Dot-Product Test)
在调试优化算法时,90% 的错误源于 $A^*$ 写错了(例如卷积核忘了翻转、边缘填充弄反了)。点积测试是验证算子正确性的数学黄金标准。
对于任何线性算子 $A: \mathcal{X} \to \mathcal{Y}$,其伴随算子 $A^*: \mathcal{Y} \to \mathcal{X}$ 必须满足: $$ \langle A(x), y \rangle_{\mathcal{Y}} = \langle x, A^*(y) \rangle_{\mathcal{X}}, \quad \forall x \in \mathcal{X}, y \in \mathcal{Y} $$
代码逻辑模板:
def dot_product_test(operator, tol=1e-5):
# 1. 生成随机向量(注意:如果是复数域,需生成复数高斯噪声)
x = random_tensor(operator.input_shape)
y = random_tensor(operator.output_shape)
# 2. 计算正向与伴随
Ax = operator.forward(x)
Aty = operator.adjoint(y)
# 3. 计算内积
# <Ax, y> = sum(Ax * y)
inner_1 = np.vdot(Ax, y)
# <x, Aty> = sum(x * Aty)
inner_2 = np.vdot(x, Aty)
# 4. 比较相对误差
error = abs(inner_1 - inner_2) / max(abs(inner_1), abs(inner_2))
if error < tol:
print(f"[PASS] Dot-test passed. Error: {error:.2e}")
return True
else:
print(f"[FAIL] Dot-test failed! Error: {error:.2e}")
print("Hint: Check padding, kernel flipping, or conjugation.")
return False
2.2 求解器架构:解耦与回调 (Solver Architecture)
不要编写名为 def image_denoising(...) 的函数。应该编写名为 def admm_solver(...) 的通用求解器。
2.2.1 依赖注入 (Dependency Injection)
求解器不应包含具体的物理模型知识。它只接受以下“零件”:
- LinearOperator:包含 $A$ 和 $A^*$。
- Proximal Operators:函数句柄,如
prox_f(x, tau)。 - Measurement:观测数据 $y$。
- Callback:每一步迭代后执行的函数(用于画图、保存日志)。
ASCII 流程图:ADMM 求解器内部数据流
初始化 x0, z0, u0
|
[ 迭代循环 k = 1 to MaxIter ]
|
+---> [子问题 1: x-update] (最小化 f(x) + Q(x))
| 调用: conjugate_gradient(A, b) 或 FFT_solver
|
+---> [子问题 2: z-update] (Prox 算子)
| 调用: z = prox_g(Ax + u, rho)
| (例如: 软阈值 soft_threshold)
|
+---> [子问题 3: u-update] (对偶更新)
| 计算: u = u + Ax - z
|
+---> [ 回调 Callback ]
| 记录: loss, primal_res, dual_res
| 操作: if k % 10 == 0: save_image(x)
|
[ 检查停止准则 ]
|
返回 x_final, stats
2.2.2 处理 Prox 算子的技巧
Prox 算子通常带有参数(如 $\lambda$)。在代码中,使用 Python 的 partial 或闭包(Closure)来固定这些参数,保持求解器接口整洁。
- Bad:
solver(..., lambda_reg=0.1, reg_type='l1') - Good:
prox = lambda v, s: soft_threshold(v, 0.1 * s)然后传入solver(..., prox_g=prox)
2.3 工程实践:配置与复现 (Configs & Reproducibility)
学术代码最怕“魔法数字”(Magic Numbers)散落在代码各处(例如:第 30 行 lambda = 0.05,第 80 行 iter = 200)。这使得实验无法追踪。
2.3.1 配置即代码 (Configuration as Code)
所有控制实验行为的参数必须抽离到独立的配置文件(YAML/JSON)中。
YAML 模板示例 (experiment_config.yaml):
meta:
project: "Video_Deblurring_Review"
experiment_id: "exp_001_ADMM_TV"
seed: 42 # 固定随机种子!
data:
dataset_path: "./data/GoPro/"
batch_size: 1 # 传统优化通常一次处理一张图/一个视频
preprocessing:
normalize: true # [0,1] float
gamma_correction: false
operator:
type: "convolution"
kernel_size: 25
kernel_type: "motion_blur_45deg"
boundary: "wrap" # 周期边界,对应 FFT
model:
algorithm: "ADMM"
hyperparams:
rho: 1.5
lambda_tv: 0.01
max_iter: 500
tol_abs: 1e-4
tol_rel: 1e-3
logging:
save_dir: "./results/"
print_freq: 10
save_image_freq: 50
2.3.2 实验记录的最佳实践
在代码运行开始时,执行以下操作:
- 创建带时间戳的目录:如
results/2023-10-27_exp_001/。 - 备份配置:将 YAML 文件复制到该目录。
- 备份代码状态:记录当前的 Git commit hash,甚至
diff信息。确保半年后你能找回生成这张图的确切代码版本。
3. 本章小结
- 抽象带来自由:通过定义
LinearOperator类,你的求解器可以无缝切换去噪($A=I$)、去模糊($A=H$)和压缩感知($A=M \cdot F$)任务。 - 点积测试是底线:不要信任你的直觉,只信任
<Ax, y> == <x, A*y>。这是开发新算子后的第一步,也是唯一一步验证。 - 求解器仅仅是调度器:优秀的求解器代码像是一个调度员,它只负责协调算子、Prox 和数据之间的流动,不包含具体算法逻辑。
- 配置决定成败:使用配置文件管理实验,严禁在循环内部硬编码参数。
4. 练习题
基础题
-
[代码实现] 基础算子库 实现以下算子的
forward和adjoint,并通过点积测试:- (a) Identity: 恒等算子。
- (b) Scaling: 标量乘法 $A(x) = \alpha x$。
- (c) Masking: 给定二值掩码 $M$,计算 $A(x) = M \odot x$。
- Hint: 掩码算子的伴随是什么?它是否是自伴随的(Self-adjoint)?
-
[调试] 卷积算子的陷阱 给定一个卷积核 $k$,算子定义为 $A(x) = k * x$(使用
scipy.signal.convolve2d,模式为'same')。 请写出其伴随算子 $A^*(y)$ 的具体实现。- Hint: 伴随卷积需要对核 $k$ 进行什么操作?边界处理
'same'对应的伴随还是'same'吗?(这是一个非常常见的陷阱,建议用 MATLABconv2或 Python 测试)。
- Hint: 伴随卷积需要对核 $k$ 进行什么操作?边界处理
-
[架构] 回调函数设计 设计一个回调函数
monitor_convergence(k, x, z, u)。要求:- 计算当前的 PSNR。
- 如果目标函数值在 10 次迭代内变化小于 $10^{-5}$,抛出
StopIteration异常以提前终止求解器。
挑战题
-
[算法扩展] 复合算子的伴随 在压缩感知 MRI 中,观测算子为 $A = M \cdot F \cdot S$,其中 $S$ 是线圈敏感度图(点乘),$F$ 是傅里叶变换,$M$ 是采样掩码。 请推导并实现 $A^*$ 的代码逻辑。
- Hint: 矩阵乘积的转置 $(ABC)^T = C^T B^T A^T$。注意复数域的共轭。
-
[系统设计] 自动超参调节器 (Auto-Tuner) 在 ADMM 中,参数 $\rho$ 严重影响收敛速度。实现一个“自适应 ADMM 求解器”外壳,根据原始残差 $r$ 和对偶残差 $s$ 的比值,每隔 10 次迭代自动调整 $\rho$。
- 规则:如果 $|r| > 10 |s|$,则 $\rho \leftarrow 2\rho$;如果 $|s| > 10 |r|$,则 $\rho \leftarrow \rho/2$。
- 思考:调整 $\rho$ 后,u 变量需要怎么调整以保持缩放一致性?
-
[性能优化] 算子缓存机制 对于某些算子(如 $A = \mathcal{F}$),如果不涉及 $A^T A + \rho I$ 的求逆,可能不需要缓存。但在涉及 $A^T A$ 求逆时(如去模糊的 x-update),通常需要预计算 $A$ 的特征值。 修改
LinearOperator类,增加eigenvalues()方法。如果是卷积算子,返回fft2(kernel);如果是恒等算子,返回ones。确保求解器只计算一次特征值。
点击查看练习题提示与答案要点
- Masking: 掩码算子是自伴随的,即 $A^*(y) = M \odot y$(假设 $M$ 是实数)。
- 卷积伴随: $A^*(y) = \text{flip}(k) * y$。如果使用
'same'填充,Python/MATLAB 的实现通常隐含了零填充。严格来说,valid卷积的伴随是full卷积,same的伴随近似于same但在边界处有细微差别(这也是为什么推荐在 FFT 域做周期卷积,因为 $A^*$ 仅仅是conj(K_hat))。 - Callback: 回调是观察优化过程的窗口。利用异常处理机制(
raise StopIteration)来控制流是一种优雅的 Python 模式。 - MRI 伴随: $A^* = S^* \cdot F^* \cdot M^*$。
- $M^*$:Mask 通常是实对角阵,伴随不变。
- $F^*$:逆傅里叶变换(注意归一化因子)。
- $S^*$:敏感度图的共轭点乘。
- 顺序:先 Mask,再 IFFT,再乘以敏感度共轭。
- 自适应 Rho: 参考 Boyd 的 ADMM 论文。调整 $\rho$ 时,为了保持对偶变量 $u$ 的物理意义(缩放后的拉格朗日乘子),通常 $u$ 不需要缩放,或者如果 $u$ 定义为 $y/\rho$,则需要缩放。标准实现中 $u$ 是缩放形式,更新 $\rho$ 时 $u$ 不变,但数值稳定性会有波动。
- 缓存: 这是“一次配置,多次运行”的关键。在 FFT 去模糊中,分母 $ |K|^2 + \rho $ 只需要计算一次。不要在循环里反复做
fft2(kernel)。
5. 常见陷阱与错误 (Gotchas)
5.1 频域计算的归一化灾难
- 现象:使用 FFT 实现的算子,经过
forward和adjoint后,数值变得极大(如放大 $N$ 倍)或极小。 - 原因:离散傅里叶变换 (DFT) 的定义在不同库中不同。
- NumPy/MATLAB:
fft无缩放,ifft乘 $1/N$。这意味着 $F$ 不是酉算子(Unitary)。
- NumPy/MATLAB:
- 对策:
- 方案 A (推荐):始终使用
norm='ortho'选项(如果在 Python 中)。这样 $F^T = F^{-1}$。 - 方案 B:手动定义算子 $A = \frac{1}{\sqrt{N}} \text{FFT}$。
- 检查:运行点积测试。如果误差是 $N$ 的倍数,就是归一化问题。
- 方案 A (推荐):始终使用
5.2 图像的内存布局 (HWC vs CHW)
- 现象:去噪结果看起来像重影或条纹,或者 RGB 通道颜色错乱。
- 原因:
- OpenCV/ImageIO 读取为
(Height, Width, Channel)。 - PyTorch/许多优化算法偏好
(Channel, Height, Width)或纯粹的(Height, Width)处理。 - 当你对一个 HWC 张量做
flatten再reshape时,如果没有转置,数据顺序就会乱。
- OpenCV/ImageIO 读取为
- 对策:在加载数据的第一步,统一转换为
(C, H, W)格式,并在保存时转回。在整个求解器内部保持格式一致。
5.3 浮点数精度截断
- 现象:算法在 PSNR 达到 30dB 后不再上升,或者梯度全为 0。
- 原因:输入图像是
uint8类型,取值 0-255。- 如果你做 $x - y$,结果可能下溢(变成 255 或 0)。
- 如果你做乘法,可能溢出。
- 对策:永远在入口处将图像转换为
float32或float64并归一化到 $[0, 1]$。
img = io.imread('path.png').astype(np.float32) / 255.0
5.4 伴随算子的边界陷阱
- 现象:恢复的图像中心很好,但在四周边缘有非常明亮或黑暗的方框/振铃。
- 原因:
- 正向算子 $A$ 使用了零填充(Zero-padding)。
- 伴随算子 $A^*$ 在边界处会累加零填充区域的值,导致边界像素值异常大。
- 或者,使用了 FFT(隐含周期边界),但图像是非周期的。
- 对策:
- 使用
Edge Tapering技术预处理图像。 - 在计算误差或 PSNR 时,裁掉边缘的 10-20 个像素(
crop_boundary)。 - 在数学模型中引入掩码算子 $M$,只在有效区域计算数据项:$|M(Ax - y)|^2$。
- 使用