diff --git a/gymnasium/wrappers/rendering.py b/gymnasium/wrappers/rendering.py index 526c47bff..abc5a3ee1 100644 --- a/gymnasium/wrappers/rendering.py +++ b/gymnasium/wrappers/rendering.py @@ -463,11 +463,12 @@ class HumanRendering( "depth_array_list", ] - def __init__(self, env: gym.Env[ObsType, ActType]): + def __init__(self, env: gym.Env[ObsType, ActType], auto_rendering: bool = True): """Initialize a :class:`HumanRendering` instance. Args: env: The environment that is being wrapped + auto_rendering: Whether to automatically render the environment on step() and reset() """ gym.utils.RecordConstructorArgs.__init__(self) gym.Wrapper.__init__(self, env) @@ -482,7 +483,9 @@ def __init__(self, env: gym.Env[ObsType, ActType]): self.screen_size = None self.window = None self.clock = None + self.auto_rendering = auto_rendering + # TODO: needed? if "human" not in self.metadata["render_modes"]: self.metadata = deepcopy(self.env.metadata) self.metadata["render_modes"].append("human") @@ -495,7 +498,8 @@ def render_mode(self): def step(self, action: ActType) -> tuple[ObsType, SupportsFloat, bool, bool, dict]: """Perform a step in the base environment and render a frame to the screen.""" result = super().step(action) - self._render_frame() + if self.auto_rendering: + self._render_frame() return result def reset( @@ -503,12 +507,13 @@ def reset( ) -> tuple[ObsType, dict[str, Any]]: """Reset the base environment and render a frame to the screen.""" result = super().reset(seed=seed, options=options) - self._render_frame() + if self.auto_rendering: + self._render_frame() return result - def render(self) -> None: - """This method doesn't do much, actual rendering is performed in :meth:`step` and :meth:`reset`.""" - return None + def render(self): + """This method doesn't do much, actual rendering is usually performed in :meth:`step` and :meth:`reset`.""" + return self._render_frame() def _render_frame(self): """Fetch the last frame from the base environment and render it to the screen.""" @@ -520,11 +525,12 @@ def _render_frame(self): ) assert self.env.render_mode is not None if self.env.render_mode.endswith("_list"): - last_rgb_array = self.env.render() - assert isinstance(last_rgb_array, list) - last_rgb_array = last_rgb_array[-1] + last_render = self.env.render() + assert isinstance(last_render, list) + last_rgb_array = last_render[-1] else: - last_rgb_array = self.env.render() + last_render = self.env.render() + last_rgb_array = last_render assert isinstance( last_rgb_array, np.ndarray @@ -553,6 +559,8 @@ def _render_frame(self): self.clock.tick(self.metadata["render_fps"]) pygame.display.flip() + return last_render + def close(self): """Close the rendering window.""" if self.window is not None: diff --git a/gymnasium/wrappers/vector/rendering.py b/gymnasium/wrappers/vector/rendering.py index 20e92a667..7b66ad078 100644 --- a/gymnasium/wrappers/vector/rendering.py +++ b/gymnasium/wrappers/vector/rendering.py @@ -22,12 +22,18 @@ class HumanRendering(VectorWrapper): "depth_array_list", ] - def __init__(self, env: VectorEnv, screen_size: tuple[int, int] | None = None): + def __init__( + self, + env: VectorEnv, + screen_size: tuple[int, int] | None = None, + auto_rendering: bool = True, + ): """Constructor for Human Rendering of Vector-based environments. Args: env: The vector environment screen_size: The rendering screen size otherwise the environment sub-env render size is used + auto_rendering: Whether to automatically render the environment on step() and reset() """ VectorWrapper.__init__(self, env) @@ -42,7 +48,9 @@ def __init__(self, env: VectorEnv, screen_size: tuple[int, int] | None = None): self.scaled_subenv_size, self.num_rows, self.num_cols = None, None, None self.window = None self.clock = None + self.auto_rendering = auto_rendering + # TODO: needed if "human" not in self.metadata["render_modes"]: self.metadata = deepcopy(self.env.metadata) self.metadata["render_modes"].append("human") @@ -57,7 +65,8 @@ def step( ) -> tuple[ObsType, ArrayType, ArrayType, ArrayType, dict[str, Any]]: """Perform a step in the base environment and render a frame to the screen.""" result = super().step(actions) - self._render_frame() + if self.auto_rendering: + self._render_frame() return result def reset( @@ -68,9 +77,14 @@ def reset( ) -> tuple[ObsType, dict[str, Any]]: """Reset the base environment and render a frame to the screen.""" result = super().reset(seed=seed, options=options) - self._render_frame() + if self.auto_rendering: + self._render_frame() return result + def render(self): + """This method doesn't do much, actual rendering is usually performed in :meth:`step` and :meth:`reset`.""" + return self._render_frame() + def _render_frame(self): """Fetch the last frame from the base environment and render it to the screen.""" try: @@ -82,11 +96,12 @@ def _render_frame(self): assert self.env.render_mode is not None if self.env.render_mode.endswith("_last"): - subenv_renders = self.env.render() - assert isinstance(subenv_renders, list) - subenv_renders = subenv_renders[-1] + last_render = self.env.render() + assert isinstance(last_render, list) + subenv_renders = last_render[-1] else: - subenv_renders = self.env.render() + last_render = self.env.render() + subenv_renders = last_render assert subenv_renders is not None assert len(subenv_renders) == self.num_envs @@ -173,6 +188,8 @@ def _render_frame(self): self.clock.tick(self.metadata["render_fps"]) pygame.display.flip() + return last_render + def close(self): """Close the rendering window.""" if self.window is not None: diff --git a/tests/wrappers/test_human_rendering.py b/tests/wrappers/test_human_rendering.py index bf9ff3931..53324883e 100644 --- a/tests/wrappers/test_human_rendering.py +++ b/tests/wrappers/test_human_rendering.py @@ -1,6 +1,7 @@ """Test suite of HumanRendering wrapper.""" import re +import numpy as np import pytest import gymnasium as gym @@ -31,3 +32,25 @@ def test_human_rendering(): ): HumanRendering(env) env.close() + + +@pytest.mark.parametrize("env_id", ["CartPole-v1"]) +@pytest.mark.parametrize("num_envs", [1, 3, 9]) +@pytest.mark.parametrize("screen_size", [None]) +def test_human_rendering_manual(env_id, num_envs, screen_size): + env = HumanRendering( + gym.make("CartPole-v1", render_mode="rgb_array", disable_env_checker=True), + auto_rendering=False, + ) + assert env.render_mode == "human" + env.reset() + + for _ in range(75): + _, _, terminated, truncated, _ = env.step(env.action_space.sample()) + if terminated or truncated: + env.reset() + rendering = env.render() + # output should match mode + assert isinstance(rendering, np.ndarray) + + env.close() diff --git a/tests/wrappers/vector/test_human_rendering.py b/tests/wrappers/vector/test_human_rendering.py index 9ddb390d7..7588c1b92 100644 --- a/tests/wrappers/vector/test_human_rendering.py +++ b/tests/wrappers/vector/test_human_rendering.py @@ -1,6 +1,7 @@ """Test suite of HumanRendering wrapper.""" import re +import numpy as np import pytest import gymnasium as gym @@ -19,6 +20,7 @@ def test_num_envs_screen_size(env_id, num_envs, screen_size): envs.reset() for _ in range(25): envs.step(envs.action_space.sample()) + envs.close() @@ -41,3 +43,26 @@ def test_render_modes(): ), ): HumanRendering(envs) + + +@pytest.mark.parametrize("env_id", ["CartPole-v1"]) +@pytest.mark.parametrize("num_envs", [1, 3, 9]) +@pytest.mark.parametrize("screen_size", [None]) +def test_human_rendering_manual(env_id, num_envs, screen_size): + envs = gym.make_vec(env_id, num_envs=num_envs, render_mode="rgb_array") + envs = HumanRendering(envs, screen_size=screen_size, auto_rendering=False) + + assert envs.render_mode == "human" + assert not envs.auto_rendering + + envs.reset() + + # Test Manual render() call + envs.step(envs.action_space.sample()) + rendering = envs.render() + # output should match mode, list of environment rgb_arrays + assert isinstance(rendering, list) + assert len(rendering) == num_envs + assert isinstance(rendering[0], np.ndarray) + + envs.close()