跳转至

05 - 知识蒸馏

⚠️ 时效性说明:本章涉及前沿模型/价格/榜单等信息,可能随版本快速变化;请以论文原文、官方发布页和 API 文档为准。

学习目标:深入理解知识蒸馏的原理、方法与实践,掌握如何将大模型知识迁移到小模型。


目录


知识蒸馏概述

1.1 什么是知识蒸馏

Text Only
知识蒸馏(Knowledge Distillation)

核心思想:将大型教师模型(Teacher)的知识迁移到小型学生模型(Student)

为什么需要蒸馏?
├── 大模型性能好但推理慢、部署难
├── 小模型速度快但性能差
└── 蒸馏:让小模型学习大模型的"暗知识"

暗知识(Dark Knowledge):
├── 不仅学习正确标签
├── 还学习类别间的相似关系
└── 例如:狗→狼的概率比狗→汽车高

类比:
├── 教师:经验丰富的专家
├── 学生:初学者
└── 蒸馏:学生学习专家的思考方式,不只是答案

1.2 蒸馏的基本框架

Text Only
┌─────────────────────────────────────────────────────────────────┐
│                     知识蒸馏框架                                 │
├─────────────────────────────────────────────────────────────────┤
│                                                                  │
│  ┌─────────────────────┐        ┌─────────────────────┐         │
│  │    教师模型 (T)      │        │    学生模型 (S)      │         │
│  │    (大、复杂)        │        │    (小、简单)        │         │
│  │                     │        │                     │         │
│  │  Input ──▶ T ──▶    │        │  Input ──▶ S ──▶    │         │
│  │         Soft logits │        │         Soft logits │         │
│  │              │      │        │              │      │         │
│  │              ▼      │        │              ▼      │         │
│  │         Softmax     │        │         Softmax     │         │
│  │         (高温τ)     │        │         (高温τ)     │         │
│  │              │      │        │              │      │         │
│  │              ▼      │        │              ▼      │         │
│  │    Soft Targets     │        │    Soft Predictions │         │
│  │         │           │        │         │           │         │
│  │         └───────────┼────────┘         │           │         │
│  │                     │                  │           │         │
│  │                     ▼                  │           │         │
│  │              KL Divergence             │           │         │
│  │              (蒸馏损失)                │           │         │
│  │                     │                  │           │         │
│  │                     ▼                  ▼           │         │
│  │              Hard Targets ───────▶ CE Loss        │         │
│  │              (真实标签)           (学生损失)       │         │
│  │                                                      │         │
│  └──────────────────────────────────────────────────────┘         │
│                              │                                   │
│                              ▼                                   │
│                    Total Loss = α·L_soft + β·L_hard             │
│                                                                  │
└─────────────────────────────────────────────────────────────────┘

基础蒸馏方法

2.1 软标签蒸馏( Hinton Distillation )

Python
import torch
import torch.nn as nn
import torch.nn.functional as F

class KnowledgeDistillationLoss(nn.Module):
    """
    Hinton知识蒸馏损失

    "Distilling the Knowledge in a Neural Network" (Hinton et al., 2015)
    """
    def __init__(self, temperature=4.0, alpha=0.7):
        super().__init__()  # super()调用父类方法
        self.temperature = temperature
        self.alpha = alpha  # 蒸馏损失权重

    def forward(self, student_logits, teacher_logits, true_labels):
        """
        Args:
            student_logits: 学生模型输出 [batch, num_classes]
            teacher_logits: 教师模型输出 [batch, num_classes]
            true_labels: 真实标签 [batch]

        Returns:
            total_loss: 总损失
            distill_loss: 蒸馏损失(KL散度)
            ce_loss: 交叉熵损失
        """
        # 1. 软目标损失(KL散度)
        soft_student = F.log_softmax(student_logits / self.temperature, dim=1)
        soft_teacher = F.softmax(teacher_logits / self.temperature, dim=1)

        # KL散度 = Σ p_teacher * log(p_teacher / p_student)
        distill_loss = F.kl_div(
            soft_student,
            soft_teacher,
            reduction='batchmean'
        ) * (self.temperature ** 2)  # 缩放因子

        # 2. 硬目标损失(交叉熵)
        ce_loss = F.cross_entropy(student_logits, true_labels)

        # 3. 总损失
        total_loss = self.alpha * distill_loss + (1 - self.alpha) * ce_loss

        return total_loss, distill_loss, ce_loss

# 温度参数的作用
def demonstrate_temperature():
    """
    演示温度参数对软标签的影响
    """
    import matplotlib.pyplot as plt

    # 假设logits
    logits = torch.tensor([2.0, 1.0, 0.1])

    temperatures = [0.5, 1.0, 2.0, 4.0, 8.0]

    plt.figure(figsize=(10, 6))
    for T in temperatures:
        probs = F.softmax(logits / T, dim=0)
        plt.plot(probs.numpy(), label=f'T={T}', marker='o')

    plt.xlabel('Class Index')
    plt.ylabel('Probability')
    plt.title('Effect of Temperature on Softmax Distribution')
    plt.legend()
    plt.grid(True)
    plt.show()

    """
    结论:
    - T→0: 趋近于one-hot(硬标签)
    - T→∞: 趋近于均匀分布
    - 适中的T(2-4): 保留类别间关系,但不过于平滑
    """

# demonstrate_temperature()

2.2 特征蒸馏( Feature Distillation )

Python
class FeatureDistillationLoss(nn.Module):
    """
    特征蒸馏

    让学生模型学习教师模型的中间层特征表示
    """
    def __init__(self, student_dim, teacher_dim, hidden_dim=None):
        super().__init__()

        if hidden_dim is None:
            hidden_dim = student_dim

        # 适配层:将学生特征映射到教师特征空间
        self.adaptation_layer = nn.Sequential(
            nn.Linear(student_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, teacher_dim)
        )

    def forward(self, student_features, teacher_features):
        """
        Args:
            student_features: [batch, student_dim]
            teacher_features: [batch, teacher_dim]
        """
        # 适配学生特征
        adapted_student = self.adaptation_layer(student_features)

        # 均方误差
        loss = F.mse_loss(adapted_student, teacher_features)

        return loss

class MultiLayerFeatureDistillation(nn.Module):
    """
    多层特征蒸馏(FitNets)

    在多个中间层进行特征蒸馏
    """
    def __init__(self, student_dims, teacher_dims, weights=None):
        super().__init__()

        assert len(student_dims) == len(teacher_dims)  # assert断言:条件False时抛出AssertionError

        self.num_layers = len(student_dims)

        # 为每层创建适配层
        self.adaptation_layers = nn.ModuleList([
            nn.Linear(s_dim, t_dim)
            for s_dim, t_dim in zip(student_dims, teacher_dims)  # zip按位置配对多个可迭代对象
        ])

        # 每层的权重
        if weights is None:
            weights = [1.0] * self.num_layers
        self.weights = weights

    def forward(self, student_features_list, teacher_features_list):
        """
        Args:
            student_features_list: 学生模型多层特征
            teacher_features_list: 教师模型多层特征
        """
        total_loss = 0

        for i, (s_feat, t_feat, adapt_layer) in enumerate(
            zip(student_features_list, teacher_features_list, self.adaptation_layers)
        ):
            # 适配并计算损失
            adapted = adapt_layer(s_feat)
            loss = F.mse_loss(adapted, t_feat)
            total_loss += self.weights[i] * loss

        return total_loss

