跳转至

01-模型压缩与加速

⚠️ 时效性说明:本章涉及前沿模型/价格/榜单等信息,可能随版本快速变化;请以论文原文、官方发布页和 API 文档为准。

学习时间: 约6-8小时 难度级别: ⭐⭐⭐⭐ 中高级 前置知识: 深度学习基础、CNN/Transformer架构、PyTorch 学习目标: 掌握知识蒸馏、模型剪枝、量化等主流模型压缩技术的原理与实现


目录


1. 为什么需要模型压缩

1.1 部署挑战

模型 参数量 大小 FLOPs 推理延迟
ResNet-50 25M 98MB 4.1G ~10ms
ViT-Base 86M 330MB 17.6G ~30ms
BERT-Base 110M 420MB ~50ms
GPT-2 1.5B 6GB ~200ms

移动端/边缘设备对模型大小、计算量、内存、延迟都有严格限制。

1.2 压缩技术全景

模型压缩技术对比

Text Only
                    模型压缩
                    /    |    \    \
            知识蒸馏  剪枝  量化  低秩分解
              |       |     |      |
          Teacher→    结构化  INT8   SVD/
          Student   非结构化  INT4   Tucker

2. 知识蒸馏

知识蒸馏示意图

2.1 基本思想

Hinton et al.(2015):用大模型(Teacher)的"软标签"来训练小模型(Student)。

软标签包含类别之间的相似度信息(如"3"和"8"在 Teacher 的 softmax 输出中可能都有较高概率)。

2.2 蒸馏损失

\[\mathcal{L} = (1-\alpha) \cdot \mathcal{L}_{CE}(y, p_s) + \alpha \cdot T^2 \cdot KL(p_t^T \| p_s^T)\]

其中 \(p^T = \text{softmax}(z/T)\) 是温度缩放后的软标签,\(T\) 是温度参数。

Python
import torch
import torch.nn as nn
import torch.nn.functional as F

class DistillationLoss(nn.Module):  # 继承nn.Module定义神经网络层
    """知识蒸馏损失"""
    def __init__(self, temperature=4.0, alpha=0.7):  # __init__构造方法,创建对象时自动调用
        super().__init__()  # super()调用父类方法
        self.T = temperature
        self.alpha = alpha
        self.ce_loss = nn.CrossEntropyLoss()

    def forward(self, student_logits, teacher_logits, labels):
        # 硬标签损失
        hard_loss = self.ce_loss(student_logits, labels)

        # 软标签损失(KL散度)
        soft_student = F.log_softmax(student_logits / self.T, dim=1)
        soft_teacher = F.softmax(teacher_logits / self.T, dim=1)
        soft_loss = F.kl_div(soft_student, soft_teacher, reduction='batchmean')

        # 总损失
        loss = (1 - self.alpha) * hard_loss + self.alpha * (self.T ** 2) * soft_loss
        return loss

2.3 完整蒸馏训练

Python
def knowledge_distillation(teacher, student, train_loader, epochs=50, device='cuda'):
    """知识蒸馏训练"""
    teacher.eval()  # eval()开启评估模式(关闭Dropout等)
    student.to(device)  # .to(device)将数据移至GPU/CPU
    teacher.to(device)

    optimizer = torch.optim.Adam(student.parameters(), lr=1e-3)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs)
    distill_loss = DistillationLoss(temperature=4.0, alpha=0.7)

    for epoch in range(epochs):
        student.train()  # train()开启训练模式
        total_loss = 0
        correct = 0
        total = 0

        for images, labels in train_loader:
            images, labels = images.to(device), labels.to(device)

            with torch.no_grad():  # 禁用梯度计算,节省内存
                teacher_logits = teacher(images)

            student_logits = student(images)
            loss = distill_loss(student_logits, teacher_logits, labels)

            optimizer.zero_grad()  # 清零梯度,防止梯度累积
            loss.backward()  # 反向传播计算梯度
            optimizer.step()  # 根据梯度更新模型参数

            total_loss += loss.item()  # .item()将单元素张量转为Python数值
            _, predicted = student_logits.max(1)
            correct += predicted.eq(labels).sum().item()  # 链式调用,连续执行多个方法
            total += labels.size(0)

        scheduler.step()
        acc = 100. * correct / total
        print(f"Epoch {epoch+1}: Loss={total_loss/len(train_loader):.4f}, Acc={acc:.2f}%")

    return student

2.4 特征蒸馏

