跳转至

10 - FLUX与新一代架构

学习时间: 6小时 重要性: ⭐⭐⭐⭐⭐ 2024-2025年最前沿的图像生成架构,面试高频考点 前置知识: 07-DiT与Transformer扩散架构 | 08-流匹配与一致性模型 | 06-Stable Diffusion进阶(SDXL到SD3)


🎯 学习目标

完成本章后,你将能够: - 深入理解 FLUX.1 的完整架构(DiT backbone + MM-DiT 双流注意力 + 2D RoPE) - 掌握 FLUX vs SD3 vs SDXL 核心差异与演进逻辑 - 从数学推导到代码实现 Rectified Flow(直线流匹配) - 理解 ControlNet 的 zero-convolution 设计与预训练权重复用 - 掌握 IP-Adapter 解耦交叉注意力原理 - 区分 T2I-Adapter、ControlNet、IP-Adapter 的设计哲学与适用场景 - 使用 diffusers 完成 FLUX 推理与 ControlNet 条件控图


1. FLUX.1 架构全解析

1.1 背景与定位

FLUX.1 由 Black Forest Labs(Stability AI 前核心团队 Robin Rombach 等)于 2024 年 8 月发布,代表了"纯 Transformer + Rectified Flow"范式的工程化巅峰。

版本 参数量 许可证 特点
FLUX.1 [pro] 12B 商业API 最高质量,仅API调用
FLUX.1 [dev] 12B 非商业开源 指导蒸馏后的开发版本
FLUX.1 [schnell] 12B Apache 2.0 1-4步超快推理

💡 FLUX.1 与 SD3 共享核心团队和设计思想(MM-DiT、Rectified Flow、多文本编码器),但 FLUX.1 在架构细节上做了大量改进。

1.2 整体架构鸟瞰

Text Only
文本编码器: CLIP ViT-L/14 (pooled 768d) + T5-XXL (序列 4096d)
    Patchify (2×2)          ▼            条件向量: timestep_embed + pooled_text
潜空间 z_T ──────→ 双流 MM-DiT Blocks (×19)  ←─────────────────────────┘
   + 2D RoPE        img流 / txt流 Joint Attention
                     单流 DiT Blocks (×38)
                     img+txt拼接为统一序列, 共享Self-Attn
                     Unpatchify + VAE Decode → 生成图像 (1024²)

双文本编码器设计

编码器 输出维度 作用
CLIP ViT-L/14 768-dim pooled 全局语义向量,用于时间步嵌入的 AdaLN 条件调制
T5-XXL (encoder-only) 4096-dim × seq_len 丰富的文本序列表示,作为 txt 流输入

去掉 SD3 的第二个 CLIP-bigG,因为 T5-XXL 的序列表示已足够丰富,第二个 CLIP 边际收益有限,去掉可减少显存和延迟。

1.3 MM-DiT 双流注意力

📌 深入展开 07-DiT与Transformer扩散架构 中的 MM-DiT 概念。

传统 Cross-Attention vs MM-DiT Joint Attention

Text Only
传统 Cross-Attention (SD 1.5/SDXL):
  Q = image_tokens × W_q    ← 只有图像做查询
  K = text_tokens  × W_k    ← 文本做键值
  V = text_tokens  × W_v
  → 信息单向: text → image

MM-DiT Joint Attention (SD3/FLUX):
  Q = concat(img_tokens × W_q^img, txt_tokens × W_q^txt)
  K = concat(img_tokens × W_k^img, txt_tokens × W_k^txt)
  V = concat(img_tokens × W_v^img, txt_tokens × W_v^txt)
  Attn = softmax(QK^T/√d) V
  → 信息双向: text ↔ image

双流设计:图像和文本各自维护独立的 LayerNorm/MLP 参数,但在注意力计算时拼接做 Joint Attention,实现双向信息交互。

Python
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
import math

class MMDiTBlock(nn.Module):  # 继承nn.Module定义网络层
    """MM-DiT双流注意力块

    双流设计:img/txt各自独立Norm/MLP参数,
    但注意力时拼接QKV实现双向信息交互。
    """
    def __init__(self, dim=3072, heads=24, mlp_ratio=4.0):
        super().__init__()  # super()调用父类方法
        self.heads = heads
        self.head_dim = dim // heads

        # ====== 图像流参数 ======
        self.img_norm1 = nn.LayerNorm(dim, elementwise_affine=False)
        self.img_qkv = nn.Linear(dim, dim * 3)
        self.img_out = nn.Linear(dim, dim)
        self.img_norm2 = nn.LayerNorm(dim, elementwise_affine=False)
        self.img_mlp = nn.Sequential(
            nn.Linear(dim, int(dim * mlp_ratio)),
            nn.GELU(approximate="tanh"),
            nn.Linear(int(dim * mlp_ratio), dim))

        # ====== 文本流参数(结构相同,权重独立)======
        self.txt_norm1 = nn.LayerNorm(dim, elementwise_affine=False)
        self.txt_qkv = nn.Linear(dim, dim * 3)
        self.txt_out = nn.Linear(dim, dim)
        self.txt_norm2 = nn.LayerNorm(dim, elementwise_affine=False)
        self.txt_mlp = nn.Sequential(
            nn.Linear(dim, int(dim * mlp_ratio)),
            nn.GELU(approximate="tanh"),
            nn.Linear(int(dim * mlp_ratio), dim))

        # AdaLN-Zero: 12个调制参数
        # img: (shift_a, scale_a, gate_a, shift_m, scale_m, gate_m) × 1
        # txt: 同上 × 1
        self.adaLN = nn.Sequential(nn.SiLU(), nn.Linear(dim, dim * 12))

    def _modulate(self, x, shift, scale):
        return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)  # unsqueeze增加一个维度

    def forward(self, img, txt, cond):
        """
        img: [B, N_img, D]  图像patch序列
        txt: [B, N_txt, D]  文本token序列
        cond: [B, D]        时间步+pooled文本条件
        """
        # chunk(12)将条件向量切分为12个调制参数:每对(shift,scale,gate)控制norm/attn/mlp,图像和文本各6个
        m = self.adaLN(cond).chunk(12, dim=-1)

        # ---- Joint Attention ----
        img_n = self._modulate(self.img_norm1(img), m[0], m[1])
        txt_n = self._modulate(self.txt_norm1(txt), m[6], m[7])
        iq, ik, iv = self.img_qkv(img_n).chunk(3, dim=-1)
        tq, tk, tv = self.txt_qkv(txt_n).chunk(3, dim=-1)

        # 拼接Q/K/V做联合注意力
        q = torch.cat([iq, tq], dim=1)  # torch.cat沿已有维度拼接张量
        k = torch.cat([ik, tk], dim=1)
        v = torch.cat([iv, tv], dim=1)

        q = rearrange(q, 'b n (h d) -> b h n d', h=self.heads)
        k = rearrange(k, 'b n (h d) -> b h n d', h=self.heads)
        v = rearrange(v, 'b n (h d) -> b h n d', h=self.heads)
        attn = F.scaled_dot_product_attention(q, k, v)  # Flash Attention  # F.xxx PyTorch函数式API
        attn = rearrange(attn, 'b h n d -> b n (h d)')

        # 拆分回img和txt,各自output projection + gate
        N = img.shape[1]
        img = img + m[2].unsqueeze(1) * self.img_out(attn[:, :N])
        txt = txt + m[8].unsqueeze(1) * self.txt_out(attn[:, N:])

        # ---- MLP (各自独立) ----
        img = img + m[5].unsqueeze(1) * self.img_mlp(
            self._modulate(self.img_norm2(img), m[3], m[4]))
        txt = txt + m[11].unsqueeze(1) * self.txt_mlp(
            self._modulate(self.txt_norm2(txt), m[9], m[10]))
        return img, txt

