跳转至

03-视觉Transformer

⚠️ 时效性说明:本章涉及前沿模型/价格/榜单等信息,可能随版本快速变化;请以论文原文、官方发布页和 API 文档为准。

学习时间: 约6-8小时 难度级别: ⭐⭐⭐⭐ 中高级 前置知识: Transformer架构、CNN基础、PyTorch 学习目标: 理解Vision Transformer (ViT) 系列模型,掌握图像分类中Transformer的应用


目录


1. 从CNN到ViT

1.1 CNN的局限性

  • 局部感受野:卷积核只看局部区域,依赖层数堆叠扩大感受野
  • 平移等变但不具有全局建模能力
  • 归纳偏置强:适合图像但限制了灵活性

1.2 为什么 Transformer 能用于视觉

Transformer 的自注意力天然具有全局建模能力——每个位置都能关注所有其他位置。关键问题是如何将图像转化为序列


2. Vision Transformer (ViT)

2.1 核心思想

Vision Transformer架构

Dosovitskiy et al.(2020):"An Image is Worth 16x16 Words"

将图像切分为固定大小的图块(Patch),展平后通过线性投影得到 token 序列,再加上位置编码和 [CLS] token,送入标准 Transformer Encoder。

2.2 Patch Embedding

Patch Embedding

一张 \(H \times W\) 的图像,被切分为 \(N = \frac{H \times W}{P^2}\) 个大小为 \(P \times P\) 的 Patch。

\[z_0 = [x_{class}; \; x_p^1 E; \; x_p^2 E; \; \cdots; \; x_p^N E] + E_{pos}\]

其中 \(E \in \mathbb{R}^{(P^2 \cdot C) \times D}\) 是线性投影,\(E_{pos} \in \mathbb{R}^{(N+1) \times D}\) 是位置编码。

Python
import torch
import torch.nn as nn
import torch.nn.functional as F
import math

class PatchEmbedding(nn.Module):  # 继承nn.Module定义神经网络层
    """将图像切分为 Patch 并线性投影"""
    def __init__(self, img_size=224, patch_size=16, in_channels=3, embed_dim=768):  # __init__构造方法,创建对象时自动调用
        super().__init__()  # super()调用父类方法
        self.img_size = img_size
        self.patch_size = patch_size
        self.num_patches = (img_size // patch_size) ** 2

        # 用卷积实现 Patch Embedding(等价于切分+线性投影,但更高效)
        self.proj = nn.Conv2d(
            in_channels, embed_dim,
            kernel_size=patch_size, stride=patch_size
        )

    def forward(self, x):
        """
        x: (batch, C, H, W)
        return: (batch, num_patches, embed_dim)
        """
        x = self.proj(x)              # (batch, embed_dim, H/P, W/P)
        x = x.flatten(2)              # (batch, embed_dim, num_patches)
        x = x.transpose(1, 2)         # (batch, num_patches, embed_dim)
        return x

2.3 完整 ViT 实现

Python
class ViTBlock(nn.Module):
    """ViT Transformer Block(Pre-Norm)"""
    def __init__(self, embed_dim, num_heads, mlp_ratio=4.0, dropout=0.1):
        super().__init__()
        self.norm1 = nn.LayerNorm(embed_dim)
        self.attn = nn.MultiheadAttention(embed_dim, num_heads, dropout=dropout, batch_first=True)
        self.norm2 = nn.LayerNorm(embed_dim)

        mlp_hidden = int(embed_dim * mlp_ratio)
        self.mlp = nn.Sequential(
            nn.Linear(embed_dim, mlp_hidden),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(mlp_hidden, embed_dim),
            nn.Dropout(dropout),
        )

    def forward(self, x):
        # Pre-Norm
        normed = self.norm1(x)
        attn_out, _ = self.attn(normed, normed, normed)
        x = x + attn_out

        x = x + self.mlp(self.norm2(x))
        return x

class VisionTransformer(nn.Module):
    """Vision Transformer (ViT)"""
    def __init__(self, img_size=224, patch_size=16, in_channels=3, num_classes=1000,
                 embed_dim=768, depth=12, num_heads=12, mlp_ratio=4.0, dropout=0.1):
        super().__init__()
        self.patch_embed = PatchEmbedding(img_size, patch_size, in_channels, embed_dim)
        num_patches = self.patch_embed.num_patches

        # [CLS] token
        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))

        # 可学习位置编码
        self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
        self.pos_drop = nn.Dropout(dropout)

        # Transformer Encoder
        self.blocks = nn.Sequential(*[
            ViTBlock(embed_dim, num_heads, mlp_ratio, dropout)
            for _ in range(depth)
        ])

        self.norm = nn.LayerNorm(embed_dim)

        # 分类头
        self.head = nn.Linear(embed_dim, num_classes)

        # 初始化
        nn.init.trunc_normal_(self.cls_token, std=0.02)
        nn.init.trunc_normal_(self.pos_embed, std=0.02)
        self.apply(self._init_weights)

    def _init_weights(self, m):
        if isinstance(m, nn.Linear):  # isinstance检查对象类型
            nn.init.trunc_normal_(m.weight, std=0.02)
            if m.bias is not None:
                nn.init.zeros_(m.bias)
        elif isinstance(m, nn.LayerNorm):
            nn.init.ones_(m.weight)
            nn.init.zeros_(m.bias)

    def forward(self, x):
        """x: (batch, C, H, W) → (batch, num_classes)"""
        batch_size = x.size(0)

        # 1. Patch Embedding
        x = self.patch_embed(x)  # (batch, num_patches, embed_dim)

        # 2. 添加 [CLS] token
        cls_tokens = self.cls_token.expand(batch_size, -1, -1)
        x = torch.cat([cls_tokens, x], dim=1)  # (batch, num_patches+1, embed_dim)

        # 3. 位置编码
        x = self.pos_drop(x + self.pos_embed)

        # 4. Transformer Encoder
        x = self.blocks(x)
        x = self.norm(x)

        # 5. 分类 — 取 [CLS] token 的输出
        cls_output = x[:, 0]
        logits = self.head(cls_output)

        return logits

