跳转至

02 - DQN详解:深度Q网络

学习时间: 4-5小时 重要性: ⭐⭐⭐⭐⭐ 深度强化学习的里程碑 前置知识: 值函数近似、神经网络基础


🎯 学习目标

完成本章后,你将能够: - 理解DQN的核心创新(经验回放、目标网络) - 掌握DQN的网络架构设计 - 实现完整的DQN算法 - 理解DQN的训练技巧 - 能够调试DQN训练过程


1. DQN简介

1.1 历史背景

DeepMind (2015):"Human-level control through deep reinforcement learning"

里程碑意义: - 首次成功将深度学习与强化学习结合 - 在Atari游戏上达到人类水平 - 证明了端到端学习的可行性

1.2 核心挑战

神经网络+RL的问题

  1. 数据相关性:连续样本高度相关
  2. 非平稳分布:策略变化导致数据分布变化
  3. 发散风险:Q值可能无界增长

DQN的解决方案: 1. 经验回放(Experience Replay) 2. 目标网络(Target Network)


2. DQN算法

2.1 网络架构

输入:原始像素帧(84×84×4) 输出:每个动作的Q值

Text Only
Input (84×84×4)
Conv1: 32 filters, 8×8, stride 4 + ReLU
Conv2: 64 filters, 4×4, stride 2 + ReLU
Conv3: 64 filters, 3×3, stride 1 + ReLU
Flatten
FC1: 512 units + ReLU
Output: |A| units (每个动作的Q值)

2.2 代码实现

Python
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
from collections import deque
import random

class DQNNetwork(nn.Module):  # 继承nn.Module定义网络层
    """DQN网络架构"""

    def __init__(self, input_shape, n_actions):
        super(DQNNetwork, self).__init__()

        self.conv = nn.Sequential(
            nn.Conv2d(input_shape[0], 32, kernel_size=8, stride=4),
            nn.ReLU(),
            nn.Conv2d(32, 64, kernel_size=4, stride=2),
            nn.ReLU(),
            nn.Conv2d(64, 64, kernel_size=3, stride=1),
            nn.ReLU()
        )

        # 计算卷积层输出大小
        conv_out_size = self._get_conv_out(input_shape)

        self.fc = nn.Sequential(
            nn.Linear(conv_out_size, 512),
            nn.ReLU(),
            nn.Linear(512, n_actions)
        )

    def _get_conv_out(self, shape):
        """计算卷积层输出维度"""
        o = self.conv(torch.zeros(1, *shape))
        return int(np.prod(o.size()))

    def forward(self, x):
        """前向传播"""
        conv_out = self.conv(x).view(x.size()[0], -1)  # 重塑张量形状
        return self.fc(conv_out)

class ReplayBuffer:
    """经验回放缓冲区"""

    def __init__(self, capacity=100000):
        self.buffer = deque(maxlen=capacity)

    def push(self, state, action, reward, next_state, done):
        """存储经验"""
        self.buffer.append((state, action, reward, next_state, done))

    def sample(self, batch_size):
        """采样批次"""
        batch = random.sample(self.buffer, batch_size)
        states, actions, rewards, next_states, dones = zip(*batch)  # zip按位置配对

        return (np.array(states), np.array(actions),  # np.array创建NumPy数组
                np.array(rewards), np.array(next_states),
                np.array(dones))

    def __len__(self):  # __len__定义len()行为
        return len(self.buffer)

