第 8 章 数值实验与实践指南(chapter8.md

8.0 开篇段落

神经切线核(NTK)理论提供了理解深度网络在无限宽度和微小学习率下的行为的数学框架。本章旨在将这些理论工具转化为可操作的实践指南。我们将详细探讨如何在有限宽度网络上精确计算经验 NTK($\hat{\Theta}$),如何利用其特征谱来预测训练收敛、泛化性能,以及如何设计和复现关键的 NTK 驱动实验,如双下降和长度外推分析。通过本章的学习,读者将掌握在实践中使用 NTK 视角进行模型诊断和结构优化的方法。


8.1 经验 NTK 的高效计算方法

经验 NTK 矩阵 $\hat{\Theta}$ 是理论应用于实践的桥梁。对于 $N$ 个数据点,其大小为 $N \times N$,且在训练开始时(参数 $\theta_0$)计算。

8.1.1 经验 NTK 的定义与计算瓶颈

给定数据集 $X = \{x_1, \ldots, x_N\}$,经验 NTK 矩阵 $\hat{\Theta} \in \mathbb{R}^{N \times N}$ 定为:

$$ \hat{\Theta}_{ij} = \nabla_{\theta} f_{\theta_0}(x_i)^\top \nabla_{\theta} f_{\theta_0}(x_j) $$

定义 $J \in \mathbb{R}^{N \times P}$ 为 Jacobian 矩阵,其中 $J_{i, p} = \frac{\partial f_{\theta_0}(x_i)}{\partial \theta_p}$。则 $\hat{\Theta} = J J^\top$。

计算瓶颈: 在深度学习中,参数数量 $P$ 远大于数据点 $N$。

  1. 存储瓶颈: 存储完整的 Jacobian 矩阵 $J$ 需要 $O(N \cdot P)$ 的内存,对于 $N=10^5, P=10^8$ 的情况,这是无法承受的。
  2. 计算瓶颈: 即使能存储 $J$,计算 $J J^\top$ 仍需要 $O(N^2 P)$ 的乘加运算,效率极低。

8.1.2 基于自动微分的 R-Operator 技巧(GVP/VGP)

为了克服 $P$ 的依赖,我们利用现代自动微分 (AD) 框架中的高效原语:梯度-向量积 (GVP) 和向量-梯度积 (VGP)。

  1. VGP (Vector-Gradient Product) / R-Operator for $J^\top v$: 对于任意向量 $v \in \mathbb{R}^N$,计算 $u = J^\top v \in \mathbb{R}^P$。这可以通过计算加损失 $L_{weighted} = \sum_{i=1}^N v_i f_{\theta}(x_i)$ 对参数 $\theta$ 的梯度来实现: $$ u = \nabla_{\theta} L_{weighted} = \sum_{i=1}^N v_i \nabla_{\theta} f_{\theta}(x_i) $$ 这仅需一次后向传播(Backpropagation),计算复杂度为 $O(P_{comp})$,与 $P$ 无关。

  2. GVP (Gradient-Vector Product) / $J u$: 计算 $J u \in \mathbb{R}^N$。这本质上是 $\nabla_\theta f(X)$ 在方向 $u$ 上的方向导数。这可以通过 R-operator 或 JVP (Jacobian-Vector Product) 操作高效实现,计算复杂度也为 $O(P_{comp})$。

  3. 计算 $\hat{\Theta} v$: NTK 矩阵与向量 $v$ 的乘积 $\hat{\Theta} v = J (J^\top v)$ 可以分解为两个高效步骤: $$ \text{Step 1 (VGP): } u = J^\top v \\ \text{Step 2 (GVP): } \hat{\Theta} v = J u $$ 整个操作的复杂度是 $O(P_{comp})$。这使得我们可以利用迭代方法(如 Lanczos 或 Arnoldi 算法)来估计 $\hat{\Theta}$ 的特征值和特征向量,而无需显式构造矩阵。

8.1.3 数值稳定与精度

NTK 计算对数值精度高度敏感。

  1. FP64 的必要性: 当计算深层网络的 NTK 时,中间梯度(尤其是权重梯度)可能非常小,相乘后再求和很容易导致数值下溢或累积误差。使用双精度浮点数(FP64)是计算经验 NTK 的标准实践,尤其在 JAX 等功能性框架中。
  2. NTK Parameterization 的重要性: 正确的初始化缩放(确保 $\hat{\Theta}$ 的元素是 $O(1)$)对于数值稳定至关重要。如果 $\hat{\Theta}$ 的范数过大或过小,会导致特征值计算不稳定。

Rule-of-Thumb 8.1 (NTK 计算的存储与精度):

  1. 如果 $N < 5000$,且计算资源允许,可直接构造并存储 $\hat{\Theta}$ (使用 FP64)。
  2. 如果 $N > 10^5$,必须使用 $O(P_{comp})$ 的 R-Operator 技巧,并依赖迭代算法来分析谱。
  3. 始终使用 NTK Parameterization 和 FP64 精度。

8.2 使用 NTK 预测训练动态

NTK 理论的核心价值在于将复杂的非线性训练态简化为线性的核回归过程,从而实现可预测性。

8.2.1 NTK 谱与学习模式分解

如前所述,训练残差 $r(t) = f(t) - y$ 沿特征基 $u_k$ 衰减: $$ r(t) = \sum_{k=1}^N \alpha_k e^{-\lambda_k t} u_k $$

  • 高频/低频模式: 具有大特征值 $\lambda_k$ 的特征向量 $u_k$ 对应于网络最容易学习的数据模式(通常是数据中的低频或简单结构)。它们衰减快。具有小特征值 $\lambda_k$ 的模式对应于最难学习的高频或复杂结构。
  • 训练速度: 整体训练速度由最大的特征值 $\lambda_{\max}$(决定初始下降)和最小的非零特征值 $\lambda_{\min}$(决定收敛尾部)共同控制。
  • 泛化与谱: 经验发现,如果 NTK 谱衰减得快(即大特征值占主导),模型倾向于学习简单模式,泛化性能通常更好,这对应于一种“频谱偏置”(Spectral Bias)。

8.2.2 有效秩与模型容量

在 NTK 视角下,网络的有效容量由 NTK 矩阵的有效秩(Effective Rank)决定。有效秩 $R_{eff}$ 可以通过 $\sum_{k=1}^N \lambda_k / \lambda_{\max}$ 来近似。

  • 当 $W \to \infty$,$\text{Rank}(\hat{\Theta})$ 趋于 $N$。
  • 如果 $\hat{\Theta}$ 的秩很低,即使 $W$ 很大,网络也只能拟合低维子空间中的函数,限制了其表达能力。

8.2.3 线性化近似的验证与偏差分析

为了评估 NTK 理论的适用性,我们需要量化实际训练轨迹 $f_{\theta_t}$ 与 NTK 线性预测 $\hat{f}_t$ 之间的偏差。

  1. 预测函数轨迹: $$ \hat{f}_t = f_{\theta_0} - (\text{I} - e^{-\hat{\Theta} t}) (y - f_{\theta_0}) $$

2. 测量偏差: 在训练过程中定期采样 $\theta_t$,计算实际输出 $f_{\theta_t}$,并计算相对偏差: $$ \text{Deviation}(t) = \frac{| f_{\theta_t} - \hat{f}_t |_2}{| f_{\theta_t} |_2} $$

  • 当 $\text{Deviation}(t)$ 保持很小时(例如 $< 5\%$),说明 NTK 线性化近似有效。
  • 当偏差开始显著增长时表明网络的梯度特征 $\nabla_\theta f(x)$ 已经开始随训练发生变化,即开始进入特征学习阶段。

8.3 开源工具与代码实践

8.3.1 neural-tangents (JAX/Python)

neural-tangents (nt) 是目前最成熟的、用于计算无限宽度解析核的库。

核心优势:

  1. 解析计算: 避免了 $O(N^2 P)$ 的复杂性,直接计算 $W \to \infty$ 时的确定性核。
  2. 递归结构支持: 支持深层 MLP、CNN、残差网络和部分 Transformer 模块的精确解析核。
  3. 计算速度快: 对于 $N$ 适中的情况(例如 $N < 10^4$),其速度远超经验计算。

使用场景: 用于快速探索不同深度和激活函数对 NTK 谱和泛化性能的影响。

8.3.2 基于 PyTorch/TensorFlow 的经验 NTK 流程(概念示例)

当解析核不可用(如自定义层、复杂归一化)时,我们需要计算经验 NTK。

| 步骤 | 目标 | PyTorch/TensorFlow 概念实现 |

步骤 目标 PyTorch/TensorFlow 概念实现
Step 1 初始化与缩放 使用 NTK parameterization 初始化网络 $\theta_0$。
Step 2 准备 R-Operator 定义一个辅助函数 compute_grad_vec_prod(outputs, vec, params)
Step 3 循环计算内积 遍历 $i=1 \dots N$,计算 $J_i$: J_i = autograd.grad(outputs[i], params)
Step 4 构造 $\hat{\Theta}$ $\hat{\Theta}_{ij} = J_i^\top J_j$。
优化 (R-Operator) 避免循环 使用 torch.autograd.functional.vjpjvp 组合,直接计算 $\hat{\Theta} v$。

8.3.3 经验计算的陷阱:内存管理

计算经验 NTK 时,最常见的错误是尝试在单个设备上存储所有 $N$ 个梯度向量。

正确做法: 即使 $N$ 很大,也应将计算分解:

  1. 计算 $J_i$ 时,只存储 $J_i$,并立即计算与所有 $J_j$ 的内积。
  2. 也可以将数据 $X$ 分割成小批次,并行计算并组装 $\hat{\Theta}$。

8.4 复现实验:双下降与长度外推

NTK 理论提供了对现代深度学习现象的精确定量分析能力。

8.4.1 复现双下降现象的定量设计

双下降的核心是泛化性能与模型容量(或 NTK 矩阵的性质)之间的关系。

实验设置:

  1. 数据与噪声: 使用低维数据(如 MNIST 降维或合成回归),固定训练集 $N=1000$。加入适度的标签噪声。
  2. 模型容量 ($C$): 通过网络宽度 $W$ 调节模型容量 $C \propto W P$.
  3. 观察指标: 训练误差(必须降到零)和测试误差(泛化误差)。
  4. NTK 分析:
    • 临界点识别: 临界插值阈值 $W_{interp}$ 是 $W$ 使得 $\hat{\Theta}$ 矩阵刚好满秩(或其前 $N$ 个特征值显著大于零)的点。
    • 现象: 在 $W \approx W_{interp}$ 时,网络必须拟合噪声,且由于 $\hat{\Theta}$ 的条件数达到峰值,最小范数解被噪声严重扰动,导致测试误差最大化(第一个下降结束)。
    • 过参数化区: 当 $W \gg W_{interp}$ 时,$\hat{\Theta}$ 的特征值谱变得更稳定,最小范数解倾向于平滑函数,泛化能力恢复(第二次下降)。

8.4.2 序列模型长度外推的 NTK 几何分析

在序列建模中,外推失败通常是由于核函数 $\Theta(x, x')$ 依赖于序列的绝对位置,导致训练集和外推集上的核结构不匹配。

核结构观察: 我们关注 $\hat{\Theta}$ 矩阵中的块结构。对于序列 $X = \{x^{(1)}, \ldots, x^{(L)}\}$,其中每个 $x^{(l)}$ 是一个 token,NTK 矩阵的块 $\hat{\Theta}_{l, l'} = \nabla_\theta f(x^{(l)})^\top \nabla_\theta f(x^{(l')})$ 衡量了在不同位置上的梯度相关性。

