跳转至

09-分布式训练

学习时间: 约5小时 难度级别: ⭐⭐⭐⭐ 中高级 前置知识: PyTorch基础、深度学习训练流程

📌 交叉引用:LLM预训练场景的分布式训练(DeepSpeed ZeRO实战、Megatron-LM)请参考 LLM学习/03-系统与工程/02-训练基础设施.md,本章侧重通用分布式训练原理与PyTorch原生方案。


本章目标

  • 理解分布式训练的核心概念(数据并行、模型并行、流水线并行)
  • 掌握PyTorch DDP(分布式数据并行)的完整使用
  • 掌握PyTorch FSDP(全分片数据并行)的配置与调优
  • 理解3D并行策略及其组合方式
  • 学会分布式训练的调试、性能分析与故障排除

1. 为什么需要分布式训练

1.1 单卡限制

Text Only
模型规模增长 vs 单卡显存:

  ResNet-50 (2015):    25M参数      ~100MB显存
  BERT-Large (2018):   340M参数     ~1.3GB显存
  GPT-2 (2019):        1.5B参数     ~6GB显存
  LLaMA-7B (2023):     7B参数       ~28GB显存 (FP32)
  LLaMA-70B (2023):    70B参数      ~280GB显存 (FP32)

  单卡A100 = 80GB → 连7B模型的FP32都放不下

  解决方案:
  1. 混合精度训练 (FP16/BF16): 显存减半
  2. 数据并行: 加速训练,不减少模型显存占用
  3. 模型并行: 将模型分到多张卡上
  4. ZeRO/FSDP: 分片优化器状态/梯度/参数

1.2 并行策略总览

数据并行vs模型并行

Text Only
                    分布式训练策略
                          |
          ┌───────────────┼───────────────┐
          |               |               |
     数据并行(DP)      模型并行(MP)     流水线并行(PP)
     ├─ DDP            ├─ 张量并行(TP)   └─ GPipe
     ├─ ZeRO-DP        │  列切分/行切分      Micro-batch
     └─ FSDP           └─ 序列并行(SP)

3D并行 = DP + TP + PP (同时使用三种)
  └─ 训练百亿级以上模型的标准方案

2. 数据并行(Data Parallelism)

2.1 原理

Text Only
数据并行是最简单的分布式策略:

  数据集 ────→ 分成N份 ────→ 每张GPU处理一份
  模型副本: GPU0 GPU1 GPU2 GPU3  (每张卡有完整模型)
  前向传播: 各自计算loss
  反向传播: 各自计算梯度
  AllReduce: 同步梯度(求平均)
  参数更新: 所有卡更新后参数一致

2.2 PyTorch DDP(推荐方案)

Python
import torch
import torch.distributed as dist
import torch.nn as nn
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data import DataLoader, DistributedSampler
import os

def setup(rank, world_size):
    """初始化分布式环境"""
    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = '12355'

    # 初始化进程组
    # nccl: NVIDIA GPU通信后端(最快)
    # gloo: CPU通信后端(兼容性好)
    dist.init_process_group(
        backend='nccl',  # GPU用nccl, CPU用gloo
        rank=rank,
        world_size=world_size
    )
    torch.cuda.set_device(rank)

def cleanup():
    dist.destroy_process_group()

