diff --git a/src/gym/envs/atari/environment.py b/src/gym/envs/atari/environment.py index 442d9bbc4..b584ba9fb 100644 --- a/src/gym/envs/atari/environment.py +++ b/src/gym/envs/atari/environment.py @@ -1,12 +1,12 @@ +import warnings +from typing import Optional, Union, Tuple, Dict, Any, List + import numpy as np import gym import gym.logger as logger from gym import error, spaces from gym import utils -from gym.utils import seeding - -from typing import Optional, Union, Tuple, Dict, Any, List import ale_py.roms as roms from ale_py._ale_py import ALEInterface, ALEState, Action, LoggerMode @@ -74,11 +74,22 @@ def __init__( raise error.Error( f"Invalid observation type: {obs_type}. Expecting: rgb, grayscale, ram." ) - if not ( - isinstance(frameskip, int) - or (isinstance(frameskip, tuple) and len(frameskip) == 2) - ): - raise error.Error(f"Invalid frameskip type: {frameskip}") + + if type(frameskip) not in (int, tuple): + raise error.Error(f"Invalid frameskip type: {type(frameskip)}.") + if isinstance(frameskip, int) and frameskip <= 0: + raise error.Error( + f"Invalid frameskip of {frameskip}, frameskip must be positive.") + elif isinstance(frameskip, tuple) and len(frameskip) != 2: + raise error.Error( + f"Invalid stochastic frameskip length of {len(frameskip)}, expected length 2.") + elif isinstance(frameskip, tuple) and frameskip[0] > frameskip[1]: + raise error.Error( + f"Invalid stochastic frameskip, lower bound is greater than upper bound.") + elif isinstance(frameskip, tuple) and frameskip[0] <= 0: + raise error.Error( + f"Invalid stochastic frameskip lower bound is greater than upper bound.") + if render_mode is not None and render_mode not in {"rgb_array", "human"}: raise error.Error( f"Render mode {render_mode} not supported (rgb_array, human)." @@ -98,7 +109,6 @@ def __init__( # Initialize ALE self.ale = ALEInterface() - self.viewer = None self._game = rom_id_to_name(game) @@ -112,7 +122,8 @@ def __init__( # Set logger mode to error only self.ale.setLoggerMode(LoggerMode.Error) # Config sticky action prob. - self.ale.setFloat("repeat_action_probability", repeat_action_probability) + self.ale.setFloat("repeat_action_probability", + repeat_action_probability) # If render mode is human we can display screen and sound if render_mode == "human": @@ -146,7 +157,8 @@ def __init__( low=0, high=255, dtype=np.uint8, shape=image_shape ) else: - raise error.Error(f"Unrecognized observation type: {self._obs_type}") + raise error.Error( + f"Unrecognized observation type: {self._obs_type}") def seed(self, seed: Optional[int] = None) -> Tuple[int, int]: """ @@ -162,10 +174,13 @@ def seed(self, seed: Optional[int] = None) -> Tuple[int, int]: Returns: tuple[int, int] => (np seed, ALE seed) """ - self.np_random, seed1 = seeding.np_random(seed) - seed2 = seeding.hash_seed(seed1 + 1) % 2 ** 31 + ss = np.random.SeedSequence(seed) + seed1, seed2 = ss.generate_state(n_words=2) - self.ale.setInt("random_seed", seed2) + self.np_random = np.random.default_rng(seed1) + # ALE only takes signed integers for `setInt`, it'll get converted back + # to unsigned in StellaEnvironment. + self.ale.setInt("random_seed", seed2.astype(np.int32)) if not hasattr(roms, self._game): raise error.Error( @@ -212,7 +227,7 @@ def step(self, action_ind: int) -> Tuple[np.ndarray, float, bool, Dict[str, Any] if isinstance(self._frameskip, int): frameskip = self._frameskip elif isinstance(self._frameskip, tuple): - frameskip = self.np_random.randint(*self._frameskip) + frameskip = self.np_random.integers(*self._frameskip) else: raise error.Error(f"Invalid frameskip type: {self._frameskip}") @@ -224,7 +239,7 @@ def step(self, action_ind: int) -> Tuple[np.ndarray, float, bool, Dict[str, Any] return self._get_obs(), reward, terminal, self._get_info() def reset( - self, *, seed: Optional[int] = None, return_info: bool = False + self, *, seed: Optional[int] = None, return_info: bool = False, options: Optional[Dict[str, Any]] = None ) -> Union[Tuple[np.ndarray, Dict[str, Any]], np.ndarray]: """ Resets environment and returns initial observation. @@ -247,7 +262,7 @@ def reset( else: return obs - def render(self, mode: str) -> None: + def render(self, mode: str) -> Any: """ Render is not supported by ALE. We use a paradigm similar to Gym3 which allows you to specify `render_mode` during construction. @@ -261,28 +276,21 @@ def render(self, mode: str) -> None: if mode == "rgb_array": return img elif mode == "human": - from gym.envs.classic_control import rendering - - if self.viewer is None: - logger.warn( - ( - "We strongly suggest supplying `render_mode` when " - "constructing your environment, e.g., gym.make(ID, render_mode='human'). " - "Using `render_mode` provides access to proper scaling, audio support, " - "and proper framerates." - ) + warnings.warn( + ( + "render('human') is deprecated. Please supply `render_mode` when " + "constructing your environment, e.g., gym.make(ID, render_mode='human'). " + "The new `render_mode` keyword argument supports DPI scaling, " + "audio support, and native framerates." ) - self.viewer = rendering.SimpleImageViewer() - self.viewer.imshow(img) - return self.viewer.isopen + ) + return False def close(self) -> None: """ Cleanup any leftovers by the environment """ - if self.viewer is not None: - self.viewer.close() - self.viewer = None + pass def _get_obs(self) -> np.ndarray: """ @@ -296,7 +304,8 @@ def _get_obs(self) -> np.ndarray: elif self._obs_type == "grayscale": return self.ale.getScreenGrayscale() else: - raise error.Error(f"Unrecognized observation type: {self._obs_type}") + raise error.Error( + f"Unrecognized observation type: {self._obs_type}") def _get_info(self) -> Dict[str, Any]: info = { diff --git a/tests/python/gym/test_gym_interface.py b/tests/python/gym/test_gym_interface.py index 8d95826c1..7b2ee9dff 100644 --- a/tests/python/gym/test_gym_interface.py +++ b/tests/python/gym/test_gym_interface.py @@ -1,23 +1,24 @@ +# fmt: off import pytest pytest.importorskip("gym") pytest.importorskip("gym.envs.atari") -import numpy as np - -from unittest.mock import patch -from itertools import product - -from gym import spaces -from gym.envs.registration import registry -from gym.core import Env -from gym.utils.env_checker import check_env - from ale_py.gym import ( register_legacy_gym_envs, _register_gym_configs, register_gym_envs, ) +from gym import error +from gym.utils.env_checker import check_env +from gym.core import Env +from gym.envs.registration import registry +from gym.envs.atari.environment import AtariEnv +from gym import spaces +from itertools import product +from unittest.mock import patch +import numpy as np +# fmt: on def test_register_legacy_env_id(): @@ -123,7 +124,8 @@ def test_register_gym_envs(test_rom_path): suffixes = [] versions = ["-v5"] - all_ids = set(map("".join, product(games, obs_types, suffixes, versions))) + all_ids = set(map("".join, product( + games, obs_types, suffixes, versions))) assert all_ids.issubset(envids) @@ -331,6 +333,13 @@ def test_gym_reset_with_infos(tetris_gym): assert "rgb" in info +@pytest.mark.parametrize("frameskip", [0, -1, 4.0, (-1, 5), (0, 5), (5, 2), (1, 2, 3)]) +def test_frameskip_warnings(test_rom_path, frameskip): + with patch("ale_py.roms.Tetris", create=True, new_callable=lambda: test_rom_path): + with pytest.raises(error.Error): + AtariEnv('Tetris', frameskip=frameskip) + + def test_gym_compliance(tetris_gym): try: check_env(tetris_gym) diff --git a/tests/python/gym/test_legacy_registration.py b/tests/python/gym/test_legacy_registration.py index a769f9095..e6f13fd9a 100644 --- a/tests/python/gym/test_legacy_registration.py +++ b/tests/python/gym/test_legacy_registration.py @@ -99,7 +99,7 @@ def test_legacy_env_specs(): """ for spec in specs: assert spec in registry.env_specs - kwargs = registry.env_specs[spec]._kwargs + kwargs = registry.env_specs[spec].kwargs max_episode_steps = registry.env_specs[spec].max_episode_steps # Assert necessary parameters are set