-
Notifications
You must be signed in to change notification settings - Fork 0
/
train.py
74 lines (65 loc) · 1.88 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
import gym
import numpy as np
from pathlib import Path
import argparse
import os
import ray
import ray.tune as tune
from maze_env import MalmoMazeEnv
# For creating OpenAI gym environment (custom MalmoMazeEnv)
def create_env(config):
xml = Path(config["mission_file"]).read_text()
env = MalmoMazeEnv(
xml=xml,
width=config["width"],
height=config["height"],
millisec_per_tick=config["millisec_per_tick"])
return env
# For stopping a learner for successful training
def stop_check(trial_id, result):
return result["episode_reward_mean"] >= 85
# Main
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument("mission_path",
help="full path to the mission file lava_maze_malmo.xml",
type=str)
parser.add_argument("--num_gpus",
type=int,
required=False,
default=0,
help="number of gpus")
args = parser.parse_args()
tune.register_env("testenv01", create_env)
ray.init()
tune.run(
run_or_experiment="DQN",
config={
"log_level": "WARN",
"env": "testenv01",
"env_config": {
"mission_file": args.mission_path,
"width": 84,
"height": 84,
"millisec_per_tick": 20
},
"framework": "tf",
"num_gpus": args.num_gpus,
"num_workers": 1,
"double_q": True,
"dueling": True,
"explore": True,
"exploration_config": {
"type": "EpsilonGreedy",
"initial_epsilon": 1.0,
"final_epsilon": 0.02,
"epsilon_timesteps": 500000
}
},
stop=stop_check,
checkpoint_freq=1,
checkpoint_at_end=True,
local_dir='./logs'
)
print('training has done !')
ray.shutdown()