项目4 - 连续控制¶
难度: ⭐⭐⭐⭐ 高级 预计时间: 4-6小时 目标: 用DDPG/SAC解决连续控制问题
1. 项目介绍¶
1.1 连续控制问题¶
与离散控制的区别: - 离散控制:动作空间有限(如上下左右) - 连续控制:动作空间连续(如力矩、速度)
典型环境: - Pendulum:倒立摆平衡 - LunarLander:月球着陆器 - MuJoCo:机器人控制
1.2 成功标准¶
- 稳定的控制性能
- 快速收敛
- 良好的泛化能力
2. DDPG实现¶
2.1 DDPG算法¶
核心组件: - Actor:确定性策略 - Critic:Q函数 - 目标网络 - 经验回放
Python
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import numpy as np
import gymnasium as gym
from collections import deque
import random
class Actor(nn.Module): # 继承nn.Module定义网络层
"""确定性策略网络"""
def __init__(self, state_dim, action_dim, hidden_dim=256, max_action=1.0):
super(Actor, self).__init__()
self.net = nn.Sequential(
nn.Linear(state_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, action_dim),
nn.Tanh()
)
self.max_action = max_action
def forward(self, state):
return self.max_action * self.net(state)
class Critic(nn.Module):
"""Q网络"""
def __init__(self, state_dim, action_dim, hidden_dim=256):
super(Critic, self).__init__()
self.net = nn.Sequential(
nn.Linear(state_dim + action_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, 1)
)
def forward(self, state, action):
x = torch.cat([state, action], dim=-1) # torch.cat沿已有维度拼接张量
return self.net(x)
class ReplayBuffer:
"""经验回放缓冲区"""
def __init__(self, capacity=100000):
self.buffer = deque(maxlen=capacity)
def push(self, state, action, reward, next_state, done):
self.buffer.append((state, action, reward, next_state, done))
def sample(self, batch_size):
batch = random.sample(self.buffer, batch_size)
states, actions, rewards, next_states, dones = zip(*batch) # zip按位置配对
return (np.array(states), np.array(actions), # np.array创建NumPy数组
np.array(rewards), np.array(next_states),
np.array(dones))
def __len__(self): # __len__定义len()行为
return len(self.buffer)
class DDPGAgent:
"""DDPG智能体"""
def __init__(self, state_dim, action_dim, max_action=1.0,
actor_lr=1e-3, critic_lr=1e-3, gamma=0.99, tau=0.005):
self.gamma = gamma
self.tau = tau
self.max_action = max_action
# Actor
self.actor = Actor(state_dim, action_dim, max_action=max_action)
self.actor_target = Actor(state_dim, action_dim, max_action=max_action)
self.actor_target.load_state_dict(self.actor.state_dict())
self.actor_optimizer = optim.Adam(self.actor.parameters(), lr=actor_lr)
# Critic
self.critic = Critic(state_dim, action_dim)
self.critic_target = Critic(state_dim, action_dim)
self.critic_target.load_state_dict(self.critic.state_dict())
self.critic_optimizer = optim.Adam(self.critic.parameters(), lr=critic_lr)
self.replay_buffer = ReplayBuffer()
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.actor.to(self.device)
self.actor_target.to(self.device)
self.critic.to(self.device)
self.critic_target.to(self.device)
def select_action(self, state, noise=0.1):
"""选择动作(带探索噪声)"""
state = torch.FloatTensor(state).unsqueeze(0).to(self.device) # unsqueeze增加一个维度
with torch.no_grad(): # 禁用梯度计算,节省内存
action = self.actor(state).cpu().numpy()[0]
# 添加探索噪声
if noise > 0:
action = action + np.random.normal(0, noise, size=action.shape)
action = np.clip(action, -self.max_action, self.max_action)
return action
def update(self, batch_size=64):
"""更新网络"""
if len(self.replay_buffer) < batch_size:
return None
states, actions, rewards, next_states, dones = self.replay_buffer.sample(batch_size)
states = torch.FloatTensor(states).to(self.device)
actions = torch.FloatTensor(actions).to(self.device)
rewards = torch.FloatTensor(rewards).unsqueeze(1).to(self.device)
next_states = torch.FloatTensor(next_states).to(self.device)
dones = torch.FloatTensor(dones).unsqueeze(1).to(self.device)
# Critic更新
with torch.no_grad():
next_actions = self.actor_target(next_states)
target_q = self.critic_target(next_states, next_actions)
target_q = rewards + (1 - dones) * self.gamma * target_q
current_q = self.critic(states, actions)
critic_loss = F.mse_loss(current_q, target_q) # F.xxx PyTorch函数式API
self.critic_optimizer.zero_grad() # 清零梯度
critic_loss.backward() # 反向传播计算梯度
self.critic_optimizer.step() # 更新参数
# Actor更新
actor_loss = -self.critic(states, self.actor(states)).mean()
self.actor_optimizer.zero_grad()
actor_loss.backward()
self.actor_optimizer.step()
# 软更新目标网络
for param, target_param in zip(self.critic.parameters(), self.critic_target.parameters()):
target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data)
for param, target_param in zip(self.actor.parameters(), self.actor_target.parameters()):
target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data)
return critic_loss.item(), actor_loss.item() # 将单元素张量转为Python数值
3. SAC实现¶
Python
class GaussianActor(nn.Module):
"""高斯策略网络(SAC需要随机策略)"""
def __init__(self, state_dim, action_dim, hidden_dim=256, max_action=1.0):
super(GaussianActor, self).__init__()
self.max_action = max_action
self.LOG_STD_MIN = -20
self.LOG_STD_MAX = 2
self.shared = nn.Sequential(
nn.Linear(state_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU(),
)
self.mean_head = nn.Linear(hidden_dim, action_dim)
self.log_std_head = nn.Linear(hidden_dim, action_dim)
def forward(self, state):
features = self.shared(state)
mean = self.mean_head(features)
log_std = self.log_std_head(features)
log_std = torch.clamp(log_std, self.LOG_STD_MIN, self.LOG_STD_MAX)
return mean, log_std
def sample(self, state):
"""重参数化采样,返回动作和对数概率"""
mean, log_std = self.forward(state)
std = log_std.exp()
normal = torch.distributions.Normal(mean, std)
# 重参数化技巧
x_t = normal.rsample()
action = torch.tanh(x_t) * self.max_action
# 计算对数概率(含tanh压缩修正)
log_prob = normal.log_prob(x_t)
log_prob -= torch.log(self.max_action * (1 - torch.tanh(x_t).pow(2)) + 1e-6)
log_prob = log_prob.sum(dim=-1, keepdim=True)
return action, log_prob
class SACAgent:
"""SAC智能体"""
def __init__(self, state_dim, action_dim, max_action=1.0,
lr=3e-4, gamma=0.99, tau=0.005, alpha=0.2):
self.gamma = gamma
self.tau = tau
self.alpha = alpha
self.max_action = max_action
# Actor(高斯策略)
self.actor = GaussianActor(state_dim, action_dim, max_action=max_action)
self.actor_optimizer = optim.Adam(self.actor.parameters(), lr=lr)
# Critic(双Q)
self.critic1 = Critic(state_dim, action_dim)
self.critic2 = Critic(state_dim, action_dim)
self.critic1_target = Critic(state_dim, action_dim)
self.critic2_target = Critic(state_dim, action_dim)
self.critic1_target.load_state_dict(self.critic1.state_dict())
self.critic2_target.load_state_dict(self.critic2.state_dict())
self.critic1_optimizer = optim.Adam(self.critic1.parameters(), lr=lr)
self.critic2_optimizer = optim.Adam(self.critic2.parameters(), lr=lr)
self.replay_buffer = ReplayBuffer()
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.actor.to(self.device)
self.critic1.to(self.device)
self.critic2.to(self.device)
self.critic1_target.to(self.device)
self.critic2_target.to(self.device)
def select_action(self, state):
"""选择动作"""
state = torch.FloatTensor(state).unsqueeze(0).to(self.device)
with torch.no_grad():
action, _ = self.actor.sample(state)
return action.cpu().numpy()[0]
def update(self, batch_size=64):
"""更新网络"""
if len(self.replay_buffer) < batch_size:
return None
states, actions, rewards, next_states, dones = self.replay_buffer.sample(batch_size)
states = torch.FloatTensor(states).to(self.device)
actions = torch.FloatTensor(actions).to(self.device)
rewards = torch.FloatTensor(rewards).unsqueeze(1).to(self.device)
next_states = torch.FloatTensor(next_states).to(self.device)
dones = torch.FloatTensor(dones).unsqueeze(1).to(self.device)
# Critic更新(含熵正则化)
with torch.no_grad():
next_actions, next_log_probs = self.actor.sample(next_states)
next_q1 = self.critic1_target(next_states, next_actions)
next_q2 = self.critic2_target(next_states, next_actions)
next_q = torch.min(next_q1, next_q2) - self.alpha * next_log_probs
target_q = rewards + (1 - dones) * self.gamma * next_q
current_q1 = self.critic1(states, actions)
current_q2 = self.critic2(states, actions)
critic1_loss = F.mse_loss(current_q1, target_q)
critic2_loss = F.mse_loss(current_q2, target_q)
self.critic1_optimizer.zero_grad()
critic1_loss.backward()
self.critic1_optimizer.step()
self.critic2_optimizer.zero_grad()
critic2_loss.backward()
self.critic2_optimizer.step()
# Actor更新(最大化 Q - α·log_prob)
new_actions, log_probs = self.actor.sample(states)
q1 = self.critic1(states, new_actions)
q2 = self.critic2(states, new_actions)
min_q = torch.min(q1, q2)
actor_loss = (self.alpha * log_probs - min_q).mean()
self.actor_optimizer.zero_grad()
actor_loss.backward()
self.actor_optimizer.step()
# 软更新
for param, target_param in zip(self.critic1.parameters(), self.critic1_target.parameters()):
target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data)
for param, target_param in zip(self.critic2.parameters(), self.critic2_target.parameters()):
target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data)
return critic1_loss.item(), actor_loss.item()
4. 训练流程¶
Python
def train_continuous_control(env_name='Pendulum-v1', algorithm='DDPG', num_episodes=500):
"""训练连续控制智能体"""
env = gym.make(env_name)
state_dim = env.observation_space.shape[0]
action_dim = env.action_space.shape[0]
max_action = float(env.action_space.high[0])
# 创建智能体
if algorithm == 'DDPG':
agent = DDPGAgent(state_dim, action_dim, max_action)
elif algorithm == 'SAC':
agent = SACAgent(state_dim, action_dim, max_action)
else:
raise ValueError(f"Unknown algorithm: {algorithm}")
rewards_history = []
for episode in range(num_episodes):
state, _ = env.reset()
episode_reward = 0
done = False
while not done:
# 选择动作
if algorithm == 'DDPG':
action = agent.select_action(state, noise=0.1)
else:
action = agent.select_action(state)
# 执行动作
next_state, reward, terminated, truncated, _ = env.step(action)
done = terminated or truncated
# 存储经验
agent.replay_buffer.push(state, action, reward, next_state, done)
# 更新网络
agent.update()
episode_reward += reward
state = next_state
rewards_history.append(episode_reward)
if (episode + 1) % 10 == 0:
avg_reward = np.mean(rewards_history[-100:])
print(f"Episode {episode + 1}, Avg Reward: {avg_reward:.2f}")
env.close()
return agent, rewards_history
# 主程序
if __name__ == "__main__":
# 训练DDPG
print("Training DDPG...")
ddpg_agent, ddpg_rewards = train_continuous_control('Pendulum-v1', 'DDPG', 500)
# 训练SAC
print("\nTraining SAC...")
sac_agent, sac_rewards = train_continuous_control('Pendulum-v1', 'SAC', 500)
# 对比
import matplotlib.pyplot as plt
plt.figure(figsize=(10, 6))
plt.plot(ddpg_rewards, label='DDPG', alpha=0.7)
plt.plot(sac_rewards, label='SAC', alpha=0.7)
plt.xlabel('Episode')
plt.ylabel('Reward')
plt.title('DDPG vs SAC on Pendulum')
plt.legend()
plt.grid(True, alpha=0.3)
plt.savefig('continuous_control_comparison.png', dpi=150)
plt.show()
5. 项目总结¶
学到的技能¶
- 连续动作空间处理
- DDPG和SAC算法
- 确定性策略与随机策略
- 目标网络软更新
关键概念¶
恭喜完成连续控制项目!
→ 下一个项目:项目5-多智能体协作.md