03 - TRPO算法:信任区域策略优化¶
学习时间: 4-5小时 重要性: ⭐⭐⭐⭐ 理论保证的策略优化 前置知识: 自然策略梯度、KL散度
🎯 学习目标¶
完成本章后,你将能够: - 理解TRPO的核心思想(信任区域约束) - 掌握Fisher信息矩阵和共轭梯度法 - 理解TRPO的理论保证 - 了解TRPO与PPO的关系
1. TRPO简介¶
1.1 动机¶
标准策略梯度的问题: - 步长难以选择 - 大的策略更新可能导致性能崩溃 - 没有单调性保证
TRPO的解决方案: - 限制策略更新的幅度 - 使用信任区域方法 - 提供理论保证
1.2 核心思想¶
约束优化问题:
\[\max_\theta \mathbb{E}_{s \sim \pi_{\theta_{old}}, a \sim \pi_\theta} \left[\frac{\pi_\theta(a|s)}{\pi_{\theta_{old}}(a|s)} A^{\pi_{\theta_{old}}}(s,a)\right]\]
\[\text{s.t. } \mathbb{E}_{s \sim \pi_{\theta_{old}}} [D_{KL}(\pi_{\theta_{old}}(\cdot|s) || \pi_\theta(\cdot|s))] \leq \delta\]
2. 理论推导¶
2.1 策略性能差异¶
关键引理:
\[J(\pi_{new}) = J(\pi_{old}) + \mathbb{E}_{\pi_{new}}[A^{\pi_{old}}(s,a)]\]
重要性采样:
\[\mathbb{E}_{\pi_{new}}[A^{\pi_{old}}(s,a)] = \mathbb{E}_{\pi_{old}}\left[\frac{\pi_{new}(a|s)}{\pi_{old}(a|s)} A^{\pi_{old}}(s,a)\right]\]
2.2 近似目标¶
替代优势函数:
\[L_{\pi_{old}}(\pi) = \mathbb{E}_{\pi_{old}}\left[\frac{\pi(a|s)}{\pi_{old}(a|s)} A^{\pi_{old}}(s,a)\right]\]
理论保证:
\[J(\pi_{new}) \geq L_{\pi_{old}}(\pi_{new}) - \frac{2\epsilon\gamma}{(1-\gamma)^2} \max_s D_{KL}(\pi_{old} || \pi_{new})\]
3. Fisher信息矩阵¶
3.1 定义¶
\[F = \mathbb{E}_{\pi_\theta}[\nabla_\theta \log \pi_\theta(a|s) \nabla_\theta \log \pi_\theta(a|s)^T]\]
与KL散度的关系:
\[D_{KL}(\pi_\theta || \pi_{\theta+\Delta\theta}) \approx \frac{1}{2} \Delta\theta^T F \Delta\theta\]
3.2 自然梯度¶
自然策略梯度:
\[\Delta\theta = F^{-1} \nabla_\theta J(\theta)\]
问题:F是大型矩阵,直接求逆计算量大
解决方案:共轭梯度法
4. 共轭梯度法¶
4.1 核心思想¶
求解线性方程组:
\[Fx = g\]
其中 \(g = \nabla_\theta J(\theta)\)
避免直接求逆:使用迭代方法
4.2 算法流程¶
Python
import numpy as np
import torch
def conjugate_gradient(Ax, b, max_iters=10, residual_tol=1e-10):
"""
共轭梯度法求解 Ax = b
参数:
Ax: 函数,计算矩阵向量乘积 A*x
b: 目标向量
max_iters: 最大迭代次数
residual_tol: 残差容忍度
返回:
x: 解
"""
x = torch.zeros_like(b)
r = b.clone() # 残差
p = r.clone() # 搜索方向
r_dot_r = torch.dot(r, r)
for i in range(max_iters):
Ap = Ax(p)
alpha = r_dot_r / (torch.dot(p, Ap) + 1e-8)
x += alpha * p
r -= alpha * Ap
new_r_dot_r = torch.dot(r, r)
if new_r_dot_r < residual_tol:
break
beta = new_r_dot_r / r_dot_r
p = r + beta * p
r_dot_r = new_r_dot_r
return x
def fisher_vector_product(policy, states, vector, damping=0.1):
"""
计算Fisher信息矩阵与向量的乘积 F*v
使用Pearlmutter技巧避免显式计算F
"""
# 计算KL散度的梯度
kl = compute_kl_divergence(policy, states)
# 二阶导数
grads = torch.autograd.grad(kl, policy.parameters(), create_graph=True)
flat_grads = flatten_gradients(grads)
# 与向量点积
grad_v = torch.dot(flat_grads, vector)
# 再次求导
grads2 = torch.autograd.grad(grad_v, policy.parameters())
flat_grads2 = flatten_gradients(grads2)
# 添加阻尼项
return flat_grads2 + damping * vector
def flatten_gradients(grads):
"""将梯度列表展平为向量"""
return torch.cat([g.view(-1) for g in grads]) # 重塑张量形状 # torch.cat沿已有维度拼接张量
5. TRPO算法¶
5.1 完整算法流程¶
Python
import torch
import torch.nn as nn
import torch.optim as optim
from torch.distributions import Categorical
class TRPO:
"""TRPO算法"""
def __init__(self, state_dim, action_dim, hidden_dim=64,
max_kl=0.01, damping=0.1, cg_iters=10, line_search_iters=10):
self.max_kl = max_kl
self.damping = damping
self.cg_iters = cg_iters
self.line_search_iters = line_search_iters
# 策略网络
self.policy = self._build_network(state_dim, action_dim, hidden_dim)
# 值网络
self.value = self._build_network(state_dim, 1, hidden_dim)
self.value_optimizer = optim.Adam(self.value.parameters(), lr=3e-4)
def _build_network(self, input_dim, output_dim, hidden_dim):
"""构建网络"""
return nn.Sequential(
nn.Linear(input_dim, hidden_dim),
nn.Tanh(),
nn.Linear(hidden_dim, hidden_dim),
nn.Tanh(),
nn.Linear(hidden_dim, output_dim)
)
def select_action(self, state):
"""选择动作"""
with torch.no_grad(): # 禁用梯度计算,节省内存
logits = self.policy(state)
dist = Categorical(logits=logits)
action = dist.sample()
log_prob = dist.log_prob(action)
return action.item(), log_prob # 将单元素张量转为Python数值
def compute_advantages(self, rewards, values, gamma=0.99, lam=0.95):
"""计算GAE优势"""
advantages = []
gae = 0
for t in reversed(range(len(rewards))):
if t == len(rewards) - 1:
next_value = 0
else:
next_value = values[t + 1]
delta = rewards[t] + gamma * next_value - values[t]
gae = delta + gamma * lam * gae
advantages.insert(0, gae)
return torch.tensor(advantages, dtype=torch.float32)
def surrogate_loss(self, states, actions, advantages, old_log_probs):
"""计算替代损失"""
logits = self.policy(states)
dist = Categorical(logits=logits)
log_probs = dist.log_prob(actions)
# 概率比率
ratio = torch.exp(log_probs - old_log_probs)
# 替代目标
surrogate = ratio * advantages
return surrogate.mean()
def kl_divergence(self, states, old_logits):
"""计算当前策略与旧策略之间的KL散度"""
with torch.no_grad():
old_dist = Categorical(logits=old_logits)
logits = self.policy(states)
dist = Categorical(logits=logits)
# KL(old || new)
kl = torch.distributions.kl_divergence(old_dist, dist)
return kl.mean()
def flat_grad(self, loss, params):
"""计算展平梯度"""
grads = torch.autograd.grad(loss, params, create_graph=True)
return torch.cat([g.view(-1) for g in grads])
def trpo_step(self, states, actions, advantages, old_log_probs):
"""TRPO更新步骤"""
# 保存旧策略参数
old_policy_params = [p.clone() for p in self.policy.parameters()]
# 计算策略梯度
loss = self.surrogate_loss(states, actions, advantages, old_log_probs)
policy_grad = self.flat_grad(loss, self.policy.parameters())
# 使用共轭梯度法求解 F*x = g
def fvp(v):
return self.fisher_vector_product(states, v)
step_dir = conjugate_gradient(fvp, policy_grad, max_iters=self.cg_iters)
# 计算步长
shs = 0.5 * torch.dot(step_dir, fvp(step_dir))
lm = torch.sqrt(shs / self.max_kl)
full_step = step_dir / lm
# 计算参数索引,用于将展平向量映射回各参数张量
param_shapes = [(p.shape, p.numel()) for p in self.policy.parameters()]
# 保存旧策略网络的logits用于KL散度计算
with torch.no_grad():
old_logits = self.policy(states).detach() # 分离计算图,不参与梯度计算
# 线性搜索
old_loss = loss.item()
for ls_iter in range(self.line_search_iters):
alpha = 0.9 ** ls_iter
# 将展平的步长向量按参数分片还原
offset = 0
new_params = []
for old_p, (shape, numel) in zip(old_policy_params, param_shapes): # zip按位置配对
new_p = old_p + alpha * full_step[offset:offset + numel].view(shape)
new_params.append(new_p)
offset += numel
# 更新参数
for p, new_p in zip(self.policy.parameters(), new_params):
p.data.copy_(new_p)
# 检查改进
new_loss = self.surrogate_loss(states, actions, advantages, old_log_probs)
# 计算KL散度(与旧策略比较)
with torch.no_grad():
new_logits = self.policy(states)
old_dist = Categorical(logits=old_logits)
new_dist = Categorical(logits=new_logits)
kl = torch.distributions.kl_divergence(old_dist, new_dist).mean()
if new_loss > old_loss and kl < self.max_kl:
return True
# 如果没有改进,恢复旧参数
for p, old_p in zip(self.policy.parameters(), old_policy_params):
p.data.copy_(old_p)
return False
def fisher_vector_product(self, states, vector):
"""计算Fisher向量乘积"""
# 计算当前策略的KL散度(与自身的微小偏移)
logits = self.policy(states)
dist = Categorical(logits=logits)
# 利用 D_KL(π_θ || π_θ) 在 θ 处的Hessian = Fisher信息矩阵
# 技巧:计算 KL(π_θ_stop_grad || π_θ),对 θ 求二阶导
with torch.no_grad():
fixed_logits = logits.detach()
fixed_dist = Categorical(logits=fixed_logits)
kl = torch.distributions.kl_divergence(fixed_dist, dist).mean()
grads = torch.autograd.grad(kl, self.policy.parameters(), create_graph=True)
flat_grads = torch.cat([g.view(-1) for g in grads])
grad_v = torch.dot(flat_grads, vector)
grads2 = torch.autograd.grad(grad_v, self.policy.parameters())
flat_grads2 = torch.cat([g.contiguous().view(-1) for g in grads2])
return flat_grads2 + self.damping * vector
6. TRPO vs PPO¶
| 特性 | TRPO | PPO |
|---|---|---|
| 优化方法 | 二阶优化(共轭梯度) | 一阶优化(梯度下降) |
| 约束处理 | KL约束 | 裁剪或惩罚 |
| 实现复杂度 | 复杂 | 简单 |
| 计算成本 | 高 | 低 |
| 样本效率 | 高 | 中 |
| 稳定性 | 高 | 高 |
实践建议: - 研究/理论:使用TRPO - 工程/应用:使用PPO
7. 本章总结¶
核心概念¶
Text Only
TRPO:
├── 信任区域: 限制策略更新幅度
├── KL约束: 保证单调性
├── Fisher信息矩阵: 参数空间几何
├── 共轭梯度: 高效求解
└── 理论保证: 性能单调提升
与PPO对比:
├── TRPO: 精确但复杂
└── PPO: 近似但简单
✅ 自测问题¶
-
为什么需要信任区域约束?
-
Fisher信息矩阵与KL散度有什么关系?
-
共轭梯度法的优势是什么?
-
TRPO和PPO各适用于什么场景?
📚 延伸阅读¶
- Schulman et al. (2015)
- "Trust Region Policy Optimization"
-
ICML 2015
-
Kakade (2002)
- "A Natural Policy Gradient"
-
NIPS 2002
-
Amari (1998)
- "Natural Gradient Works Efficiently in Learning"
- Neural Computation
→ 下一步:04-模型基方法.md