跳转至

06-联邦学习

学习时间: 约5-7小时 难度级别: ⭐⭐⭐⭐ 中高级 前置知识: 深度学习基础、分布式系统概念、PyTorch 学习目标: 理解联邦学习的动机与核心算法(FedAvg),掌握数据异构性、通信效率、隐私保护等关键问题


目录


1. 联邦学习概述

联邦学习架构

1.1 数据孤岛问题

传统集中式训练需要将所有数据汇聚到一处,但在实际场景中面临:

挑战 说明
隐私法规 GDPR、中国《个人信息保护法》严格限制数据出境
数据安全 医疗、金融等数据高度敏感
数据壁垒 各机构间数据无法直接共享
带宽限制 海量数据传输成本高昂

1.2 联邦学习的定义

多个参与方在不交换原始数据的前提下,协作训练一个共享的机器学习模型。

核心原则:数据不动,模型动

1.3 联邦学习分类

类型 数据分布 场景
横向联邦 特征相同,样本不同 不同地区的同类机构
纵向联邦 样本相同,特征不同 银行+电商(同一用户不同数据)
联邦迁移 特征和样本都不同 跨领域协作

1.4 联邦学习 vs 传统分布式训练

特性 分布式训练 联邦学习
数据分布 IID(随机分片) Non-IID(自然分布)
通信频率 高频(每个batch) 低频(多轮本地训练)
参与方 少量(几十个GPU) 大量(数百万终端)
数据隐私 不考虑 核心关注
参与方可靠性 低(可能掉线)

2. FedAvg:核心算法

FedAvg算法流程

2.1 算法流程

Text Only
FedAvg算法:
初始化: 服务器持有全局模型 w_0
for 每一轮 t = 1, 2, ..., T:
    1. 服务器随机选择一部分客户端 S_t (|S_t| = C·K)
    2. 服务器将当前全局模型 w_t 广播给选中的客户端
    3. 每个选中的客户端 k:
       - 用本地数据训练 E 个epoch
       - w_k^{t+1} = w_t - η·∇L_k(w_t)  (SGD多步)
       - 将更新后的模型 w_k^{t+1} 上传给服务器
    4. 服务器聚合:
       w_{t+1} = Σ (n_k/n) · w_k^{t+1}
       其中 n_k 为客户端k的样本数, n为总样本数

2.2 关键超参数

参数 含义 影响
C 每轮参与率 C越大收敛越快,但通信开销增大
E 本地训练轮数 E越大通信效率越高,但可能导致客户端漂移
B 本地batch size 影响本地优化质量
η 学习率 需要根据E和数据异构程度调整

2.3 聚合策略对比

方法 聚合方式 特点
FedAvg 按样本数加权平均 经典基线
FedProx FedAvg + 近端项 缓解异构
FedNova 归一化聚合 消除本地步数差异
FedOpt 服务器端自适应优化器 加速收敛

3. 数据异构性问题

异构联邦学习

3.1 Non-IID数据的类型

类型 说明 示例
标签分布偏斜 不同客户端拥有不同的标签分布 用户A只有猫的照片,用户B只有狗
特征分布偏斜 同一标签在不同客户端的特征分布不同 不同地区的方言语音
数量不平衡 客户端数据量差异大 活跃用户 vs 低频用户
概念漂移 数据分布随时间变化 用户兴趣随季节变化

3.2 客户端漂移(Client Drift)

当本地训练多步后,各客户端模型朝各自本地最优方向偏移:

\[w_k^{t+1} \to w_k^* \neq w^*_{global}\]

各客户端的更新方向不一致,聚合后的全局模型可能远离全局最优。

3.3 解决方法

方法 思路
FedProx 在本地目标中加正则项 \(\frac{\mu}{2}\lVert w - w_t\rVert^2\),限制偏离程度
SCAFFOLD 用控制变量修正本地梯度方向
FedBN 不聚合BatchNorm参数(BN统计量反映本地分布)
数据增强 用全局共享的少量公共数据校正方向

4. 通信效率优化

4.1 通信瓶颈

一个典型的ResNet-50有 ~25M 参数(约100MB),每轮每个客户端需上传+下载200MB。

若有1000个客户端、100轮:总通信量约 200TB。

4.2 优化策略

策略 方法 压缩比
梯度压缩 Top-K稀疏化 10-100x
量化 低精度传输(INT8/1-bit) 4-32x
知识蒸馏 只传模型输出/logits >100x
增加本地计算 提高E减少通信轮数 取决于E
异步聚合 不需要等待最慢客户端

