跳转至

17 - 图神经网络 (GNN)

图神经网络图

🌐 什么是图数据?

图的基本概念

Text Only
图 G = (V, E)
- V: 节点集合 (Vertices/Nodes)
- E: 边集合 (Edges)

属性:
- 节点特征:X_v ∈ R^d
- 边特征:e_uv ∈ R^k
- 图标签:y_G

图数据示例

Text Only
社交网络:
- 节点:用户
- 边:好友关系
- 特征:用户画像

分子结构:
- 节点:原子
- 边:化学键
- 任务:分子性质预测

推荐系统:
- 节点:用户、物品
- 边:交互行为
- 任务:链接预测

知识图谱:
- 节点:实体
- 边:关系
- 任务:知识推理

为什么需要GNN?

Text Only
传统ML的问题:
❌ 无法处理不规则数据结构
❌ 无法利用关系信息
❌ 节点没有固定顺序(置换不变性)

GNN的优势:
✅ 直接处理图结构数据
✅ 学习节点和边的表示
✅ 保持置换不变性
✅ 支持归纳学习

🧠 图神经网络基础

消息传递框架 (Message Passing)

Text Only
核心思想:邻居聚合

对于每个节点v:
1. 收集邻居信息(消息)
2. 聚合消息
3. 更新自身表示

数学表达:
h_v^(l+1) = UPDATE^(l)(h_v^(l), AGGREGATE^(l)({h_u^(l), ∀u ∈ N(v)}))

其中:
- h_v^(l): 节点v在第l层的特征
- N(v): 节点v的邻居集合
- AGGREGATE: 聚合函数(mean, sum, max)
- UPDATE: 更新函数(通常用神经网络)

GNN通用框架

Python
import torch
import torch.nn as nn
import torch.nn.functional as F

class GNNLayer(nn.Module):  # 继承nn.Module定义神经网络层
    def __init__(self, in_dim, out_dim):  # __init__构造方法,创建对象时自动调用
        super().__init__()  # super()调用父类方法
        self.linear = nn.Linear(in_dim, out_dim)

    def forward(self, x, adj):
        """
        x: 节点特征 (N, in_dim)
        adj: 邻接矩阵 (N, N)
        """
        # 1. 线性变换
        x = self.linear(x)

        # 2. 聚合邻居信息
        x = torch.matmul(adj, x)  # 邻接矩阵乘法 = 邻居聚合

        # 3. 激活函数
        x = F.relu(x)

        return x

🎯 经典GNN模型

GCN (Graph Convolutional Network)

核心思想

谱图卷积的局部一阶近似

Text Only
卷积公式:
H^(l+1) = σ(D̃^(-1/2) Ã D̃^(-1/2) H^(l) W^(l))

其中:
- Ã = A + I: 添加自环的邻接矩阵
- D̃: Ã的度矩阵
- D̃^(-1/2) Ã D̃^(-1/2): 对称归一化
- W^(l): 可学习参数
- σ: 激活函数

直观理解

Text Only
归一化的作用:
- 防止度数大的节点特征过大
- 平衡不同度数节点的影响

聚合过程:
节点v的新特征 = 平均(邻居特征) + 自身特征

代码实现

Python
import torch
import torch.nn as nn
import torch.nn.functional as F

class GCNLayer(nn.Module):
    def __init__(self, in_features, out_features):
        super().__init__()
        self.linear = nn.Linear(in_features, out_features)

    def forward(self, x, adj):
        """
        x: (N, in_features)
        adj: 归一化后的邻接矩阵 (N, N)
        """
        # 线性变换
        support = self.linear(x)

        # 图卷积:聚合邻居
        output = torch.matmul(adj, support)

        return output

class GCN(nn.Module):
    def __init__(self, nfeat, nhid, nclass, dropout=0.5):
        super().__init__()
        self.gc1 = GCNLayer(nfeat, nhid)
        self.gc2 = GCNLayer(nhid, nclass)
        self.dropout = dropout

    def forward(self, x, adj):
        x = F.relu(self.gc1(x, adj))
        x = F.dropout(x, self.dropout, training=self.training)
        x = self.gc2(x, adj)
        return F.log_softmax(x, dim=1)

