跳转至

02 - 离线强化学习

学习时间: 4-5小时 重要性: ⭐⭐⭐⭐⭐ 从固定数据集学习 前置知识: Q-Learning、策略约束


🎯 学习目标

完成本章后,你将能够: - 理解离线RL的挑战(分布偏移、外推误差) - 掌握BCQ、CQL、IQL等核心算法 - 应用离线RL解决实际问题 - 评估离线RL的性能


1. 离线强化学习简介

1.1 什么是离线RL

定义:从固定的、预先收集的数据集中学习,不与环境交互

与在线RL的区别: - 在线RL:边交互边学习 - 离线RL:只从数据学习

1.2 应用场景

  • 医疗:从病历学习治疗策略
  • 自动驾驶:从驾驶记录学习
  • 推荐系统:从用户行为学习
  • 机器人:从演示学习

1.3 核心挑战

分布偏移(Distribution Shift): - 学习策略与数据收集策略不同 - 访问未见过状态

外推误差(Extrapolation Error): - 对未见过状态-动作对的Q值估计不准确 - 过度估计问题


2. 挑战详解

2.1 分布偏移

问题: $\(\pi_{learned}(a|s) \neq \pi_{data}(a|s)\)$

后果: - 选择数据中没有的动作 - Q值估计不可靠

2.2 外推误差

Q值过度估计: $\(Q(s,a) \gg Q^*(s,a) \quad \text{for } a \notin \text{data}\)$

原因: - Bellman更新使用max操作 - 对未见过动作过度乐观


3. BCQ:Batch-Constrained Q-learning

3.1 核心思想

约束动作选择: - 只选择数据集中出现过的动作 - 使用生成模型(VAE)学习动作分布

3.2 算法组件

  1. 扰动模型:生成相似动作
  2. VAE:学习动作分布
  3. Q网络:评估动作

3.3 代码实现

Python
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np

