From 18daeec0573b1860a13fb22c5222db5632cfcd9d Mon Sep 17 00:00:00 2001 From: Roger Larsson Date: Fri, 31 May 2024 21:10:44 +0200 Subject: [PATCH] Rework rendering API for simultaneous "human" and "rgb_array" mode (#1010) --- gymnasium/wrappers/rendering.py | 27 ++++++++++------- gymnasium/wrappers/vector/rendering.py | 30 ++++++++++++++----- tests/wrappers/test_human_rendering.py | 23 ++++++++++++++ tests/wrappers/vector/test_human_rendering.py | 25 ++++++++++++++++ 4 files changed, 88 insertions(+), 17 deletions(-) diff --git a/gymnasium/wrappers/rendering.py b/gymnasium/wrappers/rendering.py index 80206e78b..4ea8105b9 100644 --- a/gymnasium/wrappers/rendering.py +++ b/gymnasium/wrappers/rendering.py @@ -463,11 +463,12 @@ class HumanRendering( "depth_array_list", ] - def __init__(self, env: gym.Env[ObsType, ActType]): + def __init__(self, env: gym.Env[ObsType, ActType], auto_rendering: bool = True): """Initialize a :class:`HumanRendering` instance. Args: env: The environment that is being wrapped + auto_rendering: Whether to automatically render the environment on step() and reset() """ gym.utils.RecordConstructorArgs.__init__(self) gym.Wrapper.__init__(self, env) @@ -475,6 +476,7 @@ def __init__(self, env: gym.Env[ObsType, ActType]): self.screen_size = None self.window = None # Has to be initialized before asserts, as self.window is used in auto close self.clock = None + self.auto_rendering = auto_rendering assert ( self.env.render_mode in self.ACCEPTED_RENDER_MODES @@ -495,7 +497,8 @@ def render_mode(self): def step(self, action: ActType) -> tuple[ObsType, SupportsFloat, bool, bool, dict]: """Perform a step in the base environment and render a frame to the screen.""" result = super().step(action) - self._render_frame() + if self.auto_rendering: + self._render_frame() return result def reset( @@ -503,12 +506,13 @@ def reset( ) -> tuple[ObsType, dict[str, Any]]: """Reset the base environment and render a frame to the screen.""" result = super().reset(seed=seed, options=options) - self._render_frame() + if self.auto_rendering: + self._render_frame() return result - def render(self) -> None: - """This method doesn't do much, actual rendering is performed in :meth:`step` and :meth:`reset`.""" - return None + def render(self): + """This method doesn't do much, actual rendering is usually performed in :meth:`step` and :meth:`reset`.""" + return self._render_frame() def _render_frame(self): """Fetch the last frame from the base environment and render it to the screen.""" @@ -520,11 +524,12 @@ def _render_frame(self): ) assert self.env.render_mode is not None if self.env.render_mode.endswith("_list"): - last_rgb_array = self.env.render() - assert isinstance(last_rgb_array, list) - last_rgb_array = last_rgb_array[-1] + last_render = self.env.render() + assert isinstance(last_render, list) + last_rgb_array = last_render[-1] else: - last_rgb_array = self.env.render() + last_render = self.env.render() + last_rgb_array = last_render assert isinstance( last_rgb_array, np.ndarray @@ -553,6 +558,8 @@ def _render_frame(self): self.clock.tick(self.metadata["render_fps"]) pygame.display.flip() + return last_render + def close(self): """Close the rendering window.""" if self.window is not None: diff --git a/gymnasium/wrappers/vector/rendering.py b/gymnasium/wrappers/vector/rendering.py index 95c94bd1b..fc4076e4b 100644 --- a/gymnasium/wrappers/vector/rendering.py +++ b/gymnasium/wrappers/vector/rendering.py @@ -22,12 +22,18 @@ class HumanRendering(VectorWrapper): "depth_array_list", ] - def __init__(self, env: VectorEnv, screen_size: tuple[int, int] | None = None): + def __init__( + self, + env: VectorEnv, + screen_size: tuple[int, int] | None = None, + auto_rendering: bool = True, + ): """Constructor for Human Rendering of Vector-based environments. Args: env: The vector environment screen_size: The rendering screen size otherwise the environment sub-env render size is used + auto_rendering: Whether to automatically render the environment on step() and reset() """ VectorWrapper.__init__(self, env) @@ -35,6 +41,7 @@ def __init__(self, env: VectorEnv, screen_size: tuple[int, int] | None = None): self.scaled_subenv_size, self.num_rows, self.num_cols = None, None, None self.window = None # Has to be initialized before asserts, as self.window is used in auto close self.clock = None + self.auto_rendering = auto_rendering assert ( self.env.render_mode in self.ACCEPTED_RENDER_MODES @@ -57,7 +64,8 @@ def step( ) -> tuple[ObsType, ArrayType, ArrayType, ArrayType, dict[str, Any]]: """Perform a step in the base environment and render a frame to the screen.""" result = super().step(actions) - self._render_frame() + if self.auto_rendering: + self._render_frame() return result def reset( @@ -68,9 +76,14 @@ def reset( ) -> tuple[ObsType, dict[str, Any]]: """Reset the base environment and render a frame to the screen.""" result = super().reset(seed=seed, options=options) - self._render_frame() + if self.auto_rendering: + self._render_frame() return result + def render(self): + """This method doesn't do much, actual rendering is usually performed in :meth:`step` and :meth:`reset`.""" + return self._render_frame() + def _render_frame(self): """Fetch the last frame from the base environment and render it to the screen.""" try: @@ -82,11 +95,12 @@ def _render_frame(self): assert self.env.render_mode is not None if self.env.render_mode.endswith("_last"): - subenv_renders = self.env.render() - assert isinstance(subenv_renders, list) - subenv_renders = subenv_renders[-1] + last_render = self.env.render() + assert isinstance(last_render, list) + subenv_renders = last_render[-1] else: - subenv_renders = self.env.render() + last_render = self.env.render() + subenv_renders = last_render assert subenv_renders is not None assert len(subenv_renders) == self.num_envs @@ -173,6 +187,8 @@ def _render_frame(self): self.clock.tick(self.metadata["render_fps"]) pygame.display.flip() + return last_render + def close(self): """Close the rendering window.""" if self.window is not None: diff --git a/tests/wrappers/test_human_rendering.py b/tests/wrappers/test_human_rendering.py index bf9ff3931..53324883e 100644 --- a/tests/wrappers/test_human_rendering.py +++ b/tests/wrappers/test_human_rendering.py @@ -1,6 +1,7 @@ """Test suite of HumanRendering wrapper.""" import re +import numpy as np import pytest import gymnasium as gym @@ -31,3 +32,25 @@ def test_human_rendering(): ): HumanRendering(env) env.close() + + +@pytest.mark.parametrize("env_id", ["CartPole-v1"]) +@pytest.mark.parametrize("num_envs", [1, 3, 9]) +@pytest.mark.parametrize("screen_size", [None]) +def test_human_rendering_manual(env_id, num_envs, screen_size): + env = HumanRendering( + gym.make("CartPole-v1", render_mode="rgb_array", disable_env_checker=True), + auto_rendering=False, + ) + assert env.render_mode == "human" + env.reset() + + for _ in range(75): + _, _, terminated, truncated, _ = env.step(env.action_space.sample()) + if terminated or truncated: + env.reset() + rendering = env.render() + # output should match mode + assert isinstance(rendering, np.ndarray) + + env.close() diff --git a/tests/wrappers/vector/test_human_rendering.py b/tests/wrappers/vector/test_human_rendering.py index 9ddb390d7..7588c1b92 100644 --- a/tests/wrappers/vector/test_human_rendering.py +++ b/tests/wrappers/vector/test_human_rendering.py @@ -1,6 +1,7 @@ """Test suite of HumanRendering wrapper.""" import re +import numpy as np import pytest import gymnasium as gym @@ -19,6 +20,7 @@ def test_num_envs_screen_size(env_id, num_envs, screen_size): envs.reset() for _ in range(25): envs.step(envs.action_space.sample()) + envs.close() @@ -41,3 +43,26 @@ def test_render_modes(): ), ): HumanRendering(envs) + + +@pytest.mark.parametrize("env_id", ["CartPole-v1"]) +@pytest.mark.parametrize("num_envs", [1, 3, 9]) +@pytest.mark.parametrize("screen_size", [None]) +def test_human_rendering_manual(env_id, num_envs, screen_size): + envs = gym.make_vec(env_id, num_envs=num_envs, render_mode="rgb_array") + envs = HumanRendering(envs, screen_size=screen_size, auto_rendering=False) + + assert envs.render_mode == "human" + assert not envs.auto_rendering + + envs.reset() + + # Test Manual render() call + envs.step(envs.action_space.sample()) + rendering = envs.render() + # output should match mode, list of environment rgb_arrays + assert isinstance(rendering, list) + assert len(rendering) == num_envs + assert isinstance(rendering[0], np.ndarray) + + envs.close()