项目1 - CartPole平衡¶
难度: ⭐⭐ 入门级 预计时间: 2-3小时 目标: 用DQN解决经典控制问题
1. 项目介绍¶
1.1 CartPole环境¶
目标:保持杆子直立,同时小车在轨道上移动
状态空间(4维连续): - 小车位置 - 小车速度 - 杆子角度 - 杆子角速度
动作空间(2维离散): - 0:向左推 - 1:向右推
奖励:每步+1,杆子倒下或小车出界则结束
1.2 成功标准¶
- 连续100个episode平均奖励 ≥ 475
- 最大步数达到500
2. 完整代码实现¶
Python
import gymnasium as gym
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from collections import deque
import random
import matplotlib.pyplot as plt
# 设置随机种子
def set_seed(seed=42):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
# DQN网络
class DQN(nn.Module): # 继承nn.Module定义网络层
def __init__(self, state_dim, action_dim, hidden_dim=128):
super(DQN, self).__init__()
self.net = nn.Sequential(
nn.Linear(state_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, action_dim)
)
def forward(self, x):
return self.net(x)
# 经验回放缓冲区
class ReplayBuffer:
def __init__(self, capacity=10000):
self.buffer = deque(maxlen=capacity)
def push(self, state, action, reward, next_state, done):
self.buffer.append((state, action, reward, next_state, done))
def sample(self, batch_size):
batch = random.sample(self.buffer, batch_size)
states, actions, rewards, next_states, dones = zip(*batch) # zip按位置配对
return (np.array(states), np.array(actions), # np.array创建NumPy数组
np.array(rewards), np.array(next_states),
np.array(dones))
def __len__(self): # __len__定义len()行为
return len(self.buffer)
# DQN智能体
class DQNAgent:
def __init__(self, state_dim, action_dim, lr=1e-3, gamma=0.99,
epsilon_start=1.0, epsilon_end=0.01, epsilon_decay=0.995,
buffer_size=10000, batch_size=64, target_update=100):
self.state_dim = state_dim
self.action_dim = action_dim
self.gamma = gamma
self.epsilon = epsilon_start
self.epsilon_end = epsilon_end
self.epsilon_decay = epsilon_decay
self.batch_size = batch_size
self.target_update = target_update
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# 网络
self.policy_net = DQN(state_dim, action_dim).to(self.device)
self.target_net = DQN(state_dim, action_dim).to(self.device)
self.target_net.load_state_dict(self.policy_net.state_dict())
self.optimizer = optim.Adam(self.policy_net.parameters(), lr=lr)
self.replay_buffer = ReplayBuffer(buffer_size)
self.steps = 0
def select_action(self, state, training=True):
if training and random.random() < self.epsilon:
return random.randrange(self.action_dim)
with torch.no_grad(): # 禁用梯度计算,节省内存
state = torch.FloatTensor(state).unsqueeze(0).to(self.device) # unsqueeze增加一个维度
q_values = self.policy_net(state)
return q_values.argmax().item() # 将单元素张量转为Python数值
def store_transition(self, state, action, reward, next_state, done):
self.replay_buffer.push(state, action, reward, next_state, done)
def update(self):
if len(self.replay_buffer) < self.batch_size:
return None
states, actions, rewards, next_states, dones = self.replay_buffer.sample(self.batch_size)
states = torch.FloatTensor(states).to(self.device)
actions = torch.LongTensor(actions).to(self.device)
rewards = torch.FloatTensor(rewards).to(self.device)
next_states = torch.FloatTensor(next_states).to(self.device)
dones = torch.FloatTensor(dones).to(self.device)
# 当前Q值
current_q = self.policy_net(states).gather(1, actions.unsqueeze(1))
# 目标Q值
with torch.no_grad():
next_q = self.target_net(next_states).max(1)[0]
target_q = rewards + (1 - dones) * self.gamma * next_q
loss = nn.MSELoss()(current_q.squeeze(), target_q) # squeeze压缩维度
self.optimizer.zero_grad() # 清零梯度
loss.backward() # 反向传播计算梯度
self.optimizer.step() # 更新参数
# 更新目标网络
self.steps += 1
if self.steps % self.target_update == 0:
self.target_net.load_state_dict(self.policy_net.state_dict())
# 衰减探索率
self.epsilon = max(self.epsilon_end, self.epsilon * self.epsilon_decay)
return loss.item()
def save(self, path):
torch.save(self.policy_net.state_dict(), path)
def load(self, path):
self.policy_net.load_state_dict(torch.load(path, weights_only=True))
self.target_net.load_state_dict(self.policy_net.state_dict())
# 训练函数
def train_dqn(env_name='CartPole-v1', num_episodes=1000, render=False):
set_seed(42)
env = gym.make(env_name)
state_dim = env.observation_space.shape[0]
action_dim = env.action_space.n
agent = DQNAgent(state_dim, action_dim)
rewards_history = []
losses_history = []
for episode in range(num_episodes):
state, _ = env.reset()
total_reward = 0
done = False
while not done:
if render:
env.render()
action = agent.select_action(state, training=True)
next_state, reward, terminated, truncated, _ = env.step(action)
done = terminated or truncated
agent.store_transition(state, action, reward, next_state, done)
loss = agent.update()
if loss is not None:
losses_history.append(loss)
total_reward += reward
state = next_state
rewards_history.append(total_reward)
# 打印进度
if (episode + 1) % 100 == 0:
avg_reward = np.mean(rewards_history[-100:])
avg_loss = np.mean(losses_history[-100:]) if losses_history else 0
print(f"Episode {episode + 1}, Avg Reward: {avg_reward:.2f}, "
f"Epsilon: {agent.epsilon:.3f}, Avg Loss: {avg_loss:.4f}")
# 检查是否解决
if avg_reward >= 475:
print(f"\n✓ 环境已解决!在第 {episode + 1} 个episode")
agent.save('cartpole_dqn.pth')
break
env.close()
return agent, rewards_history, losses_history
# 可视化
def plot_results(rewards, losses):
fig, axes = plt.subplots(1, 2, figsize=(14, 5))
# 奖励曲线
axes[0].plot(rewards, alpha=0.3, label='Raw')
if len(rewards) >= 100:
moving_avg = np.convolve(rewards, np.ones(100)/100, mode='valid')
axes[0].plot(range(99, len(rewards)), moving_avg, label='MA(100)', linewidth=2)
axes[0].axhline(y=475, color='r', linestyle='--', label='Solved (475)')
axes[0].set_xlabel('Episode')
axes[0].set_ylabel('Total Reward')
axes[0].set_title('Training Rewards')
axes[0].legend()
axes[0].grid(True, alpha=0.3)
# 损失曲线
if losses:
axes[1].plot(losses, alpha=0.3)
if len(losses) >= 100:
moving_avg = np.convolve(losses, np.ones(100)/100, mode='valid')
axes[1].plot(range(99, len(losses)), moving_avg, linewidth=2)
axes[1].set_xlabel('Update Step')
axes[1].set_ylabel('Loss')
axes[1].set_title('Training Loss')
axes[1].grid(True, alpha=0.3)
plt.tight_layout()
plt.savefig('cartpole_training.png', dpi=150)
plt.show()
# 测试函数
def test_agent(env_name='CartPole-v1', model_path='cartpole_dqn.pth', num_episodes=10):
env = gym.make(env_name, render_mode='human')
state_dim = env.observation_space.shape[0]
action_dim = env.action_space.n
agent = DQNAgent(state_dim, action_dim)
agent.load(model_path)
for episode in range(num_episodes):
state, _ = env.reset()
total_reward = 0
done = False
while not done:
action = agent.select_action(state, training=False)
state, reward, terminated, truncated, _ = env.step(action)
done = terminated or truncated
total_reward += reward
print(f"Test Episode {episode + 1}: Reward = {total_reward}")
env.close()
# 主程序
if __name__ == "__main__":
# 训练
print("开始训练...")
agent, rewards, losses = train_dqn(num_episodes=1000)
# 可视化
plot_results(rewards, losses)
# 测试
print("\n开始测试...")
test_agent()
3. 实验任务¶
任务1:调参实验¶
尝试不同的超参数组合:
Python
experiments = [
{'lr': 1e-3, 'batch_size': 64, 'target_update': 100},
{'lr': 5e-4, 'batch_size': 32, 'target_update': 50},
{'lr': 1e-4, 'batch_size': 128, 'target_update': 200},
]
for i, config in enumerate(experiments): # enumerate同时获取索引和元素
print(f"\n实验 {i+1}: {config}")
agent = DQNAgent(state_dim, action_dim, **config)
# ... 训练并记录结果
任务2:网络架构对比¶
比较不同网络深度的效果:
Python
class ShallowDQN(nn.Module):
def __init__(self, state_dim, action_dim):
super().__init__() # super()调用父类方法
self.net = nn.Sequential(
nn.Linear(state_dim, 64),
nn.ReLU(),
nn.Linear(64, action_dim)
)
class DeepDQN(nn.Module):
def __init__(self, state_dim, action_dim):
super().__init__()
self.net = nn.Sequential(
nn.Linear(state_dim, 256),
nn.ReLU(),
nn.Linear(256, 256),
nn.ReLU(),
nn.Linear(256, 128),
nn.ReLU(),
nn.Linear(128, action_dim)
)
任务3:对比不同算法¶
实现并对比: - DQN - Double DQN - Dueling DQN
4. 常见问题¶
Q1: 训练不稳定¶
解决方案: - 降低学习率 - 增加目标网络更新间隔 - 使用梯度裁剪
Q2: 无法收敛¶
解决方案: - 检查奖励是否归一化 - 增加网络容量 - 调整探索率衰减速度
Q3: 过拟合¶
解决方案: - 增加经验回放缓冲区大小 - 使用Dropout - 早停
5. 扩展挑战¶
挑战1:连续动作版本¶
使用DDPG或SAC解决连续动作版本的CartPole
挑战2:图像输入¶
使用CNN处理渲染的图像而非状态向量
挑战3:多智能体¶
训练多个智能体协作保持多个杆子平衡
6. 项目总结¶
学到的技能¶
- DQN完整实现
- 经验回放
- 目标网络
- 超参数调优
- 模型保存和加载
关键概念¶
Text Only
CartPole with DQN:
├── 状态: 4维连续向量
├── 动作: 2维离散
├── 网络: MLP (输入4, 输出2)
├── 训练: 经验回放 + 目标网络
└── 成功标准: 平均奖励≥475
恭喜完成第一个RL项目!
→ 下一个项目:项目2-迷宫求解.md