-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathrun_dqn.py
102 lines (94 loc) · 2.42 KB
/
run_dqn.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
from game.game import Game
from evo.evo import EvoAlg
from os import path
import random
import time
import sys
from stable_baselines3 import DQN
from stable_baselines3.common.vec_env import DummyVecEnv
gen_param_specs = {
'initial_rock_density' : {
'dtype' : float,
'min' : 0.2,
'max' : 0.4
},
'initial_tree_density' : {
'dtype' : float,
'min' : 0.2,
'max' : 0.4
},
'rock_refinement_runs' : {
'dtype' : int,
'min' : 1,
'max' : 3
},
'tree_refinement_runs' : {
'dtype' : int,
'min' : 1,
'max' : 3
},
'rock_neighbour_depth' : {
'dtype' : int,
'min' : 1,
'max' : 2
},
'tree_neighbour_depth' : {
'dtype' : int,
'min' : 1,
'max' : 2
},
'rock_neighbour_number' : {
'dtype' : int,
'min' : 4,
'max' : 8
},
'tree_neighbour_number' : {
'dtype' : int,
'min' : 4,
'max' : 8
},
'base_clear_depth' : {
'dtype' : int,
'min' : 1,
'max' : 1
},
'enemies_crush_trees' : {
'dtype' : bool
},
'random_seed' : {
'dtype' : int,
'min' : 1,
'max' : 9999
},
'flee_distance' : {
'dtype' : int,
'min' : 0,
'max' : 10
}
}
random.seed(42)
ea = EvoAlg(gen_param_specs)
env_raw = Game(evo_system=ea)
env = DummyVecEnv([lambda : env_raw])
if len(sys.argv) == 3 and sys.argv[1] == '--train':
if path.exists('saved_models/{}.zip'.format(sys.argv[2])):
model = DQN.load('saved_models/{}'.format(sys.argv[2]), env=env)
print('Loading existing model')
else:
print('Creating new model')
model = DQN('CnnPolicy', env, verbose=1, buffer_size=1000, learning_starts=1000, exploration_fraction=0.3)
try:
model.learn(total_timesteps=100000, log_interval=4)
model.save('saved_models/{}'.format(sys.argv[2]))
env_raw.worldgen.save_log(sys.argv[2])
except:
model.save('saved_models/{}'.format(sys.argv[2]))
env_raw.worldgen.save_log(sys.argv[2])
if len(sys.argv) == 3 and sys.argv[1] == '--run':
model = DQN.load('saved_models/{}'.format(sys.argv[2]), env=env)
obs = env.reset()
while True:
action, _ = model.predict(obs, deterministic=False)
obs, reward, done, info = env.step(action)
if done:
obs = env.reset()