失败的 NTK 几何(绝对位置编码): 若使用绝对位置编码,每个位置的梯度特征 $J_l$ 都会不同。在训练长度 $L_{train}$ 内,核矩阵 $K_{train}$ 结构复杂,但可学习。当测试长度 $L_{test} > L_{train}$ 时:

$$\Theta \text{ (Length } L_{train} \text{)}: \begin{pmatrix} \Theta_{1,1} & \cdots & \Theta_{1, L} \\ \vdots & \ddots & \vdots \\ \Theta_{L, 1} & \cdots & \Theta_{L, L} \end{pmatrix}$$ 若 $L_{test}$ 位置上的 $\Theta_{L_{test}, L_{test}}$ 明显偏离 $L_{train}$ 边界处的结构,则网络无法在外推长度上维持训练得到的函数关系,导致性能崩溃。

Rule-of-Thumb 8.3 (核的平移不变性): 要实现成功的长度外推,序列模型的 NTK 必须表现出近似平移不变性局部自相似性。例如,$\hat{\Theta}_{l, l'}$ 应该主要依赖于相对位置 $|l - l'|$,而不是绝对位置 $l, l'$。这正是 RoPE、ALiBi 等改进位置编码的设计目标。


8.5 实践建议与常见坑

8.5.1 将 NTK 作为“调参指南”的方式

NTK 理论提供了一种快速原型设计的工具,用于指导我们在不进行全量训练的情况下做出结构选择。

  1. 激活函数选择: 比较不同激活函数(如 ReLU vs. GELU)计算出的解析 NTK 谱。谱衰减更快的核通常暗示更好的泛化潜力(偏好单函数)。
  2. 结构诊断: 对于新的网络层(如自定义注意力),检查其是否显著改变了 NTK 的条件数。如果条件数过大,暗示训练难度高。
  3. 初始化优化: 实验不同的初始化缩放。最佳的初始化应该能产生 $O(1)$ 范数且条件数适中的 NTK 矩阵。

8.5.2 偏离 NTK 假设的训练策略

如果目标是实现特征学习(而非核回归),我们必须故意打破 NTK 假设:

  1. 使用大步长学习率: 梯度步长 $\eta$ 越大,参数偏离 $\theta_0$ 越远,线性化失效越快,特征学习效应越强。
  2. 采用 $\mu$-Parameterization ($\mu$P): 这种参数化方法通过对某些层(如输出层)的缩放,使得网络在宽度 $W \to \infty$ 时,其梯度特征的范数保持 $O(W)$,从而在训练中持续发生特征学习,而不是收敛到固定核。
  3. 使用非线性优化器: Adam 等优化器通过局部缩放梯度,有效地改变了函数空间中的度量,使得训练路径不再是 NTK 梯度流。

8.5.3 隐式正则化与最小范数解

在过参数化区域,NTK 理论预测的解是训练数据上的最小 RKHS 范数插值解: $$ \min_{f \in \mathcal{H}_\Theta} |f|_{\mathcal{H}_\Theta}^2 \quad \text{s.t.} \quad f(x_i) = y_i $$ 这被称为隐式正则化。在实践中:

  • 如果你的目标是找到最平滑(最小范数)的插值函数,应确保训练处于 NTK regime(小学习率)。
  • 如果你观察到泛化性能很好,但网络宽度有限,很可能是 NTK 理论的隐式正则化效应在起作用。

本章小结

| 概念 | 实践作用 | 关键实现技巧 |

概念 实践作用 关键实现技巧
经验 NTK ($\hat{\Theta}$) 衡量网络在初始化时对数据点 $(x_i, x_j)$ 的梯度相关性。 必须使用 R-Operator/GVP/VGP 技巧避免 $O(N^2 P)$ 计算瓶颈。
NTK 谱分析 预测训练收敛速度 ($\lambda_{\min}, \lambda_{\max}$),诊断学习模式。 快速衰减的谱利于泛化;高条件数暗示训练不稳定。
双下降实验 定量观察模型容量与泛化误差的关系。 通过调节 $W$ 接近 $N$,观察 $\hat{\Theta}$ 条件数的峰值。
长度外推 分析序列模型在外推长度上的核函数结构。 检查 NTK 矩阵是否保持局部平移不变性。
NTK Parameterization 保证 $\hat{\Theta}$ 的范数与 $W$ 无关,确保线性化近似有效。 在实践中必须严格遵守,通常要求权重缩放为 $O(1/W)$ 或 $O(1/\sqrt{n_{in}})$。

练习题

基础题

E8.1 NTK 矩阵的计算选择 你正在处理一个图像分类任务,数据集 $N=5000$,网络 $P=2000$ 万参数。你是否应该显式计算并存储 $N \times N$ 的 $\hat{\Theta}$ 矩阵?请简要说明理由。

Hint: 计算 $N^2 P$ 和 $N^2$ 的存储量。

E8.2 梯度流模拟 假设你计算了 $\hat{\Theta}$ 并得到了特征值 $\Lambda=\{10, 1, 0.1, 0.01\}$,对应的初始残差投影 $\alpha=\{5, 3, 2, 1\}$。请写出 $t=10$ 时,残差 $|r(10)|^2$ 的大致表达式。哪个模式在此时对残差贡献最大?

Hint: $r(t) = \sum \alpha_k e^{-\lambda_k t} u_k$,且 $|r|^2 = \sum \alpha_k^2 e^{-2 \lambda_k t}$。

E8.3 NTK Parameterization 检查 在一个两层 ReLU MLP 中,如果所有权重 $W_i$ 采用标准 Kaiming 初始化(方差为 $1/n_{in}$),且偏置为零。这个网络是否满足 NTK parameterization 的要求?

Hint: 检查 $f(x)$ 和 $\nabla_\theta f(x)$ 的方差随 $W$ 的缩放。

E8.4 解析核的局限 描述在什么情况下,即使 neural-tangents 库能够计算网络的解析 NTK,你仍然需要计算经验 NTK?

Hint: 关注无限宽假设和随机性。

挑战题

C8.5 Jacobian 矩阵的秩(开放性) 在双下降实验中,临界插值阈值 $W_{interp}$ 是当网络宽度 $W$ 使得 Jacobian 矩阵 $J \in \mathbb{R}^{N \times P}$ 满足何种条件时发生的?从线性代数的角度解,为何 $W_{interp}$ 附近的泛化最差?

Hint: 关注 $J$ 的列空间和行空间,以及 $N$ 与 $P$ 的关系。

C8.6 $\mu$P 带来的理论挑战 $\mu$-Parameterization 旨在确保网络的输出方差 $\text{Var}(f_\theta(x))$ 和梯度范数 $| \nabla_\theta f(x) |$ 在 $W \to \infty$ 时都保持 $O(W)$。为什么这种缩放会导致网络的行为显著偏离 NTK 理论的预测?

Hint: 回忆 NTK 理论要求 $\hat{\Theta}$ 在 $W \to \infty$ 时保持不变。

C8.7 诊断长度外推失败的核 对于一个训练长度 $L_{train}=10$ 的序列模型,其 NTK 矩阵 $\hat{\Theta}$ 表现出强烈的块对角结构。如果在测试 $L_{test}=20$ 时失败,请定性描述 $\hat{\Theta}$ 矩阵的 $10 \times 10$ 块和 $20 \times 20$ 块之间可能存在的差异。

Hint: $L_{test} > L_{train}$ 时的 $\hat{\Theta}$ 块如何反映模型泛化能力不足?

C8.8 基于 NTK 谱的正则化策略 如果你发现 $\hat{\Theta}$ 的最小特征值 $\lambda_{\min}$ 极小,这表明对应的函数模式难以学习。如果你不能改变网络结构,基于 NTK 理论,提出一种隐式或显式正则化策略,以加速训练和提高泛化。

Hint: Tikhonov 正则化(岭回归)与特征值缩放的关系。


答案

E8.1 NTK 矩阵的计算选择 应该显式计算并存储 $\hat{\Theta}$ 矩阵。

  • $N^2$ 存储:$5000^2 = 25 \text{M}$ 个元素,使用 FP64 大约 $200$ MB,内存消耗可控。
  • $N^2 P$ 计算:$25 \times 10^6 \times 20 \times 10^6 = 5 \times 10^{14}$ FLOPs。这是巨大的。 然而,更优的实践是: 存储 $\hat{\Theta}$,但利用 R-Operator 技巧 (GVP/VGP) 来计算 $J_i$ 和 $J_j$ 的内积。由于 $N$ 适中,我们可以存储 $N$ 个梯度,然后计算内积。或者,如果使用 neural-tangents 且结构简单,则可直接使用解析核。

E8.2 梯度流模拟 残差范数平方:$|r(10)|^2 \approx \sum_{k=1}^4 \alpha_k^2 e^{-2 \lambda_k t}$。

  • $k=1 (\lambda=10): 5^2 e^{-20} \approx 0$
  • $k=2 (\lambda=1): 3^2 e^{-2} \approx 9 \times 0.135 \approx 1.22$
  • $k=3 (\lambda=0.1): 2^2 e^{-0.2} \approx 4 \times 0.82 \approx 3.28$
  • $k=4 (\lambda=0.01): 1^2 e^{-0.02} \approx 1 \times 0.98 \approx 0.98$ 贡献最大模式: $k=3$ ($\lambda=0.1$),其残差为 3.28。这是因为这个模式的特征值相对较小,衰减最慢。

E8.3 NTK Parameterization 检查 该网络不满足标准的 NTK parameterization 要求。

  • 标准 Kaiming 确保 $f(x)$ 的方差在 $W \to \infty$ 时是 $O(1)$(好)。
  • 但是,它会导致 $\nabla_\theta f(x)$ 的范数是 $O(W)$。
  • 因此,$\hat{\Theta}_{ij} = \nabla_\theta f(x_i)^\top \nabla_\theta f(x_j)$ 将是 $O(W)$。
  • 为了满足 NTK parameterization,权重 $W_i$ 应该被缩放为 $O(1/W)$ 或 $O(1/\sqrt{W n_{in}})$,以保证 $\hat{\Theta}$ 矩阵的范数是 $O(1)$。

E8.4 解析核的局限 当网络的宽度 $W$ 相对较小(例如 $W < 1000$时,经验 NTK $\hat{\Theta}$ 仍包含显著的随机波动,且会随着训练时间略微变化。解析 NTK 仅是 $W \to \infty$ 时的确定性极限。如果实验重点是分析有限宽度网络(具有内在随机性)的行为,或者网络结构包含随机元素(如 Dropout),则需要计算经验 NTK。

C8.5 Jacobian 矩阵的秩(开放性) $W_{interp}$ 发生时,Jacobian 矩阵 $J \in \mathbb{R}^{N \times P}$ 的行秩 $\text{rank}(J)$ 约为 $N$。

  • 当 $\text{rank}(J) < N$,网络无法完美拟合所有数据(欠参数化)。
  • 当 $\text{rank}(J) \approx N$,网络刚好处在插值边界。此时 $J$ 几乎满秩,但 $J J^\top = \hat{\Theta}$ 接近奇异,导致条件数爆炸。最小二乘解 $\hat{w} = (J^\top J)^{-1} J^\top y$ 对噪声高度敏感,泛化最差。

C8.6 $\mu$P 带来的理论挑战 NTK 理论要求网络的函数空间变化是线性的,即 $f_{\theta_t} \approx f_{\theta_0} + \nabla_\theta f_{\theta_0}^\top (\theta_t - \theta_0)$。这要求训练过程中的梯度范数不能过度增长。 如果 $\mu$P 导致 $| \nabla_\theta f(x) | \sim O(W)$,则 $\hat{\Theta} \sim O(W)$。在梯度流 $d\theta/dt = -\nabla L$ 中,参数变化 $\Delta \theta$ 会导致函数变化 $\Delta f \sim \nabla_\theta f \Delta \theta$. 如果 $\Delta \theta$ 缩放合适,$\Delta f$ 仍可能是 $O(1)$,但由于 $O(W)$ 的梯度范数,函数空间中的变化(切线方向)不再由初始化主导,网络能够沿着比 NTK 理论预测的更快的方向进行特征学习。

C8.7 诊断长度外推失败的核 如果外推失败,$\hat{\Theta}$ 的 $20 \times 20$ 块将显示出与 $10 \times 10$ 训练块不一致的结构。

  • 训练块 ($10 \times 10$): 结构稳定,梯度相关性高。
  • 外推块 ($L > 10$ 的区域): 梯度相关性可能急剧下降(例如,所有 $\hat{\Theta}_{i, j}$ 趋近于零),或者呈现出与训练集上学到的模式不匹配的新的周期性或随机性这表明网络无法“识别”超出训练长度的位置,无法维持有效的特征映射,从而无法进行可靠预测。

C8.8 基于 NTK 谱的正则化策略 极小的 $\lambda_{\min}$ 意味着网络容易在对应模式上产生大的权重变化 $\Delta \theta$,从而拟合噪声。 策略:Tikhonov 正则化(岭回归): 将梯度流解修改为核岭回归解: $$ f_{\text{ridge}} = y - (\hat{\Theta} + \lambda I)^{-1} (y - f_0) $$ 这相当于在损失函数中增加一个 $\lambda |\theta - \theta_0|^2$ 的正则项。在 NTK 视角下,它将特征值 $\lambda_k$ 替换为 $\lambda_k + \lambda$。这有效地提升了所有特征值,特别是将 $\lambda_{\min}$ 抬升至 $\lambda$,从而稳定了训练,并降低了对噪声的敏感性(提高了泛化)。