跳转至

03 - 大模型预训练(全面版)

⚠️ 时效性说明:本章涉及前沿模型/价格/榜单等信息,可能随版本快速变化;请以论文原文、官方发布页和 API 文档为准。

学习目标:深入理解大模型预训练的完整流程,包括数据准备、训练策略、分布式训练和训练稳定性技术。


目录

  1. 预训练概述
  2. 预训练目标函数
  3. 数据工程与处理
  4. 训练策略与优化
  5. 分布式训练
  6. 训练稳定性
  7. 训练监控与调试
  8. 评估与Checkpoint管理

预训练概述

1.1 预训练 vs 微调

Text Only
预训练(Pre-training)
├── 数据:大规模无标注文本(TB级别)
├── 目标:学习通用语言表示
├── 计算:需要数千GPU训练数周
├── 结果:基础模型(如GPT-3, LLaMA)
└── 成本:数百万美元

微调(Fine-tuning)
├── 数据:特定任务标注数据(MB-GB级别)
├── 目标:适配特定任务
├── 计算:单卡或少数GPU训练数小时
├── 结果:任务专用模型
└── 成本:数百至数千美元

为什么先预训练再微调?
├── 预训练学到通用知识(语法、常识、推理)
├── 微调只需学习特定任务映射
├── 小数据也能获得好效果(迁移学习)
└── 避免从头训练的巨大成本

1.2 预训练目标函数

Python
class PretrainingObjectives:
    """
    预训练目标函数详解
    """

    @staticmethod  # @staticmethod无需实例即可调用
    def causal_language_modeling(logits, labels):
        """
        因果语言建模(Causal LM / Autoregressive)

        用于:GPT系列、LLaMA、Claude等
        目标:预测下一个token

        Loss = -Σ log P(x_t | x_<t)
        """
        # 移位:用位置t预测位置t+1
        shift_logits = logits[..., :-1, :].contiguous()
        shift_labels = labels[..., 1:].contiguous()

        loss = F.cross_entropy(
            shift_logits.view(-1, shift_logits.size(-1)),  # view重塑张量形状
            shift_labels.view(-1)
        )
        return loss

    @staticmethod
    def masked_language_modeling(logits, labels, mask_positions):
        """
        掩码语言建模(Masked LM)

        用于:BERT、RoBERTa等
        目标:预测被mask的token

        通常mask 15%的token:
        - 80% 替换为[MASK]
        - 10% 替换为随机token
        - 10% 保持不变
        """
        # 只计算被mask位置的损失
        masked_logits = logits[mask_positions]
        masked_labels = labels[mask_positions]

        loss = F.cross_entropy(masked_logits, masked_labels)
        return loss

    @staticmethod
    def prefix_lm_loss(logits, labels, prefix_len):
        """
        前缀语言建模(Prefix LM)

        用于:T5、UL2等
        目标:前缀双向,生成单向
        """
        # 只对生成部分计算损失
        gen_logits = logits[:, prefix_len:]
        gen_labels = labels[:, prefix_len:]

        loss = F.cross_entropy(
            gen_logits.reshape(-1, gen_logits.size(-1)),
            gen_labels.reshape(-1)
        )
        return loss

# 不同目标函数对比
"""
目标函数              代表模型          优点                  缺点
─────────────────────────────────────────────────────────────────
Causal LM            GPT, LLaMA       生成能力强,简单        单向上下文
Masked LM            BERT             双向上下文              预训练-微调不一致
Prefix LM            T5, UL2          灵活,统一框架          实现复杂
Span Corruption      T5               适合迁移学习            需要encoder-decoder
"""

数据工程与处理

2.1 数据 pipeline

