构建可维护、可扩展的实验代码架构是成功进行 LLM 后训练的基石。本章将深入探讨如何设计和实现一个健壮的实验基础设施,涵盖配置管理、版本控制、实验追踪等关键组件。我们将重点解决实际工程中的挑战:如何在快速迭代的同时保持代码质量,如何管理数百个实验的配置和结果,以及如何防止技术债务的累积。
在 LLM 后训练项目中,配置管理的复杂度远超传统深度学习项目。一个典型的实验可能包含上百个超参数,涉及模型架构、训练策略、数据处理、评估指标等多个维度。选择合适的配置格式至关重要。
YAML 配置的优势与劣势
YAML 因其可读性强而广受欢迎,特别适合嵌套结构的表达:
model:
architecture: llama2
hidden_size: 4096
num_layers: 32
attention:
num_heads: 32
head_dim: 128
rotary_embedding:
base: 10000
scaling_factor: 1.0
优势:
劣势:
TOML 配置的权衡
TOML 在 Rust 和 Python 社区逐渐流行,提供了更严格的语法:
[model]
architecture = "llama2"
hidden_size = 4096
num_layers = 32
[model.attention]
num_heads = 32
head_dim = 128
[model.attention.rotary_embedding]
base = 10000
scaling_factor = 1.0
优势:
劣势:
Python 配置的灵活性
直接使用 Python 文件作为配置提供了最大的灵活性:
from dataclasses import dataclass
from typing import Optional
@dataclass
class ModelConfig:
architecture: str = "llama2"
hidden_size: int = 4096
num_layers: int = 32
@property
def total_params(self) -> int:
# 动态计算参数量
return self.calculate_params()
def scale_model(self, factor: float):
"""动态调整模型规模"""
self.hidden_size = int(self.hidden_size * factor)
self.num_layers = int(self.num_layers * factor)
优势:
劣势:
实践中,我们通常需要一个基础配置和多个实验变体。设计良好的继承机制可以大幅减少配置冗余:
class ConfigManager:
def __init__(self, base_config_path: str):
self.base_config = self.load_config(base_config_path)
self.inheritance_chain = [base_config_path]
def inherit_from(self, parent_config_path: str):
"""支持多级继承"""
parent_config = self.load_config(parent_config_path)
self.base_config = self.deep_merge(parent_config, self.base_config)
self.inheritance_chain.append(parent_config_path)
def override(self, overrides: Dict[str, Any]):
"""支持命令行覆盖"""
for key_path, value in overrides.items():
self.set_nested_value(key_path, value)
def deep_merge(self, base: Dict, override: Dict) -> Dict:
"""递归合并配置字典"""
result = base.copy()
for key, value in override.items():
if key in result and isinstance(result[key], dict) and isinstance(value, dict):
result[key] = self.deep_merge(result[key], value)
else:
result[key] = value
return result
配置覆盖的优先级设计
一个清晰的优先级体系避免了配置冲突:
优先级链:
CLI Args > ENV > experiment.yaml > user.yaml > default.yaml
使用 Pydantic 或 attrs 进行运行时验证可以及早发现配置错误:
from pydantic import BaseModel, validator, Field
from typing import Literal
class TrainingConfig(BaseModel):
learning_rate: float = Field(gt=0, le=1.0)
batch_size: int = Field(gt=0, multiple_of=8)
optimizer: Literal["adam", "sgd", "adamw"]
gradient_accumulation_steps: int = Field(gt=0)
@validator("batch_size")
def validate_batch_size(cls, v, values):
if "gradient_accumulation_steps" in values:
effective_batch = v * values["gradient_accumulation_steps"]
if effective_batch > 65536:
raise ValueError(f"Effective batch size {effective_batch} too large")
return v
@validator("learning_rate")
def validate_lr_schedule(cls, v, values):
if values.get("optimizer") == "sgd" and v > 0.1:
raise ValueError("SGD learning rate typically should be < 0.1")
return v
命令行参数是实验配置的第一接触点,良好的设计能显著提升实验效率。以下是经过大规模实验验证的设计原则:
层次化的参数组织
避免平铺所有参数,而是按功能域组织:
import argparse
def create_parser():
parser = argparse.ArgumentParser()
# 使用参数组提高可读性
model_group = parser.add_argument_group("model")
model_group.add_argument("--model.name", default="llama2-7b")
model_group.add_argument("--model.checkpoint", type=str)
model_group.add_argument("--model.dtype", choices=["fp32", "fp16", "bf16"])
training_group = parser.add_argument_group("training")
training_group.add_argument("--training.batch_size", type=int, default=32)
training_group.add_argument("--training.learning_rate", type=float, default=1e-4)
training_group.add_argument("--training.warmup_steps", type=int, default=1000)
data_group = parser.add_argument_group("data")
data_group.add_argument("--data.train_path", required=True)
data_group.add_argument("--data.val_path", required=True)
data_group.add_argument("--data.num_workers", type=int, default=4)
return parser
智能默认值与必需参数
区分必需参数和可选参数,为常见场景提供合理默认值:
class FlagValidator:
@staticmethod
def validate_flags(args):
# 自动推断相关参数
if args.distributed and args.local_rank is None:
args.local_rank = int(os.environ.get("LOCAL_RANK", 0))
# 根据硬件自动设置
if args.device == "auto":
args.device = "cuda" if torch.cuda.is_available() else "cpu"
# 批量大小自动调整
if args.gradient_checkpointing and args.batch_size > 16:
logger.warning(f"Reducing batch size from {args.batch_size} to 16 due to gradient checkpointing")
args.batch_size = 16
return args
参数别名与简写
为常用参数提供简写,提高命令行效率:
parser.add_argument("-b", "--batch-size", "--training.batch_size",
dest="batch_size", type=int, default=32,
help="Training batch size per device")
parser.add_argument("-lr", "--learning-rate", "--training.lr",
dest="learning_rate", type=float, default=1e-4,
help="Peak learning rate")
parser.add_argument("-e", "--epochs", "--num-epochs",
dest="num_epochs", type=int, default=3,
help="Number of training epochs")
环境变量适合管理跨实验的全局设置和敏感信息:
分层的环境变量体系
# 系统级别(集群配置)
export CUDA_VISIBLE_DEVICES=0,1,2,3
export OMP_NUM_THREADS=8
export NCCL_DEBUG=INFO
# 项目级别(路径和凭证)
export LLM_DATA_ROOT=/mnt/data/llm_datasets
export LLM_CHECKPOINT_DIR=/mnt/checkpoints
export WANDB_API_KEY=your_api_key_here
export HF_TOKEN=your_huggingface_token
# 实验级别(运行时配置)
export LLM_EXPERIMENT_NAME=dpo_ablation_v3
export LLM_RUN_ID=$(date +%Y%m%d_%H%M%S)
export LLM_DEBUG_MODE=1
环境变量的最佳实践
import os
from pathlib import Path
from typing import Optional
class EnvConfig:
"""统一管理环境变量"""
@staticmethod
def get_data_root() -> Path:
"""获取数据根目录,支持多级fallback"""
candidates = [
os.environ.get("LLM_DATA_ROOT"),
os.environ.get("DATA_ROOT"),
"/data/llm",
"./data"
]
for path in candidates:
if path and Path(path).exists():
return Path(path)
raise ValueError("No valid data root found")
@staticmethod
def get_wandb_config() -> dict:
"""安全地获取 W&B 配置"""
config = {}
if api_key := os.environ.get("WANDB_API_KEY"):
config["api_key"] = api_key
if project := os.environ.get("WANDB_PROJECT"):
config["project"] = project
if entity := os.environ.get("WANDB_ENTITY"):
config["entity"] = entity
return config
@staticmethod
def is_debug_mode() -> bool:
"""检查调试模式"""
return os.environ.get("LLM_DEBUG_MODE", "0").lower() in ("1", "true", "yes")
在快速迭代的实验环境中,Git 分支策略需要平衡实验自由度和代码质量:
实验分支命名规范
# 功能开发分支
feature/distributed-dpo
feature/multimodal-alignment
# 实验分支(短期)
exp/20250105-lr-sweep
exp/20250106-batch-size-ablation
# 个人实验分支(更自由)
dev/alice/rope-scaling
dev/bob/attention-variants
# 长期研究分支
research/constitutional-ai
research/online-rlhf
分支保护与合并策略
# .github/branch_protection.yml
protection_rules:
main:
required_reviews: 2
dismiss_stale_reviews: true
require_code_owner_reviews: true
required_status_checks:
- lint
- type-check
- unit-tests
enforce_admins: false
release/*:
required_reviews: 3
require_code_owner_reviews: true
required_status_checks:
- all-tests
- integration-tests
- benchmark-regression
实验分支的生命周期管理
#!/bin/bash
# scripts/manage_exp_branches.sh
# 自动清理过期的实验分支
cleanup_old_exp_branches() {
local days_old=${1:-30}
git for-each-ref --format='%(refname:short) %(committerdate:unix)' refs/heads/exp/ | \
while read branch timestamp; do
age_days=$(( ($(date +%s) - timestamp) / 86400 ))
if [ $age_days -gt $days_old ]; then
echo "Deleting old experimental branch: $branch (${age_days} days old)"
git branch -D "$branch"
fi
done
}
# 归档重要实验分支
archive_exp_branch() {
local branch=$1
local archive_tag="archive/$(date +%Y%m%d)/${branch##*/}"
git tag -a "$archive_tag" "$branch" -m "Archived experimental branch $branch"
git push origin "$archive_tag"
git branch -d "$branch"
echo "Archived $branch as $archive_tag"
}
选择合适的实验追踪工具是建立可重现研究流程的关键。主流工具各有特色,需要根据团队规模和需求选择。
MLflow:开源标准的选择
MLflow 提供了完整的实验生命周期管理:
import mlflow
from mlflow.tracking import MlflowClient
import hashlib
import json
class MLflowExperimentTracker:
def __init__(self, experiment_name: str, tracking_uri: str = "file:./mlruns"):
mlflow.set_tracking_uri(tracking_uri)
mlflow.set_experiment(experiment_name)
self.client = MlflowClient()
def start_run(self, config: dict, tags: dict = None):
"""开始一个新的实验运行"""
# 生成配置哈希作为运行标识
config_hash = hashlib.md5(
json.dumps(config, sort_keys=True).encode()
).hexdigest()[:8]
run_name = f"{config.get('model_name', 'unknown')}_{config_hash}"
with mlflow.start_run(run_name=run_name) as run:
# 记录配置参数
mlflow.log_params(self.flatten_dict(config))
# 记录标签
if tags:
mlflow.set_tags(tags)
# 记录代码版本
mlflow.set_tag("git_commit", self.get_git_commit())
mlflow.set_tag("git_branch", self.get_git_branch())
return run.info.run_id
def log_metrics_batch(self, metrics: dict, step: int):
"""批量记录指标"""
for key, value in metrics.items():
mlflow.log_metric(key, value, step=step)
def log_artifact_with_metadata(self, file_path: str, metadata: dict):
"""记录文件及其元数据"""
mlflow.log_artifact(file_path)
# 同时记录元数据
metadata_path = f"{file_path}.metadata.json"
with open(metadata_path, 'w') as f:
json.dump(metadata, f, indent=2)
mlflow.log_artifact(metadata_path)
Weights & Biases:云原生的强大功能
W&B 提供了更丰富的可视化和协作功能:
import wandb
from typing import Any, Dict, Optional
import numpy as np
class WandBTracker:
def __init__(self, project: str, entity: Optional[str] = None):
self.project = project
self.entity = entity
def init_run(self, config: dict, name: Optional[str] = None,
resume: Optional[str] = None):
"""初始化 W&B 运行"""
run = wandb.init(
project=self.project,
entity=self.entity,
config=config,
name=name,
resume=resume, # 支持断点续训
save_code=True, # 自动保存代码
tags=self.generate_tags(config)
)
# 定义自定义指标
wandb.define_metric("train/step")
wandb.define_metric("train/*", step_metric="train/step")
wandb.define_metric("eval/step")
wandb.define_metric("eval/*", step_metric="eval/step")
return run
def log_distribution(self, name: str, data: np.ndarray, step: int):
"""记录数据分布"""
wandb.log({
f"{name}/mean": np.mean(data),
f"{name}/std": np.std(data),
f"{name}/min": np.min(data),
f"{name}/max": np.max(data),
f"{name}/histogram": wandb.Histogram(data)
}, step=step)
def log_gradient_flow(self, model, step: int):
"""记录梯度流信息"""
gradients = []
for name, param in model.named_parameters():
if param.grad is not None:
grad_norm = param.grad.data.norm(2).item()
gradients.append({
"name": name,
"grad_norm": grad_norm
})
# 创建梯度表格
grad_table = wandb.Table(
columns=["layer", "gradient_norm"],
data=[[g["name"], g["grad_norm"]] for g in gradients]
)
wandb.log({"gradients": grad_table}, step=step)
TensorBoard:轻量级本地方案
对于不需要云服务的场景,TensorBoard 仍是可靠选择:
from torch.utils.tensorboard import SummaryWriter
import torch
from pathlib import Path
class TensorBoardTracker:
def __init__(self, log_dir: str, comment: str = ""):
self.log_dir = Path(log_dir)
self.writer = SummaryWriter(
log_dir=str(self.log_dir),
comment=comment,
flush_secs=30 # 定期刷新到磁盘
)
def log_model_architecture(self, model: torch.nn.Module, input_shape: tuple):
"""记录模型架构"""
dummy_input = torch.randn(input_shape)
self.writer.add_graph(model, dummy_input)
def log_attention_weights(self, attention_weights: torch.Tensor,
step: int, head_idx: int = 0):
"""可视化注意力权重"""
# 选择特定的注意力头
attn = attention_weights[0, head_idx].cpu().numpy()
# 创建热力图
fig = plt.figure(figsize=(10, 8))
plt.imshow(attn, cmap='hot', interpolation='nearest')
plt.colorbar()
plt.title(f'Attention Weights - Head {head_idx}')
self.writer.add_figure(f'attention/head_{head_idx}', fig, step)
plt.close()
def log_learning_rate_schedule(self, optimizer, step: int):
"""记录学习率变化"""
for i, param_group in enumerate(optimizer.param_groups):
lr = param_group['lr']
self.writer.add_scalar(f'learning_rate/group_{i}', lr, step)
完整的元数据记录是实验可重现性的基础:
import platform
import subprocess
import datetime
import psutil
import GPUtil
class ExperimentMetadata:
@staticmethod
def collect_system_info() -> dict:
"""收集系统信息"""
return {
"platform": platform.platform(),
"python_version": platform.python_version(),
"cpu_count": psutil.cpu_count(),
"memory_gb": psutil.virtual_memory().total / (1024**3),
"hostname": platform.node(),
"user": os.environ.get("USER", "unknown")
}
@staticmethod
def collect_gpu_info() -> list:
"""收集 GPU 信息"""
gpus = GPUtil.getGPUs()
return [{
"id": gpu.id,
"name": gpu.name,
"memory_total": gpu.memoryTotal,
"driver": gpu.driver,
"compute_capability": f"{gpu.major}.{gpu.minor}"
} for gpu in gpus]
@staticmethod
def collect_dependencies() -> dict:
"""收集依赖版本"""
deps = {}
try:
import torch
deps["torch"] = torch.__version__
deps["cuda"] = torch.version.cuda if torch.cuda.is_available() else None
except ImportError:
pass
try:
import transformers
deps["transformers"] = transformers.__version__
except ImportError:
pass
# 从 requirements.txt 或 pyproject.toml 读取
if Path("requirements.txt").exists():
with open("requirements.txt") as f:
for line in f:
if "==" in line:
pkg, version = line.strip().split("==")
deps[pkg] = version
return deps
@staticmethod
def create_experiment_card(config: dict) -> dict:
"""创建实验卡片"""
return {
"experiment_id": str(uuid.uuid4()),
"timestamp": datetime.datetime.now().isoformat(),
"config": config,
"system": ExperimentMetadata.collect_system_info(),
"gpus": ExperimentMetadata.collect_gpu_info(),
"dependencies": ExperimentMetadata.collect_dependencies(),
"git": {
"commit": subprocess.check_output(
["git", "rev-parse", "HEAD"]
).decode().strip(),
"branch": subprocess.check_output(
["git", "branch", "--show-current"]
).decode().strip(),
"diff": subprocess.check_output(
["git", "diff", "HEAD"]
).decode()
}
}
高效的检查点管理对于长时间训练至关重要:
class CheckpointManager:
def __init__(self, checkpoint_dir: Path, max_checkpoints: int = 5):
self.checkpoint_dir = Path(checkpoint_dir)
self.checkpoint_dir.mkdir(parents=True, exist_ok=True)
self.max_checkpoints = max_checkpoints
def save_checkpoint(self, model, optimizer, epoch: int,
metrics: dict, is_best: bool = False):
"""保存检查点"""
checkpoint = {
"epoch": epoch,
"model_state_dict": model.state_dict(),
"optimizer_state_dict": optimizer.state_dict(),
"metrics": metrics,
"timestamp": datetime.datetime.now().isoformat()
}
# 常规检查点
checkpoint_path = self.checkpoint_dir / f"checkpoint_epoch_{epoch}.pt"
torch.save(checkpoint, checkpoint_path)
# 最佳模型
if is_best:
best_path = self.checkpoint_dir / "best_model.pt"
torch.save(checkpoint, best_path)
# 清理旧检查点
self._cleanup_old_checkpoints()
def _cleanup_old_checkpoints(self):
"""保留最新的N个检查点"""
checkpoints = sorted(
self.checkpoint_dir.glob("checkpoint_epoch_*.pt"),
key=lambda x: x.stat().st_mtime
)
if len(checkpoints) > self.max_checkpoints:
for ckpt in checkpoints[:-self.max_checkpoints]:
ckpt.unlink()
def resume_from_checkpoint(self, checkpoint_path: Path,
model, optimizer) -> dict:
"""从检查点恢复"""
checkpoint = torch.load(checkpoint_path, map_location="cpu")
model.load_state_dict(checkpoint["model_state_dict"])
optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
return {
"epoch": checkpoint["epoch"],
"metrics": checkpoint["metrics"]
}
LLM 后训练项目的快速迭代容易累积技术债务。主动管理技术债务是保持项目长期健康的关键。
技术债务的量化与追踪
from typing import List, Dict
import ast
import re
class TechnicalDebtAnalyzer:
def __init__(self, codebase_path: Path):
self.codebase_path = codebase_path
self.debt_markers = ["TODO", "FIXME", "HACK", "XXX", "DEPRECATED"]
def scan_codebase(self) -> Dict[str, List[Dict]]:
"""扫描代码库中的技术债务标记"""
debt_items = {marker: [] for marker in self.debt_markers}
for py_file in self.codebase_path.rglob("*.py"):
with open(py_file, 'r') as f:
for line_num, line in enumerate(f, 1):
for marker in self.debt_markers:
if marker in line:
debt_items[marker].append({
"file": str(py_file.relative_to(self.codebase_path)),
"line": line_num,
"content": line.strip(),
"priority": self.estimate_priority(line)
})
return debt_items
def calculate_complexity_metrics(self, file_path: Path) -> dict:
"""计算代码复杂度指标"""
with open(file_path, 'r') as f:
source = f.read()
tree = ast.parse(source)
metrics = {
"cyclomatic_complexity": self.calculate_cyclomatic_complexity(tree),
"lines_of_code": len(source.splitlines()),
"num_functions": len([n for n in ast.walk(tree) if isinstance(n, ast.FunctionDef)]),
"num_classes": len([n for n in ast.walk(tree) if isinstance(n, ast.ClassDef)]),
"max_nesting_depth": self.calculate_max_nesting(tree)
}
return metrics
def generate_debt_report(self) -> str:
"""生成技术债务报告"""
debt_items = self.scan_codebase()
total_debt = sum(len(items) for items in debt_items.values())
report = f"""
# 技术债务报告
生成时间: {datetime.datetime.now().isoformat()}
总债务项: {total_debt}
## 按类型分布
"""
for marker, items in debt_items.items():
report += f"- {marker}: {len(items)} 项\n"
# 高优先级项目
high_priority = []
for marker, items in debt_items.items():
high_priority.extend([
item for item in items
if item["priority"] == "high"
])
if high_priority:
report += "\n## 高优先级债务\n"
for item in high_priority[:10]: # 显示前10个
report += f"- {item['file']}:{item['line']} - {item['content']}\n"
return report
代码质量门禁
class CodeQualityGate:
def __init__(self, thresholds: dict):
self.thresholds = thresholds
def check_diff_quality(self, diff_file: str) -> bool:
"""检查代码变更的质量"""
checks = {
"no_print_statements": self.check_no_prints(diff_file),
"has_tests": self.check_has_tests(diff_file),
"docstring_coverage": self.check_docstrings(diff_file),
"type_hints": self.check_type_hints(diff_file)
}
failures = [name for name, passed in checks.items() if not passed]
if failures:
print(f"Quality gate failed: {', '.join(failures)}")
return False
return True
def check_no_prints(self, diff: str) -> bool:
"""检查是否包含调试用的print语句"""
pattern = r'\+.*print\('
return not re.search(pattern, diff)
def check_has_tests(self, diff: str) -> bool:
"""检查是否包含对应的测试"""
# 如果修改了src/下的文件,应该有对应的test/下的修改
src_modified = "src/" in diff
test_modified = "test/" in diff or "tests/" in diff
if src_modified and not test_modified:
return False
return True
良好的模块化设计是防止代码腐化的基础:
from abc import ABC, abstractmethod
from typing import Generic, TypeVar
T = TypeVar('T')
class BaseExperiment(ABC, Generic[T]):
"""实验基类,强制规范化实验流程"""
def __init__(self, config: dict):
self.config = config
self.setup()
@abstractmethod
def setup(self):
"""初始化实验环境"""
pass
@abstractmethod
def prepare_data(self) -> T:
"""数据准备"""
pass
@abstractmethod
def build_model(self):
"""构建模型"""
pass
@abstractmethod
def train_step(self, batch: T) -> dict:
"""单步训练"""
pass
@abstractmethod
def evaluate(self) -> dict:
"""评估"""
pass
def run(self):
"""标准化的实验流程"""
data = self.prepare_data()
model = self.build_model()
for epoch in range(self.config["num_epochs"]):
for batch in data:
metrics = self.train_step(batch)
self.log_metrics(metrics)
eval_metrics = self.evaluate()
self.log_eval_metrics(eval_metrics)
组件注册机制
class ComponentRegistry:
"""统一的组件注册机制,避免代码分散"""
_registry = {
"models": {},
"datasets": {},
"trainers": {},
"evaluators": {}
}
@classmethod
def register(cls, category: str, name: str):
"""装饰器:注册组件"""
def decorator(component_cls):
if category not in cls._registry:
raise ValueError(f"Unknown category: {category}")
cls._registry[category][name] = component_cls
return component_cls
return decorator
@classmethod
def get(cls, category: str, name: str):
"""获取注册的组件"""
if category not in cls._registry:
raise ValueError(f"Unknown category: {category}")
if name not in cls._registry[category]:
available = list(cls._registry[category].keys())
raise ValueError(f"Unknown {category}: {name}. Available: {available}")
return cls._registry[category][name]
# 使用示例
@ComponentRegistry.register("models", "llama2")
class LLaMA2Model:
pass
@ComponentRegistry.register("datasets", "alpaca")
class AlpacaDataset:
pass
分层测试策略
import pytest
from unittest.mock import Mock, patch
class TestStrategy:
"""分层测试策略"""
@staticmethod
def unit_test_example():
"""单元测试:测试独立函数"""
def test_config_merge():
base = {"a": 1, "b": {"c": 2}}
override = {"b": {"c": 3, "d": 4}}
result = deep_merge(base, override)
assert result == {"a": 1, "b": {"c": 3, "d": 4}}
@staticmethod
def integration_test_example():
"""集成测试:测试组件交互"""
def test_model_with_dataloader():
model = create_model(config)
dataloader = create_dataloader(config)
batch = next(iter(dataloader))
output = model(batch)
assert output.shape == expected_shape
@staticmethod
def smoke_test_example():
"""冒烟测试:快速验证基本功能"""
def test_training_loop_runs():
config = get_minimal_config()
trainer = Trainer(config)
# 只运行几步
trainer.train(max_steps=10)
assert trainer.global_step == 10
CI/CD 配置
# .github/workflows/ci.yml
name: CI Pipeline
on:
push:
branches: [main, develop]
pull_request:
branches: [main]
jobs:
quality-checks:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v2
- name: Set up Python
uses: actions/setup-python@v2
with:
python-version: '3.9'
- name: Install dependencies
run: |
pip install -r requirements-dev.txt
- name: Run linting
run: |
flake8 src/ --max-line-length=100
black --check src/
isort --check-only src/
- name: Type checking
run: |
mypy src/ --ignore-missing-imports
- name: Run tests
run: |
pytest tests/ -v --cov=src --cov-report=xml
- name: Check code complexity
run: |
radon cc src/ -s -nb
自动化文档生成
class DocumentationGenerator:
"""自动生成实验文档"""
def generate_experiment_doc(self, experiment_class):
"""从实验类生成文档"""
doc = f"# {experiment_class.__name__}\n\n"
# 提取类文档字符串
if experiment_class.__doc__:
doc += f"{experiment_class.__doc__}\n\n"
# 提取配置参数
doc += "## Configuration Parameters\n\n"
config_schema = experiment_class.get_config_schema()
for param, schema in config_schema.items():
doc += f"- **{param}**: {schema['type']} "
if 'default' in schema:
doc += f"(default: {schema['default']})"
doc += f"\n {schema.get('description', '')}\n"
# 提取方法文档
doc += "\n## Methods\n\n"
for method_name in dir(experiment_class):
if not method_name.startswith('_'):
method = getattr(experiment_class, method_name)
if callable(method) and method.__doc__:
doc += f"### {method_name}\n"
doc += f"{method.__doc__}\n\n"
return doc
本章深入探讨了 LLM 后训练实验代码基础设施的构建。我们学习了:
📌 配置管理的层次化设计:通过 YAML/TOML/Python 配置文件的合理选择,配置继承机制,以及运行时验证,建立了灵活且健壮的配置体系。记住:配置的复杂度应该与实验的复杂度相匹配,过度设计和设计不足都会降低效率。
📌 实验环境的多维管理:通过命令行参数、环境变量和 Git 分支的协同使用,实现了实验的隔离性和可重现性。关键原则是:Flag 用于实验特定配置,环境变量用于系统级设置,Git 分支用于代码版本管理。
📌 实验追踪的全生命周期覆盖:从 MLflow、W&B 到 TensorBoard,不同工具适合不同场景。核心是要记录足够的元数据以支持实验重现,包括代码版本、依赖环境、硬件配置等。
📌 技术债务的主动管理:通过代码质量门禁、模块化设计、自动化测试和文档生成,建立了防止代码腐化的多重防线。记住:技术债务是不可避免的,关键是要可见、可控、可偿还。
技术债务利息 = $\sum_{i=1}^{n} \text{complexity}_i \times \text{change_frequency}_i$
实验可重现性得分 = $\frac{\text{成功重现的实验数}}{\text{总实验数}} \times \text{元数据完整度}$
配置复杂度 = $\log_2(\text{配置参数数}) \times \text{嵌套深度}$
⚠️ 配置地狱(Configuration Hell)
⚠️ 实验追踪过度或不足
⚠️ Git 分支管理混乱
⚠️ 硬编码路径和配置
⚠️ 忽视代码复杂度增长
⚠️ 检查点管理不当
💡 实用技巧
{date}_{model}_{dataset}_{key_hyperparam}练习 2.1:配置文件格式选择 你的团队正在启动一个新的 LLM 后训练项目。项目需要支持:(1) 非技术人员调整超参数;(2) 复杂的嵌套配置;(3) 动态计算某些参数。请为这个项目选择配置文件格式,并说明理由。
Hint: 考虑混合方案,不同层次使用不同格式。
练习 2.2:实验追踪工具集成 设计一个统一的接口,能够同时向 MLflow 和 W&B 记录实验指标。要求支持批量记录和异步写入。
Hint: 使用适配器模式和队列机制。
练习 2.3:Git 分支清理策略 编写一个脚本,自动清理实验分支。要求:(1) 保留最近 30 天的分支;(2) 保留有未合并提交的分支;(3) 归档重要实验结果。
Hint: 使用 git for-each-ref 和 git cherry 命令。
练习 2.4:配置差异分析 实现一个工具,能够:(1) 比较两个实验的配置差异;(2) 识别哪些配置变化导致了性能提升;(3) 生成配置优化建议。
Hint: 考虑使用决策树或 SHAP 值分析配置重要性。
练习 2.5:实验代码版本隔离 设计一个系统,能够为每个实验创建隔离的代码环境,支持:(1) 代码快照;(2) 依赖版本锁定;(3) 快速切换和恢复。
Hint: 结合 Git worktree、Docker 或 Python 虚拟环境。
练习 2.6:分布式实验协调 设计一个分布式实验管理系统,支持:(1) 多机器上的实验调度;(2) 资源(GPU)分配;(3) 实验失败自动重试。
Hint: 考虑使用消息队列和状态机。
练习 2.7:代码腐化度量与预警 开发一个系统,能够:(1) 量化代码腐化程度;(2) 预测技术债务增长趋势;(3) 自动生成重构建议。
Hint: 结合静态分析、Git 历史和复杂度度量。
练习 2.8:实验结果自动分析 实现一个系统,能够自动分析实验结果,识别:(1) 异常实验;(2) 性能瓶颈;(3) 最优配置组合。
Hint: 使用异常检测、性能剖析和贝叶斯优化。
通过完成这些练习,你将掌握构建健壮的 LLM 后训练实验基础设施的核心技能。记住,好的基础设施是高效实验的基石,值得在项目初期投入时间建设。