01 - 模型压缩技术¶
⚠️ 时效性说明:本章涉及前沿模型/价格/榜单等信息,可能随版本快速变化;请以论文原文、官方发布页和 API 文档为准。
让大模型"瘦身"的同时保持性能
📖 章节概述¶
本章将深入探讨模型压缩的核心技术,包括剪枝、量化、蒸馏和知识蒸馏等方法。这些技术可以显著减少模型的参数量和计算量,使其能够在有限的资源下运行。
🎯 学习目标¶
完成本章后,你将能够:
- 理解模型压缩的基本原理
- 掌握剪枝、量化、蒸馏等核心技术
- 能够应用这些技术压缩实际模型
- 评估压缩后模型的性能损失
1. 模型压缩概述¶
1.1 为什么需要模型压缩?¶
资源限制: - 显存限制:消费级显卡通常只有8-24GB显存 - 计算能力:推理速度和吞吐量受限 - 能耗限制:移动设备和边缘设备的功耗限制
成本考虑: - 云端推理成本:按GPU小时计费 - 硬件成本:高端GPU价格昂贵 - 运维成本:服务器维护和电力成本
1.2 模型压缩技术分类¶
模型压缩技术
├── 结构化压缩
│ ├── 剪枝 (Pruning)
│ ├── 知识蒸馏 (Knowledge Distillation)
│ └── 架构设计 (Architecture Design)
├── 量化 (Quantization)
│ ├── 训练后量化 (PTQ)
│ ├── 量化感知训练 (QAT)
│ └── 混合精度 (Mixed Precision)
└── 其他技术
├── 权重共享 (Weight Sharing)
├── 低秩分解 (Low-Rank Factorization)
└── 神经架构搜索 (NAS)
2. 剪枝技术¶
2.1 剪枝原理¶
剪枝通过移除模型中不重要的参数来减少模型大小和计算量。
核心思想: - 识别并移除冗余或贡献小的参数 - 保持模型的核心功能 - 减少存储和计算开销
2.2 剪枝类型¶
非结构化剪枝¶
import torch
import torch.nn.utils.prune as prune
# 非结构化剪枝示例
def unstructured_pruning_example():
# 创建一个简单的线性层
layer = torch.nn.Linear(10, 5)
# 随机剪枝30%的权重
prune.l1_unstructured(layer, name='weight', amount=0.3)
print("剪枝后的权重稀疏度:",
(layer.weight == 0).float().mean().item()) # 使用.float()确保布尔值正确计算均值
# 移除剪枝掩码,永久应用剪枝
prune.remove(layer, 'weight')
return layer
# 运行示例
layer = unstructured_pruning_example()
特点: - ✅ 灵活性高,可以剪枝任意参数 - ✅ 压缩效果好 - ❌ 需要特殊硬件支持才能加速
结构化剪枝¶
import torch
import torch.nn as nn
# 结构化剪枝示例(剪枝整个神经元)
def structured_pruning_example():
model = nn.Sequential(
nn.Linear(10, 20),
nn.ReLU(),
nn.Linear(20, 5)
)
# 计算每个神经元的平均权重
layer1_weights = model[0].weight.abs().mean(dim=1)
# 选择剪枝50%的神经元
num_prune = int(0.5 * layer1_weights.shape[0])
prune_indices = torch.argsort(layer1_weights)[:num_prune]
print(f"剪枝的神经元索引:{prune_indices.tolist()}")
# 创建新的层,移除被剪枝的神经元
new_layer1 = nn.Linear(10, 20 - num_prune)
keep_indices = torch.tensor([i for i in range(20)
if i not in prune_indices])
new_layer1.weight.data = model[0].weight[keep_indices]
new_layer1.bias.data = model[0].bias[keep_indices]
# 更新下一层的输入维度
new_layer2 = nn.Linear(20 - num_prune, 5)
new_layer2.weight.data = model[2].weight[:, keep_indices]
new_layer2.bias.data = model[2].bias
return nn.Sequential(new_layer1, nn.ReLU(), new_layer2)
# 运行示例
pruned_model = structured_pruning_example()
特点: - ✅ 可以在标准硬件上加速 - ✅ 减少实际计算量 - ❌ 剪枝粒度较粗
2.3 剪枝策略¶
基于幅度的剪枝¶
import torch
def magnitude_based_pruning(model, sparsity=0.3):
"""
基于权重幅度的剪枝
Args:
model: 要剪枝的模型
sparsity: 目标稀疏度 (0-1)
"""
for name, param in model.named_parameters():
if 'weight' in name:
# 计算权重幅度的阈值
weight_abs = param.data.abs()
threshold = torch.quantile(weight_abs, sparsity)
# 创建掩码
mask = weight_abs > threshold
# 应用剪枝
param.data = param.data * mask.float()
print(f"{name}: 剪枝率 {100*sparsity:.1f}%")
return model
基于梯度的剪枝¶
import torch
def gradient_based_pruning(model, dataloader, criterion,
sparsity=0.3):
"""
基于梯度重要性的剪枝
Args:
model: 要剪枝的模型
dataloader: 数据加载器
criterion: 损失函数
sparsity: 目标稀疏度
"""
# 计算每个参数的梯度重要性
importance_scores = {}
for batch in dataloader:
inputs, targets = batch
outputs = model(inputs)
loss = criterion(outputs, targets)
loss.backward() # 反向传播计算梯度
for name, param in model.named_parameters():
if 'weight' in name:
if name not in importance_scores:
importance_scores[name] = torch.zeros_like(param)
importance_scores[name] += param.grad.abs()
model.zero_grad() # 清零梯度
# 基于重要性分数进行剪枝
for name, param in model.named_parameters():
if 'weight' in name:
importance = importance_scores[name]
threshold = torch.quantile(importance, sparsity)
mask = importance > threshold
param.data = param.data * mask.float()
return model
2.4 剪枝实践¶
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
# 创建示例模型和数据
class SimpleModel(nn.Module): # 继承nn.Module定义网络层
def __init__(self):
super().__init__() # super()调用父类方法
self.fc1 = nn.Linear(784, 256)
self.fc2 = nn.Linear(256, 128)
self.fc3 = nn.Linear(128, 10)
self.relu = nn.ReLU()
def forward(self, x):
x = self.relu(self.fc1(x))
x = self.relu(self.fc2(x))
x = self.fc3(x)
return x
# 创建虚拟数据
X = torch.randn(1000, 784)
y = torch.randint(0, 10, (1000,))
dataset = TensorDataset(X, y)
dataloader = DataLoader(dataset, batch_size=32, shuffle=True) # DataLoader批量加载数据
# 初始化模型
model = SimpleModel()
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
# 训练函数
def train_model(model, dataloader, epochs=5):
model.train() # train()训练模式
for epoch in range(epochs):
total_loss = 0
for batch_x, batch_y in dataloader:
optimizer.zero_grad()
outputs = model(batch_x)
loss = criterion(outputs, batch_y)
loss.backward()
optimizer.step() # 更新参数
total_loss += loss.item()
print(f"Epoch {epoch+1}, Loss: {total_loss/len(dataloader):.4f}")
# 训练原始模型
print("训练原始模型...")
train_model(model, dataloader)
# 应用剪枝
print("\n应用剪枝...")
pruned_model = magnitude_based_pruning(model, sparsity=0.4)
# 微调剪枝后的模型
print("\n微调剪枝后的模型...")
train_model(pruned_model, dataloader, epochs=3)
# 评估模型
def evaluate_model(model, dataloader):
model.eval()
correct = 0
total = 0
with torch.no_grad(): # 禁用梯度计算,节省内存
for batch_x, batch_y in dataloader:
outputs = model(batch_x)
_, predicted = torch.max(outputs.data, 1)
total += batch_y.size(0)
correct += (predicted == batch_y).sum().item()
return 100 * correct / total
print(f"\n原始模型准确率: {evaluate_model(model, dataloader):.2f}%")
print(f"剪枝后模型准确率: {evaluate_model(pruned_model, dataloader):.2f}%")
3. 量化技术¶
3.1 量化原理¶
量化通过减少参数的表示精度来减少模型大小和内存占用。
核心概念: - FP32: 32位浮点数(标准精度) - FP16: 16位浮点数(半精度) - INT8: 8位整数 - INT4: 4位整数
3.2 训练后量化 (PTQ)¶
import torch
import torch.quantization
# 训练后量化示例
def post_training_quantization(model, dataloader):
"""
训练后量化
Args:
model: 要量化的模型
dataloader: 用于校准的数据加载器
"""
# 设置模型为评估模式
model.eval()
# 配置量化
model.qconfig = torch.quantization.get_default_qconfig('fbgemm')
# 准备量化
model_prepared = torch.quantization.prepare(model)
# 使用校准数据进行校准
with torch.no_grad():
for batch_x, _ in dataloader:
model_prepared(batch_x)
# 转换为量化模型
quantized_model = torch.quantization.convert(model_prepared)
return quantized_model
# 使用示例
# quantized_model = post_training_quantization(model, dataloader)
3.3 量化感知训练 (QAT)¶
import torch
import torch.nn as nn
import torch.quantization
# 量化感知训练示例
def quantization_aware_training(model, dataloader, epochs=5):
"""
量化感知训练
Args:
model: 要训练的模型
dataloader: 数据加载器
epochs: 训练轮数
"""
# 配置量化
model.qconfig = torch.quantization.get_default_qat_qconfig('fbgemm')
# 准备量化感知训练
# 注意:inplace参数在PyTorch 2.0+已弃用,建议使用返回值
model_prepared = torch.quantization.prepare_qat(model)
# 训练
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model_prepared.parameters(), lr=0.001)
model_prepared.train()
for epoch in range(epochs):
for batch_x, batch_y in dataloader:
optimizer.zero_grad()
outputs = model_prepared(batch_x)
loss = criterion(outputs, batch_y)
loss.backward()
optimizer.step()
print(f"Epoch {epoch+1} completed")
# 转换为量化模型
quantized_model = torch.quantization.convert(model_prepared)
return quantized_model
3.4 INT4 量化¶
3.4.1 INT4 量化原理¶
INT4量化将模型权重从FP32/FP16压缩到4位整数表示,可以显著减少内存占用和提升推理速度。
核心优势: - 内存占用仅为FP32的12.5% - 推理速度提升8-16倍 - 可以在有限显存上运行大模型(如7B模型可在8GB显存上运行)
3.4.2 NF4量化类型详解¶
NF4(NormalFloat4)是bitsandbytes库中专门为神经网络权重设计的量化类型。
NF4的定义和原理: - NF4是一种基于正态分布优化的4位浮点数表示格式 - 它将权重映射到16个离散值(2^4=16),这些值经过特殊设计以匹配神经网络权重的统计分布 - NF4使用非线性量化间隔,能够更好地表示正态分布的权重值
NF4相对于FP4的优缺点:
| 特性 | NF4 | FP4 |
|---|---|---|
| 精度 | 更高(针对正态分布优化) | 较低(均匀量化) |
| 适用场景 | 神经网络权重(推荐) | 通用浮点数 |
| 数值稳定性 | 更好 | 一般 |
| 性能损失 | 1-3% | 3-5% |
| 内存占用 | 12.5% | 12.5% |
NF4的优势: 1. 针对神经网络优化:NF4的量化间隔是根据神经网络权重的正态分布特性设计的,能够更好地保留权重信息 2. 更低的精度损失:相比FP4,NF4通常能减少1-2%的精度损失 3. 更好的数值稳定性:在极端值情况下表现更稳定 4. 广泛兼容性:支持大多数Transformers模型
NF4的适用场景: - ✅ 推荐场景: - 大语言模型(LLM)的推理 - Transformer架构模型 - 权重呈正态分布的模型 - 需要在有限显存上运行大模型的场景
- ⚠️ 谨慎使用:
- 对精度要求极高的科学计算
- 需要精确数值控制的金融模型
- 权重分布非常特殊的模型
其他INT4量化类型:
- FP4(Float4):
- 传统的4位浮点数表示
- 使用均匀量化间隔
- 适用于通用浮点数场景
-
精度损失相对较大
-
INT4(整数):
- 纯整数表示
- 需要额外的缩放因子和零点
- 计算速度可能更快
- 精度损失较大
3.4.3 INT4量化精度损失与缓解策略¶
精度损失的典型范围:
| 模型类型 | NF4精度损失 | FP4精度损失 | 适用性 |
|---|---|---|---|
| 大语言模型(7B+) | 1-3% | 3-5% | ✅ 优秀 |
| 中等模型(1B-7B) | 2-4% | 4-6% | ✅ 良好 |
| 小型模型(<1B) | 3-5% | 5-8% | ⚠️ 需测试 |
| 特殊任务模型 | 4-8% | 6-10% | ❌ 不推荐 |
缓解精度损失的策略:
-
使用NF4量化类型:
-
启用双重量化:
-
使用合适的计算数据类型:
-
量化后微调(QLoRA):
-
混合精度策略:
- 关键层使用FP16/BF16
- 非关键层使用INT4
- 输入/输出层保持FP32
3.4.4 INT4量化最佳实践¶
✅ 推荐做法:
- 优先使用NF4:
- NF4是专门为神经网络优化的量化类型
- 在大多数情况下性能优于FP4
-
适合大多数Transformers模型
-
启用双重量化:
- 可以进一步减少内存占用
- 对精度影响很小
-
推荐配置:
bnb_4bit_use_double_quant=True -
选择合适的计算数据类型:
- FP16:速度更快,适合推理
- BF16:数值范围更大,适合训练
-
推荐配置:
bnb_4bit_compute_dtype=torch.float16 -
充分测试验证:
- 在多个任务上测试量化后的模型
- 对比量化前后的性能指标
-
关注实际应用场景的效果
-
考虑量化后微调:
- 使用QLoRA等技术进行微调
- 可以恢复部分精度损失
-
特别适合需要高精度的场景
-
渐进式应用:
- 先测试INT8
- 再尝试INT4
- 根据结果决定是否采用
❌ 避免做法:
- 不要盲目使用INT4:
- 不是所有模型都适合INT4量化
- 小型模型可能精度损失较大
-
需要充分测试验证
-
不要忽略计算数据类型:
bnb_4bit_compute_dtype的选择很重要- 使用错误的数据类型可能导致溢出
-
推荐使用FP16或BF16
-
不要跳过测试:
- INT4量化可能导致不可预测的行为
- 必须在目标场景中充分测试
-
记录量化前后的性能对比
-
不要在所有场景使用相同配置:
- 不同模型可能需要不同的量化策略
- 根据模型特点调整参数
- 考虑使用混合精度
3.4.5 INT4量化代码示例¶
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
# INT4量化示例
def int4_quantization(model_name="meta-llama/Llama-2-7b-hf"):
"""
INT4量化加载模型
Args:
model_name: 模型名称或路径
"""
# 加载分词器
tokenizer = AutoTokenizer.from_pretrained(model_name)
# 配置INT4量化
quantization_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=torch.float16, # 计算数据类型:FP16
bnb_4bit_use_double_quant=True, # 启用双重量化
bnb_4bit_quant_type="nf4" # 使用NF4量化类型(推荐)
)
# 加载量化模型
model = AutoModelForCausalLM.from_pretrained(
model_name,
quantization_config=quantization_config,
device_map="auto"
)
print(f"模型量化完成!")
print(f"模型大小: {model.get_memory_footprint() / 1e9:.2f} GB")
return model, tokenizer
# 使用示例
# model, tokenizer = int4_quantization()
3.4.6 INT4量化配置详解¶
from transformers import BitsAndBytesConfig
def create_int4_config():
"""
创建INT4量化配置
返回:
BitsAndBytesConfig: INT4量化配置对象
"""
quantization_config = BitsAndBytesConfig(
# ========== 基础配置 ==========
load_in_4bit=True, # 启用4位量化
# ========== 量化类型选择 ==========
# "nf4": NormalFloat4(推荐,针对神经网络优化)
# "fp4": Float4(传统浮点数表示)
bnb_4bit_quant_type="nf4",
# ========== 计算数据类型 ==========
# torch.float16: 16位浮点数(速度更快,适合推理)
# torch.bfloat16: 16位脑浮点数(数值范围更大,适合训练)
bnb_4bit_compute_dtype=torch.float16,
# ========== 双重量化 ==========
# True: 对量化参数进行二次量化,进一步减少内存
# False: 不使用双重量化
bnb_4bit_use_double_quant=True,
# ========== 其他高级配置 ==========
# llm_int8_threshold: INT8量化的异常值阈值(默认6.0)
# llm_int8_skip_modules: 跳过量化的模块列表
# llm_int8_enable_fp32_cpu_offload: 启用FP32 CPU卸载
)
return quantization_config
# 使用示例
# config = create_int4_config()
# model = AutoModelForCausalLM.from_pretrained(
# model_name,
# quantization_config=config,
# device_map="auto"
# )
3.4.7 不同INT4量化类型对比¶
from transformers import BitsAndBytesConfig
import torch
def compare_quantization_types():
"""
对比不同INT4量化类型的配置
"""
configs = {
"NF4(推荐)": BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=torch.float16,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4"
),
"FP4": BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=torch.float16,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="fp4"
),
"NF4 + BF16": BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=torch.bfloat16,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4"
),
"NF4 + 无双重量化": BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=torch.float16,
bnb_4bit_use_double_quant=False,
bnb_4bit_quant_type="nf4"
)
}
for name, config in configs.items():
print(f"{name}:")
print(f" - 量化类型: {config.bnb_4bit_quant_type}")
print(f" - 计算类型: {config.bnb_4bit_compute_dtype}")
print(f" - 双重量化: {config.bnb_4bit_use_double_quant}")
print()
return configs
# 使用示例
# configs = compare_quantization_types()
3.4.8 AWQ量化技术¶
AWQ(Activation-aware Weight Quantization)是一种先进的训练后量化技术,通过考虑激活值分布来优化权重量化。
AWQ核心原理¶
AWQ的核心思想是:并非所有权重对模型输出同等重要。通过分析激活值的分布,识别出对输出影响最大的权重通道,对这些关键通道保持较高精度,而对其他通道使用较低精度。
关键洞察: 1. 激活感知:通过分析激活值的分布来确定权重的重要性 2. 通道级量化:对不同的通道使用不同的量化策略 3. 保护关键权重:对激活值较大的通道对应的权重保持更高精度
AWQ vs GPTQ 对比¶
| 特性 | AWQ | GPTQ |
|---|---|---|
| 量化方法 | 激活感知权重量化 | 基于二阶信息的量化 |
| 校准数据需求 | 较少(128-512样本) | 中等(512-1024样本) |
| 量化速度 | 较快 | 较慢 |
| 精度保持 | 优秀(特别是4bit) | 良好 |
| 显存需求 | 较低 | 较高 |
| 适用场景 | 通用LLM量化 | 高精度需求场景 |
| 推理兼容性 | 广泛支持 | 需要特定内核 |
AWQ优势: - ✅ 更好的精度-压缩权衡 - ✅ 校准数据需求少 - ✅ 量化速度快 - ✅ 对异常值更鲁棒 - ✅ 广泛的硬件支持
GPTQ优势: - ✅ 理论基础更扎实 - ✅ 在某些模型上精度更高 - ✅ 社区成熟度高
AWQ量化实现¶
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from awq import AutoAWQForCausalLM
def awq_quantization(model_path, output_path, calib_data="pileval"):
"""
AWQ量化示例
Args:
model_path: 原始模型路径
output_path: 量化后模型保存路径
calib_data: 校准数据集名称
"""
# 加载模型和分词器
model = AutoAWQForCausalLM.from_pretrained(
model_path,
torch_dtype=torch.float16,
device_map="auto"
)
tokenizer = AutoTokenizer.from_pretrained(model_path)
# AWQ量化配置
quant_config = {
"zero_point": True, # 使用零点量化
"q_group_size": 128, # 量化组大小
"w_bit": 4, # 权重位数
"version": "GEMM" # 量化版本(GEMM/GEMV)
}
# 执行量化
model.quantize(
tokenizer,
quant_config=quant_config,
calib_data=calib_data, # 校准数据
n_samples=128 # 校准样本数
)
# 保存量化模型
model.save_quantized(output_path)
tokenizer.save_pretrained(output_path)
print(f"AWQ量化完成!模型保存至: {output_path}")
return model, tokenizer
# 使用示例
# model, tokenizer = awq_quantization(
# "meta-llama/Llama-2-7b-hf",
# "./llama2-7b-awq"
# )
AWQ推理使用¶
from awq import AutoAWQForCausalLM
from transformers import AutoTokenizer
def load_awq_model(model_path):
"""
加载AWQ量化模型进行推理
Args:
model_path: AWQ量化模型路径
"""
# 加载量化模型
model = AutoAWQForCausalLM.from_quantized(
model_path,
torch_dtype=torch.float16,
device_map="auto",
fuse_layers=True # 启用层融合加速
)
tokenizer = AutoTokenizer.from_pretrained(model_path)
# 推理
prompt = "请介绍一下人工智能的发展历程"
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
with torch.no_grad():
outputs = model.generate(
**inputs,
max_length=200,
do_sample=True,
temperature=0.7
)
result = tokenizer.decode(outputs[0], skip_special_tokens=True)
print(result)
return model, tokenizer
# 使用示例
# model, tokenizer = load_awq_model("./llama2-7b-awq")
AWQ与vLLM集成¶
from vllm import LLM, SamplingParams
def awq_vllm_inference(model_path):
"""
使用vLLM加载AWQ量化模型进行高效推理
Args:
model_path: AWQ量化模型路径
"""
# vLLM加载AWQ模型
llm = LLM(
model=model_path,
quantization="awq",
tensor_parallel_size=1, # GPU数量
gpu_memory_utilization=0.9
)
# 采样参数
sampling_params = SamplingParams(
temperature=0.7,
top_p=0.9,
max_tokens=200
)
# 批量推理
prompts = [
"什么是机器学习?",
"解释一下深度学习的概念",
"自然语言处理有哪些应用?"
]
outputs = llm.generate(prompts, sampling_params)
for output in outputs:
print(f"提示: {output.prompt}")
print(f"生成: {output.outputs[0].text}\n")
return llm
# 使用示例
# llm = awq_vllm_inference("./llama2-7b-awq")
AWQ最佳实践¶
推荐配置:
# 推荐的AWQ量化配置
recommended_config = {
"w_bit": 4, # 4位量化(推荐)
"q_group_size": 128, # 组大小128(平衡精度和速度)
"zero_point": True, # 启用零点
"version": "GEMM" # GEMM版本(更好的兼容性)
}
校准数据选择: - 使用与目标任务相似的数据 - 128-512个样本通常足够 - 数据多样性比数量更重要
性能优化建议: 1. 启用层融合(fuse_layers=True) 2. 使用Flash Attention加速 3. 配合vLLM进行批量推理 4. 选择合适的量化组大小
4. 知识蒸馏¶
4.1 蒸馏原理¶
知识蒸馏通过让一个"学生"模型学习"教师"模型的知识,在保持性能的同时减少模型大小。
核心思想: - 教师模型:大型、性能好的模型 - 学生模型:小型、高效的模型 - 软标签:教师模型的输出概率分布
4.2 蒸馏损失函数¶
import torch
import torch.nn as nn
import torch.nn.functional as F
class DistillationLoss(nn.Module):
"""
知识蒸馏损失函数
"""
def __init__(self, temperature=5.0, alpha=0.5):
super().__init__()
self.temperature = temperature
self.alpha = alpha
self.kl_div = nn.KLDivLoss(reduction='batchmean')
self.ce_loss = nn.CrossEntropyLoss()
def forward(self, student_outputs, teacher_outputs, targets):
"""
计算蒸馏损失
Args:
student_outputs: 学生模型的输出
teacher_outputs: 教师模型的输出
targets: 真实标签
"""
# 软标签损失(蒸馏损失)
soft_loss = self.kl_div(
F.log_softmax(student_outputs / self.temperature, dim=1), # F.xxx PyTorch函数式API
F.softmax(teacher_outputs / self.temperature, dim=1)
) * (self.temperature ** 2)
# 硬标签损失(标准交叉熵)
hard_loss = self.ce_loss(student_outputs, targets)
# 组合损失
loss = self.alpha * soft_loss + (1 - self.alpha) * hard_loss
return loss
# 使用示例
# criterion = DistillationLoss(temperature=5.0, alpha=0.5)
# loss = criterion(student_outputs, teacher_outputs, targets)
4.3 蒸馏训练¶
import torch
import torch.nn as nn
import torch.optim as optim
# 定义教师和学生模型
class TeacherModel(nn.Module):
def __init__(self):
super().__init__()
self.fc1 = nn.Linear(784, 512)
self.fc2 = nn.Linear(512, 256)
self.fc3 = nn.Linear(256, 128)
self.fc4 = nn.Linear(128, 10)
self.relu = nn.ReLU()
def forward(self, x):
x = self.relu(self.fc1(x))
x = self.relu(self.fc2(x))
x = self.relu(self.fc3(x))
x = self.fc4(x)
return x
class StudentModel(nn.Module):
def __init__(self):
super().__init__()
self.fc1 = nn.Linear(784, 128)
self.fc2 = nn.Linear(128, 10)
self.relu = nn.ReLU()
def forward(self, x):
x = self.relu(self.fc1(x))
x = self.fc2(x)
return x
# 蒸馏训练函数
def distillation_train(teacher, student, dataloader, epochs=10):
"""
知识蒸馏训练
Args:
teacher: 教师模型(冻结参数)
student: 学生模型
dataloader: 数据加载器
epochs: 训练轮数
"""
# 冻结教师模型参数
teacher.eval()
for param in teacher.parameters():
param.requires_grad = False
# 初始化学生模型和优化器
student.train()
optimizer = optim.Adam(student.parameters(), lr=0.001)
criterion = DistillationLoss(temperature=5.0, alpha=0.5)
# 训练循环
for epoch in range(epochs):
total_loss = 0
for batch_x, batch_y in dataloader:
# 教师模型前向传播
with torch.no_grad():
teacher_outputs = teacher(batch_x)
# 学生模型前向传播
student_outputs = student(batch_x)
# 计算损失
loss = criterion(student_outputs, teacher_outputs, batch_y)
# 反向传播和优化
optimizer.zero_grad()
loss.backward()
optimizer.step()
total_loss += loss.item()
print(f"Epoch {epoch+1}, Loss: {total_loss/len(dataloader):.4f}")
return student
# 使用示例
# teacher = TeacherModel()
# student = StudentModel()
# trained_student = distillation_train(teacher, student, dataloader)
5. 综合实践¶
5.1 完整的模型压缩流程¶
import torch
import torch.nn as nn
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
class ModelCompressor:
"""
模型压缩工具类
"""
def __init__(self, model_name):
self.model_name = model_name
self.tokenizer = None
self.model = None
def load_model(self, quantization="int4"):
"""
加载模型
Args:
quantization: 量化类型 ("int4", "int8", "fp16", "fp32")
"""
self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
if quantization == "int4":
# INT4量化配置(推荐使用NF4)
# NF4是专门为神经网络优化的量化类型,相比FP4有更好的精度表现
quantization_config = BitsAndBytesConfig(
load_in_4bit=True, # 启用4位量化
bnb_4bit_compute_dtype=torch.float16, # 计算数据类型:FP16(速度更快)
bnb_4bit_use_double_quant=True, # 启用双重量化,进一步减少内存
bnb_4bit_quant_type="nf4" # 使用NF4量化类型(推荐,针对神经网络优化)
)
self.model = AutoModelForCausalLM.from_pretrained(
self.model_name,
quantization_config=quantization_config,
device_map="auto"
)
elif quantization == "int8":
quantization_config = BitsAndBytesConfig(
load_in_8bit=True
)
self.model = AutoModelForCausalLM.from_pretrained(
self.model_name,
quantization_config=quantization_config,
device_map="auto"
)
elif quantization == "fp16":
self.model = AutoModelForCausalLM.from_pretrained(
self.model_name,
torch_dtype=torch.float16,
device_map="auto"
)
else:
self.model = AutoModelForCausalLM.from_pretrained(
self.model_name,
device_map="auto"
)
print(f"模型加载完成!量化类型: {quantization}")
print(f"模型大小: {self.model.get_memory_footprint() / 1e9:.2f} GB")
def prune_model(self, sparsity=0.3):
"""
剪枝模型
Args:
sparsity: 剪枝稀疏度
"""
for name, param in self.model.named_parameters():
if 'weight' in name and len(param.shape) > 1:
weight_abs = param.data.abs()
threshold = torch.quantile(weight_abs, sparsity)
mask = weight_abs > threshold
param.data = param.data * mask.float()
print(f"模型剪枝完成!剪枝率: {100*sparsity:.1f}%")
def save_model(self, output_path):
"""
保存模型
Args:
output_path: 保存路径
"""
self.model.save_pretrained(output_path)
self.tokenizer.save_pretrained(output_path)
print(f"模型保存到: {output_path}")
# 使用示例
compressor = ModelCompressor("meta-llama/Llama-2-7b-hf")
compressor.load_model(quantization="int4")
compressor.prune_model(sparsity=0.2)
compressor.save_model("./compressed_model")
6. 练习题¶
基础练习¶
-
实现简单的非结构化剪枝
-
实现INT8量化
进阶练习¶
-
实现完整的知识蒸馏训练流程
-
实现自适应剪枝策略
项目练习¶
- 创建一个模型压缩工具
- 支持多种压缩技术
- 提供压缩前后性能对比
- 可视化压缩效果
7. 最佳实践¶
✅ 推荐做法¶
- 渐进式压缩
- 先进行轻度压缩,评估效果
- 逐步增加压缩强度
-
每次压缩后进行微调
-
性能评估
- 在多个数据集上评估
- 记录压缩前后的性能指标
-
关注实际应用场景
-
保存检查点
- 保存压缩前的原始模型
- 保存中间压缩结果
- 便于回滚和对比
❌ 避免做法¶
- 过度压缩
- 不要一次性压缩太多
- 避免性能大幅下降
-
保持模型可用性
-
忽略微调
- 压缩后一定要微调
- 避免直接使用压缩后的模型
-
给模型恢复时间
-
单一指标
- 不要只看压缩率
- 综合考虑性能和效率
- 关注实际应用效果
8. 总结¶
本章介绍了模型压缩的核心技术:
- 剪枝: 移除不重要的参数
- 量化: 降低参数精度
- 蒸馏: 让小模型学习大模型的知识
这些技术可以单独使用,也可以组合使用,以达到最佳的压缩效果。
9. 下一步¶
继续学习02-低精度推理,深入了解低精度推理的技术细节。