1.4 2D Rotary Position Encoding

FLUX 用 2D RoPE 替代 SD3 的固定正弦编码——这是 LLM 社区 RoPE(Su et al., 2021)在二维空间的推广。

将 head_dim 分两半,分别编码行与列位置:

\[q' = [R_{\theta,h} \cdot q_{:d/2},\; R_{\theta,w} \cdot q_{d/2:}]$$ $$k' = [R_{\theta,h} \cdot k_{:d/2},\; R_{\theta,w} \cdot k_{d/2:}]\]

注意力分数自然编码二维空间距离:

\[q'^T k' = q_{:d/2}^T R_{\theta,h-h'} k_{:d/2} + q_{d/2:}^T R_{\theta,w-w'} k_{d/2:}\]

核心优势: - 分辨率外推:RoPE 是相对位置编码,可泛化到训练时未见的分辨率 - 任意长宽比:不依赖固定 grid size - 无需学习:位置编码是确定性函数,零额外参数

Python
import torch

def get_2d_rope_freqs(height, width, dim, theta=10000.0):
    """生成2D RoPE复数频率矩阵

    Returns: [H*W, dim//2] 复数频率
    """
    quarter_dim = dim // 4
    freqs = 1.0 / (theta ** (torch.arange(0, quarter_dim).float() / quarter_dim))

    # 行频率 → 前半维度
    h_freqs = torch.outer(torch.arange(height).float(), freqs)  # [H, qd]
    h_freqs = h_freqs.unsqueeze(1).expand(-1, width, -1).reshape(-1, quarter_dim)  # 重塑张量形状

    # 列频率 → 后半维度
    w_freqs = torch.outer(torch.arange(width).float(), freqs)  # [W, qd]
    w_freqs = w_freqs.unsqueeze(0).expand(height, -1, -1).reshape(-1, quarter_dim)

    freqs_2d = torch.cat([h_freqs, w_freqs], dim=-1)  # [H*W, dim//2]
    return torch.polar(torch.ones_like(freqs_2d), freqs_2d)

def apply_2d_rope(x, freqs):
    """x: [B, heads, N, head_dim] → 旋转后同形状"""
    # reshape将最后一维两两配对为(实部,虚部),view_as_complex转为复数,乘以频率复数实现旋转位置编码
    xc = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2))
    rotated = xc * freqs.unsqueeze(0).unsqueeze(0)
    return torch.view_as_real(rotated).reshape_as(x).to(x.dtype)

# 示例:1024×1024图像, patch=2, VAE下采样8x → 64×64 patches
freqs = get_2d_rope_freqs(height=64, width=64, dim=128)
print(f"2D RoPE freqs: {freqs.shape}")  # [4096, 64]

1.5 双流→单流混合架构

阶段 类型 层数 特点
第一阶段 双流 MM-DiT 19层 img/txt 各自独立 Norm/MLP,Joint Attention
第二阶段 单流 DiT 38层 img+txt 拼接为统一序列,完全共享参数

设计哲学: - 全双流:每层两套参数,参数量大、效率低 - 全单流:浅层就共享参数,模态对齐不充分 - 混合策略:前期保留模态独立性,后期深度融合,效率与效果最优平衡

Python
class SingleDiTBlock(nn.Module):
    """单流DiT块(第二阶段,img+txt共享全部参数)"""
    def __init__(self, dim=3072, heads=24, mlp_ratio=4.0):
        super().__init__()
        self.heads = heads
        self.norm1 = nn.LayerNorm(dim, elementwise_affine=False)
        self.qkv = nn.Linear(dim, dim * 3)
        self.proj = nn.Linear(dim, dim)
        self.norm2 = nn.LayerNorm(dim, elementwise_affine=False)
        self.mlp = nn.Sequential(
            nn.Linear(dim, int(dim * mlp_ratio)),
            nn.GELU(approximate="tanh"),
            nn.Linear(int(dim * mlp_ratio), dim))
        self.adaLN = nn.Sequential(nn.SiLU(), nn.Linear(dim, dim * 6))

    def forward(self, x, cond):
        """x: [B, N_img+N_txt, D] 拼接后的统一序列"""
        s_a, sc_a, g_a, s_m, sc_m, g_m = self.adaLN(cond).chunk(6, dim=-1)
        xn = self.norm1(x) * (1 + sc_a.unsqueeze(1)) + s_a.unsqueeze(1)
        # 列表推导:qkv线性层输出chunk(3)拆为Q/K/V,每个用rearrange从(B,N,heads*d)重排为(B,heads,N,d)多头格式
        q, k, v = [rearrange(t, 'b n (h d) -> b h n d', h=self.heads)
                    for t in self.qkv(xn).chunk(3, dim=-1)]
        attn = rearrange(F.scaled_dot_product_attention(q, k, v), 'b h n d -> b n (h d)')
        x = x + g_a.unsqueeze(1) * self.proj(attn)
        xn2 = self.norm2(x) * (1 + sc_m.unsqueeze(1)) + s_m.unsqueeze(1)
        x = x + g_m.unsqueeze(1) * self.mlp(xn2)
        return x