Python
class DataPipeline:
    """
    预训练数据Pipeline
    """

    def __init__(self, config):
        self.config = config
        self.tokenizer = None

    def load_data(self, data_paths):
        """
        加载原始数据
        """
        from datasets import load_dataset

        # 支持多种格式
        if data_paths[0].endswith('.jsonl'):
            dataset = load_dataset('json', data_files=data_paths)
        elif data_paths[0].endswith('.txt'):
            dataset = load_dataset('text', data_files=data_paths)
        else:
            dataset = load_dataset(data_paths)

        return dataset

    def preprocess(self, dataset):
        """
        数据预处理
        """
        def tokenize_function(examples):
            # 分词
            return self.tokenizer(
                examples['text'],
                truncation=True,
                max_length=self.config.max_seq_length,
                return_special_tokens_mask=True
            )

        # 应用分词
        tokenized = dataset.map(
            tokenize_function,
            batched=True,
            num_proc=self.config.num_workers,
            remove_columns=dataset['train'].column_names
        )

        return tokenized

    def create_dataloader(self, dataset, batch_size):
        """
        创建DataLoader
        """
        from torch.utils.data import DataLoader

        # 使用分布式采样器
        sampler = DistributedSampler(dataset) if self.config.distributed else None

        dataloader = DataLoader(
            dataset,
            batch_size=batch_size,
            sampler=sampler,
            collate_fn=self.data_collator,
            num_workers=self.config.num_workers,
            pin_memory=True
        )

        return dataloader

    def data_collator(self, features):
        """
        数据整理
        """
        import torch

        # 填充到相同长度
        batch = {}

        max_length = max(len(f['input_ids']) for f in features)

        for key in ['input_ids', 'attention_mask']:
            padded = []
            for f in features:
                padding_length = max_length - len(f[key])
                # input_ids用pad_token_id填充,attention_mask用0填充
                pad_value = self.tokenizer.pad_token_id if key == 'input_ids' else 0
                padded.append(f[key] + [pad_value] * padding_length)
            batch[key] = torch.tensor(padded)

        # 创建labels(与input_ids相同)
        batch['labels'] = batch['input_ids'].clone()

        # 将padding位置设为-100(忽略损失)
        # ⚠️ 注意: 如果 pad_token_id == eos_token_id(很多模型如 GPT-2 默认如此),
        # 下面这行会把所有 EOS token 也设为 -100,导致模型永远学不到何时停止生成。
        # 建议先设置单独的 pad_token(如 tokenizer.add_special_tokens({"pad_token": "[PAD]"}))。
        batch['labels'][batch['labels'] == self.tokenizer.pad_token_id] = -100

        return batch

2.2 数据去重与清洗

Python
class DataDeduplication:
    """
    数据去重
    """

    def __init__(self, threshold=0.9):
        self.threshold = threshold
        self.seen_hashes = set()

    def exact_deduplication(self, texts):
        """
        精确去重
        """
        import hashlib

        unique_texts = []

        for text in texts:
            text_hash = hashlib.md5(text.encode()).hexdigest()

            if text_hash not in self.seen_hashes:
                self.seen_hashes.add(text_hash)
                unique_texts.append(text)

        return unique_texts

    def minhash_deduplication(self, texts, num_perm=128):
        """
        MinHash近似去重
        """
        from datasketch import MinHash, MinHashLSH

        # 创建LSH索引
        lsh = MinHashLSH(threshold=self.threshold, num_perm=num_perm)

        unique_texts = []

        for i, text in enumerate(texts):  # enumerate同时获取索引和元素
            # 创建MinHash
            m = MinHash(num_perm=num_perm)

            # 添加shingles
            for shingle in self.get_shingles(text, k=5):
                m.update(shingle.encode('utf8'))

            # 检查是否近似重复
            if not lsh.query(m):
                lsh.insert(f"doc_{i}", m)
                unique_texts.append(text)

        return unique_texts

    def get_shingles(self, text, k=5):
        """
        获取k-gram shingles
        """
        words = text.split()
        return [' '.join(words[i:i+k]) for i in range(len(words)-k+1)]

