diff --git a/ch09/Reinforce_on_CartPole-v0.png b/ch09/Reinforce_on_CartPole-v0.png new file mode 100644 index 0000000..2818c60 Binary files /dev/null and b/ch09/Reinforce_on_CartPole-v0.png differ diff --git a/ch09/Reinforce_on_CartPole-v0_MovingAverage.png b/ch09/Reinforce_on_CartPole-v0_MovingAverage.png new file mode 100644 index 0000000..7816237 Binary files /dev/null and b/ch09/Reinforce_on_CartPole-v0_MovingAverage.png differ diff --git a/ch09/algo.py b/ch09/algo.py new file mode 100644 index 0000000..86458b8 --- /dev/null +++ b/ch09/algo.py @@ -0,0 +1,52 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class PolicyNet(nn.Module): + def __init__(self, state_dim, hidden_dim, action_dim): + super().__init__() + self.fc1 = nn.Linear(state_dim, hidden_dim) + self.fc2 = nn.Linear(hidden_dim, action_dim) + + def forward(self, x): + x = F.relu(self.fc1(x)) + return F.softmax(self.fc2(x), dim=1) + + +class Reinforce: + def __init__(self, state_dim, hidden_dim, action_dim, learning_rate, gamma, device): + self.policy_net = PolicyNet(state_dim, hidden_dim, action_dim).to(device) + self.optimizer = torch.optim.Adam(self.policy_net.parameters(), lr=learning_rate) + self.gamma = gamma + self.device = device + + def take_action(self, state): + state = torch.tensor([state], dtype=torch.float).to(self.device) + probs = self.policy_net(state) + action_dist = torch.distributions.Categorical(probs) + action = action_dist.sample() + return action.item() + + def take_max_action(self, state): + state = torch.tensor([state], dtype=torch.float).to(self.device) + probs = self.policy_net(state) + action = probs.argmax(1) + return action.item() + + def update(self, transition_dict): + reward_list = transition_dict['rewards'] + state_list = transition_dict['states'] + action_list = transition_dict['actions'] + + G = 0 + self.optimizer.zero_grad() + for i in reversed(range(len(reward_list))): + reward = reward_list[i] + state = torch.tensor([state_list[i]], dtype=torch.float).to(self.device) + action = torch.tensor([action_list[i]]).view(-1, 1).to(self.device) + log_prob = torch.log(self.policy_net(state)).gather(1, action) + G = self.gamma * G + reward + loss = - log_prob * G + loss.backward() + self.optimizer.step() \ No newline at end of file diff --git a/ch09/display.py b/ch09/display.py new file mode 100644 index 0000000..676fc7f --- /dev/null +++ b/ch09/display.py @@ -0,0 +1,31 @@ +import gym +import torch +from algo import Reinforce +from time import sleep + + +lr = 0 +hidden_dim = 128 +gamma = 0.98 +device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') + +env = gym.make('CartPole-v0') + +state_dim = env.observation_space.shape[0] +action_dim = env.action_space.n +agent = Reinforce(state_dim, hidden_dim, action_dim, lr, gamma, device) +state_dict = torch.load('reinforce_cartpolev0.pth') +agent.policy_net.load_state_dict(state_dict) + +state = env.reset() +done = False +agent_return = 0 +while not done: + action = agent.take_max_action(state) + next_state, reward, done, _ = env.step(action) + agent_return += reward + env.render() + state = next_state + sleep(0.01) + +print('Agent return:', agent_return) \ No newline at end of file diff --git a/ch09/main.py b/ch09/main.py new file mode 100644 index 0000000..0afaba7 --- /dev/null +++ b/ch09/main.py @@ -0,0 +1,75 @@ +import gym +import torch +import numpy as np +import matplotlib.pyplot as plt +import sys +sys.path.append('../') + +from rl_utils import * +from tqdm import tqdm +from algo import * + + +learning_rate = 1e-3 +num_episodes = 1000 +hidden_dim = 128 +gamma = 0.98 +device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') + +env_name = 'CartPole-v0' +env = gym.make(env_name) +env.seed(0) +torch.manual_seed(0) +state_dim = env.observation_space.shape[0] +action_dim = env.action_space.n +agent = Reinforce(state_dim, hidden_dim, action_dim, learning_rate, gamma, device) + +return_list = [] +for i in range(10): + with tqdm(total=num_episodes // 10, desc='Iteration %d' % i) as pbar: + for i_episode in range(num_episodes // 10): + episode_return = 0 + transition_dict = { + 'states': [], + 'actions': [], + 'next_states': [], + 'rewards': [], + 'dones': [] + } + + state = env.reset() + done = False + while not done: + action = agent.take_action(state) + next_state, reward, done, _ = env.step(action) + transition_dict['states'].append(state) + transition_dict['actions'].append(action) + transition_dict['next_states'].append(next_state) + transition_dict['rewards'].append(reward) + transition_dict['dones'].append(done) + state = next_state + episode_return += reward + return_list.append(episode_return) + agent.update(transition_dict) + + if (i_episode + 1) % 10 == 0: + pbar.set_postfix({ + 'episode': '%d' % (num_episodes / 10 * i + i_episode + 1), + 'return': '%.3f' % np.mean(return_list[-10:]) + }) + pbar.update(1) +torch.save(agent.policy_net.state_dict(), 'reinforce_cartpolev0.pth') + +episodes_list = list(range(len(return_list))) +plt.plot(episodes_list, return_list) +plt.xlabel('Episodes') +plt.ylabel('Returns') +plt.title('Reinforce on {}'.format(env_name)) +plt.show() + +mv_return = moving_average(return_list, 9) +plt.plot(episodes_list, mv_return) +plt.xlabel('Episodes') +plt.ylabel('Returns') +plt.title('Reinforce on {}'.format(env_name)) +plt.show() \ No newline at end of file diff --git a/ch09/reinforce_cartpolev0.pth b/ch09/reinforce_cartpolev0.pth new file mode 100644 index 0000000..14cfb6d Binary files /dev/null and b/ch09/reinforce_cartpolev0.pth differ