chapter11.md — 端到端:从零预训练(1T tokens)
开篇段落
本章是整个教程的实战巅峰,我们将前面章节中探讨的 Scaling Laws、优化器策略、并行技术和数据方案,融合成一套针对 3B、7B 和 13B LLaMA 风格模型的、可在 64x H100 80GB 集群上直接运行的端到端预训练“配方”。这不仅仅是一张参数表,更是对大规模训练中无数权衡与决策的总结。学习本章后,你将获得一套经过验证的、可用于启动 1T tokens 从零预训练的基线配置。更重要的是,你将深入理解每个关键超参背后的“为什么”,学会如何通过解读复杂的训练日志来诊断训练过程的健康状况,并具备将这些“配方”根据自身求进行调整和优化的能力。这些配置是稳健的起点,而非一成不变的终点,旨在助你自信地迈出从零训练的第一步。
1. 通用配置原则与环境假设
在深入具体参数之前,我们必须明确所有尺寸模型都将遵循的通用原则和环境基线。这些选择共同构成了我们训练框架的“骨架”,确保了效率、稳定性和可扩展性。
-
硬件与并行策略 (回顾 Chapter 09)
- 集群规模: 8 节点 × 8 卡/节点 = 64 × H100 80GB SXM。我们假设节点内通过高速 NVLink/NVSwitch 连接,节点间通过 InfiniBand/400GbE 连接。
- 张量并行 (TP, Tensor Parallelism):
TP=8。这是最大化单节点性能的基石。Transformer 中的nn.Linear和Attention模块的计算可以被优雅地沿特定维度切分。将 TP 设置为节点内的 GPU 数量(8),可以确保通信量最大、最频繁的张量切片交换(all-reduce, all-gather)发生在速度最快的 NVLink ,从而最小化通信开销。 - 流水线并行 (PP, Pipeline Parallelism):
PP=1(即不启用)。对于 3B-13B 规模的模型,在 H100 80GB 的显存下,TP 结合 ZeRO 已足够容纳模型。引入 PP(例如PP=2)虽然能进一步切分模型,但会带来“流水线气泡”(pipeline bubble)的额外开销,即流水线中部分 GPU 处于空闲等待状态,降低了硬件利用率。因此,在此规模下,我们选择不使用 PP。 - 数据并行 (DP, Data Parallelism):
DP=8(64 GPUs / (TP=8 * PP=1))。数据并行是扩展训练吞吐量的主要方式。每个 DP rank 拥有一套完整的模型(在 ZeRO-3 下是分片的),处理不同批次的数据。 - ZeRO 策略: ZeRO Stage 3 + Paged AdamW Optimizer。这是最大化显存优化的终极方案。ZeRO-3 将模型参数、梯度和优化器状态全部分片到所有 DP rank 上,极大地降低了单卡显存峰值。其代价是更高的通信量(每次前向/反向传播需要 all-gather 对应的模型参数)。配合 Paged AdamW,可以将易被换出的优化器状态 offload 到 CPU 内存,为更大的
micro-batch或模型腾出宝贵的 HBM 空间。
-
数据 (回顾 Chapter 02, 08)
- 总 Tokens:
T_tokens = 1T(1,000,000,000,000)。这个数字并非随意设定,它源于 Chapter 4 讨论的 Chinchilla-style Scaling Laws,对于 7B-13B 级别的模型,1T-2T tokens 是一个接近“计算最优”的训练数据量。 - 格式与策略: 预分词、打包 (packed) 并切分为 shards 的 WebDataset 或 Parquet 格式,存储于 CPFS。打包至关重要:它将多个短文档拼接成一个
L_ctx长度的序列,中间用特殊 token 分隔,并构建相应的 attention mask。这确保了每个 token 都在参与计算,消除了因 padding 带来的巨大计算浪费,是提升训练效率的关键技巧。 - 加载: 采用流式加载(streaming),每个 DP rank 读取独立的 shard 子集并通过 PyTorch DataLoader 的
prefetch_factor和num_workers机制确保数据供给始终快于 GPU 计算,避免 I/O 成为瓶颈。
- 总 Tokens:
-
数值精度与性能优化 (回顾 Chapter 03, 07)
- 训练精度:
bf16(Brain Floating Point)。H100 Tensor Core 对bf16提供原生硬件加速。相较于fp16,bf16拥有与fp32相同的 8 位指数位,动态范围更广,极大降低了训练中梯度下溢(underflow)或溢出(overflow)的风险,使得混合精度训练更加稳定,通常不再需要复杂的动态损失缩放(Dynamic Loss Scaling)。 - 核心算子: 全面启用 FlashAttention v2、fused RMSNorm、fused SwiGLU 和 fused RoPE。这些算子通过将多个操作合并到一个 CUDA kernel 中执行,大幅减少了 HBM 显存读写和 kernel launch 开销,是榨干 H100 算力的关键。它们不仅提升了
tokens/sec吞吐率,有时还能因减少中间步骤的舍入误差而改善数值稳定性。
- 训练精度:
2. 基线配置表:理论与实践的结合
以下配置表是本章的核心。这些数字并非孤立存在,而是 Scaling Laws、硬件限制和稳定性考量三者博弈的结果。
2.1 配置表 A: 4k 上下文 (L_ctx = 4096)
这是大多数开源模型起步的经典配置,在通用能力和训练成本之间取得了良好平衡。
| 参数 (Parameter) | 3B 模型 | 7B 模型 | 13B 模型 | 说明/权衡 (Explanation/Trade-off) |
| 参数 (Parameter) | 3B 模型 | 7B 模型 | 13B 模型 | 说明/权衡 (Explanation/Trade-off) |
|---|---|---|---|---|
| 模型架构 (LLaMA-style) | 沿用 LLaMA 论文的成功设计,确保了结构的稳定性和效率。 | |||
d_model |
3200 | 4096 | 5120 | 隐藏层维度,模型容量的核心。 |
n_layers |
32 | 32 | 40 | 模型深度。更深的模型通常能学习更复杂的层次化特征。 |
n_heads |
32 | 32 | 40 | 多头注意力机制的头数。 |
n_kv_heads |
(可选 GQA) 8 | (可选 GQA) 8 | (可选 GQA) 8 | Grouped-Query Attention。在训练阶段,其对性能影响不大,但能显著减少推理时 K/V 缓存大小,是现代模型设计的趋势。 |
intermediate_size (SwiGLU) |
8640 | 11008 | 13824 | FFN 中间层大小。LLaMA 使用 SwiGLU 激活函数,其公式为 2/3 * 4 * d_model 并向上取整到 128 的倍数,以优化硬件利用率。 |
| 训练目标 | 固定的训练预算,用于横向对比不同模型。 | |||
T_tokens (总训练 Token) |
1T | 1T | 1T | |
L_ctx (上下文长度) |
4096 | 4096 | 4096 | |
| 批量与扩展 | 核心权衡区:在 Chinchilla 最优的统计效率、硬件可承受的显存压力和通信开销之间寻找最佳实践点。 | |||
Global Batch Size (GB_tok) |
2.1M (2,097,152) | 4.2M (4,194,304) | 4.2M (4,194,304) | 关键参数。遵循 Scaling Law,更大模型需要更大 batch 来稳定梯度、充分利用并行计算并获得更好的最终性能。2M-4M token 是当前大模型训练的“甜蜜点”。 |
μ_batch (单卡样本数) |
8 | 4 | 2 | 显存调优的第一旋钮。此值直接决定了单次前向/反向传播的显存峰值。必须确保 μ_batch * L_ctx 对应的激活和梯度能被 HBM 容纳。 |
μ_tok (单卡 micro-batch tokens) |
32,768 | 16,384 | 8,192 | μ_tok = μ_batch * L_ctx。这个指标反映了单次 GPU kernel 计算的规模。 |
A (梯度累积步数) |
2.1M / (32k * 64) ≈ 1 |
4.2M / (16k * 64) ≈ 4 |
4.2M / (8k * 64) ≈ 8 |
A = GB_tok / (μ_tok * #GPUs)。梯度累积是用计算时间换取显存的技巧,它在不增加显存占用的情况下模拟了超大 batch size,但代价是参数更新频率降低,可能影响收敛动态。 |
| 优化器与学习率 | 学习率和优化器参数的设定直接决定训练的成败。 | |||
Optimizer |
AdamW | AdamW | AdamW | 默认且最稳健的选择。β₁=0.9, β₂=0.95, ε=1e-8。β₂=0.95 是针对大模型训练稳定性进行的常见调整(相较于默认的 0.999)。 |
η (Peak Learning Rate) |
3.0e-4 | 3.0e-4 | 1.5e-4 | 根据 large-batch scaling 原则调整,但非严格线性。经验法则是:模型越大,LR 越小;Batch越大,LR 越大。13B 模型使用更低的 LR 是为了抑制潜在的不稳定性。 |
min_lr |
η / 10 |
η / 10 |
η / 10 |
Cosine schedule 的终点学习率,确保训练末期模型仍在微调。 |
Warmup Tokens |
2B | 3B | 3B | 训练初期,模型权重随机,梯度巨大且方向不定。Warmup 阶段让 LR 从零缓慢增长,给予模型一个“热身”期来稳定梯度,防止初期就“跑飞”。 |
LR Schedule |
Cosine Decay | Cosine Decay | Cosine Decay | 经过 warmup 后,学习率按余弦曲线平滑衰减至 min_lr。这是目前最常用且效果最好的调度策略。 |
wd (Weight Decay) |
0.1 | 0.1 | 0.1 | L2 正则化,防止模型参数过大,提高泛化能力。 |
grad_clip |
1.0 | 1.0 | 1.0 | 按全局梯度范数进行裁剪,这是防止梯度爆炸、确保训练稳定的最后一道防线。 |
| 内存优化 | ||||
Activation Checkpointing |
Off | On | On | 空间换时间。开启后,在前向传播时不再存储所有中间激活,而是在反向传播时重新计算它们。这会增加约 20-30% 的计算耗时,但能极大降低显存占用,对于 7B+ 模型几乎是必需的。 |
2.2 配置表 B: 8k 上下文 (L_ctx = 8192)
扩展到 8k 上下文是提升模型长文本能力的关键。这主要对显存和位置编码提出挑战,配置需相应调整。
| 参数 (Parameter) | 3B 模型 | 7B 模型 | 13B 模型 | 说明/权衡 (Explanation/Trade-off) |
| 参数 (Parameter) | 3B 模型 | 7B 模型 | 13B 模型 | 说明/权衡 (Explanation/Trade-off) |
|---|---|---|---|---|
L_ctx (上下文长度) |
8192 | 8192 | 8192 | 上下文长度翻倍,Attention 矩阵的计算和存储开销增长为 平方级别 (L_ctx²),这是主要的性能和显存瓶颈。 |
| RoPE Scaling (回顾 Chapter 03) | ||||
RoPE Scaling Type |
YaRN | YaRN | YaRN | YaRN (Yet another RoPE extensioN) 通过插值和外推的结合,在扩展长上下文时,相比 PI 或 NTK-aware scaling 能更好地保持模型的短文本性能和长文本 PPL,是当前的主流选择。 |
RoPE Scaling Factor |
2.0 | 2.0 | 2.0 | (目标长度 / 原始长度),即 8192 / 4096 = 2.0。 |
RoPE Original Max Pos |
4096 | 4096 | 4096 | 告知 YaRN 算法原始 RoPE 的设计基线。 |
| 批量与扩展 (调整后) | ||||
Global Batch Size (GB_tok) |
2.1M | 4.2M | 4.2M | 保持不变。Scaling Law 关注的是参数更新的信噪比,主要与模型参数量和 tokens 总量相关,上下文长度对其影响较小,因此我们保持 GB_tok 来维持相似的收敛动态。 |
μ_batch (单卡样本数) |
4 (vs 8) | 2 (vs 4) | 1 (vs 2) | 核心调整。由于 L_ctx 翻倍导致 Attention 显存开销 4 倍增长,我们必须将单卡处理的样本数减半,以避免 OOM。这是最直接有效的应对策略。 |
μ_tok (单卡 micro-batch tokens) |
32,768 | 16,384 | 8,192 | μ_tok 保持不变,这是通过 μ_batch 的减半来抵消 L_ctx 翻倍的结果。维持 μ_tok 有助于保持计算核的利用率。 |
A (梯度累积步数) |
2.1M / (32k * 64) ≈ 1 |
4.2M / (16k * 64) ≈ 4 |
4.2M / (8k * 64) ≈ 8 |
由于 μ_tok 和 GB_tok 均未变,A 也不需要改变。 |
Activation Checkpointing |
On (vs Off) | On | On | 强制开启。对于 8k 上下文,即使是 3B 模型,激活值的显存占用也会非常巨大,开启激活重计算是避免 OOM 的必要手段。 |
| 其它参数 | ||||
| 其它所有参数 | 同 4k 配置 | 同 4k 配置 | 同 4k 配置 | LR、优化器、warmup 等策略通常可直接沿用。但需更密切关注稳定性,因为更长的上下文意味着更长的依赖链,能放大数值误差,增加梯度消失/爆炸的风险。 |
3. 训练日志样例解读:成为“炼丹宗师”的第一步
启动训练后,持续监控日志是你作为 AI Scientist 最重要的工作。日志是模型的“心电图”,能揭示其内部状态。
- 损失曲线 (Loss Curve)
- 健康状态: 整体呈平滑下降趋势,符合幂律分布(在 log-log 图上接近一条直线)。曲线在放大后有正常的随机噪声波动,但整体趋势稳定向下。训练初期下降快,后期逐渐放缓。
- 异常信号与诊断:
- 尖峰 (Spike): 损失突然暴增后回落。
- 日志表现:
loss从2.5突然跳到5.0,几步后又回到2.5附近。 - 可能原因: 1. 数据问题: 遇到了一个包含乱码、超长数字序列或格式错误的“毒样本”。2. 学习率过高: 在某个特定的梯度方向上子迈得太大。3. 数值不稳定:
bf16在极端情况下精度不足。 - 行动: 如果偶尔出现且能恢复,可忽略。若频繁出现,应降低 LR 或检查数据清洗流程。
- 日志表现:
- 停滞 (Plateau): 损失在数千步内几乎不再下降。
- 日志表现:
loss长期在1.85 ± 0.02范围内波动。 - 可能原因: 1. 学习率衰减过快/过低: 模型失去了进一步优化的动力。2. 数据耗尽/多样性不足: 模型已“背熟”了训练数据。3. 陷入局部最优。
- 行动: 检查 LR schedule,考虑“重启” LR (LR restart)。检查数据混比策略,是否某个高质量数据源已用尽。
- 日志表现:
- 发散 (Divergence): 损失持续增大,最终变为
NaN。- 日志表现:
loss从3.0->4.5->10.2->NaN。 - 可能原因: 灾难性问题。通常是学习率过高、梯度爆炸、RoPE scaling 实现有 bug、或严重的硬件/CUDA 问题。
- 行动: 立即停止训练。从最后一个正常的 checkpoint 回滚,大幅降低学习率(例如减半),并仔细检查代码和数据。
- 日志表现:
- 尖峰 (Spike): 损失突然暴增后回落。
[健康 log-log 图] [尖峰] [停滞] [发散]
log(loss)| loss| loss| loss|
| | spike | |
\ | / \ | |
\ | / \ | | .----
\ | / \ |----. .----'
\_____| /_______\_ | `----.----'
-----------+----->-----------+-----> -----------+-----> -----------+----->
log(iter) iter iter iter
-
梯度范数与优化器状态 (Gradient Norm & Optimizer States)
- 监控点:
- 全局梯度范数 (Global Grad Norm): 在梯度裁剪前,所有参数梯度的 L2 范数。一个健康的梯度范数应在 warmup 后稳定在一个数量级内(例如 1.0-10.0 之间),有波动但无持续增长趋势。如果该值频繁地、远超
grad_clip(1.0),说明学习过程非常不稳定,是 LR 过高的明确信号。 - DeepSpeed 溢出监控:
overflow计数器。如果该计数器不为零,说明在bf16梯度计算中出现了inf或nan。偶尔的溢出可以被跳过,但如果持续增加,说明存在严重的数值稳定性问题。
- 全局梯度范数 (Global Grad Norm): 在梯度裁剪前,所有参数梯度的 L2 范数。一个健康的梯度范数应在 warmup 后稳定在一个数量级内(例如 1.0-10.0 之间),有波动但无持续增长趋势。如果该值频繁地、远超
- 监控点:
-
吞吐率与硬件利用率 (Throughput & Utilization)
- 核心指标:
samples/sec,tokens/sec/gpu,TFLOPs/gpu。 - 目标: 对于 H100,一个经过优化的 7B 模型训练,单卡 TFLOPs 理论峰值约为 989 (
bf16),实际利用率(MFU, Model FLOPs Utilization)能达到 50-60% 即为优秀,对应约 500 TFLOPs。这转化为一个非常可观的tokens/sec。 - 诊断: 如果吞率远低于预期(例如 MFU < 30%),使用 PyTorch Profiler 分析
step时间构成。常见瓶颈包括:data_loading: I/O 或 CPU 预处理成为瓶颈。all_reduce/all_gather: ZeRO 或 TP 的通信开销过大。optimizer.step: 优化器更新步骤耗时过长,可能是 CPU offload 带来的瓶颈。
- 核心指标:
本章小结
- 配方是科学与艺术的结合: 本章提供的配置表是基于 Scaling Laws 的科学推导和大量工程实践经验的结合,是启动大规模训练的坚实基础。
- 权衡无处不在:
- 模型尺寸 vs. Batch Size: 更大模型需要更大
Global Batch Size来稳定收敛。 - 上下文长度 vs. 显存:
L_ctx增加,μ_batch必须相应减小,并强制开启激活重计算。 - 速度 vs. 内存: ZeRO-3 和激活重计算用通信/计算开销换取了宝贵的显存空间,使得大模型训练成为可能。
- 模型尺寸 vs. Batch Size: 更大模型需要更大
- 并行策略是根基:
TP=8最大节点内效率,ZeRO-3 解决跨节点显存扩展问题,这套组合拳是当前 H100 集群的标准打法。 - 监控日志是航海图: 深刻理解并持续监控损失、梯度范数和吞吐率,是诊断训练问题、避免在错误航向上浪费数百万 GPU 小时的关键。
常见陷阱与错误 (Gotchas)
-
配置“一抄到底”不调整 (Blindly Copying Configs)
- 症状: 将 13B 模型的 LR (
1.5e-4) 用于 3B 模型,发现损失下降极其缓慢;反之,将 3B 的 LR (3.0e-4) 用于 13B,训练在几百步内loss变为NaN。 - 分析: 不同规模的模型对学习率的敏感度不同。这是一个必须根据模型规模和
GB_tok进行调整的核心参数。 - 对策: 遵循“模型越大、LR 越小”的原则。调整
GB_tok时,参考 Chapter 5 的linear或sqrtscaling 法则来调整 LR。
- 症状: 将 13B 模型的 LR (
-
忽视早期的损失尖峰 (Ignoring Early Loss Spikes)
- 症状: 训练前 1k 步,
loss出现数次 > 5.0 的尖峰,但最终都“恢复”了。你认为这很正常并继续训练。 - 分析: 早期的尖峰是训练不稳定的强烈信号。即使恢复,也可能已经对模型权重造成了不可逆的“伤害”,影响最终性能。
- 对策: 立即暂停。抓取导致尖峰的那个 batch 的数据样本,进行人工检查。90% 的可能是数据质量问题。同时,可以考虑延长 warmup 步数或降低初始 LR。
- 症状: 训练前 1k 步,
-
8k 上下文 OOM 的隐蔽原因 (Subtle OOMs with 8k Context)
- 症状: 将配置从 4k 改为 8k,
μ_batch也已减半,但依然在训练中途随机 OOM。 - 分析: 除了 Attention 矩阵,某些中间激活或临时变量也可能与
L_ctx相关。例如,如果某个自定义的fused操作内部实现不佳,可能会产生一个巨大的临时张量。 - 对策: 使用
torch.cuda.memory_summary()或torch.cuda.max_memory_allocated()在代码关键位插入打印,精确定位 OOM 发生在哪个操作之后。实在不行,只能进一步减小μ_batch。
- 症状: 将配置从 4k 改为 8k,
-
无声的数据加载瓶颈 (The Silent Data Loading Bottleneck)
- 症状:
loss曲线正常,没有错误,但TFLOPs利用率只有 20%,远低于预期。GPU 利用率在监控工具中显示为周期性的“波谷”。 - 分析: 这是典型的数据供给慢于 GPU 计算。GPU 在每个
step开始时都在等待数据。 - 对策: 1. 增加
DataLoader的num_workers。2. 开启pin_memory=True。3. 检查数据预处理逻辑是否有 Python 全局锁(GIL)瓶颈。4. 压测 CPFS 的并发读取性能,确认不是存储端的问题。
- 症状:
-
RoPE Scaling 的精度陷阱 (Precision Pitfalls in RoPE Scaling)
- 症状: 训练 8k 模型时,偶尔出现
NaN,尤其是在梯度累积步数较多的情况下。 - 分析: RoPE 的计算涉及三角函数和旋转操作。在长序列和
bf16精度下,些中间计算结果可能超出bf16的表示范围或损失过多精度,最终导致inf或nan在 attention score 中传播。 - 对策: 强制 RoPE 的部分计算(例如生成
cos和sin值)在fp32下进行,然后再 cast 回bf16。这会带来微小的性能开销,但能显著提升数值稳定性。
- 症状: 训练 8k 模型时,偶尔出现