class DataCleaning:
    """
    数据清洗
    """

    def __init__(self):
        self.min_length = 100
        self.max_length = 100000

    def clean_text(self, text):
        """
        清洗单条文本
        """
        import re

        # 长度过滤
        if len(text) < self.min_length or len(text) > self.max_length:
            return None

        # 去除过多重复字符
        if re.search(r'(.)\1{10,}', text):  # re.search正则表达式搜索匹配
            return None

        # 去除过多换行
        if text.count('\n') > len(text) / 10:
            return None

        # 去除过多大写字母(可能是乱码)
        if sum(1 for c in text if c.isupper()) / len(text) > 0.5:
            return None

        # 清洗HTML标签
        text = re.sub(r'<[^>]+>', '', text)

        # 规范化空白
        text = ' '.join(text.split())

        return text

训练策略与优化

3.1 优化器与学习率调度

Python
class TrainingOptimizer:
    """
    训练优化器配置
    """

    @staticmethod
    def create_adamw_optimizer(model, lr=1e-4, weight_decay=0.01):
        """
        创建AdamW优化器
        """
        # 区分权重衰减
        no_decay = ['bias', 'LayerNorm.weight', 'layer_norm']

        optimizer_grouped_parameters = [
            {
                'params': [p for n, p in model.named_parameters()
                          if not any(nd in n for nd in no_decay)],  # any()任一为True则返回True
                'weight_decay': weight_decay
            },
            {
                'params': [p for n, p in model.named_parameters()
                          if any(nd in n for nd in no_decay)],
                'weight_decay': 0.0
            }
        ]

        optimizer = torch.optim.AdamW(
            optimizer_grouped_parameters,
            lr=lr,
            betas=(0.9, 0.95),
            eps=1e-8
        )

        return optimizer

    @staticmethod
    def create_scheduler(optimizer, num_warmup_steps, num_training_steps, scheduler_type='cosine'):
        """
        创建学习率调度器
        """
        from transformers import get_scheduler

        scheduler = get_scheduler(
            scheduler_type,
            optimizer=optimizer,
            num_warmup_steps=num_warmup_steps,
            num_training_steps=num_training_steps
        )

        return scheduler

# 学习率调度可视化
def visualize_lr_schedule():
    """
    可视化不同学习率调度策略
    """
    import matplotlib.pyplot as plt
    from transformers import get_scheduler

    num_steps = 10000
    warmup_steps = 1000
    base_lr = 1e-4

    # 为每种调度策略分别创建优化器,避免共享状态冲突
    model = torch.nn.Linear(10, 10)

    schedule_configs = {
        'linear': {},
        'cosine': {},
        'polynomial': {},  # 默认power=1.0
        'constant_with_warmup': {},
    }

    plt.figure(figsize=(12, 6))

    for name, kwargs in schedule_configs.items():
        optimizer = torch.optim.AdamW(model.parameters(), lr=base_lr)
        scheduler = get_scheduler(name, optimizer, warmup_steps, num_steps, **kwargs)

        lrs = []
        for step in range(num_steps):
            lrs.append(optimizer.param_groups[0]['lr'])
            scheduler.step()

        plt.plot(lrs, label=name)

    plt.xlabel('Step')
    plt.ylabel('Learning Rate')
    plt.title('Learning Rate Schedules')
    plt.legend()
    plt.grid(True)
    plt.show()

3.2 梯度累积与裁剪

Python
class GradientManagement:
    """
    梯度管理
    """

    def __init__(self, accumulation_steps=4, max_grad_norm=1.0):
        self.accumulation_steps = accumulation_steps
        self.max_grad_norm = max_grad_norm
        self.step_count = 0

    def backward(self, loss, optimizer):
        """
        反向传播(支持梯度累积)
        """
        # 缩放损失
        scaled_loss = loss / self.accumulation_steps
        scaled_loss.backward()

        self.step_count += 1

        # 达到累积步数后更新
        if self.step_count % self.accumulation_steps == 0:
            # 梯度裁剪
            torch.nn.utils.clip_grad_norm_(
                self.get_trainable_params(optimizer),
                self.max_grad_norm
            )

            optimizer.step()
            optimizer.zero_grad()

    def get_trainable_params(self, optimizer):
        """
        获取可训练参数
        """
        params = []
        for group in optimizer.param_groups:
            params.extend(group['params'])
        return params