class VAE(nn.Module):  # 继承nn.Module定义网络层
    """变分自编码器,学习动作分布"""

    def __init__(self, state_dim, action_dim, latent_dim=32, hidden_dim=750):
        super(VAE, self).__init__()

        # 编码器
        self.encoder = nn.Sequential(
            nn.Linear(state_dim + action_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU()
        )

        self.mean = nn.Linear(hidden_dim, latent_dim)
        self.log_std = nn.Linear(hidden_dim, latent_dim)

        # 解码器
        self.decoder = nn.Sequential(
            nn.Linear(state_dim + latent_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, action_dim),
            nn.Tanh()
        )

    def encode(self, state, action):
        x = torch.cat([state, action], dim=-1)  # torch.cat沿已有维度拼接张量
        h = self.encoder(x)
        return self.mean(h), self.log_std(h)

    def reparameterize(self, mean, log_std):
        std = torch.exp(0.5 * log_std)
        eps = torch.randn_like(std)
        return mean + eps * std

    def decode(self, state, z):
        x = torch.cat([state, z], dim=-1)
        return self.decoder(x)

    def forward(self, state, action):
        mean, log_std = self.encode(state, action)
        z = self.reparameterize(mean, log_std)
        return self.decode(state, z), mean, log_std

class PerturbationNetwork(nn.Module):
    """扰动网络,生成相似动作"""

    def __init__(self, state_dim, action_dim, hidden_dim=400, phi=0.05):
        super(PerturbationNetwork, self).__init__()

        self.phi = phi

        self.net = nn.Sequential(
            nn.Linear(state_dim + action_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, action_dim),
            nn.Tanh()
        )

    def forward(self, state, action):
        x = torch.cat([state, action], dim=-1)
        perturbation = self.net(x)
        return action + self.phi * perturbation

class BCQ:
    """BCQ算法"""

    def __init__(self, state_dim, action_dim, lr=1e-3, gamma=0.99):
        self.gamma = gamma

        # VAE
        self.vae = VAE(state_dim, action_dim)
        self.vae_optimizer = optim.Adam(self.vae.parameters(), lr=lr)

        # 扰动网络
        self.perturbation = PerturbationNetwork(state_dim, action_dim)
        self.perturbation_optimizer = optim.Adam(self.perturbation.parameters(), lr=lr)

        # Q网络(双Q)
        self.q1 = self._make_q_network(state_dim, action_dim)
        self.q2 = self._make_q_network(state_dim, action_dim)
        self.q1_target = self._make_q_network(state_dim, action_dim)
        self.q2_target = self._make_q_network(state_dim, action_dim)

        self.q1_target.load_state_dict(self.q1.state_dict())
        self.q2_target.load_state_dict(self.q2.state_dict())

        self.q1_optimizer = optim.Adam(self.q1.parameters(), lr=lr)
        self.q2_optimizer = optim.Adam(self.q2.parameters(), lr=lr)

    def _make_q_network(self, state_dim, action_dim):
        return nn.Sequential(
            nn.Linear(state_dim + action_dim, 400),
            nn.ReLU(),
            nn.Linear(400, 300),
            nn.ReLU(),
            nn.Linear(300, 1)
        )

    def select_action(self, state):
        """选择动作(采样多个,选Q值最高的)"""
        state = torch.FloatTensor(state).unsqueeze(0)  # unsqueeze增加一个维度

        with torch.no_grad():  # 禁用梯度计算,节省内存
            # 采样多个候选动作
            actions = []
            for _ in range(10):
                z = torch.randn(1, 32)
                action = self.vae.decode(state, z)
                action = self.perturbation(state, action)
                actions.append(action)

            actions = torch.cat(actions, dim=0)
            states = state.repeat(10, 1)

            # 选择Q值最高的
            q_values = self.q1(torch.cat([states, actions], dim=-1))
            best_action = actions[q_values.argmax()]

        return best_action.numpy()

    def train(self, replay_buffer, batch_size=100):
        """训练BCQ"""
        # 采样批次
        state, action, reward, next_state, done = replay_buffer.sample(batch_size)

        # 训练VAE
        recon, mean, log_std = self.vae(state, action)
        recon_loss = nn.MSELoss()(recon, action)
        kl_loss = -0.5 * torch.sum(1 + log_std - mean.pow(2) - log_std.exp())
        vae_loss = recon_loss + 0.5 * kl_loss / action.size(0)

        self.vae_optimizer.zero_grad()  # 清零梯度
        vae_loss.backward()  # 反向传播计算梯度
        self.vae_optimizer.step()  # 更新参数

        # 训练Critic
        with torch.no_grad():
            # 采样候选动作
            next_actions = []
            for _ in range(10):
                z = torch.randn(batch_size, 32)
                next_action = self.vae.decode(next_state, z)
                next_action = self.perturbation(next_state, next_action)
                next_actions.append(next_action)

            next_actions = torch.stack(next_actions, dim=0)  # (10, batch, action_dim)  # torch.stack沿新维度拼接张量

            # 计算Q值
            next_states_expanded = next_state.unsqueeze(0).expand(10, -1, -1)
            next_q1 = self.q1_target(torch.cat([next_states_expanded, next_actions], dim=-1))
            next_q2 = self.q2_target(torch.cat([next_states_expanded, next_actions], dim=-1))
            next_q = torch.min(next_q1, next_q2)

            # 选择最大Q值
            max_next_q = next_q.max(dim=0)[0]
            target_q = reward + (1 - done) * self.gamma * max_next_q

        current_q1 = self.q1(torch.cat([state, action], dim=-1))
        current_q2 = self.q2(torch.cat([state, action], dim=-1))

        q1_loss = nn.MSELoss()(current_q1, target_q)
        q2_loss = nn.MSELoss()(current_q2, target_q)

        self.q1_optimizer.zero_grad()
        q1_loss.backward()
        self.q1_optimizer.step()

        self.q2_optimizer.zero_grad()
        q2_loss.backward()
        self.q2_optimizer.step()

        # 训练扰动网络
        sampled_actions = self.vae.decode(state, torch.randn(batch_size, 32))
        perturbed_actions = self.perturbation(state, sampled_actions)

        perturbation_loss = -self.q1(torch.cat([state, perturbed_actions], dim=-1)).mean()

        self.perturbation_optimizer.zero_grad()
        perturbation_loss.backward()
        self.perturbation_optimizer.step()

4. CQL:Conservative Q-Learning

4.1 核心思想

保守估计: - 降低未见过动作的Q值 - 防止过度估计

损失函数: $\(L_{CQL} = L_{DQN} + \alpha \cdot \mathbb{E}_{s \sim D, a \sim \pi}[Q(s,a)] - \mathbb{E}_{s,a \sim D}[Q(s,a)]\)$

4.2 代码实现

Python
class CQL:
    """Conservative Q-Learning(离散动作版本)"""

    def __init__(self, state_dim, n_actions, lr=3e-4, gamma=0.99, alpha=1.0):
        self.gamma = gamma
        self.alpha = alpha
        self.n_actions = n_actions

        # Q网络(输入状态,输出每个动作的Q值)
        self.q1 = self._make_q_network(state_dim, n_actions)
        self.q2 = self._make_q_network(state_dim, n_actions)
        self.q1_target = self._make_q_network(state_dim, n_actions)
        self.q2_target = self._make_q_network(state_dim, n_actions)

        self.q1_target.load_state_dict(self.q1.state_dict())
        self.q2_target.load_state_dict(self.q2.state_dict())

        self.q1_optimizer = optim.Adam(self.q1.parameters(), lr=lr)
        self.q2_optimizer = optim.Adam(self.q2.parameters(), lr=lr)

    def _make_q_network(self, state_dim, n_actions):
        return nn.Sequential(
            nn.Linear(state_dim, 256),
            nn.ReLU(),
            nn.Linear(256, 256),
            nn.ReLU(),
            nn.Linear(256, n_actions)
        )

    def train(self, replay_buffer, batch_size=256):
        """训练CQL"""
        state, action, reward, next_state, done = replay_buffer.sample(batch_size)

        # 标准DQN目标值(Double DQN风格)
        with torch.no_grad():
            # 用Q1选择动作,用Q2_target评估
            next_action = self.q1_target(next_state).argmax(dim=1, keepdim=True)
            next_q = self.q2_target(next_state).gather(1, next_action)
            target_q = reward + (1 - done) * self.gamma * next_q

        # 当前Q值(按数据中的动作索引取值)
        current_q1 = self.q1(state).gather(1, action.long())
        current_q2 = self.q2(state).gather(1, action.long())

        # CQL惩罚项:降低所有动作的Q值,提升数据中动作的Q值
        # logsumexp(Q(s, ·)) ≈ log(Σ_a exp(Q(s,a))) 是所有动作Q值的soft-max
        all_q1 = self.q1(state)
        all_q2 = self.q2(state)

        cql_loss_q1 = torch.logsumexp(all_q1, dim=1).mean() - current_q1.mean()
        cql_loss_q2 = torch.logsumexp(all_q2, dim=1).mean() - current_q2.mean()

        # 总损失 = 标准TD损失 + CQL正则化(各Q网络独立惩罚)
        q1_loss = nn.MSELoss()(current_q1, target_q) + self.alpha * cql_loss_q1
        q2_loss = nn.MSELoss()(current_q2, target_q) + self.alpha * cql_loss_q2

        self.q1_optimizer.zero_grad()
        q1_loss.backward()
        self.q1_optimizer.step()

        self.q2_optimizer.zero_grad()
        q2_loss.backward()
        self.q2_optimizer.step()

5. 算法对比

算法 核心思想 优点 缺点
BCQ 约束动作选择 简单有效 需要训练VAE
CQL 保守Q值估计 理论保证 计算复杂
IQL 隐式Q学习 稳定 需要大量数据

6. 本章总结

核心概念

Text Only
离线RL:
├── 挑战:
│   ├── 分布偏移
│   └── 外推误差
├── 解决方案:
│   ├── BCQ: 约束动作
│   ├── CQL: 保守估计
│   └── IQL: 隐式学习
└── 应用:
    ├── 医疗
    ├── 自动驾驶
    └── 推荐系统

✅ 自测问题

  1. 离线RL与在线RL的主要区别是什么?

  2. 什么是外推误差?如何解决?

  3. BCQ和CQL的核心思想有什么区别?


📚 延伸阅读

  1. Fujimoto et al. (2019) - BCQ
  2. Kumar et al. (2020) - CQL
  3. Kostrikov et al. (2022) - IQL

→ 下一步:03-元强化学习.md