跳转至

05 - 分布式强化学习

学习时间: 3-4小时 重要性: ⭐⭐⭐⭐ 大规模并行训练 前置知识: A3C、经验回放


🎯 学习目标

完成本章后,你将能够: - 理解分布式RL的核心思想 - 掌握Ape-X架构 - 了解IMPALA等分布式方法 - 应用分布式训练加速学习


1. 分布式RL简介

1.1 动机

单机的限制: - 计算资源有限 - 训练时间长 - 样本收集慢

分布式的优势: - 并行环境交互 - 加速训练 - 处理大规模问题

1.2 架构分类

Text Only
分布式RL
├── 分布式Actor,集中式Learner
│   ├── Ape-X
│   └── R2D2
├── 分布式Actor,分布式Learner
│   ├── IMPALA
│   └── SEED
└── 参数服务器架构
    └── A3C变种

2. Ape-X

2.1 核心思想

多个Actor,一个Learner: - Actor并行收集经验 - 优先经验回放缓冲区 - Learner从缓冲区学习

2.2 代码实现

Python
import torch
import torch.nn as nn
import torch.multiprocessing as mp
import gymnasium as gym
from collections import deque
import random
import numpy as np

class Actor:
    """分布式Actor"""

    def __init__(self, actor_id, shared_network, replay_queue, epsilon):
        self.actor_id = actor_id
        self.network = shared_network
        self.replay_queue = replay_queue
        self.epsilon = epsilon

    def run(self, env_name):
        """Actor主循环"""
        env = gym.make(env_name)

        while True:
            state, _ = env.reset()
            done = False

            while not done:
                # ε-贪婪策略
                if random.random() < self.epsilon:
                    action = env.action_space.sample()
                else:
                    with torch.no_grad():  # 禁用梯度计算,节省内存
                        state_tensor = torch.FloatTensor(state)
                        q_values = self.network(state_tensor)
                        action = q_values.argmax().item()  # 将单元素张量转为Python数值

                next_state, reward, terminated, truncated, _ = env.step(action)
                done = terminated or truncated

                # 发送经验到缓冲区
                experience = (state, action, reward, next_state, done)
                self.replay_queue.put(experience)

                state = next_state

class Learner:
    """集中式Learner"""

    def __init__(self, network, target_network, optimizer):
        self.network = network
        self.target_network = target_network
        self.optimizer = optimizer
        self.replay_buffer = PrioritizedReplayBuffer(capacity=1000000)

    def learn(self, batch_size=32):
        """从缓冲区学习"""
        if len(self.replay_buffer) < batch_size:
            return None

        # 优先级采样
        batch, indices, weights = self.replay_buffer.sample(batch_size)
        states, actions, rewards, next_states, dones = zip(*batch)  # zip按位置配对

        # 转换为张量
        states = torch.FloatTensor(states)
        actions = torch.LongTensor(actions)
        rewards = torch.FloatTensor(rewards)
        next_states = torch.FloatTensor(next_states)
        dones = torch.FloatTensor(dones)
        weights = torch.FloatTensor(weights)

        # DQN更新
        current_q = self.network(states).gather(1, actions.unsqueeze(1))  # unsqueeze增加一个维度

        with torch.no_grad():
            next_q = self.target_network(next_states).max(1)[0]
            target_q = rewards + 0.99 * next_q * (1 - dones)

        # 加权损失
        td_errors = torch.abs(current_q.squeeze() - target_q)  # squeeze压缩维度
        loss = (weights * td_errors.pow(2)).mean()

        self.optimizer.zero_grad()  # 清零梯度
        loss.backward()  # 反向传播计算梯度
        self.optimizer.step()  # 更新参数

        # 更新优先级
        self.replay_buffer.update_priorities(indices, td_errors.detach().cpu().numpy())  # 分离计算图,确保在CPU上

        return loss.item()

3. IMPALA

3.1 核心思想

重要性加权Actor-Learner架构: - 多个Actor并行 - 多个Learner并行 - 使用V-trace修正偏差

3.2 V-trace

重要性采样比率: $\(\rho_t = \min\left(\bar{\rho}, \frac{\pi(a_t|s_t)}{\mu(a_t|s_t)}\right)\)$

V-trace目标: $\(v_s = V(s) + \sum_{t=s}^{s+n-1} \gamma^{t-s} \left(\prod_{i=s}^{t-1} c_i\right) \rho_t \delta_t\)$

3.3 代码实现

Python
class IMPALA:
    """IMPALA算法"""

    def __init__(self, network, rho_bar=1.0, c_bar=1.0):
        self.network = network
        self.rho_bar = rho_bar
        self.c_bar = c_bar

    def vtrace(self, rewards, values, log_probs_policy, log_probs_behavior, gamma=0.99):
        """
        计算V-trace目标

        参数:
            rewards: 奖励序列
            values: 价值估计序列
            log_probs_policy: 策略对数概率
            log_probs_behavior: 行为策略对数概率
        """
        # 计算重要性采样比率
        log_rhos = log_probs_policy - log_probs_behavior
        rhos = torch.exp(log_rhos)

        # 截断比率
        clipped_rhos = torch.clamp(rhos, max=self.rho_bar)
        clipped_cs = torch.clamp(rhos, max=self.c_bar)

        # 计算V-trace目标
        deltas = clipped_rhos * (rewards + gamma * values[1:] - values[:-1])

        # 反向计算
        vs = torch.zeros_like(values)
        vs[-1] = values[-1]  # [-1]负索引取最后元素

        for t in reversed(range(len(rewards))):
            vs[t] = values[t] + deltas[t] + gamma * clipped_cs[t] * (vs[t+1] - values[t+1])

        return vs

4. 分布式方法对比

方法 Actor Learner 特点
Ape-X 优先回放
IMPALA V-trace修正
R2D2 循环网络
SEED gRPC通信

5. 本章总结

核心概念

Text Only
分布式RL:
├── 并行Actor: 加速样本收集
├── 优先回放: 高效利用样本
├── V-trace: 修正Off-Policy偏差
└── 可扩展性: 处理大规模问题

选择建议:
├── 研究: IMPALA
├── 工程: Ape-X
└── 大规模: SEED

✅ 自测问题

  1. 分布式RL相比单机RL有什么优势?

  2. Ape-X中的优先级回放如何工作?

  3. V-trace的作用是什么?


📚 延伸阅读

  1. Horgan et al. (2018) - Ape-X
  2. Espeholt et al. (2018) - IMPALA
  3. Kapturowski et al. (2019) - R2D2

→ 下一阶段:05-实战项目