2.3 注意力蒸馏

Python
class AttentionDistillationLoss(nn.Module):
    """
    注意力蒸馏

    让学生模型学习教师模型的注意力模式
    """
    def __init__(self, temperature=1.0):
        super().__init__()
        self.temperature = temperature

    def forward(self, student_attention, teacher_attention):
        """
        Args:
            student_attention: [batch, num_heads, seq_len, seq_len]
            teacher_attention: [batch, num_heads, seq_len, seq_len]
        """
        # 方法1: KL散度
        student_attn = F.log_softmax(student_attention / self.temperature, dim=-1)
        teacher_attn = F.softmax(teacher_attention / self.temperature, dim=-1)

        kl_loss = F.kl_div(
            student_attn,
            teacher_attn,
            reduction='batchmean'
        )

        # 方法2: 均方误差(有时效果更好)
        mse_loss = F.mse_loss(student_attention, teacher_attention)

        # 组合
        total_loss = kl_loss + mse_loss

        return total_loss

class MultiHeadAttentionTransfer(nn.Module):
    """
    多头注意力迁移

    TinyBERT等方法使用
    """
    def __init__(self, student_num_heads, teacher_num_heads):
        super().__init__()
        self.student_num_heads = student_num_heads
        self.teacher_num_heads = teacher_num_heads

    def forward(self, student_attn_list, teacher_attn_list):
        """
        对齐不同层数的注意力
        """
        # 选择教师模型的特定层进行蒸馏
        # 例如:学生6层,教师12层,选择教师第2,4,6,8,10,12层
        selected_teacher_layers = self.select_layers(
            len(teacher_attn_list),
            len(student_attn_list)
        )

        total_loss = 0
        for student_idx, teacher_idx in enumerate(selected_teacher_layers):  # enumerate同时获取索引和元素
            s_attn = student_attn_list[student_idx]
            t_attn = teacher_attn_list[teacher_idx]

            # 如果头数不同,需要平均或插值
            if s_attn.size(1) != t_attn.size(1):
                t_attn = t_attn.mean(dim=1, keepdim=True).expand(-1, s_attn.size(1), -1, -1)

            loss = F.mse_loss(s_attn, t_attn)
            total_loss += loss

        return total_loss / len(student_attn_list)

    def select_layers(self, teacher_layers, student_layers):
        """
        选择对应的教师层
        """
        import numpy as np
        indices = np.linspace(0, teacher_layers - 1, student_layers, dtype=int)
        return indices.tolist()

高级蒸馏技术

3.1 自蒸馏( Self-Distillation )

Python
class SelfDistillationLoss(nn.Module):
    """
    自蒸馏

    模型自己蒸馏自己,不需要单独的教师模型
    """
    def __init__(self, temperature=4.0):
        super().__init__()
        self.temperature = temperature

    def forward(self, logits, labels, epoch, total_epochs):
        """
        随着训练进行,逐渐增加软标签的影响
        """
        # 硬标签损失
        hard_loss = F.cross_entropy(logits, labels)

        # 生成软标签(使用当前模型自己的预测)
        with torch.no_grad():  # 禁用梯度计算,节省内存(推理时使用)
            soft_targets = F.softmax(logits / self.temperature, dim=1)

        # 软标签损失
        # ⚠️ 演示说明: 此处 soft_targets 与 logits 来自同一分布,理论上 KL(p‖p)≡0。
        # 实际自蒸馏中,soft_targets 应来自上一个 epoch 保存的模型输出(即“历史版本”作为教师)。
        # 这里仅作为代码结构示例。
        soft_loss = F.kl_div(
            F.log_softmax(logits / self.temperature, dim=1),
            soft_targets,
            reduction='batchmean'
        ) * (self.temperature ** 2)

        # 动态权重:早期主要用硬标签,后期增加软标签
        alpha = min(0.5, epoch / (total_epochs * 0.5))

        total_loss = (1 - alpha) * hard_loss + alpha * soft_loss

        return total_loss

3.2 在线蒸馏( Online Distillation )

Python
class OnlineDistillation:
    """
    在线蒸馏(Deep Mutual Learning)

    多个学生模型互相学习,没有固定的教师
    """
    def __init__(self, models, temperature=4.0):
        self.models = models  # 多个学生模型
        self.temperature = temperature
        self.optimizers = [torch.optim.AdamW(m.parameters()) for m in models]

    def train_step(self, batch):
        """
        所有模型互相蒸馏
        """
        inputs, labels = batch

        # 获取所有模型的预测
        logits_list = [model(inputs) for model in self.models]

        # 计算平均预测作为"软教师"
        avg_logits = torch.stack(logits_list).mean(dim=0)

        # 每个模型向平均预测学习
        for i, (model, optimizer) in enumerate(zip(self.models, self.optimizers)):
            optimizer.zero_grad()

            # 硬标签损失
            hard_loss = F.cross_entropy(logits_list[i], labels)

            # 向其他模型学习(排除自己)
            other_logits = torch.stack(
                [logits_list[j] for j in range(len(self.models)) if j != i]
            ).mean(dim=0)

            soft_loss = F.kl_div(
                F.log_softmax(logits_list[i] / self.temperature, dim=1),
                F.softmax(other_logits / self.temperature, dim=1),
                reduction='batchmean'
            ) * (self.temperature ** 2)

            # 总损失
            loss = hard_loss + soft_loss
            loss.backward()
            optimizer.step()

3.3 任务特定蒸馏

Python
class TaskSpecificDistillation:
    """
    任务特定蒸馏

    针对不同任务设计特定的蒸馏策略
    """

    @staticmethod  # @staticmethod无需实例即可调用
    def classification_distillation(student_logits, teacher_logits, labels, alpha=0.7, T=4.0):
        """
        分类任务蒸馏
        """
        soft_loss = F.kl_div(
            F.log_softmax(student_logits / T, dim=1),
            F.softmax(teacher_logits / T, dim=1),
            reduction='batchmean'
        ) * (T ** 2)

        hard_loss = F.cross_entropy(student_logits, labels)

        return alpha * soft_loss + (1 - alpha) * hard_loss

    @staticmethod
    def ner_distillation(student_logits, teacher_logits, labels, attention_mask, alpha=0.7, T=4.0):
        """
        命名实体识别任务蒸馏

        每个token独立计算
        """
        batch_size, seq_len, num_classes = student_logits.shape

        # 展平
        student_logits = student_logits.view(-1, num_classes)  # view重塑张量形状
        teacher_logits = teacher_logits.view(-1, num_classes)
        labels = labels.view(-1)
        attention_mask = attention_mask.view(-1)

        # 只计算有效token
        valid_indices = attention_mask.bool()
        student_logits = student_logits[valid_indices]
        teacher_logits = teacher_logits[valid_indices]
        labels = labels[valid_indices]

        soft_loss = F.kl_div(
            F.log_softmax(student_logits / T, dim=1),
            F.softmax(teacher_logits / T, dim=1),
            reduction='batchmean'
        ) * (T ** 2)

        hard_loss = F.cross_entropy(student_logits, labels)

        return alpha * soft_loss + (1 - alpha) * hard_loss

    @staticmethod
    def generation_distillation(student_logits, teacher_logits, labels, alpha=0.7, T=4.0):
        """
        生成任务蒸馏

        序列生成任务的蒸馏
        """
        # 移位:预测下一个token
        shift_student = student_logits[..., :-1, :].contiguous()
        shift_teacher = teacher_logits[..., :-1, :].contiguous()
        shift_labels = labels[..., 1:].contiguous()

        # 展平
        shift_student = shift_student.view(-1, shift_student.size(-1))
        shift_teacher = shift_teacher.view(-1, shift_teacher.size(-1))
        shift_labels = shift_labels.view(-1)

        soft_loss = F.kl_div(
            F.log_softmax(shift_student / T, dim=1),
            F.softmax(shift_teacher / T, dim=1),
            reduction='batchmean'
        ) * (T ** 2)

        hard_loss = F.cross_entropy(shift_student, shift_labels)

        return alpha * soft_loss + (1 - alpha) * hard_loss

