跳转至

05 - 多步方法

学习时间: 3-4小时 重要性: ⭐⭐⭐⭐ 连接TD和MC的桥梁 前置知识: TD(0)、SARSA、资格迹


🎯 学习目标

完成本章后,你将能够: - 理解n-step方法的核心思想 - 掌握n-step SARSA和n-step Q-Learning - 理解TD(λ)的前向和后向视角 - 实现资格迹算法 - 在不同场景选择合适的方法


1. n-step方法简介

1.1 从TD(0)到MC的连续谱

TD(0):使用1步回报 $\(G_{t:t+1} = R_{t+1} + \gamma V(S_{t+1})\)$

MC:使用完整episode回报 $\(G_t = R_{t+1} + \gamma R_{t+2} + \gamma^2 R_{t+3} + \cdots\)$

n-step:使用n步回报 $\(G_{t:t+n} = R_{t+1} + \gamma R_{t+2} + \cdots + \gamma^{n-1} R_{t+n} + \gamma^n V(S_{t+n})\)$

1.2 n-step回报

定义

\[G_{t:t+n} = \sum_{i=1}^{n} \gamma^{i-1} R_{t+i} + \gamma^n V(S_{t+n})\]

特殊情况: - n=1:TD(0) - n=∞:MC


2. n-step TD预测

2.1 算法

Python
def n_step_td_prediction(env, policy, n=3, alpha=0.1, gamma=0.9, num_episodes=1000):
    """
    n-step TD预测

    参数:
        n: 步数
    """
    V = defaultdict(float)  # defaultdict访问不存在的键时返回默认值

    for episode in range(num_episodes):
        states = [env.reset()[0]]
        rewards = [0]  # 索引对齐

        T = float('inf')
        t = 0

        while True:
            if t < T:
                # 执行动作
                action = np.random.choice(len(policy[states[t]]), p=policy[states[t]])
                next_state, reward, terminated, truncated, _ = env.step(action)
                done = terminated or truncated

                states.append(next_state)
                rewards.append(reward)

                if done:
                    T = t + 1

            # 更新的时间步
            tau = t - n + 1

            if tau >= 0:
                # 计算n-step回报
                G = 0
                for i in range(tau + 1, min(tau + n, T) + 1):
                    G += (gamma ** (i - tau - 1)) * rewards[i]

                if tau + n < T:
                    G += (gamma ** n) * V[states[tau + n]]

                # 更新V
                V[states[tau]] += alpha * (G - V[states[tau]])

            if tau == T - 1:
                break

            t += 1

    return V

2.2 偏差-方差权衡

n值 偏差 方差 适用场景
噪声大、需要快速学习
需要准确估计

3. n-step SARSA

3.1 算法

Python
def n_step_sarsa(env, n=3, alpha=0.1, gamma=0.9, epsilon=0.1, num_episodes=1000):
    """n-step SARSA"""
    n_actions = env.get_action_space()
    Q = defaultdict(lambda: np.zeros(n_actions))

    def epsilon_greedy(state):
        if np.random.random() < epsilon:
            return np.random.randint(n_actions)
        return np.argmax(Q[state])

    for episode in range(num_episodes):
        states = [env.reset()[0]]
        actions = [epsilon_greedy(states[0])]
        rewards = [0]

        T = float('inf')
        t = 0

        while True:
            if t < T:
                next_state, reward, terminated, truncated, _ = env.step(actions[t])
                done = terminated or truncated
                states.append(next_state)
                rewards.append(reward)

                if done:
                    T = t + 1
                else:
                    actions.append(epsilon_greedy(next_state))

            tau = t - n + 1

            if tau >= 0:
                # 计算n-step回报
                G = 0
                for i in range(tau + 1, min(tau + n, T) + 1):
                    G += (gamma ** (i - tau - 1)) * rewards[i]

                if tau + n < T:
                    G += (gamma ** n) * Q[states[tau + n]][actions[tau + n]]

                # 更新Q
                Q[states[tau]][actions[tau]] += alpha * (G - Q[states[tau]][actions[tau]])

            if tau == T - 1:
                break

            t += 1

    return Q