# 归一化邻接矩阵
def normalize_adj(adj):
    """对称归一化邻接矩阵"""
    adj = adj + torch.eye(adj.size(0))  # 添加自环
    deg = adj.sum(dim=1)
    deg_inv_sqrt = torch.pow(deg, -0.5)
    deg_inv_sqrt[torch.isinf(deg_inv_sqrt)] = 0
    adj_norm = deg_inv_sqrt.unsqueeze(1) * adj * deg_inv_sqrt.unsqueeze(0)  # unsqueeze增加一个维度
    return adj_norm

GraphSAGE

核心创新

归纳式学习 + 多种聚合方式

Text Only
GCN的问题:
- 需要整个图的邻接矩阵
- 无法处理新节点

GraphSAGE的解决:
- 采样固定数量的邻居
- 支持归纳学习(新节点无需重新训练)

算法流程

Text Only
对于每个节点v:
1. 采样:从邻居中采样固定数量k
2. 聚合:聚合邻居特征
   - Mean aggregator
   - LSTM aggregator
   - Pooling aggregator
3. 拼接:将聚合结果与自身特征拼接
4. 变换:线性变换 + 非线性激活

代码实现

Python
class SAGELayer(nn.Module):
    def __init__(self, in_dim, out_dim, aggregator='mean'):
        super().__init__()
        self.aggregator = aggregator
        self.linear = nn.Linear(in_dim * 2, out_dim)

    def forward(self, x, adj, sample_size=10):
        """
        x: (N, in_dim)
        adj: 邻接矩阵 (N, N)
        """
        N = x.size(0)
        h_neigh = []

        for i in range(N):
            # 获取邻居
            neighbors = adj[i].nonzero(as_tuple=True)[0]

            # 采样邻居
            if len(neighbors) > sample_size:
                sampled = neighbors[torch.randperm(len(neighbors))[:sample_size]]
            else:
                sampled = neighbors

            # 聚合邻居特征
            if len(sampled) > 0:
                neigh_feat = x[sampled]
                if self.aggregator == 'mean':
                    agg = neigh_feat.mean(dim=0)
                elif self.aggregator == 'max':
                    agg = neigh_feat.max(dim=0)[0]
            else:
                agg = torch.zeros_like(x[i])

            h_neigh.append(agg)

        h_neigh = torch.stack(h_neigh)

        # 拼接自身和邻居特征
        h_concat = torch.cat([x, h_neigh], dim=1)

        # 线性变换
        h = F.relu(self.linear(h_concat))

        return h

GAT (Graph Attention Network)

核心创新

引入注意力机制,为不同邻居分配不同权重

Text Only
GCN/GraphSAGE的问题:
- 所有邻居权重相同
- 无法区分邻居重要性

GAT的解决:
- 学习邻居的注意力权重
- 重要邻居贡献更大

注意力机制

Text Only
对于节点i和邻居j:

注意力系数:
e_ij = LeakyReLU(a^T [Wh_i || Wh_j])

归一化:
α_ij = softmax_j(e_ij) = exp(e_ij) / Σ_k exp(e_ik)

聚合:
h_i' = σ(Σ_j α_ij Wh_j)

多头注意力:
h_i' = ||_{k=1}^K σ(Σ_j α_ij^k W^k h_j)

代码实现

Python
class GATLayer(nn.Module):
    def __init__(self, in_dim, out_dim, num_heads=8, dropout=0.6):
        super().__init__()
        self.num_heads = num_heads
        self.out_dim = out_dim
        self.head_dim = out_dim // num_heads

        # 每个头独立的线性变换
        self.W = nn.Linear(in_dim, out_dim, bias=False)

        # 注意力参数
        self.a = nn.Parameter(torch.Tensor(num_heads, 2 * self.head_dim, 1))

        self.dropout = nn.Dropout(dropout)
        self.leakyrelu = nn.LeakyReLU(0.2)

        nn.init.xavier_uniform_(self.a)

    def forward(self, x, adj):
        """
        x: (N, in_dim)
        adj: 邻接矩阵 (N, N)
        """
        N = x.size(0)

        # 线性变换
        h = self.W(x)  # (N, out_dim)
        h = h.view(N, self.num_heads, self.head_dim)  # (N, num_heads, head_dim)

        # 计算注意力系数
        # 重复h以计算所有点对
        h_i = h.unsqueeze(2).repeat(1, 1, N, 1)  # (N, num_heads, N, head_dim)
        h_j = h.unsqueeze(0).repeat(N, 1, 1, 1)  # (N, num_heads, N, head_dim)

        # 拼接
        concat = torch.cat([h_i, h_j], dim=-1)  # (N, num_heads, N, 2*head_dim)

        # 计算注意力分数
        e = self.leakyrelu(torch.matmul(concat, self.a)).squeeze(-1)  # (N, num_heads, N)

        # Mask:只保留邻居
        mask = adj.unsqueeze(1).expand(-1, self.num_heads, -1)
        e = e.masked_fill(mask == 0, float('-inf'))

        # Softmax
        alpha = F.softmax(e, dim=2)  # (N, num_heads, N)
        alpha = self.dropout(alpha)

        # 聚合
        h_prime = torch.matmul(alpha, h)  # (N, num_heads, head_dim)
        h_prime = h_prime.view(N, -1)  # (N, out_dim)

        return F.elu(h_prime)

