跳转至

03 - 量化精度损失评估

评估和补偿量化带来的精度损失

📖 章节概述

本章将介绍如何评估量化带来的精度损失,包括精度评估指标、损失分析和补偿方法等内容。

🎯 学习目标

完成本章后,你将能够:

  • 掌握精度评估的指标和方法
  • 了解量化损失的分析技巧
  • 实现精度补偿的方法
  • 能够评估和优化量化效果

1. 精度评估指标

1.1 分类任务指标

Python
import numpy as np
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score

def evaluate_classification(y_true, y_pred):
    """
    评估分类任务

    Args:
        y_true: 真实标签
        y_pred: 预测标签
    """
    metrics = {
        'accuracy': accuracy_score(y_true, y_pred),
        'precision': precision_score(y_true, y_pred, average='weighted'),
        'recall': recall_score(y_true, y_pred, average='weighted'),
        'f1_score': f1_score(y_true, y_pred, average='weighted')
    }

    return metrics

# 使用示例
# metrics = evaluate_classification(y_true, y_pred)
# print(metrics)

1.2 生成任务指标

Python
import numpy as np
from nltk.translate.bleu_score import sentence_bleu
from rouge import Rouge

def evaluate_generation(references, hypotheses):
    """
    评估生成任务

    Args:
        references: 参考文本列表
        hypotheses: 生成文本列表
    """
    # BLEU分数
    bleu_scores = []
    for ref, hyp in zip(references, hypotheses):
        bleu = sentence_bleu([ref.split()], hyp.split())
        bleu_scores.append(bleu)

    # ROUGE分数
    rouge = Rouge()
    rouge_scores = rouge.get_scores(hypotheses, references, avg=True)

    metrics = {
        'bleu': np.mean(bleu_scores),
        'rouge_1': rouge_scores['rouge-1']['f'],
        'rouge_2': rouge_scores['rouge-2']['f'],
        'rouge_l': rouge_scores['rouge-l']['f']
    }

    return metrics

# 使用示例
# metrics = evaluate_generation(references, hypotheses)
# print(metrics)

2. 损失分析

2.1 量化误差分析

Python
import torch
import numpy as np

def analyze_quantization_error(original, quantized):
    """
    分析量化误差

    Args:
        original: 原始张量
        quantized: 量化张量
    """
    # 计算误差
    error = original - quantized

    # 计算统计信息
    metrics = {
        'mean_absolute_error': torch.mean(torch.abs(error)).item(),
        'mean_squared_error': torch.mean(error ** 2).item(),
        'root_mean_squared_error': torch.sqrt(torch.mean(error ** 2)).item(),
        'max_absolute_error': torch.max(torch.abs(error)).item(),
        'std_error': torch.std(error).item()
    }

    return metrics

# 使用示例
# error_metrics = analyze_quantization_error(original_weight, quantized_weight)
# print(error_metrics)

2.2 层级误差分析

Python
import torch.nn as nn

def analyze_layer_errors(model_fp32, model_quantized, dataloader, max_batches=10):
    """
    分析层级误差:使用 forward hook 收集每层中间激活值

    Args:
        model_fp32: FP32 模型
        model_quantized: 量化模型
        dataloader: 数据加载器
        max_batches: 最多分析多少个 batch
    """
    model_fp32.eval()
    model_quantized.eval()

    layer_errors = {}
    fp32_activations = {}
    quant_activations = {}

    def make_hook(storage_dict, name):
        """创建 forward hook,将每层输出存入 storage_dict"""
        def hook_fn(module, input, output):
            storage_dict[name] = output.detach()
        return hook_fn

    # 为两个模型的所有 Linear 层注册 hook
    fp32_hooks = []
    quant_hooks = []
    for name, module in model_fp32.named_modules():
        if isinstance(module, nn.Linear):
            fp32_hooks.append(module.register_forward_hook(make_hook(fp32_activations, name)))
    for name, module in model_quantized.named_modules():
        if isinstance(module, nn.Linear):
            quant_hooks.append(module.register_forward_hook(make_hook(quant_activations, name)))

    # 运行数据并收集误差
    with torch.no_grad():
        for i, (batch_x, _) in enumerate(dataloader):
            if i >= max_batches:
                break

            # 前向传播
            fp32_activations.clear()
            quant_activations.clear()
            model_fp32(batch_x)
            model_quantized(batch_x)

            # 计算每层误差
            for name in fp32_activations:
                if name in quant_activations:
                    if name not in layer_errors:
                        layer_errors[name] = []
                    error = torch.mean(torch.abs(
                        fp32_activations[name] - quant_activations[name]
                    )).item()
                    layer_errors[name].append(error)

    # 移除 hook
    for hook in fp32_hooks + quant_hooks:
        hook.remove()

    # 计算平均误差
    for name in layer_errors:
        layer_errors[name] = np.mean(layer_errors[name])

    return layer_errors