4.3 梯度压缩示例

Python
def top_k_sparsification(gradient, k_ratio=0.01):
    """Top-K梯度稀疏化"""
    flat_grad = gradient.flatten()
    k = max(1, int(flat_grad.numel() * k_ratio))

    # 取绝对值最大的Top-K个元素
    values, indices = flat_grad.abs().topk(k)
    sparse_grad = torch.zeros_like(flat_grad)
    sparse_grad[indices] = flat_grad[indices]

    return sparse_grad.view_as(gradient)

def quantize_gradient(gradient, n_bits=8):
    """低精度梯度量化"""
    min_val = gradient.min()
    max_val = gradient.max()
    n_levels = 2 ** n_bits - 1

    # 线性量化
    scale = (max_val - min_val) / n_levels
    quantized = torch.round((gradient - min_val) / scale).to(torch.uint8)

    # 解量化
    dequantized = quantized.float() * scale + min_val
    return dequantized

5. 隐私与安全

联邦学习隐私保护

5.1 联邦学习的隐私风险

仅传递模型参数/梯度并不意味着安全

攻击类型 说明
梯度反演 从梯度中恢复训练数据(DLG attack)
成员推断 判断某个样本是否在训练集中
模型反演 从模型中恢复训练数据的统计特征
拜占庭攻击 恶意客户端发送有害更新

5.2 差分隐私(DP)

在梯度中加入校准噪声,提供数学隐私保证:

\[\tilde{g} = \text{clip}(g, C) + \mathcal{N}(0, \sigma^2 C^2 I)\]
参数 含义
ε (epsilon) 隐私预算,越小越隐私
δ (delta) 允许的失败概率
C 梯度裁剪阈值
σ 噪声幅度

5.3 安全聚合(Secure Aggregation)

通过密码学协议,使服务器只能看到聚合结果,无法获取单个客户端的更新:

Text Only
客户端A: 上传 w_A + mask_A
客户端B: 上传 w_B + mask_B
...
服务器聚合后: Σ mask_i = 0 (掩码互相抵消)
服务器只得到: Σ w_i (无法分离)

5.4 拜占庭鲁棒聚合

方法 思路
Median 取各维度中位数
Trimmed Mean 去掉最大最小后取均值
Krum 选择与其他更新最相似的那个
FLTrust 服务器用小量可信数据验证更新方向

6. 个性化联邦学习

6.1 为什么需要个性化

全局模型在数据异构场景下可能对某些客户端表现不佳。每个客户端需要一个适合自身数据分布的模型。

6.2 主要方法

方法 策略 说明
本地微调 全局模型 + 本地微调 简单但可能过拟合
Meta-Learning 学共享初始化(MAML) Per-FedAvg: 全局模型作为好的初始化
混合专家 本地模型 + 全局模型加权 根据样本自适应混合
部分共享 只聚合部分层 FedPer: 共享底层,本地顶层
聚类 客户端分组,组内聚合 相似的客户端共享模型

6.3 FedPer示例思路

Python
class FedPerModel(nn.Module):  # 继承nn.Module定义神经网络层
    def __init__(self):  # __init__构造方法,创建对象时自动调用
        super().__init__()  # super()调用父类方法
        # 共享层(参与联邦聚合)
        self.shared_layers = nn.Sequential(
            nn.Conv2d(3, 32, 3, padding=1),
            nn.ReLU(),
            nn.Conv2d(32, 64, 3, padding=1),
            nn.ReLU(),
            nn.AdaptiveAvgPool2d(1),
        )
        # 个性化层(仅本地训练)
        self.personal_layers = nn.Sequential(
            nn.Linear(64, 32),
            nn.ReLU(),
            nn.Linear(32, 10),
        )

    def forward(self, x):
        x = self.shared_layers(x).flatten(1)
        return self.personal_layers(x)

    def get_shared_params(self):
        return dict(self.shared_layers.named_parameters())

    def get_personal_params(self):
        return dict(self.personal_layers.named_parameters())

7. 联邦学习系统设计

7.1 实际工业部署架构

