Skip to content

Commit

Permalink
Merge pull request #54 from NREL/ray_v2
Browse files Browse the repository at this point in the history
Update to use ray v2.3
  • Loading branch information
jlaw9 authored Mar 29, 2023
2 parents 68978b9 + 77e5f4b commit b0e9067
Show file tree
Hide file tree
Showing 15 changed files with 62 additions and 38 deletions.
3 changes: 2 additions & 1 deletion docs/source/background/introduction.rst
Original file line number Diff line number Diff line change
Expand Up @@ -75,5 +75,6 @@ random walk down a 1D corridor:
while not done:
action = random.choice(range(len(env.state.children)))
obs, reward, done, info = env.step(action)
obs, reward, terminated, truncated, info = env.step(action)
done = terminated or truncated
total_reward += reward
13 changes: 7 additions & 6 deletions docs/source/examples/hallway.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -334,8 +334,9 @@
}
],
"source": [
"obs = env.reset()\n",
"print(obs)"
"obs, info = env.reset()\n",
"print(obs)\n",
"print(info)"
]
},
{
Expand Down Expand Up @@ -379,7 +380,7 @@
],
"source": [
"# Not a valid action\n",
"obs, rew, done, info = env.step(1)"
"obs, rew, terminated, truncated, info = env.step(1)"
]
},
{
Expand All @@ -390,7 +391,7 @@
"outputs": [],
"source": [
"# A valid action\n",
"obs, rew, done, info = env.step(0)"
"obs, rew, terminated, truncated, info = env.step(0)"
]
},
{
Expand Down Expand Up @@ -504,7 +505,7 @@
"metadata": {},
"outputs": [],
"source": [
"obs, rew, done, info = env.step(1)"
"obs, rew, terminated, truncated, info = env.step(1)"
]
},
{
Expand Down Expand Up @@ -604,7 +605,7 @@
"env.step(0)\n",
"\n",
"for _ in range(5):\n",
" obs, rew, done, info = env.step(1)\n",
" obs, rew, terminated, truncated, info = env.step(1)\n",
"\n",
"env.make_observation()"
]
Expand Down
8 changes: 5 additions & 3 deletions docs/source/examples/tsp_docs.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -381,7 +381,8 @@
"rand_rew = 0.\n",
"while not done:\n",
" action = env.action_space.sample()\n",
" _, rew, done, _ = env.step(action)\n",
" _, rew, terminated, truncated, _ = env.step(action)\n",
" done = terminated or truncated\n",
" rand_rew += rew\n",
" \n",
"print(f\"Random reward = {rand_rew}\")\n",
Expand Down Expand Up @@ -425,15 +426,16 @@
}
],
"source": [
"obs = env.reset()\n",
"obs, info = env.reset()\n",
"\n",
"done = False\n",
"greedy_rew = 0.\n",
"i = 0\n",
"while not done:\n",
" # Get the node with shortest distance to the parent (current) node\n",
" idx = np.argmin([x[\"parent_dist\"] for x in obs[1:]]) \n",
" obs, rew, done, _ = env.step(idx)\n",
" obs, rew, terminated, truncated, _ = env.step(idx)\n",
" done = terminated or truncated\n",
" greedy_rew += rew\n",
" \n",
"print(f\"Greedy reward = {greedy_rew}\")\n",
Expand Down
11 changes: 6 additions & 5 deletions docs/source/examples/tsp_env.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@
"%%capture\n",
"\n",
"# Reset the environment and initialize the observation, reward, and done fields\n",
"obs = env.reset()\n",
"obs, info = env.reset()\n",
"greedy_reward = 0\n",
"done = False\n",
"\n",
Expand All @@ -74,7 +74,8 @@
"\n",
" # Get the observation for the next set of candidate nodes,\n",
" # incremental reward, and done flags\n",
" obs, reward, done, info = env.step(action)\n",
" obs, reward, terminated, truncated, info = env.step(action)\n",
" done = terminated or truncated\n",
"\n",
" # Append the step's reward to the running total\n",
" greedy_reward += reward\n",
Expand Down Expand Up @@ -182,7 +183,7 @@
" )[:k]\n",
"\n",
" for entry in top_actions:\n",
" obs, reward, done, info = entry[\"env\"].step(entry[\"action_index\"])\n",
" obs, reward, terminated, truncated, info = entry[\"env\"].step(entry[\"action_index\"])\n",
"\n",
" return [(entry[\"env\"], entry[\"reward\"]) for entry in top_actions], done"
]
Expand All @@ -194,7 +195,7 @@
"metadata": {},
"outputs": [],
"source": [
"obs = env.reset()\n",
"obs, info = env.reset()\n",
"env_list = [(env, 0)]\n",
"done = False\n",
"\n",
Expand All @@ -212,7 +213,7 @@
"metadata": {},
"outputs": [],
"source": [
"obs = env.reset()\n",
"obs, info = env.reset()\n",
"env_list = [(env, 0)]\n",
"done = False\n",
"\n",
Expand Down
18 changes: 13 additions & 5 deletions experiments/hallway/custom_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import os
import random

import gym
import gymnasium as gym
import ray
from gym.spaces import Box, Discrete
from ray import tune
Expand Down Expand Up @@ -87,19 +87,27 @@ def __init__(self, config: EnvContext):
# Set the seed. This is only used for the final (reach goal) reward.
self.seed(config.worker_index * config.num_workers)

def reset(self):
def reset(self, *, seed=None, options=None):
self.cur_pos = 0
return [self.cur_pos]
info_dict = {}
return [self.cur_pos], info_dict

def step(self, action):
assert action in [0, 1], action
if action == 0 and self.cur_pos > 0:
self.cur_pos -= 1
elif action == 1:
self.cur_pos += 1
done = self.cur_pos >= self.end_pos
terminated = self.cur_pos >= self.end_pos
truncated = False
# Produce a random reward when we reach the goal.
return [self.cur_pos], random.random() * 2 if done else -0.1, done, {}
return (
[self.cur_pos],
random.random() * 2 if terminated else -0.1,
terminated,
truncated,
{}
)

def seed(self, seed=None):
random.seed(seed)
Expand Down
10 changes: 6 additions & 4 deletions experiments/tsp/untrained_model_sampling.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,8 @@
" )\n",
" action_probabilities = tf.nn.softmax(masked_action_values).numpy()\n",
" action = np.random.choice(env.max_num_children, size=1, p=action_probabilities)[0]\n",
" obs, reward, done, info = env.step(action)\n",
" obs, reward, terminated, truncated, info = env.step(action)\n",
" done = terminated or truncated\n",
" total_reward += reward\n",
" \n",
" return total_reward"
Expand Down Expand Up @@ -213,7 +214,7 @@
}
],
"source": [
"obs = env.reset()\n",
"obs, info = env.reset()\n",
"env.observation_space.contains(obs)"
]
},
Expand Down Expand Up @@ -376,11 +377,12 @@
" # run until episode ends\n",
" episode_reward = 0\n",
" done = False\n",
" obs = env.reset()\n",
" obs, info = env.reset()\n",
"\n",
" while not done:\n",
" action = agent.compute_single_action(obs)\n",
" obs, reward, done, info = env.step(action)\n",
" obs, reward, terminated, truncated, info = env.step(action)\n",
" done = terminated or truncated\n",
" episode_reward += reward\n",
" \n",
" return episode_reward"
Expand Down
2 changes: 1 addition & 1 deletion graphenv/examples/hallway/hallway_state.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import random
from typing import Dict, Sequence

