跳转至

第12章 视觉Transformer

视觉Transformer图

📚 章节概述

本章深入讲解Transformer在计算机视觉中的革命性应用。从2020年ViT横空出世,到2025年ViT已成为视觉基础模型的标配架构,深刻改变了计算机视觉的技术范式。本章将从ViT原始架构出发,涵盖主流变体演进、高效设计、下游任务应用,以及工程实践中的关键代码实现。

学习时间:7-10天 难度等级:⭐⭐⭐⭐⭐ 前置知识:第5-6章CNN基础、NLP领域Transformer架构(Self-Attention / Multi-Head Attention / Positional Encoding)

🎯 学习目标

完成本章后,你将能够: - 深入理解ViT的每个组件(Patch Embedding / Position Embedding / CLS Token / MHA) - 掌握Swin Transformer的层级设计与窗口注意力机制 - 了解DeiT、BEiT、MAE、EVA、SigLIP等重要变体的核心创新 - 比较ViT与CNN的归纳偏置、数据效率、计算复杂度差异 - 理解ViT在检测(DETR)、分割(SAM)、生成(DiT)等下游任务的应用 - 能够用PyTorch手写完整的ViT前向传播 - 熟练使用HuggingFace预训练ViT模型 - 掌握8道高频大厂面试题


12.1 ViT原始架构深度解析

12.1.1 论文背景

论文An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale(Dosovitskiy et al., ICLR 2021)

核心思想:完全抛弃卷积操作,将图像视为一组patch序列,直接使用标准Transformer编码器进行分类。在JFT-300M大规模数据集上预训练后,ViT在ImageNet上达到了SOTA。

关键发现: - Transformer在大规模数据下可以超越CNN - 在中小数据集上ViT不如CNN(缺乏归纳偏置) - Scaling Law在视觉领域同样成立

12.1.2 Patch Embedding

Patch Embedding是ViT将2D图像转换为1D序列的核心步骤。

原理: 1. 将 \(H \times W \times C\) 的图像切分为 \(N\) 个不重叠的patch,每个patch大小为 \(P \times P\) 2. patch数量 \(N = \frac{H \times W}{P^2}\),例如 \(224 \times 224\) 图像、\(P=16\)\(N = 196\) 3. 每个patch展平为 \(P^2 \cdot C\) 维向量(\(16 \times 16 \times 3 = 768\)) 4. 通过线性投影映射到 \(D\) 维嵌入空间

实现方式:实践中通常用一个 kernel_size=stride=P 的卷积层高效实现:

Python
import torch
import torch.nn as nn
import torch.nn.functional as F
import math

class PatchEmbedding(nn.Module):  # 继承nn.Module定义网络层
    """将图像分割为patches并投影到嵌入空间

    用Conv2d实现,等价于:切patch → 展平 → 线性投影
    但Conv2d利用了GPU并行计算,效率更高
    """
    def __init__(self, img_size=224, patch_size=16, in_channels=3, embed_dim=768):
        super().__init__()  # super()调用父类方法
        self.img_size = img_size
        self.patch_size = patch_size
        self.n_patches = (img_size // patch_size) ** 2  # 196 for 224/16

        # Conv2d with kernel_size=stride=patch_size 等价于切patch+线性投影
        self.proj = nn.Conv2d(
            in_channels, embed_dim,
            kernel_size=patch_size, stride=patch_size
        )

    def forward(self, x):
        # x: (B, C, H, W) e.g. (B, 3, 224, 224)
        x = self.proj(x)        # (B, embed_dim, H/P, W/P) e.g. (B, 768, 14, 14)
        x = x.flatten(2)        # (B, embed_dim, N) e.g. (B, 768, 196)
        x = x.transpose(1, 2)   # (B, N, embed_dim) e.g. (B, 196, 768)
        return x

12.1.3 Position Embedding

Transformer本身对输入顺序不敏感(排列等变性),因此必须额外注入位置信息。

ViT使用可学习的1D位置编码

位置编码类型 描述 代表模型
可学习1D 每个位置一个可训练向量,直接加到patch embedding上 ViT
正弦余弦 使用固定的sin/cos函数,不需要训练 原始Transformer
可学习2D 分别为行和列学习位置编码 DeiT-III
相对位置偏置 编码patch对之间的相对距离 Swin Transformer
旋转位置编码(RoPE) 通过旋转矩阵注入位置 EVA-02
无位置编码 利用卷积隐式编码位置 CvT

关键细节:ViT论文实验表明,可学习1D与可学习2D位置编码效果几乎一致,因此默认采用更简单的1D方案。

Python
# 可学习的1D位置编码
# +1是因为还有一个CLS token
pos_embed = nn.Parameter(torch.zeros(1, n_patches + 1, embed_dim))
nn.init.trunc_normal_(pos_embed, std=0.02)

位置编码的可视化:训练后的位置编码会自动学习到2D空间结构——相邻patch的位置编码余弦相似度高,对角方向也呈现规律性。

12.1.4 CLS Token

设计动机:借鉴BERT的 [CLS] token设计,在patch序列前添加一个可学习的分类token。

工作流程: 1. 初始化一个可学习的向量 cls_token ∈ R^D 2. 拼接到patch embedding序列的最前面:[CLS, p1, p2, ..., pN] 3. 经过Transformer编码器后,CLS token聚合了全局信息 4. 取CLS token的输出接分类头

替代方案:全局平均池化(Global Average Pooling,GAP) - DeiT、BEiT等后续工作发现GAP效果与CLS token相当 - GAP无需额外参数,更简洁

Python
# 方案1: CLS Token(ViT原始)
cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
output = transformer_output[:, 0]  # 取CLS位置

# 方案2: Global Average Pooling(很多变体采用)
output = transformer_output[:, 1:].mean(dim=1)  # 对所有patch取平均

12.1.5 Multi-Head Self-Attention (MHSA)

MHSA是Transformer的核心计算模块。对于输入序列 \(X \in \mathbb{R}^{N \times D}\)

\[Q = XW_Q, \quad K = XW_K, \quad V = XW_V\]
\[\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right) V\]

