03-视觉Transformer¶
⚠️ 时效性说明:本章涉及前沿模型/价格/榜单等信息,可能随版本快速变化;请以论文原文、官方发布页和 API 文档为准。
学习时间: 约6-8小时 难度级别: ⭐⭐⭐⭐ 中高级 前置知识: Transformer架构、CNN基础、PyTorch 学习目标: 理解Vision Transformer (ViT) 系列模型,掌握图像分类中Transformer的应用
目录¶
- 1. 从CNN到ViT
- 2. Vision Transformer (ViT)
- 3. DeiT:数据高效训练
- 4. Swin Transformer
- 5. MAE:掩码自编码器
- 6. 实战:ViT图像分类
- 7. 练习与自我检查
1. 从CNN到ViT¶
1.1 CNN的局限性¶
- 局部感受野:卷积核只看局部区域,依赖层数堆叠扩大感受野
- 平移等变但不具有全局建模能力
- 归纳偏置强:适合图像但限制了灵活性
1.2 为什么 Transformer 能用于视觉¶
Transformer 的自注意力天然具有全局建模能力——每个位置都能关注所有其他位置。关键问题是如何将图像转化为序列。
2. Vision Transformer (ViT)¶
2.1 核心思想¶
Dosovitskiy et al.(2020):"An Image is Worth 16x16 Words"
将图像切分为固定大小的图块(Patch),展平后通过线性投影得到 token 序列,再加上位置编码和 [CLS] token,送入标准 Transformer Encoder。
2.2 Patch Embedding¶
一张 \(H \times W\) 的图像,被切分为 \(N = \frac{H \times W}{P^2}\) 个大小为 \(P \times P\) 的 Patch。
其中 \(E \in \mathbb{R}^{(P^2 \cdot C) \times D}\) 是线性投影,\(E_{pos} \in \mathbb{R}^{(N+1) \times D}\) 是位置编码。
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
class PatchEmbedding(nn.Module): # 继承nn.Module定义神经网络层
"""将图像切分为 Patch 并线性投影"""
def __init__(self, img_size=224, patch_size=16, in_channels=3, embed_dim=768): # __init__构造方法,创建对象时自动调用
super().__init__() # super()调用父类方法
self.img_size = img_size
self.patch_size = patch_size
self.num_patches = (img_size // patch_size) ** 2
# 用卷积实现 Patch Embedding(等价于切分+线性投影,但更高效)
self.proj = nn.Conv2d(
in_channels, embed_dim,
kernel_size=patch_size, stride=patch_size
)
def forward(self, x):
"""
x: (batch, C, H, W)
return: (batch, num_patches, embed_dim)
"""
x = self.proj(x) # (batch, embed_dim, H/P, W/P)
x = x.flatten(2) # (batch, embed_dim, num_patches)
x = x.transpose(1, 2) # (batch, num_patches, embed_dim)
return x
2.3 完整 ViT 实现¶
class ViTBlock(nn.Module):
"""ViT Transformer Block(Pre-Norm)"""
def __init__(self, embed_dim, num_heads, mlp_ratio=4.0, dropout=0.1):
super().__init__()
self.norm1 = nn.LayerNorm(embed_dim)
self.attn = nn.MultiheadAttention(embed_dim, num_heads, dropout=dropout, batch_first=True)
self.norm2 = nn.LayerNorm(embed_dim)
mlp_hidden = int(embed_dim * mlp_ratio)
self.mlp = nn.Sequential(
nn.Linear(embed_dim, mlp_hidden),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(mlp_hidden, embed_dim),
nn.Dropout(dropout),
)
def forward(self, x):
# Pre-Norm
normed = self.norm1(x)
attn_out, _ = self.attn(normed, normed, normed)
x = x + attn_out
x = x + self.mlp(self.norm2(x))
return x
class VisionTransformer(nn.Module):
"""Vision Transformer (ViT)"""
def __init__(self, img_size=224, patch_size=16, in_channels=3, num_classes=1000,
embed_dim=768, depth=12, num_heads=12, mlp_ratio=4.0, dropout=0.1):
super().__init__()
self.patch_embed = PatchEmbedding(img_size, patch_size, in_channels, embed_dim)
num_patches = self.patch_embed.num_patches
# [CLS] token
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
# 可学习位置编码
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
self.pos_drop = nn.Dropout(dropout)
# Transformer Encoder
self.blocks = nn.Sequential(*[
ViTBlock(embed_dim, num_heads, mlp_ratio, dropout)
for _ in range(depth)
])
self.norm = nn.LayerNorm(embed_dim)
# 分类头
self.head = nn.Linear(embed_dim, num_classes)
# 初始化
nn.init.trunc_normal_(self.cls_token, std=0.02)
nn.init.trunc_normal_(self.pos_embed, std=0.02)
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, nn.Linear): # isinstance检查对象类型
nn.init.trunc_normal_(m.weight, std=0.02)
if m.bias is not None:
nn.init.zeros_(m.bias)
elif isinstance(m, nn.LayerNorm):
nn.init.ones_(m.weight)
nn.init.zeros_(m.bias)
def forward(self, x):
"""x: (batch, C, H, W) → (batch, num_classes)"""
batch_size = x.size(0)
# 1. Patch Embedding
x = self.patch_embed(x) # (batch, num_patches, embed_dim)
# 2. 添加 [CLS] token
cls_tokens = self.cls_token.expand(batch_size, -1, -1)
x = torch.cat([cls_tokens, x], dim=1) # (batch, num_patches+1, embed_dim)
# 3. 位置编码
x = self.pos_drop(x + self.pos_embed)
# 4. Transformer Encoder
x = self.blocks(x)
x = self.norm(x)
# 5. 分类 — 取 [CLS] token 的输出
cls_output = x[:, 0]
logits = self.head(cls_output)
return logits
# ===== 测试各种 ViT 配置 =====
def vit_tiny(num_classes=100):
return VisionTransformer(embed_dim=192, depth=12, num_heads=3, num_classes=num_classes)
def vit_small(num_classes=100):
return VisionTransformer(embed_dim=384, depth=12, num_heads=6, num_classes=num_classes)
def vit_base(num_classes=100):
return VisionTransformer(embed_dim=768, depth=12, num_heads=12, num_classes=num_classes)
model = vit_tiny(num_classes=10)
x = torch.randn(4, 3, 224, 224)
out = model(x)
print(f"ViT-Tiny 输出: {out.shape}")
print(f"参数量: {sum(p.numel() for p in model.parameters()):,}")
3. DeiT:数据高效训练¶
3.1 训练策略¶
Touvron et al.(2021)的 DeiT 证明 ViT 不需要 JFT-300M 大数据集,通过以下技巧在 ImageNet 上高效训练:
- 强数据增强:RandAugment, Mixup, CutMix
- 正则化:DropPath, Label Smoothing
- 知识蒸馏:使用 distillation token
3.2 DropPath(Stochastic Depth)¶
class DropPath(nn.Module):
"""DropPath(随机深度):随机跳过整个残差分支"""
def __init__(self, drop_prob=0.0):
super().__init__()
self.drop_prob = drop_prob
def forward(self, x):
if not self.training or self.drop_prob == 0.0:
return x
keep_prob = 1 - self.drop_prob
shape = (x.shape[0],) + (1,) * (x.ndim - 1)
random_tensor = torch.rand(shape, device=x.device) < keep_prob
return x * random_tensor / keep_prob
class DeiTBlock(nn.Module):
"""带 DropPath 的 ViT Block"""
def __init__(self, embed_dim, num_heads, mlp_ratio=4.0, dropout=0.1, drop_path=0.0):
super().__init__()
self.norm1 = nn.LayerNorm(embed_dim)
self.attn = nn.MultiheadAttention(embed_dim, num_heads, dropout=dropout, batch_first=True)
self.drop_path1 = DropPath(drop_path)
self.norm2 = nn.LayerNorm(embed_dim)
mlp_hidden = int(embed_dim * mlp_ratio)
self.mlp = nn.Sequential(
nn.Linear(embed_dim, mlp_hidden),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(mlp_hidden, embed_dim),
nn.Dropout(dropout)
)
self.drop_path2 = DropPath(drop_path)
def forward(self, x):
normed = self.norm1(x)
attn_out, _ = self.attn(normed, normed, normed)
x = x + self.drop_path1(attn_out)
x = x + self.drop_path2(self.mlp(self.norm2(x)))
return x
4. Swin Transformer¶
4.1 核心创新¶
Liu et al.(2021)的 Swin Transformer 解决了 ViT 的两个问题:
- 多尺度特征:类似 CNN 的层次化结构,通过 Patch Merging 逐步降低分辨率
- 计算效率:用窗口注意力将复杂度从 \(O(n^2)\) 降低到 \(O(n)\)
4.2 窗口注意力(Window Attention)¶
class WindowAttention(nn.Module):
"""窗口注意力"""
def __init__(self, dim, window_size, num_heads):
super().__init__()
self.dim = dim
self.window_size = window_size # (Wh, Ww)
self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = head_dim ** -0.5
self.qkv = nn.Linear(dim, dim * 3)
self.proj = nn.Linear(dim, dim)
# 相对位置偏置
self.relative_position_bias_table = nn.Parameter(
torch.zeros((2 * window_size - 1) * (2 * window_size - 1), num_heads)
)
nn.init.trunc_normal_(self.relative_position_bias_table, std=0.02)
coords = torch.arange(window_size)
coords = torch.stack(torch.meshgrid([coords, coords], indexing='ij')).flatten(1)
relative_coords = coords[:, :, None] - coords[:, None, :]
relative_coords = relative_coords.permute(1, 2, 0).contiguous() # 链式调用,连续执行多个方法
relative_coords[:, :, 0] += window_size - 1
relative_coords[:, :, 1] += window_size - 1
relative_coords[:, :, 0] *= 2 * window_size - 1
relative_position_index = relative_coords.sum(-1)
self.register_buffer('relative_position_index', relative_position_index)
def forward(self, x, mask=None):
B_, N, C = x.shape
qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads) # reshape重塑张量形状
qkv = qkv.permute(2, 0, 3, 1, 4)
q, k, v = qkv.unbind(0)
attn = (q @ k.transpose(-2, -1)) * self.scale
relative_position_bias = self.relative_position_bias_table[
self.relative_position_index.view(-1) # view重塑张量形状(要求内存连续)
].view(N, N, -1).permute(2, 0, 1)
attn = attn + relative_position_bias.unsqueeze(0) # unsqueeze增加一个维度
if mask is not None:
attn = attn + mask.unsqueeze(1)
attn = F.softmax(attn, dim=-1)
x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
x = self.proj(x)
return x
4.3 窗口分割与合并¶
def window_partition(x, window_size):
"""将特征图分割为不重叠的窗口"""
B, H, W, C = x.shape
x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
windows = x.permute(0, 1, 3, 2, 4, 5).contiguous()
windows = windows.view(-1, window_size, window_size, C)
return windows # (num_windows*B, window_size, window_size, C)
def window_reverse(windows, window_size, H, W):
"""将窗口合并回特征图"""
B = int(windows.shape[0] / (H * W / window_size / window_size))
x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
x = x.permute(0, 1, 3, 2, 4, 5).contiguous()
x = x.view(B, H, W, -1)
return x
4.4 Patch Merging¶
class PatchMerging(nn.Module):
"""Patch合并 — 降低分辨率(2倍下采样)"""
def __init__(self, dim):
super().__init__()
self.norm = nn.LayerNorm(4 * dim)
self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
def forward(self, x, H, W):
B, L, C = x.shape
x = x.view(B, H, W, C)
x0 = x[:, 0::2, 0::2, :]
x1 = x[:, 1::2, 0::2, :]
x2 = x[:, 0::2, 1::2, :]
x3 = x[:, 1::2, 1::2, :]
x = torch.cat([x0, x1, x2, x3], dim=-1) # (B, H/2, W/2, 4C)
x = x.view(B, -1, 4 * C)
x = self.norm(x)
x = self.reduction(x) # (B, H/2 * W/2, 2C)
return x, H // 2, W // 2
5. MAE:掩码自编码器¶
5.1 思想¶
He et al.(2022)提出 MAE(Masked Autoencoders):随机遮挡大量图像 Patch(如 75%),训练模型重构被遮挡的 Patch。
class MAE(nn.Module):
"""简化版 Masked Autoencoder"""
def __init__(self, img_size=224, patch_size=16, in_channels=3,
embed_dim=768, depth=12, num_heads=12,
decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16,
mask_ratio=0.75):
super().__init__()
self.patch_embed = PatchEmbedding(img_size, patch_size, in_channels, embed_dim)
self.num_patches = self.patch_embed.num_patches
self.mask_ratio = mask_ratio
self.patch_size = patch_size
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
self.pos_embed = nn.Parameter(torch.zeros(1, self.num_patches + 1, embed_dim))
# 编码器
self.encoder_blocks = nn.Sequential(*[
ViTBlock(embed_dim, num_heads) for _ in range(depth)
])
self.encoder_norm = nn.LayerNorm(embed_dim)
# 解码器
self.decoder_embed = nn.Linear(embed_dim, decoder_embed_dim)
self.mask_token = nn.Parameter(torch.zeros(1, 1, decoder_embed_dim))
self.decoder_pos_embed = nn.Parameter(torch.zeros(1, self.num_patches + 1, decoder_embed_dim))
self.decoder_blocks = nn.Sequential(*[
ViTBlock(decoder_embed_dim, decoder_num_heads) for _ in range(decoder_depth)
])
self.decoder_norm = nn.LayerNorm(decoder_embed_dim)
self.decoder_pred = nn.Linear(decoder_embed_dim, patch_size**2 * in_channels)
self._init_weights()
def _init_weights(self):
nn.init.trunc_normal_(self.cls_token, std=0.02)
nn.init.trunc_normal_(self.pos_embed, std=0.02)
nn.init.trunc_normal_(self.mask_token, std=0.02)
nn.init.trunc_normal_(self.decoder_pos_embed, std=0.02)
def random_masking(self, x, mask_ratio):
"""随机掩码,返回可见 token 及索引"""
B, N, D = x.shape
len_keep = int(N * (1 - mask_ratio))
noise = torch.rand(B, N, device=x.device)
ids_shuffle = torch.argsort(noise, dim=1)
ids_restore = torch.argsort(ids_shuffle, dim=1)
ids_keep = ids_shuffle[:, :len_keep]
x_masked = torch.gather(x, 1, ids_keep.unsqueeze(-1).expand(-1, -1, D))
mask = torch.ones(B, N, device=x.device)
mask[:, :len_keep] = 0
mask = torch.gather(mask, 1, ids_restore)
return x_masked, mask, ids_restore
def forward(self, imgs):
# 1. Patch embedding
x = self.patch_embed(imgs)
x = x + self.pos_embed[:, 1:]
# 2. 随机掩码
x, mask, ids_restore = self.random_masking(x, self.mask_ratio)
# 3. 添加 CLS token
cls_token = self.cls_token + self.pos_embed[:, :1]
cls_tokens = cls_token.expand(x.shape[0], -1, -1)
x = torch.cat([cls_tokens, x], dim=1)
# 4. 编码器
x = self.encoder_blocks(x)
x = self.encoder_norm(x)
# 5. 解码器
x = self.decoder_embed(x)
mask_tokens = self.mask_token.repeat(x.shape[0], ids_restore.shape[1] + 1 - x.shape[1], 1)
x_ = torch.cat([x[:, 1:], mask_tokens], dim=1)
x_ = torch.gather(x_, 1, ids_restore.unsqueeze(-1).expand(-1, -1, x_.shape[2]))
x = torch.cat([x[:, :1], x_], dim=1)
x = x + self.decoder_pos_embed
x = self.decoder_blocks(x)
x = self.decoder_norm(x)
pred = self.decoder_pred(x[:, 1:])
return pred, mask
# 测试
mae = MAE(img_size=224, patch_size=16, embed_dim=768, depth=12, num_heads=12,
decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16, mask_ratio=0.75)
imgs = torch.randn(2, 3, 224, 224)
pred, mask = mae(imgs)
print(f"MAE 预测: {pred.shape}, 掩码: {mask.shape}")
# pred: (2, 196, 768), mask: (2, 196)
6. 实战:ViT图像分类¶
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
def train_vit_cifar10():
"""在 CIFAR-10 上训练 ViT"""
# ===== 数据准备 =====
transform_train = transforms.Compose([
transforms.Resize(224),
transforms.RandomHorizontalFlip(),
transforms.RandAugment(num_ops=2, magnitude=9),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
])
transform_test = transforms.Compose([
transforms.Resize(224),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
])
trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)
testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)
train_loader = DataLoader(trainset, batch_size=128, shuffle=True, num_workers=4) # DataLoader批量加载数据,支持shuffle和多进程
test_loader = DataLoader(testset, batch_size=256, shuffle=False, num_workers=4)
# ===== 模型 =====
model = vit_tiny(num_classes=10)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device) # .to(device)将数据移至GPU/CPU
# ===== 训练配置 =====
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3, weight_decay=0.05)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=100)
criterion = nn.CrossEntropyLoss(label_smoothing=0.1)
# ===== 训练循环 =====
for epoch in range(100):
model.train() # train()开启训练模式
total_loss, correct, total = 0, 0, 0
for images, labels in train_loader:
images, labels = images.to(device), labels.to(device)
outputs = model(images)
loss = criterion(outputs, labels)
optimizer.zero_grad() # 清零梯度,防止梯度累积
loss.backward() # 反向传播计算梯度
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
optimizer.step() # 根据梯度更新模型参数
total_loss += loss.item() # .item()将单元素张量转为Python数值
_, predicted = outputs.max(1)
correct += predicted.eq(labels).sum().item()
total += labels.size(0)
scheduler.step()
train_acc = 100. * correct / total
# 验证
model.eval()
test_correct, test_total = 0, 0
with torch.no_grad(): # 禁用梯度计算,节省内存
for images, labels in test_loader:
images, labels = images.to(device), labels.to(device)
outputs = model(images)
_, predicted = outputs.max(1)
test_correct += predicted.eq(labels).sum().item()
test_total += labels.size(0)
test_acc = 100. * test_correct / test_total
print(f"Epoch {epoch+1}: Loss={total_loss/len(train_loader):.4f}, "
f"Train Acc={train_acc:.2f}%, Test Acc={test_acc:.2f}%")
# train_vit_cifar10()
7. 练习与自我检查¶
练习题¶
- Patch Embedding:手动实现 Patch 切分(不用 Conv2d),验证与卷积实现的结果一致。
- ViT 图像分类:在 CIFAR-10/100 上训练 ViT-Tiny,对比 ResNet-18。
- 窗口注意力:实现简化版 Swin Transformer 的一个 stage。
- MAE 预训练:在小数据集上预训练 MAE,然后 fine-tune 到分类任务,观察预训练的好处。
- 注意力可视化:可视化 ViT 不同层、不同头的注意力图,分析各头关注了什么。
自我检查清单¶
- 理解 ViT 如何将图像转化为 token 序列
- 能从零实现 Patch Embedding + ViT
- 了解 DeiT 的训练策略和 DropPath
- 理解 Swin Transformer 的窗口注意力和层次化设计
- 了解 MAE 的掩码预训练思想
- 能在实际数据集上训练 ViT 模型
下一篇: 05-生成模型/01-GAN基础

