第12章 视觉Transformer¶
📚 章节概述¶
本章深入讲解Transformer在计算机视觉中的革命性应用。从2020年ViT横空出世,到2025年ViT已成为视觉基础模型的标配架构,深刻改变了计算机视觉的技术范式。本章将从ViT原始架构出发,涵盖主流变体演进、高效设计、下游任务应用,以及工程实践中的关键代码实现。
学习时间:7-10天 难度等级:⭐⭐⭐⭐⭐ 前置知识:第5-6章CNN基础、NLP领域Transformer架构(Self-Attention / Multi-Head Attention / Positional Encoding)
🎯 学习目标¶
完成本章后,你将能够: - 深入理解ViT的每个组件(Patch Embedding / Position Embedding / CLS Token / MHA) - 掌握Swin Transformer的层级设计与窗口注意力机制 - 了解DeiT、BEiT、MAE、EVA、SigLIP等重要变体的核心创新 - 比较ViT与CNN的归纳偏置、数据效率、计算复杂度差异 - 理解ViT在检测(DETR)、分割(SAM)、生成(DiT)等下游任务的应用 - 能够用PyTorch手写完整的ViT前向传播 - 熟练使用HuggingFace预训练ViT模型 - 掌握8道高频大厂面试题
12.1 ViT原始架构深度解析¶
12.1.1 论文背景¶
论文:An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale(Dosovitskiy et al., ICLR 2021)
核心思想:完全抛弃卷积操作,将图像视为一组patch序列,直接使用标准Transformer编码器进行分类。在JFT-300M大规模数据集上预训练后,ViT在ImageNet上达到了SOTA。
关键发现: - Transformer在大规模数据下可以超越CNN - 在中小数据集上ViT不如CNN(缺乏归纳偏置) - Scaling Law在视觉领域同样成立
12.1.2 Patch Embedding¶
Patch Embedding是ViT将2D图像转换为1D序列的核心步骤。
原理: 1. 将 \(H \times W \times C\) 的图像切分为 \(N\) 个不重叠的patch,每个patch大小为 \(P \times P\) 2. patch数量 \(N = \frac{H \times W}{P^2}\),例如 \(224 \times 224\) 图像、\(P=16\) → \(N = 196\) 3. 每个patch展平为 \(P^2 \cdot C\) 维向量(\(16 \times 16 \times 3 = 768\)) 4. 通过线性投影映射到 \(D\) 维嵌入空间
实现方式:实践中通常用一个 kernel_size=stride=P 的卷积层高效实现:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
class PatchEmbedding(nn.Module): # 继承nn.Module定义网络层
"""将图像分割为patches并投影到嵌入空间
用Conv2d实现,等价于:切patch → 展平 → 线性投影
但Conv2d利用了GPU并行计算,效率更高
"""
def __init__(self, img_size=224, patch_size=16, in_channels=3, embed_dim=768):
super().__init__() # super()调用父类方法
self.img_size = img_size
self.patch_size = patch_size
self.n_patches = (img_size // patch_size) ** 2 # 196 for 224/16
# Conv2d with kernel_size=stride=patch_size 等价于切patch+线性投影
self.proj = nn.Conv2d(
in_channels, embed_dim,
kernel_size=patch_size, stride=patch_size
)
def forward(self, x):
# x: (B, C, H, W) e.g. (B, 3, 224, 224)
x = self.proj(x) # (B, embed_dim, H/P, W/P) e.g. (B, 768, 14, 14)
x = x.flatten(2) # (B, embed_dim, N) e.g. (B, 768, 196)
x = x.transpose(1, 2) # (B, N, embed_dim) e.g. (B, 196, 768)
return x
12.1.3 Position Embedding¶
Transformer本身对输入顺序不敏感(排列等变性),因此必须额外注入位置信息。
ViT使用可学习的1D位置编码:
| 位置编码类型 | 描述 | 代表模型 |
|---|---|---|
| 可学习1D | 每个位置一个可训练向量,直接加到patch embedding上 | ViT |
| 正弦余弦 | 使用固定的sin/cos函数,不需要训练 | 原始Transformer |
| 可学习2D | 分别为行和列学习位置编码 | DeiT-III |
| 相对位置偏置 | 编码patch对之间的相对距离 | Swin Transformer |
| 旋转位置编码(RoPE) | 通过旋转矩阵注入位置 | EVA-02 |
| 无位置编码 | 利用卷积隐式编码位置 | CvT |
关键细节:ViT论文实验表明,可学习1D与可学习2D位置编码效果几乎一致,因此默认采用更简单的1D方案。
# 可学习的1D位置编码
# +1是因为还有一个CLS token
pos_embed = nn.Parameter(torch.zeros(1, n_patches + 1, embed_dim))
nn.init.trunc_normal_(pos_embed, std=0.02)
位置编码的可视化:训练后的位置编码会自动学习到2D空间结构——相邻patch的位置编码余弦相似度高,对角方向也呈现规律性。
12.1.4 CLS Token¶
设计动机:借鉴BERT的 [CLS] token设计,在patch序列前添加一个可学习的分类token。
工作流程: 1. 初始化一个可学习的向量 cls_token ∈ R^D 2. 拼接到patch embedding序列的最前面:[CLS, p1, p2, ..., pN] 3. 经过Transformer编码器后,CLS token聚合了全局信息 4. 取CLS token的输出接分类头
替代方案:全局平均池化(Global Average Pooling,GAP) - DeiT、BEiT等后续工作发现GAP效果与CLS token相当 - GAP无需额外参数,更简洁
# 方案1: CLS Token(ViT原始)
cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
output = transformer_output[:, 0] # 取CLS位置
# 方案2: Global Average Pooling(很多变体采用)
output = transformer_output[:, 1:].mean(dim=1) # 对所有patch取平均
12.1.5 Multi-Head Self-Attention (MHSA)¶
MHSA是Transformer的核心计算模块。对于输入序列 \(X \in \mathbb{R}^{N \times D}\):
多头机制:将 \(D\) 维分为 \(h\) 个头,每个头维度 \(d_k = D/h\),并行计算后拼接。
计算复杂度分析: - Self-Attention:\(O(N^2 \cdot D)\),\(N=196\) 时计算量可控 - 但\(N\)随分辨率平方增长:\(448 \times 448 \rightarrow N=784\),复杂度增4倍
class MultiHeadSelfAttention(nn.Module):
"""多头自注意力机制
实现要点:
1. QKV投影用一个线性层高效计算
2. 缩放因子 sqrt(d_k) 防止softmax梯度消失
3. 支持attention dropout
"""
def __init__(self, embed_dim=768, num_heads=12, attn_drop=0., proj_drop=0.):
super().__init__()
self.num_heads = num_heads
self.head_dim = embed_dim // num_heads
self.scale = self.head_dim ** -0.5 # 1/sqrt(d_k)
# 一个线性层同时计算Q, K, V(效率更高)
self.qkv = nn.Linear(embed_dim, embed_dim * 3)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(embed_dim, embed_dim)
self.proj_drop = nn.Dropout(proj_drop)
def forward(self, x):
B, N, C = x.shape
# QKV投影并reshape为多头格式
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim) # 重塑张量形状
# permute将QKV维度提到最前面,unbind(0)沿第0维拆分为3个张量,避免3次独立线性层节省计算
qkv = qkv.permute(2, 0, 3, 1, 4) # (3, B, heads, N, head_dim)
q, k, v = qkv.unbind(0) # 各 (B, heads, N, head_dim)
# Scaled Dot-Product Attention
attn = (q @ k.transpose(-2, -1)) * self.scale # (B, heads, N, N)
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
# 加权求和并重组
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x
12.1.6 Transformer Block(Encoder Layer)¶
ViT使用Pre-Norm结构(LayerNorm在Attention/MLP之前),区别于原始Transformer的Post-Norm:
x → LayerNorm → MHSA → + → LayerNorm → MLP → +
↑________________________| ↑___________________|
residual residual
MLP:两层全连接 + GELU激活,隐藏层维度通常为4倍embed_dim。
class MLP(nn.Module):
"""前馈神经网络(FFN)"""
def __init__(self, embed_dim=768, mlp_ratio=4.0, drop=0.):
super().__init__()
hidden_dim = int(embed_dim * mlp_ratio)
self.fc1 = nn.Linear(embed_dim, hidden_dim)
self.act = nn.GELU()
self.fc2 = nn.Linear(hidden_dim, embed_dim)
self.drop = nn.Dropout(drop)
def forward(self, x):
x = self.fc1(x)
x = self.act(x)
x = self.drop(x)
x = self.fc2(x)
x = self.drop(x)
return x
class TransformerBlock(nn.Module):
"""Pre-Norm Transformer Block"""
def __init__(self, embed_dim=768, num_heads=12, mlp_ratio=4.0,
drop=0., attn_drop=0.):
super().__init__()
self.norm1 = nn.LayerNorm(embed_dim)
self.attn = MultiHeadSelfAttention(embed_dim, num_heads, attn_drop, drop)
self.norm2 = nn.LayerNorm(embed_dim)
self.mlp = MLP(embed_dim, mlp_ratio, drop)
def forward(self, x):
x = x + self.attn(self.norm1(x)) # 残差连接
x = x + self.mlp(self.norm2(x)) # 残差连接
return x
12.1.7 完整ViT模型¶
class VisionTransformer(nn.Module):
"""完整的Vision Transformer实现
ViT-Base: embed_dim=768, depth=12, num_heads=12 (86M params)
ViT-Large: embed_dim=1024, depth=24, num_heads=16 (307M params)
ViT-Huge: embed_dim=1280, depth=32, num_heads=16 (632M params)
"""
def __init__(self, img_size=224, patch_size=16, in_channels=3, num_classes=1000,
embed_dim=768, depth=12, num_heads=12, mlp_ratio=4.0,
drop_rate=0., attn_drop_rate=0.):
super().__init__()
self.num_classes = num_classes
self.embed_dim = embed_dim
# 1. Patch Embedding
self.patch_embed = PatchEmbedding(img_size, patch_size, in_channels, embed_dim)
num_patches = self.patch_embed.n_patches
# 2. CLS Token
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
# 3. Position Embedding (可学习的1D)
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
self.pos_drop = nn.Dropout(drop_rate)
# 4. Transformer Encoder
self.blocks = nn.Sequential(*[
TransformerBlock(embed_dim, num_heads, mlp_ratio, drop_rate, attn_drop_rate)
for _ in range(depth)
])
# 5. Classification Head
self.norm = nn.LayerNorm(embed_dim)
self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity()
# 参数初始化
self._init_weights()
def _init_weights(self):
nn.init.trunc_normal_(self.pos_embed, std=0.02)
nn.init.trunc_normal_(self.cls_token, std=0.02)
self.apply(self._init_module_weights)
def _init_module_weights(self, m):
if isinstance(m, nn.Linear): # isinstance检查类型
nn.init.trunc_normal_(m.weight, std=0.02)
if m.bias is not None:
nn.init.zeros_(m.bias)
elif isinstance(m, nn.LayerNorm):
nn.init.ones_(m.weight)
nn.init.zeros_(m.bias)
def forward(self, x):
B = x.shape[0]
# Step 1: Patch Embedding → (B, N, D)
x = self.patch_embed(x)
# Step 2: Prepend CLS Token → (B, N+1, D)
cls_tokens = self.cls_token.expand(B, -1, -1)
x = torch.cat([cls_tokens, x], dim=1) # torch.cat沿已有维度拼接张量
# Step 3: Add Position Embedding
x = x + self.pos_embed
x = self.pos_drop(x)
# Step 4: Transformer Encoder
x = self.blocks(x)
# Step 5: Classification
x = self.norm(x)
cls_output = x[:, 0] # 取CLS token输出
logits = self.head(cls_output)
return logits
# ViT模型变体配置
def vit_base_patch16_224(**kwargs): # *args接收任意位置参数,**kwargs接收任意关键字参数
return VisionTransformer(
patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs)
def vit_large_patch16_224(**kwargs):
return VisionTransformer(
patch_size=16, embed_dim=1024, depth=24, num_heads=16, **kwargs)
def vit_huge_patch14_224(**kwargs):
return VisionTransformer(
patch_size=14, embed_dim=1280, depth=32, num_heads=16, **kwargs)
12.1.8 ViT模型族参数对比¶
| 模型 | Patch Size | Embed Dim | Depth | Heads | Params | ImageNet Top-1 |
|---|---|---|---|---|---|---|
| ViT-S/16 | 16 | 384 | 12 | 6 | 22M | 79.9% |
| ViT-B/16 | 16 | 768 | 12 | 12 | 86M | 84.5% |
| ViT-L/16 | 16 | 1024 | 24 | 16 | 307M | 87.1% |
| ViT-H/14 | 14 | 1280 | 32 | 16 | 632M | 88.6% |
| ViT-G/14 | 14 | 1664 | 48 | 16 | 1843M | 90.5% |
注:Top-1精度基于JFT-300M预训练 + ImageNet微调
12.2 ViT变体演进¶
12.2.1 DeiT — 数据高效的知识蒸馏¶
论文:Training data-efficient image transformers & distillation through attention(Touvron et al., ICML 2021)
核心动机:ViT需要JFT-300M级别数据才能超越CNN,DeiT证明了仅用ImageNet-1K也可以训练出强大的ViT。
关键创新: 1. 知识蒸馏Token:在CLS token之外,额外添加一个distillation token - 蒸馏token向教师模型(RegNetY-16GF)学习 - 最终预测 = CLS预测 + 蒸馏token预测的加权平均 2. 强数据增强:RandAugment、Mixup、CutMix、Random Erasing 3. 正则化策略:Stochastic Depth、Repeated Augmentation
输入序列: [CLS_token, distill_token, patch_1, patch_2, ..., patch_N]
分类损失: CrossEntropy(CLS_output, label)
蒸馏损失: KL_Div(distill_output, teacher_output) 或 CrossEntropy(distill_output, teacher_hard_label)
DeiT-III(2022更新):引入3-Augment策略(灰度+翻转+SolarizeAdd),比DeiT效果更好。
| 模型 | ImageNet Top-1 | 预训练数据 |
|---|---|---|
| DeiT-S | 79.8% | ImageNet-1K |
| DeiT-B | 81.8% | ImageNet-1K |
| DeiT-B (蒸馏) | 83.4% | ImageNet-1K |
| DeiT-III-L | 87.7% | ImageNet-21K |
12.2.2 BEiT — BERT式视觉预训练¶
论文:BEiT: BERT Pre-Training of Image Transformers(Bao et al., ICLR 2022)
核心思想:借鉴BERT的Masked Language Modeling (MLM),提出Masked Image Modeling (MIM)。
预训练流程: 1. 使用离散VAE(dVAE)将图像patch编码为离散visual tokens 2. 随机mask约40%的patch 3. 让ViT根据未mask的patch预测被mask位置的visual token
原始图像 → Patch序列 [p1, p2, ..., p196]
↓ 随机mask
掩码图像 → [p1, [MASK], p3, [MASK], p5, ...]
↓ ViT编码
预测被mask位置的visual token ID
BEiT v2(2022):使用VQ-KD(向量量化知识蒸馏)替代dVAE,获得更好的视觉词典。
BEiT-3(2023):统一的多模态预训练,用Multiway Transformer同时处理图像、文本和图像-文本对。
12.2.3 MAE — 掩码自编码器¶
论文:Masked Autoencoders Are Scalable Vision Learners(He et al., CVPR 2022)
核心创新:极高掩码率(75%)+ 非对称编码器-解码器架构。
与BEiT的区别: - MAE直接在像素空间重建(无需visual tokenizer) - 编码器只处理可见patch(25%),极大节省计算 - 解码器轻量(仅用于预训练),下游任务只用编码器
原始图像 196个patch
↓ 随机mask 75%
可见patch (49个) → 【重型Encoder】 → 编码特征
↓ 加入mask tokens和位置编码
全部tokens (196个) → 【轻量Decoder】 → 重建像素
↓
MSE Loss (仅在被mask的位置计算)
为什么75%掩码率有效: - 图像有大量空间冗余(相邻区域高度相关) - 高掩码率迫使模型学习高层语义,而非简单的插值 - 节省预训练时间(编码器只处理25%的token)
12.2.4 EVA — 大规模视觉基础模型¶
论文:EVA: Exploring the Limits of Masked Visual Representation Learning at Scale(Fang et al., CVPR 2023)
关键创新: - 使用CLIP的视觉特征作为MIM的重建目标(而非像素或离散token) - 证明MIM + CLIP特征蒸馏是scaling ViT的有效方案
EVA系列演进: | 模型 | 参数量 | ImageNet Top-1 | 核心创新 | |------|--------|---------------|---------| | EVA | 1.0B | 89.6% | CLIP特征作为MIM目标 | | EVA-02 | 304M | 90.0% | RoPE位置编码 + SwiGLU FFN | | EVA-CLIP | 5.0B | 最强开源CLIP | 融合EVA与CLIP训练 |
12.2.5 SigLIP — 更好的视觉-语言对齐¶
论文:Sigmoid Loss for Language Image Pre-Training(Zhai et al., ICCV 2023)
核心改进:用Sigmoid loss替代CLIP的Softmax (InfoNCE) loss。
Softmax (CLIP) vs Sigmoid (SigLIP): - CLIP:需要全局归一化,依赖大batch size(32K+),分布式训练需要all-gather操作 - SigLIP:逐对独立计算sigmoid,无需全局归一化,对batch size不敏感,可更高效地分布式训练
SigLIP在2024-2025已成为VLM(如PaLI-Gemma、LLaVA-1.5)的标配视觉编码器。
12.3 Swin Transformer深入¶
12.3.1 设计动机¶
ViT的全局自注意力存在两个问题: 1. 计算复杂度:\(O(N^2)\) 随分辨率平方增长,难以处理高分辨率图像 2. 缺乏多尺度特征:单一分辨率的token序列,不适合密集预测任务(检测/分割)
Swin Transformer通过层级设计 + 窗口注意力解决这两个问题。
12.3.2 层级设计(Hierarchical Architecture)¶
Swin采用类似CNN的4阶段层级结构,每阶段通过Patch Merging下采样:
Stage 1: H/4 × W/4, dim=C (56×56, 96维)
↓ Patch Merging (2×2 → 1, dim×2)
Stage 2: H/8 × W/8, dim=2C (28×28, 192维)
↓ Patch Merging
Stage 3: H/16 × W/16, dim=4C (14×14, 384维)
↓ Patch Merging
Stage 4: H/32 × W/32, dim=8C (7×7, 768维)
Patch Merging:将相邻 \(2 \times 2\) 个patch拼接后通过线性层降维,类似CNN的stride-2卷积。
12.3.3 窗口注意力(Window Attention)¶
将feature map划分为 \(M \times M\)(默认7×7)的不重叠窗口,在每个窗口内部做自注意力:
- 计算复杂度:\(O(M^2 \cdot N)\),其中 \(N\) 为总token数 → 线性复杂度
- 对比全局注意力 \(O(N^2)\),窗口注意力在高分辨率下优势巨大
12.3.4 移位窗口(Shifted Window)¶
问题:窗口注意力导致窗口之间没有信息交互。
解决方案:交替使用普通窗口和移位窗口: - 奇数层:常规窗口划分 - 偶数层:窗口向右下移动 \(M/2\) 个像素后重新划分
高效实现:通过cyclic shift + mask实现,避免实际创建更多窗口:
class WindowAttention(nn.Module):
"""Swin Transformer的窗口注意力(含相对位置偏置)"""
def __init__(self, dim, window_size, num_heads):
super().__init__()
self.dim = dim
self.window_size = window_size # (Wh, Ww)
self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = head_dim ** -0.5
self.qkv = nn.Linear(dim, dim * 3)
self.proj = nn.Linear(dim, dim)
# 相对位置偏置表
# (2*Wh-1) * (2*Ww-1) 个可能的相对位置
self.relative_position_bias_table = nn.Parameter(
torch.zeros((2*window_size[0]-1) * (2*window_size[1]-1), num_heads)
)
nn.init.trunc_normal_(self.relative_position_bias_table, std=0.02)
# 计算每个token对的相对位置索引
coords_h = torch.arange(window_size[0])
coords_w = torch.arange(window_size[1])
coords = torch.stack(torch.meshgrid([coords_h, coords_w], indexing='ij')) # (2, Wh, Ww) # torch.stack沿新维度拼接张量
coords_flatten = torch.flatten(coords, 1) # (2, Wh*Ww)
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # (2, N, N)
relative_coords = relative_coords.permute(1, 2, 0).contiguous() # (N, N, 2)
relative_coords[:, :, 0] += window_size[0] - 1
relative_coords[:, :, 1] += window_size[1] - 1
relative_coords[:, :, 0] *= 2 * window_size[1] - 1
relative_position_index = relative_coords.sum(-1) # (N, N)
self.register_buffer("relative_position_index", relative_position_index)
def forward(self, x, mask=None):
B_, N, C = x.shape # B_为窗口总数
qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads)
qkv = qkv.permute(2, 0, 3, 1, 4)
q, k, v = qkv.unbind(0)
attn = (q @ k.transpose(-2, -1)) * self.scale
# 添加相对位置偏置
# 用index从偏置表查找→view(N,N,heads)→permute为(heads,N,N)→unsqueeze(0)加batch维以广播加到注意力分数上
relative_position_bias = self.relative_position_bias_table[
self.relative_position_index.view(-1)
].view(N, N, -1).permute(2, 0, 1).contiguous()
attn = attn + relative_position_bias.unsqueeze(0) # unsqueeze增加一个维度
# 移位窗口的mask
if mask is not None:
nW = mask.shape[0]
attn = attn.view(B_ // nW, nW, self.num_heads, N, N)
attn = attn + mask.unsqueeze(1).unsqueeze(0) # 被mask位置设为-inf
attn = attn.view(-1, self.num_heads, N, N)
attn = attn.softmax(dim=-1)
x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
x = self.proj(x)
return x
def window_partition(x, window_size):
"""将feature map划分为窗口"""
B, H, W, C = x.shape
x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
return windows # (num_windows*B, window_size, window_size, C)
def window_reverse(windows, window_size, H, W):
"""将窗口还原为feature map"""
B = int(windows.shape[0] / (H * W / window_size / window_size))
x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
return x
12.3.5 Swin Transformer模型族¶
| 模型 | Embed Dim | Depths | Heads | Params | ImageNet Top-1 |
|---|---|---|---|---|---|
| Swin-T | 96 | [2,2,6,2] | [3,6,12,24] | 29M | 81.3% |
| Swin-S | 96 | [2,2,18,2] | [3,6,12,24] | 50M | 83.0% |
| Swin-B | 128 | [2,2,18,2] | [4,8,16,32] | 88M | 83.5% |
| Swin-L | 192 | [2,2,18,2] | [6,12,24,48] | 197M | 86.4% |
Swin v2(2022):引入log-CPB(对数连续位置偏置)和余弦注意力,支持更高分辨率和更大模型。
12.4 高效ViT¶
12.4.1 EfficientViT (MIT)¶
论文:EfficientViT: Lightweight Multi-Scale Attention for High-Resolution Dense Prediction(Cai et al., ICCV 2023)
核心创新: - 多尺度线性注意力:用线性注意力(\(O(N)\))替代标准softmax注意力(\(O(N^2)\)) - 级联分组注意力:将不同head分配给不同的特征分片,减少冗余计算 - 在分割任务(Cityscapes、ADE20K)上以极低延迟达到SOTA
| 模型 | Params | FLOPs | ImageNet Top-1 | GPU延迟 |
|---|---|---|---|---|
| EfficientViT-B1 | 9.1M | 0.52G | 79.4% | 0.3ms |
| EfficientViT-B2 | 24.3M | 1.6G | 82.1% | 0.6ms |
| EfficientViT-B3 | 48.6M | 4.0G | 83.5% | 1.1ms |
12.4.2 MobileViT¶
论文:MobileViT: Light-weight, General-purpose, and Mobile-friendly Vision Transformer(Mehta & Rastegari, ICLR 2022)
设计理念:结合MobileNet的轻量CNN和ViT的全局注意力。
- 用MobileNetV2的Inverted Residual Block提取局部特征
- 在关键阶段插入Transformer Block捕捉全局依赖
- 参数量仅5-6M,适合移动端部署
12.4.3 TinyViT¶
论文:TinyViT: Fast Pretraining Distillation for Small Vision Transformers(Wu et al., ECCV 2022)
方法:通过大模型蒸馏快速预训练小ViT(5M-21M参数),在ImageNet上达到与大模型相当的精度。
| 模型 | Params | ImageNet Top-1 | 用途 |
|---|---|---|---|
| TinyViT-5M | 5.4M | 79.1% | 移动端推理 |
| TinyViT-11M | 11M | 81.5% | 边缘设备 |
| TinyViT-21M | 21M | 83.2% | 轻量服务端 |
12.5 ViT在下游任务中的应用¶
12.5.1 目标检测 — DETR¶
论文:End-to-End Object Detection with Transformers(Carion et al., ECCV 2020)
革命性意义:去掉了anchor、NMS等手工设计组件,用Transformer实现端到端检测。
架构:
图像 → CNN Backbone → Transformer Encoder → Transformer Decoder → 预测框+类别
↑
Object Queries (可学习)
匹配策略:用匈牙利算法做预测框与GT的二分匹配(Set Prediction)。
演进:DETR → Deformable DETR → DINO → Co-DETR → RT-DETR(实时版)
12.5.2 图像分割 — SAM¶
论文:Segment Anything(Kirillov et al., ICCV 2023)
SAM (Segment Anything Model) 是视觉基础模型的里程碑: - Image Encoder:ViT-H(MAE预训练),提取图像特征 - Prompt Encoder:编码点击、框、文本等提示 - Mask Decoder:轻量Transformer解码器,生成分割掩码
SAM 2(2024):扩展到视频分割,引入Memory Mechanism追踪时序信息。
12.5.3 图像生成 — DiT¶
论文:Scalable Diffusion Models with Transformers(Peebles & Xie, ICCV 2023)
DiT (Diffusion Transformer):用ViT替代U-Net作为扩散模型的去噪网络。
核心设计: - 输入:带噪声的latent patch序列 + 时间步embedding + 类别embedding - 用AdaLN-Zero(自适应LayerNorm)注入条件信息 - DiT是Sora的基础架构
Noisy Latent → Patchify → Transformer Blocks (with AdaLN) → Unpatchify → Denoised Latent
↑ timestep + class label
12.5.4 其他重要应用¶
| 任务 | 代表模型 | 核心用法 |
|---|---|---|
| 语义分割 | SegFormer, Mask2Former | ViT作为backbone + 分割解码器 |
| 深度估计 | DPT, Depth Anything | ViT编码器 + 多尺度解码器 |
| 视频理解 | VideoMAE, TimeSformer | 时空patch + 视频Transformer |
| 点云处理 | Point-BERT, Point-MAE | 3D点云patch化 + ViT |
| 医学影像 | MedViT, SAM-Med | 预训练ViT + 领域微调 |
12.6 ViT vs CNN 对比分析¶
12.6.1 归纳偏置(Inductive Bias)¶
| 特性 | CNN | ViT |
|---|---|---|
| 局部性 | 卷积核限制感受野,天然捕捉局部模式 | 全局注意力,需要从数据中学习局部性 |
| 平移等变性 | 卷积权重共享保证平移等变 | 无此先验,依赖数据增强和位置编码 |
| 层级特征 | pooling自然构建多尺度 | 原始ViT单尺度,需Swin等设计引入 |
| 训练效率 | 归纳偏置帮助小数据学习 | 小数据下表现差,大数据下优势显现 |
12.6.2 数据效率¶
结论:ViT是一个更"通用"但也更"饥渴"的学习器——在足够数据下可以学到比CNN更好的表示。
12.6.3 计算复杂度¶
对于输入分辨率 \(H \times W\)、patch数 \(N = HW/P^2\):
| 操作 | 复杂度 | 说明 |
|---|---|---|
| ViT Self-Attention | \(O(N^2 \cdot D)\) | 随分辨率二次增长 |
| Swin Window Attention | \(O(M^2 \cdot N \cdot D)\) | 线性于N |
| CNN 3×3 Conv | \(O(9 \cdot C^2 \cdot N)\) | 线性于N |
| Linear Attention | \(O(N \cdot D^2)\) | 线性于N |
12.6.4 实践建议¶
| 场景 | 推荐架构 | 原因 |
|---|---|---|
| 预训练数据充足 | ViT/EVA | Scaling表现最佳 |
| 目标检测/分割 | Swin/ViTDet | 多尺度特征+高分辨率 |
| 移动端部署 | MobileViT/EfficientViT | 轻量高效 |
| 中小数据集 | ConvNeXt/CNN+ViT混合 | 归纳偏置有帮助 |
| 多模态应用 | ViT+SigLIP | 与语言模型对接方便 |
12.7 使用预训练ViT模型¶
12.7.1 HuggingFace推理¶
import torch
from transformers import ViTForImageClassification, ViTImageProcessor
from PIL import Image
# 加载预训练模型和处理器
model_name = 'google/vit-base-patch16-224'
model = ViTForImageClassification.from_pretrained(model_name)
processor = ViTImageProcessor.from_pretrained(model_name)
model.eval() # eval()评估模式
# 推理
image = Image.open('image.jpg').convert('RGB')
inputs = processor(images=image, return_tensors="pt")
with torch.no_grad(): # 禁用梯度计算,节省内存
outputs = model(**inputs)
logits = outputs.logits
predicted_class_idx = logits.argmax(-1).item() # 将单元素张量转为Python数值
predicted_class = model.config.id2label[predicted_class_idx]
confidence = torch.softmax(logits, dim=-1).max().item()
print(f"预测类别: {predicted_class}, 置信度: {confidence:.4f}")
12.7.2 timm库使用¶
import timm
import torch
from PIL import Image
from timm.data import resolve_data_config, create_transform
# timm提供了几乎所有主流ViT变体
model = timm.create_model('vit_base_patch16_224.augreg_in21k_ft_in1k', pretrained=True)
model.eval()
# 自动获取模型对应的数据预处理
config = resolve_data_config(model.pretrained_cfg)
transform = create_transform(**config)
image = Image.open('image.jpg').convert('RGB')
input_tensor = transform(image).unsqueeze(0)
with torch.no_grad():
output = model(input_tensor)
probabilities = torch.softmax(output, dim=-1)
top5_prob, top5_idx = probabilities.topk(5)
for i in range(5):
print(f"Top-{i+1}: class={top5_idx[0][i].item()}, prob={top5_prob[0][i].item():.4f}")
# 列出所有可用的ViT模型
vit_models = timm.list_models('vit_*', pretrained=True)
print(f"可用ViT模型数量: {len(vit_models)}")
12.7.3 ViT微调示例¶
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import timm
# 加载预训练ViT并修改分类头
model = timm.create_model('vit_base_patch16_224', pretrained=True, num_classes=10)
# 冻结backbone,只训练分类头(适合小数据集)
for name, param in model.named_parameters():
if 'head' not in name:
param.requires_grad = False
# 或者:全模型微调(适合中等数据集),使用较小学习率
optimizer = torch.optim.AdamW([
{'params': model.head.parameters(), 'lr': 1e-3}, # 分类头大学习率
{'params': [p for n, p in model.named_parameters()
if 'head' not in n and p.requires_grad], 'lr': 1e-5} # backbone小学习率
], weight_decay=0.05)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=100)
criterion = nn.CrossEntropyLoss(label_smoothing=0.1)
# 训练循环
def train_one_epoch(model, dataloader, optimizer, criterion, device):
model.train()
total_loss, correct, total = 0, 0, 0
for images, labels in dataloader:
images, labels = images.to(device), labels.to(device) # 移至GPU/CPU
outputs = model(images)
loss = criterion(outputs, labels)
optimizer.zero_grad() # 清零梯度
loss.backward() # 反向传播计算梯度
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
optimizer.step() # 更新参数
total_loss += loss.item()
correct += (outputs.argmax(1) == labels).sum().item()
total += labels.size(0)
return total_loss / len(dataloader), correct / total
12.8 练习题¶
基础题¶
- 简答题:
- ViT的Patch Embedding相当于什么操作?为什么用Conv2d实现?
Patch Embedding相当于用stride=patch_size的Conv2d对图像做不重叠卷积,将每个patch投影为一个向量。用Conv2d实现而非先reshape再线性层,是因为Conv2d能直接完成“切块+投影”两步,硬件利用率更高、实现更简洁。
- 解释Pre-Norm和Post-Norm的区别,ViT使用哪种?
Post-Norm在残差连接之后做LayerNorm(原始Transformer),训练时梯度不稳定,需要学习率warmup。Pre-Norm在残差连接之前做LayerNorm,梯度可以直接通过残差路径流动,训练更稳定,适合深层网络。ViT使用Pre-Norm。
-
为什么ViT需要位置编码?可学习1D和2D位置编码效果有区别吗?
Transformer的Self-Attention是置换不变的(permutation-invariant),无法感知token的空间位置,因此必须添加位置编码。实验表明可学习1D和2D位置编码效果差别很小(约±0.1%),因为1D编码在训练中可以自动学到类似2D的结构信息。
-
计算题:
- 计算ViT-B/16在224×224输入时的patch数量、序列长度和Self-Attention的FLOPs。
Patch数量 = \((224/16)^2 = 14 \times 14 = 196\);加上CLS token后序列长度 \(N = 196 + 1 = 197\)。ViT-B的隐藏维度 \(D = 768\)。单层Self-Attention的FLOPs ≈ \(2 \times 4N^2D = 2 \times 4 \times 197^2 \times 768 \approx 238M\)(包含Q/K/V投影和注意力计算)。ViT-B有12层,总注意力FLOPs ≈ \(12 \times 238M \approx 2.86G\)。
- 如果将输入分辨率改为384×384,序列长度和计算量如何变化?
Patch数量 = \((384/16)^2 = 24 \times 24 = 576\),序列长度 \(N = 577\)。相比224分辨率的197,序列长度增加约2.93倍。由于Self-Attention计算量与 \(N^2\) 成正比,注意力FLOPs增加约 \((577/197)^2 \approx 8.58\) 倍。总模型FLOPs增加约3倍左右(因为MLP部分仅线性增长)。
进阶题¶
- 编程题:
- 从零实现一个完整的ViT(包含Patch Embedding、Position Embedding、MHA、MLP、CLS Token)。
-
使用timm加载预训练Swin-T,在CIFAR-10上微调。
-
分析题:
- 比较MAE和BEiT的预训练策略,各有什么优劣?
MAE随机掩码图像75%的patch,编码器只处理可见patch,解码器重建原始像素。优势:预训练效率高(3-4×加速)、无需额外组件、概念简洁。劣势:重建低层像素可能不够语义化。BEiT先用dVAE将图像编码为离散token,然后预测被掩码位置的token ID。优势:预测目标更语义化。劣势:需要预训练dVAE tokenizer,流程更复杂。
- 为什么DiT能替代U-Net成为扩散模型的backbone?
①Scaling能力更强:Transformer可轻松堆叠层数,遵循Scaling Law,U-Net的残差块+跳连结构难以扩展;②架构简洁,无需精心设计编解码器层数和通道数;③条件注入灵活,AdaLN-Zero可优雅地注入timestep和class/text条件;④与LLM生态统一,方便融入文本条件生成;⑤可利用MAE等预训练ViT权重初始化。
12.9 面试准备¶
大厂面试题¶
Q1: ViT相比CNN有什么优势和劣势?在什么场景下选择ViT?
参考答案: 优势: - 全局感受野:第一层就能看到所有patch之间的关系,CNN需要很深才能获得大感受野 - 可扩展性强:参数量和数据量增加时性能持续提升(Scaling Law) - 架构统一:与NLP的Transformer共享架构,便于多模态融合 - 预训练迁移好:MAE/CLIP等大规模预训练后迁移效果出色
劣势: - 数据饥渴:小数据集上不如CNN(缺乏locality、translation equivariance等归纳偏置) - 计算复杂度高:Self-Attention的O(N²)使高分辨率输入开销大 - 对分辨率敏感:位置编码与输入尺寸绑定,分辨率变化需要插值
选择建议:预训练数据充足时首选ViT;小数据用CNN或混合架构;移动端用EfficientViT/MobileViT;密集预测用Swin/ViTDet。
Q2: 请解释Swin Transformer的移位窗口机制及其作用
参考答案: Swin Transformer在相邻层交替使用两种窗口配置: - 常规窗口(W-MSA):将feature map均匀划分为 \(M \times M\) 窗口 - 移位窗口(SW-MSA):将窗口向右下偏移 \(M/2\),跨越常规窗口边界
作用: 1. 允许相邻窗口之间交换信息,避免信息孤岛 2. 保持线性计算复杂度(只在窗口内做注意力) 3. 高效实现:用cyclic shift + attention mask,无需创建额外窗口
Q3: MAE为什么要用75%这么高的掩码率?
参考答案: 1. 图像的空间冗余高:相邻像素高度相关,低掩码率下模型可以轻松"插值"恢复,学不到高层语义 2. 高掩码率迫使模型理解语义:看到极少信息时必须"推理"内容是什么 3. 计算效率:编码器只处理25%的visible tokens,预训练速度提升3-4倍 4. 与NLP对比:BERT用15%的掩码率,因为语言信息密度高,而图像信息密度低
Q4: DETR如何实现端到端目标检测?相比传统检测器有什么优势?
参考答案: DETR流程: 1. CNN backbone提取特征 → Transformer Encoder全局建模 2. \(N\) 个可学习的Object Queries经Transformer Decoder交叉注意力 3. 每个query直接输出(class, box)预测 4. 用匈牙利算法做预测与GT的二分匹配,计算集合级别损失
优势:去掉了anchor设计、NMS后处理、FPN等手工组件,简化流程 劣势:训练收敛慢(500 epochs)、小目标检测弱 → Deformable DETR解决了这些问题
Q5: DiT为什么能替代U-Net成为扩散模型的标配backbone?
参考答案: 1. Scaling能力:U-Net的残差块+跳跃连接结构不易扩展,DiT基于Transformer可以轻松堆叠层数,遵循Scaling Law 2. 架构简洁:无需精心设计U-Net的编码器-解码器层数和通道数 3. 条件注入灵活:AdaLN-Zero可以优雅地注入timestep和class/text条件 4. 与LLM生态统一:方便融入文本条件生成(如Sora) 5. 预训练兼容:可以利用MAE等预训练的ViT权重初始化
Q6: SigLIP相比CLIP有什么改进?为什么成为VLM视觉编码器的标配?
参考答案: 改进:用Sigmoid loss替代InfoNCE (Softmax) loss - InfoNCE需要全batch的负样本做归一化 → 需要超大batch size和all-gather - Sigmoid loss对每个正/负样本对独立计算 → 不依赖batch size,分布式训练更高效
成为标配的原因: - 训练更稳定、效率更高 - 相同参数量下zero-shot性能更好 - 在PaLI-Gemma、LLaVA-Next等VLM中广泛使用
Q7: 如何将ViT的位置编码从224×224分辨率迁移到384×384?
参考答案: 使用双线性插值调整位置编码维度: 1. 将 \(14 \times 14\) 的2D位置编码reshape 2. 用 F.interpolate 双线性插值到 \(24 \times 24\) 3. 再展平回1D序列 4. CLS token的位置编码保持不变
pos_embed_2d = pos_embed[:, 1:].reshape(1, 14, 14, D).permute(0, 3, 1, 2)
pos_embed_2d = F.interpolate(pos_embed_2d, size=(24, 24), mode='bicubic', align_corners=False) # F.xxx PyTorch函数式API
new_pos_embed = torch.cat([pos_embed[:, :1], pos_embed_2d.flatten(2).transpose(1, 2)], dim=1)
Q8: 比较ViT、Swin Transformer和ConvNeXt三个架构的设计哲学
参考答案:
| 特性 | ViT | Swin Transformer | ConvNeXt |
|---|---|---|---|
| 设计哲学 | 纯Transformer,最小视觉先验 | Transformer + CNN层级设计 | 纯CNN + Transformer训练技巧 |
| 注意力范围 | 全局 | 局部窗口 | 局部(7×7深度可分离卷积) |
| 多尺度特征 | 无(单尺度) | 有(4阶段层级) | 有(4阶段层级) |
| 计算复杂度 | O(N²) | O(N) | O(N) |
| 最佳使用场景 | 基础模型预训练 | 密集预测任务 | 需要CNN生态兼容 |
| 代表应用 | CLIP/MAE/DiT | 检测/分割backbone | 替代ResNet |
12.10 关键论文列表¶
必读论文¶
| 年份 | 论文 | 核心贡献 |
|---|---|---|
| 2020 | ViT: An Image is Worth 16x16 Words | 开创性地将纯Transformer用于视觉 |
| 2021 | DeiT: Training data-efficient image transformers | 知识蒸馏+强数据增强训练ViT |
| 2021 | Swin Transformer: Hierarchical Vision Transformer | 移位窗口+层级结构,适配密集预测 |
| 2021 | BEiT: BERT Pre-Training of Image Transformers | Masked Image Modeling预训练 |
| 2022 | MAE: Masked Autoencoders Are Scalable Vision Learners | 75%掩码+像素重建的高效预训练 |
| 2022 | ConvNeXt: A ConvNet for the 2020s | 用Transformer技巧现代化CNN |
扩展论文¶
| 年份 | 论文 | 核心贡献 |
|---|---|---|
| 2020 | DETR: End-to-End Object Detection with Transformers | Transformer端到端检测 |
| 2023 | EVA-02: A Visual Representation for Neon Genesis | RoPE + SwiGLU的强力ViT |
| 2023 | SigLIP: Sigmoid Loss for Language Image Pre-Training | 更高效的视觉-语言对比学习 |
| 2023 | SAM: Segment Anything | ViT驱动的视觉分割基础模型 |
| 2023 | DiT: Scalable Diffusion Models with Transformers | ViT替代U-Net做扩散模型 |
| 2023 | DINOv2: Learning Robust Visual Features | 自监督ViT视觉基础模型 |
| 2024 | SAM 2: Segment Anything in Images and Videos | ViT+Memory的视频分割 |
12.11 本章小结¶
核心知识点¶
- ViT原始架构:Patch Embedding → CLS Token + Position Embedding → Transformer Encoder → Classification Head
- ViT变体:DeiT(蒸馏)、BEiT(MIM)、MAE(掩码自编码)、EVA(大规模)、SigLIP(高效对比学习)
- Swin Transformer:层级设计 + 窗口注意力 + 移位窗口 + 相对位置偏置
- 高效ViT:EfficientViT、MobileViT、TinyViT——面向端侧部署
- 下游应用:DETR(检测)、SAM(分割)、DiT(生成)
- ViT vs CNN:归纳偏置、数据效率、计算复杂度的全面对比
下一步¶
下一章:13-多模态学习.md - 学习CLIP、BLIP、LLaVA等视觉-语言多模态模型
恭喜完成第12章! 🎉 视觉Transformer已成为现代CV的基石——从分类、检测、分割到生成,ViT无处不在。