跳转至

05 - 知识蒸馏(全面版)

⚠️ 时效性说明:本章涉及前沿模型/价格/榜单等信息,可能随版本快速变化;请以论文原文、官方发布页和 API 文档为准。

学习目标:深入理解知识蒸馏的原理、方法与实践,掌握如何将大模型知识迁移到小模型。


目录

  1. 知识蒸馏概述
  2. 基础蒸馏方法
  3. 高级蒸馏技术
  4. 白盒蒸馏与黑盒蒸馏
  5. 蒸馏与量化的结合
  6. 经典案例与最佳实践
  7. 蒸馏效果评估

知识蒸馏概述

1.1 什么是知识蒸馏

Text Only
知识蒸馏(Knowledge Distillation)

核心思想:将大型教师模型(Teacher)的知识迁移到小型学生模型(Student)

为什么需要蒸馏?
├── 大模型性能好但推理慢、部署难
├── 小模型速度快但性能差
└── 蒸馏:让小模型学习大模型的"暗知识"

暗知识(Dark Knowledge):
├── 不仅学习正确标签
├── 还学习类别间的相似关系
└── 例如:狗→狼的概率比狗→汽车高

类比:
├── 教师:经验丰富的专家
├── 学生:初学者
└── 蒸馏:学生学习专家的思考方式,不只是答案

1.2 蒸馏的基本框架

Text Only
┌─────────────────────────────────────────────────────────────────┐
│                     知识蒸馏框架                                 │
├─────────────────────────────────────────────────────────────────┤
│                                                                  │
│  ┌─────────────────────┐        ┌─────────────────────┐         │
│  │    教师模型 (T)      │        │    学生模型 (S)      │         │
│  │    (大、复杂)        │        │    (小、简单)        │         │
│  │                     │        │                     │         │
│  │  Input ──▶ T ──▶    │        │  Input ──▶ S ──▶    │         │
│  │         Soft logits │        │         Soft logits │         │
│  │              │      │        │              │      │         │
│  │              ▼      │        │              ▼      │         │
│  │         Softmax     │        │         Softmax     │         │
│  │         (高温τ)     │        │         (高温τ)     │         │
│  │              │      │        │              │      │         │
│  │              ▼      │        │              ▼      │         │
│  │    Soft Targets     │        │    Soft Predictions │         │
│  │         │           │        │         │           │         │
│  │         └───────────┼────────┘         │           │         │
│  │                     │                  │           │         │
│  │                     ▼                  │           │         │
│  │              KL Divergence             │           │         │
│  │              (蒸馏损失)                │           │         │
│  │                     │                  │           │         │
│  │                     ▼                  ▼           │         │
│  │              Hard Targets ───────▶ CE Loss        │         │
│  │              (真实标签)           (学生损失)       │         │
│  │                                                      │         │
│  └──────────────────────────────────────────────────────┘         │
│                              │                                   │
│                              ▼                                   │
│                    Total Loss = α·L_soft + β·L_hard             │
│                                                                  │
└─────────────────────────────────────────────────────────────────┘

基础蒸馏方法

2.1 软标签蒸馏(Hinton Distillation)

Python
import torch
import torch.nn as nn
import torch.nn.functional as F

class KnowledgeDistillationLoss(nn.Module):
    """
    Hinton知识蒸馏损失

    "Distilling the Knowledge in a Neural Network" (Hinton et al., 2015)
    """
    def __init__(self, temperature=4.0, alpha=0.7):
        super().__init__()  # super()调用父类方法
        self.temperature = temperature
        self.alpha = alpha  # 蒸馏损失权重

    def forward(self, student_logits, teacher_logits, true_labels):
        """
        Args:
            student_logits: 学生模型输出 [batch, num_classes]
            teacher_logits: 教师模型输出 [batch, num_classes]
            true_labels: 真实标签 [batch]

        Returns:
            total_loss: 总损失
            distill_loss: 蒸馏损失(KL散度)
            ce_loss: 交叉熵损失
        """
        # 1. 软目标损失(KL散度)
        soft_student = F.log_softmax(student_logits / self.temperature, dim=1)
        soft_teacher = F.softmax(teacher_logits / self.temperature, dim=1)

        # KL散度 = Σ p_teacher * log(p_teacher / p_student)
        distill_loss = F.kl_div(
            soft_student,
            soft_teacher,
            reduction='batchmean'
        ) * (self.temperature ** 2)  # 缩放因子

        # 2. 硬目标损失(交叉熵)
        ce_loss = F.cross_entropy(student_logits, true_labels)

        # 3. 总损失
        total_loss = self.alpha * distill_loss + (1 - self.alpha) * ce_loss

        return total_loss, distill_loss, ce_loss