import gym
import gymnasium as gym
import numpy as np
from graphenv import tf
from graphenv.vertex import Vertex
Expand Down
2 changes: 1 addition & 1 deletion graphenv/examples/tsp/tsp_nfp_state.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from math import sqrt
from typing import Dict, Optional

import gym
import gymnasium as gym
import numpy as np
from graphenv.examples.tsp.tsp_preprocessor import TSPPreprocessor
from graphenv.examples.tsp.tsp_state import TSPState
Expand Down
2 changes: 1 addition & 1 deletion graphenv/examples/tsp/tsp_state.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from typing import Any, Callable, Dict, List, Optional, Sequence

import gym
import gymnasium as gym
import networkx as nx
import numpy as np
from graphenv import tf
Expand Down
18 changes: 13 additions & 5 deletions graphenv/graph_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import warnings
from typing import Any, Dict, List, Optional, Tuple

import gym
import gymnasium as gym
import numpy as np
from ray.rllib.env.env_context import EnvContext
from ray.rllib.utils.spaces.repeated import Repeated
Expand Down Expand Up @@ -61,17 +61,17 @@ def __init__(self, env_config: EnvContext) -> None:
self.action_space = gym.spaces.Discrete(self.max_num_children)
logger.debug("leaving graphenv construction")

def reset(self) -> Dict[str, np.ndarray]:
def reset(self, *, seed=None, options=None) -> Tuple[Dict[str, np.ndarray], Dict]:
"""Reset this state to the root vertex. It is possible for state.root to
return different root vertices on each call.
Returns:
Dict[str, np.ndarray]: Observation of the root vertex.
"""
self.state = self.state.root
return self.make_observation()
return self.make_observation(), self.state.info