分布式训练

4.1 DDP训练

Python
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP

class DistributedTrainer:
    """
    分布式训练器
    """

    def __init__(self, rank, world_size):
        self.rank = rank
        self.world_size = world_size

        # 初始化进程组
        dist.init_process_group(
            backend='nccl',
            init_method='env://',
            world_size=world_size,
            rank=rank
        )

        torch.cuda.set_device(rank)

    def setup_model(self, model):
        """
        设置分布式模型
        """
        model = model.to(self.rank)
        model = DDP(
            model,
            device_ids=[self.rank],
            output_device=self.rank,
            find_unused_parameters=False
        )
        return model

    def setup_dataloader(self, dataset, batch_size):
        """
        设置分布式数据加载器
        """
        from torch.utils.data.distributed import DistributedSampler

        sampler = DistributedSampler(
            dataset,
            num_replicas=self.world_size,
            rank=self.rank,
            shuffle=True
        )

        dataloader = DataLoader(
            dataset,
            batch_size=batch_size,
            sampler=sampler,
            num_workers=4,
            pin_memory=True
        )

        return dataloader, sampler

    def train_step(self, model, batch, optimizer, scaler=None):
        """
        单步训练
        """
        model.train()

        # 前向传播
        if scaler is not None:
            with torch.amp.autocast('cuda'):
                outputs = model(**batch)
                loss = outputs.loss  # DDP已自动处理梯度平均,无需手动除以world_size
        else:
            outputs = model(**batch)
            loss = outputs.loss  # DDP已自动处理梯度平均

        # 反向传播
        if scaler is not None:
            scaler.scale(loss).backward()
        else:
            loss.backward()

        # 只在主进程打印
        if self.rank == 0:
            print(f"Loss: {loss.item():.4f}")

        return loss

    def cleanup(self):
        """
        清理
        """
        dist.destroy_process_group()

4.2 DeepSpeed集成

Python
class DeepSpeedTrainer:
    """
    DeepSpeed训练器
    """

    def __init__(self, model, config_path):
        import deepspeed

        # 加载配置
        with open(config_path, 'r') as f:  # with自动管理文件关闭
            ds_config = json.load(f)

        # 初始化DeepSpeed
        model_engine, optimizer, _, _ = deepspeed.initialize(
            model=model,
            model_parameters=model.parameters(),
            config=ds_config
        )

        self.model = model_engine
        self.optimizer = optimizer

    def train_step(self, batch):
        """
        训练步骤
        """
        # 前向传播
        outputs = self.model(**batch)
        loss = outputs.loss

        # 反向传播(DeepSpeed自动处理梯度累积和混合精度)
        self.model.backward(loss)

        # 更新参数
        self.model.step()

        return loss.item()

    def save_checkpoint(self, save_dir, tag):
        """
        保存检查点
        """
        self.model.save_checkpoint(save_dir, tag)

    def load_checkpoint(self, load_dir, tag):
        """
        加载检查点
        """
        self.model.load_checkpoint(load_dir, tag)

训练稳定性

5.1 损失尖峰处理

Python
class LossSpikeHandler:
    """
    损失尖峰处理
    """

    def __init__(self, spike_threshold=5.0, skip_threshold=10.0):
        self.spike_threshold = spike_threshold
        self.skip_threshold = skip_threshold
        self.loss_history = []

    def check_loss(self, loss):
        """
        检查损失是否正常

        Returns:
            'normal': 正常
            'spike': 尖峰,但可接受
            'skip': 跳过此步
        """
        self.loss_history.append(loss)

        if len(self.loss_history) < 10:
            return 'normal'

        # 计算移动平均
        recent_avg = sum(self.loss_history[-10:]) / 10

        # 检查是否超过阈值
        if loss > recent_avg * self.skip_threshold:
            return 'skip'
        elif loss > recent_avg * self.spike_threshold:
            return 'spike'

        return 'normal'

    def get_loss_scale(self, loss_status):
        """
        根据损失状态调整学习率缩放
        """
        if loss_status == 'skip':
            return 0.0  # 跳过此步
        elif loss_status == 'spike':
            return 0.5  # 降低学习率
        else:
            return 1.0  # 正常