# 温度参数的作用
def demonstrate_temperature():
    """
    演示温度参数对软标签的影响
    """
    import matplotlib.pyplot as plt

    # 假设logits
    logits = torch.tensor([2.0, 1.0, 0.1])

    temperatures = [0.5, 1.0, 2.0, 4.0, 8.0]

    plt.figure(figsize=(10, 6))
    for T in temperatures:
        probs = F.softmax(logits / T, dim=0)
        plt.plot(probs.numpy(), label=f'T={T}', marker='o')

    plt.xlabel('Class Index')
    plt.ylabel('Probability')
    plt.title('Effect of Temperature on Softmax Distribution')
    plt.legend()
    plt.grid(True)
    plt.show()

    """
    结论:
    - T→0: 趋近于one-hot(硬标签)
    - T→∞: 趋近于均匀分布
    - 适中的T(2-4): 保留类别间关系,但不过于平滑
    """

# demonstrate_temperature()

2.2 特征蒸馏(Feature Distillation)

Python
class FeatureDistillationLoss(nn.Module):
    """
    特征蒸馏

    让学生模型学习教师模型的中间层特征表示
    """
    def __init__(self, student_dim, teacher_dim, hidden_dim=None):
        super().__init__()

        if hidden_dim is None:
            hidden_dim = student_dim

        # 适配层:将学生特征映射到教师特征空间
        self.adaptation_layer = nn.Sequential(
            nn.Linear(student_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, teacher_dim)
        )

    def forward(self, student_features, teacher_features):
        """
        Args:
            student_features: [batch, student_dim]
            teacher_features: [batch, teacher_dim]
        """
        # 适配学生特征
        adapted_student = self.adaptation_layer(student_features)

        # 均方误差
        loss = F.mse_loss(adapted_student, teacher_features)

        return loss

class MultiLayerFeatureDistillation(nn.Module):
    """
    多层特征蒸馏(FitNets)

    在多个中间层进行特征蒸馏
    """
    def __init__(self, student_dims, teacher_dims, weights=None):
        super().__init__()

        assert len(student_dims) == len(teacher_dims)  # assert断言:条件False时抛出AssertionError

        self.num_layers = len(student_dims)

        # 为每层创建适配层
        self.adaptation_layers = nn.ModuleList([
            nn.Linear(s_dim, t_dim)
            for s_dim, t_dim in zip(student_dims, teacher_dims)  # zip按位置配对多个可迭代对象
        ])

        # 每层的权重
        if weights is None:
            weights = [1.0] * self.num_layers
        self.weights = weights

    def forward(self, student_features_list, teacher_features_list):
        """
        Args:
            student_features_list: 学生模型多层特征
            teacher_features_list: 教师模型多层特征
        """
        total_loss = 0

        for i, (s_feat, t_feat, adapt_layer) in enumerate(
            zip(student_features_list, teacher_features_list, self.adaptation_layers)
        ):
            # 适配并计算损失
            adapted = adapt_layer(s_feat)
            loss = F.mse_loss(adapted, t_feat)
            total_loss += self.weights[i] * loss

        return total_loss

2.3 注意力蒸馏

Python
class AttentionDistillationLoss(nn.Module):
    """
    注意力蒸馏

    让学生模型学习教师模型的注意力模式
    """
    def __init__(self, temperature=1.0):
        super().__init__()
        self.temperature = temperature

    def forward(self, student_attention, teacher_attention):
        """
        Args:
            student_attention: [batch, num_heads, seq_len, seq_len]
            teacher_attention: [batch, num_heads, seq_len, seq_len]
        """
        # 方法1: KL散度
        student_attn = F.log_softmax(student_attention / self.temperature, dim=-1)
        teacher_attn = F.softmax(teacher_attention / self.temperature, dim=-1)

        kl_loss = F.kl_div(
            student_attn,
            teacher_attn,
            reduction='batchmean'
        )

        # 方法2: 均方误差(有时效果更好)
        mse_loss = F.mse_loss(student_attention, teacher_attention)

        # 组合
        total_loss = kl_loss + mse_loss

        return total_loss