白盒蒸馏与黑盒蒸馏

4.1 白盒蒸馏

Python
class WhiteBoxDistillation:
    """
    白盒蒸馏

    可以访问教师模型的内部参数和特征
    """
    def __init__(self, teacher_model, student_model):
        self.teacher_model = teacher_model
        self.student_model = student_model

        # 注册hook获取中间特征
        self.teacher_features = []
        self.student_features = []

        self._register_hooks()

    def _register_hooks(self):
        """
        注册前向hook获取特征
        """
        self.hook_handles = []  # 保存hook句柄,用于后续清理

        def get_features(name, storage):
            def hook(module, input, output):
                storage.append(output)
            return hook

        # 在特定层注册hook
        for i, layer in enumerate(self.teacher_model.transformer_layers):
            handle = layer.register_forward_hook(
                get_features(f'teacher_layer_{i}', self.teacher_features)
            )
            self.hook_handles.append(handle)

        for i, layer in enumerate(self.student_model.transformer_layers):
            handle = layer.register_forward_hook(
                get_features(f'student_layer_{i}', self.student_features)
            )
            self.hook_handles.append(handle)

    def forward(self, inputs, labels):
        """
        前向传播并计算蒸馏损失
        """
        # 教师前向
        with torch.no_grad():
            teacher_logits = self.teacher_model(inputs)

        # 学生前向
        student_logits = self.student_model(inputs)

        # 1. 输出蒸馏
        output_loss = F.kl_div(
            F.log_softmax(student_logits / 4.0, dim=1),
            F.softmax(teacher_logits / 4.0, dim=1),
            reduction='batchmean'
        ) * 16

        # 2. 特征蒸馏
        feature_loss = 0
        for s_feat, t_feat in zip(self.student_features, self.teacher_features):
            # 适配维度
            if s_feat.shape != t_feat.shape:
                t_feat = F.adaptive_avg_pool1d(
                    t_feat.transpose(1, 2),
                    s_feat.size(1)
                ).transpose(1, 2)

            feature_loss += F.mse_loss(s_feat, t_feat)

        # 3. 硬标签损失
        ce_loss = F.cross_entropy(student_logits, labels)

        # 总损失
        total_loss = 0.5 * output_loss + 0.3 * feature_loss + 0.2 * ce_loss

        # 清空特征缓存
        self.teacher_features.clear()
        self.student_features.clear()

        return total_loss

4.2 黑盒蒸馏

Python
class BlackBoxDistillation:
    """
    黑盒蒸馏

    只能访问教师模型的输出(如API),无法访问内部参数
    """
    def __init__(self, student_model, teacher_api=None):
        self.student_model = student_model
        self.teacher_api = teacher_api  # 教师模型API接口

    def get_teacher_predictions(self, inputs):
        """
        通过API获取教师模型预测

        实际应用中可能是:
        - OpenAI API
        - 本地教师模型推理
        - 预计算的教师预测
        """
        if self.teacher_api is None:
            raise ValueError("Teacher API not provided")

        # 调用API获取预测
        teacher_logits = self.teacher_api.predict(inputs)

        return teacher_logits

    def train_step(self, batch):
        """
        训练步骤
        """
        inputs, labels = batch

        # 获取教师预测(黑盒)
        with torch.no_grad():
            teacher_logits = self.get_teacher_predictions(inputs)

        # 学生前向
        student_logits = self.student_model(inputs)

        # 只能使用输出蒸馏
        distill_loss = F.kl_div(
            F.log_softmax(student_logits / 4.0, dim=1),
            F.softmax(teacher_logits / 4.0, dim=1),
            reduction='batchmean'
        ) * 16

        # 硬标签损失
        ce_loss = F.cross_entropy(student_logits, labels)

        # 总损失
        total_loss = 0.7 * distill_loss + 0.3 * ce_loss

        return total_loss

class DataFreeDistillation:
    """
    无数据蒸馏

    不需要原始训练数据,通过生成数据蒸馏
    """
    def __init__(self, student_model, generator_model):
        self.student_model = student_model
        self.generator_model = generator_model  # 数据生成器

    def generate_synthetic_data(self, num_samples, sequence_length):
        """
        生成合成数据
        """
        # 使用生成器或随机生成
        synthetic_inputs = self.generator_model.generate(
            num_samples=num_samples,
            sequence_length=sequence_length
        )

        return synthetic_inputs

    def train_step(self, teacher_model):
        """
        训练步骤
        """
        # 生成合成数据
        synthetic_inputs = self.generate_synthetic_data(
            num_samples=32,
            sequence_length=128
        )

        # 获取教师预测
        with torch.no_grad():
            teacher_logits = teacher_model(synthetic_inputs)

        # 学生预测
        student_logits = self.student_model(synthetic_inputs)

        # 蒸馏损失
        loss = F.kl_div(
            F.log_softmax(student_logits / 4.0, dim=1),
            F.softmax(teacher_logits / 4.0, dim=1),
            reduction='batchmean'
        ) * 16

        return loss

蒸馏与量化的结合

5.1 量化感知蒸馏

Python
class QuantizationAwareDistillation:
    """
    量化感知蒸馏

    在蒸馏过程中考虑量化误差
    """
    def __init__(self, teacher_model, student_model, num_bits=8):
        self.teacher_model = teacher_model
        self.student_model = student_model
        self.num_bits = num_bits

    def quantize_tensor(self, tensor, num_bits):
        """
        模拟量化
        """
        # 计算缩放因子
        qmin = -(2 ** (num_bits - 1))
        qmax = 2 ** (num_bits - 1) - 1

        min_val = tensor.min()
        max_val = tensor.max()

        scale = (max_val - min_val) / (qmax - qmin)
        zero_point = qmin - min_val / scale

        # 量化
        quantized = torch.clamp(
            torch.round(tensor / scale + zero_point),
            qmin, qmax
        )

        # 反量化
        dequantized = (quantized - zero_point) * scale

        return dequantized

    def forward(self, inputs, labels):
        """
        前向传播
        """
        # 教师预测(全精度)
        with torch.no_grad():
            teacher_logits = self.teacher_model(inputs)

        # 学生预测(模拟量化)
        student_logits = self.student_model(inputs)

        # 模拟量化学生输出
        quantized_student_logits = self.quantize_tensor(
            student_logits, self.num_bits
        )

        # 蒸馏损失(使用量化后的输出)
        distill_loss = F.kl_div(
            F.log_softmax(quantized_student_logits / 4.0, dim=1),
            F.softmax(teacher_logits / 4.0, dim=1),
            reduction='batchmean'
        ) * 16

        # 硬标签损失(使用原始输出)
        ce_loss = F.cross_entropy(student_logits, labels)

        return 0.7 * distill_loss + 0.3 * ce_loss