多头机制:将 \(D\) 维分为 \(h\) 个头,每个头维度 \(d_k = D/h\),并行计算后拼接。

计算复杂度分析: - Self-Attention:\(O(N^2 \cdot D)\)\(N=196\) 时计算量可控 - 但\(N\)随分辨率平方增长:\(448 \times 448 \rightarrow N=784\),复杂度增4倍

Python
class MultiHeadSelfAttention(nn.Module):
    """多头自注意力机制

    实现要点:
    1. QKV投影用一个线性层高效计算
    2. 缩放因子 sqrt(d_k) 防止softmax梯度消失
    3. 支持attention dropout
    """
    def __init__(self, embed_dim=768, num_heads=12, attn_drop=0., proj_drop=0.):
        super().__init__()
        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads
        self.scale = self.head_dim ** -0.5  # 1/sqrt(d_k)

        # 一个线性层同时计算Q, K, V(效率更高)
        self.qkv = nn.Linear(embed_dim, embed_dim * 3)
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(embed_dim, embed_dim)
        self.proj_drop = nn.Dropout(proj_drop)

    def forward(self, x):
        B, N, C = x.shape

        # QKV投影并reshape为多头格式
        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim)  # 重塑张量形状
        # permute将QKV维度提到最前面,unbind(0)沿第0维拆分为3个张量,避免3次独立线性层节省计算
        qkv = qkv.permute(2, 0, 3, 1, 4)  # (3, B, heads, N, head_dim)
        q, k, v = qkv.unbind(0)  # 各 (B, heads, N, head_dim)

        # Scaled Dot-Product Attention
        attn = (q @ k.transpose(-2, -1)) * self.scale  # (B, heads, N, N)
        attn = attn.softmax(dim=-1)
        attn = self.attn_drop(attn)

        # 加权求和并重组
        x = (attn @ v).transpose(1, 2).reshape(B, N, C)
        x = self.proj(x)
        x = self.proj_drop(x)
        return x

12.1.6 Transformer Block(Encoder Layer)

ViT使用Pre-Norm结构(LayerNorm在Attention/MLP之前),区别于原始Transformer的Post-Norm:

Text Only
x → LayerNorm → MHSA → + → LayerNorm → MLP → +
↑________________________|  ↑___________________|
       residual                  residual

MLP:两层全连接 + GELU激活,隐藏层维度通常为4倍embed_dim。

Python
class MLP(nn.Module):
    """前馈神经网络(FFN)"""
    def __init__(self, embed_dim=768, mlp_ratio=4.0, drop=0.):
        super().__init__()
        hidden_dim = int(embed_dim * mlp_ratio)
        self.fc1 = nn.Linear(embed_dim, hidden_dim)
        self.act = nn.GELU()
        self.fc2 = nn.Linear(hidden_dim, embed_dim)
        self.drop = nn.Dropout(drop)

    def forward(self, x):
        x = self.fc1(x)
        x = self.act(x)
        x = self.drop(x)
        x = self.fc2(x)
        x = self.drop(x)
        return x

class TransformerBlock(nn.Module):
    """Pre-Norm Transformer Block"""
    def __init__(self, embed_dim=768, num_heads=12, mlp_ratio=4.0,
                 drop=0., attn_drop=0.):
        super().__init__()
        self.norm1 = nn.LayerNorm(embed_dim)
        self.attn = MultiHeadSelfAttention(embed_dim, num_heads, attn_drop, drop)
        self.norm2 = nn.LayerNorm(embed_dim)
        self.mlp = MLP(embed_dim, mlp_ratio, drop)

    def forward(self, x):
        x = x + self.attn(self.norm1(x))   # 残差连接
        x = x + self.mlp(self.norm2(x))    # 残差连接
        return x

12.1.7 完整ViT模型