def timestep_embedding(t, dim, max_period=10000):
    """正弦时间步嵌入"""
    half = dim // 2
    freqs = torch.exp(-math.log(max_period) * torch.arange(half, device=t.device) / half)
    args = t[:, None].float() * freqs[None]
    return torch.cat([torch.cos(args), torch.sin(args)], dim=-1)

📝 面试考点 1. MM-DiT 与 cross-attention 的本质区别? cross-attention 信息单向 text→image;MM-DiT Joint Attention 拼接 Q/K/V 实现双向信息流 text↔image。 2. 为什么双流→单流而非全双流? 全双流每层两套参数效率低;混合设计前期保留模态特异性,后期共享参数深度融合。 3. 2D RoPE 相比固定位置编码的优势? 相对位置编码,支持分辨率外推和任意长宽比,无需学习参数。


2. FLUX vs SD3 vs SDXL 架构对比

2.1 核心架构对比表

维度 SDXL (2023.07) SD3 (2024.03) FLUX.1 (2024.08)
骨干网络 U-Net (3.5B) MM-DiT (2B/8B) 混合MM-DiT+DiT (12B)
噪声调度 DDPM (离散) Rectified Flow (连续) Rectified Flow (连续)
文本编码器 CLIP-L + CLIP-bigG CLIP-L + CLIP-bigG + T5-XXL CLIP-L + T5-XXL
位置编码 固定正弦 1D 固定正弦 2D 2D RoPE
注意力 Self + Cross-Attn MM-DiT Joint Attn 双流MM-DiT + 单流Self
VAE通道 4ch 16ch 16ch
训练目标 ε-prediction v-prediction (Flow) v-prediction (RF,预测速度 \(v_t = x_1 - x_0\))
最快推理 ~25步 ~28步 1-4步 (schnell)
推荐CFG 7.0-7.5 7.0 3.5
文字渲染 优秀
VRAM (fp16) ~6.5GB ~8GB ~24GB (NF4→~8GB)

2.2 关键升级解读

16通道 VAE(SD3/FLUX 共享):

Text Only
SD 1.5 / SDXL:  VAE latent = 4 channels  → 1024×1024 → 128×128×4
SD3 / FLUX.1:   VAE latent = 16 channels → 1024×1024 → 128×128×16

更丰富的潜空间表示→VAE 重建质量更高→减少高频细节损失(文字笔画、纤细纹理)。代价是 Transformer 输入维度从 4 增至 16(patchify 后 \(16\times2\times2=64\))。

FLUX schnell 无需 CFG:通过指导蒸馏将 CFG 引导效果内化到模型权重,一次前向 = 教师的有条件+无条件两次前向。

2.3 推理配置典型值

Python
configs = {
    "SDXL":         {"steps": 25, "cfg": 7.0,  "scheduler": "DPM++ 2M Karras"},
    "SD3-Medium":   {"steps": 28, "cfg": 7.0,  "scheduler": "Euler + logit-normal"},
    "FLUX.1-dev":   {"steps": 28, "cfg": 3.5,  "scheduler": "Euler"},
    "FLUX.1-schnell": {"steps": 4, "cfg": 0.0, "scheduler": "Euler (无CFG)"},
}

📝 面试考点 1. 16通道 VAE 的目的? 提升潜空间信息容量,减少高频细节损失,改善文字渲染。 2. FLUX 为何只用 2 个文本编码器? T5-XXL 已涵盖 CLIP-bigG 的语义贡献,去掉减少显存。 3. FLUX schnell 为何不需 CFG? 指导蒸馏已将 CFG 引导效果内化到权重。


3. Rectified Flow:直线流匹配

📌 本节深化 08-流匹配与一致性模型 的流匹配理论。

3.1 核心原理

Rectified Flow(Liu et al., 2023, "Flow Straight and Fast")的核心思想:让概率传输路径为直线

线性插值路径

\[x_t = (1-t) \cdot x_0 + t \cdot x_1, \quad t \in [0,1], \; x_0 \sim \mathcal{N}(0,I), \; x_1 \sim p_\text{data}\]

目标速度场\(u_t(x_t|x_0,x_1) = x_1 - x_0\)(与 \(t\) 无关——正是"直线"的含义)

训练损失

\[\mathcal{L}_\text{RF} = \mathbb{E}_{t \sim \mathcal{U}(0,1),\, x_0,x_1}\left[\|v_\theta(x_t, t) - (x_1 - x_0)\|^2\right]\]

3.2 ODE vs SDE 采样

ODE(确定性)\(x_{t+\Delta t} = x_t + \Delta t \cdot v_\theta(x_t, t)\) 直线路径→理想 1 步 Euler 即可,实际 4-50 步。

SDE(随机性)\(dx_t = v_\theta dt + \sigma_t dW_t\) 少量噪声可修正 ODE 误差累积,提升多样性。

3.3 与 DDPM 对比

维度 DDPM Rectified Flow
路径类型 曲线 (\(\sqrt{\bar\alpha_t}x_0 + \sqrt{1-\bar\alpha_t}\epsilon\)) 直线 (\((1-t)x_0 + tx_1\))
预测目标 噪声 \(\epsilon\) 速度 \(v = x_1 - x_0\)
采样步数 通常 20-1000 通常 1-50
少步质量 差(需DDIM/DPM++加速) 好(路径本身近似直线)
数学框架 离散马尔可夫链 / SDE 连续 ODE (CNF)

3.4 ReFlow 蒸馏

用当前模型 \(v_{\theta_1}\) 做 ODE 推理得新 \((x_0, \hat{x}_1)\) 对,重新训练 \(v_{\theta_2}\)。新 pair 间 ODE 路径更直→\(v_{\theta_2}\) 需更少步数。迭代可进一步拉直。FLUX schnell 即用类似策略实现 1-4 步。