📊 GNN任务类型

节点分类 (Node Classification)

Text Only
任务:预测每个节点的类别
示例:社交网络用户分类、论文主题分类

输出:每个节点的类别概率
损失:交叉熵损失
Python
# 节点分类示例
class NodeClassifier(nn.Module):
    def __init__(self, nfeat, nhid, nclass):
        super().__init__()
        self.gcn1 = GCNLayer(nfeat, nhid)
        self.gcn2 = GCNLayer(nhid, nclass)

    def forward(self, x, adj):
        h = F.relu(self.gcn1(x, adj))
        h = self.gcn2(h, adj)
        return F.log_softmax(h, dim=1)

# 训练
model = NodeClassifier(nfeat=1433, nhid=16, nclass=7)
optimizer = torch.optim.Adam(model.parameters())

for epoch in range(200):
    model.train()  # train()开启训练模式
    optimizer.zero_grad()  # 清零梯度,防止梯度累积
    output = model(features, adj_norm)
    loss = F.nll_loss(output[train_mask], labels[train_mask])
    loss.backward()  # 反向传播计算梯度
    optimizer.step()  # 根据梯度更新模型参数

图分类 (Graph Classification)

Text Only
任务:预测整个图的类别
示例:分子性质预测、蛋白质功能分类

方法:
1. 节点嵌入 → 2. 图池化 → 3. 分类
Python
class GraphClassifier(nn.Module):
    def __init__(self, nfeat, nhid, nclass):
        super().__init__()
        self.conv1 = GCNLayer(nfeat, nhid)
        self.conv2 = GCNLayer(nhid, nhid)
        self.fc = nn.Linear(nhid, nclass)

    def forward(self, x, adj, batch):
        """
        batch: 指示每个节点属于哪个图
        """
        # 图卷积
        h = F.relu(self.conv1(x, adj))
        h = self.conv2(h, adj)

        # 全局平均池化
        h = global_mean_pool(h, batch)

        # 分类
        out = self.fc(h)
        return F.log_softmax(out, dim=1)

def global_mean_pool(x, batch):
    """全局平均池化
    x: 节点特征 (N, d)
    batch: 批次索引 (N,),指示每个节点属于哪个图
    """
    # 方法1: 使用PyTorch Geometric
    # from torch_scatter import scatter_mean
    # return scatter_mean(x, batch, dim=0)

    # 方法2: 使用PyTorch原生实现
    num_graphs = batch.max().item() + 1
    out = torch.zeros(num_graphs, x.size(1), device=x.device)
    for i in range(num_graphs):
        mask = (batch == i)
        out[i] = x[mask].mean(dim=0)
    return out
Text Only
任务:预测两个节点之间是否存在边
示例:推荐系统、知识图谱补全

方法:
1. 学习节点嵌入
2. 计算节点对之间的相似度
3. 预测链接存在概率
Python
class LinkPredictor(nn.Module):
    def __init__(self, nfeat, nhid):
        super().__init__()
        self.encoder = GCN(nfeat, nhid, nhid)
        self.decoder = nn.Linear(nhid * 2, 1)

    def forward(self, x, adj, edge_index):
        # 编码节点
        z = self.encoder(x, adj)

        # 解码边
        src, dst = edge_index
        z_src = z[src]
        z_dst = z[dst]

        # 拼接并预测
        z_edge = torch.cat([z_src, z_dst], dim=1)
        pred = torch.sigmoid(self.decoder(z_edge))

        return pred

