-
Notifications
You must be signed in to change notification settings - Fork 0
/
q_learning.py
124 lines (106 loc) · 3.71 KB
/
q_learning.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
# -*- coding: utf-8 -*-
"""Q-Learning.ipynb
Automatically generated by Colaboratory.
Original file is located at
https://colab.research.google.com/drive/1LkCoqnDqCGpixrhh7Q018PCtJQPgj2Qm
"""
import gymnasium as gym
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import random
# Hyperparameters
LR = 0.001
GAMMA = 0.99 # discount factor on rewards so future rewards are worth less
EPSILON_START = 1 # Randomness factor to encourage exploration
# final espsilon value we want this to be low at the end since we want to exploit
EPSILON_END = 0.01
EPSILON_DECAY = 200
MEMORY_SIZE = 10000
class RL(nn.Module):
def __init__(self, input_dim, output_dim) -> None:
super(RL, self).__init__()
self.network = nn.Sequential(
nn.Linear(input_dim, 64),
nn.ReLU(),
nn.Linear(64, output_dim),
)
def forward(self, x):
return self.network(x)
def take_action(state, iters, network):
# sample against epsilon
choice = random.random()
probability = EPSILON_END + \
(EPSILON_START - EPSILON_END) * np.exp(-1. * iters / EPSILON_DECAY)
if choice > probability:
# Pass in state get action Exploit
# Pick the max argument index then convert to a float
action = network(state).argmax().item()
else:
# Pick a random action from the action_space
action = env.action_space.sample()
return action
def train(network, memory, optimizer):
if len(memory) < 32:
return
# Unpack the memory
sample = random.sample(memory, min(len(memory), 32))
state, action, reward, next_state = zip(*sample)
state = torch.stack(state)
action = torch.stack(action)
reward = torch.stack(reward)
next_state = torch.stack(next_state)
# print(state.shape, action.shape, reward.shape, next_state.shape)
# Get the current q value based on the action taken
# print(action)
current_q = network(state).gather(1, action.unsqueeze(1).long())
print(current_q.shape)
# Get the max q value for the next state
with torch.no_grad():
max_next_q = network(next_state).max(1)[0]
print('R', max_next_q.shape, reward.shape)
# Get the expected q value
expected_q = (reward + GAMMA * max_next_q)
# Calculate the loss
loss = criterion(current_q, expected_q.unsqueeze(1))
optimizer.zero_grad()
loss.backward()
optimizer.step()
env = gym.make("CartPole-v1")
n_states = env.observation_space.shape[0]
n_actions = env.action_space.n
model = RL(n_states, n_actions)
optimizer = optim.Adam(model.parameters(), lr=LR)
criterion = nn.MSELoss()
state = torch.tensor(env.reset()[0], dtype=torch.float32)
action = take_action(state, 0, model)
avg_loss = 0
# Now loop
num_episodes = 500
memory = []
for episode in range(num_episodes):
state = torch.tensor(env.reset()[0], dtype=torch.float32)
steps_done = 0
total_reward = 0
for t in range(200):
action = take_action(state, steps_done, model)
observation, reward, terminated, truncated, info = env.step(action)
next_state = torch.tensor(observation, dtype=torch.float32)
if terminated:
reward = -20
memory.append((state, torch.tensor(action, dtype=torch.float32), torch.tensor(
reward, dtype=torch.float32), next_state))
if len(memory) > MEMORY_SIZE:
memory.pop(0)
state = next_state
total_reward += reward
steps_done += 1
# If over reset
if terminated or truncated:
break
train(model, memory, optimizer)
avg_loss += total_reward
print(
f"Episode {episode + 1}, Total Reward: {total_reward}, Total Avg: {avg_loss / (episode + 1)}")
env.close()