diff --git a/examples/exp_configs/rl/multiagent/multiagent_traffic_light_grid.py b/examples/exp_configs/rl/multiagent/multiagent_traffic_light_grid.py index b8293f638..88c412946 100644 --- a/examples/exp_configs/rl/multiagent/multiagent_traffic_light_grid.py +++ b/examples/exp_configs/rl/multiagent/multiagent_traffic_light_grid.py @@ -88,7 +88,7 @@ "target_velocity": 50, "switch_time": 3, "num_observed": 2, - "discrete": False, + "discrete": False, # set True for DQN "tl_type": "actuated", "num_local_edges": 4, "num_local_lights": 4, diff --git a/examples/exp_configs/rl/singleagent/singleagent_traffic_light_grid.py b/examples/exp_configs/rl/singleagent/singleagent_traffic_light_grid.py index 085d26be9..11f94023d 100644 --- a/examples/exp_configs/rl/singleagent/singleagent_traffic_light_grid.py +++ b/examples/exp_configs/rl/singleagent/singleagent_traffic_light_grid.py @@ -144,8 +144,8 @@ def get_non_flow_params(enter_speed, add_net_params): 'target_velocity': 50, 'switch_time': 3.0, 'num_observed': 2, - 'discrete': False, - 'tl_type': 'controlled' + 'discrete': False, # set True for DQN + 'tl_type': 'actuated' } additional_net_params = { diff --git a/examples/train.py b/examples/train.py index 5f8edbb22..be72c34eb 100644 --- a/examples/train.py +++ b/examples/train.py @@ -42,7 +42,9 @@ def parse_args(args): parser.add_argument( '--rl_trainer', type=str, default="rllib", help='the RL trainer to use. either rllib or Stable-Baselines') - + parser.add_argument( + '--algorithm', type=str, default="PPO", + help='RL algorithm to use. Options are PPO and DQN right now.') parser.add_argument( '--num_cpus', type=int, default=1, help='How many CPUs to use') @@ -101,6 +103,7 @@ def run_model_stablebaseline(flow_params, def setup_exps_rllib(flow_params, n_cpus, n_rollouts, + flags, policy_graphs=None, policy_mapping_fn=None, policies_to_train=None): @@ -114,6 +117,8 @@ def setup_exps_rllib(flow_params, number of CPUs to run the experiment over n_rollouts : int number of rollouts per training iteration + flags: + custom arguments policy_graphs : dict, optional TODO policy_mapping_fn : function, optional @@ -139,19 +144,31 @@ def setup_exps_rllib(flow_params, horizon = flow_params['env'].horizon - alg_run = "PPO" - - agent_cls = get_agent_class(alg_run) - config = deepcopy(agent_cls._default_config) + alg_run = flags.algorithm.upper() + + if alg_run == "PPO": + agent_cls = get_agent_class(alg_run) + config = deepcopy(agent_cls._default_config) + config["gamma"] = 0.999 # discount rate + config["model"].update({"fcnet_hiddens": [32, 32, 32]}) + config["use_gae"] = True + config["lambda"] = 0.97 + config["kl_target"] = 0.02 + config["num_sgd_iter"] = 10 + elif alg_run == "DQN": + agent_cls = get_agent_class(alg_run) + config = deepcopy(agent_cls._default_config) + config['clip_actions'] = False + config["timesteps_per_iteration"] = horizon * n_rollouts + # https://github.com/ray-project/ray/blob/master/rllib/tuned_examples/dqn/atari-dist-dqn.yaml + config["hiddens"] = [512] + config["lr"] = 0.0000625 + config["schedule_max_timesteps"] = 2000000 + config["buffer_size"] = 1000000 + config["target_network_update_freq"] = 8000 config["num_workers"] = n_cpus config["train_batch_size"] = horizon * n_rollouts - config["gamma"] = 0.999 # discount rate - config["model"].update({"fcnet_hiddens": [32, 32, 32]}) - config["use_gae"] = True - config["lambda"] = 0.97 - config["kl_target"] = 0.02 - config["num_sgd_iter"] = 10 config["horizon"] = horizon # save the flow params for replay @@ -190,10 +207,10 @@ def train_rllib(submodule, flags): policies_to_train = getattr(submodule, "policies_to_train", None) alg_run, gym_name, config = setup_exps_rllib( - flow_params, n_cpus, n_rollouts, + flow_params, n_cpus, n_rollouts, flags, policy_graphs, policy_mapping_fn, policies_to_train) - ray.init(num_cpus=n_cpus + 1, object_store_memory=200 * 1024 * 1024) + ray.init(num_cpus=n_cpus + 1, ignore_reinit_error=True, object_store_memory=200 * 1024 * 1024) exp_config = { "run": alg_run, "env": gym_name, diff --git a/flow/envs/multiagent/traffic_light_grid.py b/flow/envs/multiagent/traffic_light_grid.py index a0438f828..9cf3161b9 100644 --- a/flow/envs/multiagent/traffic_light_grid.py +++ b/flow/envs/multiagent/traffic_light_grid.py @@ -208,7 +208,7 @@ def _apply_rl_actions(self, rl_actions): for rl_id, rl_action in rl_actions.items(): i = int(rl_id.split("center")[ID_IDX]) if self.discrete: - raise NotImplementedError + action = rl_action else: # convert values less than 0.0 to zero and above to 1. 0's # indicate that we should not switch the direction diff --git a/flow/envs/traffic_light_grid.py b/flow/envs/traffic_light_grid.py index 8be0cb8a5..24f813ea7 100644 --- a/flow/envs/traffic_light_grid.py +++ b/flow/envs/traffic_light_grid.py @@ -19,7 +19,7 @@ "switch_time": 2.0, # whether the traffic lights should be actuated by sumo or RL # options are "controlled" and "actuated" - "tl_type": "controlled", + "tl_type": "actuated", # determines whether the action space is meant to be discrete or continuous "discrete": False, }