def train_ddp(rank, world_size, epochs=10):
    """DDP训练主函数"""
    setup(rank, world_size)

    # 1. 模型 → 指定GPU → DDP包装
    model = nn.Sequential(
        nn.Linear(784, 256), nn.ReLU(),
        nn.Linear(256, 128), nn.ReLU(),
        nn.Linear(128, 10)
    ).to(rank)

    ddp_model = DDP(model, device_ids=[rank])

    # 2. DistributedSampler确保每张卡看到不同的数据
    train_dataset = ...  # your dataset
    sampler = DistributedSampler(
        train_dataset,
        num_replicas=world_size,
        rank=rank,
        shuffle=True
    )
    dataloader = DataLoader(  # DataLoader批量加载数据,支持shuffle和多进程
        train_dataset,
        batch_size=64,
        sampler=sampler,
        num_workers=4,
        pin_memory=True
    )

    optimizer = torch.optim.AdamW(ddp_model.parameters(), lr=1e-3)
    criterion = nn.CrossEntropyLoss()

    # 3. 训练循环
    for epoch in range(epochs):
        sampler.set_epoch(epoch)  # 重要!确保每epoch数据打乱方式不同
        ddp_model.train()  # train()开启训练模式

        for batch_idx, (data, target) in enumerate(dataloader):  # enumerate同时获取索引和元素
            data, target = data.to(rank), target.to(rank)

            optimizer.zero_grad()  # 清零梯度,防止梯度累积
            output = ddp_model(data)
            loss = criterion(output, target)
            loss.backward()  # DDP自动在backward时同步梯度
            optimizer.step()  # 根据梯度更新模型参数

            if rank == 0 and batch_idx % 100 == 0:
                print(f"Epoch {epoch}, Batch {batch_idx}, Loss: {loss.item():.4f}")  # .item()将单元素张量转为Python数值

    # 4. 只在rank 0保存模型
    if rank == 0:
        torch.save(ddp_model.module.state_dict(), "model.pth")

    cleanup()

# 启动: torchrun --nproc_per_node=4 train.py
# 注意: torch.distributed.launch 已废弃,请统一使用 torchrun
if __name__ == "__main__":
    import torch.multiprocessing as mp
    world_size = torch.cuda.device_count()
    mp.spawn(train_ddp, args=(world_size,), nprocs=world_size, join=True)

2.3 DDP核心机制

Ring AllReduce通信

Python
# DDP关键技术点:

# 1. 梯度同步: AllReduce操作
#    所有GPU在backward()结束时自动同步梯度
#    使用Ring-AllReduce算法(带宽最优)

# 2. 梯度桶(Gradient Buckets):
#    不是每个参数单独同步,而是打包成桶
#    桶大小默认25MB,可调:
ddp_model = DDP(model, device_ids=[rank], bucket_cap_mb=25)

# 3. 计算与通信重叠:
#    DDP在backward()期间就开始同步已计算好的梯度
#    → 通信和计算并行进行

# 4. 广播初始参数:
#    DDP初始化时将rank 0的参数广播到所有其他rank
#    → 确保所有卡从相同的参数开始训练

2.4 torchrun启动方式

Bash
# 单机多卡
torchrun --nproc_per_node=4 train.py

# 多机多卡 (2台机器,每台4卡)
# 机器1:
torchrun --nproc_per_node=4 --nnodes=2 --node_rank=0 \
    --master_addr=192.168.1.1 --master_port=12355 train.py

# 机器2:
torchrun --nproc_per_node=4 --nnodes=2 --node_rank=1 \
    --master_addr=192.168.1.1 --master_port=12355 train.py

3. FSDP(全分片数据并行)

3.1 原理

Text Only
DDP的问题: 每张卡都存储完整模型参数 + 优化器状态 + 梯度
  → 对于大模型,仅优化器状态就可能超过单卡显存

  Adam优化器显存计算 (以7B FP32模型为例):
    参数: 7B × 4 bytes = 28GB
    梯度: 7B × 4 bytes = 28GB
    优化器状态 (m + v): 7B × 4 × 2 = 56GB
    总计: 112GB → 单张A100(80GB)放不下

FSDP (= ZeRO-3 的PyTorch原生实现):
  将参数、梯度、优化器状态都分片到各GPU

  分片策略:
  GPU0: 参数[0:1/4], 梯度[0:1/4], 优化器[0:1/4]
  GPU1: 参数[1/4:2/4], 梯度[1/4:2/4], 优化器[1/4:2/4]
  GPU2: 参数[2/4:3/4], 梯度[2/4:3/4], 优化器[2/4:3/4]
  GPU3: 参数[3/4:4/4], 梯度[3/4:4/4], 优化器[3/4:4/4]

  计算时: AllGather收集完整参数 → 前向/反向 → ReduceScatter分发梯度
  显存: 几乎线性减少(4卡≈4倍减少)