5.2 渐进式压缩

Python
class ProgressiveCompression:
    """
    渐进式压缩

    先蒸馏,再量化,最后剪枝
    """
    def __init__(self, teacher_model):
        self.teacher_model = teacher_model
        self.student_model = None

    def stage1_distillation(self, train_loader, num_epochs=10):
        """
        第一阶段:知识蒸馏
        """
        # 创建学生模型(比教师小)
        self.student_model = self.create_student_model()

        # 蒸馏训练
        distillation_trainer = KnowledgeDistillationTrainer(
            teacher=self.teacher_model,
            student=self.student_model
        )

        for epoch in range(num_epochs):
            for batch in train_loader:
                loss = distillation_trainer.train_step(batch)

        return self.student_model

    def stage2_quantization(self, model, num_bits=8):
        """
        第二阶段:量化
        """
        quantized_model = torch.quantization.quantize_dynamic(
            model,
            {nn.Linear},
            dtype=torch.qint8
        )

        return quantized_model

    def stage3_pruning(self, model, sparsity=0.3):
        """
        第三阶段:剪枝
        """
        import torch.nn.utils.prune as prune

        # 对线性层进行剪枝
        for name, module in model.named_modules():
            if isinstance(module, nn.Linear):  # isinstance检查类型
                prune.l1_unstructured(
                    module,
                    name='weight',
                    amount=sparsity
                )

        return model

    def full_pipeline(self, train_loader):
        """
        完整压缩流程
        """
        # Stage 1: 蒸馏
        print("Stage 1: Distillation...")
        distilled_model = self.stage1_distillation(train_loader)

        # Stage 2: 量化
        print("Stage 2: Quantization...")
        quantized_model = self.stage2_quantization(distilled_model)

        # Stage 3: 剪枝
        print("Stage 3: Pruning...")
        final_model = self.stage3_pruning(quantized_model)

        return final_model

闭源模型蒸馏

6.1 API调用式蒸馏原理(黑盒蒸馏)

Python
class BlackBoxDistillation:
    """
    黑盒蒸馏(API调用式蒸馏)

    适用于无法访问模型内部结构的闭源模型(如GPT-4、Claude等)
    通过API调用获取教师模型的输出,仅利用输入-输出对进行蒸馏
    """

    def __init__(self, api_key, model_name="gpt-4"):
        self.api_key = api_key
        self.model_name = model_name
        self.client = OpenAI(api_key=api_key)

    def generate_teacher_response(self, prompt, temperature=0.7, max_tokens=2048):
        """
        通过API调用获取教师模型输出

        Args:
            prompt: 输入提示
            temperature: 生成温度
            max_tokens: 最大token数

        Returns:
            教师模型的响应内容
        """
        response = self.client.chat.completions.create(
            model=self.model_name,
            messages=[{"role": "user", "content": prompt}],
            temperature=temperature,
            max_tokens=max_tokens
        )
        return response.choices[0].message.content

    def batch_generate_distillation_data(self, prompts, batch_size=100):
        """
        批量生成蒸馏数据

        使用闭源API批量生成训练数据
        """
        distillation_data = []

        for i in range(0, len(prompts), batch_size):
            batch = prompts[i:i + batch_size]

            # 并发调用API
            futures = []
            with ThreadPoolExecutor(max_workers=10) as executor:
                for prompt in batch:
                    future = executor.submit(
                        self.generate_teacher_response, prompt
                    )
                    futures.append((prompt, future))

            # 收集结果
            for prompt, future in futures:
                try:
                    response = future.result(timeout=30)
                    distillation_data.append({
                        "input": prompt,
                        "teacher_output": response,
                        "label": self.extract_label(response)
                    })
                except Exception as e:
                    print(f"Error generating response: {e}")
                    continue

        return distillation_data

    def extract_label(self, response):
        """
        从教师输出中提取标签

        对于分类任务,提取分类标签
        对于生成任务,提取目标输出
        """
        # 简单的标签提取逻辑,实际应用中需要根据任务定制
        return response.strip()

    def distill_with_api_data(self, student_model, distillation_data, epochs=3):
        """
        使用API生成的数据蒸馏学生模型
        """
        optimizer = torch.optim.Adam(student_model.parameters(), lr=1e-4)

        for epoch in range(epochs):
            for data in distillation_data:
                # 学生模型前向传播
                student_output = student_model(data["input"])

                # 教师输出编码
                teacher_encoding = self.encode_response(data["teacher_output"])

                # 计算蒸馏损失
                loss = self.compute_distillation_loss(
                    student_output, teacher_encoding
                )

                # 反向传播
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

        return student_model

    def encode_response(self, response):
        """
        将教师响应编码为训练信号
        """
        # 实际应用中可能需要更复杂的编码逻辑
        return response

    def compute_distillation_loss(self, student_output, teacher_encoding):
        """
        计算蒸馏损失
        """
        # 简化版本,实际需要根据任务定制
        return nn.functional.cross_entropy(
            student_output,
            teacher_encoding
        )

6.2 思维链蒸馏技术(CoT Distillation)

