01-模型压缩与加速¶
⚠️ 时效性说明:本章涉及前沿模型/价格/榜单等信息,可能随版本快速变化;请以论文原文、官方发布页和 API 文档为准。
学习时间: 约6-8小时 难度级别: ⭐⭐⭐⭐ 中高级 前置知识: 深度学习基础、CNN/Transformer架构、PyTorch 学习目标: 掌握知识蒸馏、模型剪枝、量化等主流模型压缩技术的原理与实现
目录¶
1. 为什么需要模型压缩¶
1.1 部署挑战¶
| 模型 | 参数量 | 大小 | FLOPs | 推理延迟 |
|---|---|---|---|---|
| ResNet-50 | 25M | 98MB | 4.1G | ~10ms |
| ViT-Base | 86M | 330MB | 17.6G | ~30ms |
| BERT-Base | 110M | 420MB | — | ~50ms |
| GPT-2 | 1.5B | 6GB | — | ~200ms |
移动端/边缘设备对模型大小、计算量、内存、延迟都有严格限制。
1.2 压缩技术全景¶
2. 知识蒸馏¶
2.1 基本思想¶
Hinton et al.(2015):用大模型(Teacher)的"软标签"来训练小模型(Student)。
软标签包含类别之间的相似度信息(如"3"和"8"在 Teacher 的 softmax 输出中可能都有较高概率)。
2.2 蒸馏损失¶
\[\mathcal{L} = (1-\alpha) \cdot \mathcal{L}_{CE}(y, p_s) + \alpha \cdot T^2 \cdot KL(p_t^T \| p_s^T)\]
其中 \(p^T = \text{softmax}(z/T)\) 是温度缩放后的软标签,\(T\) 是温度参数。
Python
import torch
import torch.nn as nn
import torch.nn.functional as F
class DistillationLoss(nn.Module): # 继承nn.Module定义神经网络层
"""知识蒸馏损失"""
def __init__(self, temperature=4.0, alpha=0.7): # __init__构造方法,创建对象时自动调用
super().__init__() # super()调用父类方法
self.T = temperature
self.alpha = alpha
self.ce_loss = nn.CrossEntropyLoss()
def forward(self, student_logits, teacher_logits, labels):
# 硬标签损失
hard_loss = self.ce_loss(student_logits, labels)
# 软标签损失(KL散度)
soft_student = F.log_softmax(student_logits / self.T, dim=1)
soft_teacher = F.softmax(teacher_logits / self.T, dim=1)
soft_loss = F.kl_div(soft_student, soft_teacher, reduction='batchmean')
# 总损失
loss = (1 - self.alpha) * hard_loss + self.alpha * (self.T ** 2) * soft_loss
return loss
2.3 完整蒸馏训练¶
Python
def knowledge_distillation(teacher, student, train_loader, epochs=50, device='cuda'):
"""知识蒸馏训练"""
teacher.eval() # eval()开启评估模式(关闭Dropout等)
student.to(device) # .to(device)将数据移至GPU/CPU
teacher.to(device)
optimizer = torch.optim.Adam(student.parameters(), lr=1e-3)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs)
distill_loss = DistillationLoss(temperature=4.0, alpha=0.7)
for epoch in range(epochs):
student.train() # train()开启训练模式
total_loss = 0
correct = 0
total = 0
for images, labels in train_loader:
images, labels = images.to(device), labels.to(device)
with torch.no_grad(): # 禁用梯度计算,节省内存
teacher_logits = teacher(images)
student_logits = student(images)
loss = distill_loss(student_logits, teacher_logits, labels)
optimizer.zero_grad() # 清零梯度,防止梯度累积
loss.backward() # 反向传播计算梯度
optimizer.step() # 根据梯度更新模型参数
total_loss += loss.item() # .item()将单元素张量转为Python数值
_, predicted = student_logits.max(1)
correct += predicted.eq(labels).sum().item() # 链式调用,连续执行多个方法
total += labels.size(0)
scheduler.step()
acc = 100. * correct / total
print(f"Epoch {epoch+1}: Loss={total_loss/len(train_loader):.4f}, Acc={acc:.2f}%")
return student
2.4 特征蒸馏¶
除了输出层蒸馏,还可以对中间层特征进行蒸馏:
Python
class FeatureDistillationLoss(nn.Module):
"""特征蒸馏 — 匹配中间层特征"""
def __init__(self, student_channels, teacher_channels):
super().__init__()
# 通道适配器(学生和教师通道数可能不同)
self.adaptor = nn.Conv2d(student_channels, teacher_channels, 1)
def forward(self, student_feat, teacher_feat):
adapted = self.adaptor(student_feat)
# L2 距离
loss = F.mse_loss(adapted, teacher_feat)
return loss
3. 模型剪枝¶
3.1 非结构化剪枝¶
移除单个权重(设为 0),产生稀疏矩阵。
Python
import torch.nn.utils.prune as prune
def unstructured_pruning(model, amount=0.3):
"""非结构化剪枝 — 按权重大小"""
for name, module in model.named_modules():
if isinstance(module, (nn.Conv2d, nn.Linear)): # isinstance检查对象类型
prune.l1_unstructured(module, name='weight', amount=amount)
# 查看稀疏度
total = 0
pruned = 0
for name, module in model.named_modules():
if isinstance(module, (nn.Conv2d, nn.Linear)):
total += module.weight.nelement()
pruned += (module.weight == 0).sum().item()
print(f"全局稀疏度: {100. * pruned / total:.1f}%")
return model
def make_pruning_permanent(model):
"""使剪枝永久化(移除 mask)"""
for name, module in model.named_modules():
if isinstance(module, (nn.Conv2d, nn.Linear)):
try: # try/except捕获异常,防止程序崩溃
prune.remove(module, 'weight')
except ValueError:
pass
return model
3.2 结构化剪枝¶
移除整个滤波器/通道/注意力头,直接减小模型大小和计算量。
Python
def structured_pruning_conv(model, amount=0.3):
"""结构化剪枝 — 按通道的 L1 范数"""
for name, module in model.named_modules():
if isinstance(module, nn.Conv2d):
prune.ln_structured(module, name='weight', amount=amount, n=1, dim=0)
return model
class ChannelPruner:
"""基于通道重要性的结构化剪枝"""
def __init__(self, model):
self.model = model
def compute_channel_importance(self):
"""计算每个通道的重要性(L1范数)"""
importances = {}
for name, module in self.model.named_modules():
if isinstance(module, nn.Conv2d):
weight = module.weight.data
# 每个输出通道的 L1 范数
importance = weight.abs().sum(dim=(1, 2, 3))
importances[name] = importance
return importances
def prune(self, ratio=0.3):
"""按比例剪枝最不重要的通道"""
importances = self.compute_channel_importance()
for name, module in self.model.named_modules():
if isinstance(module, nn.Conv2d) and name in importances:
imp = importances[name]
num_prune = int(len(imp) * ratio)
if num_prune > 0:
_, indices = imp.sort()
prune_indices = indices[:num_prune]
# 将要剪枝的通道权重置零
module.weight.data[prune_indices] = 0
return self.model
3.3 彩票假说(Lottery Ticket Hypothesis)¶
Frankle & Carbin(2019):随机初始化的网络中存在子网络(“中奖彩票”),该子网络独立训练能达到与完整网络相同的性能。
Python
def lottery_ticket_experiment(model_fn, train_fn, prune_ratio=0.2, rounds=5):
"""彩票假说实验(迭代剪枝)"""
model = model_fn()
original_state = {k: v.clone() for k, v in model.state_dict().items()}
for round_idx in range(rounds):
# 训练
model = train_fn(model)
# 剪枝(按权重大小)
for name, module in model.named_modules():
if isinstance(module, (nn.Conv2d, nn.Linear)):
prune.l1_unstructured(module, name='weight', amount=prune_ratio)
prune.remove(module, 'weight')
# 将未剪枝的权重重置为原始初始化值(关键步骤!)
current_state = model.state_dict()
for key in current_state:
if 'weight' in key:
mask = (current_state[key] != 0).float()
current_state[key] = original_state[key] * mask
model.load_state_dict(current_state)
total = sum(p.nelement() for p in model.parameters())
nonzero = sum((p != 0).sum().item() for p in model.parameters())
print(f"Round {round_idx+1}: 剩余参数 {100.*nonzero/total:.1f}%")
return model
4. 模型量化¶
4.1 量化基础¶
将 FP32 权重/激活值映射到低精度(INT8/INT4):
\[q = \text{round}\left(\frac{x}{s}\right) + z\]
其中 \(s\) 是缩放因子,\(z\) 是零点。
4.2 PyTorch 动态量化¶
Python
import torch.ao.quantization as quant
def dynamic_quantization(model):
"""动态量化 — 权重静态量化,激活动态量化"""
quantized_model = torch.ao.quantization.quantize_dynamic(
model,
{nn.Linear, nn.LSTM}, # 要量化的层类型
dtype=torch.qint8
)
return quantized_model
# 示例
class SimpleModel(nn.Module):
def __init__(self):
super().__init__()
self.fc1 = nn.Linear(784, 256)
self.relu = nn.ReLU()
self.fc2 = nn.Linear(256, 10)
def forward(self, x):
return self.fc2(self.relu(self.fc1(x)))
model = SimpleModel()
quantized = dynamic_quantization(model)
# 大小对比
import os, tempfile
def get_model_size(model, path="temp.pt"):
torch.save(model.state_dict(), path)
size = os.path.getsize(path) / 1024 / 1024
os.remove(path)
return size
print(f"原始模型: {get_model_size(model):.2f} MB")
print(f"量化模型: {get_model_size(quantized):.2f} MB")
4.3 静态量化(PTQ)¶
Python
def static_quantization(model, calibration_loader, device='cpu'):
"""训练后静态量化(Post-Training Quantization)"""
model.eval()
model.to(device)
# 1. 融合操作
model_fused = torch.ao.quantization.fuse_modules(model, [['fc1', 'relu']])
# 2. 设置量化配置
model_fused.qconfig = torch.ao.quantization.get_default_qconfig('fbgemm')
# 3. 准备量化
model_prepared = torch.ao.quantization.prepare(model_fused)
# 4. 校准(用少量数据跑一遍前向传播)
with torch.no_grad():
for data, _ in calibration_loader:
data = data.view(-1, 784).to(device) # view重塑张量形状(要求内存连续)
model_prepared(data)
# 5. 转换为量化模型
model_quantized = torch.ao.quantization.convert(model_prepared)
return model_quantized
4.4 量化感知训练(QAT)¶
Python
def quantization_aware_training(model, train_loader, epochs=10, device='cpu'):
"""量化感知训练 — 训练时模拟量化噪声"""
model.train()
model.to(device)
model.qconfig = torch.ao.quantization.get_default_qat_qconfig('fbgemm')
model_prepared = torch.ao.quantization.prepare_qat(model)
optimizer = torch.optim.Adam(model_prepared.parameters(), lr=1e-4)
criterion = nn.CrossEntropyLoss()
for epoch in range(epochs):
for data, target in train_loader:
data = data.view(-1, 784).to(device)
target = target.to(device)
output = model_prepared(data)
loss = criterion(output, target)
optimizer.zero_grad()
loss.backward()
optimizer.step()
model_quantized = torch.ao.quantization.convert(model_prepared.eval())
return model_quantized
5. 低秩分解¶
5.1 SVD分解¶
将权重矩阵 \(W \in \mathbb{R}^{m \times n}\) 分解为 \(W \approx U_r \Sigma_r V_r^T\),其中 \(r \ll \min(m, n)\)。
Python
def svd_decompose_linear(layer, rank):
"""用 SVD 分解线性层"""
W = layer.weight.data
U, S, Vt = torch.linalg.svd(W, full_matrices=False)
U_r = U[:, :rank]
S_r = S[:rank]
Vt_r = Vt[:rank, :]
# W ≈ (U_r * S_r) @ Vt_r = W1 @ W2
W1 = U_r * S_r.unsqueeze(0) # (m, r) # unsqueeze增加一个维度
W2 = Vt_r # (r, n)
# 替换为两个小线性层
layer1 = nn.Linear(W2.size(1), rank, bias=False)
layer2 = nn.Linear(rank, W1.size(0), bias=layer.bias is not None)
layer1.weight.data = W2
layer2.weight.data = W1
if layer.bias is not None:
layer2.bias.data = layer.bias.data
compressed = nn.Sequential(layer1, layer2)
# 压缩比
original_params = W.numel()
compressed_params = W1.numel() + W2.numel()
print(f"SVD: {original_params} → {compressed_params} "
f"({100.*compressed_params/original_params:.1f}%)")
return compressed
6. 高效架构设计¶
6.1 深度可分离卷积¶
Python
class DepthwiseSeparableConv(nn.Module):
"""深度可分离卷积(MobileNet 核心模块)"""
def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1):
super().__init__()
self.depthwise = nn.Conv2d(in_channels, in_channels, kernel_size,
stride, padding, groups=in_channels)
self.pointwise = nn.Conv2d(in_channels, out_channels, 1)
self.bn1 = nn.BatchNorm2d(in_channels)
self.bn2 = nn.BatchNorm2d(out_channels)
self.relu = nn.ReLU6(inplace=True)
def forward(self, x):
x = self.relu(self.bn1(self.depthwise(x)))
x = self.relu(self.bn2(self.pointwise(x)))
return x
def param_ratio(self, in_channels, out_channels, kernel_size=3):
standard = in_channels * out_channels * kernel_size * kernel_size
sep = in_channels * kernel_size * kernel_size + in_channels * out_channels
return sep / standard
# 参数对比
dsc = DepthwiseSeparableConv(64, 128)
ratio = dsc.param_ratio(64, 128)
print(f"深度可分离卷积参数量是标准卷积的 {ratio:.1%}")
7. 推理优化实战¶
7.1 ONNX导出¶
Python
def export_to_onnx(model, input_shape, path="model.onnx"):
"""导出为 ONNX 格式"""
model.eval()
dummy_input = torch.randn(*input_shape)
torch.onnx.export(
model, dummy_input, path,
input_names=['input'],
output_names=['output'],
dynamic_axes={'input': {0: 'batch'}, 'output': {0: 'batch'}},
opset_version=13
)
print(f"导出到 {path}")
7.2 TorchScript¶
Python
def optimize_with_torchscript(model, input_shape):
"""用 TorchScript 优化推理"""
model.eval()
# 方法1: Tracing
example = torch.randn(*input_shape)
traced = torch.jit.trace(model, example)
# 方法2: Scripting
# scripted = torch.jit.script(model)
# 对比推理速度
import time
with torch.no_grad():
# 预热
for _ in range(10):
model(example)
traced(example)
# 原始模型
start = time.time()
for _ in range(100):
model(example)
original_time = time.time() - start
# TorchScript
start = time.time()
for _ in range(100):
traced(example)
traced_time = time.time() - start
print(f"原始模型: {original_time*10:.1f} ms/iter")
print(f"TorchScript: {traced_time*10:.1f} ms/iter")
print(f"加速比: {original_time/traced_time:.2f}x")
return traced
7.3 综合压缩流水线¶
Python
def full_compression_pipeline(teacher, student_fn, train_loader, cal_loader, device='cuda'):
"""完整的模型压缩流水线"""
print("=" * 50)
print("Step 1: 知识蒸馏")
student = student_fn()
student = knowledge_distillation(teacher, student, train_loader, epochs=30, device=device)
print("\nStep 2: 结构化剪枝")
student = structured_pruning_conv(student, amount=0.2)
# 微调恢复精度
# student = fine_tune(student, train_loader, epochs=10)
print("\nStep 3: 量化")
student_quantized = dynamic_quantization(student.cpu())
print("\nStep 4: 导出 ONNX")
# export_to_onnx(student, (1, 3, 224, 224))
print("\n压缩完成!")
return student_quantized
8. 练习与自我检查¶
练习题¶
- 知识蒸馏:用 ResNet-50 作为 Teacher,ResNet-18 作为 Student,在 CIFAR-10 上做蒸馏,对比直接训练 ResNet-18。
- 剪枝实验:对训练好的 ResNet 进行不同比例的非结构化剪枝,画出稀疏度 vs 精度曲线。
- 量化对比:在同一模型上分别应用动态量化、静态量化、QAT,对比精度和推理速度。
- SVD压缩:对 BERT 的全连接层做 SVD 低秩分解,分析不同秩对性能的影响。
- 综合实战:设计一个完整的压缩流水线(蒸馏 → 剪枝 → 量化),在移动端部署模型。
自我检查清单¶
- 理解知识蒸馏中温度参数 T 的作用
- 能区分结构化剪枝和非结构化剪枝
- 了解 PTQ 和 QAT 的区别
- 能使用 PyTorch 的量化工具
- 理解 SVD 低秩分解的原理
- 能设计完整的模型压缩方案
下一篇: 02-自监督学习