class DQNAgent:
    """DQN智能体"""

    def __init__(self, input_shape, n_actions, lr=1e-4, gamma=0.99,
                 epsilon_start=1.0, epsilon_end=0.01, epsilon_decay=0.995,
                 buffer_size=100000, batch_size=32, target_update=1000):

        self.n_actions = n_actions
        self.gamma = gamma
        self.epsilon = epsilon_start
        self.epsilon_end = epsilon_end
        self.epsilon_decay = epsilon_decay
        self.batch_size = batch_size
        self.target_update = target_update

        # 设备
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

        # 网络
        self.policy_net = DQNNetwork(input_shape, n_actions).to(self.device)
        self.target_net = DQNNetwork(input_shape, n_actions).to(self.device)
        self.target_net.load_state_dict(self.policy_net.state_dict())
        self.target_net.eval()

        # 优化器
        self.optimizer = optim.Adam(self.policy_net.parameters(), lr=lr)

        # 经验回放
        self.replay_buffer = ReplayBuffer(buffer_size)

        # 训练步数
        self.steps = 0

    def select_action(self, state, training=True):
        """ε-贪婪动作选择"""
        if training and random.random() < self.epsilon:
            return random.randrange(self.n_actions)

        with torch.no_grad():  # 禁用梯度计算,节省内存
            state = torch.FloatTensor(state).unsqueeze(0).to(self.device)  # unsqueeze增加一个维度
            q_values = self.policy_net(state)
            return q_values.argmax().item()  # 将单元素张量转为Python数值

    def store_transition(self, state, action, reward, next_state, done):
        """存储转移"""
        self.replay_buffer.push(state, action, reward, next_state, done)

    def update(self):
        """更新网络"""
        if len(self.replay_buffer) < self.batch_size:
            return None

        # 采样批次
        states, actions, rewards, next_states, dones = self.replay_buffer.sample(self.batch_size)

        # 转换为张量
        states = torch.FloatTensor(states).to(self.device)
        actions = torch.LongTensor(actions).to(self.device)
        rewards = torch.FloatTensor(rewards).to(self.device)
        next_states = torch.FloatTensor(next_states).to(self.device)
        dones = torch.FloatTensor(dones).to(self.device)

        # 当前Q值
        current_q = self.policy_net(states).gather(1, actions.unsqueeze(1))

        # 目标Q值(使用目标网络)
        with torch.no_grad():
            next_q = self.target_net(next_states).max(1)[0]
            target_q = rewards + (1 - dones) * self.gamma * next_q

        # 计算损失
        loss = nn.MSELoss()(current_q.squeeze(), target_q)  # squeeze压缩维度

        # 优化
        self.optimizer.zero_grad()  # 清零梯度
        loss.backward()  # 反向传播计算梯度
        # 梯度裁剪
        torch.nn.utils.clip_grad_norm_(self.policy_net.parameters(), max_norm=10)
        self.optimizer.step()  # 更新参数

        # 更新目标网络
        self.steps += 1
        if self.steps % self.target_update == 0:
            self.target_net.load_state_dict(self.policy_net.state_dict())

        # 衰减探索率
        self.epsilon = max(self.epsilon_end, self.epsilon * self.epsilon_decay)

        return loss.item()

    def train(self, env, num_episodes=1000, max_steps=1000):
        """训练智能体"""
        rewards_history = []

        for episode in range(num_episodes):
            state, info = env.reset()
            total_reward = 0

            for step in range(max_steps):
                # 选择动作
                action = self.select_action(state, training=True)

                # 执行动作
                next_state, reward, terminated, truncated, info = env.step(action)
                done = terminated or truncated
                total_reward += reward

                # 存储经验
                self.store_transition(state, action, reward, next_state, done)

                # 更新网络
                loss = self.update()

                state = next_state

                if done:
                    break

            rewards_history.append(total_reward)

            if (episode + 1) % 10 == 0:
                avg_reward = np.mean(rewards_history[-100:])
                print(f"Episode {episode + 1}, Avg Reward: {avg_reward:.2f}, "
                      f"Epsilon: {self.epsilon:.3f}")

        return rewards_history

3. 关键技巧详解

3.1 经验回放

为什么需要?

  1. 打破相关性:随机采样代替连续样本
  2. 提高样本效率:同一经验可被多次使用
  3. 平滑数据分布:避免非平稳性

实现要点: - 足够大的缓冲区(通常100k-1M) - 随机均匀采样 - 优先经验回放(可选改进)

3.2 目标网络

为什么需要?

