-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathalgo.py
52 lines (43 loc) · 1.88 KB
/
algo.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
import torch
import torch.nn as nn
import torch.nn.functional as F
class PolicyNet(nn.Module):
def __init__(self, state_dim, hidden_dim, action_dim):
super().__init__()
self.fc1 = nn.Linear(state_dim, hidden_dim)
self.fc2 = nn.Linear(hidden_dim, action_dim)
def forward(self, x):
x = F.relu(self.fc1(x))
return F.softmax(self.fc2(x), dim=1)
class Reinforce:
def __init__(self, state_dim, hidden_dim, action_dim, learning_rate, gamma, device):
self.policy_net = PolicyNet(state_dim, hidden_dim, action_dim).to(device)
self.optimizer = torch.optim.Adam(self.policy_net.parameters(), lr=learning_rate)
self.gamma = gamma
self.device = device
def take_action(self, state):
state = torch.tensor([state], dtype=torch.float).to(self.device)
probs = self.policy_net(state)
action_dist = torch.distributions.Categorical(probs)
action = action_dist.sample()
return action.item()
def take_max_action(self, state):
state = torch.tensor([state], dtype=torch.float).to(self.device)
probs = self.policy_net(state)
action = probs.argmax(1)
return action.item()
def update(self, transition_dict):
reward_list = transition_dict['rewards']
state_list = transition_dict['states']
action_list = transition_dict['actions']
G = 0
self.optimizer.zero_grad()
for i in reversed(range(len(reward_list))):
reward = reward_list[i]
state = torch.tensor([state_list[i]], dtype=torch.float).to(self.device)
action = torch.tensor([action_list[i]]).view(-1, 1).to(self.device)
log_prob = torch.log(self.policy_net(state)).gather(1, action)
G = self.gamma * G + reward
loss = - log_prob * G
loss.backward()
self.optimizer.step()