3.2 FSDP实现

Python
import torch
import torch.nn as nn
from torch.distributed.fsdp import (
    FullyShardedDataParallel as FSDP,
    MixedPrecision,
    ShardingStrategy,
    BackwardPrefetch,
    CPUOffload,
)
from torch.distributed.fsdp.wrap import (
    size_based_auto_wrap_policy,
    transformer_auto_wrap_policy,
)
from functools import partial

def setup_fsdp_model(model, rank):
    """配置FSDP模型"""

    # 混合精度策略
    mixed_precision_policy = MixedPrecision(
        param_dtype=torch.bfloat16,      # 参数存储精度
        reduce_dtype=torch.bfloat16,     # 梯度通信精度
        buffer_dtype=torch.bfloat16,     # Buffer精度
    )

    # 自动包装策略 (按参数量)
    auto_wrap_policy = partial(
        size_based_auto_wrap_policy,
        min_num_params=1e6  # 参数>1M的子模块单独分片
    )

    # 对Transformer模型,更推荐按层包装:
    # from transformers.models.llama.modeling_llama import LlamaDecoderLayer
    # auto_wrap_policy = partial(
    #     transformer_auto_wrap_policy,
    #     transformer_layer_cls={LlamaDecoderLayer}
    # )

    # FSDP包装
    fsdp_model = FSDP(
        model,

        # 分片策略
        sharding_strategy=ShardingStrategy.FULL_SHARD,
        # FULL_SHARD: 完全分片(= ZeRO-3,显存最省)
        # SHARD_GRAD_OP: 分片梯度和优化器(= ZeRO-2)
        # NO_SHARD: 不分片(= DDP)

        # 混合精度
        mixed_precision=mixed_precision_policy,

        # 自动包装
        auto_wrap_policy=auto_wrap_policy,

        # 预取优化(提前收集下一层参数,计算通信重叠)
        backward_prefetch=BackwardPrefetch.BACKWARD_PRE,

        # CPU Offload (极端显存不足时使用,会降低速度)
        # cpu_offload=CPUOffload(offload_params=True),

        device_id=rank,
        use_orig_params=True,  # 兼容非FSDP优化器
    )

    return fsdp_model

def train_fsdp(rank, world_size):
    """FSDP训练示例"""
    setup(rank, world_size)

    # 创建模型(先在CPU上,FSDP会自动管理设备)
    model = create_large_model()
    fsdp_model = setup_fsdp_model(model, rank)

    optimizer = torch.optim.AdamW(fsdp_model.parameters(), lr=1e-4)

    for epoch in range(num_epochs):
        for batch in dataloader:
            batch = {k: v.to(rank) for k, v in batch.items()}

            loss = fsdp_model(**batch).loss
            loss.backward()

            # 梯度裁剪(FSDP需要特殊处理)
            fsdp_model.clip_grad_norm_(max_norm=1.0)

            optimizer.step()
            optimizer.zero_grad()

    # FSDP模型保存
    # 方法1: 全量保存(收集到rank 0)
    from torch.distributed.fsdp import FullStateDictConfig, StateDictType

    full_state_cfg = FullStateDictConfig(offload_to_cpu=True, rank0_only=True)
    with FSDP.state_dict_type(fsdp_model, StateDictType.FULL_STATE_DICT, full_state_cfg):
        state_dict = fsdp_model.state_dict()
        if rank == 0:
            torch.save(state_dict, "model_fsdp.pth")

    cleanup()

3.3 FSDP vs DDP vs DeepSpeed ZeRO

