chapter11.md — 端到端:从零预训练(1T tokens)

开篇段落

本章是整个教程的实战巅峰,我们将前面章节中探讨的 Scaling Laws、优化器策略、并行技术和数据方案,融合成一套针对 3B、7B 和 13B LLaMA 风格模型的、可在 64x H100 80GB 集群上直接运行的端到端预训练“配方”。这不仅仅是一张参数表,更是对大规模训练中无数权衡与决策的总结。学习本章后,你将获得一套经过验证的、可用于启动 1T tokens 从零预训练的基线配置。更重要的是,你将深入理解每个关键超参背后的“为什么”,学会如何通过解读复杂的训练日志来诊断训练过程的健康状况,并具备将这些“配方”根据自身求进行调整和优化的能力。这些配置是稳健的起点,而非一成不变的终点,旨在助你自信地迈出从零训练的第一步。

1. 通用配置原则与环境假设

在深入具体参数之前,我们必须明确所有尺寸模型都将遵循的通用原则和环境基线。这些选择共同构成了我们训练框架的“骨架”,确保了效率、稳定性和可扩展性。

  • 硬件与并行策略 (回顾 Chapter 09)

    • 集群规模: 8 节点 × 8 卡/节点 = 64 × H100 80GB SXM。我们假设节点内通过高速 NVLink/NVSwitch 连接,节点间通过 InfiniBand/400GbE 连接。
    • 张量并行 (TP, Tensor Parallelism): TP=8。这是最大化单节点性能的基石。Transformer 中的 nn.LinearAttention 模块的计算可以被优雅地沿特定维度切分。将 TP 设置为节点内的 GPU 数量(8),可以确保通信量最大、最频繁的张量切片交换(all-reduce, all-gather)发生在速度最快的 NVLink ,从而最小化通信开销。
    • 流水线并行 (PP, Pipeline Parallelism): PP=1 (即不启用)。对于 3B-13B 规模的模型,在 H100 80GB 的显存下,TP 结合 ZeRO 已足够容纳模型。引入 PP(例如 PP=2)虽然能进一步切分模型,但会带来“流水线气泡”(pipeline bubble)的额外开销,即流水线中部分 GPU 处于空闲等待状态,降低了硬件利用率。因此,在此规模下,我们选择不使用 PP。
    • 数据并行 (DP, Data Parallelism): DP=8 (64 GPUs / (TP=8 * PP=1))。数据并行是扩展训练吞吐量的主要方式。每个 DP rank 拥有一套完整的模型(在 ZeRO-3 下是分片的),处理不同批次的数据。
    • ZeRO 策略: ZeRO Stage 3 + Paged AdamW Optimizer。这是最大化显存优化的终极方案。ZeRO-3 将模型参数、梯度和优化器状态全部分片到所有 DP rank 上,极大地降低了单卡显存峰值。其代价是更高的通信量(每次前向/反向传播需要 all-gather 对应的模型参数)。配合 Paged AdamW,可以将易被换出的优化器状态 offload 到 CPU 内存,为更大的 micro-batch 或模型腾出宝贵的 HBM 空间。
  • 数据 (回顾 Chapter 02, 08)

    • 总 Tokens: T_tokens = 1T (1,000,000,000,000)。这个数字并非随意设定,它源于 Chapter 4 讨论的 Chinchilla-style Scaling Laws,对于 7B-13B 级别的模型,1T-2T tokens 是一个接近“计算最优”的训练数据量。
    • 格式与策略: 预分词、打包 (packed) 并切分为 shards 的 WebDataset 或 Parquet 格式,存储于 CPFS。打包至关重要:它将多个短文档拼接成一个 L_ctx 长度的序列,中间用特殊 token 分隔,并构建相应的 attention mask。这确保了每个 token 都在参与计算,消除了因 padding 带来的巨大计算浪费,是提升训练效率的关键技巧。
    • 加载: 采用流式加载(streaming),每个 DP rank 读取独立的 shard 子集并通过 PyTorch DataLoader 的 prefetch_factornum_workers 机制确保数据供给始终快于 GPU 计算,避免 I/O 成为瓶颈。
  • 数值精度与性能优化 (回顾 Chapter 03, 07)

    • 训练精度: bf16 (Brain Floating Point)。H100 Tensor Core 对 bf16 提供原生硬件加速。相较于 fp16bf16 拥有与 fp32 相同的 8 位指数位,动态范围更广,极大降低了训练中梯度下溢(underflow)或溢出(overflow)的风险,使得混合精度训练更加稳定,通常不再需要复杂的动态损失缩放(Dynamic Loss Scaling)。
    • 核心算子: 全面启用 FlashAttention v2fused RMSNormfused SwiGLUfused RoPE。这些算子通过将多个操作合并到一个 CUDA kernel 中执行,大幅减少了 HBM 显存读写和 kernel launch 开销,是榨干 H100 算力的关键。它们不仅提升了 tokens/sec 吞吐率,有时还能因减少中间步骤的舍入误差而改善数值稳定性。