3.5 完整训练与采样代码

Python
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

class VelocityNet(nn.Module):
    """简化的速度场网络(用于小图像验证)"""
    def __init__(self, ch=1, size=28, hidden=256):
        super().__init__()
        d = ch * size * size
        self.net = nn.Sequential(
            nn.Linear(d + 1, hidden), nn.SiLU(),
            nn.Linear(hidden, hidden), nn.SiLU(),
            nn.Linear(hidden, hidden), nn.SiLU(),
            nn.Linear(hidden, d))
        self.shape = (ch, size, size)

    def forward(self, x_t, t):
        B = x_t.shape[0]
        x_flat = x_t.reshape(B, -1)
        return self.net(torch.cat([x_flat, t.reshape(B,1)], -1)).reshape(B, *self.shape)

class RectifiedFlowTrainer:
    """Rectified Flow 训练与采样"""
    def __init__(self, model, lr=2e-4, device="cuda"):
        self.model = model.to(device)  # 移至GPU/CPU
        self.opt = torch.optim.AdamW(model.parameters(), lr=lr)
        self.device = device

    def train_step(self, x_1):
        """x_1: [B,C,H,W] 真实数据(对应t=1)"""
        B = x_1.shape[0]
        x_0 = torch.randn_like(x_1)                          # 噪声
        t = torch.rand(B, 1, device=self.device)
        t_e = t.unsqueeze(-1).unsqueeze(-1)                   # [B,1,1,1]
        x_t = (1 - t_e) * x_0 + t_e * x_1                   # 线性插值
        target = x_1 - x_0                                     # 目标速度
        loss = F.mse_loss(self.model(x_t, t), target)
        self.opt.zero_grad(); loss.backward(); self.opt.step()  # 清零梯度  # 反向传播计算梯度
        return loss.item()  # 将单元素张量转为Python数值

    @torch.no_grad()  # 禁用梯度计算,节省内存
    def sample_ode(self, batch_size, num_steps=20):
        """ODE Euler采样(确定性)"""
        self.model.eval()  # eval()评估模式
        x = torch.randn(batch_size, *self.model.shape, device=self.device)
        dt = 1.0 / num_steps
        for i in range(num_steps):
            t = torch.full((batch_size, 1), i * dt, device=self.device)
            x = x + dt * self.model(x, t)
        self.model.train()
        return x.clamp(-1, 1)

    @torch.no_grad()
    def sample_sde(self, batch_size, num_steps=20, sigma=0.01):
        """SDE采样(带随机噪声修正)"""
        self.model.eval()
        x = torch.randn(batch_size, *self.model.shape, device=self.device)
        dt = 1.0 / num_steps
        for i in range(num_steps):
            t = torch.full((batch_size, 1), i * dt, device=self.device)
            noise = torch.randn_like(x) * sigma * dt**0.5 if i < num_steps - 1 else 0
            x = x + dt * self.model(x, t) + noise
        self.model.train()
        return x.clamp(-1, 1)

def train_rectified_flow():
    device = "cuda" if torch.cuda.is_available() else "cpu"
    transform = transforms.Compose([
        transforms.ToTensor(), transforms.Normalize([0.5], [0.5])])
    dl = DataLoader(datasets.MNIST("./data", train=True, download=True,  # DataLoader批量加载数据
        transform=transform), batch_size=128, shuffle=True)

    trainer = RectifiedFlowTrainer(VelocityNet(), device=device)
    for epoch in range(10):
        avg_loss = sum(trainer.train_step(x.to(device)) for x,_ in dl) / len(dl)
        print(f"Epoch {epoch}: loss={avg_loss:.4f}")
        if epoch % 5 == 0:
            samples = trainer.sample_ode(16, num_steps=10)
            print(f"  ODE 10步: range [{samples.min():.2f}, {samples.max():.2f}]")
            samples_sde = trainer.sample_sde(16, num_steps=10, sigma=0.05)
            print(f"  SDE 10步: range [{samples_sde.min():.2f}, {samples_sde.max():.2f}]")

📝 面试考点 1. RF 比 DDPM 高效的核心原因? 直线路径→Euler 离散化误差极小→少步即可高质量。 2. v-prediction 与 ε-prediction 的关系? \(v = x_1 - x_0 = x_1 - \epsilon\),二者可互转。 3. ReFlow 蒸馏如何拉直路径? 用当前模型 ODE 推理产生新 pair 重训,新 pair 间路径更直。


4. ControlNet 原理与实现

4.1 核心设计哲学

ControlNet(Zhang et al., 2023)教扩散模型遵循空间条件(边缘/深度/姿态),三大原则:

  1. 复用预训练权重:深拷贝 U-Net encoder blocks 作为可训练副本
  2. 条件编码器:将条件图编码为与 latent 对齐的特征
  3. zero-convolution 安全连接:确保训练开始时不破坏原始模型

4.2 Zero-Convolution 详解

权重和偏置全零初始化的 1×1 卷积:\(y = Wx + b,\; W=0,\; b=0\)

初始化方式 训练初期行为 风险
随机初始化 随机信号注入原始模型 ❌ 瞬间破坏预训练质量
零初始化全连接 梯度全零,对称性问题 ❌ 训练僵死
Zero-Conv 输出为零,原始模型不受影响 ✅ 安全起步

关键洞察:zero-conv 只是 ControlNet→主模型的"桥"初始化为零。副本内部参数继承自预训练模型(非零),梯度正常反传,进而更新 zero-conv 参数。

Python
import torch, torch.nn as nn, copy

class ZeroConv(nn.Module):
    """零初始化1×1卷积"""
    def __init__(self, ch):
        super().__init__()
        self.conv = nn.Conv2d(ch, ch, kernel_size=1)
        nn.init.zeros_(self.conv.weight)
        nn.init.zeros_(self.conv.bias)
    def forward(self, x):
        return self.conv(x)

