02 - UNet架构详解¶
学习时间: 4小时 重要性: ⭐⭐⭐⭐⭐ 扩散模型的核心网络结构
🎯 学习目标¶
完成本章后,你将能够: - 理解UNet的编码器-解码器结构 - 掌握残差块和注意力机制的实现 - 实现完整的UNet网络 - 理解为什么UNet适合扩散模型
1. UNet架构概述¶
1.1 什么是UNet¶
UNet是一种编码器-解码器结构的卷积神经网络,最初用于医学图像分割。它的特点是: - U形结构:先下采样(编码),再上采样(解码) - 跳跃连接:将编码器的特征直接连接到解码器 - 多尺度特征:融合不同层次的特征
1.2 为什么扩散模型使用UNet¶
- 保持空间信息:跳跃连接保留了细节信息
- 多尺度处理:不同层次的特征对去噪都有帮助
- 输入输出同尺寸:适合像素级预测任务
- 计算效率:结构对称,易于实现
1.3 UNet结构示意图¶
Text Only
输入 x_t
↓
┌─────────────────────────────────────────┐
│ 编码器 (Encoder) │
│ │
│ Conv ──→ Conv ──→ MaxPool ──→ ... │
│ ↓ ↓ ↓ │
│ f1 f2 fN │
└─────────────────────────────────────────┘
↓
bottleneck
↓
┌─────────────────────────────────────────┐
│ 解码器 (Decoder) │
│ │
│ UpSample ──→ Conv ──→ Conv ──→ ... │
│ ↑ ↑ │
│ fN f2 f1 │
│ (跳跃连接) │
└─────────────────────────────────────────┘
↓
输出 ε_θ
2. 基础组件实现¶
2.1 残差块(Residual Block)¶
残差块是UNet的基本构建单元,包含两个卷积层和跳跃连接。
Python
import torch
import torch.nn as nn
import torch.nn.functional as F
class ResidualBlock(nn.Module): # 继承nn.Module定义网络层
"""
残差块
"""
def __init__(self, in_channels, out_channels, time_emb_dim, dropout=0.1):
super().__init__() # super()调用父类方法
# 第一个卷积层
self.conv1 = nn.Sequential(
nn.GroupNorm(8, in_channels),
nn.SiLU(),
nn.Conv2d(in_channels, out_channels, 3, padding=1)
)
# 时间嵌入投影
self.time_mlp = nn.Sequential(
nn.SiLU(),
nn.Linear(time_emb_dim, out_channels)
)
# 第二个卷积层
self.conv2 = nn.Sequential(
nn.GroupNorm(8, out_channels),
nn.SiLU(),
nn.Dropout(dropout),
nn.Conv2d(out_channels, out_channels, 3, padding=1)
)
# 残差连接
if in_channels != out_channels:
self.shortcut = nn.Conv2d(in_channels, out_channels, 1)
else:
self.shortcut = nn.Identity()
def forward(self, x, t_emb):
"""
参数:
x: 输入特征 [B, C, H, W]
t_emb: 时间嵌入 [B, time_emb_dim]
返回:
输出特征 [B, out_channels, H, W]
"""
h = self.conv1(x)
# 添加时间嵌入
h = h + self.time_mlp(t_emb)[:, :, None, None]
h = self.conv2(h)
return h + self.shortcut(x)
# 测试
block = ResidualBlock(64, 128, time_emb_dim=256)
x = torch.randn(2, 64, 32, 32)
t_emb = torch.randn(2, 256)
out = block(x, t_emb)
print(f"输入形状: {x.shape}")
print(f"输出形状: {out.shape}")
2.2 注意力块(Attention Block)¶
注意力机制帮助模型关注重要的空间位置。
Python
class AttentionBlock(nn.Module):
"""
空间注意力块
"""
def __init__(self, channels, num_heads=4):
super().__init__()
self.channels = channels
self.num_heads = num_heads
self.norm = nn.GroupNorm(8, channels)
self.qkv = nn.Conv2d(channels, channels * 3, 1)
self.proj = nn.Conv2d(channels, channels, 1)
def forward(self, x):
"""
参数:
x: [B, C, H, W]
返回:
[B, C, H, W]
"""
B, C, H, W = x.shape
h = self.norm(x)
# 计算Q, K, V
qkv = self.qkv(h)
q, k, v = qkv.chunk(3, dim=1)
# 重塑为多头注意力格式
q = q.view(B, self.num_heads, C // self.num_heads, H * W).transpose(2, 3) # 重塑张量形状
k = k.view(B, self.num_heads, C // self.num_heads, H * W).transpose(2, 3)
v = v.view(B, self.num_heads, C // self.num_heads, H * W).transpose(2, 3)
# 计算注意力
scale = (C // self.num_heads) ** -0.5
attn = torch.softmax(q @ k.transpose(-2, -1) * scale, dim=-1)
# 应用注意力
h = attn @ v
h = h.transpose(2, 3).reshape(B, C, H, W)
# 投影
h = self.proj(h)
return x + h
# 测试
attn_block = AttentionBlock(128)
x = torch.randn(2, 128, 16, 16)
out = attn_block(x)
print(f"输入形状: {x.shape}")
print(f"输出形状: {out.shape}")
2.3 下采样和上采样¶
Python
class Downsample(nn.Module):
"""下采样层"""
def __init__(self, channels):
super().__init__()
self.conv = nn.Conv2d(channels, channels, 3, stride=2, padding=1)
def forward(self, x):
return self.conv(x)
class Upsample(nn.Module):
"""上采样层"""
def __init__(self, channels):
super().__init__()
self.conv = nn.Conv2d(channels, channels, 3, padding=1)
def forward(self, x):
x = F.interpolate(x, scale_factor=2, mode='nearest') # F.xxx PyTorch函数式API
return self.conv(x)
# 测试
down = Downsample(64)
up = Upsample(64)
x = torch.randn(2, 64, 32, 32)
x_down = down(x)
x_up = up(x_down)
print(f"原始形状: {x.shape}")
print(f"下采样后: {x_down.shape}")
print(f"上采样后: {x_up.shape}")
3. 时间步嵌入¶
3.1 正弦位置编码¶
Python
class SinusoidalPositionEmbeddings(nn.Module):
"""
正弦位置编码,用于时间步嵌入
"""
def __init__(self, dim):
super().__init__()
self.dim = dim
def forward(self, time):
"""
参数:
time: [B] 时间步
返回:
[B, dim] 位置编码
"""
import math # 确保在使用前导入
device = time.device
half_dim = self.dim // 2
# 计算频率
embeddings = math.log(10000) / (half_dim - 1)
embeddings = torch.exp(torch.arange(half_dim, device=device) * -embeddings)
# 计算正弦和余弦
embeddings = time[:, None] * embeddings[None, :]
embeddings = torch.cat([embeddings.sin(), embeddings.cos()], dim=-1) # torch.cat沿已有维度拼接张量
return embeddings
import math
# 测试
time_emb = SinusoidalPositionEmbeddings(256)
t = torch.tensor([0, 100, 500, 999])
emb = time_emb(t)
print(f"时间步: {t}")
print(f"嵌入形状: {emb.shape}")
print(f"嵌入范围: [{emb.min():.3f}, {emb.max():.3f}]")
3.2 时间嵌入MLP¶
Python
class TimeEmbeddingMLP(nn.Module):
"""
时间嵌入的多层感知机
"""
def __init__(self, time_emb_dim, hidden_dim):
super().__init__()
self.mlp = nn.Sequential(
SinusoidalPositionEmbeddings(time_emb_dim),
nn.Linear(time_emb_dim, hidden_dim),
nn.SiLU(),
nn.Linear(hidden_dim, hidden_dim),
)
def forward(self, t):
return self.mlp(t)
4. 完整UNet实现¶
Python
class UNet(nn.Module):
"""
用于扩散模型的UNet
"""
def __init__(
self,
in_channels=3,
out_channels=3,
base_channels=64,
channel_mults=(1, 2, 4, 8),
num_res_blocks=2,
time_emb_dim=256,
attention_resolutions=(8, 16),
dropout=0.1,
):
super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.num_resolutions = len(channel_mults)
self._num_res_blocks = num_res_blocks
# 时间嵌入
self.time_embed = TimeEmbeddingMLP(time_emb_dim, base_channels * 4)
# 输入卷积
self.input_conv = nn.Conv2d(in_channels, base_channels, 3, padding=1)
# 编码器
self.encoder_blocks = nn.ModuleList()
self.encoder_downsamples = nn.ModuleList()
channels = [base_channels]
now_channels = base_channels
for i, mult in enumerate(channel_mults): # enumerate同时获取索引和元素
out_ch = base_channels * mult
for _ in range(num_res_blocks):
self.encoder_blocks.append(
ResidualBlock(now_channels, out_ch, time_emb_dim, dropout)
)
# 在指定分辨率添加注意力
if 32 // (2 ** i) in attention_resolutions:
self.encoder_blocks.append(AttentionBlock(out_ch))
now_channels = out_ch
channels.append(now_channels)
# 下采样(最后一层除外)
if i != len(channel_mults) - 1:
self.encoder_downsamples.append(Downsample(now_channels))
channels.append(now_channels)
# Bottleneck
self.middle_block1 = ResidualBlock(now_channels, now_channels, time_emb_dim, dropout)
self.middle_attn = AttentionBlock(now_channels)
self.middle_block2 = ResidualBlock(now_channels, now_channels, time_emb_dim, dropout)
# 解码器
self.decoder_blocks = nn.ModuleList()
self.decoder_upsamples = nn.ModuleList()
for i, mult in reversed(list(enumerate(channel_mults))):
out_ch = base_channels * mult
for j in range(num_res_blocks + 1):
# 跳跃连接使输入通道数翻倍
in_ch = channels.pop() + now_channels
self.decoder_blocks.append(
ResidualBlock(in_ch, out_ch, time_emb_dim, dropout)
)
# 在指定分辨率添加注意力
if 32 // (2 ** i) in attention_resolutions:
self.decoder_blocks.append(AttentionBlock(out_ch))
now_channels = out_ch
# 上采样(第一层除外)
if i != 0:
self.decoder_upsamples.append(Upsample(now_channels))
# 输出
self.output_conv = nn.Sequential(
nn.GroupNorm(8, now_channels),
nn.SiLU(),
nn.Conv2d(now_channels, out_channels, 3, padding=1),
)
def forward(self, x, t):
"""
参数:
x: [B, in_channels, H, W] 输入图像
t: [B] 时间步
返回:
[B, out_channels, H, W] 预测的噪声
"""
# 时间嵌入
t_emb = self.time_embed(t)
# 输入卷积
h = self.input_conv(x)
# 编码器:按分辨率层级遍历
skips = [h]
enc_idx = 0
down_idx = 0
for i in range(self.num_resolutions):
for _ in range(self._num_res_blocks):
# ResidualBlock
h = self.encoder_blocks[enc_idx](h, t_emb)
enc_idx += 1
# 可选的AttentionBlock
if enc_idx < len(self.encoder_blocks) and isinstance(self.encoder_blocks[enc_idx], AttentionBlock): # isinstance检查类型
h = self.encoder_blocks[enc_idx](h)
enc_idx += 1
skips.append(h)
# 下采样(最后一层除外)
if i < self.num_resolutions - 1:
h = self.encoder_downsamples[down_idx](h)
down_idx += 1
skips.append(h)
# Bottleneck
h = self.middle_block1(h, t_emb)
h = self.middle_attn(h)
h = self.middle_block2(h, t_emb)
# 解码器:按分辨率层级遍历
dec_idx = 0
up_idx = 0
for i in reversed(range(self.num_resolutions)):
for j in range(self._num_res_blocks + 1):
# 跳跃连接 + ResidualBlock
h = torch.cat([h, skips.pop()], dim=1)
h = self.decoder_blocks[dec_idx](h, t_emb)
dec_idx += 1
# 可选的AttentionBlock(不弹出skip)
if dec_idx < len(self.decoder_blocks) and isinstance(self.decoder_blocks[dec_idx], AttentionBlock):
h = self.decoder_blocks[dec_idx](h)
dec_idx += 1
# 上采样(第一层除外)
if i > 0:
h = self.decoder_upsamples[up_idx](h)
up_idx += 1
# 输出
return self.output_conv(h)
# 测试UNet
model = UNet(
in_channels=3,
out_channels=3,
base_channels=64,
channel_mults=(1, 2, 4),
num_res_blocks=2,
)
x = torch.randn(2, 3, 32, 32)
t = torch.randint(0, 1000, (2,))
out = model(x, t)
print(f"输入形状: {x.shape}")
print(f"时间步: {t}")
print(f"输出形状: {out.shape}")
print(f"参数量: {sum(p.numel() for p in model.parameters()) / 1e6:.2f}M")
5. UNet变体¶
5.1 更轻量的版本¶
Python
class LightUNet(nn.Module):
"""
轻量级UNet,适合快速实验
"""
def __init__(self, in_channels=3, out_channels=3, base_channels=32):
super().__init__()
# 简化版:更少的层和通道
self.time_embed = nn.Sequential(
SinusoidalPositionEmbeddings(128),
nn.Linear(128, 512),
nn.SiLU(),
nn.Linear(512, 512),
)
# 编码器
self.enc1 = self._make_block(in_channels, base_channels, 512)
self.enc2 = self._make_block(base_channels, base_channels * 2, 512)
self.enc3 = self._make_block(base_channels * 2, base_channels * 4, 512)
self.down1 = nn.Conv2d(base_channels, base_channels, 3, stride=2, padding=1)
self.down2 = nn.Conv2d(base_channels * 2, base_channels * 2, 3, stride=2, padding=1)
# Bottleneck
self.bottleneck = self._make_block(base_channels * 4, base_channels * 4, 512)
# 解码器
self.up2 = nn.ConvTranspose2d(base_channels * 4, base_channels * 2, 4, stride=2, padding=1)
self.up1 = nn.ConvTranspose2d(base_channels * 2, base_channels, 4, stride=2, padding=1)
self.dec3 = self._make_block(base_channels * 4, base_channels * 2, 512)
self.dec2 = self._make_block(base_channels * 2, base_channels, 512)
self.dec1 = self._make_block(base_channels, base_channels, 512)
# 输出
self.output = nn.Conv2d(base_channels, out_channels, 3, padding=1)
def _make_block(self, in_ch, out_ch, time_dim):
return nn.ModuleDict({
'conv1': nn.Conv2d(in_ch, out_ch, 3, padding=1),
'conv2': nn.Conv2d(out_ch, out_ch, 3, padding=1),
'norm1': nn.GroupNorm(8, out_ch),
'norm2': nn.GroupNorm(8, out_ch),
'time_mlp': nn.Sequential(nn.SiLU(), nn.Linear(time_dim, out_ch)),
})
def _forward_block(self, block, x, t_emb):
h = F.silu(block['norm1'](block['conv1'](x)))
h = h + block['time_mlp'](t_emb)[:, :, None, None]
h = F.silu(block['norm2'](block['conv2'](h)))
return h
def forward(self, x, t):
t_emb = self.time_embed(t)
# 编码器
e1 = self._forward_block(self.enc1, x, t_emb)
e2 = self._forward_block(self.enc2, self.down1(e1), t_emb)
e3 = self._forward_block(self.enc3, self.down2(e2), t_emb)
# Bottleneck
b = self._forward_block(self.bottleneck, e3, t_emb)
# 解码器(带跳跃连接)
d3 = self._forward_block(self.dec3, torch.cat([self.up2(b), e2], dim=1), t_emb)
d2 = self._forward_block(self.dec2, torch.cat([self.up1(d3), e1], dim=1), t_emb)
d1 = self._forward_block(self.dec1, d2, t_emb)
return self.output(d1)
# 测试轻量版
light_model = LightUNet(base_channels=32)
x = torch.randn(2, 3, 32, 32)
t = torch.randint(0, 1000, (2,))
out = light_model(x, t)
print(f"\n轻量版UNet:")
print(f"输入形状: {x.shape}")
print(f"输出形状: {out.shape}")
print(f"参数量: {sum(p.numel() for p in light_model.parameters()) / 1e6:.2f}M")
6. 本章总结¶
核心概念¶
- UNet结构
- 编码器-解码器架构
- 跳跃连接保留细节
-
多尺度特征融合
-
关键组件
- 残差块:基本构建单元
- 注意力机制:捕获长距离依赖
-
时间嵌入:注入时间步信息
-
设计选择
- GroupNorm:比BatchNorm更适合小批量
- SiLU激活:平滑的非线性
- dropout:防止过拟合
关键代码¶
Python
# 残差块
class ResidualBlock(nn.Module):
def __init__(self, in_ch, out_ch, time_dim):
self.conv1 = nn.Conv2d(in_ch, out_ch, 3, padding=1)
self.time_mlp = nn.Linear(time_dim, out_ch)
self.conv2 = nn.Conv2d(out_ch, out_ch, 3, padding=1)
def forward(self, x, t_emb):
h = self.conv1(x)
h = h + self.time_mlp(t_emb)[:, :, None, None]
h = self.conv2(h)
return h + x
# UNet
class UNet(nn.Module):
def forward(self, x, t):
t_emb = self.time_embed(t)
# 编码器...
# 解码器...
return output
📝 自测问题¶
基础问题¶
- UNet结构
- UNet为什么适合扩散模型?
- 跳跃连接的作用是什么?
-
编码器和解码器的通道数如何设计?
-
组件实现
- 为什么使用GroupNorm而不是BatchNorm?
- 时间嵌入是如何注入到网络中的?
-
注意力机制在UNet中起什么作用?
-
设计选择
- 如何选择合适的base_channels?
- channel_mults的设计原则是什么?
- 轻量版UNet做了哪些简化?
编程练习¶
- 修改UNet,添加更多的注意力层
- 实现一个更深的UNet(5层编码器)
- 比较不同配置下的参数量和计算量
- 可视化UNet的特征图
思考题¶
- 为什么扩散模型不使用ResNet或Transformer?
- 如何设计适合高分辨率图像的UNet?
- UNet的哪些部分可以进一步优化?
🔗 下一步¶
理解了UNet架构后,我们将学习数据加载与预处理,为训练做好准备。
→ 下一步:03-数据加载与预处理.md