跳转至

07 - DiT与Transformer扩散架构

⚠️ 时效性说明:本章涉及前沿模型/价格/榜单等信息,可能随版本快速变化;请以论文原文、官方发布页和 API 文档为准。

学习时间: 4小时 重要性: ⭐⭐⭐⭐⭐ Sora/SD3/FLUX等最新模型的共同基础架构


🎯 学习目标

完成本章后,你将能够: - 理解扩散模型从U-Net到Transformer的范式转变动因 - 掌握DiT的核心设计:Patchify、AdaLN-Zero条件注入 - 了解U-ViT、SiT、MDT、MM-DiT等Transformer扩散变体 - 理解Scaling Laws在扩散Transformer中的体现 - 实现一个简化版DiT模型


1. 从U-Net到Transformer的范式转变

1.1 为什么要替换U-Net?

U-Net自DDPM以来一直是扩散模型的标准骨干网络,但它存在固有局限:

维度 U-Net Transformer
归纳偏置 强空间局部性偏置 弱偏置,更灵活
可扩展性 难以单纯堆叠层数提升 Scaling Laws明确
多模态融合 Cross-attention拼接 天然支持序列混合
分辨率泛化 需要特殊处理 位置编码灵活适配
工程生态 定制化结构 复用LLM训练基础设施

关键动因:随着模型规模增大,Transformer展现出更明确的Scaling Law特性——模型越大、数据越多、效果越好,且改进趋势可预测。这与LLM的成功经验一致。

1.2 Transformer在视觉任务中的发展

Text Only
ViT (2020) → 图像分类
DALL-E (2021) → 自回归图像生成
U-ViT (2023) → 将Transformer嵌入U-Net结构
DiT (2023) → 纯Transformer扩散模型
SD3/FLUX/Sora (2024) → 大规模Transformer扩散模型

2. DiT原理详解

2.1 论文概述

DiT(Peebles & Xie, 2023, "Scalable Diffusion Models with Transformers")首次证明了纯Transformer架构在扩散模型中可以取代U-Net,并展现出清晰的Scaling Laws。

2.2 Patchify:图像到序列的转换

DiT遵循ViT的方式,将图像(或潜空间特征)分割为patch序列:

\[\text{Input: } z \in \mathbb{R}^{C \times H \times W} \xrightarrow{\text{Patchify}} X \in \mathbb{R}^{N \times D}\]

其中 \(N = \frac{H \times W}{p^2}\) 为patch数量,\(p\) 为patch大小,\(D\) 为embedding维度。

Python
import torch
import torch.nn as nn

