Chapter 05 — 大批量训练与学习率策略

开篇段落

本章将深入探讨大规模语言模型训练中两个相互交织、至关重要的核心要素:大批量训练(Large-batch Training)学习率(Learning Rate)策略。这二者共同决定了训练的收敛速度、稳定性和最终模型的性能。我们将从第一性原理出发,明确定义“全局批量”(Global Batch Size),并讨论其在百万级 token 规模下的理论基础与实践选择。核心目标是让你不仅知其然,更知其所以然:深刻理解批量大小、学习率、训练步数与总训练 token 数之间的联动关系,并能熟练运用学习率缩放法则(LR Scaling)和高级调度策略(LR Schedule)来驾驭 64x H100 规模的训练集群。本章将为你提供一套经过业界验证的经验法则、理论依据和一套完整的操作心法。

文字论述

1. 全局批量大小(Global Batch Size, GB_tok)的深度解析

在 LLM 训练的语境下,批量大小的唯一通用“货币”是 token 数量,而非样本数。这为我们讨论不同序列长度、不同打包策略下的训练提供了一个统一的基准。

全局批量大小(GB_tok 的严格定义是:在执行一次模型权重更新(即 optimizer.step()时,所使用的全部训练 token 的总和。它是一个逻辑上的概念,由物理限制和并行策略共同决定:

$$ GB_{\text{tok}} = \mu_{\text{tok}} \times D \times A $$

让我们逐一拆解这个公式的每个组成部分:

  • μ_tok (Micro-batch Size in tokens):单张 GPU 在一次完整的前向和后向传播(a single forward/backward pass)中处理的 token 总数。这个值是显存瓶颈直接体现。其上限由模型大小、激活值、优化器状态(若未 offload)以及上下文长度 L_ctx 共同决定。

    • 例如:对于一个 7B 模型,在 bf16 精度和 4k 上下文长度下,一次前向传播的激活值可能就高达数十 GB。考虑到反向传播的梯度和 AdamW 优化器的状态(每个参数需要额外的 8 字节,对于 7B 模型约 56GB),一张 80GB 的 H100 卡在不使用 ZeRO-3 或 offload 的情况下,可能只能容纳 μ_batch_size=4 个样本,即 μ_tok = 4 * 4096 = 16,384
  • D (Data Parallel Factor):数据并行的副本数量。在我们的 64x H100 设置中,这个值取决于张量并行(TP)和流水线并行(PP)的配置。总 GPU 数 N_gpus = D \times TP \times PP

    • 例如:若 TP=4, PP=1,则 D = 64 / 4 / 1 = 16。这意味着有 16 个模型副本在同时处理不同的数据。
  • A (Gradient Accumulation Steps):梯度累积步数。这是一种用时间换间的技巧,允许我们在不增加显存占用的情况下,模拟出更大的批量。计算机会执行 A 次前向/后向传播,将每次计算出的梯度累加起来,最后用这个累加的梯度执行一次优化器更新。

1.1 为什么必须追求大批量?

追求数百万 token 级别的全局批量,并非单纯的工程选择,而是基于深刻的统计学和硬件效率考量。

  1. 降低梯度方差,稳定训练方向: 随机梯度下降(SGD)及其变体(如 AdamW)的核心思想是用一个小批量数据的梯度来近似整个训练集的梯度。批量越小,这个梯度的估计噪声就越大,方向摆动越剧烈。大批量则像是进行一次更精准的“民意调查”,其计算出的平均梯度方向更接近“真实”的全局梯度方向。 $$ \mathbb{E}[g(x; B)] = \nabla L(x) \quad \text{Var}(g(x; B)) \propto \frac{1}{|B|} $$ 其中 g(x; B) 是批量 B 的梯度。梯度方差与批量大小 |B| 成反比。更小的方差意味着我们可以更自信地朝着最优解前进,从而能够承受更大的学习率,加速收敛。

  2. 摊薄通信开销,提升硬件利用率: 在分布式训练中,每次权重更新都需要在 D 个数据并行副本之间进行梯度同步(All-Reduce 操作)。这是一个通信密集型操作。如果批量太小,计算时间(computation time)可能远小于通信时间(communication time),导致 GPU 大量时间在“等待”而非“计算”,即所谓的“通信墙”(communication wall)。增大 μ_tokD 可以有效提高计算通信比(Compute-to-Communication Ratio),让 GPU 的 TFLOPS 得到更充分的利用。梯度累积 A 虽然不能减少通信次数,但它通过减少优化器更新的频率,也间接降低了总训练时间中通信开销的占比。

1.2 GB_tok 的“甜点区”:2M-4M Tokens

LLaMA 系列论文的成功,将 2M-4M tokens 的全局批量大小确立为一个黄金标准。

  • 低于 1M tokens:训练依然可行,但通常需要更小的学习率峰值和更长的 warmup 阶段来抑制梯度噪声。整体收敛速度可能不是最优的。
  • 2M-44M tokens:这是当前大规模预训练的“舒适区”。它在梯度估计的准确性、训练稳定性、收敛速度和硬件效率之间取得了极佳的平衡。对于我们的 1T token 训练目标,这是一个强力推荐的起点。
  • 远超 4M tokens:当批量大到一定程度,会遇到收益递减的现象。理论上,梯度方差的下降速度会放缓。实践中,可能会出现“泛化差距”(generalization gap)扩大的问题,即模型在训练集上表现优异,但在验证集上性能停滞。此外,极大的批量也意味着极少的总训练步数,这使得学习率调度变得异常敏感。

实战计算示例: 假设我们要在 64x H100 (TP=4, PP=1, so D=16) 上训练一个 13B 模型,目标 GB_tok 为 4M tokens,上下文 L_ctx=4k

  1. 确定 μ_tok:通过实验我们发现,为了防止 OOM,单卡 micro-batch 样本数最多为 μ_batch_size=4。因此 μ_tok = 4 \times 4096 = 16,384
  2. 计算 A:根据公式 GB_tok = μ_tok * D * A,我们有: A = 4,194,304 / (16,384 * 16) = 4,194,304 / 262,144 = 16

  3. 最终配置:我们需要设置梯度累积步数为 16。

2. GB_tok、迭代步数与总 tokens 的铁三角关系

这三个变量构成了一个不可违背的约束方程,是所有训练计划的数学基础:

$$ T_{\text{tokens}} = GB_{\text{tok}} \times \text{iters} $$

对于我们 1T tokens (10^{12}) 的目标,这个关系意味着 GB_tok 的选择直接决定了训练的总步数,进而深刻影响学习率调度的设计。

| 全局批量大小 (GB_tok) | 总训练步数 (iters for 1T tokens) | 对 LR Schedule 的影响 | 优缺点 |

全局批量大小 (GB_tok) 总训练步数 (iters for 1T tokens) 对 LR Schedule 的影响 优缺点
1.05M (~$2^{20}$) ~953,674 较长的衰减周期,对超参不敏感,warmup 占比小 训练步数多,优化器更新频繁,可能总训练时间更长
2.10M (~$2^{21}$) ~476,837 标准,平衡良好 业界标准配置,易于复现和调试
4.19M (~$2^{22}$) ~238,418 非常短的衰减周期,warmup 阶段和峰值 LR 极为敏感 优化器更新次数少,理论上墙上时间(wall-clock time)可能更短,但调试难度大

核心启示:选择一个大的 GB_tok 是一种高风险高回报的策略。它要求你对学习率、warmup 等超参有更准的把握,因为任何早期的不稳定都会在短暂的训练步数中被不成比例地放大。

3. 学习率缩放(LR Scaling)经验法则

当你改变 GB_tok 时,为了保持相似的训练动态,学习率 η 必须相应调整。

  1. 线性缩放法则 (Linear Scaling Rule) $$ \eta_{\text{new}} = \eta_{\text{base}} \times \frac{GB_{\text{tok, new}}}{GB_{\text{tok, base}}} $$ 直觉:批量加倍,梯度估计的信噪比加倍,所以我们可以迈出两倍大的步子。 适用场景:在批量变化范围较小(如 2x-4x)时,此法则是合理的近似。但在从非常小扩展到非常大的批量时,它过于激进,容易导致训练在 warmup 结束后立即发散。

  2. 平方根缩放法则 (Square Root Scaling Rule) $$ \eta_{\text{new}} = \eta_{\text{base}} \times \sqrt{\frac{GB_{\text{tok, new}}}{GB_{\text{tok, base}}}} $$ 直觉:基于中心极限定理,梯度的信噪比(signal-to-noise ratio)与 $\sqrt{|B|}$ 成正比。因此,学习率也应按此比例调整,以维持恒定的有效学习“强度”。 适用场景这是大规模训练中更安全、更受推荐的黄金法则。几乎所有成功的 LLM 训练都遵循或接近这个规律。

实践建议 (Rule-of-thumb):

  • 始终以平方根缩放作为你的默认策略。例如,社区中一个常见的基线是 Chinchilla 论文中的 GB_tok=262,144 (256k) 对应 η_peak=6.0e-4。若要扩展到我们的目标 GB_tok=4.19M,则: η_new = 6.0e-4 * sqrt(4,194,304 / 262,144) = 6.0e-4 * sqrt(16) = 2.4e-3

  • “激进的平方根”:在实践中,有时会使用一个介于 sqrtlinear 之间的指数,如 B^{0.6}B^{0.75},但这需要更多实验来验证。

  • 注意:µP (Maximal Update Parametrization) 等更前沿的理论试图从根本上解决超参随规模变化的问题,但目前 sqrt 法则仍是最实用的工程捷径。

4. 学习率调度器(LR Schedule):驾驭整个训练过程的艺术

一个精心设计的学习率调度器是引导模型走向优秀收敛的关键。它通常由三个阶段组成:

      Peak LR (η_peak)
        . . . . . . . . . . . . . . . . . . .
      .                 ' ''''-.
    .                    '      ''
  .                       '        ''
.                          '          '
                          '            '.
                         '               '
                        '                 ''
Learning Rate (η) ---> '                   ''-.
                      '                       ' ' ' ' . . . . . Min LR (η_final)
 ___________________ . . . . . . . . . . . . . . . . . . . . . . . . . . .
|   Warmup Phase    |         Cosine Decay Phase          | Cooldown |
<-------------------><-------------------------------------><-------->
0               T_warmup                                T_total
                         Training Steps (iters)
  1. Warmup (热):

    • 目的:在训练初期,模型权重是随机的,优化器(尤其是 Adam 的二阶动量 v)也未积累有意义的统计信息。此时若直接使用峰值学习率,梯度会异常巨大且方向不稳定,极易导致“梯度爆炸”和数值溢出(NaN)。Warmup 通过在最初几千步内将学习率从 0 线性增加到 η_peak,为模型和优化器提供了一个平稳的“冷启动”过程。
    • 时长建议:通常设置为 2000 到 5000 步。一个稳健的经验法则是,让 warmup 阶段覆盖总训练步数的 1%-5%。对于 iters=250k 的短程训练,2000-3000 步的 warmup 是一个合理的选择。
  2. Decay (衰减):

    • 目的:当模型度过不稳定的初期后,需要一个较高的学习率来快速探索损失函数的广阔空间。随着训练的进行,模型逐渐接近最优解所在的“山谷”,此时需要降低学习率以进行“精细微调”,避免在谷底附近反复振荡而无法敛。
    • 余弦衰减 (Cosine Decay)这是 LLM 预训练的绝对主流选择。其数学形式保证了学习率平滑地从 η_peak 下降到 η_final(通常是 η_peak 的 10%)。相比线性衰减,它在训练中后期能维持相对较高的学习率更长时间,被认为有助于模型逃离不良的局部最小值,找到更优的泛化解。
  3. Cooldown (冷却,可选):

    • 目的:在余弦衰减结束后,将学习率固定在 η_final 再训练一小段时间(如几千步)。这可以给模型最后的机会来稳定和巩固其学到的知识,尤其是在总训练步数非常多的时候。

5. 梯度裁剪、权重衰减与稳定性观测

这些是保障训练过程不“脱轨”的必要安全措施。

  • 梯度裁剪 (Gradient Clipping)

    • 机制:在优化器更新权重之前,计算所有模型参数梯度的全局 L2 范数。如果该范数超过预设阈值(clip_grad_norm),则所有梯度都会被同比例缩小,使得最终的范数恰好等于该阈值。
    • 为什么是必须的:在使用 bf16fp16 混合精度训练时,数值的动态范围远小于 fp32。梯度裁剪是防止因偶尔出现的巨大梯度导致数值上溢(overflow)变成 NaNInf生命线
    • 推荐值clip_grad_norm = 1.0 是一个非常标准且效果良好的选择。
  • 权重衰减 (Weight Decay, wd):

    • 机制:作为一种 L2 正则化手段,它通过惩罚较大的权重值来防止过拟合。在 AdamW 优化器中,权重衰减被“解耦”了——它不再是梯度的一部分,而是直接从权重中减去一个与其大小成比例的值。W_t = W_{t-1} - η * (grad_t + wd * W_{t-1}) (Adam) vs W_t = (1 - η * wd) * W_{t-1} - η * update_t (AdamW)。这种方式在 Adam 这类自适应优化器上表现更好。
    • 推荐值:对于大型 Transformer 模型,wd = 0.1 是一个常见的有效值。
  • 定性观测清单

    • 损失曲线:寻找尖峰(spikes)。偶尔的小尖峰可能无伤大雅,但频繁或剧烈的尖峰是严重不稳定性的信号。
    • 梯度范数(裁剪前):在监控工具中绘制此指标。如果它频繁地撞上 1.0 的裁剪天花板,说明你的学习率可能过高。健康的曲线应该在 warmup 后稳定在一个远低于 1.0 的水平,偶尔触及上限。
    • 学习率曲线:确保实际执行的学习率与你设计的调度器完全一致。
    • 激活值范数:通过在模型中添加 hook,可以监控各层激活值的范数。如果某些层的激活值持续增长并趋于无穷,那么离出现 NaN 就不远了。

本章小结

  • 规划起点:先确定总训练 T_tokens(如 1T),然后选择全局批量 GB_tok(推荐 2M-4M tokens),由此计算出总步数 iters
  • 核心参数GB_tok 是通过 μ_tok (受显存限制), D (并行策略), 和 A (梯度累积) 共同实现的。
  • 学习率核心法则:当 GB_tok 变化时,学习率峰值 η_peak 应遵循平方根缩放法则进行调整。
  • 调度器标准模板:采用 Warmup + Cosine Decay 策略。Warmup 步数覆盖总步数的 1-5%,衰减终点 η_final 设为 η_peak 的 10%。
  • 安全带必须开启梯度裁剪clip_grad_norm=1.0),并使用合理的权重衰减(如 wd=0.1)。
  • 监控是关键:持续观察损失、梯度范数和学习率曲线,将它们作为诊断训练健康状况的核心仪表盘。

常见陷阱与错误 (Gotchas)

  1. 混淆不同层级的批量大小:在计算 LR Scaling 时,误用了 μ_tok(单卡 micro-batch)或 μ_tok * D(未计入累积的分布式批量),而非完整的 GB_tok = μ_tok * D * A调试技巧:在日志中明确打印出这三个值以及最终的 GB_tok,确保计算无误。

  2. “静态”的学习率:调整了 GB_tok(如,为了适配更长的上下文而减小 μ_batch_size,并用增加 A 来补偿),但忘记了重新计算和应用 LR Scaling。这会导致实际的训练动态与预期严重不符。调试技巧:将 LR Scaling 的计算逻辑脚本化,使其与批量配置联动。

  3. Warmup 与总步数不匹配:Warmup 的步数是绝对值。当你因为 GB_tok 翻倍而导致 iters 减半时,固定的 warmup 步数(如 3000 步)在总训练中的占比会翻倍,这可能不是你想要的效果。调试技巧:考虑将 warmup 时长定义为总步数的一个百分比,或者在每次调整 iters 后都重新审视 warmup 步数的绝对值。

  4. 调度器总步数 (T_max) 设置错误:在 PyTorch 的 CosineAnnealingLR 等调度器中,你需要传入一个总步数参数。这个参数必须等于你计算出的 iters。如果设置错误,你的学习率可能会提前衰减到零,或者在训练结束时还远未衰减到位。

  5. “批量尺寸陷阱” (The Micro-batch Size Trap):为了节省内存,将 μ_batch_size 设置为 1。这虽然能跑起来,但极大地损害了 GPU 的计算效率,因为无法利用张量核心(Tensor Cores)进行高效的矩阵运算,导致 tokens/s 吞吐量暴跌。调试技巧:进行一次短期的性能剖析,寻找能最大化 tokens/s 吞吐量且不 OOM 的最大 μ_batch_size,然后用梯度累积 A 来补足剩下的 GB_tok