项目3 - Atari游戏¶
难度: ⭐⭐⭐⭐ 高级 预计时间: 4-6小时 目标: 用DQN玩Atari游戏
1. 项目介绍¶
1.1 Atari环境¶
OpenAI Gym/Gymnasium提供: - Breakout - Pong - SpaceInvaders - 等等
挑战: - 高维图像输入 (84×84×4) - 延迟奖励 - 需要理解视觉模式
1.2 成功标准¶
- 达到人类水平或超越
- 稳定的训练过程
- 良好的泛化能力
2. 图像预处理¶
2.1 预处理流程¶
Python
import numpy as np
import torch
import torch.nn as nn
import gymnasium as gym
from collections import deque
import cv2
class AtariPreprocessor:
"""Atari图像预处理器"""
def __init__(self, frame_size=84, frame_stack=4):
self.frame_size = frame_size
self.frame_stack = frame_stack
self.frames = deque(maxlen=frame_stack)
def preprocess(self, frame):
"""
预处理单帧图像
步骤:
1. 转换为灰度图
2. 调整大小
3. 归一化
"""
# 转换为灰度图
gray = cv2.cvtColor(frame, cv2.COLOR_RGB2GRAY)
# 调整大小
resized = cv2.resize(gray, (self.frame_size, self.frame_size),
interpolation=cv2.INTER_AREA)
# 归一化到[0, 1]
normalized = resized.astype(np.float32) / 255.0
return normalized
def reset(self):
"""重置帧栈"""
self.frames.clear()
def add_frame(self, frame):
"""添加帧到栈"""
processed = self.preprocess(frame)
self.frames.append(processed)
# 如果帧栈未满,重复最后一帧
while len(self.frames) < self.frame_stack:
self.frames.append(processed)
return self.get_state()
def get_state(self):
"""获取当前状态(堆叠的帧)"""
return np.stack(self.frames, axis=0)
class NoopResetEnv(gym.Wrapper):
"""开始时执行随机数量的Noop动作"""
def __init__(self, env, noop_max=30):
super().__init__(env) # super()调用父类方法
self.noop_max = noop_max
self.noop_action = 0
def reset(self, **kwargs): # *args接收任意位置参数,**kwargs接收任意关键字参数
obs, info = self.env.reset(**kwargs)
# 执行随机数量的Noop
noops = np.random.randint(1, self.noop_max + 1)
for _ in range(noops):
obs, _, terminated, truncated, info = self.env.step(self.noop_action)
if terminated or truncated:
obs, info = self.env.reset(**kwargs)
return obs, info
class MaxAndSkipEnv(gym.Wrapper):
"""跳帧并取最大池化"""
def __init__(self, env, skip=4):
super().__init__(env)
self.skip = skip
self.obs_buffer = deque(maxlen=2)
def step(self, action):
total_reward = 0.0
terminated = False
truncated = False
for _ in range(self.skip):
obs, reward, terminated, truncated, info = self.env.step(action)
self.obs_buffer.append(obs)
total_reward += reward
if terminated or truncated:
break
# 取最后两帧的最大值
max_frame = np.max(np.stack(self.obs_buffer), axis=0)
return max_frame, total_reward, terminated, truncated, info
def reset(self, **kwargs):
self.obs_buffer.clear()
obs, info = self.env.reset(**kwargs)
self.obs_buffer.append(obs)
return obs, info
def make_atari_env(env_name):
"""创建Atari环境(带所有预处理)"""
env = gym.make(env_name, render_mode=None)
env = NoopResetEnv(env)
env = MaxAndSkipEnv(env)
return env
3. DQN网络架构¶
Python
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
class DQN(nn.Module): # 继承nn.Module定义网络层
"""DQN网络(用于Atari)"""
def __init__(self, n_actions):
super(DQN, self).__init__()
# 卷积层
self.conv = nn.Sequential(
nn.Conv2d(4, 32, kernel_size=8, stride=4),
nn.ReLU(),
nn.Conv2d(32, 64, kernel_size=4, stride=2),
nn.ReLU(),
nn.Conv2d(64, 64, kernel_size=3, stride=1),
nn.ReLU()
)
# 计算卷积输出大小
self.conv_out_size = self._get_conv_out((4, 84, 84))
# 全连接层
self.fc = nn.Sequential(
nn.Linear(self.conv_out_size, 512),
nn.ReLU(),
nn.Linear(512, n_actions)
)
def _get_conv_out(self, shape):
"""计算卷积层输出大小"""
o = self.conv(torch.zeros(1, *shape))
return int(np.prod(o.size()))
def forward(self, x):
"""前向传播"""
conv_out = self.conv(x).view(x.size(0), -1) # 重塑张量形状
return self.fc(conv_out)
class DuelingDQN(nn.Module):
"""Dueling DQN网络"""
def __init__(self, n_actions):
super(DuelingDQN, self).__init__()
# 共享卷积层
self.conv = nn.Sequential(
nn.Conv2d(4, 32, kernel_size=8, stride=4),
nn.ReLU(),
nn.Conv2d(32, 64, kernel_size=4, stride=2),
nn.ReLU(),
nn.Conv2d(64, 64, kernel_size=3, stride=1),
nn.ReLU()
)
self.conv_out_size = self._get_conv_out((4, 84, 84))
# 共享全连接
self.fc_shared = nn.Sequential(
nn.Linear(self.conv_out_size, 512),
nn.ReLU()
)
# Value流
self.value_stream = nn.Sequential(
nn.Linear(512, 256),
nn.ReLU(),
nn.Linear(256, 1)
)
# Advantage流
self.advantage_stream = nn.Sequential(
nn.Linear(512, 256),
nn.ReLU(),
nn.Linear(256, n_actions)
)
def _get_conv_out(self, shape):
o = self.conv(torch.zeros(1, *shape))
return int(np.prod(o.size()))
def forward(self, x):
conv_out = self.conv(x).view(x.size(0), -1)
shared_out = self.fc_shared(conv_out)
value = self.value_stream(shared_out)
advantage = self.advantage_stream(shared_out)
# Dueling聚合
q = value + (advantage - advantage.mean(dim=1, keepdim=True))
return q
4. 完整DQN智能体¶
Python
class ReplayBuffer:
"""经验回放缓冲区"""
def __init__(self, capacity=100000):
self.buffer = deque(maxlen=capacity)
def push(self, state, action, reward, next_state, done):
self.buffer.append((state, action, reward, next_state, done))
def sample(self, batch_size):
batch = random.sample(self.buffer, batch_size)
states, actions, rewards, next_states, dones = zip(*batch) # zip按位置配对
return (np.array(states), np.array(actions), # np.array创建NumPy数组
np.array(rewards), np.array(next_states),
np.array(dones))
def __len__(self): # __len__定义len()行为
return len(self.buffer)
class AtariDQNAgent:
"""Atari DQN智能体"""
def __init__(self, n_actions, lr=1e-4, gamma=0.99,
epsilon_start=1.0, epsilon_end=0.01, epsilon_decay=0.995,
buffer_size=100000, batch_size=32, target_update=1000):
self.n_actions = n_actions
self.gamma = gamma
self.epsilon = epsilon_start
self.epsilon_end = epsilon_end
self.epsilon_decay = epsilon_decay
self.batch_size = batch_size
self.target_update = target_update
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# 网络
self.policy_net = DuelingDQN(n_actions).to(self.device)
self.target_net = DuelingDQN(n_actions).to(self.device)
self.target_net.load_state_dict(self.policy_net.state_dict())
self.target_net.eval()
self.optimizer = optim.Adam(self.policy_net.parameters(), lr=lr)
self.replay_buffer = ReplayBuffer(buffer_size)
self.steps = 0
self.preprocessor = AtariPreprocessor()
def select_action(self, state, training=True):
"""选择动作"""
if training and random.random() < self.epsilon:
return random.randrange(self.n_actions)
with torch.no_grad(): # 禁用梯度计算,节省内存
state = torch.FloatTensor(state).unsqueeze(0).to(self.device) # unsqueeze增加一个维度
q_values = self.policy_net(state)
return q_values.argmax().item() # 将单元素张量转为Python数值
def store_transition(self, state, action, reward, next_state, done):
"""存储转移"""
# 裁剪奖励
reward = np.clip(reward, -1, 1)
self.replay_buffer.push(state, action, reward, next_state, done)
def update(self):
"""更新网络"""
if len(self.replay_buffer) < self.batch_size:
return None
states, actions, rewards, next_states, dones = self.replay_buffer.sample(self.batch_size)
states = torch.FloatTensor(states).to(self.device)
actions = torch.LongTensor(actions).to(self.device)
rewards = torch.FloatTensor(rewards).to(self.device)
next_states = torch.FloatTensor(next_states).to(self.device)
dones = torch.FloatTensor(dones).to(self.device)
# 当前Q值
current_q = self.policy_net(states).gather(1, actions.unsqueeze(1))
# Double DQN目标
with torch.no_grad():
next_actions = self.policy_net(next_states).argmax(1)
next_q = self.target_net(next_states).gather(1, next_actions.unsqueeze(1)).squeeze(1) # squeeze压缩维度
target_q = rewards + (1 - dones) * self.gamma * next_q
loss = F.mse_loss(current_q.squeeze(), target_q) # F.xxx PyTorch函数式API
self.optimizer.zero_grad() # 清零梯度
loss.backward() # 反向传播计算梯度
torch.nn.utils.clip_grad_norm_(self.policy_net.parameters(), max_norm=10)
self.optimizer.step() # 更新参数
# 更新目标网络
self.steps += 1
if self.steps % self.target_update == 0:
self.target_net.load_state_dict(self.policy_net.state_dict())
# 衰减探索率
self.epsilon = max(self.epsilon_end, self.epsilon * self.epsilon_decay)
return loss.item()
def save(self, path):
"""保存模型"""
torch.save({
'policy_net': self.policy_net.state_dict(),
'target_net': self.target_net.state_dict(),
'optimizer': self.optimizer.state_dict(),
'epsilon': self.epsilon,
'steps': self.steps
}, path)
def load(self, path):
"""加载模型"""
checkpoint = torch.load(path, weights_only=True)
self.policy_net.load_state_dict(checkpoint['policy_net'])
self.target_net.load_state_dict(checkpoint['target_net'])
self.optimizer.load_state_dict(checkpoint['optimizer'])
self.epsilon = checkpoint['epsilon']
self.steps = checkpoint['steps']
5. 训练循环¶
Python
def train_atari(env_name='Breakout-v4', num_frames=1000000):
"""训练Atari智能体"""
# 创建环境
env = make_atari_env(env_name)
n_actions = env.action_space.n
# 创建智能体
agent = AtariDQNAgent(n_actions)
preprocessor = AtariPreprocessor()
# 训练统计
episode_rewards = []
episode_lengths = []
losses = []
frame_count = 0
episode = 0
while frame_count < num_frames:
obs, _ = env.reset()
preprocessor.reset()
state = preprocessor.add_frame(obs)
episode_reward = 0
episode_length = 0
done = False
while not done:
# 选择动作
action = agent.select_action(state, training=True)
# 执行动作
next_obs, reward, terminated, truncated, _ = env.step(action)
done = terminated or truncated
# 预处理下一状态
next_state = preprocessor.add_frame(next_obs)
# 存储转移
agent.store_transition(state, action, reward, next_state, done)
# 更新网络
loss = agent.update()
if loss is not None:
losses.append(loss)
episode_reward += reward
episode_length += 1
frame_count += 1
state = next_state
# 定期保存和打印
if frame_count % 10000 == 0:
avg_reward = np.mean(episode_rewards[-100:]) if episode_rewards else 0
avg_loss = np.mean(losses[-1000:]) if losses else 0
print(f"Frame {frame_count}/{num_frames}, "
f"Episode {episode}, "
f"Avg Reward: {avg_reward:.2f}, "
f"Epsilon: {agent.epsilon:.3f}, "
f"Avg Loss: {avg_loss:.4f}")
# 保存模型
agent.save(f'atari_dqn_{frame_count}.pth')
episode_rewards.append(episode_reward)
episode_lengths.append(episode_length)
episode += 1
env.close()
return agent, episode_rewards, losses
# 主程序
if __name__ == "__main__":
agent, rewards, losses = train_atari('Breakout-v4', num_frames=1000000)
6. 项目总结¶
学到的技能¶
- 图像预处理(灰度、缩放、归一化)
- 帧堆叠和跳帧
- CNN架构设计
- 模型保存和加载
- 长时间训练管理
关键概念¶
Text Only
Atari DQN:
├── 预处理: 灰度 + 缩放 + 归一化
├── 帧处理: 跳帧 + 帧堆叠
├── 网络: CNN + Dueling
├── 训练: Double DQN + 经验回放
└── 技巧: 奖励裁剪 + 目标网络
恭喜完成Atari游戏项目!
→ 下一个项目:项目4-连续控制.md