-
Notifications
You must be signed in to change notification settings - Fork 3
/
agent.py
102 lines (82 loc) · 4.19 KB
/
agent.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
from model import Actor, Critic
from torch.optim import Adam
from torch import from_numpy
import numpy as np
import torch
from torch.optim.lr_scheduler import LambdaLR
class Agent:
def __init__(self, env_name, n_iter, n_states, action_bounds, n_actions, lr):
self.env_name = env_name
self.n_iter = n_iter
self.action_bounds = action_bounds
self.n_actions = n_actions
self.n_states = n_states
self.device = torch.device("cpu")
self.lr = lr
self.current_policy = Actor(n_states=self.n_states,
n_actions=self.n_actions).to(self.device)
self.critic = Critic(n_states=self.n_states).to(self.device)
self.actor_optimizer = Adam(self.current_policy.parameters(), lr=self.lr, eps=1e-5)
self.critic_optimizer = Adam(self.critic.parameters(), lr=self.lr, eps=1e-5)
self.critic_loss = torch.nn.MSELoss()
self.scheduler = lambda step: max(1.0 - float(step / self.n_iter), 0)
self.actor_scheduler = LambdaLR(self.actor_optimizer, lr_lambda=self.scheduler)
self.critic_scheduler = LambdaLR(self.actor_optimizer, lr_lambda=self.scheduler)
def choose_dist(self, state):
state = np.expand_dims(state, 0)
state = from_numpy(state).float().to(self.device)
with torch.no_grad():
dist = self.current_policy(state)
# action *= self.action_bounds[1]
# action = np.clip(action, self.action_bounds[0], self.action_bounds[1])
return dist
def get_value(self, state):
state = np.expand_dims(state, 0)
state = from_numpy(state).float().to(self.device)
with torch.no_grad():
value = self.critic(state)
return value.detach().cpu().numpy()
def optimize(self, actor_loss, critic_loss):
self.actor_optimizer.zero_grad()
actor_loss.backward()
# torch.nn.utils.clip_grad_norm_(self.current_policy.parameters(), 0.5)
# torch.nn.utils.clip_grad_norm_(self.critic.parameters(), 0.5)
self.actor_optimizer.step()
self.critic_optimizer.zero_grad()
critic_loss.backward()
# torch.nn.utils.clip_grad_norm_(self.current_policy.parameters(), 0.5)
# torch.nn.utils.clip_grad_norm_(self.critic.parameters(), 0.5)
self.critic_optimizer.step()
def schedule_lr(self):
# self.total_scheduler.step()
self.actor_scheduler.step()
self.critic_scheduler.step()
def save_weights(self, iteration, state_rms):
torch.save({"current_policy_state_dict": self.current_policy.state_dict(),
"critic_state_dict": self.critic.state_dict(),
"actor_optimizer_state_dict": self.actor_optimizer.state_dict(),
"critic_optimizer_state_dict": self.critic_optimizer.state_dict(),
"actor_scheduler_state_dict": self.actor_scheduler.state_dict(),
"critic_scheduler_state_dict": self.critic_scheduler.state_dict(),
"iteration": iteration,
"state_rms_mean": state_rms.mean,
"state_rms_var": state_rms.var,
"state_rms_count": state_rms.count}, self.env_name + "_weights.pth")
def load_weights(self):
checkpoint = torch.load(self.env_name + "_weights.pth")
self.current_policy.load_state_dict(checkpoint["current_policy_state_dict"])
self.critic.load_state_dict(checkpoint["critic_state_dict"])
self.actor_optimizer.load_state_dict(checkpoint["actor_optimizer_state_dict"])
self.critic_optimizer.load_state_dict(checkpoint["critic_optimizer_state_dict"])
self.actor_scheduler.load_state_dict(checkpoint["actor_scheduler_state_dict"])
self.critic_scheduler.load_state_dict(checkpoint["critic_scheduler_state_dict"])
iteration = checkpoint["iteration"]
state_rms_mean = checkpoint["state_rms_mean"]
state_rms_var = checkpoint["state_rms_var"]
return iteration, state_rms_mean, state_rms_var
def set_to_eval_mode(self):
self.current_policy.eval()
self.critic.eval()
def set_to_train_mode(self):
self.current_policy.train()
self.critic.train()