跳转至

03 - TRPO算法:信任区域策略优化

学习时间: 4-5小时 重要性: ⭐⭐⭐⭐ 理论保证的策略优化 前置知识: 自然策略梯度、KL散度


🎯 学习目标

完成本章后,你将能够: - 理解TRPO的核心思想(信任区域约束) - 掌握Fisher信息矩阵和共轭梯度法 - 理解TRPO的理论保证 - 了解TRPO与PPO的关系


1. TRPO简介

1.1 动机

标准策略梯度的问题: - 步长难以选择 - 大的策略更新可能导致性能崩溃 - 没有单调性保证

TRPO的解决方案: - 限制策略更新的幅度 - 使用信任区域方法 - 提供理论保证

1.2 核心思想

约束优化问题

\[\max_\theta \mathbb{E}_{s \sim \pi_{\theta_{old}}, a \sim \pi_\theta} \left[\frac{\pi_\theta(a|s)}{\pi_{\theta_{old}}(a|s)} A^{\pi_{\theta_{old}}}(s,a)\right]\]
\[\text{s.t. } \mathbb{E}_{s \sim \pi_{\theta_{old}}} [D_{KL}(\pi_{\theta_{old}}(\cdot|s) || \pi_\theta(\cdot|s))] \leq \delta\]

2. 理论推导

2.1 策略性能差异

关键引理

\[J(\pi_{new}) = J(\pi_{old}) + \mathbb{E}_{\pi_{new}}[A^{\pi_{old}}(s,a)]\]

重要性采样

\[\mathbb{E}_{\pi_{new}}[A^{\pi_{old}}(s,a)] = \mathbb{E}_{\pi_{old}}\left[\frac{\pi_{new}(a|s)}{\pi_{old}(a|s)} A^{\pi_{old}}(s,a)\right]\]

2.2 近似目标

替代优势函数

\[L_{\pi_{old}}(\pi) = \mathbb{E}_{\pi_{old}}\left[\frac{\pi(a|s)}{\pi_{old}(a|s)} A^{\pi_{old}}(s,a)\right]\]

理论保证

\[J(\pi_{new}) \geq L_{\pi_{old}}(\pi_{new}) - \frac{2\epsilon\gamma}{(1-\gamma)^2} \max_s D_{KL}(\pi_{old} || \pi_{new})\]

3. Fisher信息矩阵

3.1 定义

\[F = \mathbb{E}_{\pi_\theta}[\nabla_\theta \log \pi_\theta(a|s) \nabla_\theta \log \pi_\theta(a|s)^T]\]

与KL散度的关系

\[D_{KL}(\pi_\theta || \pi_{\theta+\Delta\theta}) \approx \frac{1}{2} \Delta\theta^T F \Delta\theta\]

3.2 自然梯度

自然策略梯度

\[\Delta\theta = F^{-1} \nabla_\theta J(\theta)\]

问题:F是大型矩阵,直接求逆计算量大

解决方案:共轭梯度法


4. 共轭梯度法

4.1 核心思想

求解线性方程组

\[Fx = g\]

其中 \(g = \nabla_\theta J(\theta)\)

避免直接求逆:使用迭代方法

4.2 算法流程

Python
import numpy as np
import torch

def conjugate_gradient(Ax, b, max_iters=10, residual_tol=1e-10):
    """
    共轭梯度法求解 Ax = b

    参数:
        Ax: 函数,计算矩阵向量乘积 A*x
        b: 目标向量
        max_iters: 最大迭代次数
        residual_tol: 残差容忍度

    返回:
        x: 解
    """
    x = torch.zeros_like(b)
    r = b.clone()  # 残差
    p = r.clone()  # 搜索方向

    r_dot_r = torch.dot(r, r)

    for i in range(max_iters):
        Ap = Ax(p)
        alpha = r_dot_r / (torch.dot(p, Ap) + 1e-8)

        x += alpha * p
        r -= alpha * Ap

        new_r_dot_r = torch.dot(r, r)
        if new_r_dot_r < residual_tol:
            break

        beta = new_r_dot_r / r_dot_r
        p = r + beta * p
        r_dot_r = new_r_dot_r

    return x

