-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathAleAgent.py
149 lines (124 loc) · 5.12 KB
/
AleAgent.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
import sys
import pygame
import numpy as np
from ale_python_interface import ALEInterface
from Autoencoder.Encoder import Encoder
from NFQ.NFQ import NFQ
##
# Main Class of the program initializing vizual UI of the game and also a processing and training phase
class AleAgent:
##
# @param processing_cls Class for processing game visual unput
def __init__(self, processing_cls, game_rom=None, encoder_model=None, encoder_weights=None, NFQ_model=None, NFQ_weights=None):
assert game_rom is not None
self.game = ALEInterface()
if encoder_weights is not None and encoder_model is not None:
self.encoder = Encoder(path_to_model=encoder_model, path_to_weights=encoder_weights)
else:
self.encoder = Encoder()
self.processor = processing_cls()
# Get & Set the desired settings
self.game.setInt('random_seed', 0)
self.game.setInt('frame_skip', 4)
# Set USE_SDL to true to display the screen. ALE must be compilied
# with SDL enabled for this to work. On OSX, pygame init is used to
# proxy-call SDL_main.
USE_SDL = True
if USE_SDL:
if sys.platform == 'darwin':
pygame.init()
self.game.setBool('sound', False) # Sound doesn't work on OSX
elif sys.platform.startswith('linux'):
self.game.setBool('sound', False) # no sound
self.game.setBool('display_screen', True)
# Load the ROM file
self.game.loadROM(game_rom)
# Get the list of legal actions
self.legal_actions = self.game.getLegalActionSet()
# Get actions applicable in current game
self.minimal_actions = self.game.getMinimalActionSet()
if NFQ_model is not None and NFQ_weights is not None:
self.NFQ = NFQ(
self.encoder.out_dim,
len(self.minimal_actions),
model_path=NFQ_model,
weights_path=NFQ_weights
)
else:
self.NFQ = NFQ(self.encoder.out_dim, len(self.minimal_actions))
(self.screen_width, self.screen_height) = self.game.getScreenDims()
self.screen_data = np.zeros(
(self.screen_height, self.screen_width),
dtype=np.uint8
)
##
# Initialize the reinforcement learning
def train(self, num_of_episodes=1500, eps=0.995, key_binding=None):
pygame.init()
for episode in xrange(num_of_episodes):
total_reward = 0
moves = 0
hits = 0
print 'Starting episode: ', episode+1
if key_binding:
eps = 0.05
else:
eps -= 2/num_of_episodes
self.game.getScreenGrayscale(self.screen_data)
pooled_data = self.processor.process(self.screen_data)
next_state = self.encoder.encode(pooled_data)
while not self.game.game_over():
current_state = next_state
x = None
if key_binding:
key_pressed = pygame.key.get_pressed()
x = key_binding(key_pressed)
if x is None:
r = np.random.rand()
if r < eps:
x = np.random.randint(self.minimal_actions.size)
else:
x = self.NFQ.predict_action(current_state)
a = self.minimal_actions[x]
# Apply an action and get the resulting reward
reward = self.game.act(a)
# record only every 3 frames
# if not moves % 3:
self.game.getScreenGrayscale(self.screen_data)
pooled_data = self.processor.process(self.screen_data)
next_state = self.encoder.encode(pooled_data)
transition = np.append(current_state, x)
transition = np.append(transition, next_state)
transition = np.append(transition, reward)
self.NFQ.add_transition(transition)
total_reward += reward
if reward > 0:
hits += 1
moves += 1
if eps > 0.1:
eps -= 0.00001
# end while
print 'Epsilon: ', eps
print 'Episode', episode+1, 'ended with score:', total_reward
print 'Hits: ', hits
self.game.reset_game()
self.NFQ.train()
hits = 0
moves = 0
self.NFQ.save_net()
# end for
##
# Play the game!
def play(self):
total_reward = 0
moves = 1
while not self.game.game_over():
self.game.getScreenGrayscale(self.screen_data)
pooled_data = self.processor.process(self.screen_data)
current_state = self.encoder.encode(pooled_data)
x = self.NFQ.predict_action(current_state)
a = self.minimal_actions[x]
reward = self.game.act(a)
total_reward += reward
moves += 1
print 'The game ended with score:', total_reward, ' after: ', moves, ' moves'