# ===== 测试各种 ViT 配置 =====
def vit_tiny(num_classes=100):
    return VisionTransformer(embed_dim=192, depth=12, num_heads=3, num_classes=num_classes)

def vit_small(num_classes=100):
    return VisionTransformer(embed_dim=384, depth=12, num_heads=6, num_classes=num_classes)

def vit_base(num_classes=100):
    return VisionTransformer(embed_dim=768, depth=12, num_heads=12, num_classes=num_classes)

model = vit_tiny(num_classes=10)
x = torch.randn(4, 3, 224, 224)
out = model(x)
print(f"ViT-Tiny 输出: {out.shape}")
print(f"参数量: {sum(p.numel() for p in model.parameters()):,}")

3. DeiT:数据高效训练

3.1 训练策略

Touvron et al.(2021)的 DeiT 证明 ViT 不需要 JFT-300M 大数据集,通过以下技巧在 ImageNet 上高效训练:

  • 强数据增强:RandAugment, Mixup, CutMix
  • 正则化:DropPath, Label Smoothing
  • 知识蒸馏:使用 distillation token

3.2 DropPath(Stochastic Depth)

Python
class DropPath(nn.Module):
    """DropPath(随机深度):随机跳过整个残差分支"""
    def __init__(self, drop_prob=0.0):
        super().__init__()
        self.drop_prob = drop_prob

    def forward(self, x):
        if not self.training or self.drop_prob == 0.0:
            return x
        keep_prob = 1 - self.drop_prob
        shape = (x.shape[0],) + (1,) * (x.ndim - 1)
        random_tensor = torch.rand(shape, device=x.device) < keep_prob
        return x * random_tensor / keep_prob