除了输出层蒸馏,还可以对中间层特征进行蒸馏:

Python
class FeatureDistillationLoss(nn.Module):
    """特征蒸馏 — 匹配中间层特征"""
    def __init__(self, student_channels, teacher_channels):
        super().__init__()
        # 通道适配器(学生和教师通道数可能不同)
        self.adaptor = nn.Conv2d(student_channels, teacher_channels, 1)

    def forward(self, student_feat, teacher_feat):
        adapted = self.adaptor(student_feat)
        # L2 距离
        loss = F.mse_loss(adapted, teacher_feat)
        return loss

3. 模型剪枝

剪枝对比:结构化vs非结构化

3.1 非结构化剪枝

移除单个权重(设为 0),产生稀疏矩阵。

Python
import torch.nn.utils.prune as prune

def unstructured_pruning(model, amount=0.3):
    """非结构化剪枝 — 按权重大小"""
    for name, module in model.named_modules():
        if isinstance(module, (nn.Conv2d, nn.Linear)):  # isinstance检查对象类型
            prune.l1_unstructured(module, name='weight', amount=amount)

    # 查看稀疏度
    total = 0
    pruned = 0
    for name, module in model.named_modules():
        if isinstance(module, (nn.Conv2d, nn.Linear)):
            total += module.weight.nelement()
            pruned += (module.weight == 0).sum().item()

    print(f"全局稀疏度: {100. * pruned / total:.1f}%")
    return model

def make_pruning_permanent(model):
    """使剪枝永久化(移除 mask)"""
    for name, module in model.named_modules():
        if isinstance(module, (nn.Conv2d, nn.Linear)):
            try:  # try/except捕获异常,防止程序崩溃
                prune.remove(module, 'weight')
            except ValueError:
                pass
    return model

3.2 结构化剪枝

移除整个滤波器/通道/注意力头,直接减小模型大小和计算量。

Python
def structured_pruning_conv(model, amount=0.3):
    """结构化剪枝 — 按通道的 L1 范数"""
    for name, module in model.named_modules():
        if isinstance(module, nn.Conv2d):
            prune.ln_structured(module, name='weight', amount=amount, n=1, dim=0)
    return model

class ChannelPruner:
    """基于通道重要性的结构化剪枝"""
    def __init__(self, model):
        self.model = model

    def compute_channel_importance(self):
        """计算每个通道的重要性(L1范数)"""
        importances = {}
        for name, module in self.model.named_modules():
            if isinstance(module, nn.Conv2d):
                weight = module.weight.data
                # 每个输出通道的 L1 范数
                importance = weight.abs().sum(dim=(1, 2, 3))
                importances[name] = importance
        return importances

    def prune(self, ratio=0.3):
        """按比例剪枝最不重要的通道"""
        importances = self.compute_channel_importance()

        for name, module in self.model.named_modules():
            if isinstance(module, nn.Conv2d) and name in importances:
                imp = importances[name]
                num_prune = int(len(imp) * ratio)
                if num_prune > 0:
                    _, indices = imp.sort()
                    prune_indices = indices[:num_prune]
                    # 将要剪枝的通道权重置零
                    module.weight.data[prune_indices] = 0

        return self.model

3.3 彩票假说(Lottery Ticket Hypothesis)

Frankle & Carbin(2019):随机初始化的网络中存在子网络(“中奖彩票”),该子网络独立训练能达到与完整网络相同的性能。

Python
def lottery_ticket_experiment(model_fn, train_fn, prune_ratio=0.2, rounds=5):
    """彩票假说实验(迭代剪枝)"""
    model = model_fn()
    original_state = {k: v.clone() for k, v in model.state_dict().items()}

    for round_idx in range(rounds):
        # 训练
        model = train_fn(model)

        # 剪枝(按权重大小)
        for name, module in model.named_modules():
            if isinstance(module, (nn.Conv2d, nn.Linear)):
                prune.l1_unstructured(module, name='weight', amount=prune_ratio)
                prune.remove(module, 'weight')

        # 将未剪枝的权重重置为原始初始化值(关键步骤!)
        current_state = model.state_dict()
        for key in current_state:
            if 'weight' in key:
                mask = (current_state[key] != 0).float()
                current_state[key] = original_state[key] * mask
        model.load_state_dict(current_state)

        total = sum(p.nelement() for p in model.parameters())
        nonzero = sum((p != 0).sum().item() for p in model.parameters())
        print(f"Round {round_idx+1}: 剩余参数 {100.*nonzero/total:.1f}%")

    return model

