跳转至

05 - Actor-Critic高级方法

学习时间: 4-5小时 重要性: ⭐⭐⭐⭐ Actor-Critic的进阶技术 前置知识: Actor-Critic、策略梯度


🎯 学习目标

完成本章后,你将能够: - 理解GAE(广义优势估计) - 掌握A3C异步训练 - 了解自然策略梯度 - 理解TRPO的基础概念


1. 广义优势估计 (GAE)

1.1 动机

问题: - TD(0):偏差大,方差小 - MC:偏差小,方差大

解决方案:GAE平衡两者

1.2 GAE公式

\[\hat{A}_t^{GAE(\gamma,\lambda)} = \sum_{l=0}^{\infty}(\gamma\lambda)^l\delta_{t+l}^V\]

其中: $\(\delta_t^V = r_t + \gamma V(s_{t+1}) - V(s_t)\)$

特殊情况: - λ=0:TD(0) - λ=1:MC

1.3 代码实现

Python
def compute_gae(rewards, values, dones, gamma=0.99, lambda_=0.95, next_value=0):
    """
    计算GAE

    参数:
        rewards: 奖励序列(长度N)
        values: 值函数估计序列(长度N,对应每个时间步的V(s_t))
        dones: 终止标志序列(长度N)
        gamma: 折扣因子
        lambda_: GAE参数
        next_value: 最后一步之后的bootstrap值V(s_{N}),终止时为0

    返回:
        advantages: 优势估计
        returns: 回报估计
    """
    advantages = []
    gae = 0

    for t in reversed(range(len(rewards))):
        if t == len(rewards) - 1:
            nv = next_value
        else:
            nv = values[t + 1]

        delta = rewards[t] + gamma * nv * (1 - dones[t]) - values[t]
        gae = delta + gamma * lambda_ * (1 - dones[t]) * gae
        advantages.insert(0, gae)

    advantages = np.array(advantages)  # np.array创建NumPy数组
    returns = advantages + np.array(values)

    return advantages, returns

2. A3C:异步优势Actor-Critic

2.1 核心思想

异步训练: - 多个并行worker - 每个worker独立与环境交互 - 异步更新全局网络

优势: - 无需经验回放 - 多核CPU并行 - 探索多样化

2.2 代码实现

Python
import torch
import torch.nn as nn
import torch.optim as optim
import torch.multiprocessing as mp
from torch.distributions import Categorical

class A3CNetwork(nn.Module):  # 继承nn.Module定义网络层
    """A3C共享网络"""

    def __init__(self, state_dim, action_dim, hidden_dim=256):
        super(A3CNetwork, self).__init__()

        # 共享层
        self.shared = nn.Sequential(
            nn.Linear(state_dim, hidden_dim),
            nn.ReLU()
        )

        # Actor
        self.actor = nn.Linear(hidden_dim, action_dim)

        # Critic
        self.critic = nn.Linear(hidden_dim, 1)

    def forward(self, state):
        features = self.shared(state)
        logits = self.actor(features)
        value = self.critic(features)
        return logits, value