# 使用示例
# layer_errors = analyze_layer_errors(model_fp32, model_quantized, dataloader)
# for name, error in sorted(layer_errors.items(), key=lambda x: x[1], reverse=True):
#     print(f"{name}: MAE = {error:.6f}")

3. 补偿方法

3.1 量化后微调

Python
import torch
import torch.nn as nn

def post_quantization_finetune(quantized_model, dataloader, epochs=3):
    """
    量化后微调

    Args:
        quantized_model: 量化模型
        dataloader: 数据加载器
        epochs: 训练轮数
    """
    # 冻结量化参数
    for name, param in quantized_model.named_parameters():
        if 'weight' in name and 'quantized' in name:
            param.requires_grad = False

    # 训练
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(
        filter(lambda p: p.requires_grad, quantized_model.parameters()),  # lambda匿名函数:简洁的单行函数
        lr=0.0001
    )

    quantized_model.train()
    for epoch in range(epochs):
        total_loss = 0
        for batch_x, batch_y in dataloader:
            optimizer.zero_grad()
            outputs = quantized_model(batch_x)
            loss = criterion(outputs, batch_y)
            loss.backward()
            optimizer.step()
            total_loss += loss.item()

        print(f"Epoch {epoch+1}, Loss: {total_loss/len(dataloader):.4f}")

    return quantized_model

# 使用示例
# finetuned_model = post_quantization_finetune(quantized_model, dataloader)

3.2 知识蒸馏补偿

Python
class DistillationLoss(nn.Module):
    """
    蒸馏损失
    """
    def __init__(self, temperature=5.0, alpha=0.5):
        super().__init__()  # super()调用父类方法
        self.temperature = temperature
        self.alpha = alpha
        self.kl_div = nn.KLDivLoss(reduction='batchmean')
        self.ce_loss = nn.CrossEntropyLoss()

    def forward(self, student_outputs, teacher_outputs, targets):
        """
        计算蒸馏损失
        """
        # 软标签损失
        soft_loss = self.kl_div(
            torch.log_softmax(student_outputs / self.temperature, dim=1),
            torch.softmax(teacher_outputs / self.temperature, dim=1)
        ) * (self.temperature ** 2)

        # 硬标签损失
        hard_loss = self.ce_loss(student_outputs, targets)

        # 组合损失
        loss = self.alpha * soft_loss + (1 - self.alpha) * hard_loss

        return loss

def distillation_compensation(teacher_model, student_model, dataloader, epochs=5):
    """
    蒸馏补偿

    Args:
        teacher_model: 教师模型(FP32)
        student_model: 学生模型(量化)
        dataloader: 数据加载器
        epochs: 训练轮数
    """
    # 冻结教师模型
    teacher_model.eval()
    for param in teacher_model.parameters():
        param.requires_grad = False

    # 训练学生模型
    criterion = DistillationLoss(temperature=5.0, alpha=0.5)
    optimizer = torch.optim.Adam(student_model.parameters(), lr=0.001)

    student_model.train()
    for epoch in range(epochs):
        total_loss = 0
        for batch_x, batch_y in dataloader:
            # 教师模型前向传播
            with torch.no_grad():
                teacher_outputs = teacher_model(batch_x)

            # 学生模型前向传播
            student_outputs = student_model(batch_x)

            # 计算损失
            loss = criterion(student_outputs, teacher_outputs, batch_y)

            # 反向传播和优化
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            total_loss += loss.item()

        print(f"Epoch {epoch+1}, Loss: {total_loss/len(dataloader):.4f}")

    return student_model

# 使用示例
# compensated_model = distillation_compensation(teacher_model, student_model, dataloader)

