Skip to content

Commit

Permalink
update ch09
Browse files Browse the repository at this point in the history
  • Loading branch information
peterwu4084 committed Apr 5, 2023
1 parent 6ba2abc commit 1cd8384
Show file tree
Hide file tree
Showing 6 changed files with 158 additions and 0 deletions.
Binary file added ch09/Reinforce_on_CartPole-v0.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added ch09/Reinforce_on_CartPole-v0_MovingAverage.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
52 changes: 52 additions & 0 deletions ch09/algo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
import torch
import torch.nn as nn
import torch.nn.functional as F


class PolicyNet(nn.Module):
def __init__(self, state_dim, hidden_dim, action_dim):
super().__init__()
self.fc1 = nn.Linear(state_dim, hidden_dim)
self.fc2 = nn.Linear(hidden_dim, action_dim)

def forward(self, x):
x = F.relu(self.fc1(x))
return F.softmax(self.fc2(x), dim=1)


class Reinforce:
def __init__(self, state_dim, hidden_dim, action_dim, learning_rate, gamma, device):
self.policy_net = PolicyNet(state_dim, hidden_dim, action_dim).to(device)
self.optimizer = torch.optim.Adam(self.policy_net.parameters(), lr=learning_rate)
self.gamma = gamma
self.device = device

def take_action(self, state):
state = torch.tensor([state], dtype=torch.float).to(self.device)
probs = self.policy_net(state)
action_dist = torch.distributions.Categorical(probs)
action = action_dist.sample()
return action.item()

def take_max_action(self, state):
state = torch.tensor([state], dtype=torch.float).to(self.device)
probs = self.policy_net(state)
action = probs.argmax(1)
return action.item()

def update(self, transition_dict):
reward_list = transition_dict['rewards']
state_list = transition_dict['states']
action_list = transition_dict['actions']

G = 0
self.optimizer.zero_grad()
for i in reversed(range(len(reward_list))):
reward = reward_list[i]
state = torch.tensor([state_list[i]], dtype=torch.float).to(self.device)
action = torch.tensor([action_list[i]]).view(-1, 1).to(self.device)
log_prob = torch.log(self.policy_net(state)).gather(1, action)
G = self.gamma * G + reward
loss = - log_prob * G
loss.backward()
self.optimizer.step()
31 changes: 31 additions & 0 deletions ch09/display.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
import gym
import torch
from algo import Reinforce
from time import sleep


lr = 0
hidden_dim = 128
gamma = 0.98
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

env = gym.make('CartPole-v0')

state_dim = env.observation_space.shape[0]
action_dim = env.action_space.n
agent = Reinforce(state_dim, hidden_dim, action_dim, lr, gamma, device)
state_dict = torch.load('reinforce_cartpolev0.pth')
agent.policy_net.load_state_dict(state_dict)

state = env.reset()
done = False
agent_return = 0
while not done:
action = agent.take_max_action(state)
next_state, reward, done, _ = env.step(action)
agent_return += reward
env.render()
state = next_state
sleep(0.01)

print('Agent return:', agent_return)
75 changes: 75 additions & 0 deletions ch09/main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
import gym
import torch
import numpy as np
import matplotlib.pyplot as plt
import sys
sys.path.append('../')

from rl_utils import *
from tqdm import tqdm
from algo import *


learning_rate = 1e-3
num_episodes = 1000
hidden_dim = 128
gamma = 0.98
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

env_name = 'CartPole-v0'
env = gym.make(env_name)
env.seed(0)
torch.manual_seed(0)
state_dim = env.observation_space.shape[0]
action_dim = env.action_space.n
agent = Reinforce(state_dim, hidden_dim, action_dim, learning_rate, gamma, device)

return_list = []
for i in range(10):
with tqdm(total=num_episodes // 10, desc='Iteration %d' % i) as pbar:
for i_episode in range(num_episodes // 10):
episode_return = 0
transition_dict = {
'states': [],
'actions': [],
'next_states': [],
'rewards': [],
'dones': []
}

state = env.reset()
done = False
while not done:
action = agent.take_action(state)
next_state, reward, done, _ = env.step(action)
transition_dict['states'].append(state)
transition_dict['actions'].append(action)
transition_dict['next_states'].append(next_state)
transition_dict['rewards'].append(reward)
transition_dict['dones'].append(done)
state = next_state
episode_return += reward
return_list.append(episode_return)
agent.update(transition_dict)

if (i_episode + 1) % 10 == 0:
pbar.set_postfix({
'episode': '%d' % (num_episodes / 10 * i + i_episode + 1),
'return': '%.3f' % np.mean(return_list[-10:])
})
pbar.update(1)
torch.save(agent.policy_net.state_dict(), 'reinforce_cartpolev0.pth')

episodes_list = list(range(len(return_list)))
plt.plot(episodes_list, return_list)
plt.xlabel('Episodes')
plt.ylabel('Returns')
plt.title('Reinforce on {}'.format(env_name))
plt.show()

mv_return = moving_average(return_list, 9)
plt.plot(episodes_list, mv_return)
plt.xlabel('Episodes')
plt.ylabel('Returns')
plt.title('Reinforce on {}'.format(env_name))
plt.show()
Binary file added ch09/reinforce_cartpolev0.pth
Binary file not shown.

0 comments on commit 1cd8384

Please sign in to comment.