Text Only
                    ┌──────────────────┐
                    │   协调服务器      │
                    │  - 模型聚合      │
                    │  - 客户端调度    │
                    │  - 版本管理      │
                    └──────┬───────────┘
                           │ HTTPS/gRPC
          ┌────────────────┼────────────────┐
          ▼                ▼                ▼
    ┌──────────┐    ┌──────────┐    ┌──────────┐
    │ 客户端A   │    │ 客户端B   │    │ 客户端C   │
    │ 本地数据  │    │ 本地数据  │    │ 本地数据  │
    │ 本地训练  │    │ 本地训练  │    │ 本地训练  │
    └──────────┘    └──────────┘    └──────────┘

7.2 工业实践中的额外挑战

挑战 解决思路
设备异质性 允许不同客户端训练不同大小子模型
掉线/延迟 异步聚合、超时淘汰
模型版本 版本号管理、增量更新
激励机制 贡献度评估、奖励分配

7.3 主流框架

框架 开发者 特点
Flower Flower Labs 灵活易用,支持多框架
PySyft OpenMined 注重隐私,支持MPC
FATE 微众银行 工业级,纵向联邦
FedML FedML Inc 跨平台,支持多种拓扑

8. 实战:PyTorch联邦学习模拟

Python
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Subset
import numpy as np
from copy import deepcopy

# ==================== 模型定义 ====================
class SimpleCNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 16, 3, padding=1)
        self.conv2 = nn.Conv2d(16, 32, 3, padding=1)
        self.pool = nn.MaxPool2d(2)
        self.fc1 = nn.Linear(32 * 7 * 7, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 32 * 7 * 7)  # view重塑张量形状(要求内存连续)
        x = F.relu(self.fc1(x))
        return self.fc2(x)

# ==================== 数据划分(模拟Non-IID) ====================
def create_non_iid_partitions(dataset, n_clients, n_shards_per_client=2):
    """按标签排序后切片,模拟Label-skew Non-IID"""
    labels = np.array([dataset[i][1] for i in range(len(dataset))])  # 列表推导式,简洁创建列表  # np.array创建NumPy数组
    sorted_indices = np.argsort(labels)

    n_shards = n_clients * n_shards_per_client
    shard_size = len(dataset) // n_shards
    shards = [sorted_indices[i * shard_size:(i + 1) * shard_size]
              for i in range(n_shards)]

    np.random.shuffle(shards)
    client_indices = {}
    for i in range(n_clients):
        client_indices[i] = np.concatenate(
            shards[i * n_shards_per_client:(i + 1) * n_shards_per_client]
        )
    return client_indices

# ==================== 客户端本地训练 ====================
def client_update(model, dataloader, lr=0.01, local_epochs=5):
    """客户端本地训练"""
    model.train()  # train()开启训练模式
    optimizer = torch.optim.SGD(model.parameters(), lr=lr)

    for epoch in range(local_epochs):
        for x, y in dataloader:
            optimizer.zero_grad()  # 清零梯度,防止梯度累积
            loss = F.cross_entropy(model(x), y)
            loss.backward()  # 反向传播计算梯度
            optimizer.step()  # 根据梯度更新模型参数

    return model.state_dict()

# ==================== 服务器聚合 ====================
def fedavg_aggregate(global_model, client_state_dicts, client_sizes):
    """FedAvg聚合:按样本数加权平均"""
    total_size = sum(client_sizes)
    global_state = global_model.state_dict()

    for key in global_state:
        global_state[key] = sum(
            client_state_dicts[i][key].float() * (client_sizes[i] / total_size)
            for i in range(len(client_state_dicts))
        )

    global_model.load_state_dict(global_state)
    return global_model

# ==================== 联邦训练主循环 ====================
def federated_train(train_dataset, test_dataset, n_clients=10,
                    n_rounds=20, local_epochs=5, participation_rate=0.3):
    """完整的联邦训练流程"""
    # 数据划分
    client_indices = create_non_iid_partitions(train_dataset, n_clients)

    # 初始化全局模型
    global_model = SimpleCNN()

    # 测试集loader
    test_loader = DataLoader(test_dataset, batch_size=128)

    for round_idx in range(n_rounds):
        # 1. 选择参与客户端
        n_selected = max(1, int(n_clients * participation_rate))
        selected = np.random.choice(n_clients, n_selected, replace=False)

        client_states = []
        client_sizes = []

        for client_id in selected:
            # 2. 下发全局模型
            local_model = deepcopy(global_model)

            # 3. 本地训练
            indices = client_indices[client_id]
            local_dataset = Subset(train_dataset, indices)
            local_loader = DataLoader(local_dataset, batch_size=32, shuffle=True)

            state_dict = client_update(local_model, local_loader,
                                       local_epochs=local_epochs)
            client_states.append(state_dict)
            client_sizes.append(len(indices))

        # 4. 服务器聚合
        global_model = fedavg_aggregate(global_model, client_states, client_sizes)

        # 5. 评估
        if (round_idx + 1) % 5 == 0:
            global_model.eval()  # eval()开启评估模式(关闭Dropout等)
            correct, total = 0, 0
            with torch.no_grad():  # 禁用梯度计算,节省内存
                for x, y in test_loader:
                    pred = global_model(x).argmax(dim=1)
                    correct += (pred == y).sum().item()  # .item()将单元素张量转为Python数值
                    total += y.size(0)
            print(f"Round {round_idx+1}, Accuracy: {correct/total:.4f}")

    return global_model