🛠️ GNN实践技巧

使用PyTorch Geometric

Bash
pip install torch-geometric
Python
import torch
from torch_geometric.nn import GCNConv, GATConv, global_mean_pool
from torch_geometric.data import Data
from torch_geometric.loader import DataLoader  # PyG 2.0+ 将 DataLoader 移至 loader 模块

# 定义数据
edge_index = torch.tensor([[0, 1, 1, 2],
                           [1, 0, 2, 1]], dtype=torch.long)
x = torch.tensor([[-1], [0], [1]], dtype=torch.float)  # 每个节点一个标量特征

data = Data(x=x, edge_index=edge_index)

# 定义模型
class GNN(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = GCNConv(dataset.num_features, 16)
        self.conv2 = GCNConv(16, dataset.num_classes)

    def forward(self, data):
        x, edge_index = data.x, data.edge_index

        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = F.dropout(x, training=self.training)
        x = self.conv2(x, edge_index)

        return F.log_softmax(x, dim=1)

处理大图

Text Only
大图挑战:
- 内存:无法存储整个图的邻接矩阵
- 计算:无法一次性前向传播

解决方案:
1. 邻居采样 (Neighbor Sampling)
2. 图采样 (GraphSAGE风格)
3. 子图采样 (Cluster-GCN)
Python
from torch_geometric.loader import NeighborLoader

# 邻居采样
loader = NeighborLoader(
    data,
    num_neighbors=[10, 10],  # 每层采样10个邻居
    batch_size=64,
    input_nodes=train_mask,
)

for batch in loader:
    optimizer.zero_grad()
    out = model(batch.x, batch.edge_index)
    loss = criterion(out, batch.y)
    loss.backward()
    optimizer.step()

📈 GNN应用案例

1. 分子性质预测

Python
from torch_geometric.datasets import MoleculeNet
from torch_geometric.nn import GCNConv, global_mean_pool

# 加载分子数据集
dataset = MoleculeNet(root='./data', name='ESOL')

class MolecularGNN(torch.nn.Module):
    def __init__(self, hidden_dim=64):
        super().__init__()
        self.conv1 = GCNConv(dataset.num_features, hidden_dim)
        self.conv2 = GCNConv(hidden_dim, hidden_dim)
        self.fc = torch.nn.Linear(hidden_dim, 1)

    def forward(self, data):
        x, edge_index, batch = data.x, data.edge_index, data.batch

        x = F.relu(self.conv1(x, edge_index))
        x = self.conv2(x, edge_index)

        # 全局池化
        x = global_mean_pool(x, batch)

        return self.fc(x).squeeze()

2. 推荐系统

Python
class GNNRecommender(torch.nn.Module):
    def __init__(self, num_users, num_items, emb_dim=64):
        super().__init__()
        self.num_users = num_users
        self.user_emb = torch.nn.Embedding(num_users, emb_dim)
        self.item_emb = torch.nn.Embedding(num_items, emb_dim)

        self.conv1 = GCNConv(emb_dim, emb_dim)
        self.conv2 = GCNConv(emb_dim, emb_dim)

    def forward(self, edge_index, user_ids, item_ids):
        # 构建节点特征
        x = torch.cat([self.user_emb.weight, self.item_emb.weight], dim=0)

        # 图卷积
        x = F.relu(self.conv1(x, edge_index))
        x = self.conv2(x, edge_index)

        # 预测评分
        user_emb = x[user_ids]
        item_emb = x[self.num_users + item_ids]

        return (user_emb * item_emb).sum(dim=1)

💡 总结

Text Only
GNN的核心:
1. 消息传递:聚合邻居信息
2. 置换不变性:节点顺序不影响结果
3. 归纳能力:支持新节点

经典模型演进:
GCN → GraphSAGE → GAT
↓       ↓          ↓
谱方法  归纳学习   注意力机制

应用方向:
- 节点分类:社交网络分析
- 图分类:分子性质预测
- 链接预测:推荐系统
- 图生成:药物发现

学习建议:
1. 先理解消息传递框架
2. 掌握GCN、GAT核心原理
3. 使用PyTorch Geometric实践
4. 从节点分类任务开始

下一步:学习 18-NLP与Transformer详解.md,掌握自然语言处理核心技术!