跳转至

项目3 - Atari游戏

难度: ⭐⭐⭐⭐ 高级 预计时间: 4-6小时 目标: 用DQN玩Atari游戏


1. 项目介绍

1.1 Atari环境

OpenAI Gym/Gymnasium提供: - Breakout - Pong - SpaceInvaders - 等等

挑战: - 高维图像输入 (84×84×4) - 延迟奖励 - 需要理解视觉模式

1.2 成功标准

  • 达到人类水平或超越
  • 稳定的训练过程
  • 良好的泛化能力

2. 图像预处理

2.1 预处理流程

Python
import numpy as np
import torch
import torch.nn as nn
import gymnasium as gym
from collections import deque
import cv2

class AtariPreprocessor:
    """Atari图像预处理器"""

    def __init__(self, frame_size=84, frame_stack=4):
        self.frame_size = frame_size
        self.frame_stack = frame_stack
        self.frames = deque(maxlen=frame_stack)

    def preprocess(self, frame):
        """
        预处理单帧图像

        步骤:
        1. 转换为灰度图
        2. 调整大小
        3. 归一化
        """
        # 转换为灰度图
        gray = cv2.cvtColor(frame, cv2.COLOR_RGB2GRAY)

        # 调整大小
        resized = cv2.resize(gray, (self.frame_size, self.frame_size),
                            interpolation=cv2.INTER_AREA)

        # 归一化到[0, 1]
        normalized = resized.astype(np.float32) / 255.0

        return normalized

    def reset(self):
        """重置帧栈"""
        self.frames.clear()

    def add_frame(self, frame):
        """添加帧到栈"""
        processed = self.preprocess(frame)
        self.frames.append(processed)

        # 如果帧栈未满,重复最后一帧
        while len(self.frames) < self.frame_stack:
            self.frames.append(processed)

        return self.get_state()

    def get_state(self):
        """获取当前状态(堆叠的帧)"""
        return np.stack(self.frames, axis=0)

class NoopResetEnv(gym.Wrapper):
    """开始时执行随机数量的Noop动作"""

    def __init__(self, env, noop_max=30):
        super().__init__(env)  # super()调用父类方法
        self.noop_max = noop_max
        self.noop_action = 0

    def reset(self, **kwargs):  # *args接收任意位置参数,**kwargs接收任意关键字参数
        obs, info = self.env.reset(**kwargs)

        # 执行随机数量的Noop
        noops = np.random.randint(1, self.noop_max + 1)
        for _ in range(noops):
            obs, _, terminated, truncated, info = self.env.step(self.noop_action)
            if terminated or truncated:
                obs, info = self.env.reset(**kwargs)

        return obs, info

class MaxAndSkipEnv(gym.Wrapper):
    """跳帧并取最大池化"""

    def __init__(self, env, skip=4):
        super().__init__(env)
        self.skip = skip
        self.obs_buffer = deque(maxlen=2)

    def step(self, action):
        total_reward = 0.0
        terminated = False
        truncated = False

        for _ in range(self.skip):
            obs, reward, terminated, truncated, info = self.env.step(action)
            self.obs_buffer.append(obs)
            total_reward += reward

            if terminated or truncated:
                break

        # 取最后两帧的最大值
        max_frame = np.max(np.stack(self.obs_buffer), axis=0)

        return max_frame, total_reward, terminated, truncated, info

    def reset(self, **kwargs):
        self.obs_buffer.clear()
        obs, info = self.env.reset(**kwargs)
        self.obs_buffer.append(obs)
        return obs, info

def make_atari_env(env_name):
    """创建Atari环境(带所有预处理)"""
    env = gym.make(env_name, render_mode=None)
    env = NoopResetEnv(env)
    env = MaxAndSkipEnv(env)
    return env

3. DQN网络架构

Python
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

class DQN(nn.Module):  # 继承nn.Module定义网络层
    """DQN网络(用于Atari)"""

    def __init__(self, n_actions):
        super(DQN, self).__init__()

        # 卷积层
        self.conv = nn.Sequential(
            nn.Conv2d(4, 32, kernel_size=8, stride=4),
            nn.ReLU(),
            nn.Conv2d(32, 64, kernel_size=4, stride=2),
            nn.ReLU(),
            nn.Conv2d(64, 64, kernel_size=3, stride=1),
            nn.ReLU()
        )

        # 计算卷积输出大小
        self.conv_out_size = self._get_conv_out((4, 84, 84))

        # 全连接层
        self.fc = nn.Sequential(
            nn.Linear(self.conv_out_size, 512),
            nn.ReLU(),
            nn.Linear(512, n_actions)
        )

    def _get_conv_out(self, shape):
        """计算卷积层输出大小"""
        o = self.conv(torch.zeros(1, *shape))
        return int(np.prod(o.size()))

    def forward(self, x):
        """前向传播"""
        conv_out = self.conv(x).view(x.size(0), -1)  # 重塑张量形状
        return self.fc(conv_out)

