第二章:实验代码基础设施
构建可维护、可扩展的实验代码架构是成功进行 LLM 后训练的基石。本章将深入探讨如何设计和实现一个健壮的实验基础设施,涵盖配置管理、版本控制、实验追踪等关键组件。我们将重点解决实际工程中的挑战:如何在快速迭代的同时保持代码质量,如何管理数百个实验的配置和结果,以及如何防止技术债务的累积。
2.1 实验配置管理
配置文件格式选择
在 LLM 后训练项目中,配置管理的复杂度远超传统深度学习项目。一个典型的实验可能包含上百个超参数,涉及模型架构、训练策略、数据处理、评估指标等多个维度。选择合适的配置格式至关重要。
YAML 配置的优势与劣势
YAML 因其可读性强而广受欢迎,特别适合嵌套结构的表达:
model:
architecture: llama2
hidden_size: 4096
num_layers: 32
attention:
num_heads: 32
head_dim: 128
rotary_embedding:
base: 10000
scaling_factor: 1.0
优势:
- 人类可读性最佳,适合配置审查
- 支持注释,便于文档化
- 层次结构清晰,适合复杂配置
- 生态系统成熟,工具链完善
劣势:
- 缩进敏感,容易出错
- 类型推断可能产生意外(如 "no" 被解析为布尔值)
- 不支持变量引用和计算表达式
- 大文件解析速度较慢
TOML 配置的权衡
TOML 在 Rust 和 Python 社区逐渐流行,提供了更严格的语法:
[model]
architecture = "llama2"
hidden_size = 4096
num_layers = 32
[model.attention]
num_heads = 32
head_dim = 128
[model.attention.rotary_embedding]
base = 10000
scaling_factor = 1.0
优势:
- 语法明确,歧义少
- 原生支持日期时间类型
- 表格数组语法适合批量实验配置
劣势:
- 深层嵌套可读性下降
- 数组和内联表的语法较复杂
- 生态系统相对较新
Python 配置的灵活性
直接使用 Python 文件作为配置提供了最大的灵活性:
from dataclasses import dataclass
from typing import Optional
@dataclass
class ModelConfig:
architecture: str = "llama2"
hidden_size: int = 4096
num_layers: int = 32
@property
def total_params(self) -> int:
# 动态计算参数量
return self.calculate_params()
def scale_model(self, factor: float):
"""动态调整模型规模"""
self.hidden_size = int(self.hidden_size * factor)
self.num_layers = int(self.num_layers * factor)
优势:
- 支持动态计算和条件逻辑
- 类型检查和 IDE 支持完善
- 可以复用代码和导入模块
- 支持配置验证和默认值
劣势:
- 安全性风险(执行任意代码)
- 非技术人员难以修改
- 版本控制中 diff 可读性较差
配置继承与覆盖机制
实践中,我们通常需要一个基础配置和多个实验变体。设计良好的继承机制可以大幅减少配置冗余:
class ConfigManager:
def __init__(self, base_config_path: str):
self.base_config = self.load_config(base_config_path)
self.inheritance_chain = [base_config_path]
def inherit_from(self, parent_config_path: str):
"""支持多级继承"""
parent_config = self.load_config(parent_config_path)
self.base_config = self.deep_merge(parent_config, self.base_config)
self.inheritance_chain.append(parent_config_path)
def override(self, overrides: Dict[str, Any]):
"""支持命令行覆盖"""
for key_path, value in overrides.items():
self.set_nested_value(key_path, value)
def deep_merge(self, base: Dict, override: Dict) -> Dict:
"""递归合并配置字典"""
result = base.copy()
for key, value in override.items():
if key in result and isinstance(result[key], dict) and isinstance(value, dict):
result[key] = self.deep_merge(result[key], value)
else:
result[key] = value
return result
配置覆盖的优先级设计
一个清晰的优先级体系避免了配置冲突:
- 命令行参数(最高优先级)
- 环境变量
- 实验特定配置文件
- 用户配置文件
- 项目默认配置(最低优先级)
优先级链:
CLI Args > ENV > experiment.yaml > user.yaml > default.yaml
配置验证与类型检查
使用 Pydantic 或 attrs 进行运行时验证可以及早发现配置错误:
from pydantic import BaseModel, validator, Field
from typing import Literal
class TrainingConfig(BaseModel):
learning_rate: float = Field(gt=0, le=1.0)
batch_size: int = Field(gt=0, multiple_of=8)
optimizer: Literal["adam", "sgd", "adamw"]
gradient_accumulation_steps: int = Field(gt=0)
@validator("batch_size")
def validate_batch_size(cls, v, values):
if "gradient_accumulation_steps" in values:
effective_batch = v * values["gradient_accumulation_steps"]
if effective_batch > 65536:
raise ValueError(f"Effective batch size {effective_batch} too large")
return v
@validator("learning_rate")
def validate_lr_schedule(cls, v, values):
if values.get("optimizer") == "sgd" and v > 0.1:
raise ValueError("SGD learning rate typically should be < 0.1")
return v
2.2 Flag、环境变量与 Git 分支策略
Command-line Flags 的设计原则
命令行参数是实验配置的第一接触点,良好的设计能显著提升实验效率。以下是经过大规模实验验证的设计原则:
层次化的参数组织
避免平铺所有参数,而是按功能域组织:
import argparse
def create_parser():
parser = argparse.ArgumentParser()
# 使用参数组提高可读性
model_group = parser.add_argument_group("model")
model_group.add_argument("--model.name", default="llama2-7b")
model_group.add_argument("--model.checkpoint", type=str)
model_group.add_argument("--model.dtype", choices=["fp32", "fp16", "bf16"])
training_group = parser.add_argument_group("training")
training_group.add_argument("--training.batch_size", type=int, default=32)
training_group.add_argument("--training.learning_rate", type=float, default=1e-4)
training_group.add_argument("--training.warmup_steps", type=int, default=1000)
data_group = parser.add_argument_group("data")
data_group.add_argument("--data.train_path", required=True)
data_group.add_argument("--data.val_path", required=True)
data_group.add_argument("--data.num_workers", type=int, default=4)
return parser
智能默认值与必需参数
区分必需参数和可选参数,为常见场景提供合理默认值:
class FlagValidator:
@staticmethod
def validate_flags(args):
# 自动推断相关参数
if args.distributed and args.local_rank is None:
args.local_rank = int(os.environ.get("LOCAL_RANK", 0))
# 根据硬件自动设置
if args.device == "auto":
args.device = "cuda" if torch.cuda.is_available() else "cpu"
# 批量大小自动调整
if args.gradient_checkpointing and args.batch_size > 16:
logger.warning(f"Reducing batch size from {args.batch_size} to 16 due to gradient checkpointing")
args.batch_size = 16
return args
参数别名与简写
为常用参数提供简写,提高命令行效率:
parser.add_argument("-b", "--batch-size", "--training.batch_size",
dest="batch_size", type=int, default=32,
help="Training batch size per device")
parser.add_argument("-lr", "--learning-rate", "--training.lr",
dest="learning_rate", type=float, default=1e-4,
help="Peak learning rate")
parser.add_argument("-e", "--epochs", "--num-epochs",
dest="num_epochs", type=int, default=3,
help="Number of training epochs")
环境变量的使用场景
环境变量适合管理跨实验的全局设置和敏感信息:
分层的环境变量体系
# 系统级别(集群配置)
export CUDA_VISIBLE_DEVICES=0,1,2,3
export OMP_NUM_THREADS=8
export NCCL_DEBUG=INFO
# 项目级别(路径和凭证)
export LLM_DATA_ROOT=/mnt/data/llm_datasets
export LLM_CHECKPOINT_DIR=/mnt/checkpoints
export WANDB_API_KEY=your_api_key_here
export HF_TOKEN=your_huggingface_token
# 实验级别(运行时配置)
export LLM_EXPERIMENT_NAME=dpo_ablation_v3
export LLM_RUN_ID=$(date +%Y%m%d_%H%M%S)
export LLM_DEBUG_MODE=1
环境变量的最佳实践
import os
from pathlib import Path
from typing import Optional
class EnvConfig:
"""统一管理环境变量"""
@staticmethod
def get_data_root() -> Path:
"""获取数据根目录,支持多级fallback"""
candidates = [
os.environ.get("LLM_DATA_ROOT"),
os.environ.get("DATA_ROOT"),
"/data/llm",
"./data"
]
for path in candidates:
if path and Path(path).exists():
return Path(path)
raise ValueError("No valid data root found")
@staticmethod
def get_wandb_config() -> dict:
"""安全地获取 W&B 配置"""
config = {}
if api_key := os.environ.get("WANDB_API_KEY"):
config["api_key"] = api_key
if project := os.environ.get("WANDB_PROJECT"):
config["project"] = project
if entity := os.environ.get("WANDB_ENTITY"):
config["entity"] = entity
return config
@staticmethod
def is_debug_mode() -> bool:
"""检查调试模式"""
return os.environ.get("LLM_DEBUG_MODE", "0").lower() in ("1", "true", "yes")
Git 分支管理实践
在快速迭代的实验环境中,Git 分支策略需要平衡实验自由度和代码质量:
实验分支命名规范
# 功能开发分支
feature/distributed-dpo
feature/multimodal-alignment
# 实验分支(短期)
exp/20250105-lr-sweep
exp/20250106-batch-size-ablation
# 个人实验分支(更自由)
dev/alice/rope-scaling
dev/bob/attention-variants
# 长期研究分支
research/constitutional-ai
research/online-rlhf
分支保护与合并策略
# .github/branch_protection.yml
protection_rules:
main:
required_reviews: 2
dismiss_stale_reviews: true
require_code_owner_reviews: true
required_status_checks:
- lint
- type-check
- unit-tests
enforce_admins: false
release/*:
required_reviews: 3
require_code_owner_reviews: true
required_status_checks:
- all-tests
- integration-tests
- benchmark-regression
实验分支的生命周期管理
#!/bin/bash
# scripts/manage_exp_branches.sh
# 自动清理过期的实验分支
cleanup_old_exp_branches() {
local days_old=${1:-30}
git for-each-ref --format='%(refname:short) %(committerdate:unix)' refs/heads/exp/ | \
while read branch timestamp; do
age_days=$(( ($(date +%s) - timestamp) / 86400 ))
if [ $age\_days -gt $days_old ]; then
echo "Deleting old experimental branch: $branch (${age_days} days old)"
git branch -D "$branch"
fi
done
}
# 归档重要实验分支
archive_exp_branch() {
local branch=$1
local archive_tag="archive/$(date +%Y%m%d)/${branch##*/}"
git tag -a "$archive\_tag" "$branch" -m "Archived experimental branch $branch"
git push origin "$archive_tag"
git branch -d "$branch"
echo "Archived $branch as $archive_tag"
}
2.3 实验追踪与版本控制
实验追踪工具选择
选择合适的实验追踪工具是建立可重现研究流程的关键。主流工具各有特色,需要根据团队规模和需求选择。
MLflow:开源标准的选择
MLflow 提供了完整的实验生命周期管理:
import mlflow
from mlflow.tracking import MlflowClient
import hashlib
import json
class MLflowExperimentTracker:
def __init__(self, experiment_name: str, tracking_uri: str = "file:./mlruns"):
mlflow.set_tracking_uri(tracking_uri)
mlflow.set_experiment(experiment_name)
self.client = MlflowClient()
def start_run(self, config: dict, tags: dict = None):
"""开始一个新的实验运行"""
# 生成配置哈希作为运行标识
config_hash = hashlib.md5(
json.dumps(config, sort_keys=True).encode()
).hexdigest()[:8]
run_name = f"{config.get('model_name', 'unknown')}_{config_hash}"
with mlflow.start_run(run_name=run_name) as run:
# 记录配置参数
mlflow.log_params(self.flatten_dict(config))
# 记录标签
if tags:
mlflow.set_tags(tags)
# 记录代码版本
mlflow.set_tag("git_commit", self.get_git_commit())
mlflow.set_tag("git_branch", self.get_git_branch())
return run.info.run_id
def log_metrics_batch(self, metrics: dict, step: int):
"""批量记录指标"""
for key, value in metrics.items():
mlflow.log_metric(key, value, step=step)
def log_artifact_with_metadata(self, file_path: str, metadata: dict):
"""记录文件及其元数据"""
mlflow.log_artifact(file_path)
# 同时记录元数据
metadata_path = f"{file_path}.metadata.json"
with open(metadata_path, 'w') as f:
json.dump(metadata, f, indent=2)
mlflow.log_artifact(metadata_path)
Weights & Biases:云原生的强大功能
W&B 提供了更丰富的可视化和协作功能:
import wandb
from typing import Any, Dict, Optional
import numpy as np
class WandBTracker:
def __init__(self, project: str, entity: Optional[str] = None):
self.project = project
self.entity = entity
def init_run(self, config: dict, name: Optional[str] = None,
resume: Optional[str] = None):
"""初始化 W&B 运行"""
run = wandb.init(
project=self.project,
entity=self.entity,
config=config,
name=name,
resume=resume, # 支持断点续训
save_code=True, # 自动保存代码
tags=self.generate_tags(config)
)
# 定义自定义指标
wandb.define_metric("train/step")
wandb.define_metric("train/*", step_metric="train/step")
wandb.define_metric("eval/step")
wandb.define_metric("eval/*", step_metric="eval/step")
return run
def log_distribution(self, name: str, data: np.ndarray, step: int):
"""记录数据分布"""
wandb.log({
f"{name}/mean": np.mean(data),
f"{name}/std": np.std(data),
f"{name}/min": np.min(data),
f"{name}/max": np.max(data),
f"{name}/histogram": wandb.Histogram(data)
}, step=step)
def log_gradient_flow(self, model, step: int):
"""记录梯度流信息"""
gradients = []
for name, param in model.named_parameters():
if param.grad is not None:
grad_norm = param.grad.data.norm(2).item()
gradients.append({
"name": name,
"grad_norm": grad_norm
})
# 创建梯度表格
grad_table = wandb.Table(
columns=["layer", "gradient_norm"],
data=[[g["name"], g["grad_norm"]] for g in gradients]
)
wandb.log({"gradients": grad_table}, step=step)
TensorBoard:轻量级本地方案
对于不需要云服务的场景,TensorBoard 仍是可靠选择:
from torch.utils.tensorboard import SummaryWriter
import torch
from pathlib import Path
class TensorBoardTracker:
def __init__(self, log_dir: str, comment: str = ""):
self.log_dir = Path(log_dir)
self.writer = SummaryWriter(
log_dir=str(self.log_dir),
comment=comment,
flush_secs=30 # 定期刷新到磁盘
)
def log_model_architecture(self, model: torch.nn.Module, input_shape: tuple):
"""记录模型架构"""
dummy_input = torch.randn(input_shape)
self.writer.add_graph(model, dummy_input)
def log_attention_weights(self, attention_weights: torch.Tensor,
step: int, head_idx: int = 0):
"""可视化注意力权重"""
# 选择特定的注意力头
attn = attention_weights[0, head_idx].cpu().numpy()
# 创建热力图
fig = plt.figure(figsize=(10, 8))
plt.imshow(attn, cmap='hot', interpolation='nearest')
plt.colorbar()
plt.title(f'Attention Weights - Head {head_idx}')
self.writer.add_figure(f'attention/head_{head_idx}', fig, step)
plt.close()
def log_learning_rate_schedule(self, optimizer, step: int):
"""记录学习率变化"""
for i, param_group in enumerate(optimizer.param_groups):
lr = param_group['lr']
self.writer.add_scalar(f'learning_rate/group_{i}', lr, step)
实验元数据管理
完整的元数据记录是实验可重现性的基础:
import platform
import subprocess
import datetime
import psutil
import GPUtil
class ExperimentMetadata:
@staticmethod
def collect_system_info() -> dict:
"""收集系统信息"""
return {
"platform": platform.platform(),
"python_version": platform.python_version(),
"cpu_count": psutil.cpu_count(),
"memory_gb": psutil.virtual_memory().total / (1024**3),
"hostname": platform.node(),
"user": os.environ.get("USER", "unknown")
}
@staticmethod
def collect_gpu_info() -> list:
"""收集 GPU 信息"""
gpus = GPUtil.getGPUs()
return [{
"id": gpu.id,
"name": gpu.name,
"memory_total": gpu.memoryTotal,
"driver": gpu.driver,
"compute_capability": f"{gpu.major}.{gpu.minor}"
} for gpu in gpus]
@staticmethod
def collect_dependencies() -> dict:
"""收集依赖版本"""
deps = {}
try:
import torch
deps["torch"] = torch.__version__
deps["cuda"] = torch.version.cuda if torch.cuda.is_available() else None
except ImportError:
pass
try:
import transformers
deps["transformers"] = transformers.__version__
except ImportError:
pass
# 从 requirements.txt 或 pyproject.toml 读取
if Path("requirements.txt").exists():
with open("requirements.txt") as f:
for line in f:
if "==" in line:
pkg, version = line.strip().split("==")
deps[pkg] = version
return deps
@staticmethod
def create_experiment_card(config: dict) -> dict:
"""创建实验卡片"""
return {
"experiment_id": str(uuid.uuid4()),
"timestamp": datetime.datetime.now().isoformat(),
"config": config,
"system": ExperimentMetadata.collect_system_info(),
"gpus": ExperimentMetadata.collect_gpu_info(),
"dependencies": ExperimentMetadata.collect_dependencies(),
"git": {
"commit": subprocess.check_output(
["git", "rev-parse", "HEAD"]
).decode().strip(),
"branch": subprocess.check_output(
["git", "branch", "--show-current"]
).decode().strip(),
"diff": subprocess.check_output(
["git", "diff", "HEAD"]
).decode()
}
}
模型检查点策略
高效的检查点管理对于长时间训练至关重要:
class CheckpointManager:
def __init__(self, checkpoint_dir: Path, max_checkpoints: int = 5):
self.checkpoint_dir = Path(checkpoint_dir)
self.checkpoint_dir.mkdir(parents=True, exist_ok=True)
self.max_checkpoints = max_checkpoints
def save_checkpoint(self, model, optimizer, epoch: int,
metrics: dict, is_best: bool = False):
"""保存检查点"""
checkpoint = {
"epoch": epoch,
"model_state_dict": model.state_dict(),
"optimizer_state_dict": optimizer.state_dict(),
"metrics": metrics,
"timestamp": datetime.datetime.now().isoformat()
}
# 常规检查点
checkpoint_path = self.checkpoint_dir / f"checkpoint_epoch_{epoch}.pt"
torch.save(checkpoint, checkpoint_path)
# 最佳模型
if is_best:
best_path = self.checkpoint_dir / "best_model.pt"
torch.save(checkpoint, best_path)
# 清理旧检查点
self._cleanup_old_checkpoints()
def _cleanup_old_checkpoints(self):
"""保留最新的N个检查点"""
checkpoints = sorted(
self.checkpoint_dir.glob("checkpoint_epoch_*.pt"),
key=lambda x: x.stat().st_mtime
)
if len(checkpoints) > self.max_checkpoints:
for ckpt in checkpoints[:-self.max_checkpoints]:
ckpt.unlink()
def resume_from_checkpoint(self, checkpoint_path: Path,
model, optimizer) -> dict:
"""从检查点恢复"""
checkpoint = torch.load(checkpoint_path, map_location="cpu")
model.load_state_dict(checkpoint["model_state_dict"])
optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
return {
"epoch": checkpoint["epoch"],
"metrics": checkpoint["metrics"]
}
2.4 防止代码腐化的最佳实践
技术债务管理
LLM 后训练项目的快速迭代容易累积技术债务。主动管理技术债务是保持项目长期健康的关键。
技术债务的量化与追踪
from typing import List, Dict
import ast
import re
class TechnicalDebtAnalyzer:
def __init__(self, codebase_path: Path):
self.codebase_path = codebase_path
self.debt_markers = ["TODO", "FIXME", "HACK", "XXX", "DEPRECATED"]
def scan_codebase(self) -> Dict[str, List[Dict]]:
"""扫描代码库中的技术债务标记"""
debt_items = {marker: [] for marker in self.debt_markers}
for py_file in self.codebase_path.rglob("*.py"):
with open(py_file, 'r') as f:
for line_num, line in enumerate(f, 1):
for marker in self.debt_markers:
if marker in line:
debt_items[marker].append({
"file": str(py_file.relative_to(self.codebase_path)),
"line": line_num,
"content": line.strip(),
"priority": self.estimate_priority(line)
})
return debt_items
def calculate_complexity_metrics(self, file_path: Path) -> dict:
"""计算代码复杂度指标"""
with open(file_path, 'r') as f:
source = f.read()
tree = ast.parse(source)
metrics = {
"cyclomatic_complexity": self.calculate_cyclomatic_complexity(tree),
"lines_of_code": len(source.splitlines()),
"num_functions": len([n for n in ast.walk(tree) if isinstance(n, ast.FunctionDef)]),
"num_classes": len([n for n in ast.walk(tree) if isinstance(n, ast.ClassDef)]),
"max_nesting_depth": self.calculate_max_nesting(tree)
}
return metrics
def generate_debt_report(self) -> str:
"""生成技术债务报告"""
debt_items = self.scan_codebase()
total_debt = sum(len(items) for items in debt_items.values())
report = f"""
# 技术债务报告
生成时间: {datetime.datetime.now().isoformat()}
总债务项: {total_debt}
## 按类型分布
"""
for marker, items in debt_items.items():
report += f"- {marker}: {len(items)} 项\n"
# 高优先级项目
high_priority = []
for marker, items in debt_items.items():
high_priority.extend([
item for item in items
if item["priority"] == "high"
])
if high_priority:
report += "\n## 高优先级债务\n"
for item in high_priority[:10]: # 显示前10个
report += f"- {item['file']}:{item['line']} - {item['content']}\n"
return report
代码质量门禁
class CodeQualityGate:
def __init__(self, thresholds: dict):
self.thresholds = thresholds
def check_diff_quality(self, diff_file: str) -> bool:
"""检查代码变更的质量"""
checks = {
"no_print_statements": self.check_no_prints(diff_file),
"has_tests": self.check_has_tests(diff_file),
"docstring_coverage": self.check_docstrings(diff_file),
"type_hints": self.check_type_hints(diff_file)
}
failures = [name for name, passed in checks.items() if not passed]
if failures:
print(f"Quality gate failed: {', '.join(failures)}")
return False
return True
def check_no_prints(self, diff: str) -> bool:
"""检查是否包含调试用的print语句"""
pattern = r'\+.*print\('
return not re.search(pattern, diff)
def check_has_tests(self, diff: str) -> bool:
"""检查是否包含对应的测试"""
# 如果修改了src/下的文件,应该有对应的test/下的修改
src_modified = "src/" in diff
test_modified = "test/" in diff or "tests/" in diff
if src_modified and not test_modified:
return False
return True
代码复用与模块化
良好的模块化设计是防止代码腐化的基础:
from abc import ABC, abstractmethod
from typing import Generic, TypeVar
T = TypeVar('T')
class BaseExperiment(ABC, Generic[T]):
"""实验基类,强制规范化实验流程"""
def __init__(self, config: dict):
self.config = config
self.setup()
@abstractmethod
def setup(self):
"""初始化实验环境"""
pass
@abstractmethod
def prepare_data(self) -> T:
"""数据准备"""
pass
@abstractmethod
def build_model(self):
"""构建模型"""
pass
@abstractmethod
def train_step(self, batch: T) -> dict:
"""单步训练"""
pass
@abstractmethod
def evaluate(self) -> dict:
"""评估"""
pass
def run(self):
"""标准化的实验流程"""
data = self.prepare_data()
model = self.build_model()
for epoch in range(self.config["num_epochs"]):
for batch in data:
metrics = self.train_step(batch)
self.log_metrics(metrics)
eval_metrics = self.evaluate()
self.log_eval_metrics(eval_metrics)
组件注册机制
class ComponentRegistry:
"""统一的组件注册机制,避免代码分散"""
_registry = {
"models": {},
"datasets": {},
"trainers": {},
"evaluators": {}
}
@classmethod
def register(cls, category: str, name: str):
"""装饰器:注册组件"""
def decorator(component_cls):
if category not in cls._registry:
raise ValueError(f"Unknown category: {category}")
cls._registry[category][name] = component_cls
return component_cls
return decorator
@classmethod
def get(cls, category: str, name: str):
"""获取注册的组件"""
if category not in cls._registry:
raise ValueError(f"Unknown category: {category}")
if name not in cls._registry[category]:
available = list(cls._registry[category].keys())
raise ValueError(f"Unknown {category}: {name}. Available: {available}")
return cls._registry[category][name]
# 使用示例
@ComponentRegistry.register("models", "llama2")
class LLaMA2Model:
pass
@ComponentRegistry.register("datasets", "alpaca")
class AlpacaDataset:
pass
持续集成与测试
分层测试策略
import pytest
from unittest.mock import Mock, patch
class TestStrategy:
"""分层测试策略"""
@staticmethod
def unit_test_example():
"""单元测试:测试独立函数"""
def test_config_merge():
base = {"a": 1, "b": {"c": 2}}
override = {"b": {"c": 3, "d": 4}}
result = deep_merge(base, override)
assert result == {"a": 1, "b": {"c": 3, "d": 4}}
@staticmethod
def integration_test_example():
"""集成测试:测试组件交互"""
def test_model_with_dataloader():
model = create_model(config)
dataloader = create_dataloader(config)
batch = next(iter(dataloader))
output = model(batch)
assert output.shape == expected_shape
@staticmethod
def smoke_test_example():
"""冒烟测试:快速验证基本功能"""
def test_training_loop_runs():
config = get_minimal_config()
trainer = Trainer(config)
# 只运行几步
trainer.train(max_steps=10)
assert trainer.global_step == 10
CI/CD 配置
# .github/workflows/ci.yml
name: CI Pipeline
on:
push:
branches: [main, develop]
pull_request:
branches: [main]
jobs:
quality-checks:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v2
- name: Set up Python
uses: actions/setup-python@v2
with:
python-version: '3.9'
- name: Install dependencies
run: |
pip install -r requirements-dev.txt
- name: Run linting
run: |
flake8 src/ --max-line-length=100
black --check src/
isort --check-only src/
- name: Type checking
run: |
mypy src/ --ignore-missing-imports
- name: Run tests
run: |
pytest tests/ -v --cov=src --cov-report=xml
- name: Check code complexity
run: |
radon cc src/ -s -nb
文档与知识传承
自动化文档生成
class DocumentationGenerator:
"""自动生成实验文档"""
def generate_experiment_doc(self, experiment_class):
"""从实验类生成文档"""
doc = f"# {experiment_class.__name__}\n\n"
# 提取类文档字符串
if experiment_class.__doc__:
doc += f"{experiment_class.__doc__}\n\n"
# 提取配置参数
doc += "## Configuration Parameters\n\n"
config_schema = experiment_class.get_config_schema()
for param, schema in config_schema.items():
doc += f"- **{param}**: {schema['type']} "
if 'default' in schema:
doc += f"(default: {schema['default']})"
doc += f"\n {schema.get('description', '')}\n"
# 提取方法文档
doc += "\n## Methods\n\n"
for method_name in dir(experiment_class):
if not method_name.startswith('_'):
method = getattr(experiment_class, method_name)
if callable(method) and method.__doc__:
doc += f"### {method_name}\n"
doc += f"{method.__doc__}\n\n"
return doc
本章小结
本章深入探讨了 LLM 后训练实验代码基础设施的构建。我们学习了:
📌 配置管理的层次化设计:通过 YAML/TOML/Python 配置文件的合理选择,配置继承机制,以及运行时验证,建立了灵活且健壮的配置体系。记住:配置的复杂度应该与实验的复杂度相匹配,过度设计和设计不足都会降低效率。
📌 实验环境的多维管理:通过命令行参数、环境变量和 Git 分支的协同使用,实现了实验的隔离性和可重现性。关键原则是:Flag 用于实验特定配置,环境变量用于系统级设置,Git 分支用于代码版本管理。
📌 实验追踪的全生命周期覆盖:从 MLflow、W&B 到 TensorBoard,不同工具适合不同场景。核心是要记录足够的元数据以支持实验重现,包括代码版本、依赖环境、硬件配置等。
📌 技术债务的主动管理:通过代码质量门禁、模块化设计、自动化测试和文档生成,建立了防止代码腐化的多重防线。记住:技术债务是不可避免的,关键是要可见、可控、可偿还。
关键公式与度量
-
技术债务利息 = $\sum_{i=1}^{n} \text{complexity}_i \times \text{change_frequency}_i$
-
实验可重现性得分 = $\frac{\text{成功重现的实验数}}{\text{总实验数}} \times \text{元数据完整度}$
-
配置复杂度 = $\log_2(\text{配置参数数}) \times \text{嵌套深度}$
常见陷阱与错误 (Gotchas)
⚠️ 配置地狱(Configuration Hell)
- 错误:为每个小实验创建完全独立的配置文件
- 后果:配置文件爆炸式增长,难以维护
- 解决:使用配置继承,只记录与基线的差异
⚠️ 实验追踪过度或不足
- 错误:记录所有可能的指标 vs 只记录最终结果
- 后果:存储爆炸或信息不足无法调试
- 解决:分层记录策略,关键指标详细记录,辅助指标采样记录
⚠️ Git 分支管理混乱
- 错误:所有实验都在 main 分支进行
- 后果:代码历史混乱,难以回溯
- 解决:严格的分支命名规范和生命周期管理
⚠️ 硬编码路径和配置
- 错误:在代码中硬编码数据路径、模型路径
- 后果:代码无法跨环境运行
- 解决:所有路径通过配置或环境变量管理
⚠️ 忽视代码复杂度增长
- 错误:为了快速实验不断添加 if-else 分支
- 后果:代码变成意大利面条,无法维护
- 解决:定期重构,使用策略模式或注册机制
⚠️ 检查点管理不当
- 错误:保存所有检查点或只保存最后一个
- 后果:磁盘空间耗尽或无法恢复最佳模型
- 解决:滚动窗口策略 + 最佳模型保存
💡 实用技巧
- 配置验证前置:在实验开始前验证所有配置,fail fast
- 实验命名规范:
{date}_{model}_{dataset}_{key_hyperparam} - 自动化清理:定期清理过期的实验分支和检查点
- 增量式日志:使用结构化日志,便于后续分析
- 配置快照:每次实验开始时保存完整配置快照
练习题
基础题
练习 2.1:配置文件格式选择 你的团队正在启动一个新的 LLM 后训练项目。项目需要支持:(1) 非技术人员调整超参数;(2) 复杂的嵌套配置;(3) 动态计算某些参数。请为这个项目选择配置文件格式,并说明理由。
Hint: 考虑混合方案,不同层次使用不同格式。
参考答案
建议采用混合配置方案:
- 基础配置层(YAML):用于非技术人员可调整的参数,如学习率、批次大小等
- 高级配置层(Python):用于需要动态计算的参数,如根据 GPU 内存自动调整批次大小
- 用户覆盖层(TOML):用于用户特定的环境配置
实现方式:
- 先加载 YAML 基础配置
- 通过 Python 配置类进行动态计算和验证
- 最后应用 TOML 用户覆盖
这样既保证了易用性,又提供了足够的灵活性。
练习 2.2:实验追踪工具集成 设计一个统一的接口,能够同时向 MLflow 和 W&B 记录实验指标。要求支持批量记录和异步写入。
Hint: 使用适配器模式和队列机制。
参考答案
from abc import ABC, abstractmethod
from queue import Queue
from threading import Thread
class ExperimentTracker(ABC):
@abstractmethod
def log_metrics(self, metrics: dict, step: int): pass
class UnifiedTracker:
def __init__(self):
self.trackers = []
self.queue = Queue()
self.worker = Thread(target=self._process_queue)
self.worker.start()
def add_tracker(self, tracker: ExperimentTracker):
self.trackers.append(tracker)
def log_metrics(self, metrics: dict, step: int):
# 异步记录
self.queue.put(("metrics", metrics, step))
def _process_queue(self):
while True:
item = self.queue.get()
if item is None:
break
event_type, data, step = item
for tracker in self.trackers:
try:
tracker.log_metrics(data, step)
except Exception as e:
print(f"Tracker failed: {e}")
关键点:
- 统一接口抽象
- 异步队列避免阻塞训练
- 错误隔离,单个 tracker 失败不影响其他
练习 2.3:Git 分支清理策略 编写一个脚本,自动清理实验分支。要求:(1) 保留最近 30 天的分支;(2) 保留有未合并提交的分支;(3) 归档重要实验结果。
Hint: 使用 git for-each-ref 和 git cherry 命令。
参考答案
#!/bin/bash
cleanup_experimental_branches() {
local cutoff_date=$(date -d "30 days ago" +%s)
git for-each-ref --format='%(refname:short) %(committerdate:unix)' refs/heads/exp/ | \
while read branch timestamp; do
# 检查年龄
if [ $timestamp -lt $cutoff_date ]; then
# 检查是否有未合并的提交
unmerged=$(git cherry main $branch | grep "^+" | wc -l)
if [ $unmerged -eq 0 ]; then
# 检查是否标记为重要
if git tag --list "important/$branch" | grep -q .; then
# 归档而非删除
git tag -a "archive/$(date +%Y%m)/$branch" $branch -m "Auto-archived"
fi
git branch -D $branch
echo "Deleted: $branch"
else
echo "Kept (unmerged): $branch"
fi
fi
done
}
关键检查:
- 时间戳比较
- 未合并提交检测
- 重要性标记识别
挑战题
练习 2.4:配置差异分析 实现一个工具,能够:(1) 比较两个实验的配置差异;(2) 识别哪些配置变化导致了性能提升;(3) 生成配置优化建议。
Hint: 考虑使用决策树或 SHAP 值分析配置重要性。
参考答案
import numpy as np
from sklearn.ensemble import RandomForestRegressor
import shap
class ConfigAnalyzer:
def __init__(self, experiments: List[Dict]):
self.experiments = experiments
self.feature_names = self._extract_features()
def _extract_features(self):
# 提取所有配置键
all_keys = set()
for exp in self.experiments:
all_keys.update(self._flatten_dict(exp['config']).keys())
return sorted(all_keys)
def _flatten_dict(self, d, parent_key=''):
items = []
for k, v in d.items():
new_key = f"{parent_key}.{k}" if parent_key else k
if isinstance(v, dict):
items.extend(self._flatten_dict(v, new_key).items())
else:
items.append((new_key, v))
return dict(items)
def analyze_importance(self, metric='accuracy'):
# 准备数据
X = []
y = []
for exp in self.experiments:
flat_config = self._flatten_dict(exp['config'])
features = [flat_config.get(k, 0) for k in self.feature_names]
X.append(features)
y.append(exp['metrics'][metric])
# 训练模型
model = RandomForestRegressor(n_estimators=100)
model.fit(X, y)
# SHAP 分析
explainer = shap.TreeExplainer(model)
shap_values = explainer.shap_values(X)
# 生成重要性排名
importance = {}
for i, name in enumerate(self.feature_names):
importance[name] = np.abs(shap_values[:, i]).mean()
return sorted(importance.items(), key=lambda x: x[1], reverse=True)
def suggest_optimization(self, current_config: dict):
importance = self.analyze_importance()
suggestions = []
# 找出表现最好的配置
best_exp = max(self.experiments, key=lambda x: x['metrics']['accuracy'])
best_config = self._flatten_dict(best_exp['config'])
current_flat = self._flatten_dict(current_config)
# 基于重要性生成建议
for param, imp_score in importance[:5]: # Top 5 重要参数
if param in best_config and param in current_flat:
if best_config[param] != current_flat[param]:
suggestions.append({
'parameter': param,
'current': current_flat[param],
'suggested': best_config[param],
'importance': imp_score
})
return suggestions
核心思路:
- 使用随机森林学习配置到性能的映射
- SHAP 值量化每个配置的贡献
- 基于历史最佳实践生成优化建议
练习 2.5:实验代码版本隔离 设计一个系统,能够为每个实验创建隔离的代码环境,支持:(1) 代码快照;(2) 依赖版本锁定;(3) 快速切换和恢复。
Hint: 结合 Git worktree、Docker 或 Python 虚拟环境。
参考答案
import subprocess
import json
from pathlib import Path
import venv
class ExperimentEnvironment:
def __init__(self, base_dir: Path):
self.base_dir = base_dir
self.envs_dir = base_dir / "environments"
self.envs_dir.mkdir(exist_ok=True)
def create_environment(self, exp_id: str, config: dict):
env_path = self.envs_dir / exp_id
# 1. 创建 Git worktree
worktree_path = env_path / "code"
subprocess.run([
"git", "worktree", "add",
str(worktree_path),
config.get("git_commit", "HEAD")
])
# 2. 创建 Python 虚拟环境
venv_path = env_path / "venv"
venv.create(venv_path, with_pip=True)
# 3. 锁定依赖版本
pip_path = venv_path / "bin" / "pip"
requirements = config.get("requirements", [])
# 生成 requirements.txt
req_file = env_path / "requirements.txt"
with open(req_file, 'w') as f:
for pkg in requirements:
f.write(f"{pkg}\n")
# 安装依赖
subprocess.run([
str(pip_path), "install", "-r", str(req_file)
])
# 4. 保存环境元数据
metadata = {
"exp_id": exp_id,
"created_at": datetime.now().isoformat(),
"git_commit": subprocess.check_output(
["git", "rev-parse", "HEAD"],
cwd=worktree_path
).decode().strip(),
"config": config
}
with open(env_path / "metadata.json", 'w') as f:
json.dump(metadata, f, indent=2)
return env_path
def activate_environment(self, exp_id: str):
env_path = self.envs_dir / exp_id
# 生成激活脚本
activate_script = f"""
#!/bin/bash
export EXPERIMENT_ID={exp_id}
export PYTHONPATH={env_path}/code:$PYTHONPATH
source {env_path}/venv/bin/activate
cd {env_path}/code
echo "Environment {exp_id} activated"
"""
script_path = env_path / "activate.sh"
with open(script_path, 'w') as f:
f.write(activate_script)
script_path.chmod(0o755)
return script_path
def cleanup_environment(self, exp_id: str, archive: bool = True):
env_path = self.envs_dir / exp_id
if archive:
# 归档重要文件
archive_path = self.base_dir / "archives" / f"{exp_id}.tar.gz"
archive_path.parent.mkdir(exist_ok=True)
subprocess.run([
"tar", "czf", str(archive_path),
"-C", str(self.envs_dir),
exp_id,
"--exclude", "venv",
"--exclude", ".git"
])
# 清理 worktree
worktree_path = env_path / "code"
subprocess.run(["git", "worktree", "remove", str(worktree_path)])
# 删除环境目录
import shutil
shutil.rmtree(env_path)
关键特性:
- Git worktree 提供代码隔离
- Python venv 提供依赖隔离
- 元数据记录保证可追溯性
- 归档机制保留重要实验
练习 2.6:分布式实验协调 设计一个分布式实验管理系统,支持:(1) 多机器上的实验调度;(2) 资源(GPU)分配;(3) 实验失败自动重试。
Hint: 考虑使用消息队列和状态机。
参考答案
from enum import Enum
from dataclasses import dataclass
import redis
import json
from typing import Optional
class ExperimentState(Enum):
PENDING = "pending"
RUNNING = "running"
COMPLETED = "completed"
FAILED = "failed"
RETRYING = "retrying"
@dataclass
class ExperimentJob:
exp_id: str
config: dict
state: ExperimentState
assigned_worker: Optional[str] = None
retry_count: int = 0
max_retries: int = 3
class DistributedExperimentScheduler:
def __init__(self, redis_host: str):
self.redis = redis.Redis(host=redis_host, decode_responses=True)
self.job_queue = "experiment_queue"
self.worker_status = "worker_status"
def submit_experiment(self, exp_id: str, config: dict):
job = ExperimentJob(
exp_id=exp_id,
config=config,
state=ExperimentState.PENDING
)
# 加入队列
self.redis.lpush(self.job_queue, json.dumps({
'exp_id': job.exp_id,
'config': job.config,
'state': job.state.value,
'retry_count': job.retry_count
}))
# 记录作业状态
self.redis.hset(f"job:{exp_id}", mapping={
'state': job.state.value,
'submitted_at': datetime.now().isoformat()
})
def worker_loop(self, worker_id: str, resources: dict):
"""工作节点主循环"""
# 注册工作节点
self.redis.hset(self.worker_status, worker_id, json.dumps({
'status': 'idle',
'resources': resources,
'last_heartbeat': datetime.now().isoformat()
}))
while True:
# 获取任务
job_data = self.redis.brpop(self.job_queue, timeout=5)
if job_data:
_, job_str = job_data
job = json.loads(job_str)
# 检查资源需求
if self._can_run(job['config'], resources):
self._run_experiment(worker_id, job)
else:
# 放回队列末尾
self.redis.lpush(self.job_queue, job_str)
# 发送心跳
self._heartbeat(worker_id)
def _can_run(self, config: dict, resources: dict) -> bool:
"""检查资源是否满足需求"""
required_gpus = config.get('num_gpus', 1)
available_gpus = resources.get('gpus', 0)
required_memory = config.get('memory_gb', 16)
available_memory = resources.get('memory_gb', 0)
return (available_gpus >= required_gpus and
available_memory >= required_memory)
def _run_experiment(self, worker_id: str, job: dict):
exp_id = job['exp_id']
try:
# 更新状态
self.redis.hset(f"job:{exp_id}", mapping={
'state': ExperimentState.RUNNING.value,
'worker': worker_id,
'started_at': datetime.now().isoformat()
})
# 更新工作节点状态
self.redis.hset(self.worker_status, worker_id, json.dumps({
'status': 'busy',
'current_job': exp_id
}))
# 执行实验
result = self._execute_experiment(job['config'])
# 标记完成
self.redis.hset(f"job:{exp_id}", mapping={
'state': ExperimentState.COMPLETED.value,
'completed_at': datetime.now().isoformat(),
'result': json.dumps(result)
})
except Exception as e:
# 处理失败
self._handle_failure(exp_id, job, str(e))
def _handle_failure(self, exp_id: str, job: dict, error: str):
retry_count = job.get('retry_count', 0)
max_retries = job.get('max_retries', 3)
if retry_count < max_retries:
# 重试
job['retry_count'] = retry_count + 1
job['state'] = ExperimentState.RETRYING.value
# 延迟重试(指数退避)
delay = 2 ** retry_count
self.redis.lpush(f"{self.job_queue}:delayed:{delay}",
json.dumps(job))
self.redis.hset(f"job:{exp_id}", mapping={
'state': ExperimentState.RETRYING.value,
'retry_count': retry_count + 1,
'last_error': error
})
else:
# 最终失败
self.redis.hset(f"job:{exp_id}", mapping={
'state': ExperimentState.FAILED.value,
'failed_at': datetime.now().isoformat(),
'error': error
})
系统设计要点:
- Redis 作为中央协调器
- 基于资源的任务分配
- 状态机管理实验生命周期
- 指数退避的重试机制
- 心跳机制检测工作节点健康
练习 2.7:代码腐化度量与预警 开发一个系统,能够:(1) 量化代码腐化程度;(2) 预测技术债务增长趋势;(3) 自动生成重构建议。
Hint: 结合静态分析、Git 历史和复杂度度量。
参考答案
import ast
import git
from datetime import datetime, timedelta
import numpy as np
from sklearn.linear_model import LinearRegression
class CodeHealthMonitor:
def __init__(self, repo_path: str):
self.repo = git.Repo(repo_path)
self.metrics_history = []
def calculate_health_score(self) -> float:
"""计算代码健康度得分 (0-100)"""
metrics = {
'complexity': self._measure_complexity(),
'duplication': self._measure_duplication(),
'test_coverage': self._measure_test_coverage(),
'debt_density': self._measure_technical_debt(),
'change_frequency': self._measure_change_frequency()
}
# 加权计算总分
weights = {
'complexity': -0.3, # 复杂度越高分数越低
'duplication': -0.2, # 重复代码越多分数越低
'test_coverage': 0.25, # 测试覆盖率越高分数越高
'debt_density': -0.15, # 技术债务越多分数越低
'change_frequency': -0.1 # 频繁修改的代码分数越低
}
score = 50 # 基准分
for metric, value in metrics.items():
score += weights[metric] * value
return max(0, min(100, score))
def _measure_complexity(self) -> float:
"""测量圈复杂度"""
total_complexity = 0
file_count = 0
for py_file in Path(self.repo.working_dir).rglob("*.py"):
with open(py_file) as f:
tree = ast.parse(f.read())
complexity = self._calculate_cyclomatic_complexity(tree)
total_complexity += complexity
file_count += 1
return total_complexity / max(file_count, 1)
def _measure_duplication(self) -> float:
"""测量代码重复率"""
# 使用简化的方法:相似代码块检测
code_blocks = []
for py_file in Path(self.repo.working_dir).rglob("*.py"):
with open(py_file) as f:
content = f.read()
# 提取函数体作为代码块
tree = ast.parse(content)
for node in ast.walk(tree):
if isinstance(node, ast.FunctionDef):
code_blocks.append(ast.unparse(node))
# 计算相似度
duplicates = 0
for i, block1 in enumerate(code_blocks):
for block2 in code_blocks[i+1:]:
if self._similarity(block1, block2) > 0.8:
duplicates += 1
return duplicates / max(len(code_blocks), 1) * 100
def predict_debt_growth(self, days_ahead: int = 30) -> dict:
"""预测技术债务增长趋势"""
# 收集历史数据
history = []
for days_ago in range(90, 0, -7): # 过去90天,每周采样
date = datetime.now() - timedelta(days=days_ago)
commit = self._get_commit_at_date(date)
if commit:
self.repo.git.checkout(commit.hexsha)
metrics = {
'date': date,
'debt_count': self._count_debt_markers(),
'complexity': self._measure_complexity()
}
history.append(metrics)
# 回到当前分支
self.repo.git.checkout('main')
# 线性回归预测
X = np.array([i for i in range(len(history))]).reshape(-1, 1)
y = np.array([h['debt_count'] for h in history])
model = LinearRegression()
model.fit(X, y)
# 预测未来
future_x = len(history) + days_ahead // 7
predicted_debt = model.predict([[future_x]])[0]
return {
'current_debt': history[-1]['debt_count'] if history else 0,
'predicted_debt': predicted_debt,
'growth_rate': model.coef_[0],
'confidence': model.score(X, y)
}
def generate_refactoring_suggestions(self) -> list:
"""生成重构建议"""
suggestions = []
# 分析热点文件(频繁修改且复杂度高)
hotspots = self._identify_hotspots()
for file_path, metrics in hotspots[:5]: # Top 5 热点
suggestion = {
'file': file_path,
'reason': [],
'actions': []
}
if metrics['complexity'] > 10:
suggestion['reason'].append(f"高复杂度: {metrics['complexity']}")
suggestion['actions'].append("拆分大函数为更小的功能单元")
if metrics['change_frequency'] > 20:
suggestion['reason'].append(f"频繁修改: {metrics['change_frequency']}次/月")
suggestion['actions'].append("考虑抽象出稳定接口")
if metrics['duplication'] > 20:
suggestion['reason'].append(f"代码重复: {metrics['duplication']}%")
suggestion['actions'].append("提取公共代码到工具模块")
if metrics['test_coverage'] < 50:
suggestion['reason'].append(f"测试覆盖率低: {metrics['test_coverage']}%")
suggestion['actions'].append("增加单元测试覆盖")
suggestions.append(suggestion)
return suggestions
def _identify_hotspots(self) -> list:
"""识别代码热点"""
file_metrics = {}
# 分析 Git 历史
for commit in self.repo.iter_commits('main', max_count=100):
for file in commit.stats.files:
if file.endswith('.py'):
if file not in file_metrics:
file_metrics[file] = {
'change_frequency': 0,
'complexity': 0,
'duplication': 0,
'test_coverage': 0
}
file_metrics[file]['change_frequency'] += 1
# 计算当前指标
for file_path in file_metrics:
full_path = Path(self.repo.working_dir) / file_path
if full_path.exists():
with open(full_path) as f:
tree = ast.parse(f.read())
file_metrics[file_path]['complexity'] = \
self._calculate_cyclomatic_complexity(tree)
# 按综合得分排序
scored = []
for file, metrics in file_metrics.items():
score = (metrics['change_frequency'] * 0.4 +
metrics['complexity'] * 0.4 +
metrics['duplication'] * 0.2)
scored.append((file, metrics, score))
return sorted(scored, key=lambda x: x[2], reverse=True)
监控系统特点:
- 多维度健康评分
- 基于历史数据的趋势预测
- 热点分析识别问题区域
- 可操作的重构建议
- 持续监控和预警机制
练习 2.8:实验结果自动分析 实现一个系统,能够自动分析实验结果,识别:(1) 异常实验;(2) 性能瓶颈;(3) 最优配置组合。
Hint: 使用异常检测、性能剖析和贝叶斯优化。
参考答案
from scipy import stats
from sklearn.gaussian_process import GaussianProcessRegressor
from sklearn.gaussian_process.kernels import Matern
import numpy as np
class ExperimentAnalyzer:
def __init__(self, experiment_history: list):
self.history = experiment_history
def detect_anomalies(self) -> list:
"""检测异常实验"""
anomalies = []
# 提取关键指标
metrics = ['loss', 'accuracy', 'training_time']
for metric in metrics:
values = [exp['metrics'].get(metric, 0) for exp in self.history]
if len(values) < 3:
continue
# 使用 IQR 方法检测异常
q1, q3 = np.percentile(values, [25, 75])
iqr = q3 - q1
lower_bound = q1 - 1.5 * iqr
upper_bound = q3 + 1.5 * iqr
for i, exp in enumerate(self.history):
value = exp['metrics'].get(metric, 0)
if value < lower_bound or value > upper_bound:
anomalies.append({
'exp_id': exp['id'],
'metric': metric,
'value': value,
'expected_range': (lower_bound, upper_bound),
'severity': 'high' if abs(value - np.mean(values)) > 3 * np.std(values) else 'medium'
})
# 检测训练曲线异常
for exp in self.history:
if 'training_curve' in exp:
curve_anomalies = self._detect_curve_anomalies(exp['training_curve'])
if curve_anomalies:
anomalies.extend(curve_anomalies)
return anomalies
def _detect_curve_anomalies(self, curve: list) -> list:
"""检测训练曲线异常"""
anomalies = []
# 检测 loss 爆炸
losses = [point['loss'] for point in curve]
if any(np.isnan(losses)) or any(np.isinf(losses)):
anomalies.append({
'type': 'loss_explosion',
'severity': 'critical'
})
# 检测过拟合
if len(curve) > 10:
train_acc = [p.get('train_acc', 0) for p in curve[-10:]]
val_acc = [p.get('val_acc', 0) for p in curve[-10:]]
if train_acc and val_acc:
gap = np.mean(train_acc) - np.mean(val_acc)
if gap > 0.1: # 10% 差距
anomalies.append({
'type': 'overfitting',
'severity': 'medium',
'train_val_gap': gap
})
return anomalies
def identify_bottlenecks(self) -> dict:
"""识别性能瓶颈"""
bottlenecks = {
'data_loading': [],
'forward_pass': [],
'backward_pass': [],
'optimizer_step': []
}
for exp in self.history:
if 'profiling' not in exp:
continue
prof = exp['profiling']
# 分析各阶段耗时
total_time = sum(prof.values())
for stage, time in prof.items():
percentage = (time / total_time) * 100
# 如果某阶段占比异常高
expected_percentages = {
'data_loading': 10,
'forward_pass': 30,
'backward_pass': 40,
'optimizer_step': 20
}
if stage in expected_percentages:
expected = expected_percentages[stage]
if percentage > expected * 1.5: # 超过预期 50%
bottlenecks[stage].append({
'exp_id': exp['id'],
'percentage': percentage,
'expected': expected,
'suggestions': self._get_optimization_suggestions(stage)
})
return bottlenecks
def _get_optimization_suggestions(self, stage: str) -> list:
"""获取优化建议"""
suggestions = {
'data_loading': [
"增加数据加载的 num_workers",
"使用更高效的数据格式 (如 HDF5)",
"实现数据预取和缓存",
"考虑使用 DALI 或其他加速库"
],
'forward_pass': [
"使用混合精度训练 (AMP)",
"启用 cudnn.benchmark",
"考虑模型剪枝或量化",
"使用更高效的算子实现"
],
'backward_pass': [
"使用梯度累积减少显存占用",
"启用梯度检查点",
"考虑使用 ZeRO 优化器",
"检查是否有不必要的梯度计算"
],
'optimizer_step': [
"使用融合优化器 (如 FusedAdam)",
"减少参数更新频率",
"考虑使用 LARS 或 LAMB",
"检查权重衰减设置"
]
}
return suggestions.get(stage, [])
def find_optimal_config(self) -> dict:
"""使用贝叶斯优化找到最优配置"""
# 准备数据
configs = []
scores = []
param_names = set()
for exp in self.history:
flat_config = self._flatten_config(exp['config'])
param_names.update(flat_config.keys())
configs.append(flat_config)
scores.append(exp['metrics'].get('accuracy', 0))
param_names = sorted(param_names)
# 转换为数值矩阵
X = []
for config in configs:
x = []
for param in param_names:
value = config.get(param, 0)
# 简单的数值化
if isinstance(value, bool):
x.append(float(value))
elif isinstance(value, str):
x.append(hash(value) % 100) # 简化处理
else:
x.append(float(value))
X.append(x)
X = np.array(X)
y = np.array(scores)
# 高斯过程回归
kernel = Matern(length_scale=1.0, nu=2.5)
gpr = GaussianProcessRegressor(kernel=kernel, alpha=1e-6)
gpr.fit(X, y)
# 贝叶斯优化:找下一个最佳点
def acquisition_function(x):
"""Expected Improvement"""
mu, sigma = gpr.predict(x.reshape(1, -1), return_std=True)
best_y = np.max(y)
z = (mu - best_y - 0.01) / sigma
ei = (mu - best_y - 0.01) * stats.norm.cdf(z) + sigma * stats.norm.pdf(z)
return ei[0]
# 随机搜索找最大 EI
best_ei = -np.inf
best_config = None
for _ in range(1000):
# 在已有配置附近采样
idx = np.random.randint(len(X))
candidate = X[idx] + np.random.randn(len(param_names)) * 0.1
ei = acquisition_function(candidate)
if ei > best_ei:
best_ei = ei
best_config = candidate
# 转换回配置字典
optimal_config = {}
for i, param in enumerate(param_names):
optimal_config[param] = best_config[i]
# 预测性能
predicted_score, uncertainty = gpr.predict(
best_config.reshape(1, -1),
return_std=True
)
return {
'config': optimal_config,
'predicted_score': predicted_score[0],
'uncertainty': uncertainty[0],
'expected_improvement': best_ei
}
def _flatten_config(self, config: dict, prefix: str = '') -> dict:
"""展平嵌套配置"""
flat = {}
for key, value in config.items():
full_key = f"{prefix}.{key}" if prefix else key
if isinstance(value, dict):
flat.update(self._flatten_config(value, full_key))
else:
flat[full_key] = value
return flat
分析系统功能:
- 异常检测:IQR 方法 + 曲线分析
- 瓶颈识别:性能剖析 + 阶段分析
- 配置优化:贝叶斯优化 + 高斯过程
- 可操作建议:针对性优化方案
- 不确定性量化:预测置信度
通过完成这些练习,你将掌握构建健壮的 LLM 后训练实验基础设施的核心技能。记住,好的基础设施是高效实验的基石,值得在项目初期投入时间建设。