-
Notifications
You must be signed in to change notification settings - Fork 21
/
Copy pathplay.py
68 lines (52 loc) · 1.42 KB
/
play.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
import utils
import nets
from logger import Logger
from PIL import Image
import gym
import numpy as np
import torch
from torch.autograd import Variable
from models import History
from nets import DeepQNetwork
from processing import phi_map, tuple_to_numpy
def to_variable(arr):
v = Variable(torch.from_numpy(arr).float())
return v
def greedy_action(Q, phi):
epsilon = 0.0
# Obtain a random value in range [0,1)
rand = np.random.uniform()
phi = to_variable(phi)
# With probability e select random action a_t
if rand < epsilon:
return env.action_space.sample()
else:
print(Q(phi))
# raw_input()
# Otherwise select action that maximises Q(phi)
# In other words: a_t = argmax_a Q(phi, a)
return Q(phi).max(1)[1].data
def initial_history(env):
s = env.reset()[0]
H = History()
for _ in range(H.length):
H.add(s)
return H
# ----------------------------------
# Play
env = gym.make('Pong-v0')
H = initial_history(env)
Q = DeepQNetwork(6)
Q.load_state_dict(torch.load('data/models/episode_360.txt',
map_location=lambda storage, loc: storage))
# print(Q.state_dict())
# raw_input()
while(True):
env.render(mode='human')
phi = phi_map(H.get())
action = greedy_action(Q, phi)
# raw_input()
image, reward, done, _ = env.step(action)
H.add(image)
if done:
H.add(env.reset()[0])