Python
class ChainOfThoughtDistillation:
    """
    思维链蒸馏(CoT Distillation)

    将大模型的思维链推理能力迁移到小模型
    核心思想:让学生模型学习教师的推理过程,而非仅学习答案
    """

    def __init__(self, teacher_model, student_model):
        self.teacher = teacher_model
        self.student = student_model

    def generate_cot_data(self, problems, use_cot_prompt=True):
        """
        生成思维链训练数据

        Args:
            problems: 问题列表
            use_cot_prompt: 是否使用CoT提示

        Returns:
            包含问题、思维链、答案的数据列表
        """
        cot_data = []

        for problem in problems:
            # 使用教师模型生成思维链
            if use_cot_prompt:
                prompt = f"{problem}\n\n请逐步思考并展示你的推理过程。"
            else:
                prompt = problem

            # 获取教师响应
            teacher_response = self.teacher.generate(prompt)

            # 解析思维链和答案
            parsed = self.parse_cot_response(teacher_response)

            cot_data.append({
                "problem": problem,
                "reasoning": parsed["reasoning"],
                "answer": parsed["answer"]
            })

        return cot_data

    def parse_cot_response(self, response):
        """
        解析思维链响应

        提取推理步骤和最终答案
        """
        # 简单的解析逻辑
        lines = response.split("\n")

        reasoning_steps = []
        answer = ""

        for line in lines:
            if "答案:" in line or "Answer:" in line:
                answer = line.split("答案:")[-1].split("Answer:")[-1].strip()
            else:
                reasoning_steps.append(line)

        return {
            "reasoning": "\n".join(reasoning_steps),
            "answer": answer
        }

    def distill_cot(self, cot_data, method="full_cot"):
        """
        执行思维链蒸馏

        Args:
            cot_data: 思维链训练数据
            method: 蒸馏方法
                   - "full_cot": 完整思维链蒸馏
                   - "answer_only": 仅蒸馏答案
                   - "summary_only": 仅蒸馏思维链摘要
        """
        if method == "full_cot":
            return self._distill_full_cot(cot_data)
        elif method == "answer_only":
            return self._distill_answer_only(cot_data)
        elif method == "summary_only":
            return self._distill_summary_only(cot_data)

    def _distill_full_cot(self, cot_data):
        """
        完整思维链蒸馏

        学生模型学习完整的推理过程和最终答案
        """
        total_loss = 0

        for data in cot_data:
            # 学生模型生成完整响应
            student_output = self.student.generate(
                f"{data['problem']}\n\n请逐步思考。"
            )

            # 计算响应级别的蒸馏损失
            response_loss = self.compute_response_loss(
                student_output,
                data["reasoning"] + "\n\n答案:" + data["answer"]
            )

            # 计算token级别的损失
            token_loss = self.compute_token_level_loss(
                student_output,
                data["reasoning"] + " " + data["answer"]
            )

            # 综合损失
            loss = 0.5 * response_loss + 0.5 * token_loss
            total_loss += loss

        return total_loss / len(cot_data)

    def _distill_answer_only(self, cot_data):
        """
        仅蒸馏答案

        简单有效,但丢失推理过程信息
        """
        total_loss = 0

        for data in cot_data:
            # 仅使用答案作为标签
            student_output = self.student.generate(data["problem"])

            # 提取学生模型的答案
            student_answer = self.extract_answer(student_output)

            # 计算答案损失
            loss = self.compute_answer_loss(
                student_answer,
                data["answer"]
            )

            total_loss += loss

        return total_loss / len(cot_data)

    def _distill_summary_only(self, cot_data):
        """
        Summary-only CoT蒸馏

        仅蒸馏思维链摘要,不包含完整推理细节
        减少学生模型的学习负担,同时保留核心推理能力
        """
        total_loss = 0

        for data in cot_data:
            # 生成思维链摘要
            summary = self.summarize_reasoning(data["reasoning"])

            # 学生学习摘要
            student_summary = self.student.generate(
                f"{data['problem']}\n\n请简要思考关键步骤。"
            )

            # 计算摘要级别损失
            loss = self.compute_summary_loss(
                student_summary,
                summary
            )

            total_loss += loss

        return total_loss / len(cot_data)

    def summarize_reasoning(self, reasoning):
        """
        生成思维链摘要

        提取核心推理步骤,移除冗余信息
        """
        # 简化版本,实际应用中可能需要更复杂的摘要逻辑
        lines = reasoning.split("\n")
        key_steps = []

        for line in lines:
            # 保留关键推理步骤
            if any(keyword in line for keyword in ["因此", "所以", "所以", "第一步", "第二步", "关键"]):
                key_steps.append(line)

        return " ".join(key_steps)

    def compute_response_loss(self, student_response, teacher_response):
        """
        计算响应级别的损失
        """
        # 使用文本相似度作为损失
        student_emb = self.embed(student_response)
        teacher_emb = self.embed(teacher_response)

        return 1 - cosine_similarity(student_emb, teacher_emb)

    def compute_token_level_loss(self, student_output, teacher_output):
        """
        计算token级别的损失
        """
        student_tokens = self.tokenize(student_output)
        teacher_tokens = self.tokenize(teacher_output)

        # 对齐token序列
        aligned_loss = 0
        for s_tok, t_tok in zip(student_tokens, teacher_tokens):
            aligned_loss += nn.functional.cross_entropy(s_tok, t_tok)

        return aligned_loss / max(len(student_tokens), len(teacher_tokens))

    def compute_answer_loss(self, student_answer, teacher_answer):
        """
        计算答案损失
        """
        # 简化为文本匹配损失
        return 0 if student_answer == teacher_answer else 1.0

    def compute_summary_loss(self, student_summary, teacher_summary):
        """
        计算摘要损失
        """
        student_emb = self.embed(student_summary)
        teacher_emb = self.embed(teacher_summary)

        return 1 - cosine_similarity(student_emb, teacher_emb)

    def embed(self, text):
        """
        文本嵌入
        """
        # 实际应用中需要使用适当的嵌入模型
        return torch.randn(768)

    def tokenize(self, text):
        """
        分词
        """
        # 简化版本
        return [torch.randn(50257) for _ in range(len(text.split()))]

    def extract_answer(self, response):
        """
        从响应中提取答案
        """
        if "答案:" in response:
            return response.split("答案:")[-1].strip()
        elif "Answer:" in response:
            return response.split("Answer:")[-1].strip()
        return response.strip()

6.3 PRM与ORM在蒸馏中的应用

Python
class RewardGuidedDistillation:
    """
    基于奖励模型的蒸馏

    PRM(过程奖励模型):对每个推理步骤打分
    ORM(结果奖励模型):对整个响应结果打分

    在蒸馏过程中引入奖励信号,引导学生模型学习更好的推理路径
    """

    def __init__(self, teacher_model, student_model, reward_model_type="prm"):
        self.teacher = teacher_model
        self.student = student_model
        self.reward_model_type = reward_model_type

        if reward_model_type == "prm":
            self.reward_model = ProcessRewardModel()
        else:
            self.reward_model = OutcomeRewardModel()

    def distill_with_rewards(self, training_data, use_rewards=True):
        """
        使用奖励信号进行蒸馏

        Args:
            training_data: 训练数据
            use_rewards: 是否使用奖励信号
        """
        total_loss = 0

        for data in training_data:
            # 学生模型生成响应
            student_output = self.student.generate(
                data["problem"],
                return_reasoning_steps=True
            )

            # 获取教师响应
            teacher_output = self.teacher.generate(
                data["problem"],
                return_reasoning_steps=True
            )

            # 计算基础蒸馏损失
            distill_loss = self.compute_distillation_loss(
                student_output, teacher_output
            )

            if use_rewards:
                # 计算奖励
                if self.reward_model_type == "prm":
                    student_reward = self.compute_prm_reward(
                        student_output["reasoning_steps"]
                    )
                    teacher_reward = self.compute_prm_reward(
                        teacher_output["reasoning_steps"]
                    )
                else:
                    student_reward = self.reward_model.score(
                        data["problem"], student_output["answer"]
                    )
                    teacher_reward = self.reward_model.score(
                        data["problem"], teacher_output["answer"]
                    )

                # 奖励引导的蒸馏损失
                reward_weight = abs(student_reward - teacher_reward)
                total_loss += distill_loss * (1 + reward_weight)
            else:
                total_loss += distill_loss

        return total_loss / len(training_data)

    def compute_prm_reward(self, reasoning_steps):
        """
        计算过程奖励

        对每个推理步骤打分,综合得到总奖励
        """
        step_rewards = []

        for step in reasoning_steps:
            reward = self.reward_model.score_step(step)
            step_rewards.append(reward)

        # 加权平均作为总奖励
        weights = [0.1 * (i + 1) for i in range(len(step_rewards))]
        total_reward = sum(w * r for w, r in zip(weights, step_rewards))

        return total_reward


class ProcessRewardModel:
    """
    过程奖励模型(PRM)

    对每个推理步骤独立打分
    """

    def __init__(self):
        # 实际应用中需要加载预训练的PRM
        self.model = None

    def score_step(self, step):
        """
        对单个推理步骤打分

        Returns:
            float: 步骤分数 [0, 1]
        """
        # 简化实现
        # 实际应用中需要使用真实的PRM模型
        if any(keyword in step for keyword in ["正确", "对", "是的"]):
            return 0.9
        elif any(keyword in step for keyword in ["错误", "不对"]):
            return 0.3
        return 0.7


