03-图神经网络(GNN)¶
学习时间: 约6-8小时 难度级别: ⭐⭐⭐⭐ 中高级 前置知识: 深度学习基础、线性代数(特征值分解)、PyTorch 学习目标: 理解图数据的表示与建模方法,掌握GCN、GAT、GraphSAGE等核心模型的原理与实现
目录¶
- 1. 图数据与图学习概述
- 2. 谱域图卷积
- 3. GCN:图卷积网络
- 4. GAT:图注意力网络
- 5. GraphSAGE:归纳式学习
- 6. 消息传递框架
- 7. 图级任务与池化
- 8. 实战:用PyG进行节点分类
- 9. 面试高频题
- 10. 练习与自我检查
1. 图数据与图学习概述¶
1.1 什么是图数据¶
| 概念 | 说明 | 示例 |
|---|---|---|
| 节点 (Node) | 实体 | 用户、论文、原子 |
| 边 (Edge) | 实体间关系 | 好友关系、引用、化学键 |
| 特征 (Feature) | 节点/边的属性向量 | 用户画像、词嵌入 |
| 邻接矩阵 A | 描述图结构的矩阵 | \(A_{ij}=1\) 表示 \(i,j\) 相连 |
1.2 图学习的典型任务¶
图学习任务
├── 节点级 (Node-level)
│ ├── 节点分类:预测节点类别(论文主题分类)
│ └── 节点回归:预测节点属性(分子原子电荷)
├── 边级 (Edge-level)
│ └── 链接预测:预测缺失边(推荐系统、知识图谱补全)
└── 图级 (Graph-level)
└── 图分类:预测整张图的标签(分子毒性预测)
1.3 为什么CNN/RNN不够用¶
- 非欧几里得结构:图没有固定的网格或序列拓扑
- 节点数量不固定:不同图的大小不同
- 排列不变性:节点顺序任意排列,结果应一致
2. 谱域图卷积¶
2.1 图拉普拉斯矩阵¶
其中 \(D\) 为度矩阵(对角线为每个节点的度),\(A\) 为邻接矩阵。
归一化拉普拉斯矩阵:
2.2 从频域到空域¶
- 对 \(L\) 做特征分解:\(L = U \Lambda U^T\)
- 图信号 \(x\) 的图傅里叶变换:\(\hat{x} = U^T x\)
- 频域滤波:\(y = U g_\theta(\Lambda) U^T x\)
问题:特征分解计算复杂度为 \(O(n^3)\),无法处理大图。
2.3 ChebNet 的多项式近似¶
用切比雪夫多项式近似滤波器,避免显式特征分解:
复杂度降为 \(O(K|\mathcal{E}|)\),其中 \(|\mathcal{E}|\) 为边数。
3. GCN:图卷积网络¶
3.1 核心公式¶
当 ChebNet 取 \(K=1\) 并简化后得到 GCN 的逐层传播规则:
其中 \(\tilde{A} = A + I\)(加自环),\(\tilde{D}\) 为 \(\tilde{A}\) 的度矩阵。
3.2 直觉理解¶
每一层做了两件事: 1. 聚合:每个节点收集邻居的特征(包括自身) 2. 变换:通过可学习的权重矩阵 \(W\) 做线性变换 + 非线性激活
3.3 PyTorch 实现¶
import torch
import torch.nn as nn
import torch.nn.functional as F
class GCNLayer(nn.Module): # 继承nn.Module定义神经网络层
"""单层图卷积"""
def __init__(self, in_features, out_features): # __init__构造方法,创建对象时自动调用
super().__init__() # super()调用父类方法
self.weight = nn.Parameter(torch.FloatTensor(in_features, out_features))
self.bias = nn.Parameter(torch.FloatTensor(out_features))
nn.init.xavier_uniform_(self.weight)
nn.init.zeros_(self.bias)
def forward(self, x, adj_norm):
"""
x: 节点特征矩阵 [N, in_features]
adj_norm: 归一化邻接矩阵 D^{-1/2} A~ D^{-1/2}
"""
support = torch.mm(x, self.weight) # 线性变换
output = torch.spmm(adj_norm, support) # 邻居聚合
return output + self.bias
class GCN(nn.Module):
def __init__(self, n_feat, n_hidden, n_class, dropout=0.5):
super().__init__()
self.gc1 = GCNLayer(n_feat, n_hidden)
self.gc2 = GCNLayer(n_hidden, n_class)
self.dropout = dropout
def forward(self, x, adj):
x = F.relu(self.gc1(x, adj))
x = F.dropout(x, self.dropout, training=self.training)
x = self.gc2(x, adj)
return F.log_softmax(x, dim=1)
3.4 GCN 的局限性¶
| 局限 | 说明 |
|---|---|
| 过平滑 | 层数增加后所有节点表示趋于一致 |
| 固定聚合权重 | 对所有邻居一视同仁 |
| 转导学习 | 训练时需要完整图结构,无法处理新节点 |
4. GAT:图注意力网络¶
4.1 核心思想¶
为不同邻居分配不同的注意力权重:
4.2 多头注意力¶
类似 Transformer,使用 \(K\) 个独立注意力头并拼接:
4.3 PyTorch 实现¶
class GATLayer(nn.Module):
def __init__(self, in_features, out_features, n_heads=4, concat=True):
super().__init__()
self.n_heads = n_heads
self.concat = concat
self.W = nn.Linear(in_features, out_features * n_heads, bias=False)
self.a = nn.Parameter(torch.FloatTensor(n_heads, 2 * out_features))
self.leaky_relu = nn.LeakyReLU(0.2)
nn.init.xavier_uniform_(self.a)
self.out_features = out_features
def forward(self, x, edge_index):
"""
x: [N, in_features]
edge_index: [2, E] 边索引
"""
N = x.size(0)
h = self.W(x).view(N, self.n_heads, self.out_features) # [N, heads, out] # 链式调用,连续执行多个方法
src, dst = edge_index # 源节点与目标节点
h_src = h[src] # [E, heads, out]
h_dst = h[dst] # [E, heads, out]
# 计算注意力系数
cat_feat = torch.cat([h_src, h_dst], dim=-1) # [E, heads, 2*out]
e = (cat_feat * self.a.unsqueeze(0)).sum(dim=-1) # [E, heads] # unsqueeze增加一个维度
e = self.leaky_relu(e)
# softmax归一化(针对每个目标节点的所有入边)
# 使用scatter实现per-node softmax:对每个目标节点,其所有入边的注意力系数归一化
e_max = e.new_full((N, self.n_heads), -1e9)
e_max.scatter_reduce_(0, dst.unsqueeze(-1).expand_as(e), e, reduce='amax')
e_stable = torch.exp(e - e_max[dst]) # 数值稳定
e_sum = torch.zeros(N, self.n_heads, device=x.device)
e_sum.scatter_add_(0, dst.unsqueeze(-1).expand_as(e_stable), e_stable)
alpha = e_stable / (e_sum[dst] + 1e-16) # [E, heads]
# 聚合
out = torch.zeros(N, self.n_heads, self.out_features, device=x.device)
out.scatter_add_(0, dst.unsqueeze(-1).unsqueeze(-1).expand_as(h_src * alpha.unsqueeze(-1)),
h_src * alpha.unsqueeze(-1))
if self.concat:
return out.view(N, -1) # [N, heads * out]
else:
return out.mean(dim=1) # [N, out]
5. GraphSAGE:归纳式学习¶
5.1 与GCN的关键区别¶
| 特性 | GCN | GraphSAGE |
|---|---|---|
| 学习方式 | 转导(Transductive) | 归纳(Inductive) |
| 邻居采样 | 使用全部邻居 | 采样固定数量邻居 |
| 新节点 | 无法处理 | 可以泛化到新节点 |
| 聚合方式 | 固定(均值+归一化) | 可选(Mean/LSTM/Pool) |
5.2 算法流程¶
GraphSAGE 前向传播(K层):
for k = 1 to K:
for 每个节点 v:
1. 采样: 从N(v)中随机采样S个邻居
2. 聚合: h_N(v) = AGG({h_u, u ∈ Sample(N(v))})
3. 更新: h_v = σ(W · CONCAT(h_v, h_N(v)))
4. 归一化: h_v = h_v / ||h_v||
5.3 聚合器对比¶
| 聚合器 | 公式 | 特点 |
|---|---|---|
| Mean | \(\text{mean}(\{h_u\})\) | 简单高效 |
| Max Pool | \(\max(\sigma(W_{pool} h_u + b))\) | 捕捉突出特征 |
| LSTM | LSTM随机排列的邻居 | 表达力强但非排列不变 |
6. 消息传递框架¶
6.1 通用消息传递范式 (MPNN)¶
所有空域GNN可以统一为消息传递框架:
其中 \(M_t\) 为消息函数,\(U_t\) 为更新函数。
6.2 不同GNN在MPNN下的统一视角¶
| 模型 | 消息函数 | 聚合 | 更新 |
|---|---|---|---|
| GCN | \(\frac{1}{\sqrt{d_u d_v}} W h_u\) | Sum | ReLU |
| GAT | \(\alpha_{vu} W h_u\) | Sum | ELU |
| GIN | \(W h_u\) | Sum | MLP |
| GGNN | \(W h_u\) | Sum | GRU |
6.3 过平滑问题与缓解¶
过平滑:随着GNN层数增加,所有节点的表示会趋于相同。
缓解策略: - 残差连接:\(h_v^{(l+1)} = h_v^{(l+1)} + h_v^{(l)}\) - JumpKnowledge:拼接所有层的输出 - DropEdge:训练时随机删除部分边 - PairNorm / NodeNorm:特定的归一化方法
7. 图级任务与池化¶
7.1 图读出(Readout)¶
将节点级表示汇聚为图级表示:
常用方法:均值池化、求和池化、注意力加权池化。
7.2 层次化图池化¶
# DiffPool 核心思想:学习软聚类分配矩阵
class DiffPoolLayer(nn.Module):
def __init__(self, in_feat, n_clusters):
super().__init__()
self.gnn_embed = GCNLayer(in_feat, in_feat) # 特征GNN
self.gnn_pool = GCNLayer(in_feat, n_clusters) # 池化GNN
def forward(self, x, adj):
z = F.relu(self.gnn_embed(x, adj)) # 节点嵌入
s = F.softmax(self.gnn_pool(x, adj), dim=-1) # 分配矩阵 [N, K]
# 粗化
x_pooled = s.T @ z # [K, feat]
adj_pooled = s.T @ adj @ s # [K, K]
return x_pooled, adj_pooled
8. 实战:用PyG进行节点分类¶
8.1 使用 PyTorch Geometric¶
import torch
import torch.nn.functional as F
from torch_geometric.nn import GCNConv, GATConv
from torch_geometric.datasets import Planetoid
# 加载 Cora 数据集
dataset = Planetoid(root='./data', name='Cora')
data = dataset[0]
print(f"节点数: {data.num_nodes}")
print(f"边数: {data.num_edges}")
print(f"特征维度: {data.num_node_features}")
print(f"类别数: {dataset.num_classes}")
class GATNet(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv1 = GATConv(dataset.num_node_features, 8, heads=8)
self.conv2 = GATConv(64, dataset.num_classes, heads=1, concat=False)
def forward(self, data):
x, edge_index = data.x, data.edge_index
x = F.dropout(x, p=0.6, training=self.training)
x = F.elu(self.conv1(x, edge_index))
x = F.dropout(x, p=0.6, training=self.training)
x = self.conv2(x, edge_index)
return F.log_softmax(x, dim=1)
# 训练
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = GATNet().to(device) # .to(device)将数据移至GPU/CPU
data = data.to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.005, weight_decay=5e-4)
model.train() # train()开启训练模式
for epoch in range(200):
optimizer.zero_grad() # 清零梯度,防止梯度累积
out = model(data)
loss = F.nll_loss(out[data.train_mask], data.y[data.train_mask])
loss.backward() # 反向传播计算梯度
optimizer.step() # 根据梯度更新模型参数
# 测试
model.eval()
pred = model(data).argmax(dim=1)
correct = (pred[data.test_mask] == data.y[data.test_mask]).sum()
acc = int(correct) / int(data.test_mask.sum())
print(f"测试准确率: {acc:.4f}")
9. 面试高频题¶
Q1: GCN的核心公式是什么?它的直觉含义是什么?¶
答:GCN逐层传播公式为 \(H^{(l+1)} = \sigma(\tilde{D}^{-1/2}\tilde{A}\tilde{D}^{-1/2} H^{(l)} W^{(l)})\)。直觉上,每个节点在每一层收集自身与邻居的特征,经归一化加权平均后通过线性变换和非线性激活。加自环保留了自身信息,归一化防止度大的节点主导。
Q2: GAT与GCN的本质区别是什么?¶
答:GCN对所有邻居使用固定的归一化权重(基于度),GAT通过注意力机制为每个邻居学习自适应权重。GAT更灵活,能识别不同邻居的重要性差异,但计算代价更高。
Q3: 什么是过平滑问题?如何缓解?¶
答:过平滑是指GNN层数增加时,所有节点的表示趋于相同,因为每层聚合扩大了感受野,最终所有节点接收了相似的全局信息。缓解方法包括:残差连接、JumpKnowledge网络、DropEdge、PairNorm、以及限制GNN层数(通常2-3层)。
Q4: GraphSAGE如何实现归纳学习?¶
答:GraphSAGE学习的是聚合函数而非每个节点的嵌入。训练时通过采样邻居并聚合来生成表示,因此对于从未见过的新节点,只要知道其邻居,就能用学到的聚合函数计算其嵌入。
Q5: GNN在哪些领域有重要应用?¶
答:社交网络分析(社区发现、好友推荐)、推荐系统(用户-物品二部图)、分子性质预测(药物发现)、知识图谱推理、交通流预测、组合优化问题等。
10. 练习与自我检查¶
编程练习¶
- 基础:用纯PyTorch实现两层GCN,在Cora数据集上训练并评估
- 进阶:实现GAT并对比GCN在Cora上的结果
- 挑战:在分子数据集(如MUTAG)上实现图分类任务
检查清单¶
- 能解释图数据的基本组成(节点、边、邻接矩阵)
- 理解从谱域图卷积到GCN的推导过程
- 能手写GCN的核心前向传播代码
- 理解GAT的注意力计算机制
- 能说明GraphSAGE与GCN的关键区别
- 理解消息传递框架(MPNN)的统一视角
- 知道过平滑问题及至少3种缓解方法
- 能用PyTorch Geometric完成节点分类实战
- 了解GNN在实际场景中的应用
扩展阅读: - Kipf & Welling, 2017: Semi-Supervised Classification with Graph Convolutional Networks - Veličković et al., 2018: Graph Attention Networks - Hamilton et al., 2017: Inductive Representation Learning on Large Graphs - Xu et al., 2019: How Powerful are Graph Neural Networks? (GIN)