class DeiTBlock(nn.Module):
    """带 DropPath 的 ViT Block"""
    def __init__(self, embed_dim, num_heads, mlp_ratio=4.0, dropout=0.1, drop_path=0.0):
        super().__init__()
        self.norm1 = nn.LayerNorm(embed_dim)
        self.attn = nn.MultiheadAttention(embed_dim, num_heads, dropout=dropout, batch_first=True)
        self.drop_path1 = DropPath(drop_path)

        self.norm2 = nn.LayerNorm(embed_dim)
        mlp_hidden = int(embed_dim * mlp_ratio)
        self.mlp = nn.Sequential(
            nn.Linear(embed_dim, mlp_hidden),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(mlp_hidden, embed_dim),
            nn.Dropout(dropout)
        )
        self.drop_path2 = DropPath(drop_path)

    def forward(self, x):
        normed = self.norm1(x)
        attn_out, _ = self.attn(normed, normed, normed)
        x = x + self.drop_path1(attn_out)
        x = x + self.drop_path2(self.mlp(self.norm2(x)))
        return x

4. Swin Transformer

4.1 核心创新

Liu et al.(2021)的 Swin Transformer 解决了 ViT 的两个问题:

  1. 多尺度特征:类似 CNN 的层次化结构,通过 Patch Merging 逐步降低分辨率
  2. 计算效率:用窗口注意力将复杂度从 \(O(n^2)\) 降低到 \(O(n)\)

4.2 窗口注意力(Window Attention)

Python
class WindowAttention(nn.Module):
    """窗口注意力"""
    def __init__(self, dim, window_size, num_heads):
        super().__init__()
        self.dim = dim
        self.window_size = window_size  # (Wh, Ww)
        self.num_heads = num_heads
        head_dim = dim // num_heads
        self.scale = head_dim ** -0.5

        self.qkv = nn.Linear(dim, dim * 3)
        self.proj = nn.Linear(dim, dim)

        # 相对位置偏置
        self.relative_position_bias_table = nn.Parameter(
            torch.zeros((2 * window_size - 1) * (2 * window_size - 1), num_heads)
        )
        nn.init.trunc_normal_(self.relative_position_bias_table, std=0.02)

        coords = torch.arange(window_size)
        coords = torch.stack(torch.meshgrid([coords, coords], indexing='ij')).flatten(1)
        relative_coords = coords[:, :, None] - coords[:, None, :]
        relative_coords = relative_coords.permute(1, 2, 0).contiguous()  # 链式调用,连续执行多个方法
        relative_coords[:, :, 0] += window_size - 1
        relative_coords[:, :, 1] += window_size - 1
        relative_coords[:, :, 0] *= 2 * window_size - 1
        relative_position_index = relative_coords.sum(-1)
        self.register_buffer('relative_position_index', relative_position_index)

    def forward(self, x, mask=None):
        B_, N, C = x.shape
        qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads)  # reshape重塑张量形状
        qkv = qkv.permute(2, 0, 3, 1, 4)
        q, k, v = qkv.unbind(0)

        attn = (q @ k.transpose(-2, -1)) * self.scale

        relative_position_bias = self.relative_position_bias_table[
            self.relative_position_index.view(-1)  # view重塑张量形状(要求内存连续)
        ].view(N, N, -1).permute(2, 0, 1)
        attn = attn + relative_position_bias.unsqueeze(0)  # unsqueeze增加一个维度

        if mask is not None:
            attn = attn + mask.unsqueeze(1)

        attn = F.softmax(attn, dim=-1)
        x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
        x = self.proj(x)
        return x

4.3 窗口分割与合并

Python
def window_partition(x, window_size):
    """将特征图分割为不重叠的窗口"""
    B, H, W, C = x.shape
    x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
    windows = x.permute(0, 1, 3, 2, 4, 5).contiguous()
    windows = windows.view(-1, window_size, window_size, C)
    return windows  # (num_windows*B, window_size, window_size, C)

def window_reverse(windows, window_size, H, W):
    """将窗口合并回特征图"""
    B = int(windows.shape[0] / (H * W / window_size / window_size))
    x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
    x = x.permute(0, 1, 3, 2, 4, 5).contiguous()
    x = x.view(B, H, W, -1)
    return x

4.4 Patch Merging