Python
class VisionTransformer(nn.Module):
    """完整的Vision Transformer实现

    ViT-Base:  embed_dim=768,  depth=12, num_heads=12  (86M params)
    ViT-Large: embed_dim=1024, depth=24, num_heads=16  (307M params)
    ViT-Huge:  embed_dim=1280, depth=32, num_heads=16  (632M params)
    """
    def __init__(self, img_size=224, patch_size=16, in_channels=3, num_classes=1000,
                 embed_dim=768, depth=12, num_heads=12, mlp_ratio=4.0,
                 drop_rate=0., attn_drop_rate=0.):
        super().__init__()
        self.num_classes = num_classes
        self.embed_dim = embed_dim

        # 1. Patch Embedding
        self.patch_embed = PatchEmbedding(img_size, patch_size, in_channels, embed_dim)
        num_patches = self.patch_embed.n_patches

        # 2. CLS Token
        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))

        # 3. Position Embedding (可学习的1D)
        self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
        self.pos_drop = nn.Dropout(drop_rate)

        # 4. Transformer Encoder
        self.blocks = nn.Sequential(*[
            TransformerBlock(embed_dim, num_heads, mlp_ratio, drop_rate, attn_drop_rate)
            for _ in range(depth)
        ])

        # 5. Classification Head
        self.norm = nn.LayerNorm(embed_dim)
        self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity()

        # 参数初始化
        self._init_weights()

    def _init_weights(self):
        nn.init.trunc_normal_(self.pos_embed, std=0.02)
        nn.init.trunc_normal_(self.cls_token, std=0.02)
        self.apply(self._init_module_weights)

    def _init_module_weights(self, m):
        if isinstance(m, nn.Linear):  # isinstance检查类型
            nn.init.trunc_normal_(m.weight, std=0.02)
            if m.bias is not None:
                nn.init.zeros_(m.bias)
        elif isinstance(m, nn.LayerNorm):
            nn.init.ones_(m.weight)
            nn.init.zeros_(m.bias)

    def forward(self, x):
        B = x.shape[0]

        # Step 1: Patch Embedding → (B, N, D)
        x = self.patch_embed(x)

        # Step 2: Prepend CLS Token → (B, N+1, D)
        cls_tokens = self.cls_token.expand(B, -1, -1)
        x = torch.cat([cls_tokens, x], dim=1)  # torch.cat沿已有维度拼接张量

        # Step 3: Add Position Embedding
        x = x + self.pos_embed
        x = self.pos_drop(x)

        # Step 4: Transformer Encoder
        x = self.blocks(x)

        # Step 5: Classification
        x = self.norm(x)
        cls_output = x[:, 0]  # 取CLS token输出
        logits = self.head(cls_output)
        return logits

# ViT模型变体配置
def vit_base_patch16_224(**kwargs):  # *args接收任意位置参数,**kwargs接收任意关键字参数
    return VisionTransformer(
        patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs)

def vit_large_patch16_224(**kwargs):
    return VisionTransformer(
        patch_size=16, embed_dim=1024, depth=24, num_heads=16, **kwargs)

def vit_huge_patch14_224(**kwargs):
    return VisionTransformer(
        patch_size=14, embed_dim=1280, depth=32, num_heads=16, **kwargs)

12.1.8 ViT模型族参数对比

模型 Patch Size Embed Dim Depth Heads Params ImageNet Top-1
ViT-S/16 16 384 12 6 22M 79.9%
ViT-B/16 16 768 12 12 86M 84.5%
ViT-L/16 16 1024 24 16 307M 87.1%
ViT-H/14 14 1280 32 16 632M 88.6%
ViT-G/14 14 1664 48 16 1843M 90.5%

注:Top-1精度基于JFT-300M预训练 + ImageNet微调


12.2 ViT变体演进

12.2.1 DeiT — 数据高效的知识蒸馏

论文Training data-efficient image transformers & distillation through attention(Touvron et al., ICML 2021)

核心动机:ViT需要JFT-300M级别数据才能超越CNN,DeiT证明了仅用ImageNet-1K也可以训练出强大的ViT。

关键创新: 1. 知识蒸馏Token:在CLS token之外,额外添加一个distillation token - 蒸馏token向教师模型(RegNetY-16GF)学习 - 最终预测 = CLS预测 + 蒸馏token预测的加权平均 2. 强数据增强:RandAugment、Mixup、CutMix、Random Erasing 3. 正则化策略:Stochastic Depth、Repeated Augmentation

Text Only
输入序列: [CLS_token, distill_token, patch_1, patch_2, ..., patch_N]
分类损失: CrossEntropy(CLS_output, label)
蒸馏损失: KL_Div(distill_output, teacher_output) 或 CrossEntropy(distill_output, teacher_hard_label)

DeiT-III(2022更新):引入3-Augment策略(灰度+翻转+SolarizeAdd),比DeiT效果更好。

模型 ImageNet Top-1 预训练数据
DeiT-S 79.8% ImageNet-1K
DeiT-B 81.8% ImageNet-1K
DeiT-B (蒸馏) 83.4% ImageNet-1K
DeiT-III-L 87.7% ImageNet-21K

12.2.2 BEiT — BERT式视觉预训练

论文BEiT: BERT Pre-Training of Image Transformers(Bao et al., ICLR 2022)

核心思想:借鉴BERT的Masked Language Modeling (MLM),提出Masked Image Modeling (MIM)。

预训练流程: 1. 使用离散VAE(dVAE)将图像patch编码为离散visual tokens 2. 随机mask约40%的patch 3. 让ViT根据未mask的patch预测被mask位置的visual token

Text Only
原始图像 → Patch序列 [p1, p2, ..., p196]
                           ↓ 随机mask
掩码图像 → [p1, [MASK], p3, [MASK], p5, ...]
                           ↓ ViT编码
                    预测被mask位置的visual token ID

BEiT v2(2022):使用VQ-KD(向量量化知识蒸馏)替代dVAE,获得更好的视觉词典。

BEiT-3(2023):统一的多模态预训练,用Multiway Transformer同时处理图像、文本和图像-文本对。

12.2.3 MAE — 掩码自编码器

论文Masked Autoencoders Are Scalable Vision Learners(He et al., CVPR 2022)

核心创新:极高掩码率(75%)+ 非对称编码器-解码器架构。

与BEiT的区别: - MAE直接在像素空间重建(无需visual tokenizer) - 编码器只处理可见patch(25%),极大节省计算 - 解码器轻量(仅用于预训练),下游任务只用编码器

Text Only
原始图像 196个patch
   ↓ 随机mask 75%
可见patch (49个) → 【重型Encoder】 → 编码特征
   ↓ 加入mask tokens和位置编码
全部tokens (196个) → 【轻量Decoder】 → 重建像素
MSE Loss (仅在被mask的位置计算)