class OutcomeRewardModel:
    """
    结果奖励模型(ORM)

    对整个响应结果打分
    """

    def __init__(self):
        self.model = None

    def score(self, problem, answer):
        """
        对响应结果打分

        Returns:
            float: 响应分数 [0, 1]
        """
        # 简化实现
        # 实际应用中需要使用真实的ORM模型
        ground_truth = self.extract_ground_truth(problem)

        if answer == ground_truth:
            return 1.0
        elif self.is_partial_match(answer, ground_truth):
            return 0.5
        return 0.0

    def extract_ground_truth(self, problem):
        """
        从问题中提取标准答案
        """
        # 简化实现
        return ""

    def is_partial_match(self, answer, ground_truth):
        """
        判断部分匹配
        """
        return False


class BestOfNDistillation:
    """
    Best-of-N 蒸馏

    生成N个候选响应,使用奖励模型选择最佳响应进行蒸馏
    """

    def __init__(self, teacher_model, student_model, n_candidates=8):
        self.teacher = teacher_model
        self.student = student_model
        self.n_candidates = n_candidates
        self.reward_model = ProcessRewardModel()

    def distill_best_of_n(self, problem):
        """
        Best-of-N 蒸馏

        1. 学生模型生成N个候选响应
        2. 使用PRM/ORM选择最佳响应
        3. 使用最佳响应进行蒸馏
        """
        # 生成N个候选
        candidates = []
        for _ in range(self.n_candidates):
            candidate = self.student.generate(
                problem,
                temperature=0.8,
                return_reasoning_steps=True
            )
            candidates.append(candidate)

        # 使用奖励模型选择最佳
        best_candidate = self.select_best(candidates)

        # 获取教师响应
        teacher_response = self.teacher.generate(
            problem,
            return_reasoning_steps=True
        )

        # 计算蒸馏损失
        loss = self.compute_loss(best_candidate, teacher_response)

        return loss

    def select_best(self, candidates):
        """
        使用奖励模型选择最佳候选
        """
        best_score = -float('inf')
        best_candidate = None

        for candidate in candidates:
            # 计算候选的奖励分数
            if "reasoning_steps" in candidate:
                score = self.reward_model.compute_prm_reward(
                    candidate["reasoning_steps"]
                )
            else:
                score = 0.5  # 默认分数

            if score > best_score:
                best_score = score
                best_candidate = candidate

        return best_candidate

    def compute_loss(self, student_output, teacher_output):
        """
        计算蒸馏损失
        """
        # 简化的损失计算
        return 0.1

6.4 DeepSeek事件技术分析

Python
class DeepSeekDistillationAnalysis:
    """
    DeepSeek蒸馏技术分析

    2025年初DeepSeek-R1的发布引发了业界对模型蒸馏技术的广泛关注
    以下分析基于DeepSeek官方技术报告和公开信息
    """

    @staticmethod
    def analyze_r1_distillation_approach():
        """
        分析DeepSeek-R1的蒸馏方法

        核心创新:
        1. 使用强化学习训练大推理模型
        2. 将推理能力蒸馏到小模型
        3. 开源多个蒸馏版本
        """
        analysis = """
        DeepSeek-R1 蒸馏技术分析
        ════════════════════════════════════════════════════════════

        1. 教师模型选择
        ├── DeepSeek-R1 (671B MoE) 作为教师
        ├── 蒸馏到多个小规模密集模型
        └── Qwen系列和Llama系列作为学生基座

        2. 蒸馏数据生成
        ├── 使用DeepSeek-R1生成高质量推理数据
        ├── 拒绝采样过滤低质量样本
        └── 保留包含正确思维链的数据

        3. 蒸馏训练方法
        ├── SFT(监督微调)作为主要方法
        ├── 学生模型学习教师的思维链格式
        └── 不使用强化学习,仅用SFT

        4. 蒸馏效果
        ├── DeepSeek-R1-Distill-Qwen-7B:
        │   在AIME 2024上胜过gpt-5.4
        ├── DeepSeek-R1-Distill-Qwen-32B:
        │   在AIME 2024上胜过o1-mini
        └── 证明小模型可通过蒸馏获得推理能力

        ════════════════════════════════════════════════════════════
        """
        return analysis

    @staticmethod
    def compare_with_openai_distillation():
        """
        DeepSeek vs OpenAI 蒸馏方法对比
        """
        comparison = """
        蒸馏方法对比
        ════════════════════════════════════════════════════════════

        | 维度 | DeepSeek | OpenAI |
        ├─────────────────────────────────────────────────────────┤
        | 教师模型 | DeepSeek-R1 (开源) | gpt-5.4 (闭源) |
        | 蒸馏方式 | SFT + CoT数据 | API调用 + 黑盒蒸馏 |
        | 思维链 | 原生支持 | 提示工程 |
        | 开源程度 | 完全开源 | 不开源 |
        | 成本 | 低(自托管) | 高(API费用) |

        DeepSeek的优势:
        1. 完全开源,可自由使用和修改
        2. 思维链原生集成,推理能力强
        3. 训练成本低,适合资源有限的团队

        OpenAI的优势:
        1. 模型能力领先(GPT-4系列)
        2. API稳定可靠
        3. 支持多模态
        ════════════════════════════════════════════════════════════
        """
        return comparison

    @staticmethod
    def implementation_recommendations():
        """
        蒸馏实践建议

        基于DeepSeek的经验总结
        """
        recommendations = """
        蒸馏实践建议
        ════════════════════════════════════════════════════════════

        1. 数据准备
        ✅ 使用高质量推理数据
        ✅ 过滤错误思维链样本
        ✅ 保持数据多样性

        2. 蒸馏训练
        ✅ 使用SFT作为基础
        ✅ 添加格式遵循损失
        ✅ 渐进式训练(简单→复杂)

        3. 评估验证
        ✅ 在多个基准上评估
        ✅ 关注思维链质量
        ✅ 对比原教师模型

        4. 生产部署
        ✅ 量化压缩模型
        ✅ 优化推理速度
        ✅ 监控模型质量

        ════════════════════════════════════════════════════════════
        """
        return recommendations


def implement_deepseek_style_distillation():
    """
    实现DeepSeek风格的蒸馏流程
    """
    distillation_pipeline = """
    # DeepSeek风格蒸馏流程

    1. 数据生成
    ```python
    # 使用DeepSeek-R1生成推理数据
    teacher = DeepSeekR1Model()

    for problem in problems:
        response = teacher.generate_with_cot(problem)
        if verify_answer(response):
            save_to_dataset(problem, response)
    ```

    2. 数据过滤
    ```python
    # 拒绝采样过滤
    for sample in raw_data:
        student_response = student.generate(sample.problem)

        if student_response.answer == sample.answer:
            # 保留学生也能正确回答的样本
            keep(sample)
        else:
            # 丢弃学生无法正确回答的样本
            discard(sample)
    ```

    3. 蒸馏训练
    ```python
    # SFT蒸馏
    for epoch in range(num_epochs):
        for batch in dataloader:
            student_output = student(batch.problems)

            # 匹配教师思维链格式
            loss = compute_cot_loss(
                student_output.reasoning,
                batch.teacher_reasoning
            )

            student.backward(loss)
            student.step()
    ```

    4. 评估
    ```python
    # 多基准评估
    benchmarks = ["AIME", "MATH", "GSM8K", "HumanEval"]

    for bench in benchmarks:
        score = evaluate(student, bench)
        print(f"{bench}: {score}")
    ```
    """
    return distillation_pipeline

6.5 主流蒸馏框架介绍