class PatchEmbed(nn.Module):  # 继承nn.Module定义网络层
    """将潜空间特征图转换为patch序列"""

    def __init__(self, img_size=32, patch_size=2, in_channels=4, embed_dim=768):
        super().__init__()  # super()调用父类方法
        self.num_patches = (img_size // patch_size) ** 2
        self.proj = nn.Conv2d(
            in_channels, embed_dim,
            kernel_size=patch_size, stride=patch_size
        )

    def forward(self, x):
        # x: [B, C, H, W] → [B, D, H/p, W/p] → [B, N, D]
        x = self.proj(x)                    # [B, D, H/p, W/p]
        x = x.flatten(2).transpose(1, 2)    # [B, N, D]
        return x

Patch大小的影响: - \(p=2\): 序列更长,计算量大,细节更好 - \(p=4\): 序列更短,效率高,适合高分辨率 - \(p=8\): 极端压缩,信息损失较大

2.3 AdaLN-Zero:条件注入机制

DiT探索了四种将时间步 \(t\) 和类别标签 \(c\) 注入Transformer Block的方式:

方式 方法 FID
In-context \(t, c\) 作为额外token拼接到序列中 较高
Cross-attention \(t, c\) 作为cross-attention的KV 中等
AdaLN \(t, c\) 预测LayerNorm的 \(\gamma, \beta\) 较低
AdaLN-Zero AdaLN + 零初始化scale参数 最低

AdaLN-Zero的核心思想

标准LayerNorm对输入进行归一化后,施加可学习的scale和shift: $\(\text{LN}(x) = \gamma \cdot \frac{x - \mu}{\sigma} + \beta\)$

AdaLN-Zero将 \(\gamma, \beta\) 替换为条件预测的参数,并额外引入一个零初始化的门控参数 \(\alpha\)

\[\gamma, \beta, \alpha = \text{MLP}(t_{emb} + c_{emb})\]
\[\text{Block Output} = x + \alpha \cdot \text{Attention}(\text{AdaLN}(x; \gamma, \beta))\]

初始化时 \(\alpha = 0\),使得每个DiT Block在训练初期是恒等映射,保证训练稳定性。

Python
class AdaLNZero(nn.Module):
    """Adaptive Layer Norm Zero — DiT的核心条件注入模块"""

    def __init__(self, dim, cond_dim):
        super().__init__()
        self.norm = nn.LayerNorm(dim, elementwise_affine=False)
        # 预测 gamma, beta, alpha 三组参数
        self.adaLN_modulation = nn.Sequential(
            nn.SiLU(),
            nn.Linear(cond_dim, 6 * dim)  # 为attention和MLP各预测3个参数
        )
        # 零初始化
        nn.init.zeros_(self.adaLN_modulation[-1].weight)  # [-1]负索引取最后元素
        nn.init.zeros_(self.adaLN_modulation[-1].bias)

    def forward(self, x, cond):
        # cond: [B, cond_dim] → 预测6组调制参数
        (gamma1, beta1, alpha1,
         gamma2, beta2, alpha2) = self.adaLN_modulation(cond).chunk(6, dim=-1)

        return gamma1, beta1, alpha1, gamma2, beta2, alpha2

    def modulate(self, x, gamma, beta):
        return self.norm(x) * (1 + gamma.unsqueeze(1)) + beta.unsqueeze(1)  # unsqueeze增加一个维度

2.4 DiT整体架构

Text Only
输入潜变量 z ∈ R^(C×H×W)
   [PatchEmbed] → [B, N, D]
   + Positional Embedding (sin-cos)
   ┌─────────────────────────┐
   │     DiT Block × L       │
   │ ┌─────────────────────┐ │
   │ │ AdaLN-Zero(t, c)    │ │
   │ │     ↓                │ │
   │ │ Self-Attention       │ │
   │ │     ↓                │ │
   │ │ AdaLN-Zero(t, c)    │ │
   │ │     ↓                │ │
   │ │ Pointwise FFN        │ │
   │ └─────────────────────┘ │
   └─────────────────────────┘
   [Final AdaLN + Linear]
   [Unpatchify] → R^(2C×H×W)  (预测噪声+对角协方差)

3. Transformer扩散模型家族

3.1 U-ViT

U-ViT(Bao et al., 2023, "All are Worth Words: A ViT Backbone for Diffusion Models")将U-Net的跳跃连接引入ViT:

  • 将时间步、条件、图像patches视为统一的token序列
  • 浅层和深层之间引入long skip connections
  • 保持了U-Net跳跃连接的优势,同时获得Transformer的灵活性
\[X_{deep} = \text{Block}_{deep}(X_{shallow}) + \text{Linear}(X_{shallow})\]

3.2 SiT(Scalable Interpolant Transformers)

SiT(Ma et al., 2024)结合DiT架构与Flow Matching训练:

  • 使用随机插值(Stochastic Interpolant)框架统一多种扩散/流匹配目标
  • 在Interpolant框架下自由选择不同的时间调度和速度场参数化
  • 同等FLOPs下优于DiT

3.3 MDT(Masked Diffusion Transformer)

MDT(Gao et al., 2023, "Masked Diffusion Transformer is a Strong Image Synthesizer")引入了掩码建模思想:

  • 训练时随机mask部分patch tokens
  • 模型同时学习去噪和patch预测
  • 加速训练收敛,提升上下文学习能力

3.4 MM-DiT

MM-DiT(Esser et al., 2024)是SD3/SD3.5使用的多模态Transformer(详见上一章):

  • 图像和文本token联合注意力
  • 模态专属的QKV投影和MLP
  • 支持多种文本编码器的灵活接入

3.5 架构对比

模型 条件注入 文本处理 跳跃连接 训练目标
DiT AdaLN-Zero 类别标签 DDPM \(\epsilon\)-prediction
U-ViT In-context Token拼接 DDPM \(\epsilon\)-prediction
SiT AdaLN-Zero 类别标签 Flow Matching
MDT AdaLN-Zero 类别标签 DDPM + Mask
MM-DiT AdaLN-Zero 联合注意力 Rectified Flow

4. Scaling Laws在扩散Transformer中的应用

4.1 DiT的Scaling实验

DiT论文中系统测试了不同规模的模型:

模型 层数 维度 注意力头数 参数量 FID (ImageNet 256)
DiT-S/2 12 384 6 33M 68.4
DiT-B/2 12 768 12 130M 43.5
DiT-L/2 24 1024 16 458M 23.3
DiT-XL/2 28 1152 16 675M 9.62

关键发现: 1. FID与模型FLOPs呈幂律关系\(\text{FID} \propto \text{GFLOPs}^{-\alpha}\) 2. 增大模型比增大patch更有效:相同GFLOPs下,大模型+大patch < 大模型+小patch 3. Scaling趋势未饱和:暗示更大模型将持续改进

4.2 工业级Scaling实践

Sora、FLUX等产品验证了扩散Transformer的Scaling Law:

\[L(N, D) \approx \frac{A}{N^{\alpha_N}} + \frac{B}{D^{\alpha_D}} + L_0\]

其中 \(N\) 为参数量,\(D\) 为数据量,\(L\) 为验证损失。

实际数据: - FLUX.1: 12B参数 → 目前开源最佳图像质量 - Sora: 估计数十B参数 → 突破性视频质量 - 趋势:扩散模型正在重复LLM的scaling故事


5. 简化版DiT代码实现

Python
import torch
import torch.nn as nn
import math

class SinusoidalPosEmb(nn.Module):
    """正弦时间步编码"""
    def __init__(self, dim):
        super().__init__()
        self.dim = dim

    def forward(self, t):
        device = t.device
        half_dim = self.dim // 2
        emb = math.log(10000) / (half_dim - 1)
        emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
        emb = t[:, None].float() * emb[None, :]
        return torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1)  # torch.cat沿已有维度拼接张量

