第十二章 端到端:CPT / 继续预训练

开篇段落

欢迎来到第十二章,本章将深入探讨继续预训练(Continued Pre-Training, CPT)的端到端实践,这是一项在现代大模型生命周期管理中愈发核心的技术。与从零预训练(Pre-training from scratch)那种“一次性”的巨大投入不同,CPT 是一种更具经济效益和战略价值的“增量式”演进。它允许我们为一个已经强大的基座模型注入新领域的专业知识、更新过时的世界信息,或强化特定的技能,从而打造一个“活的”、持续进化的模型。本章将系统性地拆解 CPT 的核心哲学,深入剖析其在数据策略、优化器与学习率调度、监控与评等方面的关键差异,并提供一套可立即上手的配置指南与诊断清单。我们的目标是让你掌握在“新域能力拉升”与“基座泛化保持”之间取得精妙平衡的艺术,从而高效地实现模型价值的最大化。

CPT 的哲学:为何选择 CPT 而非 SFT 或从零训练?

在深入技术细节前,我们必须明确 CPT 的定位。它处在从零预训练和指令微调(Supervised Fine-Tuning, SFT)之间,解决的是不同的问题。

  • 从零预训练:构建模型的基础世界模型语言能力。这是一个无监督或自监督的过程,需要海量的通用数据和巨大的计算资源。
  • 指令微调(SFT):教会模型如何与人对话,遵循指令的格式和行为。SFT 的数据通常是 (prompt, response) 对,其核心是塑造模型的行为模式,而非注入大规模的事实性知识。
  • 继续预训练(CPT):在不破坏基础世界模型的前提下,扩展型的知识边界适应新的数据分布。CPT 的数据仍然是海量的非结构化文本,其目标是让模型“阅读”和“吸收”新领域的知识,使其语言模型本身(即对下一个 token 的预测分布)适应新的领域。

一句话总结:CPT 负责“学什么”(知识),SFT 负责“怎么说”(行为)。 当你的目标是让模型成为一个“医学专家”而不仅仅是“扮演医学专家的聊天机器人”时,CPT 是不可或缺的一步。

核心挑战:驾驭灾难性遗忘的艺术

CPT 的核心挑战是灾难性遗忘(Catastrophic Forgetting)。我们可以从优化理论的角度理解它:

一个预训练好的基座模型,其权重 θ_base 处在通用数据分布 P_general 的一个宽阔、平坦的损失极小值区域(a wide, flat loss minimum)。这个区域的平坦性赋予了模型良好的泛化能力。

CPT 的目标是找到一个新的权重 θ_cpt,使其在混合数据分布 P_mix = (1-α)P_general + αP_new上表现更优。然而,新领域数据 P_new 对应的损失曲面可能非常陡峭或与 P_general 的极小值区域相距甚远。如果训练策略过于激进(如学习率过高、新数据比例过大),优化器会把模型权重 θ 猛地“拽”出 P_general 的平坦区域,掉入一个只对 P_new 友好的狭窄深谷。模型在新领域上表现优异,但彻底忘记了如何在通用领域上泛化——这就是灾难性遗忘。

Loss Landscape
      ▲ Loss
      |
      |             /-----\
      |            /       \      <-- P_new's sharp minimum
      |           /         * θ_cpt_bad
  /---\          /
 /     \--------/
/       * θ_base

* θ_cpt_good

--------------------------------------------> Weight Space

CPT 的所有策略,都是为了让 θθ_base 温和地移动到 θ_cpt_good,一个同时兼顾两个分布的新平衡点,而不是被拉到灾难性的 θ_cpt_bad

一、数据策略CPT 的灵魂与基石

数据是 CPT 中最关键的变量,它直接决定了学习的效果和遗忘的程度。

1. 混比策略:黄金比例的探索