为什么75%掩码率有效: - 图像有大量空间冗余(相邻区域高度相关) - 高掩码率迫使模型学习高层语义,而非简单的插值 - 节省预训练时间(编码器只处理25%的token)

12.2.4 EVA — 大规模视觉基础模型

论文EVA: Exploring the Limits of Masked Visual Representation Learning at Scale(Fang et al., CVPR 2023)

关键创新: - 使用CLIP的视觉特征作为MIM的重建目标(而非像素或离散token) - 证明MIM + CLIP特征蒸馏是scaling ViT的有效方案

EVA系列演进: | 模型 | 参数量 | ImageNet Top-1 | 核心创新 | |------|--------|---------------|---------| | EVA | 1.0B | 89.6% | CLIP特征作为MIM目标 | | EVA-02 | 304M | 90.0% | RoPE位置编码 + SwiGLU FFN | | EVA-CLIP | 5.0B | 最强开源CLIP | 融合EVA与CLIP训练 |

12.2.5 SigLIP — 更好的视觉-语言对齐

论文Sigmoid Loss for Language Image Pre-Training(Zhai et al., ICCV 2023)

核心改进:用Sigmoid loss替代CLIP的Softmax (InfoNCE) loss。

Softmax (CLIP) vs Sigmoid (SigLIP): - CLIP:需要全局归一化,依赖大batch size(32K+),分布式训练需要all-gather操作 - SigLIP:逐对独立计算sigmoid,无需全局归一化,对batch size不敏感,可更高效地分布式训练

\[\text{CLIP Loss} = -\log \frac{\exp(\text{sim}(x_i, y_i)/\tau)}{\sum_j \exp(\text{sim}(x_i, y_j)/\tau)}\]
\[\text{SigLIP Loss} = -\frac{1}{N}\sum_{i,j} \log \sigma((-1)^{[i \neq j]} z_{ij}/\tau)\]

SigLIP在2024-2025已成为VLM(如PaLI-Gemma、LLaVA-1.5)的标配视觉编码器。


12.3 Swin Transformer深入

12.3.1 设计动机

ViT的全局自注意力存在两个问题: 1. 计算复杂度\(O(N^2)\) 随分辨率平方增长,难以处理高分辨率图像 2. 缺乏多尺度特征:单一分辨率的token序列,不适合密集预测任务(检测/分割)

Swin Transformer通过层级设计 + 窗口注意力解决这两个问题。

12.3.2 层级设计(Hierarchical Architecture)

Swin采用类似CNN的4阶段层级结构,每阶段通过Patch Merging下采样:

Text Only
Stage 1: H/4 × W/4,  dim=C       (56×56, 96维)
    ↓ Patch Merging (2×2 → 1, dim×2)
Stage 2: H/8 × W/8,  dim=2C      (28×28, 192维)
    ↓ Patch Merging
Stage 3: H/16 × W/16, dim=4C     (14×14, 384维)
    ↓ Patch Merging
Stage 4: H/32 × W/32, dim=8C     (7×7, 768维)

Patch Merging:将相邻 \(2 \times 2\) 个patch拼接后通过线性层降维,类似CNN的stride-2卷积。

12.3.3 窗口注意力(Window Attention)

将feature map划分为 \(M \times M\)(默认7×7)的不重叠窗口,在每个窗口内部做自注意力:

  • 计算复杂度:\(O(M^2 \cdot N)\),其中 \(N\) 为总token数 → 线性复杂度
  • 对比全局注意力 \(O(N^2)\),窗口注意力在高分辨率下优势巨大

12.3.4 移位窗口(Shifted Window)

问题:窗口注意力导致窗口之间没有信息交互。

解决方案:交替使用普通窗口和移位窗口: - 奇数层:常规窗口划分 - 偶数层:窗口向右下移动 \(M/2\) 个像素后重新划分

高效实现:通过cyclic shift + mask实现,避免实际创建更多窗口:

