05 - 知识蒸馏¶
⚠️ 时效性说明:本章涉及前沿模型/价格/榜单等信息,可能随版本快速变化;请以论文原文、官方发布页和 API 文档为准。
学习目标:深入理解知识蒸馏的原理、方法与实践,掌握如何将大模型知识迁移到小模型。
目录¶
知识蒸馏概述¶
1.1 什么是知识蒸馏¶
Text Only
知识蒸馏(Knowledge Distillation)
核心思想:将大型教师模型(Teacher)的知识迁移到小型学生模型(Student)
为什么需要蒸馏?
├── 大模型性能好但推理慢、部署难
├── 小模型速度快但性能差
└── 蒸馏:让小模型学习大模型的"暗知识"
暗知识(Dark Knowledge):
├── 不仅学习正确标签
├── 还学习类别间的相似关系
└── 例如:狗→狼的概率比狗→汽车高
类比:
├── 教师:经验丰富的专家
├── 学生:初学者
└── 蒸馏:学生学习专家的思考方式,不只是答案
1.2 蒸馏的基本框架¶
Text Only
┌─────────────────────────────────────────────────────────────────┐
│ 知识蒸馏框架 │
├─────────────────────────────────────────────────────────────────┤
│ │
│ ┌─────────────────────┐ ┌─────────────────────┐ │
│ │ 教师模型 (T) │ │ 学生模型 (S) │ │
│ │ (大、复杂) │ │ (小、简单) │ │
│ │ │ │ │ │
│ │ Input ──▶ T ──▶ │ │ Input ──▶ S ──▶ │ │
│ │ Soft logits │ │ Soft logits │ │
│ │ │ │ │ │ │ │
│ │ ▼ │ │ ▼ │ │
│ │ Softmax │ │ Softmax │ │
│ │ (高温τ) │ │ (高温τ) │ │
│ │ │ │ │ │ │ │
│ │ ▼ │ │ ▼ │ │
│ │ Soft Targets │ │ Soft Predictions │ │
│ │ │ │ │ │ │ │
│ │ └───────────┼────────┘ │ │ │
│ │ │ │ │ │
│ │ ▼ │ │ │
│ │ KL Divergence │ │ │
│ │ (蒸馏损失) │ │ │
│ │ │ │ │ │
│ │ ▼ ▼ │ │
│ │ Hard Targets ───────▶ CE Loss │ │
│ │ (真实标签) (学生损失) │ │
│ │ │ │
│ └──────────────────────────────────────────────────────┘ │
│ │ │
│ ▼ │
│ Total Loss = α·L_soft + β·L_hard │
│ │
└─────────────────────────────────────────────────────────────────┘
基础蒸馏方法¶
2.1 软标签蒸馏( Hinton Distillation )¶
Python
import torch
import torch.nn as nn
import torch.nn.functional as F
class KnowledgeDistillationLoss(nn.Module):
"""
Hinton知识蒸馏损失
"Distilling the Knowledge in a Neural Network" (Hinton et al., 2015)
"""
def __init__(self, temperature=4.0, alpha=0.7):
super().__init__() # super()调用父类方法
self.temperature = temperature
self.alpha = alpha # 蒸馏损失权重
def forward(self, student_logits, teacher_logits, true_labels):
"""
Args:
student_logits: 学生模型输出 [batch, num_classes]
teacher_logits: 教师模型输出 [batch, num_classes]
true_labels: 真实标签 [batch]
Returns:
total_loss: 总损失
distill_loss: 蒸馏损失(KL散度)
ce_loss: 交叉熵损失
"""
# 1. 软目标损失(KL散度)
soft_student = F.log_softmax(student_logits / self.temperature, dim=1)
soft_teacher = F.softmax(teacher_logits / self.temperature, dim=1)
# KL散度 = Σ p_teacher * log(p_teacher / p_student)
distill_loss = F.kl_div(
soft_student,
soft_teacher,
reduction='batchmean'
) * (self.temperature ** 2) # 缩放因子
# 2. 硬目标损失(交叉熵)
ce_loss = F.cross_entropy(student_logits, true_labels)
# 3. 总损失
total_loss = self.alpha * distill_loss + (1 - self.alpha) * ce_loss
return total_loss, distill_loss, ce_loss
# 温度参数的作用
def demonstrate_temperature():
"""
演示温度参数对软标签的影响
"""
import matplotlib.pyplot as plt
# 假设logits
logits = torch.tensor([2.0, 1.0, 0.1])
temperatures = [0.5, 1.0, 2.0, 4.0, 8.0]
plt.figure(figsize=(10, 6))
for T in temperatures:
probs = F.softmax(logits / T, dim=0)
plt.plot(probs.numpy(), label=f'T={T}', marker='o')
plt.xlabel('Class Index')
plt.ylabel('Probability')
plt.title('Effect of Temperature on Softmax Distribution')
plt.legend()
plt.grid(True)
plt.show()
"""
结论:
- T→0: 趋近于one-hot(硬标签)
- T→∞: 趋近于均匀分布
- 适中的T(2-4): 保留类别间关系,但不过于平滑
"""
# demonstrate_temperature()
2.2 特征蒸馏( Feature Distillation )¶
Python
class FeatureDistillationLoss(nn.Module):
"""
特征蒸馏
让学生模型学习教师模型的中间层特征表示
"""
def __init__(self, student_dim, teacher_dim, hidden_dim=None):
super().__init__()
if hidden_dim is None:
hidden_dim = student_dim
# 适配层:将学生特征映射到教师特征空间
self.adaptation_layer = nn.Sequential(
nn.Linear(student_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, teacher_dim)
)
def forward(self, student_features, teacher_features):
"""
Args:
student_features: [batch, student_dim]
teacher_features: [batch, teacher_dim]
"""
# 适配学生特征
adapted_student = self.adaptation_layer(student_features)
# 均方误差
loss = F.mse_loss(adapted_student, teacher_features)
return loss
class MultiLayerFeatureDistillation(nn.Module):
"""
多层特征蒸馏(FitNets)
在多个中间层进行特征蒸馏
"""
def __init__(self, student_dims, teacher_dims, weights=None):
super().__init__()
assert len(student_dims) == len(teacher_dims) # assert断言:条件False时抛出AssertionError
self.num_layers = len(student_dims)
# 为每层创建适配层
self.adaptation_layers = nn.ModuleList([
nn.Linear(s_dim, t_dim)
for s_dim, t_dim in zip(student_dims, teacher_dims) # zip按位置配对多个可迭代对象
])
# 每层的权重
if weights is None:
weights = [1.0] * self.num_layers
self.weights = weights
def forward(self, student_features_list, teacher_features_list):
"""
Args:
student_features_list: 学生模型多层特征
teacher_features_list: 教师模型多层特征
"""
total_loss = 0
for i, (s_feat, t_feat, adapt_layer) in enumerate(
zip(student_features_list, teacher_features_list, self.adaptation_layers)
):
# 适配并计算损失
adapted = adapt_layer(s_feat)
loss = F.mse_loss(adapted, t_feat)
total_loss += self.weights[i] * loss
return total_loss
2.3 注意力蒸馏¶
Python
class AttentionDistillationLoss(nn.Module):
"""
注意力蒸馏
让学生模型学习教师模型的注意力模式
"""
def __init__(self, temperature=1.0):
super().__init__()
self.temperature = temperature
def forward(self, student_attention, teacher_attention):
"""
Args:
student_attention: [batch, num_heads, seq_len, seq_len]
teacher_attention: [batch, num_heads, seq_len, seq_len]
"""
# 方法1: KL散度
student_attn = F.log_softmax(student_attention / self.temperature, dim=-1)
teacher_attn = F.softmax(teacher_attention / self.temperature, dim=-1)
kl_loss = F.kl_div(
student_attn,
teacher_attn,
reduction='batchmean'
)
# 方法2: 均方误差(有时效果更好)
mse_loss = F.mse_loss(student_attention, teacher_attention)
# 组合
total_loss = kl_loss + mse_loss
return total_loss
class MultiHeadAttentionTransfer(nn.Module):
"""
多头注意力迁移
TinyBERT等方法使用
"""
def __init__(self, student_num_heads, teacher_num_heads):
super().__init__()
self.student_num_heads = student_num_heads
self.teacher_num_heads = teacher_num_heads
def forward(self, student_attn_list, teacher_attn_list):
"""
对齐不同层数的注意力
"""
# 选择教师模型的特定层进行蒸馏
# 例如:学生6层,教师12层,选择教师第2,4,6,8,10,12层
selected_teacher_layers = self.select_layers(
len(teacher_attn_list),
len(student_attn_list)
)
total_loss = 0
for student_idx, teacher_idx in enumerate(selected_teacher_layers): # enumerate同时获取索引和元素
s_attn = student_attn_list[student_idx]
t_attn = teacher_attn_list[teacher_idx]
# 如果头数不同,需要平均或插值
if s_attn.size(1) != t_attn.size(1):
t_attn = t_attn.mean(dim=1, keepdim=True).expand(-1, s_attn.size(1), -1, -1)
loss = F.mse_loss(s_attn, t_attn)
total_loss += loss
return total_loss / len(student_attn_list)
def select_layers(self, teacher_layers, student_layers):
"""
选择对应的教师层
"""
import numpy as np
indices = np.linspace(0, teacher_layers - 1, student_layers, dtype=int)
return indices.tolist()
高级蒸馏技术¶
3.1 自蒸馏( Self-Distillation )¶
Python
class SelfDistillationLoss(nn.Module):
"""
自蒸馏
模型自己蒸馏自己,不需要单独的教师模型
"""
def __init__(self, temperature=4.0):
super().__init__()
self.temperature = temperature
def forward(self, logits, labels, epoch, total_epochs):
"""
随着训练进行,逐渐增加软标签的影响
"""
# 硬标签损失
hard_loss = F.cross_entropy(logits, labels)
# 生成软标签(使用当前模型自己的预测)
with torch.no_grad(): # 禁用梯度计算,节省内存(推理时使用)
soft_targets = F.softmax(logits / self.temperature, dim=1)
# 软标签损失
# ⚠️ 演示说明: 此处 soft_targets 与 logits 来自同一分布,理论上 KL(p‖p)≡0。
# 实际自蒸馏中,soft_targets 应来自上一个 epoch 保存的模型输出(即“历史版本”作为教师)。
# 这里仅作为代码结构示例。
soft_loss = F.kl_div(
F.log_softmax(logits / self.temperature, dim=1),
soft_targets,
reduction='batchmean'
) * (self.temperature ** 2)
# 动态权重:早期主要用硬标签,后期增加软标签
alpha = min(0.5, epoch / (total_epochs * 0.5))
total_loss = (1 - alpha) * hard_loss + alpha * soft_loss
return total_loss
3.2 在线蒸馏( Online Distillation )¶
Python
class OnlineDistillation:
"""
在线蒸馏(Deep Mutual Learning)
多个学生模型互相学习,没有固定的教师
"""
def __init__(self, models, temperature=4.0):
self.models = models # 多个学生模型
self.temperature = temperature
self.optimizers = [torch.optim.AdamW(m.parameters()) for m in models]
def train_step(self, batch):
"""
所有模型互相蒸馏
"""
inputs, labels = batch
# 获取所有模型的预测
logits_list = [model(inputs) for model in self.models]
# 计算平均预测作为"软教师"
avg_logits = torch.stack(logits_list).mean(dim=0)
# 每个模型向平均预测学习
for i, (model, optimizer) in enumerate(zip(self.models, self.optimizers)):
optimizer.zero_grad()
# 硬标签损失
hard_loss = F.cross_entropy(logits_list[i], labels)
# 向其他模型学习(排除自己)
other_logits = torch.stack(
[logits_list[j] for j in range(len(self.models)) if j != i]
).mean(dim=0)
soft_loss = F.kl_div(
F.log_softmax(logits_list[i] / self.temperature, dim=1),
F.softmax(other_logits / self.temperature, dim=1),
reduction='batchmean'
) * (self.temperature ** 2)
# 总损失
loss = hard_loss + soft_loss
loss.backward()
optimizer.step()
3.3 任务特定蒸馏¶
Python
class TaskSpecificDistillation:
"""
任务特定蒸馏
针对不同任务设计特定的蒸馏策略
"""
@staticmethod # @staticmethod无需实例即可调用
def classification_distillation(student_logits, teacher_logits, labels, alpha=0.7, T=4.0):
"""
分类任务蒸馏
"""
soft_loss = F.kl_div(
F.log_softmax(student_logits / T, dim=1),
F.softmax(teacher_logits / T, dim=1),
reduction='batchmean'
) * (T ** 2)
hard_loss = F.cross_entropy(student_logits, labels)
return alpha * soft_loss + (1 - alpha) * hard_loss
@staticmethod
def ner_distillation(student_logits, teacher_logits, labels, attention_mask, alpha=0.7, T=4.0):
"""
命名实体识别任务蒸馏
每个token独立计算
"""
batch_size, seq_len, num_classes = student_logits.shape
# 展平
student_logits = student_logits.view(-1, num_classes) # view重塑张量形状
teacher_logits = teacher_logits.view(-1, num_classes)
labels = labels.view(-1)
attention_mask = attention_mask.view(-1)
# 只计算有效token
valid_indices = attention_mask.bool()
student_logits = student_logits[valid_indices]
teacher_logits = teacher_logits[valid_indices]
labels = labels[valid_indices]
soft_loss = F.kl_div(
F.log_softmax(student_logits / T, dim=1),
F.softmax(teacher_logits / T, dim=1),
reduction='batchmean'
) * (T ** 2)
hard_loss = F.cross_entropy(student_logits, labels)
return alpha * soft_loss + (1 - alpha) * hard_loss
@staticmethod
def generation_distillation(student_logits, teacher_logits, labels, alpha=0.7, T=4.0):
"""
生成任务蒸馏
序列生成任务的蒸馏
"""
# 移位:预测下一个token
shift_student = student_logits[..., :-1, :].contiguous()
shift_teacher = teacher_logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# 展平
shift_student = shift_student.view(-1, shift_student.size(-1))
shift_teacher = shift_teacher.view(-1, shift_teacher.size(-1))
shift_labels = shift_labels.view(-1)
soft_loss = F.kl_div(
F.log_softmax(shift_student / T, dim=1),
F.softmax(shift_teacher / T, dim=1),
reduction='batchmean'
) * (T ** 2)
hard_loss = F.cross_entropy(shift_student, shift_labels)
return alpha * soft_loss + (1 - alpha) * hard_loss
白盒蒸馏与黑盒蒸馏¶
4.1 白盒蒸馏¶
Python
class WhiteBoxDistillation:
"""
白盒蒸馏
可以访问教师模型的内部参数和特征
"""
def __init__(self, teacher_model, student_model):
self.teacher_model = teacher_model
self.student_model = student_model
# 注册hook获取中间特征
self.teacher_features = []
self.student_features = []
self._register_hooks()
def _register_hooks(self):
"""
注册前向hook获取特征
"""
self.hook_handles = [] # 保存hook句柄,用于后续清理
def get_features(name, storage):
def hook(module, input, output):
storage.append(output)
return hook
# 在特定层注册hook
for i, layer in enumerate(self.teacher_model.transformer_layers):
handle = layer.register_forward_hook(
get_features(f'teacher_layer_{i}', self.teacher_features)
)
self.hook_handles.append(handle)
for i, layer in enumerate(self.student_model.transformer_layers):
handle = layer.register_forward_hook(
get_features(f'student_layer_{i}', self.student_features)
)
self.hook_handles.append(handle)
def forward(self, inputs, labels):
"""
前向传播并计算蒸馏损失
"""
# 教师前向
with torch.no_grad():
teacher_logits = self.teacher_model(inputs)
# 学生前向
student_logits = self.student_model(inputs)
# 1. 输出蒸馏
output_loss = F.kl_div(
F.log_softmax(student_logits / 4.0, dim=1),
F.softmax(teacher_logits / 4.0, dim=1),
reduction='batchmean'
) * 16
# 2. 特征蒸馏
feature_loss = 0
for s_feat, t_feat in zip(self.student_features, self.teacher_features):
# 适配维度
if s_feat.shape != t_feat.shape:
t_feat = F.adaptive_avg_pool1d(
t_feat.transpose(1, 2),
s_feat.size(1)
).transpose(1, 2)
feature_loss += F.mse_loss(s_feat, t_feat)
# 3. 硬标签损失
ce_loss = F.cross_entropy(student_logits, labels)
# 总损失
total_loss = 0.5 * output_loss + 0.3 * feature_loss + 0.2 * ce_loss
# 清空特征缓存
self.teacher_features.clear()
self.student_features.clear()
return total_loss
4.2 黑盒蒸馏¶
Python
class BlackBoxDistillation:
"""
黑盒蒸馏
只能访问教师模型的输出(如API),无法访问内部参数
"""
def __init__(self, student_model, teacher_api=None):
self.student_model = student_model
self.teacher_api = teacher_api # 教师模型API接口
def get_teacher_predictions(self, inputs):
"""
通过API获取教师模型预测
实际应用中可能是:
- OpenAI API
- 本地教师模型推理
- 预计算的教师预测
"""
if self.teacher_api is None:
raise ValueError("Teacher API not provided")
# 调用API获取预测
teacher_logits = self.teacher_api.predict(inputs)
return teacher_logits
def train_step(self, batch):
"""
训练步骤
"""
inputs, labels = batch
# 获取教师预测(黑盒)
with torch.no_grad():
teacher_logits = self.get_teacher_predictions(inputs)
# 学生前向
student_logits = self.student_model(inputs)
# 只能使用输出蒸馏
distill_loss = F.kl_div(
F.log_softmax(student_logits / 4.0, dim=1),
F.softmax(teacher_logits / 4.0, dim=1),
reduction='batchmean'
) * 16
# 硬标签损失
ce_loss = F.cross_entropy(student_logits, labels)
# 总损失
total_loss = 0.7 * distill_loss + 0.3 * ce_loss
return total_loss
class DataFreeDistillation:
"""
无数据蒸馏
不需要原始训练数据,通过生成数据蒸馏
"""
def __init__(self, student_model, generator_model):
self.student_model = student_model
self.generator_model = generator_model # 数据生成器
def generate_synthetic_data(self, num_samples, sequence_length):
"""
生成合成数据
"""
# 使用生成器或随机生成
synthetic_inputs = self.generator_model.generate(
num_samples=num_samples,
sequence_length=sequence_length
)
return synthetic_inputs
def train_step(self, teacher_model):
"""
训练步骤
"""
# 生成合成数据
synthetic_inputs = self.generate_synthetic_data(
num_samples=32,
sequence_length=128
)
# 获取教师预测
with torch.no_grad():
teacher_logits = teacher_model(synthetic_inputs)
# 学生预测
student_logits = self.student_model(synthetic_inputs)
# 蒸馏损失
loss = F.kl_div(
F.log_softmax(student_logits / 4.0, dim=1),
F.softmax(teacher_logits / 4.0, dim=1),
reduction='batchmean'
) * 16
return loss
蒸馏与量化的结合¶
5.1 量化感知蒸馏¶
Python
class QuantizationAwareDistillation:
"""
量化感知蒸馏
在蒸馏过程中考虑量化误差
"""
def __init__(self, teacher_model, student_model, num_bits=8):
self.teacher_model = teacher_model
self.student_model = student_model
self.num_bits = num_bits
def quantize_tensor(self, tensor, num_bits):
"""
模拟量化
"""
# 计算缩放因子
qmin = -(2 ** (num_bits - 1))
qmax = 2 ** (num_bits - 1) - 1
min_val = tensor.min()
max_val = tensor.max()
scale = (max_val - min_val) / (qmax - qmin)
zero_point = qmin - min_val / scale
# 量化
quantized = torch.clamp(
torch.round(tensor / scale + zero_point),
qmin, qmax
)
# 反量化
dequantized = (quantized - zero_point) * scale
return dequantized
def forward(self, inputs, labels):
"""
前向传播
"""
# 教师预测(全精度)
with torch.no_grad():
teacher_logits = self.teacher_model(inputs)
# 学生预测(模拟量化)
student_logits = self.student_model(inputs)
# 模拟量化学生输出
quantized_student_logits = self.quantize_tensor(
student_logits, self.num_bits
)
# 蒸馏损失(使用量化后的输出)
distill_loss = F.kl_div(
F.log_softmax(quantized_student_logits / 4.0, dim=1),
F.softmax(teacher_logits / 4.0, dim=1),
reduction='batchmean'
) * 16
# 硬标签损失(使用原始输出)
ce_loss = F.cross_entropy(student_logits, labels)
return 0.7 * distill_loss + 0.3 * ce_loss
5.2 渐进式压缩¶
Python
class ProgressiveCompression:
"""
渐进式压缩
先蒸馏,再量化,最后剪枝
"""
def __init__(self, teacher_model):
self.teacher_model = teacher_model
self.student_model = None
def stage1_distillation(self, train_loader, num_epochs=10):
"""
第一阶段:知识蒸馏
"""
# 创建学生模型(比教师小)
self.student_model = self.create_student_model()
# 蒸馏训练
distillation_trainer = KnowledgeDistillationTrainer(
teacher=self.teacher_model,
student=self.student_model
)
for epoch in range(num_epochs):
for batch in train_loader:
loss = distillation_trainer.train_step(batch)
return self.student_model
def stage2_quantization(self, model, num_bits=8):
"""
第二阶段:量化
"""
quantized_model = torch.quantization.quantize_dynamic(
model,
{nn.Linear},
dtype=torch.qint8
)
return quantized_model
def stage3_pruning(self, model, sparsity=0.3):
"""
第三阶段:剪枝
"""
import torch.nn.utils.prune as prune
# 对线性层进行剪枝
for name, module in model.named_modules():
if isinstance(module, nn.Linear): # isinstance检查类型
prune.l1_unstructured(
module,
name='weight',
amount=sparsity
)
return model
def full_pipeline(self, train_loader):
"""
完整压缩流程
"""
# Stage 1: 蒸馏
print("Stage 1: Distillation...")
distilled_model = self.stage1_distillation(train_loader)
# Stage 2: 量化
print("Stage 2: Quantization...")
quantized_model = self.stage2_quantization(distilled_model)
# Stage 3: 剪枝
print("Stage 3: Pruning...")
final_model = self.stage3_pruning(quantized_model)
return final_model
闭源模型蒸馏¶
6.1 API调用式蒸馏原理(黑盒蒸馏)¶
Python
class BlackBoxDistillation:
"""
黑盒蒸馏(API调用式蒸馏)
适用于无法访问模型内部结构的闭源模型(如GPT-4、Claude等)
通过API调用获取教师模型的输出,仅利用输入-输出对进行蒸馏
"""
def __init__(self, api_key, model_name="gpt-4"):
self.api_key = api_key
self.model_name = model_name
self.client = OpenAI(api_key=api_key)
def generate_teacher_response(self, prompt, temperature=0.7, max_tokens=2048):
"""
通过API调用获取教师模型输出
Args:
prompt: 输入提示
temperature: 生成温度
max_tokens: 最大token数
Returns:
教师模型的响应内容
"""
response = self.client.chat.completions.create(
model=self.model_name,
messages=[{"role": "user", "content": prompt}],
temperature=temperature,
max_tokens=max_tokens
)
return response.choices[0].message.content
def batch_generate_distillation_data(self, prompts, batch_size=100):
"""
批量生成蒸馏数据
使用闭源API批量生成训练数据
"""
distillation_data = []
for i in range(0, len(prompts), batch_size):
batch = prompts[i:i + batch_size]
# 并发调用API
futures = []
with ThreadPoolExecutor(max_workers=10) as executor:
for prompt in batch:
future = executor.submit(
self.generate_teacher_response, prompt
)
futures.append((prompt, future))
# 收集结果
for prompt, future in futures:
try:
response = future.result(timeout=30)
distillation_data.append({
"input": prompt,
"teacher_output": response,
"label": self.extract_label(response)
})
except Exception as e:
print(f"Error generating response: {e}")
continue
return distillation_data
def extract_label(self, response):
"""
从教师输出中提取标签
对于分类任务,提取分类标签
对于生成任务,提取目标输出
"""
# 简单的标签提取逻辑,实际应用中需要根据任务定制
return response.strip()
def distill_with_api_data(self, student_model, distillation_data, epochs=3):
"""
使用API生成的数据蒸馏学生模型
"""
optimizer = torch.optim.Adam(student_model.parameters(), lr=1e-4)
for epoch in range(epochs):
for data in distillation_data:
# 学生模型前向传播
student_output = student_model(data["input"])
# 教师输出编码
teacher_encoding = self.encode_response(data["teacher_output"])
# 计算蒸馏损失
loss = self.compute_distillation_loss(
student_output, teacher_encoding
)
# 反向传播
optimizer.zero_grad()
loss.backward()
optimizer.step()
return student_model
def encode_response(self, response):
"""
将教师响应编码为训练信号
"""
# 实际应用中可能需要更复杂的编码逻辑
return response
def compute_distillation_loss(self, student_output, teacher_encoding):
"""
计算蒸馏损失
"""
# 简化版本,实际需要根据任务定制
return nn.functional.cross_entropy(
student_output,
teacher_encoding
)
6.2 思维链蒸馏技术(CoT Distillation)¶
Python
class ChainOfThoughtDistillation:
"""
思维链蒸馏(CoT Distillation)
将大模型的思维链推理能力迁移到小模型
核心思想:让学生模型学习教师的推理过程,而非仅学习答案
"""
def __init__(self, teacher_model, student_model):
self.teacher = teacher_model
self.student = student_model
def generate_cot_data(self, problems, use_cot_prompt=True):
"""
生成思维链训练数据
Args:
problems: 问题列表
use_cot_prompt: 是否使用CoT提示
Returns:
包含问题、思维链、答案的数据列表
"""
cot_data = []
for problem in problems:
# 使用教师模型生成思维链
if use_cot_prompt:
prompt = f"{problem}\n\n请逐步思考并展示你的推理过程。"
else:
prompt = problem
# 获取教师响应
teacher_response = self.teacher.generate(prompt)
# 解析思维链和答案
parsed = self.parse_cot_response(teacher_response)
cot_data.append({
"problem": problem,
"reasoning": parsed["reasoning"],
"answer": parsed["answer"]
})
return cot_data
def parse_cot_response(self, response):
"""
解析思维链响应
提取推理步骤和最终答案
"""
# 简单的解析逻辑
lines = response.split("\n")
reasoning_steps = []
answer = ""
for line in lines:
if "答案:" in line or "Answer:" in line:
answer = line.split("答案:")[-1].split("Answer:")[-1].strip()
else:
reasoning_steps.append(line)
return {
"reasoning": "\n".join(reasoning_steps),
"answer": answer
}
def distill_cot(self, cot_data, method="full_cot"):
"""
执行思维链蒸馏
Args:
cot_data: 思维链训练数据
method: 蒸馏方法
- "full_cot": 完整思维链蒸馏
- "answer_only": 仅蒸馏答案
- "summary_only": 仅蒸馏思维链摘要
"""
if method == "full_cot":
return self._distill_full_cot(cot_data)
elif method == "answer_only":
return self._distill_answer_only(cot_data)
elif method == "summary_only":
return self._distill_summary_only(cot_data)
def _distill_full_cot(self, cot_data):
"""
完整思维链蒸馏
学生模型学习完整的推理过程和最终答案
"""
total_loss = 0
for data in cot_data:
# 学生模型生成完整响应
student_output = self.student.generate(
f"{data['problem']}\n\n请逐步思考。"
)
# 计算响应级别的蒸馏损失
response_loss = self.compute_response_loss(
student_output,
data["reasoning"] + "\n\n答案:" + data["answer"]
)
# 计算token级别的损失
token_loss = self.compute_token_level_loss(
student_output,
data["reasoning"] + " " + data["answer"]
)
# 综合损失
loss = 0.5 * response_loss + 0.5 * token_loss
total_loss += loss
return total_loss / len(cot_data)
def _distill_answer_only(self, cot_data):
"""
仅蒸馏答案
简单有效,但丢失推理过程信息
"""
total_loss = 0
for data in cot_data:
# 仅使用答案作为标签
student_output = self.student.generate(data["problem"])
# 提取学生模型的答案
student_answer = self.extract_answer(student_output)
# 计算答案损失
loss = self.compute_answer_loss(
student_answer,
data["answer"]
)
total_loss += loss
return total_loss / len(cot_data)
def _distill_summary_only(self, cot_data):
"""
Summary-only CoT蒸馏
仅蒸馏思维链摘要,不包含完整推理细节
减少学生模型的学习负担,同时保留核心推理能力
"""
total_loss = 0
for data in cot_data:
# 生成思维链摘要
summary = self.summarize_reasoning(data["reasoning"])
# 学生学习摘要
student_summary = self.student.generate(
f"{data['problem']}\n\n请简要思考关键步骤。"
)
# 计算摘要级别损失
loss = self.compute_summary_loss(
student_summary,
summary
)
total_loss += loss
return total_loss / len(cot_data)
def summarize_reasoning(self, reasoning):
"""
生成思维链摘要
提取核心推理步骤,移除冗余信息
"""
# 简化版本,实际应用中可能需要更复杂的摘要逻辑
lines = reasoning.split("\n")
key_steps = []
for line in lines:
# 保留关键推理步骤
if any(keyword in line for keyword in ["因此", "所以", "所以", "第一步", "第二步", "关键"]):
key_steps.append(line)
return " ".join(key_steps)
def compute_response_loss(self, student_response, teacher_response):
"""
计算响应级别的损失
"""
# 使用文本相似度作为损失
student_emb = self.embed(student_response)
teacher_emb = self.embed(teacher_response)
return 1 - cosine_similarity(student_emb, teacher_emb)
def compute_token_level_loss(self, student_output, teacher_output):
"""
计算token级别的损失
"""
student_tokens = self.tokenize(student_output)
teacher_tokens = self.tokenize(teacher_output)
# 对齐token序列
aligned_loss = 0
for s_tok, t_tok in zip(student_tokens, teacher_tokens):
aligned_loss += nn.functional.cross_entropy(s_tok, t_tok)
return aligned_loss / max(len(student_tokens), len(teacher_tokens))
def compute_answer_loss(self, student_answer, teacher_answer):
"""
计算答案损失
"""
# 简化为文本匹配损失
return 0 if student_answer == teacher_answer else 1.0
def compute_summary_loss(self, student_summary, teacher_summary):
"""
计算摘要损失
"""
student_emb = self.embed(student_summary)
teacher_emb = self.embed(teacher_summary)
return 1 - cosine_similarity(student_emb, teacher_emb)
def embed(self, text):
"""
文本嵌入
"""
# 实际应用中需要使用适当的嵌入模型
return torch.randn(768)
def tokenize(self, text):
"""
分词
"""
# 简化版本
return [torch.randn(50257) for _ in range(len(text.split()))]
def extract_answer(self, response):
"""
从响应中提取答案
"""
if "答案:" in response:
return response.split("答案:")[-1].strip()
elif "Answer:" in response:
return response.split("Answer:")[-1].strip()
return response.strip()
6.3 PRM与ORM在蒸馏中的应用¶
Python
class RewardGuidedDistillation:
"""
基于奖励模型的蒸馏
PRM(过程奖励模型):对每个推理步骤打分
ORM(结果奖励模型):对整个响应结果打分
在蒸馏过程中引入奖励信号,引导学生模型学习更好的推理路径
"""
def __init__(self, teacher_model, student_model, reward_model_type="prm"):
self.teacher = teacher_model
self.student = student_model
self.reward_model_type = reward_model_type
if reward_model_type == "prm":
self.reward_model = ProcessRewardModel()
else:
self.reward_model = OutcomeRewardModel()
def distill_with_rewards(self, training_data, use_rewards=True):
"""
使用奖励信号进行蒸馏
Args:
training_data: 训练数据
use_rewards: 是否使用奖励信号
"""
total_loss = 0
for data in training_data:
# 学生模型生成响应
student_output = self.student.generate(
data["problem"],
return_reasoning_steps=True
)
# 获取教师响应
teacher_output = self.teacher.generate(
data["problem"],
return_reasoning_steps=True
)
# 计算基础蒸馏损失
distill_loss = self.compute_distillation_loss(
student_output, teacher_output
)
if use_rewards:
# 计算奖励
if self.reward_model_type == "prm":
student_reward = self.compute_prm_reward(
student_output["reasoning_steps"]
)
teacher_reward = self.compute_prm_reward(
teacher_output["reasoning_steps"]
)
else:
student_reward = self.reward_model.score(
data["problem"], student_output["answer"]
)
teacher_reward = self.reward_model.score(
data["problem"], teacher_output["answer"]
)
# 奖励引导的蒸馏损失
reward_weight = abs(student_reward - teacher_reward)
total_loss += distill_loss * (1 + reward_weight)
else:
total_loss += distill_loss
return total_loss / len(training_data)
def compute_prm_reward(self, reasoning_steps):
"""
计算过程奖励
对每个推理步骤打分,综合得到总奖励
"""
step_rewards = []
for step in reasoning_steps:
reward = self.reward_model.score_step(step)
step_rewards.append(reward)
# 加权平均作为总奖励
weights = [0.1 * (i + 1) for i in range(len(step_rewards))]
total_reward = sum(w * r for w, r in zip(weights, step_rewards))
return total_reward
class ProcessRewardModel:
"""
过程奖励模型(PRM)
对每个推理步骤独立打分
"""
def __init__(self):
# 实际应用中需要加载预训练的PRM
self.model = None
def score_step(self, step):
"""
对单个推理步骤打分
Returns:
float: 步骤分数 [0, 1]
"""
# 简化实现
# 实际应用中需要使用真实的PRM模型
if any(keyword in step for keyword in ["正确", "对", "是的"]):
return 0.9
elif any(keyword in step for keyword in ["错误", "不对"]):
return 0.3
return 0.7
class OutcomeRewardModel:
"""
结果奖励模型(ORM)
对整个响应结果打分
"""
def __init__(self):
self.model = None
def score(self, problem, answer):
"""
对响应结果打分
Returns:
float: 响应分数 [0, 1]
"""
# 简化实现
# 实际应用中需要使用真实的ORM模型
ground_truth = self.extract_ground_truth(problem)
if answer == ground_truth:
return 1.0
elif self.is_partial_match(answer, ground_truth):
return 0.5
return 0.0
def extract_ground_truth(self, problem):
"""
从问题中提取标准答案
"""
# 简化实现
return ""
def is_partial_match(self, answer, ground_truth):
"""
判断部分匹配
"""
return False
class BestOfNDistillation:
"""
Best-of-N 蒸馏
生成N个候选响应,使用奖励模型选择最佳响应进行蒸馏
"""
def __init__(self, teacher_model, student_model, n_candidates=8):
self.teacher = teacher_model
self.student = student_model
self.n_candidates = n_candidates
self.reward_model = ProcessRewardModel()
def distill_best_of_n(self, problem):
"""
Best-of-N 蒸馏
1. 学生模型生成N个候选响应
2. 使用PRM/ORM选择最佳响应
3. 使用最佳响应进行蒸馏
"""
# 生成N个候选
candidates = []
for _ in range(self.n_candidates):
candidate = self.student.generate(
problem,
temperature=0.8,
return_reasoning_steps=True
)
candidates.append(candidate)
# 使用奖励模型选择最佳
best_candidate = self.select_best(candidates)
# 获取教师响应
teacher_response = self.teacher.generate(
problem,
return_reasoning_steps=True
)
# 计算蒸馏损失
loss = self.compute_loss(best_candidate, teacher_response)
return loss
def select_best(self, candidates):
"""
使用奖励模型选择最佳候选
"""
best_score = -float('inf')
best_candidate = None
for candidate in candidates:
# 计算候选的奖励分数
if "reasoning_steps" in candidate:
score = self.reward_model.compute_prm_reward(
candidate["reasoning_steps"]
)
else:
score = 0.5 # 默认分数
if score > best_score:
best_score = score
best_candidate = candidate
return best_candidate
def compute_loss(self, student_output, teacher_output):
"""
计算蒸馏损失
"""
# 简化的损失计算
return 0.1
6.4 DeepSeek事件技术分析¶
Python
class DeepSeekDistillationAnalysis:
"""
DeepSeek蒸馏技术分析
2025年初DeepSeek-R1的发布引发了业界对模型蒸馏技术的广泛关注
以下分析基于DeepSeek官方技术报告和公开信息
"""
@staticmethod
def analyze_r1_distillation_approach():
"""
分析DeepSeek-R1的蒸馏方法
核心创新:
1. 使用强化学习训练大推理模型
2. 将推理能力蒸馏到小模型
3. 开源多个蒸馏版本
"""
analysis = """
DeepSeek-R1 蒸馏技术分析
════════════════════════════════════════════════════════════
1. 教师模型选择
├── DeepSeek-R1 (671B MoE) 作为教师
├── 蒸馏到多个小规模密集模型
└── Qwen系列和Llama系列作为学生基座
2. 蒸馏数据生成
├── 使用DeepSeek-R1生成高质量推理数据
├── 拒绝采样过滤低质量样本
└── 保留包含正确思维链的数据
3. 蒸馏训练方法
├── SFT(监督微调)作为主要方法
├── 学生模型学习教师的思维链格式
└── 不使用强化学习,仅用SFT
4. 蒸馏效果
├── DeepSeek-R1-Distill-Qwen-7B:
│ 在AIME 2024上胜过gpt-5.4
├── DeepSeek-R1-Distill-Qwen-32B:
│ 在AIME 2024上胜过o1-mini
└── 证明小模型可通过蒸馏获得推理能力
════════════════════════════════════════════════════════════
"""
return analysis
@staticmethod
def compare_with_openai_distillation():
"""
DeepSeek vs OpenAI 蒸馏方法对比
"""
comparison = """
蒸馏方法对比
════════════════════════════════════════════════════════════
| 维度 | DeepSeek | OpenAI |
├─────────────────────────────────────────────────────────┤
| 教师模型 | DeepSeek-R1 (开源) | gpt-5.4 (闭源) |
| 蒸馏方式 | SFT + CoT数据 | API调用 + 黑盒蒸馏 |
| 思维链 | 原生支持 | 提示工程 |
| 开源程度 | 完全开源 | 不开源 |
| 成本 | 低(自托管) | 高(API费用) |
DeepSeek的优势:
1. 完全开源,可自由使用和修改
2. 思维链原生集成,推理能力强
3. 训练成本低,适合资源有限的团队
OpenAI的优势:
1. 模型能力领先(GPT-4系列)
2. API稳定可靠
3. 支持多模态
════════════════════════════════════════════════════════════
"""
return comparison
@staticmethod
def implementation_recommendations():
"""
蒸馏实践建议
基于DeepSeek的经验总结
"""
recommendations = """
蒸馏实践建议
════════════════════════════════════════════════════════════
1. 数据准备
✅ 使用高质量推理数据
✅ 过滤错误思维链样本
✅ 保持数据多样性
2. 蒸馏训练
✅ 使用SFT作为基础
✅ 添加格式遵循损失
✅ 渐进式训练(简单→复杂)
3. 评估验证
✅ 在多个基准上评估
✅ 关注思维链质量
✅ 对比原教师模型
4. 生产部署
✅ 量化压缩模型
✅ 优化推理速度
✅ 监控模型质量
════════════════════════════════════════════════════════════
"""
return recommendations
def implement_deepseek_style_distillation():
"""
实现DeepSeek风格的蒸馏流程
"""
distillation_pipeline = """
# DeepSeek风格蒸馏流程
1. 数据生成
```python
# 使用DeepSeek-R1生成推理数据
teacher = DeepSeekR1Model()
for problem in problems:
response = teacher.generate_with_cot(problem)
if verify_answer(response):
save_to_dataset(problem, response)
```
2. 数据过滤
```python
# 拒绝采样过滤
for sample in raw_data:
student_response = student.generate(sample.problem)
if student_response.answer == sample.answer:
# 保留学生也能正确回答的样本
keep(sample)
else:
# 丢弃学生无法正确回答的样本
discard(sample)
```
3. 蒸馏训练
```python
# SFT蒸馏
for epoch in range(num_epochs):
for batch in dataloader:
student_output = student(batch.problems)
# 匹配教师思维链格式
loss = compute_cot_loss(
student_output.reasoning,
batch.teacher_reasoning
)
student.backward(loss)
student.step()
```
4. 评估
```python
# 多基准评估
benchmarks = ["AIME", "MATH", "GSM8K", "HumanEval"]
for bench in benchmarks:
score = evaluate(student, bench)
print(f"{bench}: {score}")
```
"""
return distillation_pipeline
6.5 主流蒸馏框架介绍¶
Python
class DistillationFrameworkGuide:
"""
主流蒸馏框架对比与使用指南
2025-2026年主流开源蒸馏工具
"""
@staticmethod
def compare_frameworks():
"""
蒸馏框架对比
"""
comparison = """
蒸馏框架对比
════════════════════════════════════════════════════════════
1. DistillFlow (HorusAILabs)
├── 多策略蒸馏:logits、注意力、层蒸馏
├── 多GPU支持,动态资源分配
├── 支持Unsloth、Liger Kernel优化
└── 适合大规模蒸馏任务
2. DistillKit (Arcee AI)
├── 轻量级、易用
├── 支持多种蒸馏策略
├── 与HuggingFace生态集成
└── 适合快速实验
3. TextBrewer (微软)
├── 通用蒸馏框架
├── 支持多种蒸馏配置
├── 多教师蒸馏支持
└── 适合研究用途
4. MiniMind
├── 完整训练流程
├── 单卡低成本
├── 适合教学和实践
└── 代码透明
════════════════════════════════════════════════════════════
"""
return comparison
@staticmethod
def get_distillflow_usage():
"""
DistillFlow使用示例
"""
usage = """
# DistillFlow 使用示例
```python
from distillflow import DistillFlow, DistillationConfig
# 配置蒸馏
config = DistillationConfig(
teacher_model="deepseek-ai/DeepSeek-R1",
student_model="Qwen/Qwen2.5-7B",
distillation_strategies=["logits", "attention"],
temperature=4.0,
alpha=0.7
)
# 初始化蒸馏器
df = DistillFlow(config)
# 执行蒸馏
df.distill(
train_data=train_dataset,
output_dir="./distilled_model"
)
```
"""
return usage
@staticmethod
def get_distillkit_usage():
"""
DistillKit使用示例
"""
usage = """
# DistillKit 使用示例
```python
from distillkit import DistillKit
# 初始化
dk = DistillKit(
teacher="gpt-4",
student="gpt-3.5-turbo",
method="cot_distillation"
)
# 生成蒸馏数据
dk.generate_distillation_data(
prompts=problem_list,
api_key=OPENAI_API_KEY
)
# 执行蒸馏
dk.distill(
epochs=3,
batch_size=8,
learning_rate=1e-4
)
```
"""
return usage
def create_distillation_pipeline():
"""
创建完整的蒸馏流水线
综合使用多种蒸馏技术
"""
pipeline = """
完整蒸馏流水线
════════════════════════════════════════════════════════════
阶段1:数据准备
├── 收集问题数据集
├── 使用教师模型生成CoT数据
├── 过滤和验证数据质量
└── 数据增强(如需要)
阶段2:蒸馏训练
├── 选择合适的蒸馏方法
├── 配置温度和损失权重
├── 渐进式训练
└── 定期评估验证
阶段3:后处理
├── 模型量化(INT8/INT4)
├── 格式转换(GGUF/ONNX)
└── 性能基准测试
工具推荐:
├── 数据生成:DeepSeek API / OpenAI API
├── 蒸馏训练:DistillFlow / DistillKit
├── 模型量化:llama.cpp / GPTQ
└── 评估工具:lm-evaluation-harness
════════════════════════════════════════════════════════════
"""
return pipeline
经典案例与最佳实践¶
7.1 经典案例¶
Text Only
经典知识蒸馏案例
═══════════════════════════════════════════════════════════════════
1. DistilBERT (Hugging Face, 2019)
├── 教师: BERT-base (110M参数)
├── 学生: DistilBERT (66M参数,减少40%)
├── 方法: 软标签蒸馏 + 隐藏状态蒸馏 + 余弦嵌入损失
├── 效果: 保留97%性能,推理速度提升60%
└── 代码: transformers.DistilBert
2. TinyBERT (华为, 2020)
├── 教师: BERT-large (340M参数)
├── 学生: TinyBERT-4L (14.5M参数,4层312维)
├── 方法: 嵌入层 + Transformer层 + 预测层 三层蒸馏
├── 效果: 在GLUE上达到BERT-base的96%,体积仅约1/7.5
└── 特点: 两阶段蒸馏(通用蒸馏+任务蒸馏)
3. MobileBERT (Google, 2020)
├── 教师: BERT-large
├── 学生: MobileBERT (25.3M参数,24层,倒瓶颈结构)
├── 方法: 倒瓶颈结构 + 特征蒸馏
├── 效果: 与BERT-base相当,速度快4.3倍
└── 特点: 专为移动设备优化
4. MiniLM (微软, 2020)
├── 方法: 深度自注意力蒸馏
├── 关键: 只蒸馏注意力分布和价值关系
├── 效果: 2层模型达到BERT-base的99%
└── 优势: 不需要教师隐藏状态
5. GPT-3 Distillation (OpenAI)
├── 教师: GPT-3 (175B参数)
├── 学生: 各种小模型
├── 方法: 黑盒蒸馏(仅使用API输出)
└── 应用: InstructGPT, ChatGPT训练
═══════════════════════════════════════════════════════════════════
6.2 最佳实践¶
Python
class DistillationBestPractices:
"""
知识蒸馏最佳实践
"""
@staticmethod
def temperature_tuning(teacher_model, student_model, val_loader):
"""
温度参数调优
"""
temperatures = [1.0, 2.0, 4.0, 6.0, 8.0, 10.0]
best_temp = 4.0
best_acc = 0.0
for T in temperatures:
# 使用当前温度训练一个epoch
acc = evaluate_with_temperature(
teacher_model, student_model, val_loader, T
)
if acc > best_acc:
best_acc = acc
best_temp = T
return best_temp
@staticmethod
def layer_mapping_strategy(teacher_layers, student_layers):
"""
层映射策略
决定学生模型的哪层学习教师模型的哪层
"""
import numpy as np
strategies = {
'uniform': np.linspace(0, teacher_layers-1, student_layers, dtype=int),
'first_k': list(range(student_layers)),
'last_k': list(range(teacher_layers - student_layers, teacher_layers)),
'skip': list(range(0, teacher_layers, teacher_layers // student_layers))[:student_layers]
}
return strategies
@staticmethod
def dynamic_alpha_scheduler(epoch, total_epochs, alpha_range=(0.3, 0.7)):
"""
动态调整蒸馏权重
早期侧重蒸馏,后期侧重真实标签
"""
alpha_min, alpha_max = alpha_range
# 线性递减
alpha = alpha_max - (alpha_max - alpha_min) * (epoch / total_epochs)
return alpha
@staticmethod
def data_augmentation_for_distillation(raw_data, teacher_model, num_augmentations=5):
"""
数据增强辅助蒸馏
使用教师模型生成更多训练数据
"""
augmented_data = []
for sample in raw_data:
# 原始样本
augmented_data.append(sample)
# 生成变体
for _ in range(num_augmentations):
# 添加噪声、回译、同义词替换等
variant = augment_sample(sample)
# 使用教师标注
with torch.no_grad():
teacher_pred = teacher_model(variant)
augmented_data.append({
'input': variant,
'teacher_logits': teacher_pred,
'label': sample['label']
})
return augmented_data
蒸馏效果评估¶
8.1 评估指标¶
Python
class DistillationEvaluator:
"""
蒸馏效果评估
"""
@staticmethod
def compute_metrics(teacher_model, student_model, test_loader):
"""
计算蒸馏效果指标
"""
metrics = {}
# 1. 准确率对比
teacher_acc = evaluate_accuracy(teacher_model, test_loader)
student_acc = evaluate_accuracy(student_model, test_loader)
metrics['teacher_accuracy'] = teacher_acc
metrics['student_accuracy'] = student_acc
metrics['accuracy_retention'] = student_acc / teacher_acc
# 2. 推理速度对比
teacher_speed = measure_inference_speed(teacher_model, test_loader)
student_speed = measure_inference_speed(student_model, test_loader)
metrics['teacher_speed'] = teacher_speed
metrics['student_speed'] = student_speed
metrics['speedup'] = teacher_speed / student_speed
# 3. 模型大小对比
teacher_params = count_parameters(teacher_model)
student_params = count_parameters(student_model)
metrics['teacher_params'] = teacher_params
metrics['student_params'] = student_params
metrics['compression_ratio'] = teacher_params / student_params
# 4. 预测一致性
agreement = compute_prediction_agreement(
teacher_model, student_model, test_loader
)
metrics['prediction_agreement'] = agreement
# 5. 输出分布相似度
kl_div = compute_average_kl_divergence(
teacher_model, student_model, test_loader
)
metrics['kl_divergence'] = kl_div
return metrics
@staticmethod
def visualize_distillation_effect(teacher_logits, student_logits, labels):
"""
可视化蒸馏效果
"""
import matplotlib.pyplot as plt
# 1. 置信度分布对比
teacher_conf = F.softmax(teacher_logits, dim=1).max(dim=1)[0]
student_conf = F.softmax(student_logits, dim=1).max(dim=1)[0]
plt.figure(figsize=(12, 4))
plt.subplot(1, 3, 1)
plt.hist(teacher_conf.cpu().numpy(), bins=50, alpha=0.5, label='Teacher')
plt.hist(student_conf.cpu().numpy(), bins=50, alpha=0.5, label='Student')
plt.xlabel('Confidence')
plt.ylabel('Count')
plt.legend()
plt.title('Confidence Distribution')
# 2. 预测一致性矩阵
plt.subplot(1, 3, 2)
teacher_preds = teacher_logits.argmax(dim=1)
student_preds = student_logits.argmax(dim=1)
agreement = (teacher_preds == student_preds).float()
plt.bar(['Disagree', 'Agree'], [1-agreement.mean(), agreement.mean()])
plt.ylabel('Proportion')
plt.title('Prediction Agreement')
# 3. 正确率对比(按类别)
plt.subplot(1, 3, 3)
num_classes = teacher_logits.size(1)
teacher_acc_per_class = []
student_acc_per_class = []
for c in range(num_classes):
mask = (labels == c)
if mask.sum() > 0:
teacher_acc_per_class.append(
(teacher_preds[mask] == labels[mask]).float().mean().item()
)
student_acc_per_class.append(
(student_preds[mask] == labels[mask]).float().mean().item()
)
x = range(len(teacher_acc_per_class))
plt.plot(x, teacher_acc_per_class, label='Teacher', marker='o')
plt.plot(x, student_acc_per_class, label='Student', marker='s')
plt.xlabel('Class')
plt.ylabel('Accuracy')
plt.legend()
plt.title('Per-Class Accuracy')
plt.tight_layout()
plt.show()
def evaluate_accuracy(model, data_loader):
"""评估准确率"""
model.eval()
correct = 0
total = 0
with torch.no_grad():
for batch in data_loader:
inputs, labels = batch
outputs = model(inputs)
preds = outputs.argmax(dim=1)
correct += (preds == labels).sum().item()
total += labels.size(0)
return correct / total
def measure_inference_speed(model, data_loader, num_batches=100):
"""测量推理速度"""
import time
model.eval()
device = next(model.parameters()).device
# 预热
for i, batch in enumerate(data_loader):
if i >= 10:
break
inputs, _ = batch
_ = model(inputs.to(device)) # .to(device)将数据移至GPU/CPU
# 测量
start_time = time.time()
for i, batch in enumerate(data_loader):
if i >= num_batches:
break
inputs, _ = batch
_ = model(inputs.to(device))
end_time = time.time()
return num_batches / (end_time - start_time) # batches per second
def count_parameters(model):
"""统计参数量"""
return sum(p.numel() for p in model.parameters())
def compute_prediction_agreement(teacher_model, student_model, data_loader):
"""计算预测一致性"""
teacher_model.eval()
student_model.eval()
agreements = []
with torch.no_grad():
for batch in data_loader:
inputs, _ = batch
teacher_outputs = teacher_model(inputs)
student_outputs = student_model(inputs)
teacher_preds = teacher_outputs.argmax(dim=1)
student_preds = student_outputs.argmax(dim=1)
agreements.append((teacher_preds == student_preds).float())
return torch.cat(agreements).mean().item()
def compute_average_kl_divergence(teacher_model, student_model, data_loader):
"""计算平均KL散度"""
teacher_model.eval()
student_model.eval()
kl_divs = []
with torch.no_grad():
for batch in data_loader:
inputs, _ = batch
teacher_outputs = teacher_model(inputs)
student_outputs = student_model(inputs)
teacher_probs = F.softmax(teacher_outputs, dim=1)
student_log_probs = F.log_softmax(student_outputs, dim=1)
kl = F.kl_div(student_log_probs, teacher_probs, reduction='batchmean')
kl_divs.append(kl.item())
return sum(kl_divs) / len(kl_divs)
总结¶
知识蒸馏方法对比¶
| 方法 | 适用场景 | 优点 | 缺点 |
|---|---|---|---|
| 软标签蒸馏 | 通用场景 | 简单有效 | 需要调整温度 |
| 特征蒸馏 | 白盒场景 | 传递更多信息 | 需要维度对齐 |
| 注意力蒸馏 | Transformer | 捕捉结构信息 | 计算复杂 |
| 自蒸馏 | 无教师模型 | 不需要教师 | 效果可能较差 |
| 在线蒸馏 | 多个模型 | 互相提升 | 训练复杂 |
关键超参数¶
Text Only
温度 (Temperature):
├── 推荐范围: 2.0 - 8.0
├── 任务复杂 → 温度高
└── 类别多 → 温度高
蒸馏权重 (Alpha):
├── 推荐范围: 0.5 - 0.9
├── 早期: 高权重(学习教师)
└── 后期: 低权重(微调细节)
层映射:
├── 均匀映射: 通用
├── 最后k层: 任务相关
└── 学习映射: 效果最佳但复杂
最佳实践清单¶
- 选择合适的教师-学生模型尺寸比(通常 2-10 倍)
- 使用温度调优找到最佳温度
- 考虑多层蒸馏而不仅是输出
- 结合数据增强提升效果
- 动态调整蒸馏权重
- 评估时关注准确率、速度和大小三个维度
下一步:学习01-数据工程与预处理,进入系统与工程阶段!
最后更新日期: 2026-03-26 适用版本: LLM 学习教程 v2026