class A3CWorker(mp.Process):
    """A3C Worker进程"""

    def __init__(self, global_net, optimizer, env_name, worker_id,
                 gamma=0.99, max_steps=20):
        super(A3CWorker, self).__init__()

        self.global_net = global_net
        self.optimizer = optimizer
        self.env_name = env_name
        self.worker_id = worker_id
        self.gamma = gamma
        self.max_steps = max_steps

        # 本地网络
        self.local_net = A3CNetwork(global_net.shared[0].in_features,
                                     global_net.actor.out_features)
        self.local_net.load_state_dict(global_net.state_dict())

    def run(self):
        """Worker主循环"""
        env = gym.make(self.env_name)
        state, _ = env.reset()

        while True:
            # 同步全局参数
            self.local_net.load_state_dict(self.global_net.state_dict())

            # 收集经验
            states, actions, rewards, values, log_probs = [], [], [], [], []

            for step in range(self.max_steps):
                state_tensor = torch.FloatTensor(state)
                logits, value = self.local_net(state_tensor)

                dist = Categorical(logits=logits)
                action = dist.sample()
                log_prob = dist.log_prob(action)

                next_state, reward, terminated, truncated, _ = env.step(action.item())  # 将单元素张量转为Python数值
                done = terminated or truncated

                states.append(state)
                actions.append(action)
                rewards.append(reward)
                values.append(value)
                log_probs.append(log_prob)

                state = next_state

                if done:
                    state, _ = env.reset()
                    break

            # 计算回报和优势
            R = 0 if done else self.local_net(
                torch.FloatTensor(state))[1].item()

            returns = []
            for r in reversed(rewards):
                R = r + self.gamma * R
                returns.insert(0, R)

            returns = torch.tensor(returns)
            values = torch.stack(values).squeeze()  # squeeze压缩维度  # torch.stack沿新维度拼接张量
            log_probs = torch.stack(log_probs)

            advantages = returns - values.detach()  # 分离计算图,不参与梯度计算

            # 计算损失
            actor_loss = -(log_probs * advantages).mean()
            critic_loss = nn.MSELoss()(values, returns)
            entropy = Categorical(logits=logits).entropy().mean()

            loss = actor_loss + 0.5 * critic_loss - 0.01 * entropy

            # 异步更新全局网络
            self.optimizer.zero_grad()  # 清零梯度
            loss.backward()  # 反向传播计算梯度

            # 梯度上传
            for local_param, global_param in zip(  # zip按位置配对
                self.local_net.parameters(),
                self.global_net.parameters()
            ):
                if global_param.grad is not None:
                    global_param.grad += local_param.grad
                else:
                    global_param.grad = local_param.grad.clone()

            self.optimizer.step()  # 更新参数

class A3C:
    """A3C主类"""

    def __init__(self, state_dim, action_dim, lr=1e-4, num_workers=4):
        self.global_net = A3CNetwork(state_dim, action_dim)
        self.global_net.share_memory()

        self.optimizer = optim.Adam(self.global_net.parameters(), lr=lr)
        self.num_workers = num_workers

    def train(self, env_name, num_episodes=10000):
        """训练"""
        workers = [
            A3CWorker(self.global_net, self.optimizer, env_name, i)
            for i in range(self.num_workers)
        ]

        for worker in workers:
            worker.start()

        for worker in workers:
            worker.join()

3. 自然策略梯度

3.1 核心思想

问题:标准梯度下降忽略参数空间的几何结构

解决方案:使用Fisher信息矩阵

\[\theta_{new} = \theta_{old} + \alpha F^{-1}\nabla J(\theta)\]

3.2 Fisher信息矩阵

\[F = \mathbb{E}_{\pi_\theta}[\nabla_\theta \log \pi_\theta(a|s) \nabla_\theta \log \pi_\theta(a|s)^T]\]

3.3 与TRPO的关系

TRPO使用约束优化: $\(\max_\theta J(\theta) \quad \text{s.t.} \quad D_{KL}(\pi_{\theta_{old}} || \pi_\theta) \leq \delta\)$


4. 本章总结

核心概念

Text Only
Actor-Critic高级方法:
├── GAE: 平衡偏差和方差
├── A3C: 异步并行训练
└── 自然梯度: 考虑参数空间几何

选择建议:
├── 单GPU: A2C (同步版本)
├── 多CPU: A3C
└── 连续控制: GAE + PPO

✅ 自测问题

  1. GAE如何平衡偏差和方差?

  2. A3C相比A2C有什么优势?

  3. 自然策略梯度与普通梯度有什么区别?


📚 延伸阅读

  1. Mnih et al. (2016) - A3C
  2. Schulman et al. (2016) - GAE
  3. Kakade (2002) - Natural Policy Gradient

→ 下一阶段:04-高级算法