特性 DDP FSDP DeepSpeed ZeRO-3
参数分片
梯度分片
优化器分片
实现方 PyTorch原生 PyTorch原生 微软
混合精度 AMP 内置 内置
CPU Offload ✓ (ZeRO-Infinity)
NVMe Offload
生态集成 所有框架 HF Trainer HF/Accelerate
适用规模 <10B 10-100B 10-1000B

4. 模型并行与3D并行

4.1 张量并行(Tensor Parallelism)

Text Only
将单个矩阵运算拆分到多个GPU上:

  线性层: Y = XW + b

  列切分 (Column Parallelism):
    W = [W1 | W2]  (按列分)
    GPU0: Y1 = X @ W1
    GPU1: Y2 = X @ W2
    Y = [Y1 | Y2] → AllGather

  行切分 (Row Parallelism):
    W = [W1]  (按行分)
        [W2]
    X = [X1 | X2]
    GPU0: Z1 = X1 @ W1
    GPU1: Z2 = X2 @ W2
    Y = Z1 + Z2 → AllReduce

  Transformer中的张量并行:
    MHA: Q,K,V按头数切分 → 每张卡处理部分头 → AllReduce
    FFN: 第一个线性层列切分 → 第二个线性层行切分 → AllReduce
Python
import torch
import torch.nn as nn
import torch.distributed as dist

class ColumnParallelLinear(nn.Module):  # 继承nn.Module定义神经网络层
    """列并行线性层 (简化实现)"""

    def __init__(self, in_features, out_features, world_size, rank):  # __init__构造方法,创建对象时自动调用
        super().__init__()  # super()调用父类方法
        assert out_features % world_size == 0  # assert断言,条件为False时抛出异常
        self.local_out = out_features // world_size
        self.rank = rank
        self.world_size = world_size

        # 每张卡只存储 1/world_size 的权重列
        self.weight = nn.Parameter(
            torch.randn(in_features, self.local_out) * 0.01
        )
        self.bias = nn.Parameter(torch.zeros(self.local_out))

    def forward(self, x):
        # 本地计算
        local_output = x @ self.weight + self.bias

        # AllGather收集所有卡的输出
        output_list = [torch.zeros_like(local_output) for _ in range(self.world_size)]
        dist.all_gather(output_list, local_output)

        return torch.cat(output_list, dim=-1)

class RowParallelLinear(nn.Module):
    """行并行线性层 (简化实现)"""

    def __init__(self, in_features, out_features, world_size, rank):
        super().__init__()
        assert in_features % world_size == 0
        self.local_in = in_features // world_size
        self.rank = rank
        self.world_size = world_size

        self.weight = nn.Parameter(
            torch.randn(self.local_in, out_features) * 0.01
        )
        self.bias = nn.Parameter(torch.zeros(out_features))

    def forward(self, x):
        # x已经是切分后的 (每张卡只有部分特征)
        local_output = x @ self.weight

        # AllReduce求和
        dist.all_reduce(local_output, op=dist.ReduceOp.SUM)

        return local_output + self.bias

4.2 流水线并行(Pipeline Parallelism)

Text Only
将模型按层切分到不同GPU:

  GPU0: Layer 0-7    (前1/4层)
  GPU1: Layer 8-15   (前2/4层)
  GPU2: Layer 16-23  (前3/4层)
  GPU3: Layer 24-31  (后1/4层)

  问题: 朴素流水线会导致GPU空闲(气泡)

  解决: Micro-batch

  将一个mini-batch分成多个micro-batch:

  时间 →→→→→→→→→→
  GPU0: [F1][F2][F3][F4]   [B4][B3][B2][B1]
  GPU1:     [F1][F2][F3][F4]   [B4][B3][B2][B1]
  GPU2:         [F1][F2][F3][F4]   [B4][B3][B2][B1]
  GPU3:             [F1][F2][F3][F4]   [B4][B3][B2][B1]

  F=前向, B=反向, 数字=micro-batch编号
  气泡(idle)只在开始和结束 → micro-batch越多效率越高

4.3 3D并行

混合并行策略

Text Only
3D并行 = 数据并行 × 张量并行 × 流水线并行