Python
class WindowAttention(nn.Module):
    """Swin Transformer的窗口注意力(含相对位置偏置)"""
    def __init__(self, dim, window_size, num_heads):
        super().__init__()
        self.dim = dim
        self.window_size = window_size  # (Wh, Ww)
        self.num_heads = num_heads
        head_dim = dim // num_heads
        self.scale = head_dim ** -0.5

        self.qkv = nn.Linear(dim, dim * 3)
        self.proj = nn.Linear(dim, dim)

        # 相对位置偏置表
        # (2*Wh-1) * (2*Ww-1) 个可能的相对位置
        self.relative_position_bias_table = nn.Parameter(
            torch.zeros((2*window_size[0]-1) * (2*window_size[1]-1), num_heads)
        )
        nn.init.trunc_normal_(self.relative_position_bias_table, std=0.02)

        # 计算每个token对的相对位置索引
        coords_h = torch.arange(window_size[0])
        coords_w = torch.arange(window_size[1])
        coords = torch.stack(torch.meshgrid([coords_h, coords_w], indexing='ij'))  # (2, Wh, Ww)  # torch.stack沿新维度拼接张量
        coords_flatten = torch.flatten(coords, 1)  # (2, Wh*Ww)

        relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]  # (2, N, N)
        relative_coords = relative_coords.permute(1, 2, 0).contiguous()  # (N, N, 2)
        relative_coords[:, :, 0] += window_size[0] - 1
        relative_coords[:, :, 1] += window_size[1] - 1
        relative_coords[:, :, 0] *= 2 * window_size[1] - 1
        relative_position_index = relative_coords.sum(-1)  # (N, N)
        self.register_buffer("relative_position_index", relative_position_index)

    def forward(self, x, mask=None):
        B_, N, C = x.shape  # B_为窗口总数
        qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads)
        qkv = qkv.permute(2, 0, 3, 1, 4)
        q, k, v = qkv.unbind(0)

        attn = (q @ k.transpose(-2, -1)) * self.scale

        # 添加相对位置偏置
        # 用index从偏置表查找→view(N,N,heads)→permute为(heads,N,N)→unsqueeze(0)加batch维以广播加到注意力分数上
        relative_position_bias = self.relative_position_bias_table[
            self.relative_position_index.view(-1)
        ].view(N, N, -1).permute(2, 0, 1).contiguous()
        attn = attn + relative_position_bias.unsqueeze(0)  # unsqueeze增加一个维度

        # 移位窗口的mask
        if mask is not None:
            nW = mask.shape[0]
            attn = attn.view(B_ // nW, nW, self.num_heads, N, N)
            attn = attn + mask.unsqueeze(1).unsqueeze(0)  # 被mask位置设为-inf
            attn = attn.view(-1, self.num_heads, N, N)

        attn = attn.softmax(dim=-1)
        x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
        x = self.proj(x)
        return x

def window_partition(x, window_size):
    """将feature map划分为窗口"""
    B, H, W, C = x.shape
    x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
    windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
    return windows  # (num_windows*B, window_size, window_size, C)

def window_reverse(windows, window_size, H, W):
    """将窗口还原为feature map"""
    B = int(windows.shape[0] / (H * W / window_size / window_size))
    x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
    x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
    return x

12.3.5 Swin Transformer模型族

模型 Embed Dim Depths Heads Params ImageNet Top-1
Swin-T 96 [2,2,6,2] [3,6,12,24] 29M 81.3%
Swin-S 96 [2,2,18,2] [3,6,12,24] 50M 83.0%
Swin-B 128 [2,2,18,2] [4,8,16,32] 88M 83.5%
Swin-L 192 [2,2,18,2] [6,12,24,48] 197M 86.4%

Swin v2(2022):引入log-CPB(对数连续位置偏置)和余弦注意力,支持更高分辨率和更大模型。


12.4 高效ViT

12.4.1 EfficientViT (MIT)

论文EfficientViT: Lightweight Multi-Scale Attention for High-Resolution Dense Prediction(Cai et al., ICCV 2023)

核心创新: - 多尺度线性注意力:用线性注意力(\(O(N)\))替代标准softmax注意力(\(O(N^2)\)) - 级联分组注意力:将不同head分配给不同的特征分片,减少冗余计算 - 在分割任务(Cityscapes、ADE20K)上以极低延迟达到SOTA

模型 Params FLOPs ImageNet Top-1 GPU延迟
EfficientViT-B1 9.1M 0.52G 79.4% 0.3ms
EfficientViT-B2 24.3M 1.6G 82.1% 0.6ms
EfficientViT-B3 48.6M 4.0G 83.5% 1.1ms

12.4.2 MobileViT

论文MobileViT: Light-weight, General-purpose, and Mobile-friendly Vision Transformer(Mehta & Rastegari, ICLR 2022)

设计理念:结合MobileNet的轻量CNN和ViT的全局注意力。

  • 用MobileNetV2的Inverted Residual Block提取局部特征
  • 在关键阶段插入Transformer Block捕捉全局依赖
  • 参数量仅5-6M,适合移动端部署

12.4.3 TinyViT

论文TinyViT: Fast Pretraining Distillation for Small Vision Transformers(Wu et al., ECCV 2022)

方法:通过大模型蒸馏快速预训练小ViT(5M-21M参数),在ImageNet上达到与大模型相当的精度。

模型 Params ImageNet Top-1 用途
TinyViT-5M 5.4M 79.1% 移动端推理
TinyViT-11M 11M 81.5% 边缘设备
TinyViT-21M 21M 83.2% 轻量服务端

12.5 ViT在下游任务中的应用

12.5.1 目标检测 — DETR

论文End-to-End Object Detection with Transformers(Carion et al., ECCV 2020)

革命性意义:去掉了anchor、NMS等手工设计组件,用Transformer实现端到端检测。

架构

Text Only
图像 → CNN Backbone → Transformer Encoder → Transformer Decoder → 预测框+类别
                                              Object Queries (可学习)

匹配策略:用匈牙利算法做预测框与GT的二分匹配(Set Prediction)。

演进:DETR → Deformable DETR → DINO → Co-DETR → RT-DETR(实时版)

12.5.2 图像分割 — SAM

论文Segment Anything(Kirillov et al., ICCV 2023)

SAM (Segment Anything Model) 是视觉基础模型的里程碑: - Image Encoder:ViT-H(MAE预训练),提取图像特征 - Prompt Encoder:编码点击、框、文本等提示 - Mask Decoder:轻量Transformer解码器,生成分割掩码

SAM 2(2024):扩展到视频分割,引入Memory Mechanism追踪时序信息。

12.5.3 图像生成 — DiT

论文Scalable Diffusion Models with Transformers(Peebles & Xie, ICCV 2023)

DiT (Diffusion Transformer):用ViT替代U-Net作为扩散模型的去噪网络。

核心设计: - 输入:带噪声的latent patch序列 + 时间步embedding + 类别embedding - 用AdaLN-Zero(自适应LayerNorm)注入条件信息 - DiT是Sora的基础架构

Text Only
Noisy Latent → Patchify → Transformer Blocks (with AdaLN) → Unpatchify → Denoised Latent
                              ↑ timestep + class label

12.5.4 其他重要应用

任务 代表模型 核心用法
语义分割 SegFormer, Mask2Former ViT作为backbone + 分割解码器
深度估计 DPT, Depth Anything ViT编码器 + 多尺度解码器
视频理解 VideoMAE, TimeSformer 时空patch + 视频Transformer
点云处理 Point-BERT, Point-MAE 3D点云patch化 + ViT
医学影像 MedViT, SAM-Med 预训练ViT + 领域微调

12.6 ViT vs CNN 对比分析

12.6.1 归纳偏置(Inductive Bias)

特性 CNN ViT
局部性 卷积核限制感受野,天然捕捉局部模式 全局注意力,需要从数据中学习局部性
平移等变性 卷积权重共享保证平移等变 无此先验,依赖数据增强和位置编码
层级特征 pooling自然构建多尺度 原始ViT单尺度,需Swin等设计引入
训练效率 归纳偏置帮助小数据学习 小数据下表现差,大数据下优势显现

12.6.2 数据效率

Text Only
数据量          CNN (ResNet)    ViT
< 10K          ★★★★           ★★
10K - 1M       ★★★★           ★★★
1M - 10M       ★★★            ★★★★
> 100M         ★★★            ★★★★★

结论:ViT是一个更"通用"但也更"饥渴"的学习器——在足够数据下可以学到比CNN更好的表示。

12.6.3 计算复杂度

对于输入分辨率 \(H \times W\)、patch数 \(N = HW/P^2\)

操作 复杂度 说明
ViT Self-Attention \(O(N^2 \cdot D)\) 随分辨率二次增长
Swin Window Attention \(O(M^2 \cdot N \cdot D)\) 线性于N
CNN 3×3 Conv \(O(9 \cdot C^2 \cdot N)\) 线性于N
Linear Attention \(O(N \cdot D^2)\) 线性于N

12.6.4 实践建议

场景 推荐架构 原因
预训练数据充足 ViT/EVA Scaling表现最佳
目标检测/分割 Swin/ViTDet 多尺度特征+高分辨率
移动端部署 MobileViT/EfficientViT 轻量高效
中小数据集 ConvNeXt/CNN+ViT混合 归纳偏置有帮助
多模态应用 ViT+SigLIP 与语言模型对接方便

12.7 使用预训练ViT模型

12.7.1 HuggingFace推理

Python
import torch
from transformers import ViTForImageClassification, ViTImageProcessor
from PIL import Image

# 加载预训练模型和处理器
model_name = 'google/vit-base-patch16-224'
model = ViTForImageClassification.from_pretrained(model_name)
processor = ViTImageProcessor.from_pretrained(model_name)

model.eval()  # eval()评估模式

# 推理
image = Image.open('image.jpg').convert('RGB')
inputs = processor(images=image, return_tensors="pt")

with torch.no_grad():  # 禁用梯度计算,节省内存
    outputs = model(**inputs)
    logits = outputs.logits
    predicted_class_idx = logits.argmax(-1).item()  # 将单元素张量转为Python数值
    predicted_class = model.config.id2label[predicted_class_idx]
    confidence = torch.softmax(logits, dim=-1).max().item()

print(f"预测类别: {predicted_class}, 置信度: {confidence:.4f}")

12.7.2 timm库使用

Python
import timm
import torch
from PIL import Image
from timm.data import resolve_data_config, create_transform

# timm提供了几乎所有主流ViT变体
model = timm.create_model('vit_base_patch16_224.augreg_in21k_ft_in1k', pretrained=True)
model.eval()

# 自动获取模型对应的数据预处理
config = resolve_data_config(model.pretrained_cfg)
transform = create_transform(**config)

image = Image.open('image.jpg').convert('RGB')
input_tensor = transform(image).unsqueeze(0)

with torch.no_grad():
    output = model(input_tensor)
    probabilities = torch.softmax(output, dim=-1)
    top5_prob, top5_idx = probabilities.topk(5)

for i in range(5):
    print(f"Top-{i+1}: class={top5_idx[0][i].item()}, prob={top5_prob[0][i].item():.4f}")

# 列出所有可用的ViT模型
vit_models = timm.list_models('vit_*', pretrained=True)
print(f"可用ViT模型数量: {len(vit_models)}")

12.7.3 ViT微调示例

Python
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import timm

# 加载预训练ViT并修改分类头
model = timm.create_model('vit_base_patch16_224', pretrained=True, num_classes=10)

# 冻结backbone,只训练分类头(适合小数据集)
for name, param in model.named_parameters():
    if 'head' not in name:
        param.requires_grad = False

# 或者:全模型微调(适合中等数据集),使用较小学习率
optimizer = torch.optim.AdamW([
    {'params': model.head.parameters(), 'lr': 1e-3},           # 分类头大学习率
    {'params': [p for n, p in model.named_parameters()
                if 'head' not in n and p.requires_grad], 'lr': 1e-5}  # backbone小学习率
], weight_decay=0.05)

scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=100)
criterion = nn.CrossEntropyLoss(label_smoothing=0.1)