class ControlNet(nn.Module):
    """ControlNet for U-Net based models (SD1.5/SDXL, 简化示意)"""
    def __init__(self, base_unet, cond_channels=3):
        super().__init__()
        # 条件编码器:将条件图(如Canny边缘)编码为latent尺寸特征
        self.cond_encoder = nn.Sequential(
            nn.Conv2d(cond_channels, 64, 3, stride=2, padding=1), nn.SiLU(),
            nn.Conv2d(64, 128, 3, stride=2, padding=1), nn.SiLU(),
            nn.Conv2d(128, 256, 3, stride=2, padding=1), nn.SiLU(),
            nn.Conv2d(256, 320, 3, stride=1, padding=1))

        # 复制encoder各层 + zero-conv连接
        encoder_channels = [320, 640, 1280, 1280]
        self.ctrl_blocks = nn.ModuleList()
        self.zero_convs = nn.ModuleList()
        for ch in encoder_channels:
            self.ctrl_blocks.append(nn.Sequential(
                nn.Conv2d(ch, ch, 3, padding=1),
                nn.GroupNorm(32, ch), nn.SiLU()))
            self.zero_convs.append(ZeroConv(ch))

        # 冻结原始U-Net全部参数
        for p in base_unet.parameters():
            p.requires_grad = False

    def forward(self, cond_image):
        """返回各层控制信号(加到decoder对应层)"""
        h = self.cond_encoder(cond_image)
        signals = []
        for block, zc in zip(self.ctrl_blocks, self.zero_convs):  # zip按位置配对
            h = block(h)
            signals.append(zc(h))
        return signals

4.3 训练配置

Python
# 关键训练细节:
# 1. 冻结原始U-Net全部参数
# 2. 只训练:条件编码器 + ControlNet副本 + zero-convolution
# 3. 学习率1e-5(较小,从预训练权重开始)
# 4. 数据:成对的(图像, 条件图),如(照片, Canny边缘)
# 5. 数据量仅需原始模型训练量的5-10%

📝 面试考点 1. zero-conv 的设计动因? 桥为零→训练开始不破坏预训练→控制信号渐进注入。副本内部预训练参数保证梯度正常。 2. 训练哪些参数? 条件编码器 + encoder副本 + zero-conv;原始U-Net全冻结。 3. 如何保护预训练质量? zero-conv 初始为零→原始模型初始行为不变→控制渐进式注入。


5. ControlNet for FLUX/SD3 适配

5.1 挑战

FLUX/SD3 是纯 Transformer 无 encoder-decoder 结构,传统 ControlNet 的 encoder→decoder 注入方式不适用。

挑战 U-Net ControlNet DiT ControlNet
注入点 encoder block → decoder block 无decoder,需在DiT block层级注入
条件格式 空间特征图 需patchify为token序列
多模态 无需考虑 需处理MM-DiT双流结构

5.2 方案:复制前 N 层 + zero-linear