例: 64张GPU训练100B模型
  - 张量并行(TP=8): 同一台机器内的8张GPU, NVLink高带宽互联
  - 流水线并行(PP=4): 跨4组机器,模型分4段
  - 数据并行(DP=2): 2份数据副本
  - 总GPU: 8 × 4 × 2 = 64

配置原则:
  1. TP在机器内(NVLink带宽高) — TP对带宽要求最高
  2. PP跨机器(只需传递激活值) — 通信量相对小
  3. DP与PP正交 — 各PP stage内部做DP

5. 通信原语与性能优化

5.1 集合通信原语

Python
# 核心通信操作

# 1. Broadcast: 一个发送,所有接收
#    rank 0: [A, B, C, D] → 所有rank: [A, B, C, D]
dist.broadcast(tensor, src=0)

# 2. AllReduce: 所有归约,所有接收结果
#    rank 0: [1, 2]  rank 1: [3, 4]
#    → 所有rank: [4, 6]  (sum)
dist.all_reduce(tensor, op=dist.ReduceOp.SUM)

# 3. AllGather: 所有收集
#    rank 0: [A]  rank 1: [B]  rank 2: [C]
#    → 所有rank: [A, B, C]
output_list = [torch.zeros_like(tensor) for _ in range(world_size)]
dist.all_gather(output_list, tensor)

# 4. ReduceScatter: 归约后分发
#    rank 0: [A0, A1]  rank 1: [B0, B1]
#    → rank 0: [A0+B0]  rank 1: [A1+B1]
dist.reduce_scatter(output, input_list, op=dist.ReduceOp.SUM)

# 5. AllToAll: 全交换
#    rank 0: [A0, A1]  rank 1: [B0, B1]
#    → rank 0: [A0, B0]  rank 1: [A1, B1]
dist.all_to_all(output_list, input_list)

5.2 通信后端对比

后端 适用 带宽 特点
NCCL GPU-GPU 最高 NVIDIA专用,支持NVLink/InfiniBand
Gloo CPU/GPU 中等 跨平台,CPU训练首选
MPI CPU HPC标准,功能最全

5.3 性能分析

分布式训练性能

Python
import torch.profiler

# 使用PyTorch Profiler分析分布式训练性能
with torch.profiler.profile(
    activities=[
        torch.profiler.ProfilerActivity.CPU,
        torch.profiler.ProfilerActivity.CUDA
    ],
    schedule=torch.profiler.schedule(
        wait=1, warmup=1, active=3, repeat=1
    ),
    on_trace_ready=torch.profiler.tensorboard_trace_handler('./logs'),
    record_shapes=True,
    with_stack=True
) as prof:
    for step, (data, target) in enumerate(dataloader):
        output = model(data)
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        prof.step()

# 查看结果: tensorboard --logdir=./logs
# 重点关注:
# 1. 计算 vs 通信时间比(理想 > 80% 计算)
# 2. GPU利用率(理想 > 90%)
# 3. 通信等待时间(有无通信瓶颈)

6. 实战:HuggingFace Accelerate

Python
# Accelerate: 最简单的分布式训练方式(支持DDP/FSDP/DeepSpeed)
# pip install accelerate

from accelerate import Accelerator
from torch.utils.data import DataLoader

def train_with_accelerate():
    # 一行初始化
    accelerator = Accelerator(
        mixed_precision='bf16',  # 混合精度
        gradient_accumulation_steps=4,
    )

    model = create_model()
    optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)
    dataloader = DataLoader(dataset, batch_size=32)
    scheduler = get_scheduler("cosine", optimizer, ...)

    # Accelerate自动处理:
    # - 模型放到正确的GPU
    # - 数据自动分片
    # - 梯度自动同步
    model, optimizer, dataloader, scheduler = accelerator.prepare(
        model, optimizer, dataloader, scheduler
    )

    for epoch in range(num_epochs):
        model.train()
        for batch in dataloader:
            with accelerator.accumulate(model):
                outputs = model(**batch)
                loss = outputs.loss
                accelerator.backward(loss)
                optimizer.step()
                scheduler.step()
                optimizer.zero_grad()

            if accelerator.is_main_process:
                print(f"Loss: {loss.item():.4f}")

    # 保存
    accelerator.wait_for_everyone()
    unwrapped = accelerator.unwrap_model(model)
    accelerator.save(unwrapped.state_dict(), "model.pth")