Python
class DistillationFrameworkGuide:
    """
    主流蒸馏框架对比与使用指南

    2025-2026年主流开源蒸馏工具
    """

    @staticmethod
    def compare_frameworks():
        """
        蒸馏框架对比
        """
        comparison = """
        蒸馏框架对比
        ════════════════════════════════════════════════════════════

        1. DistillFlow (HorusAILabs)
        ├── 多策略蒸馏:logits、注意力、层蒸馏
        ├── 多GPU支持,动态资源分配
        ├── 支持Unsloth、Liger Kernel优化
        └── 适合大规模蒸馏任务

        2. DistillKit (Arcee AI)
        ├── 轻量级、易用
        ├── 支持多种蒸馏策略
        ├── 与HuggingFace生态集成
        └── 适合快速实验

        3. TextBrewer (微软)
        ├── 通用蒸馏框架
        ├── 支持多种蒸馏配置
        ├── 多教师蒸馏支持
        └── 适合研究用途

        4. MiniMind
        ├── 完整训练流程
        ├── 单卡低成本
        ├── 适合教学和实践
        └── 代码透明

        ════════════════════════════════════════════════════════════
        """
        return comparison

    @staticmethod
    def get_distillflow_usage():
        """
        DistillFlow使用示例
        """
        usage = """
        # DistillFlow 使用示例

        ```python
        from distillflow import DistillFlow, DistillationConfig

        # 配置蒸馏
        config = DistillationConfig(
            teacher_model="deepseek-ai/DeepSeek-R1",
            student_model="Qwen/Qwen2.5-7B",
            distillation_strategies=["logits", "attention"],
            temperature=4.0,
            alpha=0.7
        )

        # 初始化蒸馏器
        df = DistillFlow(config)

        # 执行蒸馏
        df.distill(
            train_data=train_dataset,
            output_dir="./distilled_model"
        )
        ```
        """
        return usage

    @staticmethod
    def get_distillkit_usage():
        """
        DistillKit使用示例
        """
        usage = """
        # DistillKit 使用示例

        ```python
        from distillkit import DistillKit

        # 初始化
        dk = DistillKit(
            teacher="gpt-4",
            student="gpt-3.5-turbo",
            method="cot_distillation"
        )

        # 生成蒸馏数据
        dk.generate_distillation_data(
            prompts=problem_list,
            api_key=OPENAI_API_KEY
        )

        # 执行蒸馏
        dk.distill(
            epochs=3,
            batch_size=8,
            learning_rate=1e-4
        )
        ```
        """
        return usage


def create_distillation_pipeline():
    """
    创建完整的蒸馏流水线

    综合使用多种蒸馏技术
    """
    pipeline = """
    完整蒸馏流水线
    ════════════════════════════════════════════════════════════

    阶段1:数据准备
    ├── 收集问题数据集
    ├── 使用教师模型生成CoT数据
    ├── 过滤和验证数据质量
    └── 数据增强(如需要)

    阶段2:蒸馏训练
    ├── 选择合适的蒸馏方法
    ├── 配置温度和损失权重
    ├── 渐进式训练
    └── 定期评估验证

    阶段3:后处理
    ├── 模型量化(INT8/INT4)
    ├── 格式转换(GGUF/ONNX)
    └── 性能基准测试

    工具推荐:
    ├── 数据生成:DeepSeek API / OpenAI API
    ├── 蒸馏训练:DistillFlow / DistillKit
    ├── 模型量化:llama.cpp / GPTQ
    └── 评估工具:lm-evaluation-harness

    ════════════════════════════════════════════════════════════
    """
    return pipeline

经典案例与最佳实践

7.1 经典案例

Text Only
经典知识蒸馏案例
═══════════════════════════════════════════════════════════════════

1. DistilBERT (Hugging Face, 2019)
├── 教师: BERT-base (110M参数)
├── 学生: DistilBERT (66M参数,减少40%)
├── 方法: 软标签蒸馏 + 隐藏状态蒸馏 + 余弦嵌入损失
├── 效果: 保留97%性能,推理速度提升60%
└── 代码: transformers.DistilBert

2. TinyBERT (华为, 2020)
├── 教师: BERT-large (340M参数)
├── 学生: TinyBERT-4L (14.5M参数,4层312维)
├── 方法: 嵌入层 + Transformer层 + 预测层 三层蒸馏
├── 效果: 在GLUE上达到BERT-base的96%,体积仅约1/7.5
└── 特点: 两阶段蒸馏(通用蒸馏+任务蒸馏)

3. MobileBERT (Google, 2020)
├── 教师: BERT-large
├── 学生: MobileBERT (25.3M参数,24层,倒瓶颈结构)
├── 方法: 倒瓶颈结构 + 特征蒸馏
├── 效果: 与BERT-base相当,速度快4.3倍
└── 特点: 专为移动设备优化

4. MiniLM (微软, 2020)
├── 方法: 深度自注意力蒸馏
├── 关键: 只蒸馏注意力分布和价值关系
├── 效果: 2层模型达到BERT-base的99%
└── 优势: 不需要教师隐藏状态

5. GPT-3 Distillation (OpenAI)
├── 教师: GPT-3 (175B参数)
├── 学生: 各种小模型
├── 方法: 黑盒蒸馏(仅使用API输出)
└── 应用: InstructGPT, ChatGPT训练

═══════════════════════════════════════════════════════════════════

6.2 最佳实践