class MultiHeadAttentionTransfer(nn.Module):
    """
    多头注意力迁移

    TinyBERT等方法使用
    """
    def __init__(self, student_num_heads, teacher_num_heads):
        super().__init__()
        self.student_num_heads = student_num_heads
        self.teacher_num_heads = teacher_num_heads

    def forward(self, student_attn_list, teacher_attn_list):
        """
        对齐不同层数的注意力
        """
        # 选择教师模型的特定层进行蒸馏
        # 例如:学生6层,教师12层,选择教师第2,4,6,8,10,12层
        selected_teacher_layers = self.select_layers(
            len(teacher_attn_list),
            len(student_attn_list)
        )

        total_loss = 0
        for student_idx, teacher_idx in enumerate(selected_teacher_layers):  # enumerate同时获取索引和元素
            s_attn = student_attn_list[student_idx]
            t_attn = teacher_attn_list[teacher_idx]

            # 如果头数不同,需要平均或插值
            if s_attn.size(1) != t_attn.size(1):
                t_attn = t_attn.mean(dim=1, keepdim=True).expand(-1, s_attn.size(1), -1, -1)

            loss = F.mse_loss(s_attn, t_attn)
            total_loss += loss

        return total_loss / len(student_attn_list)

    def select_layers(self, teacher_layers, student_layers):
        """
        选择对应的教师层
        """
        import numpy as np
        indices = np.linspace(0, teacher_layers - 1, student_layers, dtype=int)
        return indices.tolist()

高级蒸馏技术

3.1 自蒸馏(Self-Distillation)

Python
class SelfDistillationLoss(nn.Module):
    """
    自蒸馏

    模型自己蒸馏自己,不需要单独的教师模型
    """
    def __init__(self, temperature=4.0):
        super().__init__()
        self.temperature = temperature

    def forward(self, logits, labels, epoch, total_epochs):
        """
        随着训练进行,逐渐增加软标签的影响
        """
        # 硬标签损失
        hard_loss = F.cross_entropy(logits, labels)

        # 生成软标签(使用当前模型自己的预测)
        with torch.no_grad():  # 禁用梯度计算,节省内存(推理时使用)
            soft_targets = F.softmax(logits / self.temperature, dim=1)

        # 软标签损失
        # ⚠️ 演示说明: 此处 soft_targets 与 logits 来自同一分布,理论上 KL(p‖p)≡0。
        # 实际自蒸馏中,soft_targets 应来自上一个 epoch 保存的模型输出(即“历史版本”作为教师)。
        # 这里仅作为代码结构示例。
        soft_loss = F.kl_div(
            F.log_softmax(logits / self.temperature, dim=1),
            soft_targets,
            reduction='batchmean'
        ) * (self.temperature ** 2)

        # 动态权重:早期主要用硬标签,后期增加软标签
        alpha = min(0.5, epoch / (total_epochs * 0.5))

        total_loss = (1 - alpha) * hard_loss + alpha * soft_loss

        return total_loss

3.2 在线蒸馏(Online Distillation)

Python
class OnlineDistillation:
    """
    在线蒸馏(Deep Mutual Learning)

    多个学生模型互相学习,没有固定的教师
    """
    def __init__(self, models, temperature=4.0):
        self.models = models  # 多个学生模型
        self.temperature = temperature
        self.optimizers = [torch.optim.AdamW(m.parameters()) for m in models]

    def train_step(self, batch):
        """
        所有模型互相蒸馏
        """
        inputs, labels = batch

        # 获取所有模型的预测
        logits_list = [model(inputs) for model in self.models]

        # 计算平均预测作为"软教师"
        avg_logits = torch.stack(logits_list).mean(dim=0)

        # 每个模型向平均预测学习
        for i, (model, optimizer) in enumerate(zip(self.models, self.optimizers)):
            optimizer.zero_grad()

            # 硬标签损失
            hard_loss = F.cross_entropy(logits_list[i], labels)

            # 向其他模型学习(排除自己)
            other_logits = torch.stack(
                [logits_list[j] for j in range(len(self.models)) if j != i]
            ).mean(dim=0)

            soft_loss = F.kl_div(
                F.log_softmax(logits_list[i] / self.temperature, dim=1),
                F.softmax(other_logits / self.temperature, dim=1),
                reduction='batchmean'
            ) * (self.temperature ** 2)

            # 总损失
            loss = hard_loss + soft_loss
            loss.backward()
            optimizer.step()

3.3 任务特定蒸馏