def fisher_vector_product(policy, states, vector, damping=0.1):
    """
    计算Fisher信息矩阵与向量的乘积 F*v

    使用Pearlmutter技巧避免显式计算F
    """
    # 计算KL散度的梯度
    kl = compute_kl_divergence(policy, states)

    # 二阶导数
    grads = torch.autograd.grad(kl, policy.parameters(), create_graph=True)
    flat_grads = flatten_gradients(grads)

    # 与向量点积
    grad_v = torch.dot(flat_grads, vector)

    # 再次求导
    grads2 = torch.autograd.grad(grad_v, policy.parameters())
    flat_grads2 = flatten_gradients(grads2)

    # 添加阻尼项
    return flat_grads2 + damping * vector

def flatten_gradients(grads):
    """将梯度列表展平为向量"""
    return torch.cat([g.view(-1) for g in grads])  # 重塑张量形状  # torch.cat沿已有维度拼接张量

5. TRPO算法

5.1 完整算法流程

Python
import torch
import torch.nn as nn
import torch.optim as optim
from torch.distributions import Categorical

class TRPO:
    """TRPO算法"""

    def __init__(self, state_dim, action_dim, hidden_dim=64,
                 max_kl=0.01, damping=0.1, cg_iters=10, line_search_iters=10):

        self.max_kl = max_kl
        self.damping = damping
        self.cg_iters = cg_iters
        self.line_search_iters = line_search_iters

        # 策略网络
        self.policy = self._build_network(state_dim, action_dim, hidden_dim)

        # 值网络
        self.value = self._build_network(state_dim, 1, hidden_dim)
        self.value_optimizer = optim.Adam(self.value.parameters(), lr=3e-4)

    def _build_network(self, input_dim, output_dim, hidden_dim):
        """构建网络"""
        return nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.Tanh(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.Tanh(),
            nn.Linear(hidden_dim, output_dim)
        )

    def select_action(self, state):
        """选择动作"""
        with torch.no_grad():  # 禁用梯度计算,节省内存
            logits = self.policy(state)
            dist = Categorical(logits=logits)
            action = dist.sample()
            log_prob = dist.log_prob(action)
            return action.item(), log_prob  # 将单元素张量转为Python数值

    def compute_advantages(self, rewards, values, gamma=0.99, lam=0.95):
        """计算GAE优势"""
        advantages = []
        gae = 0

        for t in reversed(range(len(rewards))):
            if t == len(rewards) - 1:
                next_value = 0
            else:
                next_value = values[t + 1]

            delta = rewards[t] + gamma * next_value - values[t]
            gae = delta + gamma * lam * gae
            advantages.insert(0, gae)

        return torch.tensor(advantages, dtype=torch.float32)

    def surrogate_loss(self, states, actions, advantages, old_log_probs):
        """计算替代损失"""
        logits = self.policy(states)
        dist = Categorical(logits=logits)
        log_probs = dist.log_prob(actions)

        # 概率比率
        ratio = torch.exp(log_probs - old_log_probs)

        # 替代目标
        surrogate = ratio * advantages

        return surrogate.mean()

    def kl_divergence(self, states, old_logits):
        """计算当前策略与旧策略之间的KL散度"""
        with torch.no_grad():
            old_dist = Categorical(logits=old_logits)

        logits = self.policy(states)
        dist = Categorical(logits=logits)

        # KL(old || new)
        kl = torch.distributions.kl_divergence(old_dist, dist)
        return kl.mean()

    def flat_grad(self, loss, params):
        """计算展平梯度"""
        grads = torch.autograd.grad(loss, params, create_graph=True)
        return torch.cat([g.view(-1) for g in grads])

    def trpo_step(self, states, actions, advantages, old_log_probs):
        """TRPO更新步骤"""
        # 保存旧策略参数
        old_policy_params = [p.clone() for p in self.policy.parameters()]

        # 计算策略梯度
        loss = self.surrogate_loss(states, actions, advantages, old_log_probs)
        policy_grad = self.flat_grad(loss, self.policy.parameters())

        # 使用共轭梯度法求解 F*x = g
        def fvp(v):
            return self.fisher_vector_product(states, v)

        step_dir = conjugate_gradient(fvp, policy_grad, max_iters=self.cg_iters)

        # 计算步长
        shs = 0.5 * torch.dot(step_dir, fvp(step_dir))
        lm = torch.sqrt(shs / self.max_kl)
        full_step = step_dir / lm

        # 计算参数索引,用于将展平向量映射回各参数张量
        param_shapes = [(p.shape, p.numel()) for p in self.policy.parameters()]

        # 保存旧策略网络的logits用于KL散度计算
        with torch.no_grad():
            old_logits = self.policy(states).detach()  # 分离计算图,不参与梯度计算

        # 线性搜索
        old_loss = loss.item()

        for ls_iter in range(self.line_search_iters):
            alpha = 0.9 ** ls_iter

            # 将展平的步长向量按参数分片还原
            offset = 0
            new_params = []
            for old_p, (shape, numel) in zip(old_policy_params, param_shapes):  # zip按位置配对
                new_p = old_p + alpha * full_step[offset:offset + numel].view(shape)
                new_params.append(new_p)
                offset += numel

            # 更新参数
            for p, new_p in zip(self.policy.parameters(), new_params):
                p.data.copy_(new_p)

            # 检查改进
            new_loss = self.surrogate_loss(states, actions, advantages, old_log_probs)

            # 计算KL散度(与旧策略比较)
            with torch.no_grad():
                new_logits = self.policy(states)
                old_dist = Categorical(logits=old_logits)
                new_dist = Categorical(logits=new_logits)
                kl = torch.distributions.kl_divergence(old_dist, new_dist).mean()

            if new_loss > old_loss and kl < self.max_kl:
                return True

        # 如果没有改进,恢复旧参数
        for p, old_p in zip(self.policy.parameters(), old_policy_params):
            p.data.copy_(old_p)

        return False

    def fisher_vector_product(self, states, vector):
        """计算Fisher向量乘积"""
        # 计算当前策略的KL散度(与自身的微小偏移)
        logits = self.policy(states)
        dist = Categorical(logits=logits)

        # 利用 D_KL(π_θ || π_θ) 在 θ 处的Hessian = Fisher信息矩阵
        # 技巧:计算 KL(π_θ_stop_grad || π_θ),对 θ 求二阶导
        with torch.no_grad():
            fixed_logits = logits.detach()
        fixed_dist = Categorical(logits=fixed_logits)
        kl = torch.distributions.kl_divergence(fixed_dist, dist).mean()

        grads = torch.autograd.grad(kl, self.policy.parameters(), create_graph=True)
        flat_grads = torch.cat([g.view(-1) for g in grads])

        grad_v = torch.dot(flat_grads, vector)
        grads2 = torch.autograd.grad(grad_v, self.policy.parameters())
        flat_grads2 = torch.cat([g.contiguous().view(-1) for g in grads2])

        return flat_grads2 + self.damping * vector

