02 - 离线强化学习¶
学习时间: 4-5小时 重要性: ⭐⭐⭐⭐⭐ 从固定数据集学习 前置知识: Q-Learning、策略约束
🎯 学习目标¶
完成本章后,你将能够: - 理解离线RL的挑战(分布偏移、外推误差) - 掌握BCQ、CQL、IQL等核心算法 - 应用离线RL解决实际问题 - 评估离线RL的性能
1. 离线强化学习简介¶
1.1 什么是离线RL¶
定义:从固定的、预先收集的数据集中学习,不与环境交互
与在线RL的区别: - 在线RL:边交互边学习 - 离线RL:只从数据学习
1.2 应用场景¶
- 医疗:从病历学习治疗策略
- 自动驾驶:从驾驶记录学习
- 推荐系统:从用户行为学习
- 机器人:从演示学习
1.3 核心挑战¶
分布偏移(Distribution Shift): - 学习策略与数据收集策略不同 - 访问未见过状态
外推误差(Extrapolation Error): - 对未见过状态-动作对的Q值估计不准确 - 过度估计问题
2. 挑战详解¶
2.1 分布偏移¶
问题: $\(\pi_{learned}(a|s) \neq \pi_{data}(a|s)\)$
后果: - 选择数据中没有的动作 - Q值估计不可靠
2.2 外推误差¶
Q值过度估计: $\(Q(s,a) \gg Q^*(s,a) \quad \text{for } a \notin \text{data}\)$
原因: - Bellman更新使用max操作 - 对未见过动作过度乐观
3. BCQ:Batch-Constrained Q-learning¶
3.1 核心思想¶
约束动作选择: - 只选择数据集中出现过的动作 - 使用生成模型(VAE)学习动作分布
3.2 算法组件¶
- 扰动模型:生成相似动作
- VAE:学习动作分布
- Q网络:评估动作
3.3 代码实现¶
Python
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
class VAE(nn.Module): # 继承nn.Module定义网络层
"""变分自编码器,学习动作分布"""
def __init__(self, state_dim, action_dim, latent_dim=32, hidden_dim=750):
super(VAE, self).__init__()
# 编码器
self.encoder = nn.Sequential(
nn.Linear(state_dim + action_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU()
)
self.mean = nn.Linear(hidden_dim, latent_dim)
self.log_std = nn.Linear(hidden_dim, latent_dim)
# 解码器
self.decoder = nn.Sequential(
nn.Linear(state_dim + latent_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, action_dim),
nn.Tanh()
)
def encode(self, state, action):
x = torch.cat([state, action], dim=-1) # torch.cat沿已有维度拼接张量
h = self.encoder(x)
return self.mean(h), self.log_std(h)
def reparameterize(self, mean, log_std):
std = torch.exp(0.5 * log_std)
eps = torch.randn_like(std)
return mean + eps * std
def decode(self, state, z):
x = torch.cat([state, z], dim=-1)
return self.decoder(x)
def forward(self, state, action):
mean, log_std = self.encode(state, action)
z = self.reparameterize(mean, log_std)
return self.decode(state, z), mean, log_std
class PerturbationNetwork(nn.Module):
"""扰动网络,生成相似动作"""
def __init__(self, state_dim, action_dim, hidden_dim=400, phi=0.05):
super(PerturbationNetwork, self).__init__()
self.phi = phi
self.net = nn.Sequential(
nn.Linear(state_dim + action_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, action_dim),
nn.Tanh()
)
def forward(self, state, action):
x = torch.cat([state, action], dim=-1)
perturbation = self.net(x)
return action + self.phi * perturbation
class BCQ:
"""BCQ算法"""
def __init__(self, state_dim, action_dim, lr=1e-3, gamma=0.99):
self.gamma = gamma
# VAE
self.vae = VAE(state_dim, action_dim)
self.vae_optimizer = optim.Adam(self.vae.parameters(), lr=lr)
# 扰动网络
self.perturbation = PerturbationNetwork(state_dim, action_dim)
self.perturbation_optimizer = optim.Adam(self.perturbation.parameters(), lr=lr)
# Q网络(双Q)
self.q1 = self._make_q_network(state_dim, action_dim)
self.q2 = self._make_q_network(state_dim, action_dim)
self.q1_target = self._make_q_network(state_dim, action_dim)
self.q2_target = self._make_q_network(state_dim, action_dim)
self.q1_target.load_state_dict(self.q1.state_dict())
self.q2_target.load_state_dict(self.q2.state_dict())
self.q1_optimizer = optim.Adam(self.q1.parameters(), lr=lr)
self.q2_optimizer = optim.Adam(self.q2.parameters(), lr=lr)
def _make_q_network(self, state_dim, action_dim):
return nn.Sequential(
nn.Linear(state_dim + action_dim, 400),
nn.ReLU(),
nn.Linear(400, 300),
nn.ReLU(),
nn.Linear(300, 1)
)
def select_action(self, state):
"""选择动作(采样多个,选Q值最高的)"""
state = torch.FloatTensor(state).unsqueeze(0) # unsqueeze增加一个维度
with torch.no_grad(): # 禁用梯度计算,节省内存
# 采样多个候选动作
actions = []
for _ in range(10):
z = torch.randn(1, 32)
action = self.vae.decode(state, z)
action = self.perturbation(state, action)
actions.append(action)
actions = torch.cat(actions, dim=0)
states = state.repeat(10, 1)
# 选择Q值最高的
q_values = self.q1(torch.cat([states, actions], dim=-1))
best_action = actions[q_values.argmax()]
return best_action.numpy()
def train(self, replay_buffer, batch_size=100):
"""训练BCQ"""
# 采样批次
state, action, reward, next_state, done = replay_buffer.sample(batch_size)
# 训练VAE
recon, mean, log_std = self.vae(state, action)
recon_loss = nn.MSELoss()(recon, action)
kl_loss = -0.5 * torch.sum(1 + log_std - mean.pow(2) - log_std.exp())
vae_loss = recon_loss + 0.5 * kl_loss / action.size(0)
self.vae_optimizer.zero_grad() # 清零梯度
vae_loss.backward() # 反向传播计算梯度
self.vae_optimizer.step() # 更新参数
# 训练Critic
with torch.no_grad():
# 采样候选动作
next_actions = []
for _ in range(10):
z = torch.randn(batch_size, 32)
next_action = self.vae.decode(next_state, z)
next_action = self.perturbation(next_state, next_action)
next_actions.append(next_action)
next_actions = torch.stack(next_actions, dim=0) # (10, batch, action_dim) # torch.stack沿新维度拼接张量
# 计算Q值
next_states_expanded = next_state.unsqueeze(0).expand(10, -1, -1)
next_q1 = self.q1_target(torch.cat([next_states_expanded, next_actions], dim=-1))
next_q2 = self.q2_target(torch.cat([next_states_expanded, next_actions], dim=-1))
next_q = torch.min(next_q1, next_q2)
# 选择最大Q值
max_next_q = next_q.max(dim=0)[0]
target_q = reward + (1 - done) * self.gamma * max_next_q
current_q1 = self.q1(torch.cat([state, action], dim=-1))
current_q2 = self.q2(torch.cat([state, action], dim=-1))
q1_loss = nn.MSELoss()(current_q1, target_q)
q2_loss = nn.MSELoss()(current_q2, target_q)
self.q1_optimizer.zero_grad()
q1_loss.backward()
self.q1_optimizer.step()
self.q2_optimizer.zero_grad()
q2_loss.backward()
self.q2_optimizer.step()
# 训练扰动网络
sampled_actions = self.vae.decode(state, torch.randn(batch_size, 32))
perturbed_actions = self.perturbation(state, sampled_actions)
perturbation_loss = -self.q1(torch.cat([state, perturbed_actions], dim=-1)).mean()
self.perturbation_optimizer.zero_grad()
perturbation_loss.backward()
self.perturbation_optimizer.step()
4. CQL:Conservative Q-Learning¶
4.1 核心思想¶
保守估计: - 降低未见过动作的Q值 - 防止过度估计
损失函数: $\(L_{CQL} = L_{DQN} + \alpha \cdot \mathbb{E}_{s \sim D, a \sim \pi}[Q(s,a)] - \mathbb{E}_{s,a \sim D}[Q(s,a)]\)$
4.2 代码实现¶
Python
class CQL:
"""Conservative Q-Learning(离散动作版本)"""
def __init__(self, state_dim, n_actions, lr=3e-4, gamma=0.99, alpha=1.0):
self.gamma = gamma
self.alpha = alpha
self.n_actions = n_actions
# Q网络(输入状态,输出每个动作的Q值)
self.q1 = self._make_q_network(state_dim, n_actions)
self.q2 = self._make_q_network(state_dim, n_actions)
self.q1_target = self._make_q_network(state_dim, n_actions)
self.q2_target = self._make_q_network(state_dim, n_actions)
self.q1_target.load_state_dict(self.q1.state_dict())
self.q2_target.load_state_dict(self.q2.state_dict())
self.q1_optimizer = optim.Adam(self.q1.parameters(), lr=lr)
self.q2_optimizer = optim.Adam(self.q2.parameters(), lr=lr)
def _make_q_network(self, state_dim, n_actions):
return nn.Sequential(
nn.Linear(state_dim, 256),
nn.ReLU(),
nn.Linear(256, 256),
nn.ReLU(),
nn.Linear(256, n_actions)
)
def train(self, replay_buffer, batch_size=256):
"""训练CQL"""
state, action, reward, next_state, done = replay_buffer.sample(batch_size)
# 标准DQN目标值(Double DQN风格)
with torch.no_grad():
# 用Q1选择动作,用Q2_target评估
next_action = self.q1_target(next_state).argmax(dim=1, keepdim=True)
next_q = self.q2_target(next_state).gather(1, next_action)
target_q = reward + (1 - done) * self.gamma * next_q
# 当前Q值(按数据中的动作索引取值)
current_q1 = self.q1(state).gather(1, action.long())
current_q2 = self.q2(state).gather(1, action.long())
# CQL惩罚项:降低所有动作的Q值,提升数据中动作的Q值
# logsumexp(Q(s, ·)) ≈ log(Σ_a exp(Q(s,a))) 是所有动作Q值的soft-max
all_q1 = self.q1(state)
all_q2 = self.q2(state)
cql_loss_q1 = torch.logsumexp(all_q1, dim=1).mean() - current_q1.mean()
cql_loss_q2 = torch.logsumexp(all_q2, dim=1).mean() - current_q2.mean()
# 总损失 = 标准TD损失 + CQL正则化(各Q网络独立惩罚)
q1_loss = nn.MSELoss()(current_q1, target_q) + self.alpha * cql_loss_q1
q2_loss = nn.MSELoss()(current_q2, target_q) + self.alpha * cql_loss_q2
self.q1_optimizer.zero_grad()
q1_loss.backward()
self.q1_optimizer.step()
self.q2_optimizer.zero_grad()
q2_loss.backward()
self.q2_optimizer.step()
5. 算法对比¶
| 算法 | 核心思想 | 优点 | 缺点 |
|---|---|---|---|
| BCQ | 约束动作选择 | 简单有效 | 需要训练VAE |
| CQL | 保守Q值估计 | 理论保证 | 计算复杂 |
| IQL | 隐式Q学习 | 稳定 | 需要大量数据 |
6. 本章总结¶
核心概念¶
Text Only
离线RL:
├── 挑战:
│ ├── 分布偏移
│ └── 外推误差
├── 解决方案:
│ ├── BCQ: 约束动作
│ ├── CQL: 保守估计
│ └── IQL: 隐式学习
└── 应用:
├── 医疗
├── 自动驾驶
└── 推荐系统
✅ 自测问题¶
-
离线RL与在线RL的主要区别是什么?
-
什么是外推误差?如何解决?
-
BCQ和CQL的核心思想有什么区别?
📚 延伸阅读¶
- Fujimoto et al. (2019) - BCQ
- Kumar et al. (2020) - CQL
- Kostrikov et al. (2022) - IQL
→ 下一步:03-元强化学习.md