Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Change _apply_rl_actions in green wave and added tests #42

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 12 additions & 12 deletions examples/rllab/green_wave.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,16 +61,16 @@ def get_non_flow_params(enter_speed, additional_net_params):


def run_task(*_):
v_enter = 10
inner_length = 300
long_length = 100
short_length = 300
n = 3
m = 3
num_cars_left = 1
num_cars_right = 1
num_cars_top = 1
num_cars_bot = 1
v_enter = 20
inner_length = 200
long_length = 200
short_length = 200
n = 1
m = 1
num_cars_left = 10
num_cars_right = 10
num_cars_top = 10
num_cars_bot = 10
tot_cars = (num_cars_left + num_cars_right) * m \
+ (num_cars_bot + num_cars_top) * n

Expand All @@ -95,8 +95,8 @@ def run_task(*_):

tl_logic = TrafficLights(baseline=False)

additional_env_params = {"target_velocity": 50, "num_steps": 500,
"switch_time": 3.0}
additional_env_params = {"target_velocity": 20, "num_steps": 500,
"min_yellow_time": 4.0, "min_green_time": 8.0}
env_params = EnvParams(additional_params=additional_env_params)

additional_net_params = {"speed_limit": 35, "grid_array": grid_array,
Expand Down
21 changes: 14 additions & 7 deletions flow/envs/green_wave_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,10 @@
from flow.envs.base_env import Env

ADDITIONAL_ENV_PARAMS = {
# minimum switch time for each traffic light (in seconds)
"switch_time": 2.0,
# minimum switch time for a traffic light in the yellow phase (in seconds)
"min_yellow_time": 2.0,
# minimum switch time for a traffic light in the green phase (in seconds)
"min_green_time": 8.0,
# whether the traffic lights should be actuated by sumo or RL
# options are "controlled" and "actuated"
"tl_type": "controlled",
Expand Down Expand Up @@ -76,11 +78,13 @@ def __init__(self, env_params, sumo_params, scenario):
# keeps track of the last time the light was allowed to change.
self.last_change = np.zeros((self.rows * self.cols, 3))

# when this hits min_switch_time we change from yellow to red
# When this hits min_yellow_time we can change from yellow to red.
# When this hits min_green_time, we can change from green to yellow.
# the second column indicates the direction that is currently being
# allowed to flow. 0 is flowing top to bottom, 1 is left to right
# For third column, 0 signifies yellow and 1 green or red
self.min_switch_time = env_params.additional_params["switch_time"]
self.min_yellow_time = env_params.additional_params["min_yellow_time"]
self.min_green_time = env_params.additional_params["min_green_time"]

if self.tl_type != "actuated":
for i in range(self.rows * self.cols):
Expand Down Expand Up @@ -156,7 +160,7 @@ def _apply_rl_actions(self, rl_actions):
# should switch to red
if self.last_change[i, 2] == 0: # currently yellow
self.last_change[i, 0] += self.sim_step
if self.last_change[i, 0] >= self.min_switch_time:
if self.last_change[i, 0] >= self.min_yellow_time:
if self.last_change[i, 1] == 0:
self.traffic_lights.set_state(
node_id='center{}'.format(i),
Expand All @@ -165,9 +169,11 @@ def _apply_rl_actions(self, rl_actions):
self.traffic_lights.set_state(
node_id='center{}'.format(i),
state='rrrGGGrrrGGG', env=self)
self.last_change[i, 0] = 0.0
self.last_change[i, 2] = 1
else:
if action:
else: # currently green
self.last_change[i, 0] += self.sim_step
if self.last_change[i, 0] >= self.min_green_time and rl_mask[i]:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why does self.last_change get incremented inside the else when it didn't before? (not too familiar with the green wave stuff though so I might just be missing something obvious)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

self.last_change used to only keep track of last change since a yellow phase began. Now it also keeps track of time since a green phase began. So now, last_change[i, 0] resets when the green phase begins (it didn't use to) and increments during the green phase as well.

if self.last_change[i, 1] == 0:
self.traffic_lights.set_state(
node_id='center{}'.format(i),
Expand All @@ -179,6 +185,7 @@ def _apply_rl_actions(self, rl_actions):
self.last_change[i, 0] = 0.0
self.last_change[i, 1] = not self.last_change[i, 1]
self.last_change[i, 2] = 0


def compute_reward(self, state, rl_actions, **kwargs):
return rewards.penalize_tl_changes(rl_actions >= 0.5, gain=1.0)
Expand Down
40 changes: 40 additions & 0 deletions tests/fast_tests/test_traffic_lights.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from flow.core.experiment import SumoExperiment
from flow.controllers.routing_controllers import GridRouter
from flow.controllers.car_following_models import IDMController
from flow.envs.green_wave_env import PO_TrafficLightGridEnv

os.environ["TEST_FLAG"] = "True"

Expand Down Expand Up @@ -194,6 +195,45 @@ def test_k_closest(self):
self.assertTrue(self.env.vehicles.get_edge(veh_id) in c0_edges)


def test_min_switch(self):
# FOR THE PURPOSES OF THIS TEST, never set min switch to be < 2
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this test enforce the constraint of min_switch >= 2?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The setup script it uses has min_switch >=2, is that what you mean?


# reset the environment
self.env.reset()

# Set up RL environment
self.env = PO_TrafficLightGridEnv(self.env.env_params,
self.env.sumo_params,
self.env.scenario)

for i in range(len(self.env.last_change)):
self.assertEqual(self.env.last_change[i, 2], 1)

# Run one step and make sure the count is incrementing
self.env.step([1] * self.env.num_traffic_lights)
for i in range(len(self.env.last_change)):
self.assertEqual(self.env.last_change[i, 0], 1)
self.assertEqual(self.env.last_change[i, 1], 0)
self.assertEqual(self.env.last_change[i, 2], 1)

# Run until green switches to yellow
for i in range(int(self.env.min_green_time - 1.)):
self.env.step([1] * self.env.num_traffic_lights)
for i in range(len(self.env.last_change)):
self.assertEqual(self.env.last_change[i, 0], 0)
self.assertEqual(self.env.last_change[i, 1], 1)
self.assertEqual(self.env.last_change[i, 2], 0)

# Run until yellow switches to green
for i in range(int(self.env.min_yellow_time)):
self.env.step([1] * self.env.num_traffic_lights)

for i in range(len(self.env.last_change)):
self.assertEqual(self.env.last_change[i, 0], 0)
self.assertEqual(self.env.last_change[i, 1], 1)
self.assertEqual(self.env.last_change[i, 2], 1)


class TestItRuns(unittest.TestCase):
"""
Tests the set_state function
Expand Down
2 changes: 1 addition & 1 deletion tests/setup_scripts.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,7 +335,7 @@ def grid_mxn_exp_setup(row_num=1,
if env_params is None:
# set default env_params configuration
additional_env_params = {"target_velocity": 50, "num_steps": 100,
"switch_time": 3.0}
"min_yellow_time": 3.0, "min_green_time": 8.0}

env_params = EnvParams(additional_params=additional_env_params,
horizon=100)
Expand Down