跳转至

02 - UNet架构详解

学习时间: 4小时 重要性: ⭐⭐⭐⭐⭐ 扩散模型的核心网络结构


🎯 学习目标

完成本章后,你将能够: - 理解UNet的编码器-解码器结构 - 掌握残差块和注意力机制的实现 - 实现完整的UNet网络 - 理解为什么UNet适合扩散模型


1. UNet架构概述

1.1 什么是UNet

UNet是一种编码器-解码器结构的卷积神经网络,最初用于医学图像分割。它的特点是: - U形结构:先下采样(编码),再上采样(解码) - 跳跃连接:将编码器的特征直接连接到解码器 - 多尺度特征:融合不同层次的特征

1.2 为什么扩散模型使用UNet

  1. 保持空间信息:跳跃连接保留了细节信息
  2. 多尺度处理:不同层次的特征对去噪都有帮助
  3. 输入输出同尺寸:适合像素级预测任务
  4. 计算效率:结构对称,易于实现

1.3 UNet结构示意图

Text Only
输入 x_t
┌─────────────────────────────────────────┐
│ 编码器 (Encoder)                        │
│                                         │
│  Conv ──→ Conv ──→ MaxPool ──→ ...     │
│    ↓        ↓                    ↓      │
│   f1       f2                   fN      │
└─────────────────────────────────────────┘
   bottleneck
┌─────────────────────────────────────────┐
│ 解码器 (Decoder)                        │
│                                         │
│  UpSample ──→ Conv ──→ Conv ──→ ...    │
│     ↑          ↑                      │
│    fN         f2                     f1 │
│  (跳跃连接)                              │
└─────────────────────────────────────────┘
输出 ε_θ

2. 基础组件实现

2.1 残差块(Residual Block)

残差块是UNet的基本构建单元,包含两个卷积层和跳跃连接。

Python
import torch
import torch.nn as nn
import torch.nn.functional as F

class ResidualBlock(nn.Module):  # 继承nn.Module定义网络层
    """
    残差块
    """
    def __init__(self, in_channels, out_channels, time_emb_dim, dropout=0.1):
        super().__init__()  # super()调用父类方法

        # 第一个卷积层
        self.conv1 = nn.Sequential(
            nn.GroupNorm(8, in_channels),
            nn.SiLU(),
            nn.Conv2d(in_channels, out_channels, 3, padding=1)
        )

        # 时间嵌入投影
        self.time_mlp = nn.Sequential(
            nn.SiLU(),
            nn.Linear(time_emb_dim, out_channels)
        )

        # 第二个卷积层
        self.conv2 = nn.Sequential(
            nn.GroupNorm(8, out_channels),
            nn.SiLU(),
            nn.Dropout(dropout),
            nn.Conv2d(out_channels, out_channels, 3, padding=1)
        )

        # 残差连接
        if in_channels != out_channels:
            self.shortcut = nn.Conv2d(in_channels, out_channels, 1)
        else:
            self.shortcut = nn.Identity()

    def forward(self, x, t_emb):
        """
        参数:
            x: 输入特征 [B, C, H, W]
            t_emb: 时间嵌入 [B, time_emb_dim]

        返回:
            输出特征 [B, out_channels, H, W]
        """
        h = self.conv1(x)

        # 添加时间嵌入
        h = h + self.time_mlp(t_emb)[:, :, None, None]

        h = self.conv2(h)

        return h + self.shortcut(x)

# 测试
block = ResidualBlock(64, 128, time_emb_dim=256)
x = torch.randn(2, 64, 32, 32)
t_emb = torch.randn(2, 256)
out = block(x, t_emb)
print(f"输入形状: {x.shape}")
print(f"输出形状: {out.shape}")

2.2 注意力块(Attention Block)

注意力机制帮助模型关注重要的空间位置。