Python
class FLUXControlNet(nn.Module):
    """FLUX ControlNet适配(简化版)"""
    def __init__(self, flux_model, n_ctrl=8, dim=3072, cond_ch=3):
        super().__init__()
        # 条件编码器:图像条件 → token序列
        self.cond_enc = nn.Sequential(
            nn.Conv2d(cond_ch, 64, 3, 2, 1), nn.SiLU(),
            nn.Conv2d(64, 128, 3, 2, 1), nn.SiLU(),
            nn.Conv2d(128, 256, 3, 2, 1), nn.SiLU(),
            nn.Conv2d(256, dim // 4, 1))
        self.proj = nn.Linear(dim // 4 * 4, dim)  # patchify后投影

        # 复制前N个双流MM-DiT块
        self.ctrl_blocks = nn.ModuleList([
            copy.deepcopy(flux_model.double_blocks[i]) for i in range(n_ctrl)])
        self.zero_lins = nn.ModuleList([
            ZeroLinear(dim) for _ in range(n_ctrl)])

        for p in flux_model.parameters():
            p.requires_grad = False

    def forward(self, img_tok, txt_tok, cond, cond_img, flux):
        cond_tok = self.proj(rearrange(
            self.cond_enc(cond_img),
            'b c (h p1) (w p2) -> b (h w) (c p1 p2)', p1=2, p2=2))
        ci, ct = img_tok.clone(), txt_tok.clone()
        for i, main_blk in enumerate(flux.double_blocks):  # enumerate同时获取索引和元素
            img_tok, txt_tok = main_blk(img_tok, txt_tok, cond)
            if i < len(self.ctrl_blocks):
                ci, ct = self.ctrl_blocks[i](ci + cond_tok, ct, cond)
                img_tok = img_tok + self.zero_lins[i](ci)  # zero-linear注入
        return img_tok, txt_tok

class ZeroLinear(nn.Module):
    """零初始化线性投影"""
    def __init__(self, d):
        super().__init__()
        self.lin = nn.Linear(d, d)
        nn.init.zeros_(self.lin.weight); nn.init.zeros_(self.lin.bias)
    def forward(self, x): return self.lin(x)

只复制前 N 层的原因:前几层捕获低级空间结构(与条件最相关),全复制 12B 模型不现实。

📝 面试考点 FLUX ControlNet 与 SD1.5 ControlNet 主要区别? SD1.5 利用 U-Net encoder-decoder 对称性注入;FLUX 无 decoder,在 DiT block 层级直接用 zero-linear 注入。


6. IP-Adapter:图像提示适配器

6.1 解耦交叉注意力

IP-Adapter(Ye et al., 2023)让模型以参考图像做 prompt。核心:文本和图像 cross-attention 完全分离(Decoupled Cross-Attention)。

Text Only
简单拼接: K = concat(K_text, K_img) → softmax中互相竞争, 质量下降
解耦方案: Output = CrossAttn(Q, K_text, V_text) + λ · CrossAttn(Q, K_img, V_img)
                   ↑ 原始文本路径(冻结)         ↑ 新增图像路径(可训练)

流程:参考图 → CLIP Image Encoder → 可学习投影层 → image_tokens → 独立 K/V 投影

Python
class DecoupledCrossAttn(nn.Module):
    """解耦交叉注意力:分别处理文本和图像条件"""
    def __init__(self, dim=768, heads=12):
        super().__init__()
        self.heads, self.hd = heads, dim // heads
        # 原始文本cross-attention (冻结)
        self.q = nn.Linear(dim, dim)
        self.k_text = nn.Linear(dim, dim)
        self.v_text = nn.Linear(dim, dim)
        self.out_text = nn.Linear(dim, dim)
        # 新增图像cross-attention (可训练)
        self.k_img = nn.Linear(dim, dim)
        self.v_img = nn.Linear(dim, dim)
        self.out_img = nn.Linear(dim, dim)
        self.ip_scale = nn.Parameter(torch.tensor(1.0))

    def forward(self, x, text_emb, img_tokens=None):
        B, N, D = x.shape
        def to_heads(t, s):
            return t.view(B, s, self.heads, self.hd).transpose(1, 2)
        q = to_heads(self.q(x), N)
        # 文本cross-attention
        kt = to_heads(self.k_text(text_emb), text_emb.shape[1])
        vt = to_heads(self.v_text(text_emb), text_emb.shape[1])
        out = self.out_text(
            F.scaled_dot_product_attention(q, kt, vt).transpose(1,2).reshape(B,N,D))
        # 图像cross-attention(解耦路径)
        if img_tokens is not None:
            ki = to_heads(self.k_img(img_tokens), img_tokens.shape[1])
            vi = to_heads(self.v_img(img_tokens), img_tokens.shape[1])
            img_out = self.out_img(
                F.scaled_dot_product_attention(q, ki, vi).transpose(1,2).reshape(B,N,D))
            out = out + self.ip_scale * img_out
        return out

6.2 变体与训练开销

变体 图像编码器 token数 用途
IP-Adapter CLIP ViT-H (CLS) 4 全局风格转移
IP-Adapter Plus CLIP ViT-H (patch tokens) 16 保留更多空间细节
IP-Adapter FaceID InsightFace + CLIP 4 人脸一致性保持

训练:冻结 U-Net + CLIP Image Encoder,仅训练投影层和图像 cross-attn K/V/out,约 22M 参数

📝 面试考点 1. 解耦 vs 拼接的优势? 拼接导致注意力竞争和模态干扰;解耦各自 K/V 独立投影,λ 灵活控制强度。 2. 训练参数量? 仅约 22M(投影层+图像 K/V/out),原始模型全冻结。


7. T2I-Adapter vs ControlNet vs IP-Adapter 对比

维度 T2I-Adapter ControlNet IP-Adapter
条件类型 空间(边缘/深度/姿态) 空间(边缘/深度/姿态) 语义(参考图像)
注入方式 特征直接加法 zero-conv 桥接 解耦 cross-attn
可训练参数 ~77M(轻量独立网络) ~361M(encoder副本) ~22M(投影+K/V)
控制精度 中等 低(语义级)
推理显存增加 极少 ~50% 极少
可组合性 ✅ 可叠加 ✅ MultiControlNet ✅ 与ControlNet共用

选择策略: - 精确空间控制 → ControlNet(利用预训练encoder,精度最高) - 风格/图像参考 → IP-Adapter(语义级控制,参数最少) - 低显存多条件 → T2I-Adapter(轻量独立,即插即用) - 精确控制+风格 → ControlNet + IP-Adapter 组合

📝 面试考点 1. ControlNet 比 T2I-Adapter 更精确的原因? ControlNet 复用预训练 encoder 全部参数,对结构理解更深;T2I-Adapter 用独立小网络从零学。 2. 三者能否同时使用? 可以,注入位置不同。ControlNet(空间控制)+ IP-Adapter(风格参考)是常见组合。


8. 实战:FLUX 推理代码

8.1 FLUX.1-dev 文生图

Python
import torch
from diffusers import FluxPipeline, BitsAndBytesConfig

# 方式1: NF4量化加载 (~8GB VRAM)
pipe = FluxPipeline.from_pretrained(
    "black-forest-labs/FLUX.1-dev",
    quantization_config=BitsAndBytesConfig(
        load_in_4bit=True, bnb_4bit_quant_type="nf4",
        bnb_4bit_compute_dtype=torch.bfloat16),
    torch_dtype=torch.bfloat16)

# 方式2: CPU offload (更通用)
# pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-dev",
#     torch_dtype=torch.bfloat16)
# pipe.enable_model_cpu_offload()

image = pipe(
    "A serene mountain landscape at sunset, crystal-clear lake, photorealistic, 8K",
    height=1024, width=1024,
    guidance_scale=3.5,       # FLUX推荐低CFG
    num_inference_steps=28,
    max_sequence_length=512,  # T5最大token长度
    generator=torch.Generator("cuda").manual_seed(42)
).images[0]
image.save("flux_dev.png")

8.2 FLUX.1-schnell 快速推理

Python
pipe = FluxPipeline.from_pretrained(
    "black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16)
pipe.enable_model_cpu_offload()

image = pipe(
    "A cyberpunk cityscape, neon lights, rain reflections",
    height=1024, width=1024,
    guidance_scale=0.0,    # schnell不需CFG(蒸馏已内化)
    num_inference_steps=4,  # 仅4步!
    generator=torch.Generator("cuda").manual_seed(42)
).images[0]
image.save("flux_schnell.png")

8.3 任意长宽比与 LoRA 加载

Python
# 2D RoPE支持任意分辨率
for name, (h, w) in [("横屏16:9", (576, 1024)), ("竖屏9:16", (1024, 576))]:
    img = pipe("cinematic landscape", height=h, width=w,
               guidance_scale=3.5, num_inference_steps=28).images[0]
    img.save(f"flux_{name}.png")

# 加载LoRA微调权重
pipe.load_lora_weights("your-user/flux-lora", weight_name="lora.safetensors")
pipe.fuse_lora(lora_scale=0.8)  # 可调节LoRA强度
# 卸载: pipe.unfuse_lora(); pipe.unload_lora_weights()

📝 面试考点 1. FLUX CFG 为何低于 SDXL (3.5 vs 7.5)? T5-XXL 编码器 + 12B 模型容量使条件理解更强,不需高 CFG 补偿,过高反而过饱和。 2. 低显存运行方案? NF4 量化(40GB→8GB)、CPU offload、bfloat16 混合精度。


9. 实战:ControlNet 条件控图

9.1 Canny 边缘控图 (SDXL)

Python
import torch, cv2, numpy as np
from PIL import Image
from diffusers import StableDiffusionXLControlNetPipeline, ControlNetModel
from diffusers.utils import load_image

def get_canny(path, lo=100, hi=200):
    """提取Canny边缘图"""
    edges = cv2.Canny(np.array(load_image(path)), lo, hi)  # np.array创建NumPy数组
    return Image.fromarray(np.stack([edges]*3, axis=-1))

controlnet = ControlNetModel.from_pretrained(
    "diffusers/controlnet-canny-sdxl-1.0",
    torch_dtype=torch.float16, variant="fp16")
pipe = StableDiffusionXLControlNetPipeline.from_pretrained(
    "stabilityai/stable-diffusion-xl-base-1.0",
    controlnet=controlnet, torch_dtype=torch.float16)
pipe.enable_model_cpu_offload()

image = pipe(
    "A beautiful garden, oil painting style, masterpiece",
    negative_prompt="blurry, low quality",
    image=get_canny("reference.jpg"),
    controlnet_conditioning_scale=0.7,  # 条件强度
    guidance_scale=7.5, num_inference_steps=30,
    generator=torch.Generator("cuda").manual_seed(42)
).images[0]
image.save("controlnet_canny.png")

9.2 Depth 深度图控图

Python
from transformers import DPTForDepthEstimation, DPTImageProcessor

def get_depth(path):
    """使用DPT估计深度图"""
    proc = DPTImageProcessor.from_pretrained("Intel/dpt-large")
    model = DPTForDepthEstimation.from_pretrained(
        "Intel/dpt-large", torch_dtype=torch.float16).to("cuda")
    with torch.no_grad():
        d = model(**proc(images=load_image(path), return_tensors="pt").to("cuda")
            ).predicted_depth.squeeze().cpu().numpy()  # squeeze压缩维度
    d = ((d - d.min()) / (d.max() - d.min()) * 255).astype(np.uint8)
    return Image.fromarray(np.stack([d]*3, -1)).resize((1024, 1024))

controlnet = ControlNetModel.from_pretrained(
    "diffusers/controlnet-depth-sdxl-1.0", torch_dtype=torch.float16, variant="fp16")
pipe = StableDiffusionXLControlNetPipeline.from_pretrained(
    "stabilityai/stable-diffusion-xl-base-1.0",
    controlnet=controlnet, torch_dtype=torch.float16)
pipe.enable_model_cpu_offload()

image = pipe("Futuristic city, cyberpunk, neon", image=get_depth("scene.jpg"),
    controlnet_conditioning_scale=0.6, guidance_scale=7.5,
    num_inference_steps=30).images[0]
image.save("controlnet_depth.png")

9.3 Pose 人体姿态 & MultiControlNet

Python
from controlnet_aux import OpenposeDetector

# 提取人体姿态骨架
pose_img = OpenposeDetector.from_pretrained("lllyasviel/ControlNet")(
    load_image("person.jpg"), detect_resolution=1024, image_resolution=1024)

# MultiControlNet: 多条件组合
controlnets = [
    ControlNetModel.from_pretrained("diffusers/controlnet-canny-sdxl-1.0",
        torch_dtype=torch.float16, variant="fp16"),
    ControlNetModel.from_pretrained("diffusers/controlnet-depth-sdxl-1.0",
        torch_dtype=torch.float16, variant="fp16")]

pipe = StableDiffusionXLControlNetPipeline.from_pretrained(
    "stabilityai/stable-diffusion-xl-base-1.0",
    controlnet=controlnets, torch_dtype=torch.float16)
pipe.enable_model_cpu_offload()

image = pipe("Beautiful forest scene, photorealistic",
    image=[canny_image, depth_image],          # 多个条件图
    controlnet_conditioning_scale=[0.6, 0.4],  # 各自权重
    guidance_scale=7.5, num_inference_steps=30).images[0]

9.4 FLUX ControlNet

Python
from diffusers import FluxControlNetPipeline, FluxControlNetModel

controlnet = FluxControlNetModel.from_pretrained(
    "InstantX/FLUX.1-dev-Controlnet-Canny", torch_dtype=torch.bfloat16)
pipe = FluxControlNetPipeline.from_pretrained(
    "black-forest-labs/FLUX.1-dev",
    controlnet=controlnet, torch_dtype=torch.bfloat16)
pipe.enable_model_cpu_offload()

image = pipe(
    "A detailed watercolor painting of an ancient castle",
    control_image=canny_image,
    controlnet_conditioning_scale=0.6,
    guidance_scale=3.5, num_inference_steps=28,
    height=1024, width=1024).images[0]
image.save("flux_controlnet.png")

📝 面试考点 1. controlnet_conditioning_scale 作用? 控制条件信号强度:0→纯文本,1→最大遵循条件图。 2. MultiControlNet 融合方式? 各 ControlNet 独立产生控制信号,按各自权重加权后累加到 decoder 对应层。


10. CFG 推导与面试高频题

10.1 Classifier-Free Guidance 完整推导

📌 基础参见 02-条件生成与引导,此处给出面试所需完整推导。

贝叶斯分解\(p(x|c) = \frac{p(c|x)p(x)}{p(c)}\)

取 log 两侧:\(\log p(x|c) = \log p(x) + \log p(c|x) - \log p(c)\)

\(x\) 求梯度(score function):

\[\nabla_x \log p(x|c) = \nabla_x \log p(x) + \nabla_x \log p(c|x)\]

用两次模型调用替代分类器:

\[\nabla_x \log p(c|x) = \nabla_x \log p(x|c) - \nabla_x \log p(x)\]

代入并引入引导强度 \(w\)

\[\hat\epsilon = \epsilon_\theta(x_t, \varnothing) + w \cdot [\epsilon_\theta(x_t, c) - \epsilon_\theta(x_t, \varnothing)]\]

在 Rectified Flow 中作用于速度场:

\[\hat{v} = v_\theta(x_t, t, \varnothing) + w \cdot [v_\theta(x_t, t, c) - v_\theta(x_t, t, \varnothing)]\]

训练时随机以 ~10% 概率将条件替换为空 \(\varnothing\),使模型同时学有/无条件去噪。

10.2 十道面试精选

Q1: FLUX DiT 架构相比 U-Net 有哪些核心优势?

  1. Scaling Laws 可预测:模型越大越好,改进趋势平滑
  2. MM-DiT 双向融合:vs cross-attention 单向 text→image
  3. 2D RoPE 分辨率外推:vs 固定位置编码无法泛化
  4. 复用 LLM 基础设施:FSDP、Flash Attention、Tensor Parallelism

Q2: 请推导 Classifier-Free Guidance 公式。

\(p(x|c)\propto p(c|x)p(x)\) 出发,取 log 对 \(x\) 求梯度得 \(\nabla\log p(x|c)=\nabla\log p(x)+\nabla\log p(c|x)\)。用 \(\nabla\log p(c|x)=\nabla\log p(x|c)-\nabla\log p(x)\) 替代分类器梯度,得:

\[\hat\epsilon = (1-w)\epsilon_\theta(\varnothing) + w\cdot\epsilon_\theta(c)\]

Q3: ControlNet 为什么用 zero-convolution?

桥初始化为零→训练开始控制信号为零→原始模型行为不变→控制渐进注入。副本内部参数继承自预训练(非零),梯度正常反传更新 zero-conv。随机初始化会瞬间破坏预训练质量。


Q4: Rectified Flow 比 DDPM 采样效率高的根本原因?

直线路径 \((1-t)x_0+tx_1\) 使速度场 \(v\approx x_1-x_0\) 近似与 \(t\) 无关,Euler 离散化误差极小。DDPM 的曲线路径需多步精确追踪,否则误差累积。


Q5: IP-Adapter 解耦 cross-attention 优于拼接的原因?

拼接方式:文本/图像 token 共享 softmax,互相竞争注意力权重,导致条件干扰。解耦方式:各自有独立 K/V 投影矩阵,可用 \(\lambda\) 独立控制强度,text prompt 和 image prompt 互不影响。


Q6: SD3/FLUX 的 16 通道 VAE 意义?

4→16ch 增大潜空间信息容量→VAE 重建质量显著提升→减少高频细节损失(文字笔画、发丝)。代价是 Transformer 输入维度增加(patchify 后 64 维),但 12B 模型可承受。


Q7: FLUX 双流→单流混合架构的设计逻辑?

浅层(19 层双流):保留模态独立性,各自 Norm/MLP 适应不同模态分布。深层(38 层单流):拼接共享参数,实现文本語义和图像特征的深度融合。全双流参数过多;全单流过早强制模态统一。


Q8: SD1.5 的 ControlNet 能否直接用于 SDXL?

不能。三个不兼容:①通道数不同(各层 channel 数不匹配);②block 结构不同(attention head 数、block 布局);③条件机制不同(SDXL 有额外微条件)。需重新训练或使用社区已训练的 SDXL ControlNet。


Q9: ODE vs SDE 采样各自优劣?

ODE:确定性、可复现、少步高质量(FLUX 默认)。SDE:有随机性→更多样化,噪声可修正 ODE 小误差。RF 直线路径下 ODE 已足够好,SDE 作为可选项追求多样性。


Q10: 生产环境如何优化 FLUX 推理性能?

  1. 量化:NF4(40GB→8GB)、FP8(H100 上 2x 加速)
  2. 蒸馏:用 schnell(4步)替 dev(28步),7x 加速
  3. 编译torch.compile(pipe.transformer, mode="max-autotune") 提速 20-40%
  4. 注意力:Flash Attention (SDPA 自动启用)
  5. 缓存:预计算 T5/CLIP 文本编码复用
  6. 并行:Tensor Parallelism 分片到多GPU

11. 总结与知识图谱

Text Only
FLUX与新一代架构
├── FLUX.1架构
│   ├── DiT Backbone(纯Transformer, 12B参数)
│   ├── MM-DiT双流Joint Attention(图文双向信息流)
│   ├── 2D RoPE(分辨率外推 + 任意长宽比)
│   ├── 双流→单流混合(19层双流 + 38层单流)
│   └── 双编码器(CLIP-L pooled + T5-XXL序列)
├── 架构对比: FLUX vs SD3 vs SDXL
│   ├── U-Net → MM-DiT → 混合MM-DiT+DiT
│   ├── DDPM → Rectified Flow
│   └── 4ch VAE → 16ch VAE
├── Rectified Flow
│   ├── 直线插值路径: x_t = (1-t)x_0 + tx_1
│   ├── v-prediction: v = x_1 - x_0
│   ├── ODE/SDE采样
│   └── ReFlow蒸馏 → schnell 1-4步
├── 条件控制
│   ├── ControlNet: zero-conv + 预训练复用 → 高精度空间控制
│   ├── ControlNet for FLUX: DiT block级zero-linear注入
│   ├── IP-Adapter: 解耦cross-attn → 图像语义控制
│   └── T2I-Adapter: 轻量特征加法 → 低显存控制
└── 工程实战
    ├── diffusers加载FLUX (量化/offload)
    ├── ControlNet条件控图 (Canny/Depth/Pose)
    ├── MultiControlNet多条件组合
    └── 生产优化 (量化/编译/缓存/并行)

📚 参考文献

  1. Peebles & Xie (2023). "Scalable Diffusion Models with Transformers" (DiT). ICCV 2023.
  2. Esser et al. (2024). "Scaling Rectified Flow Transformers for High-Resolution Image Synthesis" (SD3). ICML 2024.
  3. Black Forest Labs (2024). "FLUX.1 Technical Report".
  4. Liu et al. (2023). "Flow Straight and Fast: Learning to Generate and Transfer Data with Rectified Flow". ICLR 2023.
  5. Zhang et al. (2023). "Adding Conditional Control to Text-to-Image Diffusion Models" (ControlNet). ICCV 2023.
  6. Ye et al. (2023). "IP-Adapter: Text Compatible Image Prompt Adapter for Text-to-Image Diffusion Models".
  7. Mou et al. (2023). "T2I-Adapter: Learning Adapters to Dig out More Controllable Ability". AAAI 2024.
  8. Lipman et al. (2023). "Flow Matching for Generative Modeling". ICLR 2023.
  9. Ho & Salimans (2022). "Classifier-Free Diffusion Guidance". NeurIPS 2021 Workshop.
  10. Su et al. (2021). "RoFormer: Enhanced Transformer with Rotary Position Embedding".