4. 面试题

基础题

Q1: 如何评估量化带来的精度损失?

A: 使用准确率、 BLEU 、 ROUGE 等指标在测试集上评估量化前后的模型性能,对比分析精度损失。

Q2: 什么是量化后微调?

A: 量化后微调是在量化后的模型上使用少量数据进行微调,以补偿量化带来的精度损失。

进阶题

Q3: 如何减少量化带来的精度损失?

A: 可以通过以下方法减少精度损失: 1. 使用量化感知训练( QAT ) 2. 进行量化后微调 3. 使用知识蒸馏 4. 优化量化参数 5. 使用更好的量化算法

Q4: 如何分析量化误差的分布?

A: 可以计算量化误差的统计信息(均值、方差、最大值等),并可视化误差分布,分析误差的主要来源。

5. 练习题

基础练习

  1. 实现精度评估
Python
# 练习: 实现精度评估函数
def evaluate_accuracy(model, dataloader):
    # 你的代码
    pass
  1. 计算量化误差
Python
# 练习: 计算量化误差
def quantization_error(original, quantized):
    # 你的代码
    pass

进阶练习

  1. 实现量化后微调
Python
# 练习: 实现量化后微调
class PostQuantizationFinetune:
    def __init__(self, model):
        # 你的代码
        pass

    def finetune(self, dataloader, epochs):
        # 你的代码
        pass
  1. 实现蒸馏补偿
Python
# 练习: 实现蒸馏补偿
class DistillationCompensation:
    def __init__(self, teacher, student):
        # 你的代码
        pass

    def compensate(self, dataloader, epochs):
        # 你的代码
        pass

📝 参考答案提示

以下为练习题的关键实现思路,建议先独立完成后再对照。

练习 1:实现精度评估

Python
def evaluate_accuracy(model, dataloader):
    model.eval()
    correct, total = 0, 0
    with torch.no_grad():
        for batch_x, batch_y in dataloader:
            outputs = model(batch_x)
            preds = outputs.argmax(dim=1)
            correct += (preds == batch_y).sum().item()
            total += batch_y.size(0)
    return correct / total

练习 2:计算量化误差

Python
def quantization_error(original, quantized):
    error = original - quantized
    return {
        'mae': torch.mean(torch.abs(error)).item(),
        'mse': torch.mean(error**2).item(),
        'max_error': torch.max(torch.abs(error)).item()
    }

练习 3 提示:量化后微调的核心是冻结量化参数(requires_grad=False),用小学习率(1e-4~1e-5)只训练非量化层。参考本章 3.1 节 post_quantization_finetune() 函数。

练习 4 提示:蒸馏补偿的核心是 DistillationLoss = α × T² × KL散度(软标签) + (1-α) × CE(硬标签)。冻结教师模型,只训练学生(量化)模型。参考本章 3.2 节 DistillationLoss 类。

6. 最佳实践

✅ 推荐做法

  1. 全面评估
  2. 在多个数据集上评估
  3. 使用多种评估指标
  4. 记录详细结果

  5. 分析误差来源

  6. 分析层级误差
  7. 识别敏感层
  8. 针对性优化

  9. 应用补偿方法

  10. 使用量化后微调
  11. 应用知识蒸馏
  12. 迭代优化

❌ 避免做法

  1. 单一指标评估
  2. 不要只看准确率
  3. 使用多种指标
  4. 综合评估

  5. 忽略误差分析

  6. 分析误差分布
  7. 识别问题根源
  8. 针对性解决

  9. 过度补偿

  10. 不要过度微调
  11. 避免过拟合
  12. 保持泛化能力

7. 总结

本章介绍了量化精度损失的评估和补偿:

  • 评估指标: 分类和生成任务的指标
  • 损失分析: 误差分析和层级分析
  • 补偿方法: 微调和蒸馏补偿

掌握这些方法可以有效减少量化带来的精度损失。

8. 下一步

继续学习04-常见面试问题和解答,准备常见的面试问题。

⚠️ 核验说明(2026-03-26):本页已纳入 2026-03-26 全站统一复核批次。若文中涉及外部模型、API、版本号、价格或第三方产品名称,请以官方文档和实际运行环境为准。


最后更新日期: 2026-03-26