class DiTBlock(nn.Module):
    """DiT Transformer Block with AdaLN-Zero"""

    def __init__(self, dim, num_heads, mlp_ratio=4.0):
        super().__init__()
        self.dim = dim
        self.num_heads = num_heads
        head_dim = dim // num_heads
        self.scale = head_dim ** -0.5

        # Self-Attention
        self.qkv = nn.Linear(dim, 3 * dim)
        self.proj = nn.Linear(dim, dim)

        # FFN
        mlp_hidden = int(dim * mlp_ratio)
        self.mlp = nn.Sequential(
            nn.Linear(dim, mlp_hidden),
            nn.GELU(),
            nn.Linear(mlp_hidden, dim),
        )

        # AdaLN-Zero: 两层各需 gamma, beta, alpha → 6 * dim
        self.norm1 = nn.LayerNorm(dim, elementwise_affine=False)
        self.norm2 = nn.LayerNorm(dim, elementwise_affine=False)
        self.adaLN = nn.Sequential(
            nn.SiLU(),
            nn.Linear(dim, 6 * dim),
        )
        # 关键:零初始化
        nn.init.zeros_(self.adaLN[-1].weight)
        nn.init.zeros_(self.adaLN[-1].bias)

    def forward(self, x, cond):
        """
        x: [B, N, D] — patch token序列
        cond: [B, D] — 条件嵌入(时间步+类别)
        """
        B, N, D = x.shape

        # 预测6组调制参数
        gamma1, beta1, alpha1, gamma2, beta2, alpha2 = \
            self.adaLN(cond).chunk(6, dim=-1)

        # --- Attention分支 ---
        h = self.norm1(x) * (1 + gamma1.unsqueeze(1)) + beta1.unsqueeze(1)
        qkv = self.qkv(h).reshape(B, N, 3, self.num_heads, D // self.num_heads)  # 重塑张量形状
        qkv = qkv.permute(2, 0, 3, 1, 4)
        q, k, v = qkv.unbind(0)

        attn = (q @ k.transpose(-2, -1)) * self.scale
        attn = attn.softmax(dim=-1)
        h = (attn @ v).transpose(1, 2).reshape(B, N, D)
        h = self.proj(h)
        x = x + alpha1.unsqueeze(1) * h  # 零初始化门控

        # --- FFN分支 ---
        h = self.norm2(x) * (1 + gamma2.unsqueeze(1)) + beta2.unsqueeze(1)
        h = self.mlp(h)
        x = x + alpha2.unsqueeze(1) * h  # 零初始化门控

        return x

class DiT(nn.Module):
    """简化版 Diffusion Transformer"""

    def __init__(
        self,
        img_size=32,
        patch_size=2,
        in_channels=4,
        dim=768,
        depth=12,
        num_heads=12,
        num_classes=10,
        mlp_ratio=4.0,
    ):
        super().__init__()
        self.patch_size = patch_size
        self.num_patches = (img_size // patch_size) ** 2

        # Patch Embedding
        self.patch_embed = nn.Conv2d(
            in_channels, dim, kernel_size=patch_size, stride=patch_size
        )
        # 位置编码 (固定sin-cos)
        self.pos_embed = nn.Parameter(
            torch.zeros(1, self.num_patches, dim)
        )

        # 条件编码
        self.time_embed = nn.Sequential(
            SinusoidalPosEmb(dim),
            nn.Linear(dim, dim),
            nn.GELU(),
            nn.Linear(dim, dim),
        )
        self.class_embed = nn.Embedding(num_classes, dim)

        # Transformer Blocks
        self.blocks = nn.ModuleList([
            DiTBlock(dim, num_heads, mlp_ratio) for _ in range(depth)
        ])

        # 输出层
        self.final_norm = nn.LayerNorm(dim, elementwise_affine=False)
        self.final_adaLN = nn.Sequential(
            nn.SiLU(),
            nn.Linear(dim, 2 * dim),
        )
        self.final_proj = nn.Linear(dim, patch_size ** 2 * in_channels * 2)
        # 2倍channels用于预测噪声和对角方差

        self._init_weights()

    def _init_weights(self):
        # 初始化位置编码
        nn.init.normal_(self.pos_embed, std=0.02)
        # 零初始化输出层
        nn.init.zeros_(self.final_proj.weight)
        nn.init.zeros_(self.final_proj.bias)
        nn.init.zeros_(self.final_adaLN[-1].weight)
        nn.init.zeros_(self.final_adaLN[-1].bias)

    def unpatchify(self, x, H, W):
        """[B, N, p*p*C] → [B, C, H, W]"""
        p = self.patch_size
        c = x.shape[-1] // (p * p)
        h, w = H // p, W // p
        x = x.reshape(-1, h, w, p, p, c)
        x = x.permute(0, 5, 1, 3, 2, 4).reshape(-1, c, H, W)
        return x

    def forward(self, x, t, y):
        """
        x: [B, C, H, W] — 带噪声的潜变量
        t: [B] — 时间步
        y: [B] — 类别标签
        """
        B, C, H, W = x.shape

        # Patchify + 位置编码
        x = self.patch_embed(x).flatten(2).transpose(1, 2)  # [B, N, D]
        x = x + self.pos_embed

        # 条件编码
        cond = self.time_embed(t) + self.class_embed(y)     # [B, D]

        # Transformer Blocks
        for block in self.blocks:
            x = block(x, cond)

        # 输出投影
        gamma, beta = self.final_adaLN(cond).chunk(2, dim=-1)
        x = self.final_norm(x) * (1 + gamma.unsqueeze(1)) + beta.unsqueeze(1)
        x = self.final_proj(x)   # [B, N, p*p*2C]
        x = self.unpatchify(x, H, W)  # [B, 2C, H, W]

        noise_pred, var_pred = x.chunk(2, dim=1)
        return noise_pred, var_pred

# === 使用示例 ===
if __name__ == "__main__":
    model = DiT(
        img_size=32, patch_size=2, in_channels=4,
        dim=384, depth=12, num_heads=6, num_classes=10
    )
    print(f"DiT-S/2 参数量: {sum(p.numel() for p in model.parameters()) / 1e6:.1f}M")

    x = torch.randn(2, 4, 32, 32)  # 潜空间输入
    t = torch.randint(0, 1000, (2,))
    y = torch.randint(0, 10, (2,))

    noise_pred, var_pred = model(x, t, y)
    print(f"输出形状: noise={noise_pred.shape}, var={var_pred.shape}")
    # 输出: noise=torch.Size([2, 4, 32, 32]), var=torch.Size([2, 4, 32, 32])

📋 面试要点

高频面试题

  1. DiT相比U-Net的核心优势是什么?
  2. Transformer架构具有清晰的Scaling Laws,模型越大效果持续提升
  3. 弱归纳偏置使其更灵活,易于多模态扩展
  4. 可复用LLM训练基础设施(并行策略、优化器等)
  5. U-Net的下采样/上采样结构在极高分辨率下效率受限

  6. AdaLN-Zero为什么效果最好?

  7. 零初始化使每个Block初始为恒等映射,类似ResNet的残差学习
  8. 训练初期模型=直接跳过所有层,梯度直接传到输入,训练更稳定
  9. 条件信息通过可学习的scale/shift注入,比简单拼接更灵活

  10. DiT的Patch Size如何选择?

  11. \(p=2\) 最常用(DiT-XL/2达到FID 9.62)
  12. 较小的patch保留更多空间细节但增加序列长度(计算量 \(O(N^2)\)
  13. 实际中通常在潜空间(如8×下采样后的32×32)上使用patch_size=2
  14. 对于高分辨率可考虑patch_size=4配合窗口注意力

  15. 如何理解扩散Transformer的Scaling Laws?

  16. FID与GFLOPs呈幂律下降关系
  17. 这与LLM中Loss与参数量/数据量的幂律关系类似
  18. Sora、FLUX等产品验证了大规模扩散Transformer的有效性
  19. 意味着扩散模型的天花板远未达到

✏️ 练习

练习1:DiT-S/2训练

在CIFAR-10上训练简化版DiT-S/2,对比不同条件注入方式(in-context / cross-attention / AdaLN-Zero)的收敛速度和最终FID。

练习2:Patch Size消融实验

固定模型参数量,测试patch_size=½/4对生成质量和训练速度的影响。

练习3:结构变体实验

实现U-ViT的long skip connection并对比纯DiT,观察收敛速度的差异。

练习4:论文精读

  • 精读DiT原始论文,关注Table 1-3的Scaling实验
  • 阅读U-ViT论文,理解其与DiT在设计哲学上的差异

参考文献

  1. Peebles & Xie, 2023. "Scalable Diffusion Models with Transformers" — DiT原始论文
  2. Bao et al., 2023. "All are Worth Words: A ViT Backbone for Diffusion Models" — U-ViT
  3. Ma et al., 2024. "Sit: Exploring Flow and Diffusion-based Generative Models with Scalable Interpolant Transformers" — SiT
  4. Gao et al., 2023. "Masked Diffusion Transformer is a Strong Image Synthesizer" — MDT
  5. Esser et al., 2024. "Scaling Rectified Flow Transformers for High-Resolution Image Synthesis" — MM-DiT/SD3
  6. Dosovitskiy et al., 2021. "An Image is Worth 16x16 Words" — ViT

下一章: 08-流匹配与一致性模型 — 探索更高效的扩散训练与采样范式