07 - DiT与Transformer扩散架构¶
⚠️ 时效性说明:本章涉及前沿模型/价格/榜单等信息,可能随版本快速变化;请以论文原文、官方发布页和 API 文档为准。
学习时间: 4小时 重要性: ⭐⭐⭐⭐⭐ Sora/SD3/FLUX等最新模型的共同基础架构
🎯 学习目标¶
完成本章后,你将能够: - 理解扩散模型从U-Net到Transformer的范式转变动因 - 掌握DiT的核心设计:Patchify、AdaLN-Zero条件注入 - 了解U-ViT、SiT、MDT、MM-DiT等Transformer扩散变体 - 理解Scaling Laws在扩散Transformer中的体现 - 实现一个简化版DiT模型
1. 从U-Net到Transformer的范式转变¶
1.1 为什么要替换U-Net?¶
U-Net自DDPM以来一直是扩散模型的标准骨干网络,但它存在固有局限:
| 维度 | U-Net | Transformer |
|---|---|---|
| 归纳偏置 | 强空间局部性偏置 | 弱偏置,更灵活 |
| 可扩展性 | 难以单纯堆叠层数提升 | Scaling Laws明确 |
| 多模态融合 | Cross-attention拼接 | 天然支持序列混合 |
| 分辨率泛化 | 需要特殊处理 | 位置编码灵活适配 |
| 工程生态 | 定制化结构 | 复用LLM训练基础设施 |
关键动因:随着模型规模增大,Transformer展现出更明确的Scaling Law特性——模型越大、数据越多、效果越好,且改进趋势可预测。这与LLM的成功经验一致。
1.2 Transformer在视觉任务中的发展¶
ViT (2020) → 图像分类
↓
DALL-E (2021) → 自回归图像生成
↓
U-ViT (2023) → 将Transformer嵌入U-Net结构
↓
DiT (2023) → 纯Transformer扩散模型
↓
SD3/FLUX/Sora (2024) → 大规模Transformer扩散模型
2. DiT原理详解¶
2.1 论文概述¶
DiT(Peebles & Xie, 2023, "Scalable Diffusion Models with Transformers")首次证明了纯Transformer架构在扩散模型中可以取代U-Net,并展现出清晰的Scaling Laws。
2.2 Patchify:图像到序列的转换¶
DiT遵循ViT的方式,将图像(或潜空间特征)分割为patch序列:
其中 \(N = \frac{H \times W}{p^2}\) 为patch数量,\(p\) 为patch大小,\(D\) 为embedding维度。
import torch
import torch.nn as nn
class PatchEmbed(nn.Module): # 继承nn.Module定义网络层
"""将潜空间特征图转换为patch序列"""
def __init__(self, img_size=32, patch_size=2, in_channels=4, embed_dim=768):
super().__init__() # super()调用父类方法
self.num_patches = (img_size // patch_size) ** 2
self.proj = nn.Conv2d(
in_channels, embed_dim,
kernel_size=patch_size, stride=patch_size
)
def forward(self, x):
# x: [B, C, H, W] → [B, D, H/p, W/p] → [B, N, D]
x = self.proj(x) # [B, D, H/p, W/p]
x = x.flatten(2).transpose(1, 2) # [B, N, D]
return x
Patch大小的影响: - \(p=2\): 序列更长,计算量大,细节更好 - \(p=4\): 序列更短,效率高,适合高分辨率 - \(p=8\): 极端压缩,信息损失较大
2.3 AdaLN-Zero:条件注入机制¶
DiT探索了四种将时间步 \(t\) 和类别标签 \(c\) 注入Transformer Block的方式:
| 方式 | 方法 | FID |
|---|---|---|
| In-context | 将 \(t, c\) 作为额外token拼接到序列中 | 较高 |
| Cross-attention | 将 \(t, c\) 作为cross-attention的KV | 中等 |
| AdaLN | 用 \(t, c\) 预测LayerNorm的 \(\gamma, \beta\) | 较低 |
| AdaLN-Zero | AdaLN + 零初始化scale参数 | 最低 |
AdaLN-Zero的核心思想:
标准LayerNorm对输入进行归一化后,施加可学习的scale和shift: $\(\text{LN}(x) = \gamma \cdot \frac{x - \mu}{\sigma} + \beta\)$
AdaLN-Zero将 \(\gamma, \beta\) 替换为条件预测的参数,并额外引入一个零初始化的门控参数 \(\alpha\):
初始化时 \(\alpha = 0\),使得每个DiT Block在训练初期是恒等映射,保证训练稳定性。
class AdaLNZero(nn.Module):
"""Adaptive Layer Norm Zero — DiT的核心条件注入模块"""
def __init__(self, dim, cond_dim):
super().__init__()
self.norm = nn.LayerNorm(dim, elementwise_affine=False)
# 预测 gamma, beta, alpha 三组参数
self.adaLN_modulation = nn.Sequential(
nn.SiLU(),
nn.Linear(cond_dim, 6 * dim) # 为attention和MLP各预测3个参数
)
# 零初始化
nn.init.zeros_(self.adaLN_modulation[-1].weight) # [-1]负索引取最后元素
nn.init.zeros_(self.adaLN_modulation[-1].bias)
def forward(self, x, cond):
# cond: [B, cond_dim] → 预测6组调制参数
(gamma1, beta1, alpha1,
gamma2, beta2, alpha2) = self.adaLN_modulation(cond).chunk(6, dim=-1)
return gamma1, beta1, alpha1, gamma2, beta2, alpha2
def modulate(self, x, gamma, beta):
return self.norm(x) * (1 + gamma.unsqueeze(1)) + beta.unsqueeze(1) # unsqueeze增加一个维度
2.4 DiT整体架构¶
输入潜变量 z ∈ R^(C×H×W)
↓
[PatchEmbed] → [B, N, D]
↓
+ Positional Embedding (sin-cos)
↓
┌─────────────────────────┐
│ DiT Block × L │
│ ┌─────────────────────┐ │
│ │ AdaLN-Zero(t, c) │ │
│ │ ↓ │ │
│ │ Self-Attention │ │
│ │ ↓ │ │
│ │ AdaLN-Zero(t, c) │ │
│ │ ↓ │ │
│ │ Pointwise FFN │ │
│ └─────────────────────┘ │
└─────────────────────────┘
↓
[Final AdaLN + Linear]
↓
[Unpatchify] → R^(2C×H×W) (预测噪声+对角协方差)
3. Transformer扩散模型家族¶
3.1 U-ViT¶
U-ViT(Bao et al., 2023, "All are Worth Words: A ViT Backbone for Diffusion Models")将U-Net的跳跃连接引入ViT:
- 将时间步、条件、图像patches视为统一的token序列
- 浅层和深层之间引入long skip connections
- 保持了U-Net跳跃连接的优势,同时获得Transformer的灵活性
3.2 SiT(Scalable Interpolant Transformers)¶
SiT(Ma et al., 2024)结合DiT架构与Flow Matching训练:
- 使用随机插值(Stochastic Interpolant)框架统一多种扩散/流匹配目标
- 在Interpolant框架下自由选择不同的时间调度和速度场参数化
- 同等FLOPs下优于DiT
3.3 MDT(Masked Diffusion Transformer)¶
MDT(Gao et al., 2023, "Masked Diffusion Transformer is a Strong Image Synthesizer")引入了掩码建模思想:
- 训练时随机mask部分patch tokens
- 模型同时学习去噪和patch预测
- 加速训练收敛,提升上下文学习能力
3.4 MM-DiT¶
MM-DiT(Esser et al., 2024)是SD3/SD3.5使用的多模态Transformer(详见上一章):
- 图像和文本token联合注意力
- 模态专属的QKV投影和MLP
- 支持多种文本编码器的灵活接入
3.5 架构对比¶
| 模型 | 条件注入 | 文本处理 | 跳跃连接 | 训练目标 |
|---|---|---|---|---|
| DiT | AdaLN-Zero | 类别标签 | 无 | DDPM \(\epsilon\)-prediction |
| U-ViT | In-context | Token拼接 | 有 | DDPM \(\epsilon\)-prediction |
| SiT | AdaLN-Zero | 类别标签 | 无 | Flow Matching |
| MDT | AdaLN-Zero | 类别标签 | 无 | DDPM + Mask |
| MM-DiT | AdaLN-Zero | 联合注意力 | 无 | Rectified Flow |
4. Scaling Laws在扩散Transformer中的应用¶
4.1 DiT的Scaling实验¶
DiT论文中系统测试了不同规模的模型:
| 模型 | 层数 | 维度 | 注意力头数 | 参数量 | FID (ImageNet 256) |
|---|---|---|---|---|---|
| DiT-S/2 | 12 | 384 | 6 | 33M | 68.4 |
| DiT-B/2 | 12 | 768 | 12 | 130M | 43.5 |
| DiT-L/2 | 24 | 1024 | 16 | 458M | 23.3 |
| DiT-XL/2 | 28 | 1152 | 16 | 675M | 9.62 |
关键发现: 1. FID与模型FLOPs呈幂律关系:\(\text{FID} \propto \text{GFLOPs}^{-\alpha}\) 2. 增大模型比增大patch更有效:相同GFLOPs下,大模型+大patch < 大模型+小patch 3. Scaling趋势未饱和:暗示更大模型将持续改进
4.2 工业级Scaling实践¶
Sora、FLUX等产品验证了扩散Transformer的Scaling Law:
其中 \(N\) 为参数量,\(D\) 为数据量,\(L\) 为验证损失。
实际数据: - FLUX.1: 12B参数 → 目前开源最佳图像质量 - Sora: 估计数十B参数 → 突破性视频质量 - 趋势:扩散模型正在重复LLM的scaling故事
5. 简化版DiT代码实现¶
import torch
import torch.nn as nn
import math
class SinusoidalPosEmb(nn.Module):
"""正弦时间步编码"""
def __init__(self, dim):
super().__init__()
self.dim = dim
def forward(self, t):
device = t.device
half_dim = self.dim // 2
emb = math.log(10000) / (half_dim - 1)
emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
emb = t[:, None].float() * emb[None, :]
return torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1) # torch.cat沿已有维度拼接张量
class DiTBlock(nn.Module):
"""DiT Transformer Block with AdaLN-Zero"""
def __init__(self, dim, num_heads, mlp_ratio=4.0):
super().__init__()
self.dim = dim
self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = head_dim ** -0.5
# Self-Attention
self.qkv = nn.Linear(dim, 3 * dim)
self.proj = nn.Linear(dim, dim)
# FFN
mlp_hidden = int(dim * mlp_ratio)
self.mlp = nn.Sequential(
nn.Linear(dim, mlp_hidden),
nn.GELU(),
nn.Linear(mlp_hidden, dim),
)
# AdaLN-Zero: 两层各需 gamma, beta, alpha → 6 * dim
self.norm1 = nn.LayerNorm(dim, elementwise_affine=False)
self.norm2 = nn.LayerNorm(dim, elementwise_affine=False)
self.adaLN = nn.Sequential(
nn.SiLU(),
nn.Linear(dim, 6 * dim),
)
# 关键:零初始化
nn.init.zeros_(self.adaLN[-1].weight)
nn.init.zeros_(self.adaLN[-1].bias)
def forward(self, x, cond):
"""
x: [B, N, D] — patch token序列
cond: [B, D] — 条件嵌入(时间步+类别)
"""
B, N, D = x.shape
# 预测6组调制参数
gamma1, beta1, alpha1, gamma2, beta2, alpha2 = \
self.adaLN(cond).chunk(6, dim=-1)
# --- Attention分支 ---
h = self.norm1(x) * (1 + gamma1.unsqueeze(1)) + beta1.unsqueeze(1)
qkv = self.qkv(h).reshape(B, N, 3, self.num_heads, D // self.num_heads) # 重塑张量形状
qkv = qkv.permute(2, 0, 3, 1, 4)
q, k, v = qkv.unbind(0)
attn = (q @ k.transpose(-2, -1)) * self.scale
attn = attn.softmax(dim=-1)
h = (attn @ v).transpose(1, 2).reshape(B, N, D)
h = self.proj(h)
x = x + alpha1.unsqueeze(1) * h # 零初始化门控
# --- FFN分支 ---
h = self.norm2(x) * (1 + gamma2.unsqueeze(1)) + beta2.unsqueeze(1)
h = self.mlp(h)
x = x + alpha2.unsqueeze(1) * h # 零初始化门控
return x
class DiT(nn.Module):
"""简化版 Diffusion Transformer"""
def __init__(
self,
img_size=32,
patch_size=2,
in_channels=4,
dim=768,
depth=12,
num_heads=12,
num_classes=10,
mlp_ratio=4.0,
):
super().__init__()
self.patch_size = patch_size
self.num_patches = (img_size // patch_size) ** 2
# Patch Embedding
self.patch_embed = nn.Conv2d(
in_channels, dim, kernel_size=patch_size, stride=patch_size
)
# 位置编码 (固定sin-cos)
self.pos_embed = nn.Parameter(
torch.zeros(1, self.num_patches, dim)
)
# 条件编码
self.time_embed = nn.Sequential(
SinusoidalPosEmb(dim),
nn.Linear(dim, dim),
nn.GELU(),
nn.Linear(dim, dim),
)
self.class_embed = nn.Embedding(num_classes, dim)
# Transformer Blocks
self.blocks = nn.ModuleList([
DiTBlock(dim, num_heads, mlp_ratio) for _ in range(depth)
])
# 输出层
self.final_norm = nn.LayerNorm(dim, elementwise_affine=False)
self.final_adaLN = nn.Sequential(
nn.SiLU(),
nn.Linear(dim, 2 * dim),
)
self.final_proj = nn.Linear(dim, patch_size ** 2 * in_channels * 2)
# 2倍channels用于预测噪声和对角方差
self._init_weights()
def _init_weights(self):
# 初始化位置编码
nn.init.normal_(self.pos_embed, std=0.02)
# 零初始化输出层
nn.init.zeros_(self.final_proj.weight)
nn.init.zeros_(self.final_proj.bias)
nn.init.zeros_(self.final_adaLN[-1].weight)
nn.init.zeros_(self.final_adaLN[-1].bias)
def unpatchify(self, x, H, W):
"""[B, N, p*p*C] → [B, C, H, W]"""
p = self.patch_size
c = x.shape[-1] // (p * p)
h, w = H // p, W // p
x = x.reshape(-1, h, w, p, p, c)
x = x.permute(0, 5, 1, 3, 2, 4).reshape(-1, c, H, W)
return x
def forward(self, x, t, y):
"""
x: [B, C, H, W] — 带噪声的潜变量
t: [B] — 时间步
y: [B] — 类别标签
"""
B, C, H, W = x.shape
# Patchify + 位置编码
x = self.patch_embed(x).flatten(2).transpose(1, 2) # [B, N, D]
x = x + self.pos_embed
# 条件编码
cond = self.time_embed(t) + self.class_embed(y) # [B, D]
# Transformer Blocks
for block in self.blocks:
x = block(x, cond)
# 输出投影
gamma, beta = self.final_adaLN(cond).chunk(2, dim=-1)
x = self.final_norm(x) * (1 + gamma.unsqueeze(1)) + beta.unsqueeze(1)
x = self.final_proj(x) # [B, N, p*p*2C]
x = self.unpatchify(x, H, W) # [B, 2C, H, W]
noise_pred, var_pred = x.chunk(2, dim=1)
return noise_pred, var_pred
# === 使用示例 ===
if __name__ == "__main__":
model = DiT(
img_size=32, patch_size=2, in_channels=4,
dim=384, depth=12, num_heads=6, num_classes=10
)
print(f"DiT-S/2 参数量: {sum(p.numel() for p in model.parameters()) / 1e6:.1f}M")
x = torch.randn(2, 4, 32, 32) # 潜空间输入
t = torch.randint(0, 1000, (2,))
y = torch.randint(0, 10, (2,))
noise_pred, var_pred = model(x, t, y)
print(f"输出形状: noise={noise_pred.shape}, var={var_pred.shape}")
# 输出: noise=torch.Size([2, 4, 32, 32]), var=torch.Size([2, 4, 32, 32])
📋 面试要点¶
高频面试题¶
- DiT相比U-Net的核心优势是什么?
- Transformer架构具有清晰的Scaling Laws,模型越大效果持续提升
- 弱归纳偏置使其更灵活,易于多模态扩展
- 可复用LLM训练基础设施(并行策略、优化器等)
-
U-Net的下采样/上采样结构在极高分辨率下效率受限
-
AdaLN-Zero为什么效果最好?
- 零初始化使每个Block初始为恒等映射,类似ResNet的残差学习
- 训练初期模型=直接跳过所有层,梯度直接传到输入,训练更稳定
-
条件信息通过可学习的scale/shift注入,比简单拼接更灵活
-
DiT的Patch Size如何选择?
- \(p=2\) 最常用(DiT-XL/2达到FID 9.62)
- 较小的patch保留更多空间细节但增加序列长度(计算量 \(O(N^2)\))
- 实际中通常在潜空间(如8×下采样后的32×32)上使用patch_size=2
-
对于高分辨率可考虑patch_size=4配合窗口注意力
-
如何理解扩散Transformer的Scaling Laws?
- FID与GFLOPs呈幂律下降关系
- 这与LLM中Loss与参数量/数据量的幂律关系类似
- Sora、FLUX等产品验证了大规模扩散Transformer的有效性
- 意味着扩散模型的天花板远未达到
✏️ 练习¶
练习1:DiT-S/2训练¶
在CIFAR-10上训练简化版DiT-S/2,对比不同条件注入方式(in-context / cross-attention / AdaLN-Zero)的收敛速度和最终FID。
练习2:Patch Size消融实验¶
固定模型参数量,测试patch_size=½/4对生成质量和训练速度的影响。
练习3:结构变体实验¶
实现U-ViT的long skip connection并对比纯DiT,观察收敛速度的差异。
练习4:论文精读¶
- 精读DiT原始论文,关注Table 1-3的Scaling实验
- 阅读U-ViT论文,理解其与DiT在设计哲学上的差异
参考文献¶
- Peebles & Xie, 2023. "Scalable Diffusion Models with Transformers" — DiT原始论文
- Bao et al., 2023. "All are Worth Words: A ViT Backbone for Diffusion Models" — U-ViT
- Ma et al., 2024. "Sit: Exploring Flow and Diffusion-based Generative Models with Scalable Interpolant Transformers" — SiT
- Gao et al., 2023. "Masked Diffusion Transformer is a Strong Image Synthesizer" — MDT
- Esser et al., 2024. "Scaling Rectified Flow Transformers for High-Resolution Image Synthesis" — MM-DiT/SD3
- Dosovitskiy et al., 2021. "An Image is Worth 16x16 Words" — ViT
下一章: 08-流匹配与一致性模型 — 探索更高效的扩散训练与采样范式