4. TD(λ):资格迹

4.1 前向视角

λ-回报:加权平均所有n-step回报

\[G_t^\lambda = (1-\lambda) \sum_{n=1}^{\infty} \lambda^{n-1} G_{t:t+n}\]

更新规则

\[V(S_t) \leftarrow V(S_t) + \alpha (G_t^\lambda - V(S_t))\]

4.2 后向视角(资格迹)

核心思想: - 为每个状态维护一个"资格"(eligibility) - 最近访问的状态有更高的资格 - 随着时间衰减

资格迹更新

\[E_t(s) = \gamma \lambda E_{t-1}(s) + \mathbb{1}_{S_t=s}\]

TD(λ)更新

\[V(s) \leftarrow V(s) + \alpha \delta_t E_t(s), \quad \forall s\]

其中 \(\delta_t = R_{t+1} + \gamma V(S_{t+1}) - V(S_t)\)

4.3 代码实现

Python
def td_lambda(env, policy, lambda_=0.9, alpha=0.1, gamma=0.9, num_episodes=1000):
    """TD(λ) with eligibility traces"""
    V = defaultdict(float)

    for episode in range(num_episodes):
        E = defaultdict(float)  # 资格迹
        state, _ = env.reset()

        done = False
        while not done:
            action = np.random.choice(len(policy[state]), p=policy[state])
            next_state, reward, terminated, truncated, _ = env.step(action)
            done = terminated or truncated

            # TD误差
            delta = reward + gamma * V[next_state] * (not done) - V[state]

            # 增加当前状态的资格
            E[state] += 1

            # 更新所有状态的值函数
            for s in E.keys():
                V[s] += alpha * delta * E[s]
                E[s] *= gamma * lambda_

            state = next_state

    return V

4.4 SARSA(λ)

Python
def sarsa_lambda(env, lambda_=0.9, alpha=0.1, gamma=0.9, epsilon=0.1, num_episodes=1000):
    """SARSA(λ)"""
    n_actions = env.get_action_space()
    Q = defaultdict(lambda: np.zeros(n_actions))

    def epsilon_greedy(state):
        if np.random.random() < epsilon:
            return np.random.randint(n_actions)
        return np.argmax(Q[state])

    for episode in range(num_episodes):
        E = defaultdict(lambda: np.zeros(n_actions))  # 资格迹

        state, _ = env.reset()
        action = epsilon_greedy(state)

        done = False
        while not done:
            next_state, reward, terminated, truncated, _ = env.step(action)
            done = terminated or truncated
            next_action = epsilon_greedy(next_state) if not done else 0

            # TD误差
            if done:
                delta = reward - Q[state][action]
            else:
                delta = reward + gamma * Q[next_state][next_action] - Q[state][action]

            # 更新资格迹
            E[state][action] += 1

            # 更新所有Q值
            for s in E.keys():
                for a in range(n_actions):
                    Q[s][a] += alpha * delta * E[s][a]
                    E[s][a] *= gamma * lambda_

            state = next_state
            action = next_action

    return Q

5. 方法选择指南

方法 偏差 方差 计算复杂度 适用场景
TD(0) O(1) 在线学习、快速收敛
n-step O(n) 平衡偏差-方差
TD(λ) 可调 可调 O(S) 需要平衡时
MC O(T) 需要无偏估计

6. 本章总结

核心概念

Text Only
多步方法:
├── n-step: 使用n步回报
│   ├── n=1: TD(0)
│   ├── n=∞: MC
│   └── 中间值: 平衡偏差-方差
└── TD(λ): 加权平均所有n-step回报
    ├── 前向视角: λ-回报
    └── 后向视角: 资格迹

资格迹:
├── 记录状态对当前误差的贡献
├── 随时间衰减: γλ
└── 在线更新所有状态

✅ 自测问题

  1. n-step方法的偏差和方差如何随n变化?

  2. TD(λ)的前向和后向视角为什么等价?

  3. 资格迹的直观理解是什么?

  4. 在实际应用中如何选择λ值?


→ 下一阶段:03-函数近似与深度学习