项目2 - 迷宫求解¶
难度: ⭐⭐⭐ 中级 预计时间: 3-4小时 目标: 对比不同RL算法在迷宫问题上的表现
1. 项目介绍¶
1.1 迷宫环境¶
目标:从起点找到通往终点的最短路径
特点: - 可配置迷宫大小 - 可添加障碍物 - 支持多种奖励设置
状态表示: - 智能体位置 (x, y) - 可扩展为部分可观察(只能看到周围)
动作空间: - 0: 上 - 1: 下 - 2: 左 - 3: 右
1.2 成功标准¶
- 找到从起点到终点的路径
- 路径尽可能短
- 对比不同算法的收敛速度
2. 迷宫环境实现¶
Python
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.patches import Rectangle
import random
class MazeEnv:
"""迷宫环境"""
def __init__(self, size=10, obstacle_prob=0.2):
self.size = size
self.start = (0, 0)
self.goal = (size-1, size-1)
self.obstacle_prob = obstacle_prob
# 生成迷宫
self.maze = self._generate_maze()
self.reset()
def _generate_maze(self):
"""生成随机迷宫"""
maze = np.zeros((self.size, self.size))
# 添加障碍物
for i in range(self.size):
for j in range(self.size):
if (i, j) != self.start and (i, j) != self.goal:
if random.random() < self.obstacle_prob:
maze[i, j] = 1
# 确保有路径(简化:使用BFS检查)
if not self._has_path(maze):
return self._generate_maze()
return maze
def _has_path(self, maze):
"""检查是否存在从起点到终点的路径"""
from collections import deque
visited = set()
queue = deque([self.start])
visited.add(self.start)
while queue:
x, y = queue.popleft()
if (x, y) == self.goal:
return True
# 四个方向
for dx, dy in [(-1, 0), (1, 0), (0, -1), (0, 1)]:
nx, ny = x + dx, y + dy
if 0 <= nx < self.size and 0 <= ny < self.size:
if maze[nx, ny] == 0 and (nx, ny) not in visited:
visited.add((nx, ny))
queue.append((nx, ny))
return False
def reset(self):
"""重置环境"""
self.agent_pos = self.start
return self._get_state(), {}
def _get_state(self):
"""获取状态编号"""
return self.agent_pos[0] * self.size + self.agent_pos[1]
def step(self, action):
"""
执行动作
动作: 0=上, 1=下, 2=左, 3=右
"""
x, y = self.agent_pos
if action == 0:
x = max(0, x - 1)
elif action == 1:
x = min(self.size - 1, x + 1)
elif action == 2:
y = max(0, y - 1)
elif action == 3:
y = min(self.size - 1, y + 1)
# 检查是否撞墙(障碍物)
if self.maze[x, y] == 1:
# 撞到障碍物,保持原位不动
x, y = self.agent_pos
reward = -1.0
done = False
elif (x, y) == self.goal:
reward = 10.0
done = True
else:
reward = -0.01 # 每步小惩罚
done = False
self.agent_pos = (x, y)
return self._get_state(), reward, done, False, {}
def render(self, path=None):
"""可视化迷宫"""
fig, ax = plt.subplots(figsize=(8, 8))
# 绘制迷宫
for i in range(self.size):
for j in range(self.size):
if self.maze[i, j] == 1:
ax.add_patch(Rectangle((j, self.size-1-i), 1, 1,
facecolor='black', edgecolor='gray'))
else:
ax.add_patch(Rectangle((j, self.size-1-i), 1, 1,
facecolor='white', edgecolor='gray'))
# 标记起点和终点
ax.add_patch(Rectangle((self.start[1], self.size-1-self.start[0]), 1, 1,
facecolor='green', edgecolor='gray', alpha=0.5))
ax.add_patch(Rectangle((self.goal[1], self.size-1-self.goal[0]), 1, 1,
facecolor='red', edgecolor='gray', alpha=0.5))
# 绘制路径
if path:
for i, (x, y) in enumerate(path): # enumerate同时获取索引和元素
ax.add_patch(Rectangle((y, self.size-1-x), 1, 1,
facecolor='blue', edgecolor='gray', alpha=0.3))
# 标记智能体位置
ax.add_patch(Rectangle((self.agent_pos[1], self.size-1-self.agent_pos[0]), 1, 1,
facecolor='yellow', edgecolor='gray'))
ax.set_xlim(0, self.size)
ax.set_ylim(0, self.size)
ax.set_aspect('equal')
ax.axis('off')
plt.show()
def get_state_space(self):
return self.size * self.size
def get_action_space(self):
return 4
3. 算法对比实现¶
3.1 SARSA¶
Python
import numpy as np
from collections import defaultdict
class SARSAAgent:
"""SARSA智能体"""
def __init__(self, n_states, n_actions, alpha=0.1, gamma=0.9, epsilon=0.1):
self.n_actions = n_actions
self.alpha = alpha
self.gamma = gamma
self.epsilon = epsilon
self.Q = defaultdict(lambda: np.zeros(n_actions)) # defaultdict访问不存在的键时返回默认值
def select_action(self, state):
"""ε-贪婪策略"""
if np.random.random() < self.epsilon:
return np.random.randint(self.n_actions)
return np.argmax(self.Q[state])
def update(self, state, action, reward, next_state, next_action, done):
"""SARSA更新"""
if done:
target = reward
else:
target = reward + self.gamma * self.Q[next_state][next_action]
td_error = target - self.Q[state][action]
self.Q[state][action] += self.alpha * td_error
def train(self, env, num_episodes=1000):
"""训练"""
rewards_history = []
steps_history = []
for episode in range(num_episodes):
state, _ = env.reset()
action = self.select_action(state)
total_reward = 0
steps = 0
done = False
while not done:
next_state, reward, terminated, truncated, _ = env.step(action)
done = terminated or truncated
next_action = self.select_action(next_state)
self.update(state, action, reward, next_state, next_action, done)
total_reward += reward
steps += 1
state = next_state
action = next_action
rewards_history.append(total_reward)
steps_history.append(steps)
# 衰减探索率
self.epsilon = max(0.01, self.epsilon * 0.995)
return rewards_history, steps_history
def get_policy(self):
"""获取策略"""
policy = {}
for state in self.Q.keys():
policy[state] = np.argmax(self.Q[state])
return policy
3.2 Q-Learning¶
Python
class QLearningAgent:
"""Q-Learning智能体"""
def __init__(self, n_states, n_actions, alpha=0.1, gamma=0.9, epsilon=0.1):
self.n_actions = n_actions
self.alpha = alpha
self.gamma = gamma
self.epsilon = epsilon
self.Q = defaultdict(lambda: np.zeros(n_actions))
def select_action(self, state):
"""ε-贪婪策略"""
if np.random.random() < self.epsilon:
return np.random.randint(self.n_actions)
return np.argmax(self.Q[state])
def update(self, state, action, reward, next_state, done):
"""Q-Learning更新"""
if done:
target = reward
else:
target = reward + self.gamma * np.max(self.Q[next_state])
td_error = target - self.Q[state][action]
self.Q[state][action] += self.alpha * td_error
def train(self, env, num_episodes=1000):
"""训练"""
rewards_history = []
steps_history = []
for episode in range(num_episodes):
state, _ = env.reset()
total_reward = 0
steps = 0
done = False
while not done:
action = self.select_action(state)
next_state, reward, terminated, truncated, _ = env.step(action)
done = terminated or truncated
self.update(state, action, reward, next_state, done)
total_reward += reward
steps += 1
state = next_state
rewards_history.append(total_reward)
steps_history.append(steps)
# 衰减探索率
self.epsilon = max(0.01, self.epsilon * 0.995)
return rewards_history, steps_history
def get_policy(self):
"""获取策略"""
policy = {}
for state in self.Q.keys():
policy[state] = np.argmax(self.Q[state])
return policy
4. 对比实验¶
Python
def compare_algorithms(env, num_episodes=500):
"""对比不同算法"""
n_states = env.get_state_space()
n_actions = env.get_action_space()
# SARSA
print("训练 SARSA...")
sarsa_agent = SARSAAgent(n_states, n_actions, alpha=0.1, epsilon=0.2)
sarsa_rewards, sarsa_steps = sarsa_agent.train(env, num_episodes)
# Q-Learning
print("训练 Q-Learning...")
q_agent = QLearningAgent(n_states, n_actions, alpha=0.1, epsilon=0.2)
q_rewards, q_steps = q_agent.train(env, num_episodes)
# 绘制对比
fig, axes = plt.subplots(1, 2, figsize=(14, 5))
# 奖励对比
axes[0].plot(sarsa_rewards, label='SARSA', alpha=0.7)
axes[0].plot(q_rewards, label='Q-Learning', alpha=0.7)
axes[0].set_xlabel('Episode')
axes[0].set_ylabel('Total Reward')
axes[0].set_title('Reward Comparison')
axes[0].legend()
axes[0].grid(True, alpha=0.3)
# 步数对比
axes[1].plot(sarsa_steps, label='SARSA', alpha=0.7)
axes[1].plot(q_steps, label='Q-Learning', alpha=0.7)
axes[1].set_xlabel('Episode')
axes[1].set_ylabel('Steps to Goal')
axes[1].set_title('Steps Comparison')
axes[1].legend()
axes[1].grid(True, alpha=0.3)
plt.tight_layout()
plt.savefig('maze_comparison.png', dpi=150)
plt.show()
return sarsa_agent, q_agent
def visualize_policy(env, agent, title="Policy"):
"""可视化策略"""
policy = agent.get_policy()
fig, ax = plt.subplots(figsize=(8, 8))
# 绘制迷宫
for i in range(env.size):
for j in range(env.size):
if env.maze[i, j] == 1:
ax.add_patch(Rectangle((j, env.size-1-i), 1, 1,
facecolor='black', edgecolor='gray'))
else:
ax.add_patch(Rectangle((j, env.size-1-i), 1, 1,
facecolor='white', edgecolor='gray'))
# 绘制策略箭头
action_arrows = {0: '↑', 1: '↓', 2: '←', 3: '→'}
for state, action in policy.items():
x = state // env.size
y = state % env.size
if env.maze[x, y] == 0 and (x, y) != env.goal:
ax.text(y + 0.5, env.size - 1 - x + 0.5, action_arrows[action],
ha='center', va='center', fontsize=20, color='blue')
ax.set_xlim(0, env.size)
ax.set_ylim(0, env.size)
ax.set_aspect('equal')
ax.axis('off')
ax.set_title(title)
plt.show()
# 主程序
if __name__ == "__main__":
# 创建迷宫
env = MazeEnv(size=10, obstacle_prob=0.2)
print("迷宫环境:")
env.render()
# 对比算法
sarsa_agent, q_agent = compare_algorithms(env, num_episodes=500)
# 可视化策略
print("\nSARSA策略:")
visualize_policy(env, sarsa_agent, "SARSA Policy")
print("\nQ-Learning策略:")
visualize_policy(env, q_agent, "Q-Learning Policy")
5. 项目总结¶
学到的技能¶
- 自定义环境设计
- 多种RL算法实现
- 算法对比分析
- 策略可视化
关键概念¶
恭喜完成迷宫求解项目!
→ 下一个项目:项目3-Atari游戏.md