跳转至

02 - SARSA算法:On-Policy TD控制

学习时间: 3-4小时 重要性: ⭐⭐⭐⭐⭐ 第一个实用的TD控制算法 前置知识: TD(0)预测、策略改进


🎯 学习目标

完成本章后,你将能够: - 理解SARSA算法的核心思想 - 掌握On-Policy学习的特点 - 实现SARSA算法解决控制问题 - 分析SARSA的收敛性 - 理解SARSA与Q-Learning的区别


1. SARSA算法简介

1.1 从预测到控制

TD(0) 解决了预测问题:给定策略π,估计V^π。

SARSA 解决控制问题:找到最优策略π*。

关键洞察

使用TD学习来估计动作值函数Q(s,a),然后通过贪心策略改进。

1.2 SARSA名称的由来

SARSA代表算法使用的5个变量:

Text Only
S → A → R → S' → A'

S: 当前状态 (State)
A: 当前动作 (Action)
R: 获得的奖励 (Reward)
S': 下一个状态 (Next State)
A': 下一个动作 (Next Action)

1.3 SARSA更新规则

\[Q(S_t, A_t) \leftarrow Q(S_t, A_t) + \alpha [R_{t+1} + \gamma Q(S_{t+1}, A_{t+1}) - Q(S_t, A_t)]\]

与TD(0)的对比

