02 - SARSA算法:On-Policy TD控制¶
学习时间: 3-4小时 重要性: ⭐⭐⭐⭐⭐ 第一个实用的TD控制算法 前置知识: TD(0)预测、策略改进
🎯 学习目标¶
完成本章后,你将能够: - 理解SARSA算法的核心思想 - 掌握On-Policy学习的特点 - 实现SARSA算法解决控制问题 - 分析SARSA的收敛性 - 理解SARSA与Q-Learning的区别
1. SARSA算法简介¶
1.1 从预测到控制¶
TD(0) 解决了预测问题:给定策略π,估计V^π。
SARSA 解决控制问题:找到最优策略π*。
关键洞察:
使用TD学习来估计动作值函数Q(s,a),然后通过贪心策略改进。
1.2 SARSA名称的由来¶
SARSA代表算法使用的5个变量:
S → A → R → S' → A'
S: 当前状态 (State)
A: 当前动作 (Action)
R: 获得的奖励 (Reward)
S': 下一个状态 (Next State)
A': 下一个动作 (Next Action)
1.3 SARSA更新规则¶
与TD(0)的对比:
| 算法 | 更新目标 | 学习对象 |
|---|---|---|
| TD(0) | \(R + \gamma V(s')\) | 状态值函数V(s) |
| SARSA | \(R + \gamma Q(s', a')\) | 动作值函数Q(s,a) |
2. SARSA算法详解¶
2.1 算法流程¶
初始化: Q(s,a) = 0, ∀s,a
对每个episode:
初始化状态S
使用ε-贪婪策略选择动作A
当S不是终止状态时:
执行动作A, 观察R, S'
使用ε-贪婪策略选择动作A'
Q(S,A) ← Q(S,A) + α[R + γQ(S',A') - Q(S,A)]
S ← S'
A ← A'
2.2 代码实现¶
import numpy as np
from collections import defaultdict
def sarsa(env, num_episodes=1000, alpha=0.1, gamma=0.9, epsilon=0.1):
"""
SARSA算法实现
参数:
env: 环境
num_episodes: 训练的episode数量
alpha: 学习率
gamma: 折扣因子
epsilon: ε-贪婪探索率
返回:
Q: 动作值函数
policy: 学到的策略
rewards_history: 每episode的总奖励
"""
n_actions = env.get_action_space()
Q = defaultdict(lambda: np.zeros(n_actions)) # defaultdict访问不存在的键时返回默认值
rewards_history = []
def epsilon_greedy(state, Q, epsilon):
"""ε-贪婪策略"""
if np.random.random() < epsilon:
return np.random.randint(n_actions)
else:
return np.argmax(Q[state])
for episode in range(num_episodes):
state, _ = env.reset()
action = epsilon_greedy(state, Q, epsilon)
total_reward = 0
done = False
while not done:
# 执行动作
next_state, reward, terminated, truncated, _ = env.step(action)
done = terminated or truncated
total_reward += reward
# 选择下一个动作(ε-贪婪)
next_action = epsilon_greedy(next_state, Q, epsilon)
# SARSA更新
if not done:
td_target = reward + gamma * Q[next_state][next_action]
else:
td_target = reward
td_error = td_target - Q[state][action]
Q[state][action] += alpha * td_error
# 转移到下一个状态
state = next_state
action = next_action
rewards_history.append(total_reward)
# 衰减探索率
epsilon = max(0.01, epsilon * 0.995)
if (episode + 1) % 100 == 0:
avg_reward = np.mean(rewards_history[-100:])
print(f"Episode {episode + 1}/{num_episodes}, "
f"Avg Reward: {avg_reward:.2f}, Epsilon: {epsilon:.3f}")
# 提取最终策略
policy = {}
for state in Q.keys():
policy[state] = np.argmax(Q[state])
return Q, policy, rewards_history
2.3 为什么叫"On-Policy"?¶
On-Policy意味着: - 学习的策略和执行的策略是同一个 - SARSA使用实际执行的下一个动作A'来更新 - 如果A'是随机探索的动作,Q值也会反映这一点
对比Off-Policy: - Off-Policy使用最优动作的Q值来更新,不管实际执行什么
3. SARSA的收敛性¶
3.1 收敛定理¶
定理(SARSA收敛性):
在满足以下条件时,SARSA以概率1收敛到最优Q函数:
- 所有状态-动作对被无限次访问
- 策略是GLIE(Greedy in the Limit with Infinite Exploration)
- 步长满足Robbins-Monro条件
GLIE策略: - 探索率ε_t → 0 当 t → ∞ - 但每个状态-动作对被无限次访问
3.2 收敛速率¶
SARSA的收敛速率:\(O(1/\sqrt{T})\),其中T是总的时间步数。
影响因素: - 学习率α的大小 - 探索率ε的衰减速度 - 环境的随机性
4. SARSA vs Q-Learning¶
4.1 核心区别¶
| 特性 | SARSA | Q-Learning |
|---|---|---|
| 策略类型 | On-Policy | Off-Policy |
| 更新目标 | \(R + \gamma Q(s', a')\) | \(R + \gamma \max_{a'} Q(s', a')\) |
| 探索处理 | 考虑实际探索 | 假设最优动作 |
| 风险偏好 | 保守 | 激进 |
4.2 直观对比¶
悬崖行走问题(Cliff Walking):
S . . . . . . . . . . G
. . . . . . . . . . . .
. . . . . . . . . . . .
X X X X X X X X X X X X
S: 起点, G: 终点, X: 悬崖(掉下去奖励-100)
- SARSA:学习沿着悬崖上方安全路径走(考虑可能掉下去的风险)
- Q-Learning:学习最优路径(紧贴悬崖,但不考虑探索时的风险)
4.3 何时使用SARSA?¶
使用SARSA当: - 需要安全的策略 - 探索成本很高 - 实际执行中需要考虑探索
使用Q-Learning当: - 只关心最优策略 - 可以分离探索和学习 - 愿意承担探索风险
5. SARSA(λ):带资格迹的SARSA¶
5.1 动机¶
标准SARSA只使用一步信息,SARSA(λ)使用多步信息。
5.2 资格迹¶
资格迹 \(E_t(s,a)\) 记录状态-动作对的"贡献":
初始化: E(s,a) = 0, ∀s,a
对每个episode:
初始化S, A
E(s,a) = 0, ∀s,a
当S不是终止状态时:
执行A, 观察R, S'
选择A'(ε-贪婪)
δ = R + γQ(S',A') - Q(S,A)
E(S,A) = E(S,A) + 1 # 增量式
对所有s,a:
Q(s,a) = Q(s,a) + α * δ * E(s,a)
E(s,a) = γ * λ * E(s,a) # 衰减
S = S', A = A'
5.3 代码实现¶
def sarsa_lambda(env, num_episodes=1000, alpha=0.1, gamma=0.9,
epsilon=0.1, lambda_=0.9):
"""
SARSA(λ)算法
参数:
lambda_: 迹衰减参数
"""
n_actions = env.get_action_space()
Q = defaultdict(lambda: np.zeros(n_actions))
def epsilon_greedy(state, Q, epsilon):
if np.random.random() < epsilon:
return np.random.randint(n_actions)
return np.argmax(Q[state])
for episode in range(num_episodes):
# 初始化资格迹
E = defaultdict(lambda: np.zeros(n_actions))
state, _ = env.reset()
action = epsilon_greedy(state, Q, epsilon)
done = False
while not done:
next_state, reward, terminated, truncated, _ = env.step(action)
done = terminated or truncated
next_action = epsilon_greedy(next_state, Q, epsilon)
# 计算TD误差
if not done:
td_target = reward + gamma * Q[next_state][next_action]
else:
td_target = reward
td_error = td_target - Q[state][action]
# 更新资格迹
E[state][action] += 1
# 更新所有Q值
for s in E.keys():
for a in range(n_actions):
Q[s][a] += alpha * td_error * E[s][a]
E[s][a] *= gamma * lambda_
state = next_state
action = next_action
return Q
6. 实践练习¶
练习1:Windy Grid World¶
class WindyGridWorld:
"""有风网格世界"""
def __init__(self, size=7):
self.size = size
self.start = (3, 0)
self.goal = (3, 7)
# 每列的风力(向上)
self.wind = [0, 0, 0, 1, 1, 1, 2, 2, 1, 0]
def step(self, action):
x, y = self.agent_pos
# 动作效果
if action == 0: # 上
x -= 1
elif action == 1: # 下
x += 1
elif action == 2: # 左
y -= 1
elif action == 3: # 右
y += 1
# 加上风力
x -= self.wind[y]
# 边界检查
x = np.clip(x, 0, self.size - 1)
y = np.clip(y, 0, 9)
self.agent_pos = (x, y)
reward = -1
done = (self.agent_pos == self.goal)
return self._get_state(), reward, done, False, {}
# 用SARSA解决
env = WindyGridWorld()
Q, policy, rewards = sarsa(env, num_episodes=5000, alpha=0.5, epsilon=0.1)
练习2:SARSA vs Q-Learning对比¶
def compare_algorithms(env, num_runs=10, num_episodes=1000):
"""对比SARSA和Q-Learning"""
sarsa_rewards = []
qlearning_rewards = []
for run in range(num_runs):
# SARSA
_, _, sarsa_hist = sarsa(env, num_episodes)
sarsa_rewards.append(sarsa_hist)
# Q-Learning
_, _, q_hist = q_learning(env, num_episodes)
qlearning_rewards.append(q_hist)
# 绘制对比图
plt.figure(figsize=(10, 6))
sarsa_mean = np.mean(sarsa_rewards, axis=0)
q_mean = np.mean(qlearning_rewards, axis=0)
plt.plot(sarsa_mean, label='SARSA', alpha=0.7)
plt.plot(q_mean, label='Q-Learning', alpha=0.7)
plt.xlabel('Episode')
plt.ylabel('Total Reward')
plt.title('SARSA vs Q-Learning')
plt.legend()
plt.grid(True, alpha=0.3)
plt.show()
7. 本章总结¶
核心概念¶
SARSA:
├── On-Policy: 学习和执行同一策略
├── 更新: Q(s,a) ← Q(s,a) + α[R + γQ(s',a') - Q(s,a)]
├── 使用实际执行的下一个动作a'
└── 更保守,考虑探索风险
SARSA(λ):
├── 使用资格迹
├── 多步更新
└── λ控制迹衰减
学习路径¶
✅ 自测问题¶
-
为什么SARSA被称为"On-Policy"算法?
-
在悬崖行走问题中,SARSA和Q-Learning会学到什么不同的策略?为什么?
-
SARSA(λ)中的λ参数有什么作用?λ=0和λ=1分别对应什么算法?
-
设计一个场景,说明什么时候应该使用SARSA而不是Q-Learning。
-
SARSA的收敛需要什么条件?为什么需要GLIE策略?
📚 延伸阅读¶
- Rummery & Niranjan (1994)
- "On-line Q-learning using connectionist systems"
-
SARSA算法的原始论文
-
Sutton & Barto (2018)
- 《Reinforcement Learning: An Introduction》第6章
- SARSA的详细讲解
准备好学习Off-Policy的Q-Learning了吗?
→ 下一步:03-Q-Learning算法.md