05 - 知识蒸馏(全面版)¶
⚠️ 时效性说明:本章涉及前沿模型/价格/榜单等信息,可能随版本快速变化;请以论文原文、官方发布页和 API 文档为准。
学习目标:深入理解知识蒸馏的原理、方法与实践,掌握如何将大模型知识迁移到小模型。
目录¶
知识蒸馏概述¶
1.1 什么是知识蒸馏¶
Text Only
知识蒸馏(Knowledge Distillation)
核心思想:将大型教师模型(Teacher)的知识迁移到小型学生模型(Student)
为什么需要蒸馏?
├── 大模型性能好但推理慢、部署难
├── 小模型速度快但性能差
└── 蒸馏:让小模型学习大模型的"暗知识"
暗知识(Dark Knowledge):
├── 不仅学习正确标签
├── 还学习类别间的相似关系
└── 例如:狗→狼的概率比狗→汽车高
类比:
├── 教师:经验丰富的专家
├── 学生:初学者
└── 蒸馏:学生学习专家的思考方式,不只是答案
1.2 蒸馏的基本框架¶
Text Only
┌─────────────────────────────────────────────────────────────────┐
│ 知识蒸馏框架 │
├─────────────────────────────────────────────────────────────────┤
│ │
│ ┌─────────────────────┐ ┌─────────────────────┐ │
│ │ 教师模型 (T) │ │ 学生模型 (S) │ │
│ │ (大、复杂) │ │ (小、简单) │ │
│ │ │ │ │ │
│ │ Input ──▶ T ──▶ │ │ Input ──▶ S ──▶ │ │
│ │ Soft logits │ │ Soft logits │ │
│ │ │ │ │ │ │ │
│ │ ▼ │ │ ▼ │ │
│ │ Softmax │ │ Softmax │ │
│ │ (高温τ) │ │ (高温τ) │ │
│ │ │ │ │ │ │ │
│ │ ▼ │ │ ▼ │ │
│ │ Soft Targets │ │ Soft Predictions │ │
│ │ │ │ │ │ │ │
│ │ └───────────┼────────┘ │ │ │
│ │ │ │ │ │
│ │ ▼ │ │ │
│ │ KL Divergence │ │ │
│ │ (蒸馏损失) │ │ │
│ │ │ │ │ │
│ │ ▼ ▼ │ │
│ │ Hard Targets ───────▶ CE Loss │ │
│ │ (真实标签) (学生损失) │ │
│ │ │ │
│ └──────────────────────────────────────────────────────┘ │
│ │ │
│ ▼ │
│ Total Loss = α·L_soft + β·L_hard │
│ │
└─────────────────────────────────────────────────────────────────┘
基础蒸馏方法¶
2.1 软标签蒸馏(Hinton Distillation)¶
Python
import torch
import torch.nn as nn
import torch.nn.functional as F
class KnowledgeDistillationLoss(nn.Module):
"""
Hinton知识蒸馏损失
"Distilling the Knowledge in a Neural Network" (Hinton et al., 2015)
"""
def __init__(self, temperature=4.0, alpha=0.7):
super().__init__() # super()调用父类方法
self.temperature = temperature
self.alpha = alpha # 蒸馏损失权重
def forward(self, student_logits, teacher_logits, true_labels):
"""
Args:
student_logits: 学生模型输出 [batch, num_classes]
teacher_logits: 教师模型输出 [batch, num_classes]
true_labels: 真实标签 [batch]
Returns:
total_loss: 总损失
distill_loss: 蒸馏损失(KL散度)
ce_loss: 交叉熵损失
"""
# 1. 软目标损失(KL散度)
soft_student = F.log_softmax(student_logits / self.temperature, dim=1)
soft_teacher = F.softmax(teacher_logits / self.temperature, dim=1)
# KL散度 = Σ p_teacher * log(p_teacher / p_student)
distill_loss = F.kl_div(
soft_student,
soft_teacher,
reduction='batchmean'
) * (self.temperature ** 2) # 缩放因子
# 2. 硬目标损失(交叉熵)
ce_loss = F.cross_entropy(student_logits, true_labels)
# 3. 总损失
total_loss = self.alpha * distill_loss + (1 - self.alpha) * ce_loss
return total_loss, distill_loss, ce_loss
# 温度参数的作用
def demonstrate_temperature():
"""
演示温度参数对软标签的影响
"""
import matplotlib.pyplot as plt
# 假设logits
logits = torch.tensor([2.0, 1.0, 0.1])
temperatures = [0.5, 1.0, 2.0, 4.0, 8.0]
plt.figure(figsize=(10, 6))
for T in temperatures:
probs = F.softmax(logits / T, dim=0)
plt.plot(probs.numpy(), label=f'T={T}', marker='o')
plt.xlabel('Class Index')
plt.ylabel('Probability')
plt.title('Effect of Temperature on Softmax Distribution')
plt.legend()
plt.grid(True)
plt.show()
"""
结论:
- T→0: 趋近于one-hot(硬标签)
- T→∞: 趋近于均匀分布
- 适中的T(2-4): 保留类别间关系,但不过于平滑
"""
# demonstrate_temperature()
2.2 特征蒸馏(Feature Distillation)¶
Python
class FeatureDistillationLoss(nn.Module):
"""
特征蒸馏
让学生模型学习教师模型的中间层特征表示
"""
def __init__(self, student_dim, teacher_dim, hidden_dim=None):
super().__init__()
if hidden_dim is None:
hidden_dim = student_dim
# 适配层:将学生特征映射到教师特征空间
self.adaptation_layer = nn.Sequential(
nn.Linear(student_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, teacher_dim)
)
def forward(self, student_features, teacher_features):
"""
Args:
student_features: [batch, student_dim]
teacher_features: [batch, teacher_dim]
"""
# 适配学生特征
adapted_student = self.adaptation_layer(student_features)
# 均方误差
loss = F.mse_loss(adapted_student, teacher_features)
return loss
class MultiLayerFeatureDistillation(nn.Module):
"""
多层特征蒸馏(FitNets)
在多个中间层进行特征蒸馏
"""
def __init__(self, student_dims, teacher_dims, weights=None):
super().__init__()
assert len(student_dims) == len(teacher_dims) # assert断言:条件False时抛出AssertionError
self.num_layers = len(student_dims)
# 为每层创建适配层
self.adaptation_layers = nn.ModuleList([
nn.Linear(s_dim, t_dim)
for s_dim, t_dim in zip(student_dims, teacher_dims) # zip按位置配对多个可迭代对象
])
# 每层的权重
if weights is None:
weights = [1.0] * self.num_layers
self.weights = weights
def forward(self, student_features_list, teacher_features_list):
"""
Args:
student_features_list: 学生模型多层特征
teacher_features_list: 教师模型多层特征
"""
total_loss = 0
for i, (s_feat, t_feat, adapt_layer) in enumerate(
zip(student_features_list, teacher_features_list, self.adaptation_layers)
):
# 适配并计算损失
adapted = adapt_layer(s_feat)
loss = F.mse_loss(adapted, t_feat)
total_loss += self.weights[i] * loss
return total_loss
2.3 注意力蒸馏¶
Python
class AttentionDistillationLoss(nn.Module):
"""
注意力蒸馏
让学生模型学习教师模型的注意力模式
"""
def __init__(self, temperature=1.0):
super().__init__()
self.temperature = temperature
def forward(self, student_attention, teacher_attention):
"""
Args:
student_attention: [batch, num_heads, seq_len, seq_len]
teacher_attention: [batch, num_heads, seq_len, seq_len]
"""
# 方法1: KL散度
student_attn = F.log_softmax(student_attention / self.temperature, dim=-1)
teacher_attn = F.softmax(teacher_attention / self.temperature, dim=-1)
kl_loss = F.kl_div(
student_attn,
teacher_attn,
reduction='batchmean'
)
# 方法2: 均方误差(有时效果更好)
mse_loss = F.mse_loss(student_attention, teacher_attention)
# 组合
total_loss = kl_loss + mse_loss
return total_loss
class MultiHeadAttentionTransfer(nn.Module):
"""
多头注意力迁移
TinyBERT等方法使用
"""
def __init__(self, student_num_heads, teacher_num_heads):
super().__init__()
self.student_num_heads = student_num_heads
self.teacher_num_heads = teacher_num_heads
def forward(self, student_attn_list, teacher_attn_list):
"""
对齐不同层数的注意力
"""
# 选择教师模型的特定层进行蒸馏
# 例如:学生6层,教师12层,选择教师第2,4,6,8,10,12层
selected_teacher_layers = self.select_layers(
len(teacher_attn_list),
len(student_attn_list)
)
total_loss = 0
for student_idx, teacher_idx in enumerate(selected_teacher_layers): # enumerate同时获取索引和元素
s_attn = student_attn_list[student_idx]
t_attn = teacher_attn_list[teacher_idx]
# 如果头数不同,需要平均或插值
if s_attn.size(1) != t_attn.size(1):
t_attn = t_attn.mean(dim=1, keepdim=True).expand(-1, s_attn.size(1), -1, -1)
loss = F.mse_loss(s_attn, t_attn)
total_loss += loss
return total_loss / len(student_attn_list)
def select_layers(self, teacher_layers, student_layers):
"""
选择对应的教师层
"""
import numpy as np
indices = np.linspace(0, teacher_layers - 1, student_layers, dtype=int)
return indices.tolist()
高级蒸馏技术¶
3.1 自蒸馏(Self-Distillation)¶
Python
class SelfDistillationLoss(nn.Module):
"""
自蒸馏
模型自己蒸馏自己,不需要单独的教师模型
"""
def __init__(self, temperature=4.0):
super().__init__()
self.temperature = temperature
def forward(self, logits, labels, epoch, total_epochs):
"""
随着训练进行,逐渐增加软标签的影响
"""
# 硬标签损失
hard_loss = F.cross_entropy(logits, labels)
# 生成软标签(使用当前模型自己的预测)
with torch.no_grad(): # 禁用梯度计算,节省内存(推理时使用)
soft_targets = F.softmax(logits / self.temperature, dim=1)
# 软标签损失
# ⚠️ 演示说明: 此处 soft_targets 与 logits 来自同一分布,理论上 KL(p‖p)≡0。
# 实际自蒸馏中,soft_targets 应来自上一个 epoch 保存的模型输出(即“历史版本”作为教师)。
# 这里仅作为代码结构示例。
soft_loss = F.kl_div(
F.log_softmax(logits / self.temperature, dim=1),
soft_targets,
reduction='batchmean'
) * (self.temperature ** 2)
# 动态权重:早期主要用硬标签,后期增加软标签
alpha = min(0.5, epoch / (total_epochs * 0.5))
total_loss = (1 - alpha) * hard_loss + alpha * soft_loss
return total_loss
3.2 在线蒸馏(Online Distillation)¶
Python
class OnlineDistillation:
"""
在线蒸馏(Deep Mutual Learning)
多个学生模型互相学习,没有固定的教师
"""
def __init__(self, models, temperature=4.0):
self.models = models # 多个学生模型
self.temperature = temperature
self.optimizers = [torch.optim.AdamW(m.parameters()) for m in models]
def train_step(self, batch):
"""
所有模型互相蒸馏
"""
inputs, labels = batch
# 获取所有模型的预测
logits_list = [model(inputs) for model in self.models]
# 计算平均预测作为"软教师"
avg_logits = torch.stack(logits_list).mean(dim=0)
# 每个模型向平均预测学习
for i, (model, optimizer) in enumerate(zip(self.models, self.optimizers)):
optimizer.zero_grad()
# 硬标签损失
hard_loss = F.cross_entropy(logits_list[i], labels)
# 向其他模型学习(排除自己)
other_logits = torch.stack(
[logits_list[j] for j in range(len(self.models)) if j != i]
).mean(dim=0)
soft_loss = F.kl_div(
F.log_softmax(logits_list[i] / self.temperature, dim=1),
F.softmax(other_logits / self.temperature, dim=1),
reduction='batchmean'
) * (self.temperature ** 2)
# 总损失
loss = hard_loss + soft_loss
loss.backward()
optimizer.step()
3.3 任务特定蒸馏¶
Python
class TaskSpecificDistillation:
"""
任务特定蒸馏
针对不同任务设计特定的蒸馏策略
"""
@staticmethod # @staticmethod无需实例即可调用
def classification_distillation(student_logits, teacher_logits, labels, alpha=0.7, T=4.0):
"""
分类任务蒸馏
"""
soft_loss = F.kl_div(
F.log_softmax(student_logits / T, dim=1),
F.softmax(teacher_logits / T, dim=1),
reduction='batchmean'
) * (T ** 2)
hard_loss = F.cross_entropy(student_logits, labels)
return alpha * soft_loss + (1 - alpha) * hard_loss
@staticmethod
def ner_distillation(student_logits, teacher_logits, labels, attention_mask, alpha=0.7, T=4.0):
"""
命名实体识别任务蒸馏
每个token独立计算
"""
batch_size, seq_len, num_classes = student_logits.shape
# 展平
student_logits = student_logits.view(-1, num_classes) # view重塑张量形状
teacher_logits = teacher_logits.view(-1, num_classes)
labels = labels.view(-1)
attention_mask = attention_mask.view(-1)
# 只计算有效token
valid_indices = attention_mask.bool()
student_logits = student_logits[valid_indices]
teacher_logits = teacher_logits[valid_indices]
labels = labels[valid_indices]
soft_loss = F.kl_div(
F.log_softmax(student_logits / T, dim=1),
F.softmax(teacher_logits / T, dim=1),
reduction='batchmean'
) * (T ** 2)
hard_loss = F.cross_entropy(student_logits, labels)
return alpha * soft_loss + (1 - alpha) * hard_loss
@staticmethod
def generation_distillation(student_logits, teacher_logits, labels, alpha=0.7, T=4.0):
"""
生成任务蒸馏
序列生成任务的蒸馏
"""
# 移位:预测下一个token
shift_student = student_logits[..., :-1, :].contiguous()
shift_teacher = teacher_logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# 展平
shift_student = shift_student.view(-1, shift_student.size(-1))
shift_teacher = shift_teacher.view(-1, shift_teacher.size(-1))
shift_labels = shift_labels.view(-1)
soft_loss = F.kl_div(
F.log_softmax(shift_student / T, dim=1),
F.softmax(shift_teacher / T, dim=1),
reduction='batchmean'
) * (T ** 2)
hard_loss = F.cross_entropy(shift_student, shift_labels)
return alpha * soft_loss + (1 - alpha) * hard_loss
白盒蒸馏与黑盒蒸馏¶
4.1 白盒蒸馏¶
Python
class WhiteBoxDistillation:
"""
白盒蒸馏
可以访问教师模型的内部参数和特征
"""
def __init__(self, teacher_model, student_model):
self.teacher_model = teacher_model
self.student_model = student_model
# 注册hook获取中间特征
self.teacher_features = []
self.student_features = []
self._register_hooks()
def _register_hooks(self):
"""
注册前向hook获取特征
"""
self.hook_handles = [] # 保存hook句柄,用于后续清理
def get_features(name, storage):
def hook(module, input, output):
storage.append(output)
return hook
# 在特定层注册hook
for i, layer in enumerate(self.teacher_model.transformer_layers):
handle = layer.register_forward_hook(
get_features(f'teacher_layer_{i}', self.teacher_features)
)
self.hook_handles.append(handle)
for i, layer in enumerate(self.student_model.transformer_layers):
handle = layer.register_forward_hook(
get_features(f'student_layer_{i}', self.student_features)
)
self.hook_handles.append(handle)
def forward(self, inputs, labels):
"""
前向传播并计算蒸馏损失
"""
# 教师前向
with torch.no_grad():
teacher_logits = self.teacher_model(inputs)
# 学生前向
student_logits = self.student_model(inputs)
# 1. 输出蒸馏
output_loss = F.kl_div(
F.log_softmax(student_logits / 4.0, dim=1),
F.softmax(teacher_logits / 4.0, dim=1),
reduction='batchmean'
) * 16
# 2. 特征蒸馏
feature_loss = 0
for s_feat, t_feat in zip(self.student_features, self.teacher_features):
# 适配维度
if s_feat.shape != t_feat.shape:
t_feat = F.adaptive_avg_pool1d(
t_feat.transpose(1, 2),
s_feat.size(1)
).transpose(1, 2)
feature_loss += F.mse_loss(s_feat, t_feat)
# 3. 硬标签损失
ce_loss = F.cross_entropy(student_logits, labels)
# 总损失
total_loss = 0.5 * output_loss + 0.3 * feature_loss + 0.2 * ce_loss
# 清空特征缓存
self.teacher_features.clear()
self.student_features.clear()
return total_loss
4.2 黑盒蒸馏¶
Python
class BlackBoxDistillation:
"""
黑盒蒸馏
只能访问教师模型的输出(如API),无法访问内部参数
"""
def __init__(self, student_model, teacher_api=None):
self.student_model = student_model
self.teacher_api = teacher_api # 教师模型API接口
def get_teacher_predictions(self, inputs):
"""
通过API获取教师模型预测
实际应用中可能是:
- OpenAI API
- 本地教师模型推理
- 预计算的教师预测
"""
if self.teacher_api is None:
raise ValueError("Teacher API not provided")
# 调用API获取预测
teacher_logits = self.teacher_api.predict(inputs)
return teacher_logits
def train_step(self, batch):
"""
训练步骤
"""
inputs, labels = batch
# 获取教师预测(黑盒)
with torch.no_grad():
teacher_logits = self.get_teacher_predictions(inputs)
# 学生前向
student_logits = self.student_model(inputs)
# 只能使用输出蒸馏
distill_loss = F.kl_div(
F.log_softmax(student_logits / 4.0, dim=1),
F.softmax(teacher_logits / 4.0, dim=1),
reduction='batchmean'
) * 16
# 硬标签损失
ce_loss = F.cross_entropy(student_logits, labels)
# 总损失
total_loss = 0.7 * distill_loss + 0.3 * ce_loss
return total_loss
class DataFreeDistillation:
"""
无数据蒸馏
不需要原始训练数据,通过生成数据蒸馏
"""
def __init__(self, student_model, generator_model):
self.student_model = student_model
self.generator_model = generator_model # 数据生成器
def generate_synthetic_data(self, num_samples, sequence_length):
"""
生成合成数据
"""
# 使用生成器或随机生成
synthetic_inputs = self.generator_model.generate(
num_samples=num_samples,
sequence_length=sequence_length
)
return synthetic_inputs
def train_step(self, teacher_model):
"""
训练步骤
"""
# 生成合成数据
synthetic_inputs = self.generate_synthetic_data(
num_samples=32,
sequence_length=128
)
# 获取教师预测
with torch.no_grad():
teacher_logits = teacher_model(synthetic_inputs)
# 学生预测
student_logits = self.student_model(synthetic_inputs)
# 蒸馏损失
loss = F.kl_div(
F.log_softmax(student_logits / 4.0, dim=1),
F.softmax(teacher_logits / 4.0, dim=1),
reduction='batchmean'
) * 16
return loss
蒸馏与量化的结合¶
5.1 量化感知蒸馏¶
Python
class QuantizationAwareDistillation:
"""
量化感知蒸馏
在蒸馏过程中考虑量化误差
"""
def __init__(self, teacher_model, student_model, num_bits=8):
self.teacher_model = teacher_model
self.student_model = student_model
self.num_bits = num_bits
def quantize_tensor(self, tensor, num_bits):
"""
模拟量化
"""
# 计算缩放因子
qmin = -(2 ** (num_bits - 1))
qmax = 2 ** (num_bits - 1) - 1
min_val = tensor.min()
max_val = tensor.max()
scale = (max_val - min_val) / (qmax - qmin)
zero_point = qmin - min_val / scale
# 量化
quantized = torch.clamp(
torch.round(tensor / scale + zero_point),
qmin, qmax
)
# 反量化
dequantized = (quantized - zero_point) * scale
return dequantized
def forward(self, inputs, labels):
"""
前向传播
"""
# 教师预测(全精度)
with torch.no_grad():
teacher_logits = self.teacher_model(inputs)
# 学生预测(模拟量化)
student_logits = self.student_model(inputs)
# 模拟量化学生输出
quantized_student_logits = self.quantize_tensor(
student_logits, self.num_bits
)
# 蒸馏损失(使用量化后的输出)
distill_loss = F.kl_div(
F.log_softmax(quantized_student_logits / 4.0, dim=1),
F.softmax(teacher_logits / 4.0, dim=1),
reduction='batchmean'
) * 16
# 硬标签损失(使用原始输出)
ce_loss = F.cross_entropy(student_logits, labels)
return 0.7 * distill_loss + 0.3 * ce_loss
5.2 渐进式压缩¶
Python
class ProgressiveCompression:
"""
渐进式压缩
先蒸馏,再量化,最后剪枝
"""
def __init__(self, teacher_model):
self.teacher_model = teacher_model
self.student_model = None
def stage1_distillation(self, train_loader, num_epochs=10):
"""
第一阶段:知识蒸馏
"""
# 创建学生模型(比教师小)
self.student_model = self.create_student_model()
# 蒸馏训练
distillation_trainer = KnowledgeDistillationTrainer(
teacher=self.teacher_model,
student=self.student_model
)
for epoch in range(num_epochs):
for batch in train_loader:
loss = distillation_trainer.train_step(batch)
return self.student_model
def stage2_quantization(self, model, num_bits=8):
"""
第二阶段:量化
"""
quantized_model = torch.quantization.quantize_dynamic(
model,
{nn.Linear},
dtype=torch.qint8
)
return quantized_model
def stage3_pruning(self, model, sparsity=0.3):
"""
第三阶段:剪枝
"""
import torch.nn.utils.prune as prune
# 对线性层进行剪枝
for name, module in model.named_modules():
if isinstance(module, nn.Linear): # isinstance检查类型
prune.l1_unstructured(
module,
name='weight',
amount=sparsity
)
return model
def full_pipeline(self, train_loader):
"""
完整压缩流程
"""
# Stage 1: 蒸馏
print("Stage 1: Distillation...")
distilled_model = self.stage1_distillation(train_loader)
# Stage 2: 量化
print("Stage 2: Quantization...")
quantized_model = self.stage2_quantization(distilled_model)
# Stage 3: 剪枝
print("Stage 3: Pruning...")
final_model = self.stage3_pruning(quantized_model)
return final_model
经典案例与最佳实践¶
6.1 经典案例¶
Text Only
经典知识蒸馏案例
═══════════════════════════════════════════════════════════════════
1. DistilBERT (Hugging Face, 2019)
├── 教师: BERT-base (110M参数)
├── 学生: DistilBERT (66M参数,减少40%)
├── 方法: 软标签蒸馏 + 隐藏状态蒸馏 + 余弦嵌入损失
├── 效果: 保留97%性能,推理速度提升60%
└── 代码: transformers.DistilBert
2. TinyBERT (华为, 2020)
├── 教师: BERT-large (340M参数)
├── 学生: TinyBERT-4L (14.5M参数,4层312维)
├── 方法: 嵌入层 + Transformer层 + 预测层 三层蒸馏
├── 效果: 在GLUE上达到BERT-base的96%,体积仅约1/7.5
└── 特点: 两阶段蒸馏(通用蒸馏+任务蒸馏)
3. MobileBERT (Google, 2020)
├── 教师: BERT-large
├── 学生: MobileBERT (25.3M参数,24层,倒瓶颈结构)
├── 方法: 倒瓶颈结构 + 特征蒸馏
├── 效果: 与BERT-base相当,速度快4.3倍
└── 特点: 专为移动设备优化
4. MiniLM (微软, 2020)
├── 方法: 深度自注意力蒸馏
├── 关键: 只蒸馏注意力分布和价值关系
├── 效果: 2层模型达到BERT-base的99%
└── 优势: 不需要教师隐藏状态
5. GPT-3 Distillation (OpenAI)
├── 教师: GPT-3 (175B参数)
├── 学生: 各种小模型
├── 方法: 黑盒蒸馏(仅使用API输出)
└── 应用: InstructGPT, ChatGPT训练
═══════════════════════════════════════════════════════════════════
6.2 最佳实践¶
Python
class DistillationBestPractices:
"""
知识蒸馏最佳实践
"""
@staticmethod
def temperature_tuning(teacher_model, student_model, val_loader):
"""
温度参数调优
"""
temperatures = [1.0, 2.0, 4.0, 6.0, 8.0, 10.0]
best_temp = 4.0
best_acc = 0.0
for T in temperatures:
# 使用当前温度训练一个epoch
acc = evaluate_with_temperature(
teacher_model, student_model, val_loader, T
)
if acc > best_acc:
best_acc = acc
best_temp = T
return best_temp
@staticmethod
def layer_mapping_strategy(teacher_layers, student_layers):
"""
层映射策略
决定学生模型的哪层学习教师模型的哪层
"""
import numpy as np
strategies = {
'uniform': np.linspace(0, teacher_layers-1, student_layers, dtype=int),
'first_k': list(range(student_layers)),
'last_k': list(range(teacher_layers - student_layers, teacher_layers)),
'skip': list(range(0, teacher_layers, teacher_layers // student_layers))[:student_layers]
}
return strategies
@staticmethod
def dynamic_alpha_scheduler(epoch, total_epochs, alpha_range=(0.3, 0.7)):
"""
动态调整蒸馏权重
早期侧重蒸馏,后期侧重真实标签
"""
alpha_min, alpha_max = alpha_range
# 线性递减
alpha = alpha_max - (alpha_max - alpha_min) * (epoch / total_epochs)
return alpha
@staticmethod
def data_augmentation_for_distillation(raw_data, teacher_model, num_augmentations=5):
"""
数据增强辅助蒸馏
使用教师模型生成更多训练数据
"""
augmented_data = []
for sample in raw_data:
# 原始样本
augmented_data.append(sample)
# 生成变体
for _ in range(num_augmentations):
# 添加噪声、回译、同义词替换等
variant = augment_sample(sample)
# 使用教师标注
with torch.no_grad():
teacher_pred = teacher_model(variant)
augmented_data.append({
'input': variant,
'teacher_logits': teacher_pred,
'label': sample['label']
})
return augmented_data
蒸馏效果评估¶
7.1 评估指标¶
Python
class DistillationEvaluator:
"""
蒸馏效果评估
"""
@staticmethod
def compute_metrics(teacher_model, student_model, test_loader):
"""
计算蒸馏效果指标
"""
metrics = {}
# 1. 准确率对比
teacher_acc = evaluate_accuracy(teacher_model, test_loader)
student_acc = evaluate_accuracy(student_model, test_loader)
metrics['teacher_accuracy'] = teacher_acc
metrics['student_accuracy'] = student_acc
metrics['accuracy_retention'] = student_acc / teacher_acc
# 2. 推理速度对比
teacher_speed = measure_inference_speed(teacher_model, test_loader)
student_speed = measure_inference_speed(student_model, test_loader)
metrics['teacher_speed'] = teacher_speed
metrics['student_speed'] = student_speed
metrics['speedup'] = teacher_speed / student_speed
# 3. 模型大小对比
teacher_params = count_parameters(teacher_model)
student_params = count_parameters(student_model)
metrics['teacher_params'] = teacher_params
metrics['student_params'] = student_params
metrics['compression_ratio'] = teacher_params / student_params
# 4. 预测一致性
agreement = compute_prediction_agreement(
teacher_model, student_model, test_loader
)
metrics['prediction_agreement'] = agreement
# 5. 输出分布相似度
kl_div = compute_average_kl_divergence(
teacher_model, student_model, test_loader
)
metrics['kl_divergence'] = kl_div
return metrics
@staticmethod
def visualize_distillation_effect(teacher_logits, student_logits, labels):
"""
可视化蒸馏效果
"""
import matplotlib.pyplot as plt
# 1. 置信度分布对比
teacher_conf = F.softmax(teacher_logits, dim=1).max(dim=1)[0]
student_conf = F.softmax(student_logits, dim=1).max(dim=1)[0]
plt.figure(figsize=(12, 4))
plt.subplot(1, 3, 1)
plt.hist(teacher_conf.cpu().numpy(), bins=50, alpha=0.5, label='Teacher')
plt.hist(student_conf.cpu().numpy(), bins=50, alpha=0.5, label='Student')
plt.xlabel('Confidence')
plt.ylabel('Count')
plt.legend()
plt.title('Confidence Distribution')
# 2. 预测一致性矩阵
plt.subplot(1, 3, 2)
teacher_preds = teacher_logits.argmax(dim=1)
student_preds = student_logits.argmax(dim=1)
agreement = (teacher_preds == student_preds).float()
plt.bar(['Disagree', 'Agree'], [1-agreement.mean(), agreement.mean()])
plt.ylabel('Proportion')
plt.title('Prediction Agreement')
# 3. 正确率对比(按类别)
plt.subplot(1, 3, 3)
num_classes = teacher_logits.size(1)
teacher_acc_per_class = []
student_acc_per_class = []
for c in range(num_classes):
mask = (labels == c)
if mask.sum() > 0:
teacher_acc_per_class.append(
(teacher_preds[mask] == labels[mask]).float().mean().item()
)
student_acc_per_class.append(
(student_preds[mask] == labels[mask]).float().mean().item()
)
x = range(len(teacher_acc_per_class))
plt.plot(x, teacher_acc_per_class, label='Teacher', marker='o')
plt.plot(x, student_acc_per_class, label='Student', marker='s')
plt.xlabel('Class')
plt.ylabel('Accuracy')
plt.legend()
plt.title('Per-Class Accuracy')
plt.tight_layout()
plt.show()
def evaluate_accuracy(model, data_loader):
"""评估准确率"""
model.eval()
correct = 0
total = 0
with torch.no_grad():
for batch in data_loader:
inputs, labels = batch
outputs = model(inputs)
preds = outputs.argmax(dim=1)
correct += (preds == labels).sum().item()
total += labels.size(0)
return correct / total
def measure_inference_speed(model, data_loader, num_batches=100):
"""测量推理速度"""
import time
model.eval()
device = next(model.parameters()).device
# 预热
for i, batch in enumerate(data_loader):
if i >= 10:
break
inputs, _ = batch
_ = model(inputs.to(device)) # .to(device)将数据移至GPU/CPU
# 测量
start_time = time.time()
for i, batch in enumerate(data_loader):
if i >= num_batches:
break
inputs, _ = batch
_ = model(inputs.to(device))
end_time = time.time()
return num_batches / (end_time - start_time) # batches per second
def count_parameters(model):
"""统计参数量"""
return sum(p.numel() for p in model.parameters())
def compute_prediction_agreement(teacher_model, student_model, data_loader):
"""计算预测一致性"""
teacher_model.eval()
student_model.eval()
agreements = []
with torch.no_grad():
for batch in data_loader:
inputs, _ = batch
teacher_outputs = teacher_model(inputs)
student_outputs = student_model(inputs)
teacher_preds = teacher_outputs.argmax(dim=1)
student_preds = student_outputs.argmax(dim=1)
agreements.append((teacher_preds == student_preds).float())
return torch.cat(agreements).mean().item()
def compute_average_kl_divergence(teacher_model, student_model, data_loader):
"""计算平均KL散度"""
teacher_model.eval()
student_model.eval()
kl_divs = []
with torch.no_grad():
for batch in data_loader:
inputs, _ = batch
teacher_outputs = teacher_model(inputs)
student_outputs = student_model(inputs)
teacher_probs = F.softmax(teacher_outputs, dim=1)
student_log_probs = F.log_softmax(student_outputs, dim=1)
kl = F.kl_div(student_log_probs, teacher_probs, reduction='batchmean')
kl_divs.append(kl.item())
return sum(kl_divs) / len(kl_divs)
总结¶
知识蒸馏方法对比¶
| 方法 | 适用场景 | 优点 | 缺点 |
|---|---|---|---|
| 软标签蒸馏 | 通用场景 | 简单有效 | 需要调整温度 |
| 特征蒸馏 | 白盒场景 | 传递更多信息 | 需要维度对齐 |
| 注意力蒸馏 | Transformer | 捕捉结构信息 | 计算复杂 |
| 自蒸馏 | 无教师模型 | 不需要教师 | 效果可能较差 |
| 在线蒸馏 | 多个模型 | 互相提升 | 训练复杂 |
关键超参数¶
Text Only
温度 (Temperature):
├── 推荐范围: 2.0 - 8.0
├── 任务复杂 → 温度高
└── 类别多 → 温度高
蒸馏权重 (Alpha):
├── 推荐范围: 0.5 - 0.9
├── 早期: 高权重(学习教师)
└── 后期: 低权重(微调细节)
层映射:
├── 均匀映射: 通用
├── 最后k层: 任务相关
└── 学习映射: 效果最佳但复杂
最佳实践清单¶
- 选择合适的教师-学生模型尺寸比(通常2-10倍)
- 使用温度调优找到最佳温度
- 考虑多层蒸馏而不仅是输出
- 结合数据增强提升效果
- 动态调整蒸馏权重
- 评估时关注准确率、速度和大小三个维度
下一步:学习01-数据工程与预处理,进入系统与工程阶段!
最后更新日期:2026-02-12 适用版本:LLM学习教程 v2026