diff --git a/examples/rllab/green_wave.py b/examples/rllab/green_wave.py index 43128ba6..65ac9580 100644 --- a/examples/rllab/green_wave.py +++ b/examples/rllab/green_wave.py @@ -61,16 +61,16 @@ def get_non_flow_params(enter_speed, additional_net_params): def run_task(*_): - v_enter = 10 - inner_length = 300 - long_length = 100 - short_length = 300 - n = 3 - m = 3 - num_cars_left = 1 - num_cars_right = 1 - num_cars_top = 1 - num_cars_bot = 1 + v_enter = 20 + inner_length = 200 + long_length = 200 + short_length = 200 + n = 1 + m = 1 + num_cars_left = 10 + num_cars_right = 10 + num_cars_top = 10 + num_cars_bot = 10 tot_cars = (num_cars_left + num_cars_right) * m \ + (num_cars_bot + num_cars_top) * n @@ -95,8 +95,8 @@ def run_task(*_): tl_logic = TrafficLights(baseline=False) - additional_env_params = {"target_velocity": 50, "num_steps": 500, - "switch_time": 3.0} + additional_env_params = {"target_velocity": 20, "num_steps": 500, + "min_yellow_time": 4.0, "min_green_time": 8.0} env_params = EnvParams(additional_params=additional_env_params) additional_net_params = {"speed_limit": 35, "grid_array": grid_array, diff --git a/flow/envs/green_wave_env.py b/flow/envs/green_wave_env.py index 20e3b6dd..a8809976 100644 --- a/flow/envs/green_wave_env.py +++ b/flow/envs/green_wave_env.py @@ -9,8 +9,10 @@ from flow.envs.base_env import Env ADDITIONAL_ENV_PARAMS = { - # minimum switch time for each traffic light (in seconds) - "switch_time": 2.0, + # minimum switch time for a traffic light in the yellow phase (in seconds) + "min_yellow_time": 2.0, + # minimum switch time for a traffic light in the green phase (in seconds) + "min_green_time": 8.0, # whether the traffic lights should be actuated by sumo or RL # options are "controlled" and "actuated" "tl_type": "controlled", @@ -76,11 +78,13 @@ def __init__(self, env_params, sumo_params, scenario): # keeps track of the last time the light was allowed to change. self.last_change = np.zeros((self.rows * self.cols, 3)) - # when this hits min_switch_time we change from yellow to red + # When this hits min_yellow_time we can change from yellow to red. + # When this hits min_green_time, we can change from green to yellow. # the second column indicates the direction that is currently being # allowed to flow. 0 is flowing top to bottom, 1 is left to right # For third column, 0 signifies yellow and 1 green or red - self.min_switch_time = env_params.additional_params["switch_time"] + self.min_yellow_time = env_params.additional_params["min_yellow_time"] + self.min_green_time = env_params.additional_params["min_green_time"] if self.tl_type != "actuated": for i in range(self.rows * self.cols): @@ -156,7 +160,7 @@ def _apply_rl_actions(self, rl_actions): # should switch to red if self.last_change[i, 2] == 0: # currently yellow self.last_change[i, 0] += self.sim_step - if self.last_change[i, 0] >= self.min_switch_time: + if self.last_change[i, 0] >= self.min_yellow_time: if self.last_change[i, 1] == 0: self.traffic_lights.set_state( node_id='center{}'.format(i), @@ -165,9 +169,11 @@ def _apply_rl_actions(self, rl_actions): self.traffic_lights.set_state( node_id='center{}'.format(i), state='rrrGGGrrrGGG', env=self) + self.last_change[i, 0] = 0.0 self.last_change[i, 2] = 1 - else: - if action: + else: # currently green + self.last_change[i, 0] += self.sim_step + if self.last_change[i, 0] >= self.min_green_time and rl_mask[i]: if self.last_change[i, 1] == 0: self.traffic_lights.set_state( node_id='center{}'.format(i), @@ -179,6 +185,7 @@ def _apply_rl_actions(self, rl_actions): self.last_change[i, 0] = 0.0 self.last_change[i, 1] = not self.last_change[i, 1] self.last_change[i, 2] = 0 + def compute_reward(self, state, rl_actions, **kwargs): return rewards.penalize_tl_changes(rl_actions >= 0.5, gain=1.0) diff --git a/tests/fast_tests/test_traffic_lights.py b/tests/fast_tests/test_traffic_lights.py index 04a884a9..2b15e187 100644 --- a/tests/fast_tests/test_traffic_lights.py +++ b/tests/fast_tests/test_traffic_lights.py @@ -9,6 +9,7 @@ from flow.core.experiment import SumoExperiment from flow.controllers.routing_controllers import GridRouter from flow.controllers.car_following_models import IDMController +from flow.envs.green_wave_env import PO_TrafficLightGridEnv os.environ["TEST_FLAG"] = "True" @@ -194,6 +195,45 @@ def test_k_closest(self): self.assertTrue(self.env.vehicles.get_edge(veh_id) in c0_edges) + def test_min_switch(self): + # FOR THE PURPOSES OF THIS TEST, never set min switch to be < 2 + + # reset the environment + self.env.reset() + + # Set up RL environment + self.env = PO_TrafficLightGridEnv(self.env.env_params, + self.env.sumo_params, + self.env.scenario) + + for i in range(len(self.env.last_change)): + self.assertEqual(self.env.last_change[i, 2], 1) + + # Run one step and make sure the count is incrementing + self.env.step([1] * self.env.num_traffic_lights) + for i in range(len(self.env.last_change)): + self.assertEqual(self.env.last_change[i, 0], 1) + self.assertEqual(self.env.last_change[i, 1], 0) + self.assertEqual(self.env.last_change[i, 2], 1) + + # Run until green switches to yellow + for i in range(int(self.env.min_green_time - 1.)): + self.env.step([1] * self.env.num_traffic_lights) + for i in range(len(self.env.last_change)): + self.assertEqual(self.env.last_change[i, 0], 0) + self.assertEqual(self.env.last_change[i, 1], 1) + self.assertEqual(self.env.last_change[i, 2], 0) + + # Run until yellow switches to green + for i in range(int(self.env.min_yellow_time)): + self.env.step([1] * self.env.num_traffic_lights) + + for i in range(len(self.env.last_change)): + self.assertEqual(self.env.last_change[i, 0], 0) + self.assertEqual(self.env.last_change[i, 1], 1) + self.assertEqual(self.env.last_change[i, 2], 1) + + class TestItRuns(unittest.TestCase): """ Tests the set_state function diff --git a/tests/setup_scripts.py b/tests/setup_scripts.py index d714e7ec..482a1e6a 100644 --- a/tests/setup_scripts.py +++ b/tests/setup_scripts.py @@ -335,7 +335,7 @@ def grid_mxn_exp_setup(row_num=1, if env_params is None: # set default env_params configuration additional_env_params = {"target_velocity": 50, "num_steps": 100, - "switch_time": 3.0} + "min_yellow_time": 3.0, "min_green_time": 8.0} env_params = EnvParams(additional_params=additional_env_params, horizon=100)