单纯使用新数据是 CPT 的大忌。数据混合是必须的

  • 混比比例(Mixing Ratio, α:即新数据在新旧混合数据中的占比。
    • 经验法则α 通常设定在 5% 到 20% 之间。一个稳妥的起点是 10%。
    • 权衡
      • α(如 5%):学习新知识较慢,需要更多训练步数,但遗忘风险极低,非常安全。适用于基座能力绝对不能受损的关键任务。
      • α(如 20%):学习新知识速度快,但遗忘风险显著增高。适用于新领域与旧领域差异不大,或对基座泛化能力有一定容忍度的场景。
  • "Replay Buffer" 思想:作为旧数据的 P_general,不一定需要使用全部的原始预训练语料。可以精心挑选一个高质量、多样化的子集(例如 1T tokens 的原始数据中,精选 200B tokens),作为“记忆回放缓冲池”。这可以显著降低数据存储和 I/O 负担。

2. 混比粒度:动态采样的优越性

如何实现数据混合?有两种主要方式:

  • 静态混合(预先混合):在训练开始前,将新旧数据按比例混合并打乱,制作成新的训练数据集。
    • 优点:实现简单,数据加载逻辑不变。
    • 缺点:不够灵活,无法在训练中调整比例。可能出现连续多个 batch 来自同一数据源的“数据热点”问题。
  • 动态混合(在线采样):在训练时,数据加载器从不同的数据源(旧数据流、新数据流)中,按预设的概率或温度动态采样,实时组合成一个 batch。
    • 优点
      1. 均匀混合:确保每个 Global Batch 的数据构成都接近目标比例,训练过程更平滑。
      2. 灵活性:可以轻松实现课程学习(Curriculum Learning),例如在训练初期使用 10% 的新数据,后期逐步提升到 15%。
    • 实现:在 PyTorch Lightning 中,可以通过自定义 DataModuleIterableDataset 来实现,根据设定的温度 τ 对不同数据集的采样概率进行加权。

3. 数据质量与序列长度

  • 新数据质量:CPT 对新领域数据的质量极其敏感。低质量、噪声大的新数据会严重干扰模型的学习过程,甚至污染已有知识。必须进行严格的清洗、去重和过滤。
  • 序列长度分布:检查新旧数据的序列长度分布是否匹配。如果新领域(如法律文书)平均长度远超通用领域,可能会导致模型在长文本建模上产生偏向。此时,需要确保 Packer 能有效处理,并且在评估时也关注不同长度文本的表现。

二、优化器与学习率:精细的手术刀

如果说数据是药物,那么优化器和学习率就是控制剂量的手术刀。

1. 学习率(LR:低、缓、稳

  • 峰值学习率(Peak LR)
    • 核心法则:CPT 的峰值 LR 应显著低于从零预训练,通常是原始峰值 LR 的 1/10 到 1/5。例如,7B 模型从零训练的 peak LR 为 3e-4,CPT 时应选择 3e-56e-5
    • 原因:模型权重已处于良好状态,梯度更新应是微调而非重塑。过高的 LR 会产生巨大的梯度,将权重推出稳定区域。
  • 学习率调度器(LR Scheduler)
    • Warmup:可以非常短,甚至省略。一个几百步的 warmup 足以让重置后的优化器稳定下来。
    • Decay 策略Cosine Decay 依然是黄金标准。它能确保学习率在训练结束时平滑地降至接近零,有助于模型收敛到更稳定的点。
    • 总步数(Total Steps):LR 调度器的总步数必须根据 CPT 的总训练 tokens(例如 200B)重新计算,而不是沿用原始的 1T tokens 对应的步数。这是一个非常常见的错误。

2. 优化器状态:必须从零开始

这是一个非黑即白的规则:绝对不要加载 Checkpoint 中的优化器状态

  • 技术解释:AdamW 等优化器维护着梯度的一阶矩(m,动量)和二阶矩(v,自适应学习率的分母)。这些状态编码了基于旧数据分布旧学习率的梯度历史。
    • 旧的 m 会带来巨大的惯性,可能在 CPT 初期将模型推向错误的方向。
    • 旧的 v 适应了旧的梯度尺度,直接用于新的梯度可能会导致某些参数的学习率过大或过小,造成不稳定。
  • 正确操作:加载模型权重后,重新初始化一个全新的优化器实例。让它在新的数据混合和学习率下,从头开始累积梯度统计信息。

3. 模型冻结策略:原则上不冻结

  • 主流选择全参数训练。CPT 的目标是让知识渗透到模型的每一部分,从底层的词嵌入到高层的概念推理。冻结任何层都可能阻碍这种全局性的适应。
  • 实验性探索(高风险):在极少数情况下,如果极端担心某些基础能力(如语法结构)受损,可以尝试冻结模型的最底层几层 Transformer Block。但这通常会牺牲在新领域上的学习效果,不推荐作为首选策略。

三、监控与评估:CPT 的仪表盘

CPT 的成功无法仅凭训练损失来判断,必须建立一个多维度的评估体系。

  • 核心监控指标
    1. 训练损失(Training Loss):应平稳下降。若出现剧烈抖动,检查数据混合或学习率。
    2. 领域内验证困惑度(In-Domain Val PPL):使用一个 held-out 的新领域验证集。这是衡量 CPT 学习效果的核心指标,应持续下降。
    3. 通用域验证困惑度(General Val PPL):使用一个 held-out 的通用领域验证集(如 C4, Pile 的一部分)。这是衡量遗忘程度的核心“护栏”指标,应保持稳定或仅轻微上涨。
  • 诊断图谱
(A) 理想的 CPT        (B) 灾难性遗忘         (C) 学习不足
 PPL                  PPL                   PPL
|                     |                      |
|---- Gen PPL ----    |---- Gen PPL ---     |---- Gen PPL ----
|   \                 |   \              /   |   \
|    \                |    \            /    |    \ In-Domain PPL
|     \ In-Domain PPL |     \          /     |     (下降缓慢)
|      \              |      \ In-Domain PPL |      
+-----------> Time    +-----------> Time     +-----------> Time
定期绘制这两条 PPL 曲线,上图 (A) 是我们的目标。出现 (B) 则立即降低 LR 或新数据比例;出现 (C) 则可考虑适度增加它们。

端到端配置清单:一个 7B 模型的 CPT 示例(扩展示例)

| 参数项 | 从零预训练(参考基线) | CPT(推荐配置) | 理由与深度剖析 |

参数项 从零预训练(参考基线) CPT(推荐配置) 理由与深度剖析
基座模型 None (机初始化) path/to/pretrained_7B_model.pt CPT 的基础是加载一个训练成熟的模型权重。
总训练 Tokens 1T 200B CPT 是短程训练,目标是适应而非重塑。200B tokens 通常足以在一个新领域达到不错的性能。
数据混比 (α) 100% 通用数据 90% 通用数据 + 10% 新领域数据 CPT 核心。10% 是一个平衡的起点,在安全和效率之间取得平衡。
数据混合方式 N/A 动态采样 (Dynamic Sampling) 保证每个 batch 的数据分布均匀,避免训练不稳,并为课程学习提供可能。
Global Batch Size 4M tokens 2M - 4M tokens 可以沿用,或适当减小。因为 LR 已经降低,过大的 batch 带来的梯度信噪比增益不再是首要矛盾。
峰值学习率 (Peak LR) 3.0e-4 3.0e-5 CPT 核心。1/10 的原始 LR 是一个安全的起点,它以足够小的步长探索新的损失极小值区域。
LR Warmup 步数 2000 200 模型已处在稳定状态,一个非常短的 warmup 即可让新优化器适应梯度。
LR Scheduler Cosine Decay (over ~250k steps) Cosine Decay (over ~50k steps) 易错点:调度器的总步数必须基于新的、更短的训练计划重新计算。
优化器状态加载 False False (强制) CPT 核心。必须丢弃旧的 mv 向量,它们包含着不适用于新任务的梯度历史。
AdamW Betas (0.9, 0.95) (0.9, 0.95) 通常无需更改。AdamW 的 betas 对 CPT 不如 LR 敏感。
Weight Decay 0.1 0.1 可以保持不变。过高的 weight decay 反而可能阻碍模型向新权重空间的适应。
模型层冻结 无(全参数训练) 默认进行全参数更新,以实现知识在整个模型中的充分渗透。

本章小结

  • CPT 是介于预训练和微调之间的关键技术,旨在扩展知识而非改变行为,其核心挑战是克服灾性遗忘
  • 成功的 CPT 依赖于三大支柱的精细调校:
    1. 数据策略新旧数据混合是铁律,5%-20% 的新数据比例是常见区间。推荐使用动态采样以获得更平滑的训练过程。
    2. 学习率与优化器:采用远低于原始训练的峰值学习率(如 1/10),配合短 warmup 和对齐新总步数的 Cosine Decay。必须重置优化器状态
    3. 监控体系:同时跟踪领域内通用域的验证 PPL,是诊断 CPT 是否成功的“仪表盘”。
  • CPT 是一场精细的平衡实验,需要耐心调参和细致观察,但其带来的模型能力提升和资源节约是巨大的。

常见陷阱与错误 (Gotchas)

  1. 陷阱:灾难性遗忘

    • 症状:通用域验证 PPL 急剧上升,模型在通用任务上表现得像一个“傻瓜”。
    • 诊断与修复:降低学习率(减半),或降低新数据混合比例(例如从 15% 降至 8%。这是最常见的 CPT 问题。
  2. 陷阱:学习停滞

    • 症状:领域内验证 PPL 下降缓慢或完全不下降。
    • 诊断与修复:首先检查新数据质量。若数据无误,可尝试适度提高学习率或新数据比例。也可能是总训练 tokens 不足,模型需要“阅读”更多新数据。
  3. 陷阱:加载了优化器状态

    • 症状:训练初期的 loss 出现剧烈震荡、NaN 或不收敛。
    • 诊断与修复:检查你的模型加载和优化器初始化代码。确保优化器是在加载模型权重 之后 重新创建的。
  4. 陷阱:Tokenizer 不匹配

    • 症状:模型在处理新领域术语时表现很差,频繁出现 OOV (Out-of-Vocabulary) 或将专有名词切分成无意义的碎片。
    • 诊断与修复:这是一个更深层次的问题。如果新领域引入了大量新词汇,可能需要扩展 Tokenizer。这需要:1) 训练新的 Tokenizer;2) 扩展模型的 token_embeddinglm_head 层;3) 特殊处理新旧 embedding 的初始化。这已经超出了标准 CPT 的范畴,需要专门的模型手术,务必谨慎。
  5. 陷阱:隐蔽的能力衰退

    • 症状:通用 PPL 保持稳定,但模型在某些特定能力(如数学推理、代码生成)上表现变差。
    • 诊断与修复:PPL 是一个宏观指标。建立一个包含多种能力的评估基准测试集(Benchmark Suite),在 CPT 前后运行,以量化特定能力的损益。如果发现关键能力受损,可能需要在“旧数据”的 replay buffer 中,有针对性地增加该能力对应的数据比例。