10 - FLUX与新一代架构¶
学习时间: 6小时 重要性: ⭐⭐⭐⭐⭐ 2024-2025年最前沿的图像生成架构,面试高频考点 前置知识: 07-DiT与Transformer扩散架构 | 08-流匹配与一致性模型 | 06-Stable Diffusion进阶(SDXL到SD3)
🎯 学习目标¶
完成本章后,你将能够: - 深入理解 FLUX.1 的完整架构(DiT backbone + MM-DiT 双流注意力 + 2D RoPE) - 掌握 FLUX vs SD3 vs SDXL 核心差异与演进逻辑 - 从数学推导到代码实现 Rectified Flow(直线流匹配) - 理解 ControlNet 的 zero-convolution 设计与预训练权重复用 - 掌握 IP-Adapter 解耦交叉注意力原理 - 区分 T2I-Adapter、ControlNet、IP-Adapter 的设计哲学与适用场景 - 使用 diffusers 完成 FLUX 推理与 ControlNet 条件控图
1. FLUX.1 架构全解析¶
1.1 背景与定位¶
FLUX.1 由 Black Forest Labs(Stability AI 前核心团队 Robin Rombach 等)于 2024 年 8 月发布,代表了"纯 Transformer + Rectified Flow"范式的工程化巅峰。
| 版本 | 参数量 | 许可证 | 特点 |
|---|---|---|---|
| FLUX.1 [pro] | 12B | 商业API | 最高质量,仅API调用 |
| FLUX.1 [dev] | 12B | 非商业开源 | 指导蒸馏后的开发版本 |
| FLUX.1 [schnell] | 12B | Apache 2.0 | 1-4步超快推理 |
💡 FLUX.1 与 SD3 共享核心团队和设计思想(MM-DiT、Rectified Flow、多文本编码器),但 FLUX.1 在架构细节上做了大量改进。
1.2 整体架构鸟瞰¶
文本编码器: CLIP ViT-L/14 (pooled 768d) + T5-XXL (序列 4096d)
│
Patchify (2×2) ▼ 条件向量: timestep_embed + pooled_text
潜空间 z_T ──────→ 双流 MM-DiT Blocks (×19) ←─────────────────────────┘
+ 2D RoPE img流 / txt流 Joint Attention
│
▼
单流 DiT Blocks (×38)
img+txt拼接为统一序列, 共享Self-Attn
│
Unpatchify + VAE Decode → 生成图像 (1024²)
双文本编码器设计:
| 编码器 | 输出维度 | 作用 |
|---|---|---|
| CLIP ViT-L/14 | 768-dim pooled | 全局语义向量,用于时间步嵌入的 AdaLN 条件调制 |
| T5-XXL (encoder-only) | 4096-dim × seq_len | 丰富的文本序列表示,作为 txt 流输入 |
去掉 SD3 的第二个 CLIP-bigG,因为 T5-XXL 的序列表示已足够丰富,第二个 CLIP 边际收益有限,去掉可减少显存和延迟。
1.3 MM-DiT 双流注意力¶
📌 深入展开 07-DiT与Transformer扩散架构 中的 MM-DiT 概念。
传统 Cross-Attention vs MM-DiT Joint Attention:
传统 Cross-Attention (SD 1.5/SDXL):
Q = image_tokens × W_q ← 只有图像做查询
K = text_tokens × W_k ← 文本做键值
V = text_tokens × W_v
→ 信息单向: text → image
MM-DiT Joint Attention (SD3/FLUX):
Q = concat(img_tokens × W_q^img, txt_tokens × W_q^txt)
K = concat(img_tokens × W_k^img, txt_tokens × W_k^txt)
V = concat(img_tokens × W_v^img, txt_tokens × W_v^txt)
Attn = softmax(QK^T/√d) V
→ 信息双向: text ↔ image
双流设计:图像和文本各自维护独立的 LayerNorm/MLP 参数,但在注意力计算时拼接做 Joint Attention,实现双向信息交互。
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
import math
class MMDiTBlock(nn.Module): # 继承nn.Module定义网络层
"""MM-DiT双流注意力块
双流设计:img/txt各自独立Norm/MLP参数,
但注意力时拼接QKV实现双向信息交互。
"""
def __init__(self, dim=3072, heads=24, mlp_ratio=4.0):
super().__init__() # super()调用父类方法
self.heads = heads
self.head_dim = dim // heads
# ====== 图像流参数 ======
self.img_norm1 = nn.LayerNorm(dim, elementwise_affine=False)
self.img_qkv = nn.Linear(dim, dim * 3)
self.img_out = nn.Linear(dim, dim)
self.img_norm2 = nn.LayerNorm(dim, elementwise_affine=False)
self.img_mlp = nn.Sequential(
nn.Linear(dim, int(dim * mlp_ratio)),
nn.GELU(approximate="tanh"),
nn.Linear(int(dim * mlp_ratio), dim))
# ====== 文本流参数(结构相同,权重独立)======
self.txt_norm1 = nn.LayerNorm(dim, elementwise_affine=False)
self.txt_qkv = nn.Linear(dim, dim * 3)
self.txt_out = nn.Linear(dim, dim)
self.txt_norm2 = nn.LayerNorm(dim, elementwise_affine=False)
self.txt_mlp = nn.Sequential(
nn.Linear(dim, int(dim * mlp_ratio)),
nn.GELU(approximate="tanh"),
nn.Linear(int(dim * mlp_ratio), dim))
# AdaLN-Zero: 12个调制参数
# img: (shift_a, scale_a, gate_a, shift_m, scale_m, gate_m) × 1
# txt: 同上 × 1
self.adaLN = nn.Sequential(nn.SiLU(), nn.Linear(dim, dim * 12))
def _modulate(self, x, shift, scale):
return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1) # unsqueeze增加一个维度
def forward(self, img, txt, cond):
"""
img: [B, N_img, D] 图像patch序列
txt: [B, N_txt, D] 文本token序列
cond: [B, D] 时间步+pooled文本条件
"""
# chunk(12)将条件向量切分为12个调制参数:每对(shift,scale,gate)控制norm/attn/mlp,图像和文本各6个
m = self.adaLN(cond).chunk(12, dim=-1)
# ---- Joint Attention ----
img_n = self._modulate(self.img_norm1(img), m[0], m[1])
txt_n = self._modulate(self.txt_norm1(txt), m[6], m[7])
iq, ik, iv = self.img_qkv(img_n).chunk(3, dim=-1)
tq, tk, tv = self.txt_qkv(txt_n).chunk(3, dim=-1)
# 拼接Q/K/V做联合注意力
q = torch.cat([iq, tq], dim=1) # torch.cat沿已有维度拼接张量
k = torch.cat([ik, tk], dim=1)
v = torch.cat([iv, tv], dim=1)
q = rearrange(q, 'b n (h d) -> b h n d', h=self.heads)
k = rearrange(k, 'b n (h d) -> b h n d', h=self.heads)
v = rearrange(v, 'b n (h d) -> b h n d', h=self.heads)
attn = F.scaled_dot_product_attention(q, k, v) # Flash Attention # F.xxx PyTorch函数式API
attn = rearrange(attn, 'b h n d -> b n (h d)')
# 拆分回img和txt,各自output projection + gate
N = img.shape[1]
img = img + m[2].unsqueeze(1) * self.img_out(attn[:, :N])
txt = txt + m[8].unsqueeze(1) * self.txt_out(attn[:, N:])
# ---- MLP (各自独立) ----
img = img + m[5].unsqueeze(1) * self.img_mlp(
self._modulate(self.img_norm2(img), m[3], m[4]))
txt = txt + m[11].unsqueeze(1) * self.txt_mlp(
self._modulate(self.txt_norm2(txt), m[9], m[10]))
return img, txt
1.4 2D Rotary Position Encoding¶
FLUX 用 2D RoPE 替代 SD3 的固定正弦编码——这是 LLM 社区 RoPE(Su et al., 2021)在二维空间的推广。
将 head_dim 分两半,分别编码行与列位置:
注意力分数自然编码二维空间距离:
核心优势: - 分辨率外推:RoPE 是相对位置编码,可泛化到训练时未见的分辨率 - 任意长宽比:不依赖固定 grid size - 无需学习:位置编码是确定性函数,零额外参数
import torch
def get_2d_rope_freqs(height, width, dim, theta=10000.0):
"""生成2D RoPE复数频率矩阵
Returns: [H*W, dim//2] 复数频率
"""
quarter_dim = dim // 4
freqs = 1.0 / (theta ** (torch.arange(0, quarter_dim).float() / quarter_dim))
# 行频率 → 前半维度
h_freqs = torch.outer(torch.arange(height).float(), freqs) # [H, qd]
h_freqs = h_freqs.unsqueeze(1).expand(-1, width, -1).reshape(-1, quarter_dim) # 重塑张量形状
# 列频率 → 后半维度
w_freqs = torch.outer(torch.arange(width).float(), freqs) # [W, qd]
w_freqs = w_freqs.unsqueeze(0).expand(height, -1, -1).reshape(-1, quarter_dim)
freqs_2d = torch.cat([h_freqs, w_freqs], dim=-1) # [H*W, dim//2]
return torch.polar(torch.ones_like(freqs_2d), freqs_2d)
def apply_2d_rope(x, freqs):
"""x: [B, heads, N, head_dim] → 旋转后同形状"""
# reshape将最后一维两两配对为(实部,虚部),view_as_complex转为复数,乘以频率复数实现旋转位置编码
xc = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2))
rotated = xc * freqs.unsqueeze(0).unsqueeze(0)
return torch.view_as_real(rotated).reshape_as(x).to(x.dtype)
# 示例:1024×1024图像, patch=2, VAE下采样8x → 64×64 patches
freqs = get_2d_rope_freqs(height=64, width=64, dim=128)
print(f"2D RoPE freqs: {freqs.shape}") # [4096, 64]
1.5 双流→单流混合架构¶
| 阶段 | 类型 | 层数 | 特点 |
|---|---|---|---|
| 第一阶段 | 双流 MM-DiT | 19层 | img/txt 各自独立 Norm/MLP,Joint Attention |
| 第二阶段 | 单流 DiT | 38层 | img+txt 拼接为统一序列,完全共享参数 |
设计哲学: - 全双流:每层两套参数,参数量大、效率低 - 全单流:浅层就共享参数,模态对齐不充分 - 混合策略:前期保留模态独立性,后期深度融合,效率与效果最优平衡
class SingleDiTBlock(nn.Module):
"""单流DiT块(第二阶段,img+txt共享全部参数)"""
def __init__(self, dim=3072, heads=24, mlp_ratio=4.0):
super().__init__()
self.heads = heads
self.norm1 = nn.LayerNorm(dim, elementwise_affine=False)
self.qkv = nn.Linear(dim, dim * 3)
self.proj = nn.Linear(dim, dim)
self.norm2 = nn.LayerNorm(dim, elementwise_affine=False)
self.mlp = nn.Sequential(
nn.Linear(dim, int(dim * mlp_ratio)),
nn.GELU(approximate="tanh"),
nn.Linear(int(dim * mlp_ratio), dim))
self.adaLN = nn.Sequential(nn.SiLU(), nn.Linear(dim, dim * 6))
def forward(self, x, cond):
"""x: [B, N_img+N_txt, D] 拼接后的统一序列"""
s_a, sc_a, g_a, s_m, sc_m, g_m = self.adaLN(cond).chunk(6, dim=-1)
xn = self.norm1(x) * (1 + sc_a.unsqueeze(1)) + s_a.unsqueeze(1)
# 列表推导:qkv线性层输出chunk(3)拆为Q/K/V,每个用rearrange从(B,N,heads*d)重排为(B,heads,N,d)多头格式
q, k, v = [rearrange(t, 'b n (h d) -> b h n d', h=self.heads)
for t in self.qkv(xn).chunk(3, dim=-1)]
attn = rearrange(F.scaled_dot_product_attention(q, k, v), 'b h n d -> b n (h d)')
x = x + g_a.unsqueeze(1) * self.proj(attn)
xn2 = self.norm2(x) * (1 + sc_m.unsqueeze(1)) + s_m.unsqueeze(1)
x = x + g_m.unsqueeze(1) * self.mlp(xn2)
return x
def timestep_embedding(t, dim, max_period=10000):
"""正弦时间步嵌入"""
half = dim // 2
freqs = torch.exp(-math.log(max_period) * torch.arange(half, device=t.device) / half)
args = t[:, None].float() * freqs[None]
return torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
📝 面试考点 1. MM-DiT 与 cross-attention 的本质区别? cross-attention 信息单向 text→image;MM-DiT Joint Attention 拼接 Q/K/V 实现双向信息流 text↔image。 2. 为什么双流→单流而非全双流? 全双流每层两套参数效率低;混合设计前期保留模态特异性,后期共享参数深度融合。 3. 2D RoPE 相比固定位置编码的优势? 相对位置编码,支持分辨率外推和任意长宽比,无需学习参数。
2. FLUX vs SD3 vs SDXL 架构对比¶
2.1 核心架构对比表¶
| 维度 | SDXL (2023.07) | SD3 (2024.03) | FLUX.1 (2024.08) |
|---|---|---|---|
| 骨干网络 | U-Net (3.5B) | MM-DiT (2B/8B) | 混合MM-DiT+DiT (12B) |
| 噪声调度 | DDPM (离散) | Rectified Flow (连续) | Rectified Flow (连续) |
| 文本编码器 | CLIP-L + CLIP-bigG | CLIP-L + CLIP-bigG + T5-XXL | CLIP-L + T5-XXL |
| 位置编码 | 固定正弦 1D | 固定正弦 2D | 2D RoPE |
| 注意力 | Self + Cross-Attn | MM-DiT Joint Attn | 双流MM-DiT + 单流Self |
| VAE通道 | 4ch | 16ch | 16ch |
| 训练目标 | ε-prediction | v-prediction (Flow) | v-prediction (RF,预测速度 \(v_t = x_1 - x_0\)) |
| 最快推理 | ~25步 | ~28步 | 1-4步 (schnell) |
| 推荐CFG | 7.0-7.5 | 7.0 | 3.5 |
| 文字渲染 | 差 | 中 | 优秀 |
| VRAM (fp16) | ~6.5GB | ~8GB | ~24GB (NF4→~8GB) |
2.2 关键升级解读¶
16通道 VAE(SD3/FLUX 共享):
SD 1.5 / SDXL: VAE latent = 4 channels → 1024×1024 → 128×128×4
SD3 / FLUX.1: VAE latent = 16 channels → 1024×1024 → 128×128×16
更丰富的潜空间表示→VAE 重建质量更高→减少高频细节损失(文字笔画、纤细纹理)。代价是 Transformer 输入维度从 4 增至 16(patchify 后 \(16\times2\times2=64\))。
FLUX schnell 无需 CFG:通过指导蒸馏将 CFG 引导效果内化到模型权重,一次前向 = 教师的有条件+无条件两次前向。
2.3 推理配置典型值¶
configs = {
"SDXL": {"steps": 25, "cfg": 7.0, "scheduler": "DPM++ 2M Karras"},
"SD3-Medium": {"steps": 28, "cfg": 7.0, "scheduler": "Euler + logit-normal"},
"FLUX.1-dev": {"steps": 28, "cfg": 3.5, "scheduler": "Euler"},
"FLUX.1-schnell": {"steps": 4, "cfg": 0.0, "scheduler": "Euler (无CFG)"},
}
📝 面试考点 1. 16通道 VAE 的目的? 提升潜空间信息容量,减少高频细节损失,改善文字渲染。 2. FLUX 为何只用 2 个文本编码器? T5-XXL 已涵盖 CLIP-bigG 的语义贡献,去掉减少显存。 3. FLUX schnell 为何不需 CFG? 指导蒸馏已将 CFG 引导效果内化到权重。
3. Rectified Flow:直线流匹配¶
📌 本节深化 08-流匹配与一致性模型 的流匹配理论。
3.1 核心原理¶
Rectified Flow(Liu et al., 2023, "Flow Straight and Fast")的核心思想:让概率传输路径为直线。
线性插值路径:
目标速度场:\(u_t(x_t|x_0,x_1) = x_1 - x_0\)(与 \(t\) 无关——正是"直线"的含义)
训练损失:
3.2 ODE vs SDE 采样¶
ODE(确定性):\(x_{t+\Delta t} = x_t + \Delta t \cdot v_\theta(x_t, t)\) 直线路径→理想 1 步 Euler 即可,实际 4-50 步。
SDE(随机性):\(dx_t = v_\theta dt + \sigma_t dW_t\) 少量噪声可修正 ODE 误差累积,提升多样性。
3.3 与 DDPM 对比¶
| 维度 | DDPM | Rectified Flow |
|---|---|---|
| 路径类型 | 曲线 (\(\sqrt{\bar\alpha_t}x_0 + \sqrt{1-\bar\alpha_t}\epsilon\)) | 直线 (\((1-t)x_0 + tx_1\)) |
| 预测目标 | 噪声 \(\epsilon\) | 速度 \(v = x_1 - x_0\) |
| 采样步数 | 通常 20-1000 | 通常 1-50 |
| 少步质量 | 差(需DDIM/DPM++加速) | 好(路径本身近似直线) |
| 数学框架 | 离散马尔可夫链 / SDE | 连续 ODE (CNF) |
3.4 ReFlow 蒸馏¶
用当前模型 \(v_{\theta_1}\) 做 ODE 推理得新 \((x_0, \hat{x}_1)\) 对,重新训练 \(v_{\theta_2}\)。新 pair 间 ODE 路径更直→\(v_{\theta_2}\) 需更少步数。迭代可进一步拉直。FLUX schnell 即用类似策略实现 1-4 步。
3.5 完整训练与采样代码¶
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
class VelocityNet(nn.Module):
"""简化的速度场网络(用于小图像验证)"""
def __init__(self, ch=1, size=28, hidden=256):
super().__init__()
d = ch * size * size
self.net = nn.Sequential(
nn.Linear(d + 1, hidden), nn.SiLU(),
nn.Linear(hidden, hidden), nn.SiLU(),
nn.Linear(hidden, hidden), nn.SiLU(),
nn.Linear(hidden, d))
self.shape = (ch, size, size)
def forward(self, x_t, t):
B = x_t.shape[0]
x_flat = x_t.reshape(B, -1)
return self.net(torch.cat([x_flat, t.reshape(B,1)], -1)).reshape(B, *self.shape)
class RectifiedFlowTrainer:
"""Rectified Flow 训练与采样"""
def __init__(self, model, lr=2e-4, device="cuda"):
self.model = model.to(device) # 移至GPU/CPU
self.opt = torch.optim.AdamW(model.parameters(), lr=lr)
self.device = device
def train_step(self, x_1):
"""x_1: [B,C,H,W] 真实数据(对应t=1)"""
B = x_1.shape[0]
x_0 = torch.randn_like(x_1) # 噪声
t = torch.rand(B, 1, device=self.device)
t_e = t.unsqueeze(-1).unsqueeze(-1) # [B,1,1,1]
x_t = (1 - t_e) * x_0 + t_e * x_1 # 线性插值
target = x_1 - x_0 # 目标速度
loss = F.mse_loss(self.model(x_t, t), target)
self.opt.zero_grad(); loss.backward(); self.opt.step() # 清零梯度 # 反向传播计算梯度
return loss.item() # 将单元素张量转为Python数值
@torch.no_grad() # 禁用梯度计算,节省内存
def sample_ode(self, batch_size, num_steps=20):
"""ODE Euler采样(确定性)"""
self.model.eval() # eval()评估模式
x = torch.randn(batch_size, *self.model.shape, device=self.device)
dt = 1.0 / num_steps
for i in range(num_steps):
t = torch.full((batch_size, 1), i * dt, device=self.device)
x = x + dt * self.model(x, t)
self.model.train()
return x.clamp(-1, 1)
@torch.no_grad()
def sample_sde(self, batch_size, num_steps=20, sigma=0.01):
"""SDE采样(带随机噪声修正)"""
self.model.eval()
x = torch.randn(batch_size, *self.model.shape, device=self.device)
dt = 1.0 / num_steps
for i in range(num_steps):
t = torch.full((batch_size, 1), i * dt, device=self.device)
noise = torch.randn_like(x) * sigma * dt**0.5 if i < num_steps - 1 else 0
x = x + dt * self.model(x, t) + noise
self.model.train()
return x.clamp(-1, 1)
def train_rectified_flow():
device = "cuda" if torch.cuda.is_available() else "cpu"
transform = transforms.Compose([
transforms.ToTensor(), transforms.Normalize([0.5], [0.5])])
dl = DataLoader(datasets.MNIST("./data", train=True, download=True, # DataLoader批量加载数据
transform=transform), batch_size=128, shuffle=True)
trainer = RectifiedFlowTrainer(VelocityNet(), device=device)
for epoch in range(10):
avg_loss = sum(trainer.train_step(x.to(device)) for x,_ in dl) / len(dl)
print(f"Epoch {epoch}: loss={avg_loss:.4f}")
if epoch % 5 == 0:
samples = trainer.sample_ode(16, num_steps=10)
print(f" ODE 10步: range [{samples.min():.2f}, {samples.max():.2f}]")
samples_sde = trainer.sample_sde(16, num_steps=10, sigma=0.05)
print(f" SDE 10步: range [{samples_sde.min():.2f}, {samples_sde.max():.2f}]")
📝 面试考点 1. RF 比 DDPM 高效的核心原因? 直线路径→Euler 离散化误差极小→少步即可高质量。 2. v-prediction 与 ε-prediction 的关系? \(v = x_1 - x_0 = x_1 - \epsilon\),二者可互转。 3. ReFlow 蒸馏如何拉直路径? 用当前模型 ODE 推理产生新 pair 重训,新 pair 间路径更直。
4. ControlNet 原理与实现¶
4.1 核心设计哲学¶
ControlNet(Zhang et al., 2023)教扩散模型遵循空间条件(边缘/深度/姿态),三大原则:
- 复用预训练权重:深拷贝 U-Net encoder blocks 作为可训练副本
- 条件编码器:将条件图编码为与 latent 对齐的特征
- zero-convolution 安全连接:确保训练开始时不破坏原始模型
4.2 Zero-Convolution 详解¶
权重和偏置全零初始化的 1×1 卷积:\(y = Wx + b,\; W=0,\; b=0\)
| 初始化方式 | 训练初期行为 | 风险 |
|---|---|---|
| 随机初始化 | 随机信号注入原始模型 | ❌ 瞬间破坏预训练质量 |
| 零初始化全连接 | 梯度全零,对称性问题 | ❌ 训练僵死 |
| Zero-Conv | 输出为零,原始模型不受影响 | ✅ 安全起步 |
关键洞察:zero-conv 只是 ControlNet→主模型的"桥"初始化为零。副本内部参数继承自预训练模型(非零),梯度正常反传,进而更新 zero-conv 参数。
import torch, torch.nn as nn, copy
class ZeroConv(nn.Module):
"""零初始化1×1卷积"""
def __init__(self, ch):
super().__init__()
self.conv = nn.Conv2d(ch, ch, kernel_size=1)
nn.init.zeros_(self.conv.weight)
nn.init.zeros_(self.conv.bias)
def forward(self, x):
return self.conv(x)
class ControlNet(nn.Module):
"""ControlNet for U-Net based models (SD1.5/SDXL, 简化示意)"""
def __init__(self, base_unet, cond_channels=3):
super().__init__()
# 条件编码器:将条件图(如Canny边缘)编码为latent尺寸特征
self.cond_encoder = nn.Sequential(
nn.Conv2d(cond_channels, 64, 3, stride=2, padding=1), nn.SiLU(),
nn.Conv2d(64, 128, 3, stride=2, padding=1), nn.SiLU(),
nn.Conv2d(128, 256, 3, stride=2, padding=1), nn.SiLU(),
nn.Conv2d(256, 320, 3, stride=1, padding=1))
# 复制encoder各层 + zero-conv连接
encoder_channels = [320, 640, 1280, 1280]
self.ctrl_blocks = nn.ModuleList()
self.zero_convs = nn.ModuleList()
for ch in encoder_channels:
self.ctrl_blocks.append(nn.Sequential(
nn.Conv2d(ch, ch, 3, padding=1),
nn.GroupNorm(32, ch), nn.SiLU()))
self.zero_convs.append(ZeroConv(ch))
# 冻结原始U-Net全部参数
for p in base_unet.parameters():
p.requires_grad = False
def forward(self, cond_image):
"""返回各层控制信号(加到decoder对应层)"""
h = self.cond_encoder(cond_image)
signals = []
for block, zc in zip(self.ctrl_blocks, self.zero_convs): # zip按位置配对
h = block(h)
signals.append(zc(h))
return signals
4.3 训练配置¶
# 关键训练细节:
# 1. 冻结原始U-Net全部参数
# 2. 只训练:条件编码器 + ControlNet副本 + zero-convolution
# 3. 学习率1e-5(较小,从预训练权重开始)
# 4. 数据:成对的(图像, 条件图),如(照片, Canny边缘)
# 5. 数据量仅需原始模型训练量的5-10%
📝 面试考点 1. zero-conv 的设计动因? 桥为零→训练开始不破坏预训练→控制信号渐进注入。副本内部预训练参数保证梯度正常。 2. 训练哪些参数? 条件编码器 + encoder副本 + zero-conv;原始U-Net全冻结。 3. 如何保护预训练质量? zero-conv 初始为零→原始模型初始行为不变→控制渐进式注入。
5. ControlNet for FLUX/SD3 适配¶
5.1 挑战¶
FLUX/SD3 是纯 Transformer 无 encoder-decoder 结构,传统 ControlNet 的 encoder→decoder 注入方式不适用。
| 挑战 | U-Net ControlNet | DiT ControlNet |
|---|---|---|
| 注入点 | encoder block → decoder block | 无decoder,需在DiT block层级注入 |
| 条件格式 | 空间特征图 | 需patchify为token序列 |
| 多模态 | 无需考虑 | 需处理MM-DiT双流结构 |
5.2 方案:复制前 N 层 + zero-linear¶
class FLUXControlNet(nn.Module):
"""FLUX ControlNet适配(简化版)"""
def __init__(self, flux_model, n_ctrl=8, dim=3072, cond_ch=3):
super().__init__()
# 条件编码器:图像条件 → token序列
self.cond_enc = nn.Sequential(
nn.Conv2d(cond_ch, 64, 3, 2, 1), nn.SiLU(),
nn.Conv2d(64, 128, 3, 2, 1), nn.SiLU(),
nn.Conv2d(128, 256, 3, 2, 1), nn.SiLU(),
nn.Conv2d(256, dim // 4, 1))
self.proj = nn.Linear(dim // 4 * 4, dim) # patchify后投影
# 复制前N个双流MM-DiT块
self.ctrl_blocks = nn.ModuleList([
copy.deepcopy(flux_model.double_blocks[i]) for i in range(n_ctrl)])
self.zero_lins = nn.ModuleList([
ZeroLinear(dim) for _ in range(n_ctrl)])
for p in flux_model.parameters():
p.requires_grad = False
def forward(self, img_tok, txt_tok, cond, cond_img, flux):
cond_tok = self.proj(rearrange(
self.cond_enc(cond_img),
'b c (h p1) (w p2) -> b (h w) (c p1 p2)', p1=2, p2=2))
ci, ct = img_tok.clone(), txt_tok.clone()
for i, main_blk in enumerate(flux.double_blocks): # enumerate同时获取索引和元素
img_tok, txt_tok = main_blk(img_tok, txt_tok, cond)
if i < len(self.ctrl_blocks):
ci, ct = self.ctrl_blocks[i](ci + cond_tok, ct, cond)
img_tok = img_tok + self.zero_lins[i](ci) # zero-linear注入
return img_tok, txt_tok
class ZeroLinear(nn.Module):
"""零初始化线性投影"""
def __init__(self, d):
super().__init__()
self.lin = nn.Linear(d, d)
nn.init.zeros_(self.lin.weight); nn.init.zeros_(self.lin.bias)
def forward(self, x): return self.lin(x)
只复制前 N 层的原因:前几层捕获低级空间结构(与条件最相关),全复制 12B 模型不现实。
📝 面试考点 FLUX ControlNet 与 SD1.5 ControlNet 主要区别? SD1.5 利用 U-Net encoder-decoder 对称性注入;FLUX 无 decoder,在 DiT block 层级直接用 zero-linear 注入。
6. IP-Adapter:图像提示适配器¶
6.1 解耦交叉注意力¶
IP-Adapter(Ye et al., 2023)让模型以参考图像做 prompt。核心:文本和图像 cross-attention 完全分离(Decoupled Cross-Attention)。
简单拼接: K = concat(K_text, K_img) → softmax中互相竞争, 质量下降
解耦方案: Output = CrossAttn(Q, K_text, V_text) + λ · CrossAttn(Q, K_img, V_img)
↑ 原始文本路径(冻结) ↑ 新增图像路径(可训练)
流程:参考图 → CLIP Image Encoder → 可学习投影层 → image_tokens → 独立 K/V 投影
class DecoupledCrossAttn(nn.Module):
"""解耦交叉注意力:分别处理文本和图像条件"""
def __init__(self, dim=768, heads=12):
super().__init__()
self.heads, self.hd = heads, dim // heads
# 原始文本cross-attention (冻结)
self.q = nn.Linear(dim, dim)
self.k_text = nn.Linear(dim, dim)
self.v_text = nn.Linear(dim, dim)
self.out_text = nn.Linear(dim, dim)
# 新增图像cross-attention (可训练)
self.k_img = nn.Linear(dim, dim)
self.v_img = nn.Linear(dim, dim)
self.out_img = nn.Linear(dim, dim)
self.ip_scale = nn.Parameter(torch.tensor(1.0))
def forward(self, x, text_emb, img_tokens=None):
B, N, D = x.shape
def to_heads(t, s):
return t.view(B, s, self.heads, self.hd).transpose(1, 2)
q = to_heads(self.q(x), N)
# 文本cross-attention
kt = to_heads(self.k_text(text_emb), text_emb.shape[1])
vt = to_heads(self.v_text(text_emb), text_emb.shape[1])
out = self.out_text(
F.scaled_dot_product_attention(q, kt, vt).transpose(1,2).reshape(B,N,D))
# 图像cross-attention(解耦路径)
if img_tokens is not None:
ki = to_heads(self.k_img(img_tokens), img_tokens.shape[1])
vi = to_heads(self.v_img(img_tokens), img_tokens.shape[1])
img_out = self.out_img(
F.scaled_dot_product_attention(q, ki, vi).transpose(1,2).reshape(B,N,D))
out = out + self.ip_scale * img_out
return out
6.2 变体与训练开销¶
| 变体 | 图像编码器 | token数 | 用途 |
|---|---|---|---|
| IP-Adapter | CLIP ViT-H (CLS) | 4 | 全局风格转移 |
| IP-Adapter Plus | CLIP ViT-H (patch tokens) | 16 | 保留更多空间细节 |
| IP-Adapter FaceID | InsightFace + CLIP | 4 | 人脸一致性保持 |
训练:冻结 U-Net + CLIP Image Encoder,仅训练投影层和图像 cross-attn K/V/out,约 22M 参数。
📝 面试考点 1. 解耦 vs 拼接的优势? 拼接导致注意力竞争和模态干扰;解耦各自 K/V 独立投影,
λ灵活控制强度。 2. 训练参数量? 仅约 22M(投影层+图像 K/V/out),原始模型全冻结。
7. T2I-Adapter vs ControlNet vs IP-Adapter 对比¶
| 维度 | T2I-Adapter | ControlNet | IP-Adapter |
|---|---|---|---|
| 条件类型 | 空间(边缘/深度/姿态) | 空间(边缘/深度/姿态) | 语义(参考图像) |
| 注入方式 | 特征直接加法 | zero-conv 桥接 | 解耦 cross-attn |
| 可训练参数 | ~77M(轻量独立网络) | ~361M(encoder副本) | ~22M(投影+K/V) |
| 控制精度 | 中等 | 高 | 低(语义级) |
| 推理显存增加 | 极少 | ~50% | 极少 |
| 可组合性 | ✅ 可叠加 | ✅ MultiControlNet | ✅ 与ControlNet共用 |
选择策略: - 精确空间控制 → ControlNet(利用预训练encoder,精度最高) - 风格/图像参考 → IP-Adapter(语义级控制,参数最少) - 低显存多条件 → T2I-Adapter(轻量独立,即插即用) - 精确控制+风格 → ControlNet + IP-Adapter 组合
📝 面试考点 1. ControlNet 比 T2I-Adapter 更精确的原因? ControlNet 复用预训练 encoder 全部参数,对结构理解更深;T2I-Adapter 用独立小网络从零学。 2. 三者能否同时使用? 可以,注入位置不同。ControlNet(空间控制)+ IP-Adapter(风格参考)是常见组合。
8. 实战:FLUX 推理代码¶
8.1 FLUX.1-dev 文生图¶
import torch
from diffusers import FluxPipeline, BitsAndBytesConfig
# 方式1: NF4量化加载 (~8GB VRAM)
pipe = FluxPipeline.from_pretrained(
"black-forest-labs/FLUX.1-dev",
quantization_config=BitsAndBytesConfig(
load_in_4bit=True, bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16),
torch_dtype=torch.bfloat16)
# 方式2: CPU offload (更通用)
# pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-dev",
# torch_dtype=torch.bfloat16)
# pipe.enable_model_cpu_offload()
image = pipe(
"A serene mountain landscape at sunset, crystal-clear lake, photorealistic, 8K",
height=1024, width=1024,
guidance_scale=3.5, # FLUX推荐低CFG
num_inference_steps=28,
max_sequence_length=512, # T5最大token长度
generator=torch.Generator("cuda").manual_seed(42)
).images[0]
image.save("flux_dev.png")
8.2 FLUX.1-schnell 快速推理¶
pipe = FluxPipeline.from_pretrained(
"black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16)
pipe.enable_model_cpu_offload()
image = pipe(
"A cyberpunk cityscape, neon lights, rain reflections",
height=1024, width=1024,
guidance_scale=0.0, # schnell不需CFG(蒸馏已内化)
num_inference_steps=4, # 仅4步!
generator=torch.Generator("cuda").manual_seed(42)
).images[0]
image.save("flux_schnell.png")
8.3 任意长宽比与 LoRA 加载¶
# 2D RoPE支持任意分辨率
for name, (h, w) in [("横屏16:9", (576, 1024)), ("竖屏9:16", (1024, 576))]:
img = pipe("cinematic landscape", height=h, width=w,
guidance_scale=3.5, num_inference_steps=28).images[0]
img.save(f"flux_{name}.png")
# 加载LoRA微调权重
pipe.load_lora_weights("your-user/flux-lora", weight_name="lora.safetensors")
pipe.fuse_lora(lora_scale=0.8) # 可调节LoRA强度
# 卸载: pipe.unfuse_lora(); pipe.unload_lora_weights()
📝 面试考点 1. FLUX CFG 为何低于 SDXL (3.5 vs 7.5)? T5-XXL 编码器 + 12B 模型容量使条件理解更强,不需高 CFG 补偿,过高反而过饱和。 2. 低显存运行方案? NF4 量化(40GB→8GB)、CPU offload、bfloat16 混合精度。
9. 实战:ControlNet 条件控图¶
9.1 Canny 边缘控图 (SDXL)¶
import torch, cv2, numpy as np
from PIL import Image
from diffusers import StableDiffusionXLControlNetPipeline, ControlNetModel
from diffusers.utils import load_image
def get_canny(path, lo=100, hi=200):
"""提取Canny边缘图"""
edges = cv2.Canny(np.array(load_image(path)), lo, hi) # np.array创建NumPy数组
return Image.fromarray(np.stack([edges]*3, axis=-1))
controlnet = ControlNetModel.from_pretrained(
"diffusers/controlnet-canny-sdxl-1.0",
torch_dtype=torch.float16, variant="fp16")
pipe = StableDiffusionXLControlNetPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0",
controlnet=controlnet, torch_dtype=torch.float16)
pipe.enable_model_cpu_offload()
image = pipe(
"A beautiful garden, oil painting style, masterpiece",
negative_prompt="blurry, low quality",
image=get_canny("reference.jpg"),
controlnet_conditioning_scale=0.7, # 条件强度
guidance_scale=7.5, num_inference_steps=30,
generator=torch.Generator("cuda").manual_seed(42)
).images[0]
image.save("controlnet_canny.png")
9.2 Depth 深度图控图¶
from transformers import DPTForDepthEstimation, DPTImageProcessor
def get_depth(path):
"""使用DPT估计深度图"""
proc = DPTImageProcessor.from_pretrained("Intel/dpt-large")
model = DPTForDepthEstimation.from_pretrained(
"Intel/dpt-large", torch_dtype=torch.float16).to("cuda")
with torch.no_grad():
d = model(**proc(images=load_image(path), return_tensors="pt").to("cuda")
).predicted_depth.squeeze().cpu().numpy() # squeeze压缩维度
d = ((d - d.min()) / (d.max() - d.min()) * 255).astype(np.uint8)
return Image.fromarray(np.stack([d]*3, -1)).resize((1024, 1024))
controlnet = ControlNetModel.from_pretrained(
"diffusers/controlnet-depth-sdxl-1.0", torch_dtype=torch.float16, variant="fp16")
pipe = StableDiffusionXLControlNetPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0",
controlnet=controlnet, torch_dtype=torch.float16)
pipe.enable_model_cpu_offload()
image = pipe("Futuristic city, cyberpunk, neon", image=get_depth("scene.jpg"),
controlnet_conditioning_scale=0.6, guidance_scale=7.5,
num_inference_steps=30).images[0]
image.save("controlnet_depth.png")
9.3 Pose 人体姿态 & MultiControlNet¶
from controlnet_aux import OpenposeDetector
# 提取人体姿态骨架
pose_img = OpenposeDetector.from_pretrained("lllyasviel/ControlNet")(
load_image("person.jpg"), detect_resolution=1024, image_resolution=1024)
# MultiControlNet: 多条件组合
controlnets = [
ControlNetModel.from_pretrained("diffusers/controlnet-canny-sdxl-1.0",
torch_dtype=torch.float16, variant="fp16"),
ControlNetModel.from_pretrained("diffusers/controlnet-depth-sdxl-1.0",
torch_dtype=torch.float16, variant="fp16")]
pipe = StableDiffusionXLControlNetPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0",
controlnet=controlnets, torch_dtype=torch.float16)
pipe.enable_model_cpu_offload()
image = pipe("Beautiful forest scene, photorealistic",
image=[canny_image, depth_image], # 多个条件图
controlnet_conditioning_scale=[0.6, 0.4], # 各自权重
guidance_scale=7.5, num_inference_steps=30).images[0]
9.4 FLUX ControlNet¶
from diffusers import FluxControlNetPipeline, FluxControlNetModel
controlnet = FluxControlNetModel.from_pretrained(
"InstantX/FLUX.1-dev-Controlnet-Canny", torch_dtype=torch.bfloat16)
pipe = FluxControlNetPipeline.from_pretrained(
"black-forest-labs/FLUX.1-dev",
controlnet=controlnet, torch_dtype=torch.bfloat16)
pipe.enable_model_cpu_offload()
image = pipe(
"A detailed watercolor painting of an ancient castle",
control_image=canny_image,
controlnet_conditioning_scale=0.6,
guidance_scale=3.5, num_inference_steps=28,
height=1024, width=1024).images[0]
image.save("flux_controlnet.png")
📝 面试考点 1.
controlnet_conditioning_scale作用? 控制条件信号强度:0→纯文本,1→最大遵循条件图。 2. MultiControlNet 融合方式? 各 ControlNet 独立产生控制信号,按各自权重加权后累加到 decoder 对应层。
10. CFG 推导与面试高频题¶
10.1 Classifier-Free Guidance 完整推导¶
📌 基础参见 02-条件生成与引导,此处给出面试所需完整推导。
贝叶斯分解:\(p(x|c) = \frac{p(c|x)p(x)}{p(c)}\)
取 log 两侧:\(\log p(x|c) = \log p(x) + \log p(c|x) - \log p(c)\)
对 \(x\) 求梯度(score function):
用两次模型调用替代分类器:
代入并引入引导强度 \(w\):
在 Rectified Flow 中作用于速度场:
训练时随机以 ~10% 概率将条件替换为空 \(\varnothing\),使模型同时学有/无条件去噪。
10.2 十道面试精选¶
Q1: FLUX DiT 架构相比 U-Net 有哪些核心优势?
- Scaling Laws 可预测:模型越大越好,改进趋势平滑
- MM-DiT 双向融合:vs cross-attention 单向 text→image
- 2D RoPE 分辨率外推:vs 固定位置编码无法泛化
- 复用 LLM 基础设施:FSDP、Flash Attention、Tensor Parallelism
Q2: 请推导 Classifier-Free Guidance 公式。
从 \(p(x|c)\propto p(c|x)p(x)\) 出发,取 log 对 \(x\) 求梯度得 \(\nabla\log p(x|c)=\nabla\log p(x)+\nabla\log p(c|x)\)。用 \(\nabla\log p(c|x)=\nabla\log p(x|c)-\nabla\log p(x)\) 替代分类器梯度,得:
Q3: ControlNet 为什么用 zero-convolution?
桥初始化为零→训练开始控制信号为零→原始模型行为不变→控制渐进注入。副本内部参数继承自预训练(非零),梯度正常反传更新 zero-conv。随机初始化会瞬间破坏预训练质量。
Q4: Rectified Flow 比 DDPM 采样效率高的根本原因?
直线路径 \((1-t)x_0+tx_1\) 使速度场 \(v\approx x_1-x_0\) 近似与 \(t\) 无关,Euler 离散化误差极小。DDPM 的曲线路径需多步精确追踪,否则误差累积。
Q5: IP-Adapter 解耦 cross-attention 优于拼接的原因?
拼接方式:文本/图像 token 共享 softmax,互相竞争注意力权重,导致条件干扰。解耦方式:各自有独立 K/V 投影矩阵,可用 \(\lambda\) 独立控制强度,text prompt 和 image prompt 互不影响。
Q6: SD3/FLUX 的 16 通道 VAE 意义?
4→16ch 增大潜空间信息容量→VAE 重建质量显著提升→减少高频细节损失(文字笔画、发丝)。代价是 Transformer 输入维度增加(patchify 后 64 维),但 12B 模型可承受。
Q7: FLUX 双流→单流混合架构的设计逻辑?
浅层(19 层双流):保留模态独立性,各自 Norm/MLP 适应不同模态分布。深层(38 层单流):拼接共享参数,实现文本語义和图像特征的深度融合。全双流参数过多;全单流过早强制模态统一。
Q8: SD1.5 的 ControlNet 能否直接用于 SDXL?
不能。三个不兼容:①通道数不同(各层 channel 数不匹配);②block 结构不同(attention head 数、block 布局);③条件机制不同(SDXL 有额外微条件)。需重新训练或使用社区已训练的 SDXL ControlNet。
Q9: ODE vs SDE 采样各自优劣?
ODE:确定性、可复现、少步高质量(FLUX 默认)。SDE:有随机性→更多样化,噪声可修正 ODE 小误差。RF 直线路径下 ODE 已足够好,SDE 作为可选项追求多样性。
Q10: 生产环境如何优化 FLUX 推理性能?
- 量化:NF4(40GB→8GB)、FP8(H100 上 2x 加速)
- 蒸馏:用 schnell(4步)替 dev(28步),7x 加速
- 编译:
torch.compile(pipe.transformer, mode="max-autotune")提速 20-40% - 注意力:Flash Attention (SDPA 自动启用)
- 缓存:预计算 T5/CLIP 文本编码复用
- 并行:Tensor Parallelism 分片到多GPU
11. 总结与知识图谱¶
FLUX与新一代架构
├── FLUX.1架构
│ ├── DiT Backbone(纯Transformer, 12B参数)
│ ├── MM-DiT双流Joint Attention(图文双向信息流)
│ ├── 2D RoPE(分辨率外推 + 任意长宽比)
│ ├── 双流→单流混合(19层双流 + 38层单流)
│ └── 双编码器(CLIP-L pooled + T5-XXL序列)
│
├── 架构对比: FLUX vs SD3 vs SDXL
│ ├── U-Net → MM-DiT → 混合MM-DiT+DiT
│ ├── DDPM → Rectified Flow
│ └── 4ch VAE → 16ch VAE
│
├── Rectified Flow
│ ├── 直线插值路径: x_t = (1-t)x_0 + tx_1
│ ├── v-prediction: v = x_1 - x_0
│ ├── ODE/SDE采样
│ └── ReFlow蒸馏 → schnell 1-4步
│
├── 条件控制
│ ├── ControlNet: zero-conv + 预训练复用 → 高精度空间控制
│ ├── ControlNet for FLUX: DiT block级zero-linear注入
│ ├── IP-Adapter: 解耦cross-attn → 图像语义控制
│ └── T2I-Adapter: 轻量特征加法 → 低显存控制
│
└── 工程实战
├── diffusers加载FLUX (量化/offload)
├── ControlNet条件控图 (Canny/Depth/Pose)
├── MultiControlNet多条件组合
└── 生产优化 (量化/编译/缓存/并行)
📚 参考文献¶
- Peebles & Xie (2023). "Scalable Diffusion Models with Transformers" (DiT). ICCV 2023.
- Esser et al. (2024). "Scaling Rectified Flow Transformers for High-Resolution Image Synthesis" (SD3). ICML 2024.
- Black Forest Labs (2024). "FLUX.1 Technical Report".
- Liu et al. (2023). "Flow Straight and Fast: Learning to Generate and Transfer Data with Rectified Flow". ICLR 2023.
- Zhang et al. (2023). "Adding Conditional Control to Text-to-Image Diffusion Models" (ControlNet). ICCV 2023.
- Ye et al. (2023). "IP-Adapter: Text Compatible Image Prompt Adapter for Text-to-Image Diffusion Models".
- Mou et al. (2023). "T2I-Adapter: Learning Adapters to Dig out More Controllable Ability". AAAI 2024.
- Lipman et al. (2023). "Flow Matching for Generative Modeling". ICLR 2023.
- Ho & Salimans (2022). "Classifier-Free Diffusion Guidance". NeurIPS 2021 Workshop.
- Su et al. (2021). "RoFormer: Enhanced Transformer with Rotary Position Embedding".