# 训练循环
def train_one_epoch(model, dataloader, optimizer, criterion, device):
    model.train()
    total_loss, correct, total = 0, 0, 0
    for images, labels in dataloader:
        images, labels = images.to(device), labels.to(device)  # 移至GPU/CPU

        outputs = model(images)
        loss = criterion(outputs, labels)

        optimizer.zero_grad()  # 清零梯度
        loss.backward()  # 反向传播计算梯度
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        optimizer.step()  # 更新参数

        total_loss += loss.item()
        correct += (outputs.argmax(1) == labels).sum().item()
        total += labels.size(0)

    return total_loss / len(dataloader), correct / total

12.8 练习题

基础题

  1. 简答题
  2. ViT的Patch Embedding相当于什么操作?为什么用Conv2d实现?

    Patch Embedding相当于用stride=patch_size的Conv2d对图像做不重叠卷积,将每个patch投影为一个向量。用Conv2d实现而非先reshape再线性层,是因为Conv2d能直接完成“切块+投影”两步,硬件利用率更高、实现更简洁。

  3. 解释Pre-Norm和Post-Norm的区别,ViT使用哪种?

    Post-Norm在残差连接之后做LayerNorm(原始Transformer),训练时梯度不稳定,需要学习率warmup。Pre-Norm在残差连接之前做LayerNorm,梯度可以直接通过残差路径流动,训练更稳定,适合深层网络。ViT使用Pre-Norm。

  4. 为什么ViT需要位置编码?可学习1D和2D位置编码效果有区别吗?

    Transformer的Self-Attention是置换不变的(permutation-invariant),无法感知token的空间位置,因此必须添加位置编码。实验表明可学习1D和2D位置编码效果差别很小(约±0.1%),因为1D编码在训练中可以自动学到类似2D的结构信息。

  5. 计算题

  6. 计算ViT-B/16在224×224输入时的patch数量、序列长度和Self-Attention的FLOPs。

    Patch数量 = \((224/16)^2 = 14 \times 14 = 196\);加上CLS token后序列长度 \(N = 196 + 1 = 197\)。ViT-B的隐藏维度 \(D = 768\)。单层Self-Attention的FLOPs ≈ \(2 \times 4N^2D = 2 \times 4 \times 197^2 \times 768 \approx 238M\)(包含Q/K/V投影和注意力计算)。ViT-B有12层,总注意力FLOPs ≈ \(12 \times 238M \approx 2.86G\)

  7. 如果将输入分辨率改为384×384,序列长度和计算量如何变化?

    Patch数量 = \((384/16)^2 = 24 \times 24 = 576\),序列长度 \(N = 577\)。相比224分辨率的197,序列长度增加约2.93倍。由于Self-Attention计算量与 \(N^2\) 成正比,注意力FLOPs增加约 \((577/197)^2 \approx 8.58\) 倍。总模型FLOPs增加约3倍左右(因为MLP部分仅线性增长)。

