-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmain.py
102 lines (85 loc) · 3.98 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
import os
import multiprocessing
from src.algorithms.MultiModelPPO3 import MultiModelPPO3
from src.algorithms.MultiModelPPO44 import MultiModelPPO4
from src.games.driving import DrivingGame
from src.networks.ppo4_networks import PPO4ActorNetwork
from src.util.configs import DataclassSaveMixin
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
from src.algorithms.MultiModelDDQN import MultiModelDDQN
from src.networks.simple_network import SimpleActorNetwork, SimpleCriticNetwork, SimpleDDQNNetwork, SimpleSwitchNetwork
from src.user_interface.ui import GameUI
from src.games.space_invaders import SpaceInvaders
from src.games.space_invaders_large import SpaceInvadersLarge
from src.algorithms.PPO import PPO
from src.algorithms.DDQN import DDQN
from src.algorithms.MultiModelPPO import MultiModelPPO
if __name__ == '__main__':
multiprocessing.set_start_method('spawn')
########### PPO Play
# ppo = PPO()
# ppo.load('models/ppo_8')
# ui = GameUI(SpaceInvaders, model = ppo, record=False)
# # ui = GameUI(SpaceInvaders, record=False)
# ui.run()
# # # ui.playback()
############ PPO Train
# print('Starting Training')
# ppo = PPO()
# print('Training PPO')
# print('######################')
# ppo.train(SpaceInvaders, SimpleActorNetwork, SimpleCriticNetwork,
# save_location = f'{os.environ["PBS_O_WORKDIR"]}/models/ppo_large_simple',
# stats_location= f'{os.environ["PBS_O_WORKDIR"]}/models/ppo_large_simple_stats')
# ppo.train(SpaceInvaders, SimpleActorNetwork,SimpleCriticNetwork, save_location = 'ppo_3')
############ DDQN Train
# ddqn = DDQN()
# ddqn.train(SpaceInvaders, SimpleDDQNNetwork, save_location = 'ddqn_1')
############ DDQN Run
# ddqn = DDQN()
# ddqn.load('ddqn_1')
# ui = GameUI(SpaceInvaders, model = ddqn, record=False)
# ui.run()
# # ui.playback()
########### MultiModelPPO Train
# print('Starting Training')
# MultiModelPPO3.Trainer.Configs.GAMMA = 0.999
# MultiModelPPO3.train(DrivingGame, SimpleActorNetwork, SimpleCriticNetwork,SimpleSwitchNetwork,
# save_location = f'{os.environ["PBS_O_WORKDIR"]}/models/driving_multi_model',
# stats_location= f'{os.environ["PBS_O_WORKDIR"]}/models/driving_multi_model_stats')
class Configs(MultiModelPPO4.Trainer.Configs):
'''Training Parameters'''
NUM_WORKERS: int = 1
# this will be multiplied by 1000
TOTAL_TIME_STEPS: int = 600000
OBSERVATIONS_PER_BATCH: int = 10000
NUM_AGENTS: int = 2
UPDATES_PER_ITERATION: int = 1
GAMMA: float = 0.995
CLIP: float = 0.2
REPEAT_ACTION_NUM: int = 3
GAMES_TO_TRAIN_OVER: int = 1000
ENTROPY_SCALAR = 0.001
VOTE_ENTROPY_SCALAR = 0.0005
MultiModelPPO4.Trainer.Configs = Configs
MultiModelPPO4.train(SpaceInvadersLarge, PPO4ActorNetwork, SimpleCriticNetwork,
save_location = f'{os.environ["PBS_O_WORKDIR"]}/models/ppo44_with_all_votes_dual_gpu',
stats_location= f'{os.environ["PBS_O_WORKDIR"]}/models/ppo44_with_all_votes_dual_gpu_stats')
# # # ########### MultiModelPPO Play
# multi_model_ppo = MultiModelPPO()
# multi_model_ppo.load('multi_model_ppo_1')
# # ui = GameUI(SpaceInvadersLarge, model = multi_model_ppo, record=False)
# ui = GameUI(SpaceInvaders, model = multi_model_ppo, record=False)
# ui.run()
# # # ui.playback()
############## MultiModelDDQN Train
# MultiModelDDQN.train(SpaceInvadersLarge, SimpleDDQNNetwork, SimpleCriticNetwork,
# save_location = f'{os.environ["PBS_O_WORKDIR"]}/models/multi_ddqn_large_single_agent',
# stats_location = f'{os.environ["PBS_O_WORKDIR"]}/models/multi_ddqn_large_single_agent_stats'
# )
# MultiModelDDQN.train(SpaceInvaders, SimpleDDQNNetwork, SimpleCriticNetwork,
# save_location = 'models/multi_ddqn_simple')
############## MultiModelDDQN Play
# ui = GameUI(SpaceInvaders, model = MultiModelDDQN('models/multi_ddqn_simple'), record=False)
# ui.run()
# # # ui.playback()