第十二章 端到端:CPT / 继续预训练
开篇段落
欢迎来到第十二章,本章将深入探讨继续预训练(Continued Pre-Training, CPT)的端到端实践,这是一项在现代大模型生命周期管理中愈发核心的技术。与从零预训练(Pre-training from scratch)那种“一次性”的巨大投入不同,CPT 是一种更具经济效益和战略价值的“增量式”演进。它允许我们为一个已经强大的基座模型注入新领域的专业知识、更新过时的世界信息,或强化特定的技能,从而打造一个“活的”、持续进化的模型。本章将系统性地拆解 CPT 的核心哲学,深入剖析其在数据策略、优化器与学习率调度、监控与评等方面的关键差异,并提供一套可立即上手的配置指南与诊断清单。我们的目标是让你掌握在“新域能力拉升”与“基座泛化保持”之间取得精妙平衡的艺术,从而高效地实现模型价值的最大化。
CPT 的哲学:为何选择 CPT 而非 SFT 或从零训练?
在深入技术细节前,我们必须明确 CPT 的定位。它处在从零预训练和指令微调(Supervised Fine-Tuning, SFT)之间,解决的是不同的问题。
- 从零预训练:构建模型的基础世界模型和语言能力。这是一个无监督或自监督的过程,需要海量的通用数据和巨大的计算资源。
- 指令微调(SFT):教会模型如何与人对话,遵循指令的格式和行为。SFT 的数据通常是
(prompt, response)对,其核心是塑造模型的行为模式,而非注入大规模的事实性知识。 - 继续预训练(CPT):在不破坏基础世界模型的前提下,扩展型的知识边界和适应新的数据分布。CPT 的数据仍然是海量的非结构化文本,其目标是让模型“阅读”和“吸收”新领域的知识,使其语言模型本身(即对下一个 token 的预测分布)适应新的领域。
一句话总结:CPT 负责“学什么”(知识),SFT 负责“怎么说”(行为)。 当你的目标是让模型成为一个“医学专家”而不仅仅是“扮演医学专家的聊天机器人”时,CPT 是不可或缺的一步。
核心挑战:驾驭灾难性遗忘的艺术
CPT 的核心挑战是灾难性遗忘(Catastrophic Forgetting)。我们可以从优化理论的角度理解它:
一个预训练好的基座模型,其权重 θ_base 处在通用数据分布 P_general 的一个宽阔、平坦的损失极小值区域(a wide, flat loss minimum)。这个区域的平坦性赋予了模型良好的泛化能力。
CPT 的目标是找到一个新的权重 θ_cpt,使其在混合数据分布 P_mix = (1-α)P_general + αP_new上表现更优。然而,新领域数据 P_new 对应的损失曲面可能非常陡峭或与 P_general 的极小值区域相距甚远。如果训练策略过于激进(如学习率过高、新数据比例过大),优化器会把模型权重 θ 猛地“拽”出 P_general 的平坦区域,掉入一个只对 P_new 友好的狭窄深谷。模型在新领域上表现优异,但彻底忘记了如何在通用领域上泛化——这就是灾难性遗忘。
Loss Landscape
▲ Loss
|
| /-----\
| / \ <-- P_new's sharp minimum
| / * θ_cpt_bad
/---\ /
/ \--------/
/ * θ_base
* θ_cpt_good
--------------------------------------------> Weight Space
CPT 的所有策略,都是为了让 θ 从 θ_base 温和地移动到 θ_cpt_good,一个同时兼顾两个分布的新平衡点,而不是被拉到灾难性的 θ_cpt_bad。
一、数据策略CPT 的灵魂与基石
数据是 CPT 中最关键的变量,它直接决定了学习的效果和遗忘的程度。
1. 混比策略:黄金比例的探索
单纯使用新数据是 CPT 的大忌。数据混合是必须的。
- 混比比例(Mixing Ratio,
α):即新数据在新旧混合数据中的占比。- 经验法则:
α通常设定在 5% 到 20% 之间。一个稳妥的起点是 10%。 - 权衡:
- 低
α(如 5%):学习新知识较慢,需要更多训练步数,但遗忘风险极低,非常安全。适用于基座能力绝对不能受损的关键任务。 - 高
α(如 20%):学习新知识速度快,但遗忘风险显著增高。适用于新领域与旧领域差异不大,或对基座泛化能力有一定容忍度的场景。
- 低
- 经验法则:
- "Replay Buffer" 思想:作为旧数据的
P_general,不一定需要使用全部的原始预训练语料。可以精心挑选一个高质量、多样化的子集(例如 1T tokens 的原始数据中,精选 200B tokens),作为“记忆回放缓冲池”。这可以显著降低数据存储和 I/O 负担。
2. 混比粒度:动态采样的优越性
如何实现数据混合?有两种主要方式:
- 静态混合(预先混合):在训练开始前,将新旧数据按比例混合并打乱,制作成新的训练数据集。
- 优点:实现简单,数据加载逻辑不变。
- 缺点:不够灵活,无法在训练中调整比例。可能出现连续多个 batch 来自同一数据源的“数据热点”问题。
- 动态混合(在线采样):在训练时,数据加载器从不同的数据源(旧数据流、新数据流)中,按预设的概率或温度动态采样,实时组合成一个 batch。
- 优点:
- 均匀混合:确保每个 Global Batch 的数据构成都接近目标比例,训练过程更平滑。
- 灵活性:可以轻松实现课程学习(Curriculum Learning),例如在训练初期使用 10% 的新数据,后期逐步提升到 15%。
- 实现:在 PyTorch Lightning 中,可以通过自定义
DataModule或IterableDataset来实现,根据设定的温度τ对不同数据集的采样概率进行加权。
- 优点:
3. 数据质量与序列长度
- 新数据质量:CPT 对新领域数据的质量极其敏感。低质量、噪声大的新数据会严重干扰模型的学习过程,甚至污染已有知识。必须进行严格的清洗、去重和过滤。
- 序列长度分布:检查新旧数据的序列长度分布是否匹配。如果新领域(如法律文书)平均长度远超通用领域,可能会导致模型在长文本建模上产生偏向。此时,需要确保 Packer 能有效处理,并且在评估时也关注不同长度文本的表现。
二、优化器与学习率:精细的手术刀
如果说数据是药物,那么优化器和学习率就是控制剂量的手术刀。
1. 学习率(LR:低、缓、稳
- 峰值学习率(Peak LR):
- 核心法则:CPT 的峰值 LR 应显著低于从零预训练,通常是原始峰值 LR 的 1/10 到 1/5。例如,7B 模型从零训练的 peak LR 为
3e-4,CPT 时应选择3e-5到6e-5。 - 原因:模型权重已处于良好状态,梯度更新应是微调而非重塑。过高的 LR 会产生巨大的梯度,将权重推出稳定区域。
- 核心法则:CPT 的峰值 LR 应显著低于从零预训练,通常是原始峰值 LR 的 1/10 到 1/5。例如,7B 模型从零训练的 peak LR 为
- 学习率调度器(LR Scheduler):
- Warmup:可以非常短,甚至省略。一个几百步的 warmup 足以让重置后的优化器稳定下来。
- Decay 策略:Cosine Decay 依然是黄金标准。它能确保学习率在训练结束时平滑地降至接近零,有助于模型收敛到更稳定的点。
- 总步数(Total Steps):LR 调度器的总步数必须根据 CPT 的总训练 tokens(例如 200B)重新计算,而不是沿用原始的 1T tokens 对应的步数。这是一个非常常见的错误。
2. 优化器状态:必须从零开始
这是一个非黑即白的规则:绝对不要加载 Checkpoint 中的优化器状态。
- 技术解释:AdamW 等优化器维护着梯度的一阶矩(
m,动量)和二阶矩(v,自适应学习率的分母)。这些状态编码了基于旧数据分布和旧学习率的梯度历史。- 旧的
m会带来巨大的惯性,可能在 CPT 初期将模型推向错误的方向。 - 旧的
v适应了旧的梯度尺度,直接用于新的梯度可能会导致某些参数的学习率过大或过小,造成不稳定。
- 旧的
- 正确操作:加载模型权重后,重新初始化一个全新的优化器实例。让它在新的数据混合和学习率下,从头开始累积梯度统计信息。
3. 模型冻结策略:原则上不冻结
- 主流选择:全参数训练。CPT 的目标是让知识渗透到模型的每一部分,从底层的词嵌入到高层的概念推理。冻结任何层都可能阻碍这种全局性的适应。
- 实验性探索(高风险):在极少数情况下,如果极端担心某些基础能力(如语法结构)受损,可以尝试冻结模型的最底层几层 Transformer Block。但这通常会牺牲在新领域上的学习效果,不推荐作为首选策略。
三、监控与评估:CPT 的仪表盘
CPT 的成功无法仅凭训练损失来判断,必须建立一个多维度的评估体系。
- 核心监控指标:
- 训练损失(Training Loss):应平稳下降。若出现剧烈抖动,检查数据混合或学习率。
- 领域内验证困惑度(In-Domain Val PPL):使用一个 held-out 的新领域验证集。这是衡量 CPT 学习效果的核心指标,应持续下降。
- 通用域验证困惑度(General Val PPL):使用一个 held-out 的通用领域验证集(如 C4, Pile 的一部分)。这是衡量遗忘程度的核心“护栏”指标,应保持稳定或仅轻微上涨。
- 诊断图谱:
(A) 理想的 CPT (B) 灾难性遗忘 (C) 学习不足
▲ PPL ▲ PPL ▲ PPL
| | |
|---- Gen PPL ---- |---- Gen PPL ---▲ |---- Gen PPL ----
| \ | \ / | \
| \ | \ / | \ In-Domain PPL
| \ In-Domain PPL | \ / | (下降缓慢)
| \ | \ In-Domain PPL |
+-----------> Time +-----------> Time +-----------> Time
定期绘制这两条 PPL 曲线,上图 (A) 是我们的目标。出现 (B) 则立即降低 LR 或新数据比例;出现 (C) 则可考虑适度增加它们。
端到端配置清单:一个 7B 模型的 CPT 示例(扩展示例)
| 参数项 | 从零预训练(参考基线) | CPT(推荐配置) | 理由与深度剖析 |
| 参数项 | 从零预训练(参考基线) | CPT(推荐配置) | 理由与深度剖析 |
|---|---|---|---|
| 基座模型 | None (机初始化) |
path/to/pretrained_7B_model.pt |
CPT 的基础是加载一个训练成熟的模型权重。 |
| 总训练 Tokens | 1T |
200B |
CPT 是短程训练,目标是适应而非重塑。200B tokens 通常足以在一个新领域达到不错的性能。 |
数据混比 (α) |
100% 通用数据 | 90% 通用数据 + 10% 新领域数据 | CPT 核心。10% 是一个平衡的起点,在安全和效率之间取得平衡。 |
| 数据混合方式 | N/A | 动态采样 (Dynamic Sampling) | 保证每个 batch 的数据分布均匀,避免训练不稳,并为课程学习提供可能。 |
| Global Batch Size | 4M tokens |
2M - 4M tokens |
可以沿用,或适当减小。因为 LR 已经降低,过大的 batch 带来的梯度信噪比增益不再是首要矛盾。 |
| 峰值学习率 (Peak LR) | 3.0e-4 |
3.0e-5 |
CPT 核心。1/10 的原始 LR 是一个安全的起点,它以足够小的步长探索新的损失极小值区域。 |
| LR Warmup 步数 | 2000 |
200 |
模型已处在稳定状态,一个非常短的 warmup 即可让新优化器适应梯度。 |
| LR Scheduler | Cosine Decay (over ~250k steps) | Cosine Decay (over ~50k steps) | 易错点:调度器的总步数必须基于新的、更短的训练计划重新计算。 |
| 优化器状态加载 | False |
False (强制) |
CPT 核心。必须丢弃旧的 m 和 v 向量,它们包含着不适用于新任务的梯度历史。 |
| AdamW Betas | (0.9, 0.95) |
(0.9, 0.95) |
通常无需更改。AdamW 的 betas 对 CPT 不如 LR 敏感。 |
| Weight Decay | 0.1 |
0.1 |
可以保持不变。过高的 weight decay 反而可能阻碍模型向新权重空间的适应。 |
| 模型层冻结 | 无 | 无(全参数训练) | 默认进行全参数更新,以实现知识在整个模型中的充分渗透。 |
本章小结
- CPT 是介于预训练和微调之间的关键技术,旨在扩展知识而非改变行为,其核心挑战是克服灾性遗忘。
- 成功的 CPT 依赖于三大支柱的精细调校:
- 数据策略:新旧数据混合是铁律,
5%-20%的新数据比例是常见区间。推荐使用动态采样以获得更平滑的训练过程。 - 学习率与优化器:采用远低于原始训练的峰值学习率(如 1/10),配合短 warmup 和对齐新总步数的 Cosine Decay。必须重置优化器状态。
- 监控体系:同时跟踪领域内和通用域的验证 PPL,是诊断 CPT 是否成功的“仪表盘”。
- 数据策略:新旧数据混合是铁律,
- CPT 是一场精细的平衡实验,需要耐心调参和细致观察,但其带来的模型能力提升和资源节约是巨大的。
常见陷阱与错误 (Gotchas)
-
陷阱:灾难性遗忘
- 症状:通用域验证 PPL 急剧上升,模型在通用任务上表现得像一个“傻瓜”。
- 诊断与修复:降低学习率(减半),或降低新数据混合比例(例如从 15% 降至 8%。这是最常见的 CPT 问题。
-
陷阱:学习停滞
- 症状:领域内验证 PPL 下降缓慢或完全不下降。
- 诊断与修复:首先检查新数据质量。若数据无误,可尝试适度提高学习率或新数据比例。也可能是总训练 tokens 不足,模型需要“阅读”更多新数据。
-
陷阱:加载了优化器状态
- 症状:训练初期的 loss 出现剧烈震荡、
NaN或不收敛。 - 诊断与修复:检查你的模型加载和优化器初始化代码。确保优化器是在加载模型权重 之后 重新创建的。
- 症状:训练初期的 loss 出现剧烈震荡、
-
陷阱:Tokenizer 不匹配
- 症状:模型在处理新领域术语时表现很差,频繁出现 OOV (Out-of-Vocabulary) 或将专有名词切分成无意义的碎片。
- 诊断与修复:这是一个更深层次的问题。如果新领域引入了大量新词汇,可能需要扩展 Tokenizer。这需要:1) 训练新的 Tokenizer;2) 扩展模型的
token_embedding和lm_head层;3) 特殊处理新旧 embedding 的初始化。这已经超出了标准 CPT 的范畴,需要专门的模型手术,务必谨慎。
-
陷阱:隐蔽的能力衰退
- 症状:通用 PPL 保持稳定,但模型在某些特定能力(如数学推理、代码生成)上表现变差。
- 诊断与修复:PPL 是一个宏观指标。建立一个包含多种能力的评估基准测试集(Benchmark Suite),在 CPT 前后运行,以量化特定能力的损益。如果发现关键能力受损,可能需要在“旧数据”的 replay buffer 中,有针对性地增加该能力对应的数据比例。