6. TRPO vs PPO

特性 TRPO PPO
优化方法 二阶优化(共轭梯度) 一阶优化(梯度下降)
约束处理 KL约束 裁剪或惩罚
实现复杂度 复杂 简单
计算成本
样本效率
稳定性

实践建议: - 研究/理论:使用TRPO - 工程/应用:使用PPO


7. 本章总结

核心概念

Text Only
TRPO:
├── 信任区域: 限制策略更新幅度
├── KL约束: 保证单调性
├── Fisher信息矩阵: 参数空间几何
├── 共轭梯度: 高效求解
└── 理论保证: 性能单调提升

与PPO对比:
├── TRPO: 精确但复杂
└── PPO: 近似但简单

✅ 自测问题

  1. 为什么需要信任区域约束?

  2. Fisher信息矩阵与KL散度有什么关系?

  3. 共轭梯度法的优势是什么?

  4. TRPO和PPO各适用于什么场景?


📚 延伸阅读

  1. Schulman et al. (2015)
  2. "Trust Region Policy Optimization"
  3. ICML 2015

  4. Kakade (2002)

  5. "A Natural Policy Gradient"
  6. NIPS 2002

  7. Amari (1998)

  8. "Natural Gradient Works Efficiently in Learning"
  9. Neural Computation

→ 下一步:04-模型基方法.md