Python
class PatchMerging(nn.Module):
    """Patch合并 — 降低分辨率(2倍下采样)"""
    def __init__(self, dim):
        super().__init__()
        self.norm = nn.LayerNorm(4 * dim)
        self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)

    def forward(self, x, H, W):
        B, L, C = x.shape
        x = x.view(B, H, W, C)

        x0 = x[:, 0::2, 0::2, :]
        x1 = x[:, 1::2, 0::2, :]
        x2 = x[:, 0::2, 1::2, :]
        x3 = x[:, 1::2, 1::2, :]
        x = torch.cat([x0, x1, x2, x3], dim=-1)  # (B, H/2, W/2, 4C)

        x = x.view(B, -1, 4 * C)
        x = self.norm(x)
        x = self.reduction(x)  # (B, H/2 * W/2, 2C)

        return x, H // 2, W // 2

5. MAE:掩码自编码器

5.1 思想

He et al.(2022)提出 MAE(Masked Autoencoders):随机遮挡大量图像 Patch(如 75%),训练模型重构被遮挡的 Patch。

Python
class MAE(nn.Module):
    """简化版 Masked Autoencoder"""
    def __init__(self, img_size=224, patch_size=16, in_channels=3,
                 embed_dim=768, depth=12, num_heads=12,
                 decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16,
                 mask_ratio=0.75):
        super().__init__()
        self.patch_embed = PatchEmbedding(img_size, patch_size, in_channels, embed_dim)
        self.num_patches = self.patch_embed.num_patches
        self.mask_ratio = mask_ratio
        self.patch_size = patch_size

        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
        self.pos_embed = nn.Parameter(torch.zeros(1, self.num_patches + 1, embed_dim))

        # 编码器
        self.encoder_blocks = nn.Sequential(*[
            ViTBlock(embed_dim, num_heads) for _ in range(depth)
        ])
        self.encoder_norm = nn.LayerNorm(embed_dim)

        # 解码器
        self.decoder_embed = nn.Linear(embed_dim, decoder_embed_dim)
        self.mask_token = nn.Parameter(torch.zeros(1, 1, decoder_embed_dim))
        self.decoder_pos_embed = nn.Parameter(torch.zeros(1, self.num_patches + 1, decoder_embed_dim))

        self.decoder_blocks = nn.Sequential(*[
            ViTBlock(decoder_embed_dim, decoder_num_heads) for _ in range(decoder_depth)
        ])
        self.decoder_norm = nn.LayerNorm(decoder_embed_dim)
        self.decoder_pred = nn.Linear(decoder_embed_dim, patch_size**2 * in_channels)

        self._init_weights()

    def _init_weights(self):
        nn.init.trunc_normal_(self.cls_token, std=0.02)
        nn.init.trunc_normal_(self.pos_embed, std=0.02)
        nn.init.trunc_normal_(self.mask_token, std=0.02)
        nn.init.trunc_normal_(self.decoder_pos_embed, std=0.02)

    def random_masking(self, x, mask_ratio):
        """随机掩码,返回可见 token 及索引"""
        B, N, D = x.shape
        len_keep = int(N * (1 - mask_ratio))

        noise = torch.rand(B, N, device=x.device)
        ids_shuffle = torch.argsort(noise, dim=1)
        ids_restore = torch.argsort(ids_shuffle, dim=1)

        ids_keep = ids_shuffle[:, :len_keep]
        x_masked = torch.gather(x, 1, ids_keep.unsqueeze(-1).expand(-1, -1, D))

        mask = torch.ones(B, N, device=x.device)
        mask[:, :len_keep] = 0
        mask = torch.gather(mask, 1, ids_restore)

        return x_masked, mask, ids_restore

    def forward(self, imgs):
        # 1. Patch embedding
        x = self.patch_embed(imgs)
        x = x + self.pos_embed[:, 1:]

        # 2. 随机掩码
        x, mask, ids_restore = self.random_masking(x, self.mask_ratio)

        # 3. 添加 CLS token
        cls_token = self.cls_token + self.pos_embed[:, :1]
        cls_tokens = cls_token.expand(x.shape[0], -1, -1)
        x = torch.cat([cls_tokens, x], dim=1)

        # 4. 编码器
        x = self.encoder_blocks(x)
        x = self.encoder_norm(x)

        # 5. 解码器
        x = self.decoder_embed(x)

        mask_tokens = self.mask_token.repeat(x.shape[0], ids_restore.shape[1] + 1 - x.shape[1], 1)
        x_ = torch.cat([x[:, 1:], mask_tokens], dim=1)
        x_ = torch.gather(x_, 1, ids_restore.unsqueeze(-1).expand(-1, -1, x_.shape[2]))
        x = torch.cat([x[:, :1], x_], dim=1)

        x = x + self.decoder_pos_embed
        x = self.decoder_blocks(x)
        x = self.decoder_norm(x)
        pred = self.decoder_pred(x[:, 1:])

        return pred, mask