class DuelingDQN(nn.Module):
    """Dueling DQN网络"""

    def __init__(self, n_actions):
        super(DuelingDQN, self).__init__()

        # 共享卷积层
        self.conv = nn.Sequential(
            nn.Conv2d(4, 32, kernel_size=8, stride=4),
            nn.ReLU(),
            nn.Conv2d(32, 64, kernel_size=4, stride=2),
            nn.ReLU(),
            nn.Conv2d(64, 64, kernel_size=3, stride=1),
            nn.ReLU()
        )

        self.conv_out_size = self._get_conv_out((4, 84, 84))

        # 共享全连接
        self.fc_shared = nn.Sequential(
            nn.Linear(self.conv_out_size, 512),
            nn.ReLU()
        )

        # Value流
        self.value_stream = nn.Sequential(
            nn.Linear(512, 256),
            nn.ReLU(),
            nn.Linear(256, 1)
        )

        # Advantage流
        self.advantage_stream = nn.Sequential(
            nn.Linear(512, 256),
            nn.ReLU(),
            nn.Linear(256, n_actions)
        )

    def _get_conv_out(self, shape):
        o = self.conv(torch.zeros(1, *shape))
        return int(np.prod(o.size()))

    def forward(self, x):
        conv_out = self.conv(x).view(x.size(0), -1)
        shared_out = self.fc_shared(conv_out)

        value = self.value_stream(shared_out)
        advantage = self.advantage_stream(shared_out)

        # Dueling聚合
        q = value + (advantage - advantage.mean(dim=1, keepdim=True))
        return q

4. 完整DQN智能体

Python
class ReplayBuffer:
    """经验回放缓冲区"""

    def __init__(self, capacity=100000):
        self.buffer = deque(maxlen=capacity)

    def push(self, state, action, reward, next_state, done):
        self.buffer.append((state, action, reward, next_state, done))

    def sample(self, batch_size):
        batch = random.sample(self.buffer, batch_size)
        states, actions, rewards, next_states, dones = zip(*batch)  # zip按位置配对

        return (np.array(states), np.array(actions),  # np.array创建NumPy数组
                np.array(rewards), np.array(next_states),
                np.array(dones))

    def __len__(self):  # __len__定义len()行为
        return len(self.buffer)

class AtariDQNAgent:
    """Atari DQN智能体"""

    def __init__(self, n_actions, lr=1e-4, gamma=0.99,
                 epsilon_start=1.0, epsilon_end=0.01, epsilon_decay=0.995,
                 buffer_size=100000, batch_size=32, target_update=1000):

        self.n_actions = n_actions
        self.gamma = gamma
        self.epsilon = epsilon_start
        self.epsilon_end = epsilon_end
        self.epsilon_decay = epsilon_decay
        self.batch_size = batch_size
        self.target_update = target_update

        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

        # 网络
        self.policy_net = DuelingDQN(n_actions).to(self.device)
        self.target_net = DuelingDQN(n_actions).to(self.device)
        self.target_net.load_state_dict(self.policy_net.state_dict())
        self.target_net.eval()

        self.optimizer = optim.Adam(self.policy_net.parameters(), lr=lr)
        self.replay_buffer = ReplayBuffer(buffer_size)

        self.steps = 0
        self.preprocessor = AtariPreprocessor()

    def select_action(self, state, training=True):
        """选择动作"""
        if training and random.random() < self.epsilon:
            return random.randrange(self.n_actions)

        with torch.no_grad():  # 禁用梯度计算,节省内存
            state = torch.FloatTensor(state).unsqueeze(0).to(self.device)  # unsqueeze增加一个维度
            q_values = self.policy_net(state)
            return q_values.argmax().item()  # 将单元素张量转为Python数值

    def store_transition(self, state, action, reward, next_state, done):
        """存储转移"""
        # 裁剪奖励
        reward = np.clip(reward, -1, 1)
        self.replay_buffer.push(state, action, reward, next_state, done)

    def update(self):
        """更新网络"""
        if len(self.replay_buffer) < self.batch_size:
            return None

        states, actions, rewards, next_states, dones = self.replay_buffer.sample(self.batch_size)

        states = torch.FloatTensor(states).to(self.device)
        actions = torch.LongTensor(actions).to(self.device)
        rewards = torch.FloatTensor(rewards).to(self.device)
        next_states = torch.FloatTensor(next_states).to(self.device)
        dones = torch.FloatTensor(dones).to(self.device)

        # 当前Q值
        current_q = self.policy_net(states).gather(1, actions.unsqueeze(1))

        # Double DQN目标
        with torch.no_grad():
            next_actions = self.policy_net(next_states).argmax(1)
            next_q = self.target_net(next_states).gather(1, next_actions.unsqueeze(1)).squeeze(1)  # squeeze压缩维度
            target_q = rewards + (1 - dones) * self.gamma * next_q

        loss = F.mse_loss(current_q.squeeze(), target_q)  # F.xxx PyTorch函数式API

        self.optimizer.zero_grad()  # 清零梯度
        loss.backward()  # 反向传播计算梯度
        torch.nn.utils.clip_grad_norm_(self.policy_net.parameters(), max_norm=10)
        self.optimizer.step()  # 更新参数

        # 更新目标网络
        self.steps += 1
        if self.steps % self.target_update == 0:
            self.target_net.load_state_dict(self.policy_net.state_dict())

        # 衰减探索率
        self.epsilon = max(self.epsilon_end, self.epsilon * self.epsilon_decay)

        return loss.item()

    def save(self, path):
        """保存模型"""
        torch.save({
            'policy_net': self.policy_net.state_dict(),
            'target_net': self.target_net.state_dict(),
            'optimizer': self.optimizer.state_dict(),
            'epsilon': self.epsilon,
            'steps': self.steps
        }, path)

    def load(self, path):
        """加载模型"""
        checkpoint = torch.load(path, weights_only=True)
        self.policy_net.load_state_dict(checkpoint['policy_net'])
        self.target_net.load_state_dict(checkpoint['target_net'])
        self.optimizer.load_state_dict(checkpoint['optimizer'])
        self.epsilon = checkpoint['epsilon']
        self.steps = checkpoint['steps']