进阶题

  1. 编程题
  2. 从零实现一个完整的ViT(包含Patch Embedding、Position Embedding、MHA、MLP、CLS Token)。
  3. 使用timm加载预训练Swin-T,在CIFAR-10上微调。

  4. 分析题

  5. 比较MAE和BEiT的预训练策略,各有什么优劣?

    MAE随机掩码图像75%的patch,编码器只处理可见patch,解码器重建原始像素。优势:预训练效率高(3-4×加速)、无需额外组件、概念简洁。劣势:重建低层像素可能不够语义化。BEiT先用dVAE将图像编码为离散token,然后预测被掩码位置的token ID。优势:预测目标更语义化。劣势:需要预训练dVAE tokenizer,流程更复杂。

  6. 为什么DiT能替代U-Net成为扩散模型的backbone?

    Scaling能力更强:Transformer可轻松堆叠层数,遵循Scaling Law,U-Net的残差块+跳连结构难以扩展;②架构简洁,无需精心设计编解码器层数和通道数;③条件注入灵活,AdaLN-Zero可优雅地注入timestep和class/text条件;④与LLM生态统一,方便融入文本条件生成;⑤可利用MAE等预训练ViT权重初始化。


12.9 面试准备

大厂面试题

Q1: ViT相比CNN有什么优势和劣势?在什么场景下选择ViT?

参考答案优势: - 全局感受野:第一层就能看到所有patch之间的关系,CNN需要很深才能获得大感受野 - 可扩展性强:参数量和数据量增加时性能持续提升(Scaling Law) - 架构统一:与NLP的Transformer共享架构,便于多模态融合 - 预训练迁移好:MAE/CLIP等大规模预训练后迁移效果出色

劣势: - 数据饥渴:小数据集上不如CNN(缺乏locality、translation equivariance等归纳偏置) - 计算复杂度高:Self-Attention的O(N²)使高分辨率输入开销大 - 对分辨率敏感:位置编码与输入尺寸绑定,分辨率变化需要插值

选择建议:预训练数据充足时首选ViT;小数据用CNN或混合架构;移动端用EfficientViT/MobileViT;密集预测用Swin/ViTDet。


Q2: 请解释Swin Transformer的移位窗口机制及其作用

参考答案: Swin Transformer在相邻层交替使用两种窗口配置: - 常规窗口(W-MSA):将feature map均匀划分为 \(M \times M\) 窗口 - 移位窗口(SW-MSA):将窗口向右下偏移 \(M/2\),跨越常规窗口边界

作用: 1. 允许相邻窗口之间交换信息,避免信息孤岛 2. 保持线性计算复杂度(只在窗口内做注意力) 3. 高效实现:用cyclic shift + attention mask,无需创建额外窗口


Q3: MAE为什么要用75%这么高的掩码率?

参考答案: 1. 图像的空间冗余高:相邻像素高度相关,低掩码率下模型可以轻松"插值"恢复,学不到高层语义 2. 高掩码率迫使模型理解语义:看到极少信息时必须"推理"内容是什么 3. 计算效率:编码器只处理25%的visible tokens,预训练速度提升3-4倍 4. 与NLP对比:BERT用15%的掩码率,因为语言信息密度高,而图像信息密度低