# 测试
mae = MAE(img_size=224, patch_size=16, embed_dim=768, depth=12, num_heads=12,
          decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16, mask_ratio=0.75)
imgs = torch.randn(2, 3, 224, 224)
pred, mask = mae(imgs)
print(f"MAE 预测: {pred.shape}, 掩码: {mask.shape}")
# pred: (2, 196, 768), mask: (2, 196)

6. 实战:ViT图像分类

Python
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader

def train_vit_cifar10():
    """在 CIFAR-10 上训练 ViT"""

    # ===== 数据准备 =====
    transform_train = transforms.Compose([
        transforms.Resize(224),
        transforms.RandomHorizontalFlip(),
        transforms.RandAugment(num_ops=2, magnitude=9),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
    ])

    transform_test = transforms.Compose([
        transforms.Resize(224),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
    ])

    trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)
    testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)

    train_loader = DataLoader(trainset, batch_size=128, shuffle=True, num_workers=4)  # DataLoader批量加载数据,支持shuffle和多进程
    test_loader = DataLoader(testset, batch_size=256, shuffle=False, num_workers=4)

    # ===== 模型 =====
    model = vit_tiny(num_classes=10)
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = model.to(device)  # .to(device)将数据移至GPU/CPU

    # ===== 训练配置 =====
    optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3, weight_decay=0.05)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=100)
    criterion = nn.CrossEntropyLoss(label_smoothing=0.1)

    # ===== 训练循环 =====
    for epoch in range(100):
        model.train()  # train()开启训练模式
        total_loss, correct, total = 0, 0, 0

        for images, labels in train_loader:
            images, labels = images.to(device), labels.to(device)

            outputs = model(images)
            loss = criterion(outputs, labels)

            optimizer.zero_grad()  # 清零梯度,防止梯度累积
            loss.backward()  # 反向传播计算梯度
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()  # 根据梯度更新模型参数

            total_loss += loss.item()  # .item()将单元素张量转为Python数值
            _, predicted = outputs.max(1)
            correct += predicted.eq(labels).sum().item()
            total += labels.size(0)

        scheduler.step()
        train_acc = 100. * correct / total

        # 验证
        model.eval()
        test_correct, test_total = 0, 0
        with torch.no_grad():  # 禁用梯度计算,节省内存
            for images, labels in test_loader:
                images, labels = images.to(device), labels.to(device)
                outputs = model(images)
                _, predicted = outputs.max(1)
                test_correct += predicted.eq(labels).sum().item()
                test_total += labels.size(0)

        test_acc = 100. * test_correct / test_total
        print(f"Epoch {epoch+1}: Loss={total_loss/len(train_loader):.4f}, "
              f"Train Acc={train_acc:.2f}%, Test Acc={test_acc:.2f}%")

# train_vit_cifar10()

7. 练习与自我检查

练习题

  1. Patch Embedding:手动实现 Patch 切分(不用 Conv2d),验证与卷积实现的结果一致。
  2. ViT 图像分类:在 CIFAR-10/100 上训练 ViT-Tiny,对比 ResNet-18。
  3. 窗口注意力:实现简化版 Swin Transformer 的一个 stage。
  4. MAE 预训练:在小数据集上预训练 MAE,然后 fine-tune 到分类任务,观察预训练的好处。
  5. 注意力可视化:可视化 ViT 不同层、不同头的注意力图,分析各头关注了什么。

自我检查清单

  • 理解 ViT 如何将图像转化为 token 序列
  • 能从零实现 Patch Embedding + ViT
  • 了解 DeiT 的训练策略和 DropPath
  • 理解 Swin Transformer 的窗口注意力和层次化设计
  • 了解 MAE 的掩码预训练思想
  • 能在实际数据集上训练 ViT 模型

下一篇: 05-生成模型/01-GAN基础