5. 训练循环

Python
def train_atari(env_name='Breakout-v4', num_frames=1000000):
    """训练Atari智能体"""

    # 创建环境
    env = make_atari_env(env_name)
    n_actions = env.action_space.n

    # 创建智能体
    agent = AtariDQNAgent(n_actions)
    preprocessor = AtariPreprocessor()

    # 训练统计
    episode_rewards = []
    episode_lengths = []
    losses = []

    frame_count = 0
    episode = 0

    while frame_count < num_frames:
        obs, _ = env.reset()
        preprocessor.reset()
        state = preprocessor.add_frame(obs)

        episode_reward = 0
        episode_length = 0
        done = False

        while not done:
            # 选择动作
            action = agent.select_action(state, training=True)

            # 执行动作
            next_obs, reward, terminated, truncated, _ = env.step(action)
            done = terminated or truncated

            # 预处理下一状态
            next_state = preprocessor.add_frame(next_obs)

            # 存储转移
            agent.store_transition(state, action, reward, next_state, done)

            # 更新网络
            loss = agent.update()
            if loss is not None:
                losses.append(loss)

            episode_reward += reward
            episode_length += 1
            frame_count += 1
            state = next_state

            # 定期保存和打印
            if frame_count % 10000 == 0:
                avg_reward = np.mean(episode_rewards[-100:]) if episode_rewards else 0
                avg_loss = np.mean(losses[-1000:]) if losses else 0
                print(f"Frame {frame_count}/{num_frames}, "
                      f"Episode {episode}, "
                      f"Avg Reward: {avg_reward:.2f}, "
                      f"Epsilon: {agent.epsilon:.3f}, "
                      f"Avg Loss: {avg_loss:.4f}")

                # 保存模型
                agent.save(f'atari_dqn_{frame_count}.pth')

        episode_rewards.append(episode_reward)
        episode_lengths.append(episode_length)
        episode += 1

    env.close()
    return agent, episode_rewards, losses

# 主程序
if __name__ == "__main__":
    agent, rewards, losses = train_atari('Breakout-v4', num_frames=1000000)

6. 项目总结

学到的技能

  • 图像预处理(灰度、缩放、归一化)
  • 帧堆叠和跳帧
  • CNN架构设计
  • 模型保存和加载
  • 长时间训练管理

关键概念

Text Only
Atari DQN:
├── 预处理: 灰度 + 缩放 + 归一化
├── 帧处理: 跳帧 + 帧堆叠
├── 网络: CNN + Dueling
├── 训练: Double DQN + 经验回放
└── 技巧: 奖励裁剪 + 目标网络

恭喜完成Atari游戏项目!

→ 下一个项目:项目4-连续控制.md