-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathtrain_hier.py
145 lines (113 loc) · 5.23 KB
/
train_hier.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
"""Alternative RLLib model based on local training
You can visualize experiment results in ~/ray_results using TensorBoard.
"""
import gym
from gym.spaces import Discrete, Box
import numpy as np
import os
import random
import inspect
# Ray imports
import ray
from ray import tune
from ray.tune import grid_search
from ray.tune.schedulers import ASHAScheduler # https://openreview.net/forum?id=S1Y7OOlRZ algo for early stopping
from ray.rllib.env.env_context import EnvContext
from ray.rllib.models import ModelCatalog
from ray.rllib.models.tf.tf_modelv2 import TFModelV2
from ray.rllib.models.tf.fcnet import FullyConnectedNetwork
from ray.rllib.models.torch.fcnet import FullyConnectedNetwork as TorchFC
from ray.rllib.agents.trainer import Trainer
from ray.rllib.agents.ppo.ppo import DEFAULT_CONFIG
import ray.rllib.agents.ppo as ppo
from ray.rllib.utils.framework import try_import_tf, try_import_torch
# CybORG imports
from agents.hierachy_agents.hier_env import HierEnv
from typing import Any
import torch
from ray.rllib.models.torch.torch_modelv2 import TorchModelV2
tf1, tf, tfv = try_import_tf()
class CustomModel(TFModelV2):
"""Example of a keras custom model that just delegates to an fc-net."""
def __init__(self, obs_space, action_space, num_outputs, model_config,
name):
super(CustomModel, self).__init__(obs_space, action_space, num_outputs,
model_config, name)
self.model = FullyConnectedNetwork(obs_space, action_space,
num_outputs, model_config, name)
def forward(self, input_dict, state, seq_lens):
return self.model.forward(input_dict, state, seq_lens)
def value_function(self):
return self.model.value_function()
def normc_initializer(std: float = 1.0) -> Any:
def initializer(tensor):
tensor.data.normal_(0, 1)
tensor.data *= std / torch.sqrt(
tensor.data.pow(2).sum(1, keepdim=True))
return initializer
class TorchModel(TorchModelV2, torch.nn.Module):
def __init__(self, obs_space, action_space, num_outputs, model_config,
name):
TorchModelV2.__init__(self, obs_space, action_space, num_outputs, model_config,
name)
torch.nn.Module.__init__(self)
self.model = TorchFC(obs_space, action_space,
num_outputs, model_config, name)
def forward(self, input_dict, state, seq_lens):
return self.model.forward(input_dict, state, seq_lens)
def value_function(self):
return self.model.value_function()
if __name__ == "__main__":
ray.init()
# Can also register the env creator function explicitly with register_env("env name", lambda config: EnvClass(config))
ModelCatalog.register_custom_model("CybORG_hier_Model", TorchModel)
config = Trainer.merge_trainer_configs(
DEFAULT_CONFIG,
{
"env": HierEnv,
"env_config": {
"null": 0,
},
# Use GPUs iff `RLLIB_NUM_GPUS` env var set to > 0.
"num_gpus": int(os.environ.get("RLLIB_NUM_GPUS", "0")),
"model": {
"custom_model": "CybORG_hier_Model",
#"vf_share_layers": False,
},
"lr": 0.0001,
#"momentum": tune.uniform(0, 1),
"num_workers": 4, # parallelism
"framework": "torch", # May also use "tf2", "tfe" or "torch" if supported
"eager_tracing": True, # In order to reach similar execution speed as with static-graph mode (tf default)
"vf_loss_coeff": 0.01, # Scales down the value function loss for better comvergence with PPO
})
stop = {
"training_iteration": 5000, # The number of times tune.report() has been called
"timesteps_total": 10000000, # Total number of timesteps
"episode_reward_mean": -0.1, # When to stop.. it would be great if we could define this in terms
# of a more complex expression which incorporates the episode reward min too
# There is a lot of variance in the episode reward min
}
checkpoint = 'log_dir/PPO_3_step/PPO_HierEnv_ff1e1_00000_0_2022-01-31_16-08-15/checkpoint_000396/checkpoint-396'
log_dir = 'log_dir/'
analysis = tune.run(ppo.PPOTrainer, # Algo to use - alt: ppo.PPOTrainer, impala.ImpalaTrainer
config=config,
local_dir=log_dir,
stop=stop,
checkpoint_at_end=True,
checkpoint_freq=1,
restore=checkpoint,
keep_checkpoints_num=3,
checkpoint_score_attr="episode_reward_mean")
checkpoint_pointer = open("checkpoint_pointer.txt", "w")
last_checkpoint = analysis.get_last_checkpoint(
metric="episode_reward_mean", mode="max"
)
checkpoint_pointer.write(last_checkpoint)
print("Best model checkpoint written to: {}".format(last_checkpoint))
# If you want to throw an error
#if True:
# check_learning_achieved(analysis, 0.1)
checkpoint_pointer.close()
ray.shutdown()
# You can run tensorboard --logdir=log_dir/PPO... to visualise the learning processs during and after training