2. 基线配置表:理论与实践的结合

以下配置表是本章的核心。这些数字并非孤立存在,而是 Scaling Laws、硬件限制和稳定性考量三者博弈的结果。

2.1 配置表 A: 4k 上下文 (L_ctx = 4096)

这是大多数开源模型起步的经典配置,在通用能力和训练成本之间取得了良好平衡。

| 参数 (Parameter) | 3B 模型 | 7B 模型 | 13B 模型 | 说明/权衡 (Explanation/Trade-off) |

参数 (Parameter) 3B 模型 7B 模型 13B 模型 说明/权衡 (Explanation/Trade-off)
模型架构 (LLaMA-style) 沿用 LLaMA 论文的成功设计,确保了结构的稳定性和效率。
d_model 3200 4096 5120 隐藏层维度,模型容量的核心。
n_layers 32 32 40 模型深度。更深的模型通常能学习更复杂的层次化特征。
n_heads 32 32 40 多头注意力机制的头数。
n_kv_heads (可选 GQA) 8 (可选 GQA) 8 (可选 GQA) 8 Grouped-Query Attention。在训练阶段,其对性能影响不大,但能显著减少推理时 K/V 缓存大小,是现代模型设计的趋势。
intermediate_size (SwiGLU) 8640 11008 13824 FFN 中间层大小。LLaMA 使用 SwiGLU 激活函数,其公式为 2/3 * 4 * d_model 并向上取整到 128 的倍数,以优化硬件利用率。
训练目标 固定的训练预算,用于横向对比不同模型。
T_tokens (总训练 Token) 1T 1T 1T
L_ctx (上下文长度) 4096 4096 4096
批量与扩展 核心权衡区:在 Chinchilla 最优的统计效率、硬件可承受的显存压力和通信开销之间寻找最佳实践点。
Global Batch Size (GB_tok) 2.1M (2,097,152) 4.2M (4,194,304) 4.2M (4,194,304) 关键参数。遵循 Scaling Law,更大模型需要更大 batch 来稳定梯度、充分利用并行计算并获得更好的最终性能。2M-4M token 是当前大模型训练的“甜蜜点”。
μ_batch (单卡样本数) 8 4 2 显存调优的第一旋钮。此值直接决定了单次前向/反向传播的显存峰值。必须确保 μ_batch * L_ctx 对应的激活和梯度能被 HBM 容纳。
μ_tok (单卡 micro-batch tokens) 32,768 16,384 8,192 μ_tok = μ_batch * L_ctx。这个指标反映了单次 GPU kernel 计算的规模。
A (梯度累积步数) 2.1M / (32k * 64) ≈ 1 4.2M / (16k * 64) ≈ 4 4.2M / (8k * 64) ≈ 8 A = GB_tok / (μ_tok * #GPUs)。梯度累积是用计算时间换取显存的技巧,它在不增加显存占用的情况下模拟了超大 batch size,但代价是参数更新频率降低,可能影响收敛动态。
优化器与学习率 学习率和优化器参数的设定直接决定训练的成败。
Optimizer AdamW AdamW AdamW 默认且最稳健的选择。β₁=0.9, β₂=0.95, ε=1e-8β₂=0.95 是针对大模型训练稳定性进行的常见调整(相较于默认的 0.999)。
η (Peak Learning Rate) 3.0e-4 3.0e-4 1.5e-4 根据 large-batch scaling 原则调整,但非严格线性。经验法则是:模型越大,LR 越小;Batch越大,LR 越大。13B 模型使用更低的 LR 是为了抑制潜在的不稳定性。
min_lr η / 10 η / 10 η / 10 Cosine schedule 的终点学习率,确保训练末期模型仍在微调。
Warmup Tokens 2B 3B 3B 训练初期,模型权重随机,梯度巨大且方向不定。Warmup 阶段让 LR 从零缓慢增长,给予模型一个“热身”期来稳定梯度,防止初期就“跑飞”。
LR Schedule Cosine Decay Cosine Decay Cosine Decay 经过 warmup 后,学习率按余弦曲线平滑衰减至 min_lr。这是目前最常用且效果最好的调度策略。
wd (Weight Decay) 0.1 0.1 0.1 L2 正则化,防止模型参数过大,提高泛化能力。
grad_clip 1.0 1.0 1.0 按全局梯度范数进行裁剪,这是防止梯度爆炸、确保训练稳定的最后一道防线。
内存优化
Activation Checkpointing Off On On 空间换时间。开启后,在前向传播时不再存储所有中间激活,而是在反向传播时重新计算它们。这会增加约 20-30% 的计算耗时,但能极大降低显存占用,对于 7B+ 模型几乎是必需的。

2.2 配置表 B: 8k 上下文 (L_ctx = 8192)

扩展到 8k 上下文是提升模型长文本能力的关键。这主要对显存和位置编码提出挑战,配置需相应调整。

| 参数 (Parameter) | 3B 模型 | 7B 模型 | 13B 模型 | 说明/权衡 (Explanation/Trade-off) |

参数 (Parameter) 3B 模型 7B 模型 13B 模型 说明/权衡 (Explanation/Trade-off)
L_ctx (上下文长度) 8192 8192 8192 上下文长度翻倍,Attention 矩阵的计算和存储开销增长为 平方级别 (L_ctx²),这是主要的性能和显存瓶颈。
RoPE Scaling (回顾 Chapter 03)
RoPE Scaling Type YaRN YaRN YaRN YaRN (Yet another RoPE extensioN) 通过插值和外推的结合,在扩展长上下文时,相比 PI 或 NTK-aware scaling 能更好地保持模型的短文本性能和长文本 PPL,是当前的主流选择。
RoPE Scaling Factor 2.0 2.0 2.0 (目标长度 / 原始长度),即 8192 / 4096 = 2.0
RoPE Original Max Pos 4096 4096 4096 告知 YaRN 算法原始 RoPE 的设计基线。
批量与扩展 (调整后)
Global Batch Size (GB_tok) 2.1M 4.2M 4.2M 保持不变。Scaling Law 关注的是参数更新的信噪比,主要与模型参数量和 tokens 总量相关,上下文长度对其影响较小,因此我们保持 GB_tok 来维持相似的收敛动态。
μ_batch (单卡样本数) 4 (vs 8) 2 (vs 4) 1 (vs 2) 核心调整。由于 L_ctx 翻倍导致 Attention 显存开销 4 倍增长,我们必须将单卡处理的样本数减半,以避免 OOM。这是最直接有效的应对策略。
μ_tok (单卡 micro-batch tokens) 32,768 16,384 8,192 μ_tok 保持不变,这是通过 μ_batch 的减半来抵消 L_ctx 翻倍的结果。维持 μ_tok 有助于保持计算核的利用率。
A (梯度累积步数) 2.1M / (32k * 64) ≈ 1 4.2M / (16k * 64) ≈ 4 4.2M / (8k * 64) ≈ 8 由于 μ_tokGB_tok 均未变,A 也不需要改变。
Activation Checkpointing On (vs Off) On On 强制开启。对于 8k 上下文,即使是 3B 模型,激活值的显存占用也会非常巨大,开启激活重计算是避免 OOM 的必要手段。
其它参数
其它所有参数 同 4k 配置 同 4k 配置 同 4k 配置 LR、优化器、warmup 等策略通常可直接沿用。但需更密切关注稳定性,因为更长的上下文意味着更长的依赖链,能放大数值误差,增加梯度消失/爆炸的风险。

3. 训练日志样例解读:成为“炼丹宗师”的第一步

启动训练后,持续监控日志是你作为 AI Scientist 最重要的工作。日志是模型的“心电图”,能揭示其内部状态。

  • 损失曲线 (Loss Curve)
    • 健康状态: 整体呈平滑下降趋势,符合幂律分布(在 log-log 图上接近一条直线)。曲线在放大后有正常的随机噪声波动,但整体趋势稳定向下。训练初期下降快,后期逐渐放缓。
    • 异常信号与诊断:
      • 尖峰 (Spike): 损失突然暴增后回落。
        • 日志表现: loss2.5 突然跳到 5.0,几步后又回到 2.5 附近。
        • 可能原因: 1. 数据问题: 遇到了一个包含乱码、超长数字序列或格式错误的“毒样本”。2. 学习率过高: 在某个特定的梯度方向上子迈得太大。3. 数值不稳定: bf16 在极端情况下精度不足。
        • 行动: 如果偶尔出现且能恢复,可忽略。若频繁出现,应降低 LR 或检查数据清洗流程。
      • 停滞 (Plateau): 损失在数千步内几乎不再下降。
        • 日志表现: loss 长期在 1.85 ± 0.02 范围内波动。
        • 可能原因: 1. 学习率衰减过快/过低: 模型失去了进一步优化的动力。2. 数据耗尽/多样性不足: 模型已“背熟”了训练数据。3. 陷入局部最优
        • 行动: 检查 LR schedule,考虑“重启” LR (LR restart)。检查数据混比策略,是否某个高质量数据源已用尽。
      • 发散 (Divergence): 损失持续增大,最终变为 NaN
        • 日志表现: loss3.0 -> 4.5 -> 10.2 -> NaN
        • 可能原因: 灾难性问题。通常是学习率过高、梯度爆炸、RoPE scaling 实现有 bug、或严重的硬件/CUDA 问题。
        • 行动: 立即停止训练。从最后一个正常的 checkpoint 回滚,大幅降低学习率(例如减半),并仔细检查代码和数据。
[健康 log-log 图]  [尖峰]             [停滞]             [发散]
  log(loss)|         loss|              loss|              loss|
           |           | spike          |                |
  \        |          / \               |                |
   \       |         /   \              |                |     .----
    \      |        /     \             |----.           .----'
     \_____|       /_______\_           |     `----.----'
-----------+----->-----------+----->   -----------+----->   -----------+----->
      log(iter)         iter             iter             iter
  • 梯度范数与优化器状态 (Gradient Norm & Optimizer States)

    • 监控点:
      1. 全局梯度范数 (Global Grad Norm): 在梯度裁剪前,所有参数梯度的 L2 范数。一个健康的梯度范数应在 warmup 后稳定在一个数量级内(例如 1.0-10.0 之间),有波动但无持续增长趋势。如果该值频繁地、远超 grad_clip (1.0),说明学习过程非常不稳定,是 LR 过高的明确信号。
      2. DeepSpeed 溢出监控: overflow 计数器。如果该计数器不为零,说明在 bf16 梯度计算中出现了 infnan。偶尔的溢出可以被跳过,但如果持续增加,说明存在严重的数值稳定性问题。
  • 吞吐率与硬件利用率 (Throughput & Utilization)

    • 核心指标: samples/sec, tokens/sec/gpu, TFLOPs/gpu
    • 目标: 对于 H100,一个经过优化的 7B 模型训练,单卡 TFLOPs 理论峰值约为 989 (bf16),实际利用率(MFU, Model FLOPs Utilization)能达到 50-60% 即为优秀,对应约 500 TFLOPs。这转化为一个非常可观的 tokens/sec
    • 诊断: 如果吞率远低于预期(例如 MFU < 30%),使用 PyTorch Profiler 分析 step 时间构成。常见瓶颈包括:
      • data_loading: I/O 或 CPU 预处理成为瓶颈。
      • all_reduce/all_gather: ZeRO 或 TP 的通信开销过大。
      • optimizer.step: 优化器更新步骤耗时过长,可能是 CPU offload 带来的瓶颈。

本章小结

  • 配方是科学与艺术的结合: 本章提供的配置表是基于 Scaling Laws 的科学推导和大量工程实践经验的结合,是启动大规模训练的坚实基础。
  • 权衡无处不在:
    • 模型尺寸 vs. Batch Size: 更大模型需要更大 Global Batch Size 来稳定收敛。
    • 上下文长度 vs. 显存: L_ctx 增加,μ_batch 必须相应减小,并强制开启激活重计算。
    • 速度 vs. 内存: ZeRO-3 和激活重计算用通信/计算开销换取了宝贵的显存空间,使得大模型训练成为可能。
  • 并行策略是根基: TP=8 最大节点内效率,ZeRO-3 解决跨节点显存扩展问题,这套组合拳是当前 H100 集群的标准打法。
  • 监控日志是航海图: 深刻理解并持续监控损失、梯度范数和吞吐率,是诊断训练问题、避免在错误航向上浪费数百万 GPU 小时的关键。

常见陷阱与错误 (Gotchas)

  1. 配置“一抄到底”不调整 (Blindly Copying Configs)

    • 症状: 将 13B 模型的 LR (1.5e-4) 用于 3B 模型,发现损失下降极其缓慢;反之,将 3B 的 LR (3.0e-4) 用于 13B,训练在几百步内 loss 变为 NaN
    • 分析: 不同规模的模型对学习率的敏感度不同。这是一个必须根据模型规模和 GB_tok 进行调整的核心参数。
    • 对策: 遵循“模型越大、LR 越小”的原则。调整 GB_tok 时,参考 Chapter 5 的 linearsqrt scaling 法则来调整 LR。
  2. 忽视早期的损失尖峰 (Ignoring Early Loss Spikes)

    • 症状: 训练前 1k 步,loss 出现数次 > 5.0 的尖峰,但最终都“恢复”了。你认为这很正常并继续训练。
    • 分析: 早期的尖峰是训练不稳定的强烈信号。即使恢复,也可能已经对模型权重造成了不可逆的“伤害”,影响最终性能。
    • 对策: 立即暂停。抓取导致尖峰的那个 batch 的数据样本,进行人工检查。90% 的可能是数据质量问题。同时,可以考虑延长 warmup 步数或降低初始 LR。
  3. 8k 上下文 OOM 的隐蔽原因 (Subtle OOMs with 8k Context)

    • 症状: 将配置从 4k 改为 8k,μ_batch 也已减半,但依然在训练中途随机 OOM。
    • 分析: 除了 Attention 矩阵,某些中间激活或临时变量也可能与 L_ctx 相关。例如,如果某个自定义的 fused 操作内部实现不佳,可能会产生一个巨大的临时张量。
    • 对策: 使用 torch.cuda.memory_summary()torch.cuda.max_memory_allocated() 在代码关键位插入打印,精确定位 OOM 发生在哪个操作之后。实在不行,只能进一步减小 μ_batch
  4. 无声的数据加载瓶颈 (The Silent Data Loading Bottleneck)

    • 症状: loss 曲线正常,没有错误,但 TFLOPs 利用率只有 20%,远低于预期。GPU 利用率在监控工具中显示为周期性的“波谷”。
    • 分析: 这是典型的数据供给慢于 GPU 计算。GPU 在每个 step 开始时都在等待数据。
    • 对策: 1. 增加 DataLoadernum_workers。2. 开启 pin_memory=True。3. 检查数据预处理逻辑是否有 Python 全局锁(GIL)瓶颈。4. 压测 CPFS 的并发读取性能,确认不是存储端的问题。
  5. RoPE Scaling 的精度陷阱 (Precision Pitfalls in RoPE Scaling)

    • 症状: 训练 8k 模型时,偶尔出现 NaN,尤其是在梯度累积步数较多的情况下。
    • 分析: RoPE 的计算涉及三角函数和旋转操作。在长序列和 bf16 精度下,些中间计算结果可能超出 bf16 的表示范围或损失过多精度,最终导致 infnan 在 attention score 中传播。
    • 对策: 强制 RoPE 的部分计算(例如生成 cossin 值)在 fp32 下进行,然后再 cast 回 bf16。这会带来微小的性能开销,但能显著提升数值稳定性。