Python
class DistillationBestPractices:
    """
    知识蒸馏最佳实践
    """

    @staticmethod
    def temperature_tuning(teacher_model, student_model, val_loader):
        """
        温度参数调优
        """
        temperatures = [1.0, 2.0, 4.0, 6.0, 8.0, 10.0]
        best_temp = 4.0
        best_acc = 0.0

        for T in temperatures:
            # 使用当前温度训练一个epoch
            acc = evaluate_with_temperature(
                teacher_model, student_model, val_loader, T
            )

            if acc > best_acc:
                best_acc = acc
                best_temp = T

        return best_temp

    @staticmethod
    def layer_mapping_strategy(teacher_layers, student_layers):
        """
        层映射策略

        决定学生模型的哪层学习教师模型的哪层
        """
        import numpy as np

        strategies = {
            'uniform': np.linspace(0, teacher_layers-1, student_layers, dtype=int),
            'first_k': list(range(student_layers)),
            'last_k': list(range(teacher_layers - student_layers, teacher_layers)),
            'skip': list(range(0, teacher_layers, teacher_layers // student_layers))[:student_layers]
        }

        return strategies

    @staticmethod
    def dynamic_alpha_scheduler(epoch, total_epochs, alpha_range=(0.3, 0.7)):
        """
        动态调整蒸馏权重

        早期侧重蒸馏,后期侧重真实标签
        """
        alpha_min, alpha_max = alpha_range

        # 线性递减
        alpha = alpha_max - (alpha_max - alpha_min) * (epoch / total_epochs)

        return alpha

    @staticmethod
    def data_augmentation_for_distillation(raw_data, teacher_model, num_augmentations=5):
        """
        数据增强辅助蒸馏

        使用教师模型生成更多训练数据
        """
        augmented_data = []

        for sample in raw_data:
            # 原始样本
            augmented_data.append(sample)

            # 生成变体
            for _ in range(num_augmentations):
                # 添加噪声、回译、同义词替换等
                variant = augment_sample(sample)

                # 使用教师标注
                with torch.no_grad():
                    teacher_pred = teacher_model(variant)

                augmented_data.append({
                    'input': variant,
                    'teacher_logits': teacher_pred,
                    'label': sample['label']
                })

        return augmented_data

蒸馏效果评估

8.1 评估指标

Python
class DistillationEvaluator:
    """
    蒸馏效果评估
    """

    @staticmethod
    def compute_metrics(teacher_model, student_model, test_loader):
        """
        计算蒸馏效果指标
        """
        metrics = {}

        # 1. 准确率对比
        teacher_acc = evaluate_accuracy(teacher_model, test_loader)
        student_acc = evaluate_accuracy(student_model, test_loader)

        metrics['teacher_accuracy'] = teacher_acc
        metrics['student_accuracy'] = student_acc
        metrics['accuracy_retention'] = student_acc / teacher_acc

        # 2. 推理速度对比
        teacher_speed = measure_inference_speed(teacher_model, test_loader)
        student_speed = measure_inference_speed(student_model, test_loader)

        metrics['teacher_speed'] = teacher_speed
        metrics['student_speed'] = student_speed
        metrics['speedup'] = teacher_speed / student_speed

        # 3. 模型大小对比
        teacher_params = count_parameters(teacher_model)
        student_params = count_parameters(student_model)

        metrics['teacher_params'] = teacher_params
        metrics['student_params'] = student_params
        metrics['compression_ratio'] = teacher_params / student_params

        # 4. 预测一致性
        agreement = compute_prediction_agreement(
            teacher_model, student_model, test_loader
        )
        metrics['prediction_agreement'] = agreement

        # 5. 输出分布相似度
        kl_div = compute_average_kl_divergence(
            teacher_model, student_model, test_loader
        )
        metrics['kl_divergence'] = kl_div

        return metrics

    @staticmethod
    def visualize_distillation_effect(teacher_logits, student_logits, labels):
        """
        可视化蒸馏效果
        """
        import matplotlib.pyplot as plt

        # 1. 置信度分布对比
        teacher_conf = F.softmax(teacher_logits, dim=1).max(dim=1)[0]
        student_conf = F.softmax(student_logits, dim=1).max(dim=1)[0]

        plt.figure(figsize=(12, 4))

        plt.subplot(1, 3, 1)
        plt.hist(teacher_conf.cpu().numpy(), bins=50, alpha=0.5, label='Teacher')
        plt.hist(student_conf.cpu().numpy(), bins=50, alpha=0.5, label='Student')
        plt.xlabel('Confidence')
        plt.ylabel('Count')
        plt.legend()
        plt.title('Confidence Distribution')

        # 2. 预测一致性矩阵
        plt.subplot(1, 3, 2)
        teacher_preds = teacher_logits.argmax(dim=1)
        student_preds = student_logits.argmax(dim=1)

        agreement = (teacher_preds == student_preds).float()
        plt.bar(['Disagree', 'Agree'], [1-agreement.mean(), agreement.mean()])
        plt.ylabel('Proportion')
        plt.title('Prediction Agreement')

        # 3. 正确率对比(按类别)
        plt.subplot(1, 3, 3)
        num_classes = teacher_logits.size(1)
        teacher_acc_per_class = []
        student_acc_per_class = []

        for c in range(num_classes):
            mask = (labels == c)
            if mask.sum() > 0:
                teacher_acc_per_class.append(
                    (teacher_preds[mask] == labels[mask]).float().mean().item()
                )
                student_acc_per_class.append(
                    (student_preds[mask] == labels[mask]).float().mean().item()
                )

        x = range(len(teacher_acc_per_class))
        plt.plot(x, teacher_acc_per_class, label='Teacher', marker='o')
        plt.plot(x, student_acc_per_class, label='Student', marker='s')
        plt.xlabel('Class')
        plt.ylabel('Accuracy')
        plt.legend()
        plt.title('Per-Class Accuracy')

        plt.tight_layout()
        plt.show()

def evaluate_accuracy(model, data_loader):
    """评估准确率"""
    model.eval()
    correct = 0
    total = 0

    with torch.no_grad():
        for batch in data_loader:
            inputs, labels = batch
            outputs = model(inputs)
            preds = outputs.argmax(dim=1)
            correct += (preds == labels).sum().item()
            total += labels.size(0)

    return correct / total

def measure_inference_speed(model, data_loader, num_batches=100):
    """测量推理速度"""
    import time

    model.eval()
    device = next(model.parameters()).device

    # 预热
    for i, batch in enumerate(data_loader):
        if i >= 10:
            break
        inputs, _ = batch
        _ = model(inputs.to(device))  # .to(device)将数据移至GPU/CPU

    # 测量
    start_time = time.time()

    for i, batch in enumerate(data_loader):
        if i >= num_batches:
            break
        inputs, _ = batch
        _ = model(inputs.to(device))

    end_time = time.time()

    return num_batches / (end_time - start_time)  # batches per second

def count_parameters(model):
    """统计参数量"""
    return sum(p.numel() for p in model.parameters())

def compute_prediction_agreement(teacher_model, student_model, data_loader):
    """计算预测一致性"""
    teacher_model.eval()
    student_model.eval()

    agreements = []

    with torch.no_grad():
        for batch in data_loader:
            inputs, _ = batch

            teacher_outputs = teacher_model(inputs)
            student_outputs = student_model(inputs)

            teacher_preds = teacher_outputs.argmax(dim=1)
            student_preds = student_outputs.argmax(dim=1)

            agreements.append((teacher_preds == student_preds).float())

    return torch.cat(agreements).mean().item()

def compute_average_kl_divergence(teacher_model, student_model, data_loader):
    """计算平均KL散度"""
    teacher_model.eval()
    student_model.eval()

    kl_divs = []

    with torch.no_grad():
        for batch in data_loader:
            inputs, _ = batch

            teacher_outputs = teacher_model(inputs)
            student_outputs = student_model(inputs)

            teacher_probs = F.softmax(teacher_outputs, dim=1)
            student_log_probs = F.log_softmax(student_outputs, dim=1)

            kl = F.kl_div(student_log_probs, teacher_probs, reduction='batchmean')
            kl_divs.append(kl.item())

    return sum(kl_divs) / len(kl_divs)

总结

知识蒸馏方法对比

方法 适用场景 优点 缺点
软标签蒸馏 通用场景 简单有效 需要调整温度
特征蒸馏 白盒场景 传递更多信息 需要维度对齐
注意力蒸馏 Transformer 捕捉结构信息 计算复杂
自蒸馏 无教师模型 不需要教师 效果可能较差
在线蒸馏 多个模型 互相提升 训练复杂

关键超参数

Text Only
温度 (Temperature):
├── 推荐范围: 2.0 - 8.0
├── 任务复杂 → 温度高
└── 类别多 → 温度高

蒸馏权重 (Alpha):
├── 推荐范围: 0.5 - 0.9
├── 早期: 高权重(学习教师)
└── 后期: 低权重(微调细节)

层映射:
├── 均匀映射: 通用
├── 最后k层: 任务相关
└── 学习映射: 效果最佳但复杂

最佳实践清单

  • 选择合适的教师-学生模型尺寸比(通常 2-10 倍)
  • 使用温度调优找到最佳温度
  • 考虑多层蒸馏而不仅是输出
  • 结合数据增强提升效果
  • 动态调整蒸馏权重
  • 评估时关注准确率、速度和大小三个维度

下一步:学习01-数据工程与预处理,进入系统与工程阶段!


最后更新日期: 2026-03-26 适用版本: LLM 学习教程 v2026