# 使用示例(需要torchvision的MNIST数据集)
# from torchvision import datasets, transforms
# train_set = datasets.MNIST('./data', train=True, download=True,
#                            transform=transforms.ToTensor())
# test_set = datasets.MNIST('./data', train=False,
#                           transform=transforms.ToTensor())
# model = federated_train(train_set, test_set)

9. 面试高频题

Q1: FedAvg算法的核心流程是什么?

:每轮通信中:(1) 服务器随机选取部分客户端;(2) 广播当前全局模型;(3) 各客户端在本地数据上训练多个epoch;(4) 上传模型参数到服务器;(5) 服务器按样本数加权平均聚合为新的全局模型。与传统分布式SGD的关键区别是客户端进行多步本地更新,减少通信轮数。

Q2: Non-IID数据为什么影响联邦学习?

:Non-IID导致各客户端的本地目标函数不一致。本地训练多步后,各客户端模型朝不同方向偏移(客户端漂移),聚合后的全局模型可能远离全局最优。标签分布偏斜越严重,性能下降越明显。解决方法包括FedProx(近端正则)、SCAFFOLD(方差修正)、数据共享等。

Q3: 联邦学习真的能保护隐私吗?

:仅靠不传原始数据不够安全。梯度反演攻击(如DLG)可从梯度中恢复训练数据。需要额外的隐私保护机制:差分隐私(加噪声)、安全聚合(密码学协议)、同态加密等。实际部署通常需要结合多种技术。

Q4: 横向联邦和纵向联邦的区别?

:横向联邦中各参与方拥有相同特征、不同样本(如不同医院的病历),将数据按行分割。纵向联邦中各参与方拥有相同样本、不同特征(如银行和电商对同一用户的不同数据),将数据按列分割。纵向联邦还需要实体对齐步骤。

Q5: 个性化联邦学习有哪些典型方法?

:(1) 本地微调:全局模型在本地数据上微调;(2) 元学习:将全局模型作为好的初始化(Per-FedAvg);(3) 部分共享:底层聚合、顶层个性化(FedPer);(4) 客户端聚类:相似客户端共享模型;(5) 混合:本地与全局模型自适应插值。


10. 练习与自我检查

编程练习

  1. 基础:在MNIST上实现FedAvg,对比IID和Non-IID数据分区的性能差异
  2. 进阶:实现FedProx(加近端正则项),观察其在Non-IID下的改善效果
  3. 挑战:实现带差分隐私的联邦学习(DP-FedAvg),分析隐私预算ε与精度的权衡

检查清单

  • 能说明联邦学习的定义和"数据不动模型动"的核心原则
  • 理解横向、纵向、联邦迁移学习的区别
  • 能写出FedAvg的完整算法流程
  • 理解Non-IID数据的类型及其对训练的影响
  • 知道客户端漂移的原因及FedProx/SCAFFOLD的解决思路
  • 了解通信效率优化的主要手段
  • 理解差分隐私和安全聚合的基本原理
  • 能讨论联邦学习的隐私风险和防御方法
  • 了解个性化联邦学习的主要方法
  • 能用PyTorch实现简单的联邦学习模拟

联邦学习应用

扩展阅读: - McMahan et al., 2017: Communication-Efficient Learning of Deep Networks from Decentralized Data (FedAvg) - Li et al., 2020: Federated Optimization in Heterogeneous Networks (FedProx) - Kairouz et al., 2021: Advances and Open Problems in Federated Learning - Yang et al., 2019: Federated Machine Learning: Concept and Applications