def step(self, action: int) -> Tuple[Dict[str, np.ndarray], float, bool, dict]:
def step(self, action: int) -> Tuple[Dict[str, np.ndarray], float, bool, bool, dict]:
"""Steps the environment to a new state by taking an action. In the
case of GraphEnv, the action specifies which next vertex to move to and
this method advances the environment to that vertex.
Expand All @@ -86,7 +86,8 @@ def step(self, action: int) -> Tuple[Dict[str, np.ndarray], float, bool, dict]:
Tuple[Dict[str, np.ndarray], float, bool, dict]: Tuple of:
a dictionary of the new state's observation,
the reward received by moving to the new state's vertex,
a bool which is true iff the new stae is a terminal vertex,
a bool which is true iff the new state is a terminal vertex,
a bool which is true if the search is truncated
a dictionary of debugging information related to this call
"""

Expand Down Expand Up @@ -115,10 +116,17 @@ def step(self, action: int) -> Tuple[Dict[str, np.ndarray], float, bool, dict]:
RuntimeWarning,
)

# In RLlib 2.3, the config options "no_done_at_end", "horizon", and "soft_horizon" are no longer supported
# according to the migration guide https://docs.google.com/document/d/1lxYK1dI5s0Wo_jmB6V6XiP-_aEBsXDykXkD1AXRase4/edit#
# Instead, wrap your gymnasium environment with a TimeLimit wrapper,
# which will set truncated according to the number of timesteps
# see https://gymnasium.farama.org/api/wrappers/misc_wrappers/#gymnasium.wrappers.TimeLimit
truncated = False
result = (
self.make_observation(),
self.state.reward,
self.state.terminal,
truncated,
self.state.info,
)
logger.debug(
Expand Down
2 changes: 1 addition & 1 deletion graphenv/graph_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from abc import abstractmethod
from typing import Dict, List, Tuple

import gym
import gymnasium as gym
from ray.rllib.models.repeated_values import RepeatedValues
from ray.rllib.utils.typing import TensorStructType, TensorType

Expand Down
2 changes: 1 addition & 1 deletion graphenv/vertex.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from abc import abstractmethod
from typing import Any, Dict, Generic, List, Optional, Sequence, TypeVar

import gym
import gymnasium as gym

V = TypeVar("V")

Expand Down
3 changes: 2 additions & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,10 @@ classifiers =
packages = find:
install_requires =
networkx==3.0
ray[tune,rllib]==2.2.0
ray[tune,rllib]==2.3.1
numpy<1.24.0
tqdm==4.64.1
matplotlib

[options.extras_require]
tensorflow = tensorflow
Expand Down
4 changes: 2 additions & 2 deletions tests/test_hallway.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,14 +61,14 @@ def test_graphenv_reset(hallway_env: GraphEnv):


def test_graphenv_step(hallway_env: GraphEnv):
obs, reward, terminal, info = hallway_env.step(0)
obs, reward, terminal, truncated, info = hallway_env.step(0)

for _ in range(3):
assert terminal is False
assert reward == -0.1
assert hallway_env.observation_space.contains(obs)
assert hallway_env.action_space.contains(1)
obs, reward, terminal, info = hallway_env.step(1)
obs, reward, terminal, truncated, info = hallway_env.step(1)

assert terminal is True
assert reward > 0
Expand Down
2 changes: 1 addition & 1 deletion tests/test_tsp.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def test_graphenv():
}
)

obs = env.reset()
obs, info = env.reset()
assert env.observation_space.contains(obs)


Expand Down

0 comments on commit b0e9067

Please sign in to comment.