Python
class TaskSpecificDistillation:
    """
    任务特定蒸馏

    针对不同任务设计特定的蒸馏策略
    """

    @staticmethod  # @staticmethod无需实例即可调用
    def classification_distillation(student_logits, teacher_logits, labels, alpha=0.7, T=4.0):
        """
        分类任务蒸馏
        """
        soft_loss = F.kl_div(
            F.log_softmax(student_logits / T, dim=1),
            F.softmax(teacher_logits / T, dim=1),
            reduction='batchmean'
        ) * (T ** 2)

        hard_loss = F.cross_entropy(student_logits, labels)

        return alpha * soft_loss + (1 - alpha) * hard_loss

    @staticmethod
    def ner_distillation(student_logits, teacher_logits, labels, attention_mask, alpha=0.7, T=4.0):
        """
        命名实体识别任务蒸馏

        每个token独立计算
        """
        batch_size, seq_len, num_classes = student_logits.shape

        # 展平
        student_logits = student_logits.view(-1, num_classes)  # view重塑张量形状
        teacher_logits = teacher_logits.view(-1, num_classes)
        labels = labels.view(-1)
        attention_mask = attention_mask.view(-1)

        # 只计算有效token
        valid_indices = attention_mask.bool()
        student_logits = student_logits[valid_indices]
        teacher_logits = teacher_logits[valid_indices]
        labels = labels[valid_indices]

        soft_loss = F.kl_div(
            F.log_softmax(student_logits / T, dim=1),
            F.softmax(teacher_logits / T, dim=1),
            reduction='batchmean'
        ) * (T ** 2)

        hard_loss = F.cross_entropy(student_logits, labels)

        return alpha * soft_loss + (1 - alpha) * hard_loss

    @staticmethod
    def generation_distillation(student_logits, teacher_logits, labels, alpha=0.7, T=4.0):
        """
        生成任务蒸馏

        序列生成任务的蒸馏
        """
        # 移位:预测下一个token
        shift_student = student_logits[..., :-1, :].contiguous()
        shift_teacher = teacher_logits[..., :-1, :].contiguous()
        shift_labels = labels[..., 1:].contiguous()

        # 展平
        shift_student = shift_student.view(-1, shift_student.size(-1))
        shift_teacher = shift_teacher.view(-1, shift_teacher.size(-1))
        shift_labels = shift_labels.view(-1)

        soft_loss = F.kl_div(
            F.log_softmax(shift_student / T, dim=1),
            F.softmax(shift_teacher / T, dim=1),
            reduction='batchmean'
        ) * (T ** 2)

        hard_loss = F.cross_entropy(shift_student, shift_labels)

        return alpha * soft_loss + (1 - alpha) * hard_loss

白盒蒸馏与黑盒蒸馏

4.1 白盒蒸馏

Python
class WhiteBoxDistillation:
    """
    白盒蒸馏

    可以访问教师模型的内部参数和特征
    """
    def __init__(self, teacher_model, student_model):
        self.teacher_model = teacher_model
        self.student_model = student_model

        # 注册hook获取中间特征
        self.teacher_features = []
        self.student_features = []

        self._register_hooks()

    def _register_hooks(self):
        """
        注册前向hook获取特征
        """
        self.hook_handles = []  # 保存hook句柄,用于后续清理

        def get_features(name, storage):
            def hook(module, input, output):
                storage.append(output)
            return hook

        # 在特定层注册hook
        for i, layer in enumerate(self.teacher_model.transformer_layers):
            handle = layer.register_forward_hook(
                get_features(f'teacher_layer_{i}', self.teacher_features)
            )
            self.hook_handles.append(handle)

        for i, layer in enumerate(self.student_model.transformer_layers):
            handle = layer.register_forward_hook(
                get_features(f'student_layer_{i}', self.student_features)
            )
            self.hook_handles.append(handle)

    def forward(self, inputs, labels):
        """
        前向传播并计算蒸馏损失
        """
        # 教师前向
        with torch.no_grad():
            teacher_logits = self.teacher_model(inputs)

        # 学生前向
        student_logits = self.student_model(inputs)

        # 1. 输出蒸馏
        output_loss = F.kl_div(
            F.log_softmax(student_logits / 4.0, dim=1),
            F.softmax(teacher_logits / 4.0, dim=1),
            reduction='batchmean'
        ) * 16

        # 2. 特征蒸馏
        feature_loss = 0
        for s_feat, t_feat in zip(self.student_features, self.teacher_features):
            # 适配维度
            if s_feat.shape != t_feat.shape:
                t_feat = F.adaptive_avg_pool1d(
                    t_feat.transpose(1, 2),
                    s_feat.size(1)
                ).transpose(1, 2)

            feature_loss += F.mse_loss(s_feat, t_feat)

        # 3. 硬标签损失
        ce_loss = F.cross_entropy(student_logits, labels)

        # 总损失
        total_loss = 0.5 * output_loss + 0.3 * feature_loss + 0.2 * ce_loss

        # 清空特征缓存
        self.teacher_features.clear()
        self.student_features.clear()

        return total_loss

