06-联邦学习¶
学习时间: 约5-7小时 难度级别: ⭐⭐⭐⭐ 中高级 前置知识: 深度学习基础、分布式系统概念、PyTorch 学习目标: 理解联邦学习的动机与核心算法(FedAvg),掌握数据异构性、通信效率、隐私保护等关键问题
目录¶
- 1. 联邦学习概述
- 2. FedAvg:核心算法
- 3. 数据异构性问题
- 4. 通信效率优化
- 5. 隐私与安全
- 6. 个性化联邦学习
- 7. 联邦学习系统设计
- 8. 实战:PyTorch联邦学习模拟
- 9. 面试高频题
- 10. 练习与自我检查
1. 联邦学习概述¶
1.1 数据孤岛问题¶
传统集中式训练需要将所有数据汇聚到一处,但在实际场景中面临:
| 挑战 | 说明 |
|---|---|
| 隐私法规 | GDPR、中国《个人信息保护法》严格限制数据出境 |
| 数据安全 | 医疗、金融等数据高度敏感 |
| 数据壁垒 | 各机构间数据无法直接共享 |
| 带宽限制 | 海量数据传输成本高昂 |
1.2 联邦学习的定义¶
多个参与方在不交换原始数据的前提下,协作训练一个共享的机器学习模型。
核心原则:数据不动,模型动。
1.3 联邦学习分类¶
| 类型 | 数据分布 | 场景 |
|---|---|---|
| 横向联邦 | 特征相同,样本不同 | 不同地区的同类机构 |
| 纵向联邦 | 样本相同,特征不同 | 银行+电商(同一用户不同数据) |
| 联邦迁移 | 特征和样本都不同 | 跨领域协作 |
1.4 联邦学习 vs 传统分布式训练¶
| 特性 | 分布式训练 | 联邦学习 |
|---|---|---|
| 数据分布 | IID(随机分片) | Non-IID(自然分布) |
| 通信频率 | 高频(每个batch) | 低频(多轮本地训练) |
| 参与方 | 少量(几十个GPU) | 大量(数百万终端) |
| 数据隐私 | 不考虑 | 核心关注 |
| 参与方可靠性 | 高 | 低(可能掉线) |
2. FedAvg:核心算法¶
2.1 算法流程¶
FedAvg算法:
初始化: 服务器持有全局模型 w_0
for 每一轮 t = 1, 2, ..., T:
1. 服务器随机选择一部分客户端 S_t (|S_t| = C·K)
2. 服务器将当前全局模型 w_t 广播给选中的客户端
3. 每个选中的客户端 k:
- 用本地数据训练 E 个epoch
- w_k^{t+1} = w_t - η·∇L_k(w_t) (SGD多步)
- 将更新后的模型 w_k^{t+1} 上传给服务器
4. 服务器聚合:
w_{t+1} = Σ (n_k/n) · w_k^{t+1}
其中 n_k 为客户端k的样本数, n为总样本数
2.2 关键超参数¶
| 参数 | 含义 | 影响 |
|---|---|---|
| C | 每轮参与率 | C越大收敛越快,但通信开销增大 |
| E | 本地训练轮数 | E越大通信效率越高,但可能导致客户端漂移 |
| B | 本地batch size | 影响本地优化质量 |
| η | 学习率 | 需要根据E和数据异构程度调整 |
2.3 聚合策略对比¶
| 方法 | 聚合方式 | 特点 |
|---|---|---|
| FedAvg | 按样本数加权平均 | 经典基线 |
| FedProx | FedAvg + 近端项 | 缓解异构 |
| FedNova | 归一化聚合 | 消除本地步数差异 |
| FedOpt | 服务器端自适应优化器 | 加速收敛 |
3. 数据异构性问题¶
3.1 Non-IID数据的类型¶
| 类型 | 说明 | 示例 |
|---|---|---|
| 标签分布偏斜 | 不同客户端拥有不同的标签分布 | 用户A只有猫的照片,用户B只有狗 |
| 特征分布偏斜 | 同一标签在不同客户端的特征分布不同 | 不同地区的方言语音 |
| 数量不平衡 | 客户端数据量差异大 | 活跃用户 vs 低频用户 |
| 概念漂移 | 数据分布随时间变化 | 用户兴趣随季节变化 |
3.2 客户端漂移(Client Drift)¶
当本地训练多步后,各客户端模型朝各自本地最优方向偏移:
各客户端的更新方向不一致,聚合后的全局模型可能远离全局最优。
3.3 解决方法¶
| 方法 | 思路 |
|---|---|
| FedProx | 在本地目标中加正则项 \(\frac{\mu}{2}\lVert w - w_t\rVert^2\),限制偏离程度 |
| SCAFFOLD | 用控制变量修正本地梯度方向 |
| FedBN | 不聚合BatchNorm参数(BN统计量反映本地分布) |
| 数据增强 | 用全局共享的少量公共数据校正方向 |
4. 通信效率优化¶
4.1 通信瓶颈¶
一个典型的ResNet-50有 ~25M 参数(约100MB),每轮每个客户端需上传+下载200MB。
若有1000个客户端、100轮:总通信量约 200TB。
4.2 优化策略¶
| 策略 | 方法 | 压缩比 |
|---|---|---|
| 梯度压缩 | Top-K稀疏化 | 10-100x |
| 量化 | 低精度传输(INT8/1-bit) | 4-32x |
| 知识蒸馏 | 只传模型输出/logits | >100x |
| 增加本地计算 | 提高E减少通信轮数 | 取决于E |
| 异步聚合 | 不需要等待最慢客户端 | — |
4.3 梯度压缩示例¶
def top_k_sparsification(gradient, k_ratio=0.01):
"""Top-K梯度稀疏化"""
flat_grad = gradient.flatten()
k = max(1, int(flat_grad.numel() * k_ratio))
# 取绝对值最大的Top-K个元素
values, indices = flat_grad.abs().topk(k)
sparse_grad = torch.zeros_like(flat_grad)
sparse_grad[indices] = flat_grad[indices]
return sparse_grad.view_as(gradient)
def quantize_gradient(gradient, n_bits=8):
"""低精度梯度量化"""
min_val = gradient.min()
max_val = gradient.max()
n_levels = 2 ** n_bits - 1
# 线性量化
scale = (max_val - min_val) / n_levels
quantized = torch.round((gradient - min_val) / scale).to(torch.uint8)
# 解量化
dequantized = quantized.float() * scale + min_val
return dequantized
5. 隐私与安全¶
5.1 联邦学习的隐私风险¶
仅传递模型参数/梯度并不意味着安全:
| 攻击类型 | 说明 |
|---|---|
| 梯度反演 | 从梯度中恢复训练数据(DLG attack) |
| 成员推断 | 判断某个样本是否在训练集中 |
| 模型反演 | 从模型中恢复训练数据的统计特征 |
| 拜占庭攻击 | 恶意客户端发送有害更新 |
5.2 差分隐私(DP)¶
在梯度中加入校准噪声,提供数学隐私保证:
| 参数 | 含义 |
|---|---|
| ε (epsilon) | 隐私预算,越小越隐私 |
| δ (delta) | 允许的失败概率 |
| C | 梯度裁剪阈值 |
| σ | 噪声幅度 |
5.3 安全聚合(Secure Aggregation)¶
通过密码学协议,使服务器只能看到聚合结果,无法获取单个客户端的更新:
客户端A: 上传 w_A + mask_A
客户端B: 上传 w_B + mask_B
...
服务器聚合后: Σ mask_i = 0 (掩码互相抵消)
服务器只得到: Σ w_i (无法分离)
5.4 拜占庭鲁棒聚合¶
| 方法 | 思路 |
|---|---|
| Median | 取各维度中位数 |
| Trimmed Mean | 去掉最大最小后取均值 |
| Krum | 选择与其他更新最相似的那个 |
| FLTrust | 服务器用小量可信数据验证更新方向 |
6. 个性化联邦学习¶
6.1 为什么需要个性化¶
全局模型在数据异构场景下可能对某些客户端表现不佳。每个客户端需要一个适合自身数据分布的模型。
6.2 主要方法¶
| 方法 | 策略 | 说明 |
|---|---|---|
| 本地微调 | 全局模型 + 本地微调 | 简单但可能过拟合 |
| Meta-Learning | 学共享初始化(MAML) | Per-FedAvg: 全局模型作为好的初始化 |
| 混合专家 | 本地模型 + 全局模型加权 | 根据样本自适应混合 |
| 部分共享 | 只聚合部分层 | FedPer: 共享底层,本地顶层 |
| 聚类 | 客户端分组,组内聚合 | 相似的客户端共享模型 |
6.3 FedPer示例思路¶
class FedPerModel(nn.Module): # 继承nn.Module定义神经网络层
def __init__(self): # __init__构造方法,创建对象时自动调用
super().__init__() # super()调用父类方法
# 共享层(参与联邦聚合)
self.shared_layers = nn.Sequential(
nn.Conv2d(3, 32, 3, padding=1),
nn.ReLU(),
nn.Conv2d(32, 64, 3, padding=1),
nn.ReLU(),
nn.AdaptiveAvgPool2d(1),
)
# 个性化层(仅本地训练)
self.personal_layers = nn.Sequential(
nn.Linear(64, 32),
nn.ReLU(),
nn.Linear(32, 10),
)
def forward(self, x):
x = self.shared_layers(x).flatten(1)
return self.personal_layers(x)
def get_shared_params(self):
return dict(self.shared_layers.named_parameters())
def get_personal_params(self):
return dict(self.personal_layers.named_parameters())
7. 联邦学习系统设计¶
7.1 实际工业部署架构¶
┌──────────────────┐
│ 协调服务器 │
│ - 模型聚合 │
│ - 客户端调度 │
│ - 版本管理 │
└──────┬───────────┘
│ HTTPS/gRPC
┌────────────────┼────────────────┐
▼ ▼ ▼
┌──────────┐ ┌──────────┐ ┌──────────┐
│ 客户端A │ │ 客户端B │ │ 客户端C │
│ 本地数据 │ │ 本地数据 │ │ 本地数据 │
│ 本地训练 │ │ 本地训练 │ │ 本地训练 │
└──────────┘ └──────────┘ └──────────┘
7.2 工业实践中的额外挑战¶
| 挑战 | 解决思路 |
|---|---|
| 设备异质性 | 允许不同客户端训练不同大小子模型 |
| 掉线/延迟 | 异步聚合、超时淘汰 |
| 模型版本 | 版本号管理、增量更新 |
| 激励机制 | 贡献度评估、奖励分配 |
7.3 主流框架¶
| 框架 | 开发者 | 特点 |
|---|---|---|
| Flower | Flower Labs | 灵活易用,支持多框架 |
| PySyft | OpenMined | 注重隐私,支持MPC |
| FATE | 微众银行 | 工业级,纵向联邦 |
| FedML | FedML Inc | 跨平台,支持多种拓扑 |
8. 实战:PyTorch联邦学习模拟¶
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Subset
import numpy as np
from copy import deepcopy
# ==================== 模型定义 ====================
class SimpleCNN(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(1, 16, 3, padding=1)
self.conv2 = nn.Conv2d(16, 32, 3, padding=1)
self.pool = nn.MaxPool2d(2)
self.fc1 = nn.Linear(32 * 7 * 7, 128)
self.fc2 = nn.Linear(128, 10)
def forward(self, x):
x = self.pool(F.relu(self.conv1(x)))
x = self.pool(F.relu(self.conv2(x)))
x = x.view(-1, 32 * 7 * 7) # view重塑张量形状(要求内存连续)
x = F.relu(self.fc1(x))
return self.fc2(x)
# ==================== 数据划分(模拟Non-IID) ====================
def create_non_iid_partitions(dataset, n_clients, n_shards_per_client=2):
"""按标签排序后切片,模拟Label-skew Non-IID"""
labels = np.array([dataset[i][1] for i in range(len(dataset))]) # 列表推导式,简洁创建列表 # np.array创建NumPy数组
sorted_indices = np.argsort(labels)
n_shards = n_clients * n_shards_per_client
shard_size = len(dataset) // n_shards
shards = [sorted_indices[i * shard_size:(i + 1) * shard_size]
for i in range(n_shards)]
np.random.shuffle(shards)
client_indices = {}
for i in range(n_clients):
client_indices[i] = np.concatenate(
shards[i * n_shards_per_client:(i + 1) * n_shards_per_client]
)
return client_indices
# ==================== 客户端本地训练 ====================
def client_update(model, dataloader, lr=0.01, local_epochs=5):
"""客户端本地训练"""
model.train() # train()开启训练模式
optimizer = torch.optim.SGD(model.parameters(), lr=lr)
for epoch in range(local_epochs):
for x, y in dataloader:
optimizer.zero_grad() # 清零梯度,防止梯度累积
loss = F.cross_entropy(model(x), y)
loss.backward() # 反向传播计算梯度
optimizer.step() # 根据梯度更新模型参数
return model.state_dict()
# ==================== 服务器聚合 ====================
def fedavg_aggregate(global_model, client_state_dicts, client_sizes):
"""FedAvg聚合:按样本数加权平均"""
total_size = sum(client_sizes)
global_state = global_model.state_dict()
for key in global_state:
global_state[key] = sum(
client_state_dicts[i][key].float() * (client_sizes[i] / total_size)
for i in range(len(client_state_dicts))
)
global_model.load_state_dict(global_state)
return global_model
# ==================== 联邦训练主循环 ====================
def federated_train(train_dataset, test_dataset, n_clients=10,
n_rounds=20, local_epochs=5, participation_rate=0.3):
"""完整的联邦训练流程"""
# 数据划分
client_indices = create_non_iid_partitions(train_dataset, n_clients)
# 初始化全局模型
global_model = SimpleCNN()
# 测试集loader
test_loader = DataLoader(test_dataset, batch_size=128)
for round_idx in range(n_rounds):
# 1. 选择参与客户端
n_selected = max(1, int(n_clients * participation_rate))
selected = np.random.choice(n_clients, n_selected, replace=False)
client_states = []
client_sizes = []
for client_id in selected:
# 2. 下发全局模型
local_model = deepcopy(global_model)
# 3. 本地训练
indices = client_indices[client_id]
local_dataset = Subset(train_dataset, indices)
local_loader = DataLoader(local_dataset, batch_size=32, shuffle=True)
state_dict = client_update(local_model, local_loader,
local_epochs=local_epochs)
client_states.append(state_dict)
client_sizes.append(len(indices))
# 4. 服务器聚合
global_model = fedavg_aggregate(global_model, client_states, client_sizes)
# 5. 评估
if (round_idx + 1) % 5 == 0:
global_model.eval() # eval()开启评估模式(关闭Dropout等)
correct, total = 0, 0
with torch.no_grad(): # 禁用梯度计算,节省内存
for x, y in test_loader:
pred = global_model(x).argmax(dim=1)
correct += (pred == y).sum().item() # .item()将单元素张量转为Python数值
total += y.size(0)
print(f"Round {round_idx+1}, Accuracy: {correct/total:.4f}")
return global_model
# 使用示例(需要torchvision的MNIST数据集)
# from torchvision import datasets, transforms
# train_set = datasets.MNIST('./data', train=True, download=True,
# transform=transforms.ToTensor())
# test_set = datasets.MNIST('./data', train=False,
# transform=transforms.ToTensor())
# model = federated_train(train_set, test_set)
9. 面试高频题¶
Q1: FedAvg算法的核心流程是什么?¶
答:每轮通信中:(1) 服务器随机选取部分客户端;(2) 广播当前全局模型;(3) 各客户端在本地数据上训练多个epoch;(4) 上传模型参数到服务器;(5) 服务器按样本数加权平均聚合为新的全局模型。与传统分布式SGD的关键区别是客户端进行多步本地更新,减少通信轮数。
Q2: Non-IID数据为什么影响联邦学习?¶
答:Non-IID导致各客户端的本地目标函数不一致。本地训练多步后,各客户端模型朝不同方向偏移(客户端漂移),聚合后的全局模型可能远离全局最优。标签分布偏斜越严重,性能下降越明显。解决方法包括FedProx(近端正则)、SCAFFOLD(方差修正)、数据共享等。
Q3: 联邦学习真的能保护隐私吗?¶
答:仅靠不传原始数据不够安全。梯度反演攻击(如DLG)可从梯度中恢复训练数据。需要额外的隐私保护机制:差分隐私(加噪声)、安全聚合(密码学协议)、同态加密等。实际部署通常需要结合多种技术。
Q4: 横向联邦和纵向联邦的区别?¶
答:横向联邦中各参与方拥有相同特征、不同样本(如不同医院的病历),将数据按行分割。纵向联邦中各参与方拥有相同样本、不同特征(如银行和电商对同一用户的不同数据),将数据按列分割。纵向联邦还需要实体对齐步骤。
Q5: 个性化联邦学习有哪些典型方法?¶
答:(1) 本地微调:全局模型在本地数据上微调;(2) 元学习:将全局模型作为好的初始化(Per-FedAvg);(3) 部分共享:底层聚合、顶层个性化(FedPer);(4) 客户端聚类:相似客户端共享模型;(5) 混合:本地与全局模型自适应插值。
10. 练习与自我检查¶
编程练习¶
- 基础:在MNIST上实现FedAvg,对比IID和Non-IID数据分区的性能差异
- 进阶:实现FedProx(加近端正则项),观察其在Non-IID下的改善效果
- 挑战:实现带差分隐私的联邦学习(DP-FedAvg),分析隐私预算ε与精度的权衡
检查清单¶
- 能说明联邦学习的定义和"数据不动模型动"的核心原则
- 理解横向、纵向、联邦迁移学习的区别
- 能写出FedAvg的完整算法流程
- 理解Non-IID数据的类型及其对训练的影响
- 知道客户端漂移的原因及FedProx/SCAFFOLD的解决思路
- 了解通信效率优化的主要手段
- 理解差分隐私和安全聚合的基本原理
- 能讨论联邦学习的隐私风险和防御方法
- 了解个性化联邦学习的主要方法
- 能用PyTorch实现简单的联邦学习模拟
扩展阅读: - McMahan et al., 2017: Communication-Efficient Learning of Deep Networks from Decentralized Data (FedAvg) - Li et al., 2020: Federated Optimization in Heterogeneous Networks (FedProx) - Kairouz et al., 2021: Advances and Open Problems in Federated Learning - Yang et al., 2019: Federated Machine Learning: Concept and Applications