4. 模型量化

量化对比:FP32 vs INT8 vs INT4

4.1 量化基础

将 FP32 权重/激活值映射到低精度(INT8/INT4):

\[q = \text{round}\left(\frac{x}{s}\right) + z\]

其中 \(s\) 是缩放因子,\(z\) 是零点。

4.2 PyTorch 动态量化

Python
import torch.ao.quantization as quant

def dynamic_quantization(model):
    """动态量化 — 权重静态量化,激活动态量化"""
    quantized_model = torch.ao.quantization.quantize_dynamic(
        model,
        {nn.Linear, nn.LSTM},  # 要量化的层类型
        dtype=torch.qint8
    )
    return quantized_model

# 示例
class SimpleModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(784, 256)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(256, 10)

    def forward(self, x):
        return self.fc2(self.relu(self.fc1(x)))

model = SimpleModel()
quantized = dynamic_quantization(model)

# 大小对比
import os, tempfile
def get_model_size(model, path="temp.pt"):
    torch.save(model.state_dict(), path)
    size = os.path.getsize(path) / 1024 / 1024
    os.remove(path)
    return size

print(f"原始模型: {get_model_size(model):.2f} MB")
print(f"量化模型: {get_model_size(quantized):.2f} MB")

4.3 静态量化(PTQ)

Python
def static_quantization(model, calibration_loader, device='cpu'):
    """训练后静态量化(Post-Training Quantization)"""
    model.eval()
    model.to(device)

    # 1. 融合操作
    model_fused = torch.ao.quantization.fuse_modules(model, [['fc1', 'relu']])

    # 2. 设置量化配置
    model_fused.qconfig = torch.ao.quantization.get_default_qconfig('fbgemm')

    # 3. 准备量化
    model_prepared = torch.ao.quantization.prepare(model_fused)

    # 4. 校准(用少量数据跑一遍前向传播)
    with torch.no_grad():
        for data, _ in calibration_loader:
            data = data.view(-1, 784).to(device)  # view重塑张量形状(要求内存连续)
            model_prepared(data)

    # 5. 转换为量化模型
    model_quantized = torch.ao.quantization.convert(model_prepared)

    return model_quantized

4.4 量化感知训练(QAT)

Python
def quantization_aware_training(model, train_loader, epochs=10, device='cpu'):
    """量化感知训练 — 训练时模拟量化噪声"""
    model.train()
    model.to(device)

    model.qconfig = torch.ao.quantization.get_default_qat_qconfig('fbgemm')
    model_prepared = torch.ao.quantization.prepare_qat(model)

    optimizer = torch.optim.Adam(model_prepared.parameters(), lr=1e-4)
    criterion = nn.CrossEntropyLoss()

    for epoch in range(epochs):
        for data, target in train_loader:
            data = data.view(-1, 784).to(device)
            target = target.to(device)

            output = model_prepared(data)
            loss = criterion(output, target)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

    model_quantized = torch.ao.quantization.convert(model_prepared.eval())
    return model_quantized

5. 低秩分解

5.1 SVD分解

将权重矩阵 \(W \in \mathbb{R}^{m \times n}\) 分解为 \(W \approx U_r \Sigma_r V_r^T\),其中 \(r \ll \min(m, n)\)

Python
def svd_decompose_linear(layer, rank):
    """用 SVD 分解线性层"""
    W = layer.weight.data
    U, S, Vt = torch.linalg.svd(W, full_matrices=False)

    U_r = U[:, :rank]
    S_r = S[:rank]
    Vt_r = Vt[:rank, :]

    # W ≈ (U_r * S_r) @ Vt_r = W1 @ W2
    W1 = U_r * S_r.unsqueeze(0)  # (m, r)  # unsqueeze增加一个维度
    W2 = Vt_r                      # (r, n)

    # 替换为两个小线性层
    layer1 = nn.Linear(W2.size(1), rank, bias=False)
    layer2 = nn.Linear(rank, W1.size(0), bias=layer.bias is not None)

    layer1.weight.data = W2
    layer2.weight.data = W1
    if layer.bias is not None:
        layer2.bias.data = layer.bias.data

    compressed = nn.Sequential(layer1, layer2)

    # 压缩比
    original_params = W.numel()
    compressed_params = W1.numel() + W2.numel()
    print(f"SVD: {original_params}{compressed_params} "
          f"({100.*compressed_params/original_params:.1f}%)")

    return compressed

