跳转至

03-图神经网络(GNN)

学习时间: 约6-8小时 难度级别: ⭐⭐⭐⭐ 中高级 前置知识: 深度学习基础、线性代数(特征值分解)、PyTorch 学习目标: 理解图数据的表示与建模方法,掌握GCN、GAT、GraphSAGE等核心模型的原理与实现


目录


1. 图数据与图学习概述

GNN架构概览

1.1 什么是图数据

概念 说明 示例
节点 (Node) 实体 用户、论文、原子
边 (Edge) 实体间关系 好友关系、引用、化学键
特征 (Feature) 节点/边的属性向量 用户画像、词嵌入
邻接矩阵 A 描述图结构的矩阵 \(A_{ij}=1\) 表示 \(i,j\) 相连

1.2 图学习的典型任务

Text Only
图学习任务
├── 节点级 (Node-level)
│   ├── 节点分类:预测节点类别(论文主题分类)
│   └── 节点回归:预测节点属性(分子原子电荷)
├── 边级 (Edge-level)
│   └── 链接预测:预测缺失边(推荐系统、知识图谱补全)
└── 图级 (Graph-level)
    └── 图分类:预测整张图的标签(分子毒性预测)

1.3 为什么CNN/RNN不够用

  • 非欧几里得结构:图没有固定的网格或序列拓扑
  • 节点数量不固定:不同图的大小不同
  • 排列不变性:节点顺序任意排列,结果应一致

2. 谱域图卷积

2.1 图拉普拉斯矩阵

\[L = D - A\]

其中 \(D\) 为度矩阵(对角线为每个节点的度),\(A\) 为邻接矩阵。

归一化拉普拉斯矩阵:

\[\tilde{L} = I - D^{-1/2} A D^{-1/2}\]

2.2 从频域到空域

  • \(L\) 做特征分解:\(L = U \Lambda U^T\)
  • 图信号 \(x\) 的图傅里叶变换:\(\hat{x} = U^T x\)
  • 频域滤波:\(y = U g_\theta(\Lambda) U^T x\)

问题:特征分解计算复杂度为 \(O(n^3)\),无法处理大图。

2.3 ChebNet 的多项式近似

用切比雪夫多项式近似滤波器,避免显式特征分解:

\[g_\theta(\tilde{L}) \approx \sum_{k=0}^{K} \theta_k T_k(\tilde{L})\]

复杂度降为 \(O(K|\mathcal{E}|)\),其中 \(|\mathcal{E}|\) 为边数。


3. GCN:图卷积网络

GCN卷积操作

3.1 核心公式

当 ChebNet 取 \(K=1\) 并简化后得到 GCN 的逐层传播规则:

\[H^{(l+1)} = \sigma\left(\tilde{D}^{-1/2}\tilde{A}\tilde{D}^{-1/2} H^{(l)} W^{(l)}\right)\]

其中 \(\tilde{A} = A + I\)(加自环),\(\tilde{D}\)\(\tilde{A}\) 的度矩阵。

3.2 直觉理解

每一层做了两件事: 1. 聚合:每个节点收集邻居的特征(包括自身) 2. 变换:通过可学习的权重矩阵 \(W\) 做线性变换 + 非线性激活

3.3 PyTorch 实现

Python
import torch
import torch.nn as nn
import torch.nn.functional as F

class GCNLayer(nn.Module):  # 继承nn.Module定义神经网络层
    """单层图卷积"""
    def __init__(self, in_features, out_features):  # __init__构造方法,创建对象时自动调用
        super().__init__()  # super()调用父类方法
        self.weight = nn.Parameter(torch.FloatTensor(in_features, out_features))
        self.bias = nn.Parameter(torch.FloatTensor(out_features))
        nn.init.xavier_uniform_(self.weight)
        nn.init.zeros_(self.bias)

    def forward(self, x, adj_norm):
        """
        x: 节点特征矩阵 [N, in_features]
        adj_norm: 归一化邻接矩阵 D^{-1/2} A~ D^{-1/2}
        """
        support = torch.mm(x, self.weight)       # 线性变换
        output = torch.spmm(adj_norm, support)    # 邻居聚合
        return output + self.bias