4.2 黑盒蒸馏

Python
class BlackBoxDistillation:
    """
    黑盒蒸馏

    只能访问教师模型的输出(如API),无法访问内部参数
    """
    def __init__(self, student_model, teacher_api=None):
        self.student_model = student_model
        self.teacher_api = teacher_api  # 教师模型API接口

    def get_teacher_predictions(self, inputs):
        """
        通过API获取教师模型预测

        实际应用中可能是:
        - OpenAI API
        - 本地教师模型推理
        - 预计算的教师预测
        """
        if self.teacher_api is None:
            raise ValueError("Teacher API not provided")

        # 调用API获取预测
        teacher_logits = self.teacher_api.predict(inputs)

        return teacher_logits

    def train_step(self, batch):
        """
        训练步骤
        """
        inputs, labels = batch

        # 获取教师预测(黑盒)
        with torch.no_grad():
            teacher_logits = self.get_teacher_predictions(inputs)

        # 学生前向
        student_logits = self.student_model(inputs)

        # 只能使用输出蒸馏
        distill_loss = F.kl_div(
            F.log_softmax(student_logits / 4.0, dim=1),
            F.softmax(teacher_logits / 4.0, dim=1),
            reduction='batchmean'
        ) * 16

        # 硬标签损失
        ce_loss = F.cross_entropy(student_logits, labels)

        # 总损失
        total_loss = 0.7 * distill_loss + 0.3 * ce_loss

        return total_loss

class DataFreeDistillation:
    """
    无数据蒸馏

    不需要原始训练数据,通过生成数据蒸馏
    """
    def __init__(self, student_model, generator_model):
        self.student_model = student_model
        self.generator_model = generator_model  # 数据生成器

    def generate_synthetic_data(self, num_samples, sequence_length):
        """
        生成合成数据
        """
        # 使用生成器或随机生成
        synthetic_inputs = self.generator_model.generate(
            num_samples=num_samples,
            sequence_length=sequence_length
        )

        return synthetic_inputs

    def train_step(self, teacher_model):
        """
        训练步骤
        """
        # 生成合成数据
        synthetic_inputs = self.generate_synthetic_data(
            num_samples=32,
            sequence_length=128
        )

        # 获取教师预测
        with torch.no_grad():
            teacher_logits = teacher_model(synthetic_inputs)

        # 学生预测
        student_logits = self.student_model(synthetic_inputs)

        # 蒸馏损失
        loss = F.kl_div(
            F.log_softmax(student_logits / 4.0, dim=1),
            F.softmax(teacher_logits / 4.0, dim=1),
            reduction='batchmean'
        ) * 16

        return loss

蒸馏与量化的结合

5.1 量化感知蒸馏

Python
class QuantizationAwareDistillation:
    """
    量化感知蒸馏

    在蒸馏过程中考虑量化误差
    """
    def __init__(self, teacher_model, student_model, num_bits=8):
        self.teacher_model = teacher_model
        self.student_model = student_model
        self.num_bits = num_bits

    def quantize_tensor(self, tensor, num_bits):
        """
        模拟量化
        """
        # 计算缩放因子
        qmin = -(2 ** (num_bits - 1))
        qmax = 2 ** (num_bits - 1) - 1

        min_val = tensor.min()
        max_val = tensor.max()

        scale = (max_val - min_val) / (qmax - qmin)
        zero_point = qmin - min_val / scale

        # 量化
        quantized = torch.clamp(
            torch.round(tensor / scale + zero_point),
            qmin, qmax
        )

        # 反量化
        dequantized = (quantized - zero_point) * scale

        return dequantized

    def forward(self, inputs, labels):
        """
        前向传播
        """
        # 教师预测(全精度)
        with torch.no_grad():
            teacher_logits = self.teacher_model(inputs)

        # 学生预测(模拟量化)
        student_logits = self.student_model(inputs)

        # 模拟量化学生输出
        quantized_student_logits = self.quantize_tensor(
            student_logits, self.num_bits
        )

        # 蒸馏损失(使用量化后的输出)
        distill_loss = F.kl_div(
            F.log_softmax(quantized_student_logits / 4.0, dim=1),
            F.softmax(teacher_logits / 4.0, dim=1),
            reduction='batchmean'
        ) * 16

        # 硬标签损失(使用原始输出)
        ce_loss = F.cross_entropy(student_logits, labels)

        return 0.7 * distill_loss + 0.3 * ce_loss