6. 高效架构设计

6.1 深度可分离卷积

Python
class DepthwiseSeparableConv(nn.Module):
    """深度可分离卷积(MobileNet 核心模块)"""
    def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1):
        super().__init__()
        self.depthwise = nn.Conv2d(in_channels, in_channels, kernel_size,
                                    stride, padding, groups=in_channels)
        self.pointwise = nn.Conv2d(in_channels, out_channels, 1)
        self.bn1 = nn.BatchNorm2d(in_channels)
        self.bn2 = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU6(inplace=True)

    def forward(self, x):
        x = self.relu(self.bn1(self.depthwise(x)))
        x = self.relu(self.bn2(self.pointwise(x)))
        return x

    def param_ratio(self, in_channels, out_channels, kernel_size=3):
        standard = in_channels * out_channels * kernel_size * kernel_size
        sep = in_channels * kernel_size * kernel_size + in_channels * out_channels
        return sep / standard

# 参数对比
dsc = DepthwiseSeparableConv(64, 128)
ratio = dsc.param_ratio(64, 128)
print(f"深度可分离卷积参数量是标准卷积的 {ratio:.1%}")

7. 推理优化实战

7.1 ONNX导出

Python
def export_to_onnx(model, input_shape, path="model.onnx"):
    """导出为 ONNX 格式"""
    model.eval()
    dummy_input = torch.randn(*input_shape)

    torch.onnx.export(
        model, dummy_input, path,
        input_names=['input'],
        output_names=['output'],
        dynamic_axes={'input': {0: 'batch'}, 'output': {0: 'batch'}},
        opset_version=13
    )
    print(f"导出到 {path}")

7.2 TorchScript

Python
def optimize_with_torchscript(model, input_shape):
    """用 TorchScript 优化推理"""
    model.eval()

    # 方法1: Tracing
    example = torch.randn(*input_shape)
    traced = torch.jit.trace(model, example)

    # 方法2: Scripting
    # scripted = torch.jit.script(model)

    # 对比推理速度
    import time

    with torch.no_grad():
        # 预热
        for _ in range(10):
            model(example)
            traced(example)

        # 原始模型
        start = time.time()
        for _ in range(100):
            model(example)
        original_time = time.time() - start

        # TorchScript
        start = time.time()
        for _ in range(100):
            traced(example)
        traced_time = time.time() - start

    print(f"原始模型: {original_time*10:.1f} ms/iter")
    print(f"TorchScript: {traced_time*10:.1f} ms/iter")
    print(f"加速比: {original_time/traced_time:.2f}x")

    return traced

7.3 综合压缩流水线

Python
def full_compression_pipeline(teacher, student_fn, train_loader, cal_loader, device='cuda'):
    """完整的模型压缩流水线"""
    print("=" * 50)
    print("Step 1: 知识蒸馏")
    student = student_fn()
    student = knowledge_distillation(teacher, student, train_loader, epochs=30, device=device)

    print("\nStep 2: 结构化剪枝")
    student = structured_pruning_conv(student, amount=0.2)
    # 微调恢复精度
    # student = fine_tune(student, train_loader, epochs=10)

    print("\nStep 3: 量化")
    student_quantized = dynamic_quantization(student.cpu())

    print("\nStep 4: 导出 ONNX")
    # export_to_onnx(student, (1, 3, 224, 224))

    print("\n压缩完成!")
    return student_quantized

8. 练习与自我检查

练习题

  1. 知识蒸馏:用 ResNet-50 作为 Teacher,ResNet-18 作为 Student,在 CIFAR-10 上做蒸馏,对比直接训练 ResNet-18。
  2. 剪枝实验:对训练好的 ResNet 进行不同比例的非结构化剪枝,画出稀疏度 vs 精度曲线。
  3. 量化对比:在同一模型上分别应用动态量化、静态量化、QAT,对比精度和推理速度。
  4. SVD压缩:对 BERT 的全连接层做 SVD 低秩分解,分析不同秩对性能的影响。
  5. 综合实战:设计一个完整的压缩流水线(蒸馏 → 剪枝 → 量化),在移动端部署模型。

自我检查清单

  • 理解知识蒸馏中温度参数 T 的作用
  • 能区分结构化剪枝和非结构化剪枝
  • 了解 PTQ 和 QAT 的区别
  • 能使用 PyTorch 的量化工具
  • 理解 SVD 低秩分解的原理
  • 能设计完整的模型压缩方案

下一篇: 02-自监督学习