class GradientStabilizer:
    """
    梯度稳定器
    """

    def __init__(self, max_grad_norm=1.0, clip_value=1.0):
        self.max_grad_norm = max_grad_norm
        self.clip_value = clip_value

    def stabilize(self, model):
        """
        稳定梯度
        """
        # 1. 梯度裁剪
        torch.nn.utils.clip_grad_norm_(
            model.parameters(),
            self.max_grad_norm
        )

        # 2. 检查NaN/Inf
        has_nan = False
        for param in model.parameters():
            if param.grad is not None:
                if torch.isnan(param.grad).any() or torch.isinf(param.grad).any():
                    has_nan = True
                    param.grad.zero_()

        return has_nan

5.2 检查点与恢复

Python
class CheckpointManager:
    """
    检查点管理器
    """

    def __init__(self, checkpoint_dir, keep_last_n=3):
        self.checkpoint_dir = Path(checkpoint_dir)
        self.checkpoint_dir.mkdir(parents=True, exist_ok=True)
        self.keep_last_n = keep_last_n
        self.checkpoints = []

    def save(self, model, optimizer, scheduler, step, loss, metrics=None):
        """
        保存检查点
        """
        checkpoint = {
            'step': step,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'scheduler_state_dict': scheduler.state_dict() if scheduler else None,
            'loss': loss,
            'metrics': metrics or {}
        }

        # 保存
        checkpoint_path = self.checkpoint_dir / f'checkpoint_step_{step}.pt'
        torch.save(checkpoint, checkpoint_path)

        # 记录
        self.checkpoints.append((step, checkpoint_path))

        # 清理旧检查点
        self._cleanup_old_checkpoints()

        # 更新最新检查点链接
        self._update_latest_link(checkpoint_path)

        print(f"Checkpoint saved: {checkpoint_path}")

    def load(self, model, optimizer, scheduler=None, checkpoint_path=None):
        """
        加载检查点
        """
        if checkpoint_path is None:
            checkpoint_path = self.checkpoint_dir / 'checkpoint_latest.pt'

        checkpoint = torch.load(checkpoint_path, map_location='cpu', weights_only=True)

        model.load_state_dict(checkpoint['model_state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])

        if scheduler and checkpoint['scheduler_state_dict']:
            scheduler.load_state_dict(checkpoint['scheduler_state_dict'])

        step = checkpoint['step']
        loss = checkpoint['loss']

        print(f"Checkpoint loaded: {checkpoint_path}, step={step}, loss={loss}")

        return step, loss

    def _cleanup_old_checkpoints(self):
        """
        清理旧检查点
        """
        if len(self.checkpoints) > self.keep_last_n:
            # 按步数排序
            self.checkpoints.sort(key=lambda x: x[0])  # lambda匿名函数

            # 删除旧的
            while len(self.checkpoints) > self.keep_last_n:
                step, path = self.checkpoints.pop(0)
                if path.exists():
                    path.unlink()
                    print(f"Removed old checkpoint: {path}")

    def _update_latest_link(self, checkpoint_path):
        """
        更新最新检查点链接
        """
        latest_path = self.checkpoint_dir / 'checkpoint_latest.pt'
        if latest_path.exists() or latest_path.is_symlink():
            latest_path.unlink()
        latest_path.symlink_to(checkpoint_path.name)

训练监控与调试

6.1 训练日志与监控

Python
class TrainingMonitor:
    """
    训练监控器
    """

    def __init__(self, log_dir, use_wandb=False):
        self.log_dir = Path(log_dir)
        self.log_dir.mkdir(parents=True, exist_ok=True)

        self.use_wandb = use_wandb
        if use_wandb:
            import wandb
            wandb.init(project='llm-pretraining')

        # 创建日志文件
        self.log_file = self.log_dir / 'training.log'

        # 指标历史
        self.metrics_history = defaultdict(list)  # defaultdict自动为缺失键提供默认值

    def log_step(self, step, metrics):
        """
        记录单步指标
        """
        # 添加到历史
        for key, value in metrics.items():
            self.metrics_history[key].append((step, value))

        # 写入日志文件
        with open(self.log_file, 'a') as f:
            timestamp = datetime.now().strftime('%Y-%m-%d %H:%M:%S')
            metrics_str = ', '.join([f"{k}={v:.4f}" for k, v in metrics.items()])
            f.write(f"[{timestamp}] Step {step}: {metrics_str}\n")

        # 发送到wandb
        if self.use_wandb:
            import wandb
            wandb.log(metrics, step=step)

    def log_epoch(self, epoch, metrics):
        """
        记录epoch指标
        """
        print(f"Epoch {epoch}: {metrics}")

        if self.use_wandb:
            import wandb
            wandb.log({f"epoch/{k}": v for k, v in metrics.items()}, step=epoch)

    def plot_metrics(self, metrics_to_plot=None):
        """
        绘制指标曲线
        """
        import matplotlib.pyplot as plt

        if metrics_to_plot is None:
            metrics_to_plot = list(self.metrics_history.keys())

        num_metrics = len(metrics_to_plot)
        fig, axes = plt.subplots(
            (num_metrics + 1) // 2, 2,
            figsize=(15, 5 * ((num_metrics + 1) // 2))
        )
        axes = axes.flatten() if num_metrics > 1 else [axes]

        for idx, metric_name in enumerate(metrics_to_plot):
            if metric_name not in self.metrics_history:
                continue

            steps, values = zip(*self.metrics_history[metric_name])

            axes[idx].plot(steps, values)
            axes[idx].set_xlabel('Step')
            axes[idx].set_ylabel(metric_name)
            axes[idx].set_title(f'{metric_name} over time')
            axes[idx].grid(True)

        plt.tight_layout()
        plt.savefig(self.log_dir / 'metrics.png')
        plt.show()

class GPUMonitor:
    """
    GPU监控器
    """

    @staticmethod
    def get_gpu_info():
        """
        获取GPU信息
        """
        if not torch.cuda.is_available():
            return {}

        info = {}
        for i in range(torch.cuda.device_count()):
            info[f'gpu_{i}'] = {
                'name': torch.cuda.get_device_name(i),
                'memory_allocated': torch.cuda.memory_allocated(i) / 1e9,  # GB
                'memory_reserved': torch.cuda.memory_reserved(i) / 1e9,
                'memory_total': torch.cuda.get_device_properties(i).total_memory / 1e9,
                'utilization': torch.cuda.utilization(i) if hasattr(torch.cuda, 'utilization') else None  # hasattr检查对象是否有某属性
            }

        return info

    @staticmethod
    def log_gpu_usage(log_file='gpu_usage.log'):
        """
        记录GPU使用情况
        """
        import subprocess

        # 使用nvidia-smi
        result = subprocess.run(
            ['nvidia-smi', '--query-gpu=timestamp,name,utilization.gpu,memory.used,memory.total',
             '--format=csv,noheader'],
            capture_output=True,
            text=True
        )

        with open(log_file, 'a') as f:
            f.write(result.stdout)

评估与Checkpoint管理

7.1 模型评估

Python
class ModelEvaluator:
    """
    模型评估器
    """

    def __init__(self, model, tokenizer):
        self.model = model
        self.tokenizer = tokenizer

    def evaluate_perplexity(self, eval_dataset, batch_size=8):
        """
        评估困惑度
        """
        self.model.eval()

        total_loss = 0
        total_tokens = 0

        dataloader = DataLoader(eval_dataset, batch_size=batch_size)

        with torch.no_grad():  # 禁用梯度计算,节省内存(推理时使用)
            for batch in dataloader:
                outputs = self.model(**batch)

                # 计算非padding token数
                labels = batch['labels']
                num_tokens = (labels != -100).sum().item()

                total_loss += outputs.loss.item() * num_tokens
                total_tokens += num_tokens

        # 计算困惑度
        avg_loss = total_loss / total_tokens
        perplexity = math.exp(avg_loss)

        return perplexity

    def evaluate_generation(self, prompts, max_length=100, num_return_sequences=1):
        """
        评估生成能力
        """
        self.model.eval()

        results = []

        with torch.no_grad():
            for prompt in prompts:
                inputs = self.tokenizer(prompt, return_tensors='pt')

                outputs = self.model.generate(
                    **inputs,
                    max_length=max_length,
                    num_return_sequences=num_return_sequences,
                    do_sample=True,
                    temperature=0.7,
                    top_p=0.9
                )

                generated_text = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
                results.append({
                    'prompt': prompt,
                    'generated': generated_text
                })

        return results

    def evaluate_downstream_tasks(self, task_names=None):
        """
        评估下游任务
        """
        from lm_eval import evaluator

        if task_names is None:
            task_names = ['hellaswag', 'arc_easy', 'winogrande']

        results = evaluator.simple_evaluate(
            model=self.model,
            tokenizer=self.tokenizer,
            tasks=task_names,
            batch_size=8
        )

        return results

7.2 最佳Checkpoint选择

Python
class BestCheckpointSelector:
    """
    最佳检查点选择器
    """

    def __init__(self, metric='perplexity', mode='min'):
        """
        Args:
            metric: 监控的指标
            mode: 'min'或'max'
        """
        self.metric = metric
        self.mode = mode
        self.best_value = float('inf') if mode == 'min' else float('-inf')
        self.best_checkpoint = None

    def update(self, checkpoint_path, metrics):
        """
        更新最佳检查点
        """
        current_value = metrics.get(self.metric)

        if current_value is None:
            return

        is_better = (
            (self.mode == 'min' and current_value < self.best_value) or
            (self.mode == 'max' and current_value > self.best_value)
        )

        if is_better:
            self.best_value = current_value
            self.best_checkpoint = checkpoint_path

            print(f"New best checkpoint: {checkpoint_path}")
            print(f"{self.metric}: {current_value:.4f}")

    def get_best(self):
        """
        获取最佳检查点
        """
        return self.best_checkpoint, self.best_value

总结

预训练关键配置

Python
# 推荐配置(7B模型)
CONFIG_7B = {
    # 模型配置
    'vocab_size': 32000,
    'hidden_size': 4096,
    'num_layers': 32,
    'num_heads': 32,
    'intermediate_size': 11008,

    # 训练配置
    'batch_size': 4,  # 每卡
    'gradient_accumulation_steps': 8,
    'max_seq_length': 2048,
    'learning_rate': 3e-4,
    'warmup_steps': 2000,
    'total_steps': 100000,

    # 优化器配置
    'optimizer': 'adamw',
    'weight_decay': 0.1,
    'beta1': 0.9,
    'beta2': 0.95,
    'eps': 1e-8,

    # 混合精度
    'fp16': True,
    'bf16': False,

    # 分布式
    'distributed': True,
    'world_size': 8,  # 8卡

    # 检查点
    'checkpoint_steps': 1000,
    'keep_last_n': 3
}

预训练检查清单

  • 数据清洗和去重完成
  • Tokenizer训练完成
  • 分布式环境配置正确
  • 混合精度配置正确
  • 学习率调度配置正确
  • 检查点保存策略配置
  • 监控和日志系统就绪
  • 故障恢复机制就绪

下一步:学习04-对齐技术,掌握RLHF和DPO等对齐技术!


最后更新日期:2026-02-12 适用版本:LLM学习教程 v2026