Python
class AttentionBlock(nn.Module):
    """
    空间注意力块
    """
    def __init__(self, channels, num_heads=4):
        super().__init__()
        self.channels = channels
        self.num_heads = num_heads

        self.norm = nn.GroupNorm(8, channels)
        self.qkv = nn.Conv2d(channels, channels * 3, 1)
        self.proj = nn.Conv2d(channels, channels, 1)

    def forward(self, x):
        """
        参数:
            x: [B, C, H, W]

        返回:
            [B, C, H, W]
        """
        B, C, H, W = x.shape

        h = self.norm(x)

        # 计算Q, K, V
        qkv = self.qkv(h)
        q, k, v = qkv.chunk(3, dim=1)

        # 重塑为多头注意力格式
        q = q.view(B, self.num_heads, C // self.num_heads, H * W).transpose(2, 3)  # 重塑张量形状
        k = k.view(B, self.num_heads, C // self.num_heads, H * W).transpose(2, 3)
        v = v.view(B, self.num_heads, C // self.num_heads, H * W).transpose(2, 3)

        # 计算注意力
        scale = (C // self.num_heads) ** -0.5
        attn = torch.softmax(q @ k.transpose(-2, -1) * scale, dim=-1)

        # 应用注意力
        h = attn @ v
        h = h.transpose(2, 3).reshape(B, C, H, W)

        # 投影
        h = self.proj(h)

        return x + h

# 测试
attn_block = AttentionBlock(128)
x = torch.randn(2, 128, 16, 16)
out = attn_block(x)
print(f"输入形状: {x.shape}")
print(f"输出形状: {out.shape}")

2.3 下采样和上采样

Python
class Downsample(nn.Module):
    """下采样层"""
    def __init__(self, channels):
        super().__init__()
        self.conv = nn.Conv2d(channels, channels, 3, stride=2, padding=1)

    def forward(self, x):
        return self.conv(x)

class Upsample(nn.Module):
    """上采样层"""
    def __init__(self, channels):
        super().__init__()
        self.conv = nn.Conv2d(channels, channels, 3, padding=1)

    def forward(self, x):
        x = F.interpolate(x, scale_factor=2, mode='nearest')  # F.xxx PyTorch函数式API
        return self.conv(x)

# 测试
down = Downsample(64)
up = Upsample(64)
x = torch.randn(2, 64, 32, 32)
x_down = down(x)
x_up = up(x_down)
print(f"原始形状: {x.shape}")
print(f"下采样后: {x_down.shape}")
print(f"上采样后: {x_up.shape}")

3. 时间步嵌入

3.1 正弦位置编码

Python
class SinusoidalPositionEmbeddings(nn.Module):
    """
    正弦位置编码,用于时间步嵌入
    """
    def __init__(self, dim):
        super().__init__()
        self.dim = dim

    def forward(self, time):
        """
        参数:
            time: [B] 时间步

        返回:
            [B, dim] 位置编码
        """
        import math  # 确保在使用前导入
        device = time.device
        half_dim = self.dim // 2

        # 计算频率
        embeddings = math.log(10000) / (half_dim - 1)
        embeddings = torch.exp(torch.arange(half_dim, device=device) * -embeddings)

        # 计算正弦和余弦
        embeddings = time[:, None] * embeddings[None, :]
        embeddings = torch.cat([embeddings.sin(), embeddings.cos()], dim=-1)  # torch.cat沿已有维度拼接张量

        return embeddings

import math

# 测试
time_emb = SinusoidalPositionEmbeddings(256)
t = torch.tensor([0, 100, 500, 999])
emb = time_emb(t)
print(f"时间步: {t}")
print(f"嵌入形状: {emb.shape}")
print(f"嵌入范围: [{emb.min():.3f}, {emb.max():.3f}]")

3.2 时间嵌入MLP

Python
class TimeEmbeddingMLP(nn.Module):
    """
    时间嵌入的多层感知机
    """
    def __init__(self, time_emb_dim, hidden_dim):
        super().__init__()
        self.mlp = nn.Sequential(
            SinusoidalPositionEmbeddings(time_emb_dim),
            nn.Linear(time_emb_dim, hidden_dim),
            nn.SiLU(),
            nn.Linear(hidden_dim, hidden_dim),
        )

    def forward(self, t):
        return self.mlp(t)

4. 完整UNet实现

Python
class UNet(nn.Module):
    """
    用于扩散模型的UNet
    """
    def __init__(
        self,
        in_channels=3,
        out_channels=3,
        base_channels=64,
        channel_mults=(1, 2, 4, 8),
        num_res_blocks=2,
        time_emb_dim=256,
        attention_resolutions=(8, 16),
        dropout=0.1,
    ):
        super().__init__()

        self.in_channels = in_channels
        self.out_channels = out_channels
        self.num_resolutions = len(channel_mults)
        self._num_res_blocks = num_res_blocks

        # 时间嵌入
        self.time_embed = TimeEmbeddingMLP(time_emb_dim, base_channels * 4)

        # 输入卷积
        self.input_conv = nn.Conv2d(in_channels, base_channels, 3, padding=1)

        # 编码器
        self.encoder_blocks = nn.ModuleList()
        self.encoder_downsamples = nn.ModuleList()

        channels = [base_channels]
        now_channels = base_channels

        for i, mult in enumerate(channel_mults):  # enumerate同时获取索引和元素
            out_ch = base_channels * mult

            for _ in range(num_res_blocks):
                self.encoder_blocks.append(
                    ResidualBlock(now_channels, out_ch, time_emb_dim, dropout)
                )

                # 在指定分辨率添加注意力
                if 32 // (2 ** i) in attention_resolutions:
                    self.encoder_blocks.append(AttentionBlock(out_ch))

                now_channels = out_ch
                channels.append(now_channels)

            # 下采样(最后一层除外)
            if i != len(channel_mults) - 1:
                self.encoder_downsamples.append(Downsample(now_channels))
                channels.append(now_channels)

        # Bottleneck
        self.middle_block1 = ResidualBlock(now_channels, now_channels, time_emb_dim, dropout)
        self.middle_attn = AttentionBlock(now_channels)
        self.middle_block2 = ResidualBlock(now_channels, now_channels, time_emb_dim, dropout)

        # 解码器
        self.decoder_blocks = nn.ModuleList()
        self.decoder_upsamples = nn.ModuleList()

        for i, mult in reversed(list(enumerate(channel_mults))):
            out_ch = base_channels * mult

            for j in range(num_res_blocks + 1):
                # 跳跃连接使输入通道数翻倍
                in_ch = channels.pop() + now_channels
                self.decoder_blocks.append(
                    ResidualBlock(in_ch, out_ch, time_emb_dim, dropout)
                )

                # 在指定分辨率添加注意力
                if 32 // (2 ** i) in attention_resolutions:
                    self.decoder_blocks.append(AttentionBlock(out_ch))

                now_channels = out_ch

            # 上采样(第一层除外)
            if i != 0:
                self.decoder_upsamples.append(Upsample(now_channels))

        # 输出
        self.output_conv = nn.Sequential(
            nn.GroupNorm(8, now_channels),
            nn.SiLU(),
            nn.Conv2d(now_channels, out_channels, 3, padding=1),
        )

    def forward(self, x, t):
        """
        参数:
            x: [B, in_channels, H, W] 输入图像
            t: [B] 时间步

        返回:
            [B, out_channels, H, W] 预测的噪声
        """
        # 时间嵌入
        t_emb = self.time_embed(t)

        # 输入卷积
        h = self.input_conv(x)

        # 编码器:按分辨率层级遍历
        skips = [h]
        enc_idx = 0
        down_idx = 0
        for i in range(self.num_resolutions):
            for _ in range(self._num_res_blocks):
                # ResidualBlock
                h = self.encoder_blocks[enc_idx](h, t_emb)
                enc_idx += 1
                # 可选的AttentionBlock
                if enc_idx < len(self.encoder_blocks) and isinstance(self.encoder_blocks[enc_idx], AttentionBlock):  # isinstance检查类型
                    h = self.encoder_blocks[enc_idx](h)
                    enc_idx += 1
                skips.append(h)
            # 下采样(最后一层除外)
            if i < self.num_resolutions - 1:
                h = self.encoder_downsamples[down_idx](h)
                down_idx += 1
                skips.append(h)

        # Bottleneck
        h = self.middle_block1(h, t_emb)
        h = self.middle_attn(h)
        h = self.middle_block2(h, t_emb)

        # 解码器:按分辨率层级遍历
        dec_idx = 0
        up_idx = 0
        for i in reversed(range(self.num_resolutions)):
            for j in range(self._num_res_blocks + 1):
                # 跳跃连接 + ResidualBlock
                h = torch.cat([h, skips.pop()], dim=1)
                h = self.decoder_blocks[dec_idx](h, t_emb)
                dec_idx += 1
                # 可选的AttentionBlock(不弹出skip)
                if dec_idx < len(self.decoder_blocks) and isinstance(self.decoder_blocks[dec_idx], AttentionBlock):
                    h = self.decoder_blocks[dec_idx](h)
                    dec_idx += 1
            # 上采样(第一层除外)
            if i > 0:
                h = self.decoder_upsamples[up_idx](h)
                up_idx += 1

        # 输出
        return self.output_conv(h)

# 测试UNet
model = UNet(
    in_channels=3,
    out_channels=3,
    base_channels=64,
    channel_mults=(1, 2, 4),
    num_res_blocks=2,
)

x = torch.randn(2, 3, 32, 32)
t = torch.randint(0, 1000, (2,))
out = model(x, t)

print(f"输入形状: {x.shape}")
print(f"时间步: {t}")
print(f"输出形状: {out.shape}")
print(f"参数量: {sum(p.numel() for p in model.parameters()) / 1e6:.2f}M")

5. UNet变体

5.1 更轻量的版本

Python
class LightUNet(nn.Module):
    """
    轻量级UNet,适合快速实验
    """
    def __init__(self, in_channels=3, out_channels=3, base_channels=32):
        super().__init__()

        # 简化版:更少的层和通道
        self.time_embed = nn.Sequential(
            SinusoidalPositionEmbeddings(128),
            nn.Linear(128, 512),
            nn.SiLU(),
            nn.Linear(512, 512),
        )

        # 编码器
        self.enc1 = self._make_block(in_channels, base_channels, 512)
        self.enc2 = self._make_block(base_channels, base_channels * 2, 512)
        self.enc3 = self._make_block(base_channels * 2, base_channels * 4, 512)

        self.down1 = nn.Conv2d(base_channels, base_channels, 3, stride=2, padding=1)
        self.down2 = nn.Conv2d(base_channels * 2, base_channels * 2, 3, stride=2, padding=1)

        # Bottleneck
        self.bottleneck = self._make_block(base_channels * 4, base_channels * 4, 512)

        # 解码器
        self.up2 = nn.ConvTranspose2d(base_channels * 4, base_channels * 2, 4, stride=2, padding=1)
        self.up1 = nn.ConvTranspose2d(base_channels * 2, base_channels, 4, stride=2, padding=1)

        self.dec3 = self._make_block(base_channels * 4, base_channels * 2, 512)
        self.dec2 = self._make_block(base_channels * 2, base_channels, 512)
        self.dec1 = self._make_block(base_channels, base_channels, 512)

        # 输出
        self.output = nn.Conv2d(base_channels, out_channels, 3, padding=1)

    def _make_block(self, in_ch, out_ch, time_dim):
        return nn.ModuleDict({
            'conv1': nn.Conv2d(in_ch, out_ch, 3, padding=1),
            'conv2': nn.Conv2d(out_ch, out_ch, 3, padding=1),
            'norm1': nn.GroupNorm(8, out_ch),
            'norm2': nn.GroupNorm(8, out_ch),
            'time_mlp': nn.Sequential(nn.SiLU(), nn.Linear(time_dim, out_ch)),
        })

    def _forward_block(self, block, x, t_emb):
        h = F.silu(block['norm1'](block['conv1'](x)))
        h = h + block['time_mlp'](t_emb)[:, :, None, None]
        h = F.silu(block['norm2'](block['conv2'](h)))
        return h

    def forward(self, x, t):
        t_emb = self.time_embed(t)

        # 编码器
        e1 = self._forward_block(self.enc1, x, t_emb)
        e2 = self._forward_block(self.enc2, self.down1(e1), t_emb)
        e3 = self._forward_block(self.enc3, self.down2(e2), t_emb)

        # Bottleneck
        b = self._forward_block(self.bottleneck, e3, t_emb)

        # 解码器(带跳跃连接)
        d3 = self._forward_block(self.dec3, torch.cat([self.up2(b), e2], dim=1), t_emb)
        d2 = self._forward_block(self.dec2, torch.cat([self.up1(d3), e1], dim=1), t_emb)
        d1 = self._forward_block(self.dec1, d2, t_emb)

        return self.output(d1)

# 测试轻量版
light_model = LightUNet(base_channels=32)
x = torch.randn(2, 3, 32, 32)
t = torch.randint(0, 1000, (2,))
out = light_model(x, t)
print(f"\n轻量版UNet:")
print(f"输入形状: {x.shape}")
print(f"输出形状: {out.shape}")
print(f"参数量: {sum(p.numel() for p in light_model.parameters()) / 1e6:.2f}M")

6. 本章总结

核心概念

  1. UNet结构
  2. 编码器-解码器架构
  3. 跳跃连接保留细节
  4. 多尺度特征融合

  5. 关键组件

  6. 残差块:基本构建单元
  7. 注意力机制:捕获长距离依赖
  8. 时间嵌入:注入时间步信息

  9. 设计选择

  10. GroupNorm:比BatchNorm更适合小批量
  11. SiLU激活:平滑的非线性
  12. dropout:防止过拟合

关键代码

Python
# 残差块
class ResidualBlock(nn.Module):
    def __init__(self, in_ch, out_ch, time_dim):
        self.conv1 = nn.Conv2d(in_ch, out_ch, 3, padding=1)
        self.time_mlp = nn.Linear(time_dim, out_ch)
        self.conv2 = nn.Conv2d(out_ch, out_ch, 3, padding=1)

    def forward(self, x, t_emb):
        h = self.conv1(x)
        h = h + self.time_mlp(t_emb)[:, :, None, None]
        h = self.conv2(h)
        return h + x

# UNet
class UNet(nn.Module):
    def forward(self, x, t):
        t_emb = self.time_embed(t)
        # 编码器...
        # 解码器...
        return output

📝 自测问题

基础问题

  1. UNet结构
  2. UNet为什么适合扩散模型?
  3. 跳跃连接的作用是什么?
  4. 编码器和解码器的通道数如何设计?

  5. 组件实现

  6. 为什么使用GroupNorm而不是BatchNorm?
  7. 时间嵌入是如何注入到网络中的?
  8. 注意力机制在UNet中起什么作用?

  9. 设计选择

  10. 如何选择合适的base_channels?
  11. channel_mults的设计原则是什么?
  12. 轻量版UNet做了哪些简化?

编程练习

  1. 修改UNet,添加更多的注意力层
  2. 实现一个更深的UNet(5层编码器)
  3. 比较不同配置下的参数量和计算量
  4. 可视化UNet的特征图

思考题

  1. 为什么扩散模型不使用ResNet或Transformer?
  2. 如何设计适合高分辨率图像的UNet?
  3. UNet的哪些部分可以进一步优化?

🔗 下一步

理解了UNet架构后,我们将学习数据加载与预处理,为训练做好准备。

→ 下一步:03-数据加载与预处理.md