class GCN(nn.Module):
    def __init__(self, n_feat, n_hidden, n_class, dropout=0.5):
        super().__init__()
        self.gc1 = GCNLayer(n_feat, n_hidden)
        self.gc2 = GCNLayer(n_hidden, n_class)
        self.dropout = dropout

    def forward(self, x, adj):
        x = F.relu(self.gc1(x, adj))
        x = F.dropout(x, self.dropout, training=self.training)
        x = self.gc2(x, adj)
        return F.log_softmax(x, dim=1)

3.4 GCN 的局限性

局限 说明
过平滑 层数增加后所有节点表示趋于一致
固定聚合权重 对所有邻居一视同仁
转导学习 训练时需要完整图结构,无法处理新节点

4. GAT:图注意力网络

GAT注意力机制

4.1 核心思想

为不同邻居分配不同的注意力权重

\[\alpha_{ij} = \frac{\exp(\text{LeakyReLU}(\vec{a}^T [Wh_i \| Wh_j]))}{\sum_{k \in \mathcal{N}(i)} \exp(\text{LeakyReLU}(\vec{a}^T [Wh_i \| Wh_k]))}\]
\[h_i' = \sigma\left(\sum_{j \in \mathcal{N}(i)} \alpha_{ij} W h_j\right)\]

4.2 多头注意力

类似 Transformer,使用 \(K\) 个独立注意力头并拼接:

\[h_i' = \Big\|_{k=1}^{K} \sigma\left(\sum_{j \in \mathcal{N}(i)} \alpha_{ij}^k W^k h_j\right)\]

4.3 PyTorch 实现

Python
class GATLayer(nn.Module):
    def __init__(self, in_features, out_features, n_heads=4, concat=True):
        super().__init__()
        self.n_heads = n_heads
        self.concat = concat
        self.W = nn.Linear(in_features, out_features * n_heads, bias=False)
        self.a = nn.Parameter(torch.FloatTensor(n_heads, 2 * out_features))
        self.leaky_relu = nn.LeakyReLU(0.2)
        nn.init.xavier_uniform_(self.a)
        self.out_features = out_features

    def forward(self, x, edge_index):
        """
        x: [N, in_features]
        edge_index: [2, E] 边索引
        """
        N = x.size(0)
        h = self.W(x).view(N, self.n_heads, self.out_features)  # [N, heads, out]  # 链式调用,连续执行多个方法

        src, dst = edge_index  # 源节点与目标节点
        h_src = h[src]  # [E, heads, out]
        h_dst = h[dst]  # [E, heads, out]

        # 计算注意力系数
        cat_feat = torch.cat([h_src, h_dst], dim=-1)  # [E, heads, 2*out]
        e = (cat_feat * self.a.unsqueeze(0)).sum(dim=-1)  # [E, heads]  # unsqueeze增加一个维度
        e = self.leaky_relu(e)

        # softmax归一化(针对每个目标节点的所有入边)
        # 使用scatter实现per-node softmax:对每个目标节点,其所有入边的注意力系数归一化
        e_max = e.new_full((N, self.n_heads), -1e9)
        e_max.scatter_reduce_(0, dst.unsqueeze(-1).expand_as(e), e, reduce='amax')
        e_stable = torch.exp(e - e_max[dst])  # 数值稳定
        e_sum = torch.zeros(N, self.n_heads, device=x.device)
        e_sum.scatter_add_(0, dst.unsqueeze(-1).expand_as(e_stable), e_stable)
        alpha = e_stable / (e_sum[dst] + 1e-16)  # [E, heads]

        # 聚合
        out = torch.zeros(N, self.n_heads, self.out_features, device=x.device)
        out.scatter_add_(0, dst.unsqueeze(-1).unsqueeze(-1).expand_as(h_src * alpha.unsqueeze(-1)),
                         h_src * alpha.unsqueeze(-1))

        if self.concat:
            return out.view(N, -1)  # [N, heads * out]
        else:
            return out.mean(dim=1)  # [N, out]

5. GraphSAGE:归纳式学习

GraphSAGE采样策略

5.1 与GCN的关键区别

特性 GCN GraphSAGE
学习方式 转导(Transductive) 归纳(Inductive)
邻居采样 使用全部邻居 采样固定数量邻居
新节点 无法处理 可以泛化到新节点
聚合方式 固定(均值+归一化) 可选(Mean/LSTM/Pool)

5.2 算法流程

Text Only
GraphSAGE 前向传播(K层):
for k = 1 to K:
    for 每个节点 v:
        1. 采样: 从N(v)中随机采样S个邻居
        2. 聚合: h_N(v) = AGG({h_u, u ∈ Sample(N(v))})
        3. 更新: h_v = σ(W · CONCAT(h_v, h_N(v)))
        4. 归一化: h_v = h_v / ||h_v||

5.3 聚合器对比

聚合器 公式 特点
Mean \(\text{mean}(\{h_u\})\) 简单高效
Max Pool \(\max(\sigma(W_{pool} h_u + b))\) 捕捉突出特征
LSTM LSTM随机排列的邻居 表达力强但非排列不变

6. 消息传递框架

6.1 通用消息传递范式 (MPNN)

所有空域GNN可以统一为消息传递框架:

\[m_v^{(t+1)} = \sum_{u \in \mathcal{N}(v)} M_t(h_v^{(t)}, h_u^{(t)}, e_{vu})\]
\[h_v^{(t+1)} = U_t(h_v^{(t)}, m_v^{(t+1)})\]

其中 \(M_t\) 为消息函数,\(U_t\) 为更新函数。

6.2 不同GNN在MPNN下的统一视角

模型 消息函数 聚合 更新
GCN \(\frac{1}{\sqrt{d_u d_v}} W h_u\) Sum ReLU
GAT \(\alpha_{vu} W h_u\) Sum ELU
GIN \(W h_u\) Sum MLP
GGNN \(W h_u\) Sum GRU

6.3 过平滑问题与缓解

过平滑:随着GNN层数增加,所有节点的表示会趋于相同。

缓解策略: - 残差连接\(h_v^{(l+1)} = h_v^{(l+1)} + h_v^{(l)}\) - JumpKnowledge:拼接所有层的输出 - DropEdge:训练时随机删除部分边 - PairNorm / NodeNorm:特定的归一化方法


7. 图级任务与池化

7.1 图读出(Readout)

将节点级表示汇聚为图级表示:

\[h_G = \text{READOUT}(\{h_v | v \in G\})\]

常用方法:均值池化、求和池化、注意力加权池化。

7.2 层次化图池化

Python
# DiffPool 核心思想:学习软聚类分配矩阵
class DiffPoolLayer(nn.Module):
    def __init__(self, in_feat, n_clusters):
        super().__init__()
        self.gnn_embed = GCNLayer(in_feat, in_feat)   # 特征GNN
        self.gnn_pool = GCNLayer(in_feat, n_clusters)  # 池化GNN

    def forward(self, x, adj):
        z = F.relu(self.gnn_embed(x, adj))    # 节点嵌入
        s = F.softmax(self.gnn_pool(x, adj), dim=-1)  # 分配矩阵 [N, K]

        # 粗化
        x_pooled = s.T @ z      # [K, feat]
        adj_pooled = s.T @ adj @ s  # [K, K]
        return x_pooled, adj_pooled

8. 实战:用PyG进行节点分类

8.1 使用 PyTorch Geometric

Python
import torch
import torch.nn.functional as F
from torch_geometric.nn import GCNConv, GATConv
from torch_geometric.datasets import Planetoid

# 加载 Cora 数据集
dataset = Planetoid(root='./data', name='Cora')
data = dataset[0]

print(f"节点数: {data.num_nodes}")
print(f"边数: {data.num_edges}")
print(f"特征维度: {data.num_node_features}")
print(f"类别数: {dataset.num_classes}")

class GATNet(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = GATConv(dataset.num_node_features, 8, heads=8)
        self.conv2 = GATConv(64, dataset.num_classes, heads=1, concat=False)

    def forward(self, data):
        x, edge_index = data.x, data.edge_index
        x = F.dropout(x, p=0.6, training=self.training)
        x = F.elu(self.conv1(x, edge_index))
        x = F.dropout(x, p=0.6, training=self.training)
        x = self.conv2(x, edge_index)
        return F.log_softmax(x, dim=1)

# 训练
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = GATNet().to(device)  # .to(device)将数据移至GPU/CPU
data = data.to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.005, weight_decay=5e-4)

model.train()  # train()开启训练模式
for epoch in range(200):
    optimizer.zero_grad()  # 清零梯度,防止梯度累积
    out = model(data)
    loss = F.nll_loss(out[data.train_mask], data.y[data.train_mask])
    loss.backward()  # 反向传播计算梯度
    optimizer.step()  # 根据梯度更新模型参数

# 测试
model.eval()
pred = model(data).argmax(dim=1)
correct = (pred[data.test_mask] == data.y[data.test_mask]).sum()
acc = int(correct) / int(data.test_mask.sum())
print(f"测试准确率: {acc:.4f}")

9. 面试高频题

Q1: GCN的核心公式是什么?它的直觉含义是什么?

:GCN逐层传播公式为 \(H^{(l+1)} = \sigma(\tilde{D}^{-1/2}\tilde{A}\tilde{D}^{-1/2} H^{(l)} W^{(l)})\)。直觉上,每个节点在每一层收集自身与邻居的特征,经归一化加权平均后通过线性变换和非线性激活。加自环保留了自身信息,归一化防止度大的节点主导。

Q2: GAT与GCN的本质区别是什么?

:GCN对所有邻居使用固定的归一化权重(基于度),GAT通过注意力机制为每个邻居学习自适应权重。GAT更灵活,能识别不同邻居的重要性差异,但计算代价更高。

Q3: 什么是过平滑问题?如何缓解?

:过平滑是指GNN层数增加时,所有节点的表示趋于相同,因为每层聚合扩大了感受野,最终所有节点接收了相似的全局信息。缓解方法包括:残差连接、JumpKnowledge网络、DropEdge、PairNorm、以及限制GNN层数(通常2-3层)。

Q4: GraphSAGE如何实现归纳学习?

:GraphSAGE学习的是聚合函数而非每个节点的嵌入。训练时通过采样邻居并聚合来生成表示,因此对于从未见过的新节点,只要知道其邻居,就能用学到的聚合函数计算其嵌入。

Q5: GNN在哪些领域有重要应用?

:社交网络分析(社区发现、好友推荐)、推荐系统(用户-物品二部图)、分子性质预测(药物发现)、知识图谱推理、交通流预测、组合优化问题等。


10. 练习与自我检查

编程练习

  1. 基础:用纯PyTorch实现两层GCN,在Cora数据集上训练并评估
  2. 进阶:实现GAT并对比GCN在Cora上的结果
  3. 挑战:在分子数据集(如MUTAG)上实现图分类任务

检查清单

  • 能解释图数据的基本组成(节点、边、邻接矩阵)
  • 理解从谱域图卷积到GCN的推导过程
  • 能手写GCN的核心前向传播代码
  • 理解GAT的注意力计算机制
  • 能说明GraphSAGE与GCN的关键区别
  • 理解消息传递框架(MPNN)的统一视角
  • 知道过平滑问题及至少3种缓解方法
  • 能用PyTorch Geometric完成节点分类实战
  • 了解GNN在实际场景中的应用

GNN应用场景

扩展阅读: - Kipf & Welling, 2017: Semi-Supervised Classification with Graph Convolutional Networks - Veličković et al., 2018: Graph Attention Networks - Hamilton et al., 2017: Inductive Representation Learning on Large Graphs - Xu et al., 2019: How Powerful are Graph Neural Networks? (GIN)