5.2 渐进式压缩

Python
class ProgressiveCompression:
    """
    渐进式压缩

    先蒸馏,再量化,最后剪枝
    """
    def __init__(self, teacher_model):
        self.teacher_model = teacher_model
        self.student_model = None

    def stage1_distillation(self, train_loader, num_epochs=10):
        """
        第一阶段:知识蒸馏
        """
        # 创建学生模型(比教师小)
        self.student_model = self.create_student_model()

        # 蒸馏训练
        distillation_trainer = KnowledgeDistillationTrainer(
            teacher=self.teacher_model,
            student=self.student_model
        )

        for epoch in range(num_epochs):
            for batch in train_loader:
                loss = distillation_trainer.train_step(batch)

        return self.student_model

    def stage2_quantization(self, model, num_bits=8):
        """
        第二阶段:量化
        """
        quantized_model = torch.quantization.quantize_dynamic(
            model,
            {nn.Linear},
            dtype=torch.qint8
        )

        return quantized_model

    def stage3_pruning(self, model, sparsity=0.3):
        """
        第三阶段:剪枝
        """
        import torch.nn.utils.prune as prune

        # 对线性层进行剪枝
        for name, module in model.named_modules():
            if isinstance(module, nn.Linear):  # isinstance检查类型
                prune.l1_unstructured(
                    module,
                    name='weight',
                    amount=sparsity
                )

        return model

    def full_pipeline(self, train_loader):
        """
        完整压缩流程
        """
        # Stage 1: 蒸馏
        print("Stage 1: Distillation...")
        distilled_model = self.stage1_distillation(train_loader)

        # Stage 2: 量化
        print("Stage 2: Quantization...")
        quantized_model = self.stage2_quantization(distilled_model)

        # Stage 3: 剪枝
        print("Stage 3: Pruning...")
        final_model = self.stage3_pruning(quantized_model)

        return final_model

经典案例与最佳实践

6.1 经典案例

Text Only
经典知识蒸馏案例
═══════════════════════════════════════════════════════════════════

1. DistilBERT (Hugging Face, 2019)
├── 教师: BERT-base (110M参数)
├── 学生: DistilBERT (66M参数,减少40%)
├── 方法: 软标签蒸馏 + 隐藏状态蒸馏 + 余弦嵌入损失
├── 效果: 保留97%性能,推理速度提升60%
└── 代码: transformers.DistilBert

2. TinyBERT (华为, 2020)
├── 教师: BERT-large (340M参数)
├── 学生: TinyBERT-4L (14.5M参数,4层312维)
├── 方法: 嵌入层 + Transformer层 + 预测层 三层蒸馏
├── 效果: 在GLUE上达到BERT-base的96%,体积仅约1/7.5
└── 特点: 两阶段蒸馏(通用蒸馏+任务蒸馏)

3. MobileBERT (Google, 2020)
├── 教师: BERT-large
├── 学生: MobileBERT (25.3M参数,24层,倒瓶颈结构)
├── 方法: 倒瓶颈结构 + 特征蒸馏
├── 效果: 与BERT-base相当,速度快4.3倍
└── 特点: 专为移动设备优化

4. MiniLM (微软, 2020)
├── 方法: 深度自注意力蒸馏
├── 关键: 只蒸馏注意力分布和价值关系
├── 效果: 2层模型达到BERT-base的99%
└── 优势: 不需要教师隐藏状态

5. GPT-3 Distillation (OpenAI)
├── 教师: GPT-3 (175B参数)
├── 学生: 各种小模型
├── 方法: 黑盒蒸馏(仅使用API输出)
└── 应用: InstructGPT, ChatGPT训练

═══════════════════════════════════════════════════════════════════

6.2 最佳实践