# 配置: accelerate config  (交互式配置DDP/FSDP/DeepSpeed)
# 启动: accelerate launch train.py

Accelerate FSDP配置文件

YAML
# accelerate_config.yaml
compute_environment: LOCAL_MACHINE
distributed_type: FSDP
fsdp_config:
  fsdp_auto_wrap_policy: SIZE_BASED_WRAP
  fsdp_min_num_params: 1000000
  fsdp_sharding_strategy: FULL_SHARD
  fsdp_backward_prefetch_policy: BACKWARD_PRE
  fsdp_offload_params: false
  fsdp_state_dict_type: SHARDED_STATE_DICT
  fsdp_use_orig_params: true
mixed_precision: bf16
num_machines: 1
num_processes: 4

7. 常见问题与调试

7.1 死锁排查

Python
# 分布式训练最常见的bug: 进程死锁

# 原因1: 不同rank执行了不同的代码路径
# 错误示例:
if rank == 0:
    output = model(data)  # 只在rank 0执行 → 其他rank等待AllReduce → 死锁!

# 正确: 所有rank必须执行相同的集合操作

# 原因2: 数据不一致导致不同rank的iteration数不同
# 解决: 使用DistributedSampler + drop_last=True

# 调试工具:
# NCCL_DEBUG=INFO torchrun ... → 打印详细的NCCL通信日志
# TORCH_DISTRIBUTED_DEBUG=DETAIL → PyTorch分布式调试信息

7.2 显存优化清单

Text Only
显存不够?按优先级尝试:
  1. ✅ 混合精度 (BF16/FP16) → 显存减半,几乎无精度损失
  2. ✅ 梯度累积 → 等效大batch,不增加显存
  3. ✅ 梯度检查点 (Gradient Checkpointing) → 用时间换显存
  4. ✅ FSDP/ZeRO-3 → 分片参数+梯度+优化器
  5. ⚠️ CPU Offload → 大幅降速,最后手段
  6. ⚠️ 减小模型/batch → 影响训练效果
Python
# 梯度检查点: 不保存中间激活值,反向传播时重新计算
from torch.utils.checkpoint import checkpoint

class TransformerBlock(nn.Module):
    def forward(self, x):
        # 用checkpoint包装 → 不保存此层的激活值
        return checkpoint(self._forward_impl, x, use_reentrant=False)

    def _forward_impl(self, x):
        x = x + self.attn(self.norm1(x))
        x = x + self.ffn(self.norm2(x))
        return x

8. 练习题

代码实践

  1. 入门:使用DDP在多GPU上训练ResNet-18(CIFAR-10),对比单卡和多卡训练速度
  2. 进阶:使用FSDP训练一个1B参数的GPT模型(模拟显存不足场景)
  3. 高级:使用Accelerate配置FSDP+BF16+梯度累积,在4卡上微调LLaMA-7B

面试题

  1. DDP和FSDP的核心区别是什么?各自的通信开销?
  2. AllReduce是如何实现的?Ring-AllReduce的通信复杂度?
  3. 为什么张量并行适合放在机器内,而流水线并行适合跨机器?
  4. FSDP的FULL_SHARD和SHARD_GRAD_OP(ZeRO-2)各适用什么场景?
  5. 梯度累积和增大batch size在数学上等价吗?有什么差异?
  6. 如何计算训练一个10B模型需要多少GPU显存?
  7. 3D并行中TP/PP/DP的最优配比如何确定?

最后更新:2026年2月