防止自举导致的不稳定性: - 目标值:\(r + \gamma \max_{a'} Q(s', a'; \theta^-)\) - 当前值:\(Q(s, a; \theta)\) - 使用不同的参数\(\theta^-\)\(\theta\)

更新频率: - 每隔N步复制一次(通常1000-10000) - 或软更新:\(\theta^- \leftarrow \tau \theta + (1-\tau) \theta^-\)

3.3 奖励裁剪

Atari游戏中的问题: - 不同游戏奖励尺度差异大 - 大奖励可能导致Q值爆炸

解决方案

Python
reward = np.clip(reward, -1, 1)

3.4 跳帧(Frame Skipping)

目的:加速训练,减少计算

做法:每4帧执行一次动作,重复上次动作


4. 训练技巧

4.1 超参数调优

超参数 典型值 说明
学习率 1e-4 ~ 1e-3 Adam优化器
折扣因子 0.99 长期回报
ε衰减 0.995-0.999 探索率衰减
批次大小 32 梯度更新
目标更新频率 1000-10000 目标网络更新
缓冲区大小 100k-1M 经验存储

4.2 调试技巧

Python
# 1. 监控Q值
if self.steps % 100 == 0:
    avg_q = current_q.mean().item()
    print(f"Step {self.steps}, Avg Q: {avg_q:.2f}")

# 2. 监控损失
if loss is not None:
    print(f"Loss: {loss:.4f}")

# 3. 可视化训练
import matplotlib.pyplot as plt

def plot_training(rewards, losses):
    fig, axes = plt.subplots(1, 2, figsize=(12, 4))

    # 奖励曲线
    axes[0].plot(rewards)
    axes[0].set_title('Episode Rewards')
    axes[0].set_xlabel('Episode')
    axes[0].set_ylabel('Total Reward')

    # 损失曲线
    axes[1].plot(losses)
    axes[1].set_title('Training Loss')
    axes[1].set_xlabel('Update Step')
    axes[1].set_ylabel('Loss')

    plt.tight_layout()
    plt.show()

5. 实践练习

练习1:CartPole with DQN

Python
import gymnasium as gym

# 创建环境
env = gym.make('CartPole-v1')

# 获取状态维度
state_dim = env.observation_space.shape[0]
n_actions = env.action_space.n

# 创建DQN智能体(简化版,使用MLP)
class SimpleDQN(nn.Module):
    """适合低维状态空间(如CartPole)的MLP网络"""
    def __init__(self, state_dim, n_actions):
        super().__init__()  # super()调用父类方法
        self.net = nn.Sequential(
            nn.Linear(state_dim, 128),
            nn.ReLU(),
            nn.Linear(128, 128),
            nn.ReLU(),
            nn.Linear(128, n_actions)
        )

    def forward(self, x):
        return self.net(x)

# 训练(DQNAgent内部默认使用Conv2d网络,适合图像输入;
# CartPole状态是4维向量,因此先用合法图像shape初始化,再替换为MLP)
agent = DQNAgent(input_shape=(1, 84, 84), n_actions=n_actions)
# 用适合CartPole低维状态的MLP网络替换Conv2d网络
agent.policy_net = SimpleDQN(state_dim, n_actions).to(agent.device)
agent.target_net = SimpleDQN(state_dim, n_actions).to(agent.device)
agent.target_net.load_state_dict(agent.policy_net.state_dict())
agent.optimizer = optim.Adam(agent.policy_net.parameters(), lr=1e-3)

rewards = agent.train(env, num_episodes=500)

练习2:可视化Q值

Python
def visualize_q_values(agent, env):
    """可视化Q值"""
    states = []
    q_values_list = []

    for _ in range(100):
        state, info = env.reset()
        states.append(state)

        with torch.no_grad():
            state_tensor = torch.FloatTensor(state).unsqueeze(0).to(agent.device)
            q_values = agent.policy_net(state_tensor).cpu().numpy()[0]
            q_values_list.append(q_values)

    states = np.array(states)
    q_values_list = np.array(q_values_list)

    # 绘制
    plt.figure(figsize=(12, 4))
    for i in range(agent.n_actions):
        plt.subplot(1, agent.n_actions, i+1)
        plt.scatter(states[:, 0], states[:, 1], c=q_values_list[:, i], cmap='viridis')
        plt.colorbar()
        plt.title(f'Q(s, a={i})')
    plt.show()

6. 本章总结

核心概念

Text Only
DQN:
├── 网络架构: CNN处理图像输入
├── 经验回放: 打破样本相关性
├── 目标网络: 稳定学习目标
└── 训练技巧:
    ├── 奖励裁剪
    ├── 跳帧
    └── 梯度裁剪

关键创新:
├── 端到端学习: 原始像素到动作
├── 稳定性: 经验回放 + 目标网络
└── 泛化性: 相似状态相似Q值

✅ 自测问题

  1. 经验回放的作用是什么?为什么要随机采样而不是按顺序使用?

  2. 目标网络为什么能提高稳定性?软更新和硬更新有什么区别?

  3. DQN中Q值为什么会发散?有哪些防止发散的技巧?

  4. 设计一个实验验证经验回放的效果。


📚 延伸阅读

  1. Mnih et al. (2015)
  2. "Human-level control through deep reinforcement learning"
  3. Nature, 518(7540), 529-533

  4. Mnih et al. (2013)

  5. "Playing Atari with Deep Reinforcement Learning"
  6. 最初的DQN论文(arXiv)

准备好学习DQN的改进版本了吗?

→ 下一步:03-DQN改进算法.md