算法 更新目标 学习对象
TD(0) \(R + \gamma V(s')\) 状态值函数V(s)
SARSA \(R + \gamma Q(s', a')\) 动作值函数Q(s,a)

2. SARSA算法详解

2.1 算法流程

Text Only
初始化: Q(s,a) = 0, ∀s,a

对每个episode:
    初始化状态S
    使用ε-贪婪策略选择动作A

    当S不是终止状态时:
        执行动作A, 观察R, S'
        使用ε-贪婪策略选择动作A'

        Q(S,A) ← Q(S,A) + α[R + γQ(S',A') - Q(S,A)]

        S ← S'
        A ← A'

2.2 代码实现

Python
import numpy as np
from collections import defaultdict

def sarsa(env, num_episodes=1000, alpha=0.1, gamma=0.9, epsilon=0.1):
    """
    SARSA算法实现

    参数:
        env: 环境
        num_episodes: 训练的episode数量
        alpha: 学习率
        gamma: 折扣因子
        epsilon: ε-贪婪探索率

    返回:
        Q: 动作值函数
        policy: 学到的策略
        rewards_history: 每episode的总奖励
    """
    n_actions = env.get_action_space()
    Q = defaultdict(lambda: np.zeros(n_actions))  # defaultdict访问不存在的键时返回默认值
    rewards_history = []

    def epsilon_greedy(state, Q, epsilon):
        """ε-贪婪策略"""
        if np.random.random() < epsilon:
            return np.random.randint(n_actions)
        else:
            return np.argmax(Q[state])

    for episode in range(num_episodes):
        state, _ = env.reset()
        action = epsilon_greedy(state, Q, epsilon)

        total_reward = 0
        done = False

        while not done:
            # 执行动作
            next_state, reward, terminated, truncated, _ = env.step(action)
            done = terminated or truncated
            total_reward += reward

            # 选择下一个动作(ε-贪婪)
            next_action = epsilon_greedy(next_state, Q, epsilon)

            # SARSA更新
            if not done:
                td_target = reward + gamma * Q[next_state][next_action]
            else:
                td_target = reward

            td_error = td_target - Q[state][action]
            Q[state][action] += alpha * td_error

            # 转移到下一个状态
            state = next_state
            action = next_action

        rewards_history.append(total_reward)

        # 衰减探索率
        epsilon = max(0.01, epsilon * 0.995)

        if (episode + 1) % 100 == 0:
            avg_reward = np.mean(rewards_history[-100:])
            print(f"Episode {episode + 1}/{num_episodes}, "
                  f"Avg Reward: {avg_reward:.2f}, Epsilon: {epsilon:.3f}")

    # 提取最终策略
    policy = {}
    for state in Q.keys():
        policy[state] = np.argmax(Q[state])

    return Q, policy, rewards_history

2.3 为什么叫"On-Policy"?

On-Policy意味着: - 学习的策略和执行的策略是同一个 - SARSA使用实际执行的下一个动作A'来更新 - 如果A'是随机探索的动作,Q值也会反映这一点

对比Off-Policy: - Off-Policy使用最优动作的Q值来更新,不管实际执行什么


3. SARSA的收敛性

3.1 收敛定理

定理(SARSA收敛性)

在满足以下条件时,SARSA以概率1收敛到最优Q函数:

  1. 所有状态-动作对被无限次访问
  2. 策略是GLIE(Greedy in the Limit with Infinite Exploration)
  3. 步长满足Robbins-Monro条件

GLIE策略: - 探索率ε_t → 0 当 t → ∞ - 但每个状态-动作对被无限次访问

3.2 收敛速率

SARSA的收敛速率:\(O(1/\sqrt{T})\),其中T是总的时间步数。

影响因素: - 学习率α的大小 - 探索率ε的衰减速度 - 环境的随机性


4. SARSA vs Q-Learning

4.1 核心区别

特性 SARSA Q-Learning
策略类型 On-Policy Off-Policy
更新目标 \(R + \gamma Q(s', a')\) \(R + \gamma \max_{a'} Q(s', a')\)
探索处理 考虑实际探索 假设最优动作
风险偏好 保守 激进

4.2 直观对比

悬崖行走问题(Cliff Walking)

Text Only
S . . . . . . . . . . G
. . . . . . . . . . . .
. . . . . . . . . . . .
X X X X X X X X X X X X

S: 起点, G: 终点, X: 悬崖(掉下去奖励-100)
  • SARSA:学习沿着悬崖上方安全路径走(考虑可能掉下去的风险)
  • Q-Learning:学习最优路径(紧贴悬崖,但不考虑探索时的风险)

4.3 何时使用SARSA?

使用SARSA当: - 需要安全的策略 - 探索成本很高 - 实际执行中需要考虑探索

使用Q-Learning当: - 只关心最优策略 - 可以分离探索和学习 - 愿意承担探索风险


5. SARSA(λ):带资格迹的SARSA

5.1 动机

标准SARSA只使用一步信息,SARSA(λ)使用多步信息。

5.2 资格迹

资格迹 \(E_t(s,a)\) 记录状态-动作对的"贡献":

Text Only
初始化: E(s,a) = 0, ∀s,a

对每个episode:
    初始化S, A
    E(s,a) = 0, ∀s,a

    当S不是终止状态时:
        执行A, 观察R, S'
        选择A'(ε-贪婪)

        δ = R + γQ(S',A') - Q(S,A)
        E(S,A) = E(S,A) + 1  # 增量式

        对所有s,a:
            Q(s,a) = Q(s,a) + α * δ * E(s,a)
            E(s,a) = γ * λ * E(s,a)  # 衰减

        S = S', A = A'

5.3 代码实现

Python
def sarsa_lambda(env, num_episodes=1000, alpha=0.1, gamma=0.9,
                 epsilon=0.1, lambda_=0.9):
    """
    SARSA(λ)算法

    参数:
        lambda_: 迹衰减参数
    """
    n_actions = env.get_action_space()
    Q = defaultdict(lambda: np.zeros(n_actions))

    def epsilon_greedy(state, Q, epsilon):
        if np.random.random() < epsilon:
            return np.random.randint(n_actions)
        return np.argmax(Q[state])

    for episode in range(num_episodes):
        # 初始化资格迹
        E = defaultdict(lambda: np.zeros(n_actions))

        state, _ = env.reset()
        action = epsilon_greedy(state, Q, epsilon)

        done = False
        while not done:
            next_state, reward, terminated, truncated, _ = env.step(action)
            done = terminated or truncated
            next_action = epsilon_greedy(next_state, Q, epsilon)

            # 计算TD误差
            if not done:
                td_target = reward + gamma * Q[next_state][next_action]
            else:
                td_target = reward
            td_error = td_target - Q[state][action]

            # 更新资格迹
            E[state][action] += 1

            # 更新所有Q值
            for s in E.keys():
                for a in range(n_actions):
                    Q[s][a] += alpha * td_error * E[s][a]
                    E[s][a] *= gamma * lambda_

            state = next_state
            action = next_action

    return Q

6. 实践练习

练习1:Windy Grid World

Python
class WindyGridWorld:
    """有风网格世界"""

    def __init__(self, size=7):
        self.size = size
        self.start = (3, 0)
        self.goal = (3, 7)
        # 每列的风力(向上)
        self.wind = [0, 0, 0, 1, 1, 1, 2, 2, 1, 0]

    def step(self, action):
        x, y = self.agent_pos

        # 动作效果
        if action == 0:  # 上
            x -= 1
        elif action == 1:  # 下
            x += 1
        elif action == 2:  # 左
            y -= 1
        elif action == 3:  # 右
            y += 1

        # 加上风力
        x -= self.wind[y]

        # 边界检查
        x = np.clip(x, 0, self.size - 1)
        y = np.clip(y, 0, 9)

        self.agent_pos = (x, y)
        reward = -1
        done = (self.agent_pos == self.goal)

        return self._get_state(), reward, done, False, {}

# 用SARSA解决
env = WindyGridWorld()
Q, policy, rewards = sarsa(env, num_episodes=5000, alpha=0.5, epsilon=0.1)

练习2:SARSA vs Q-Learning对比

Python
def compare_algorithms(env, num_runs=10, num_episodes=1000):
    """对比SARSA和Q-Learning"""

    sarsa_rewards = []
    qlearning_rewards = []

    for run in range(num_runs):
        # SARSA
        _, _, sarsa_hist = sarsa(env, num_episodes)
        sarsa_rewards.append(sarsa_hist)

        # Q-Learning
        _, _, q_hist = q_learning(env, num_episodes)
        qlearning_rewards.append(q_hist)

    # 绘制对比图
    plt.figure(figsize=(10, 6))

    sarsa_mean = np.mean(sarsa_rewards, axis=0)
    q_mean = np.mean(qlearning_rewards, axis=0)

    plt.plot(sarsa_mean, label='SARSA', alpha=0.7)
    plt.plot(q_mean, label='Q-Learning', alpha=0.7)

    plt.xlabel('Episode')
    plt.ylabel('Total Reward')
    plt.title('SARSA vs Q-Learning')
    plt.legend()
    plt.grid(True, alpha=0.3)
    plt.show()

7. 本章总结

核心概念

Text Only
SARSA:
├── On-Policy: 学习和执行同一策略
├── 更新: Q(s,a) ← Q(s,a) + α[R + γQ(s',a') - Q(s,a)]
├── 使用实际执行的下一个动作a'
└── 更保守,考虑探索风险

SARSA(λ):
├── 使用资格迹
├── 多步更新
└── λ控制迹衰减

学习路径

Text Only
下一步:
├── Q-Learning - Off-Policy TD控制
├── 探索与利用 - 更深入的探索策略
└── 多步方法 - n-step和TD(λ)

✅ 自测问题

  1. 为什么SARSA被称为"On-Policy"算法?

  2. 在悬崖行走问题中,SARSA和Q-Learning会学到什么不同的策略?为什么?

  3. SARSA(λ)中的λ参数有什么作用?λ=0和λ=1分别对应什么算法?

  4. 设计一个场景,说明什么时候应该使用SARSA而不是Q-Learning。

  5. SARSA的收敛需要什么条件?为什么需要GLIE策略?


📚 延伸阅读

  1. Rummery & Niranjan (1994)
  2. "On-line Q-learning using connectionist systems"
  3. SARSA算法的原始论文

  4. Sutton & Barto (2018)

  5. 《Reinforcement Learning: An Introduction》第6章
  6. SARSA的详细讲解

准备好学习Off-Policy的Q-Learning了吗?

→ 下一步:03-Q-Learning算法.md