Python
class DistillationBestPractices:
    """
    知识蒸馏最佳实践
    """

    @staticmethod
    def temperature_tuning(teacher_model, student_model, val_loader):
        """
        温度参数调优
        """
        temperatures = [1.0, 2.0, 4.0, 6.0, 8.0, 10.0]
        best_temp = 4.0
        best_acc = 0.0

        for T in temperatures:
            # 使用当前温度训练一个epoch
            acc = evaluate_with_temperature(
                teacher_model, student_model, val_loader, T
            )

            if acc > best_acc:
                best_acc = acc
                best_temp = T

        return best_temp

    @staticmethod
    def layer_mapping_strategy(teacher_layers, student_layers):
        """
        层映射策略

        决定学生模型的哪层学习教师模型的哪层
        """
        import numpy as np

        strategies = {
            'uniform': np.linspace(0, teacher_layers-1, student_layers, dtype=int),
            'first_k': list(range(student_layers)),
            'last_k': list(range(teacher_layers - student_layers, teacher_layers)),
            'skip': list(range(0, teacher_layers, teacher_layers // student_layers))[:student_layers]
        }

        return strategies

    @staticmethod
    def dynamic_alpha_scheduler(epoch, total_epochs, alpha_range=(0.3, 0.7)):
        """
        动态调整蒸馏权重

        早期侧重蒸馏,后期侧重真实标签
        """
        alpha_min, alpha_max = alpha_range

        # 线性递减
        alpha = alpha_max - (alpha_max - alpha_min) * (epoch / total_epochs)

        return alpha

    @staticmethod
    def data_augmentation_for_distillation(raw_data, teacher_model, num_augmentations=5):
        """
        数据增强辅助蒸馏

        使用教师模型生成更多训练数据
        """
        augmented_data = []

        for sample in raw_data:
            # 原始样本
            augmented_data.append(sample)

            # 生成变体
            for _ in range(num_augmentations):
                # 添加噪声、回译、同义词替换等
                variant = augment_sample(sample)

                # 使用教师标注
                with torch.no_grad():
                    teacher_pred = teacher_model(variant)

                augmented_data.append({
                    'input': variant,
                    'teacher_logits': teacher_pred,
                    'label': sample['label']
                })

        return augmented_data

蒸馏效果评估

7.1 评估指标

Python
class DistillationEvaluator:
    """
    蒸馏效果评估
    """

    @staticmethod
    def compute_metrics(teacher_model, student_model, test_loader):
        """
        计算蒸馏效果指标
        """
        metrics = {}

        # 1. 准确率对比
        teacher_acc = evaluate_accuracy(teacher_model, test_loader)
        student_acc = evaluate_accuracy(student_model, test_loader)

        metrics['teacher_accuracy'] = teacher_acc
        metrics['student_accuracy'] = student_acc
        metrics['accuracy_retention'] = student_acc / teacher_acc

        # 2. 推理速度对比
        teacher_speed = measure_inference_speed(teacher_model, test_loader)
        student_speed = measure_inference_speed(student_model, test_loader)

        metrics['teacher_speed'] = teacher_speed
        metrics['student_speed'] = student_speed
        metrics['speedup'] = teacher_speed / student_speed

        # 3. 模型大小对比
        teacher_params = count_parameters(teacher_model)
        student_params = count_parameters(student_model)

        metrics['teacher_params'] = teacher_params
        metrics['student_params'] = student_params
        metrics['compression_ratio'] = teacher_params / student_params

        # 4. 预测一致性
        agreement = compute_prediction_agreement(
            teacher_model, student_model, test_loader
        )
        metrics['prediction_agreement'] = agreement

        # 5. 输出分布相似度
        kl_div = compute_average_kl_divergence(
            teacher_model, student_model, test_loader
        )
        metrics['kl_divergence'] = kl_div

        return metrics

    @staticmethod
    def visualize_distillation_effect(teacher_logits, student_logits, labels):
        """
        可视化蒸馏效果
        """
        import matplotlib.pyplot as plt

        # 1. 置信度分布对比
        teacher_conf = F.softmax(teacher_logits, dim=1).max(dim=1)[0]
        student_conf = F.softmax(student_logits, dim=1).max(dim=1)[0]

        plt.figure(figsize=(12, 4))

        plt.subplot(1, 3, 1)
        plt.hist(teacher_conf.cpu().numpy(), bins=50, alpha=0.5, label='Teacher')
        plt.hist(student_conf.cpu().numpy(), bins=50, alpha=0.5, label='Student')
        plt.xlabel('Confidence')
        plt.ylabel('Count')
        plt.legend()
        plt.title('Confidence Distribution')

        # 2. 预测一致性矩阵
        plt.subplot(1, 3, 2)
        teacher_preds = teacher_logits.argmax(dim=1)
        student_preds = student_logits.argmax(dim=1)

        agreement = (teacher_preds == student_preds).float()
        plt.bar(['Disagree', 'Agree'], [1-agreement.mean(), agreement.mean()])
        plt.ylabel('Proportion')
        plt.title('Prediction Agreement')

        # 3. 正确率对比(按类别)
        plt.subplot(1, 3, 3)
        num_classes = teacher_logits.size(1)
        teacher_acc_per_class = []
        student_acc_per_class = []

        for c in range(num_classes):
            mask = (labels == c)
            if mask.sum() > 0:
                teacher_acc_per_class.append(
                    (teacher_preds[mask] == labels[mask]).float().mean().item()
                )
                student_acc_per_class.append(
                    (student_preds[mask] == labels[mask]).float().mean().item()
                )

        x = range(len(teacher_acc_per_class))
        plt.plot(x, teacher_acc_per_class, label='Teacher', marker='o')
        plt.plot(x, student_acc_per_class, label='Student', marker='s')
        plt.xlabel('Class')
        plt.ylabel('Accuracy')
        plt.legend()
        plt.title('Per-Class Accuracy')

        plt.tight_layout()
        plt.show()

def evaluate_accuracy(model, data_loader):
    """评估准确率"""
    model.eval()
    correct = 0
    total = 0

    with torch.no_grad():
        for batch in data_loader:
            inputs, labels = batch
            outputs = model(inputs)
            preds = outputs.argmax(dim=1)
            correct += (preds == labels).sum().item()
            total += labels.size(0)

    return correct / total

def measure_inference_speed(model, data_loader, num_batches=100):
    """测量推理速度"""
    import time

    model.eval()
    device = next(model.parameters()).device

    # 预热
    for i, batch in enumerate(data_loader):
        if i >= 10:
            break
        inputs, _ = batch
        _ = model(inputs.to(device))  # .to(device)将数据移至GPU/CPU

    # 测量
    start_time = time.time()

    for i, batch in enumerate(data_loader):
        if i >= num_batches:
            break
        inputs, _ = batch
        _ = model(inputs.to(device))

    end_time = time.time()

    return num_batches / (end_time - start_time)  # batches per second

def count_parameters(model):
    """统计参数量"""
    return sum(p.numel() for p in model.parameters())

def compute_prediction_agreement(teacher_model, student_model, data_loader):
    """计算预测一致性"""
    teacher_model.eval()
    student_model.eval()

    agreements = []

    with torch.no_grad():
        for batch in data_loader:
            inputs, _ = batch

            teacher_outputs = teacher_model(inputs)
            student_outputs = student_model(inputs)

            teacher_preds = teacher_outputs.argmax(dim=1)
            student_preds = student_outputs.argmax(dim=1)

            agreements.append((teacher_preds == student_preds).float())

    return torch.cat(agreements).mean().item()

def compute_average_kl_divergence(teacher_model, student_model, data_loader):
    """计算平均KL散度"""
    teacher_model.eval()
    student_model.eval()

    kl_divs = []

    with torch.no_grad():
        for batch in data_loader:
            inputs, _ = batch

            teacher_outputs = teacher_model(inputs)
            student_outputs = student_model(inputs)

            teacher_probs = F.softmax(teacher_outputs, dim=1)
            student_log_probs = F.log_softmax(student_outputs, dim=1)

            kl = F.kl_div(student_log_probs, teacher_probs, reduction='batchmean')
            kl_divs.append(kl.item())

    return sum(kl_divs) / len(kl_divs)

总结

知识蒸馏方法对比

方法 适用场景 优点 缺点
软标签蒸馏 通用场景 简单有效 需要调整温度
特征蒸馏 白盒场景 传递更多信息 需要维度对齐
注意力蒸馏 Transformer 捕捉结构信息 计算复杂
自蒸馏 无教师模型 不需要教师 效果可能较差
在线蒸馏 多个模型 互相提升 训练复杂

关键超参数

Text Only
温度 (Temperature):
├── 推荐范围: 2.0 - 8.0
├── 任务复杂 → 温度高
└── 类别多 → 温度高

蒸馏权重 (Alpha):
├── 推荐范围: 0.5 - 0.9
├── 早期: 高权重(学习教师)
└── 后期: 低权重(微调细节)

层映射:
├── 均匀映射: 通用
├── 最后k层: 任务相关
└── 学习映射: 效果最佳但复杂

最佳实践清单

  • 选择合适的教师-学生模型尺寸比(通常2-10倍)
  • 使用温度调优找到最佳温度
  • 考虑多层蒸馏而不仅是输出
  • 结合数据增强提升效果
  • 动态调整蒸馏权重
  • 评估时关注准确率、速度和大小三个维度

下一步:学习01-数据工程与预处理,进入系统与工程阶段!


最后更新日期:2026-02-12 适用版本:LLM学习教程 v2026