Q4: DETR如何实现端到端目标检测?相比传统检测器有什么优势?

参考答案DETR流程: 1. CNN backbone提取特征 → Transformer Encoder全局建模 2. \(N\) 个可学习的Object Queries经Transformer Decoder交叉注意力 3. 每个query直接输出(class, box)预测 4. 用匈牙利算法做预测与GT的二分匹配,计算集合级别损失

优势:去掉了anchor设计、NMS后处理、FPN等手工组件,简化流程 劣势:训练收敛慢(500 epochs)、小目标检测弱 → Deformable DETR解决了这些问题


Q5: DiT为什么能替代U-Net成为扩散模型的标配backbone?

参考答案: 1. Scaling能力:U-Net的残差块+跳跃连接结构不易扩展,DiT基于Transformer可以轻松堆叠层数,遵循Scaling Law 2. 架构简洁:无需精心设计U-Net的编码器-解码器层数和通道数 3. 条件注入灵活:AdaLN-Zero可以优雅地注入timestep和class/text条件 4. 与LLM生态统一:方便融入文本条件生成(如Sora) 5. 预训练兼容:可以利用MAE等预训练的ViT权重初始化


Q6: SigLIP相比CLIP有什么改进?为什么成为VLM视觉编码器的标配?

参考答案改进:用Sigmoid loss替代InfoNCE (Softmax) loss - InfoNCE需要全batch的负样本做归一化 → 需要超大batch size和all-gather - Sigmoid loss对每个正/负样本对独立计算 → 不依赖batch size,分布式训练更高效

成为标配的原因: - 训练更稳定、效率更高 - 相同参数量下zero-shot性能更好 - 在PaLI-Gemma、LLaVA-Next等VLM中广泛使用


Q7: 如何将ViT的位置编码从224×224分辨率迁移到384×384?

参考答案: 使用双线性插值调整位置编码维度: 1. 将 \(14 \times 14\) 的2D位置编码reshape 2. 用 F.interpolate 双线性插值到 \(24 \times 24\) 3. 再展平回1D序列 4. CLS token的位置编码保持不变

Python
pos_embed_2d = pos_embed[:, 1:].reshape(1, 14, 14, D).permute(0, 3, 1, 2)
pos_embed_2d = F.interpolate(pos_embed_2d, size=(24, 24), mode='bicubic', align_corners=False)  # F.xxx PyTorch函数式API
new_pos_embed = torch.cat([pos_embed[:, :1], pos_embed_2d.flatten(2).transpose(1, 2)], dim=1)

Q8: 比较ViT、Swin Transformer和ConvNeXt三个架构的设计哲学

参考答案

特性 ViT Swin Transformer ConvNeXt
设计哲学 纯Transformer,最小视觉先验 Transformer + CNN层级设计 纯CNN + Transformer训练技巧
注意力范围 全局 局部窗口 局部(7×7深度可分离卷积)
多尺度特征 无(单尺度) 有(4阶段层级) 有(4阶段层级)
计算复杂度 O(N²) O(N) O(N)
最佳使用场景 基础模型预训练 密集预测任务 需要CNN生态兼容
代表应用 CLIP/MAE/DiT 检测/分割backbone 替代ResNet

12.10 关键论文列表

必读论文

年份 论文 核心贡献
2020 ViT: An Image is Worth 16x16 Words 开创性地将纯Transformer用于视觉
2021 DeiT: Training data-efficient image transformers 知识蒸馏+强数据增强训练ViT
2021 Swin Transformer: Hierarchical Vision Transformer 移位窗口+层级结构,适配密集预测
2021 BEiT: BERT Pre-Training of Image Transformers Masked Image Modeling预训练
2022 MAE: Masked Autoencoders Are Scalable Vision Learners 75%掩码+像素重建的高效预训练
2022 ConvNeXt: A ConvNet for the 2020s 用Transformer技巧现代化CNN

扩展论文

年份 论文 核心贡献
2020 DETR: End-to-End Object Detection with Transformers Transformer端到端检测
2023 EVA-02: A Visual Representation for Neon Genesis RoPE + SwiGLU的强力ViT
2023 SigLIP: Sigmoid Loss for Language Image Pre-Training 更高效的视觉-语言对比学习
2023 SAM: Segment Anything ViT驱动的视觉分割基础模型
2023 DiT: Scalable Diffusion Models with Transformers ViT替代U-Net做扩散模型
2023 DINOv2: Learning Robust Visual Features 自监督ViT视觉基础模型
2024 SAM 2: Segment Anything in Images and Videos ViT+Memory的视频分割

12.11 本章小结

核心知识点

  1. ViT原始架构:Patch Embedding → CLS Token + Position Embedding → Transformer Encoder → Classification Head
  2. ViT变体:DeiT(蒸馏)、BEiT(MIM)、MAE(掩码自编码)、EVA(大规模)、SigLIP(高效对比学习)
  3. Swin Transformer:层级设计 + 窗口注意力 + 移位窗口 + 相对位置偏置
  4. 高效ViT:EfficientViT、MobileViT、TinyViT——面向端侧部署
  5. 下游应用:DETR(检测)、SAM(分割)、DiT(生成)
  6. ViT vs CNN:归纳偏置、数据效率、计算复杂度的全面对比

下一步

下一章13-多模态学习.md - 学习CLIP、BLIP、LLaVA等视觉-语言多模态模型


恭喜完成第12章! 🎉 视觉Transformer已成为现代CV的基石——从分类、检测、分割到生成,ViT无处不在。