Skip to content

Commit

Permalink
Rework rendering API for simultaneous "human" and "rgb_array" mode (F…
Browse files Browse the repository at this point in the history
  • Loading branch information
RogerJL committed May 31, 2024
1 parent 04fb345 commit b537395
Show file tree
Hide file tree
Showing 4 changed files with 90 additions and 17 deletions.
28 changes: 18 additions & 10 deletions gymnasium/wrappers/rendering.py
Original file line number Diff line number Diff line change
Expand Up @@ -463,11 +463,12 @@ class HumanRendering(
"depth_array_list",
]

def __init__(self, env: gym.Env[ObsType, ActType]):
def __init__(self, env: gym.Env[ObsType, ActType], auto_rendering: bool = True):
"""Initialize a :class:`HumanRendering` instance.
Args:
env: The environment that is being wrapped
auto_rendering: Whether to automatically render the environment on step() and reset()
"""
gym.utils.RecordConstructorArgs.__init__(self)
gym.Wrapper.__init__(self, env)
Expand All @@ -482,7 +483,9 @@ def __init__(self, env: gym.Env[ObsType, ActType]):
self.screen_size = None
self.window = None
self.clock = None
self.auto_rendering = auto_rendering

# TODO: needed?
if "human" not in self.metadata["render_modes"]:
self.metadata = deepcopy(self.env.metadata)
self.metadata["render_modes"].append("human")
Expand All @@ -495,20 +498,22 @@ def render_mode(self):
def step(self, action: ActType) -> tuple[ObsType, SupportsFloat, bool, bool, dict]:
"""Perform a step in the base environment and render a frame to the screen."""
result = super().step(action)
self._render_frame()
if self.auto_rendering:
self._render_frame()
return result

def reset(
self, *, seed: int | None = None, options: dict[str, Any] | None = None
) -> tuple[ObsType, dict[str, Any]]:
"""Reset the base environment and render a frame to the screen."""
result = super().reset(seed=seed, options=options)
self._render_frame()
if self.auto_rendering:
self._render_frame()
return result

def render(self) -> None:
"""This method doesn't do much, actual rendering is performed in :meth:`step` and :meth:`reset`."""
return None
def render(self):
"""This method doesn't do much, actual rendering is usually performed in :meth:`step` and :meth:`reset`."""
return self._render_frame()

def _render_frame(self):
"""Fetch the last frame from the base environment and render it to the screen."""
Expand All @@ -520,11 +525,12 @@ def _render_frame(self):
)
assert self.env.render_mode is not None
if self.env.render_mode.endswith("_list"):
last_rgb_array = self.env.render()
assert isinstance(last_rgb_array, list)
last_rgb_array = last_rgb_array[-1]
last_render = self.env.render()
assert isinstance(last_render, list)
last_rgb_array = last_render[-1]
else:
last_rgb_array = self.env.render()
last_render = self.env.render()
last_rgb_array = last_render

assert isinstance(
last_rgb_array, np.ndarray
Expand Down Expand Up @@ -553,6 +559,8 @@ def _render_frame(self):
self.clock.tick(self.metadata["render_fps"])
pygame.display.flip()

return last_render

def close(self):
"""Close the rendering window."""
if self.window is not None:
Expand Down
31 changes: 24 additions & 7 deletions gymnasium/wrappers/vector/rendering.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,18 @@ class HumanRendering(VectorWrapper):
"depth_array_list",
]

def __init__(self, env: VectorEnv, screen_size: tuple[int, int] | None = None):
def __init__(
self,
env: VectorEnv,
screen_size: tuple[int, int] | None = None,
auto_rendering: bool = True,
):
"""Constructor for Human Rendering of Vector-based environments.
Args:
env: The vector environment
screen_size: The rendering screen size otherwise the environment sub-env render size is used
auto_rendering: Whether to automatically render the environment on step() and reset()
"""
VectorWrapper.__init__(self, env)

Expand All @@ -42,7 +48,9 @@ def __init__(self, env: VectorEnv, screen_size: tuple[int, int] | None = None):
self.scaled_subenv_size, self.num_rows, self.num_cols = None, None, None
self.window = None
self.clock = None
self.auto_rendering = auto_rendering

# TODO: needed
if "human" not in self.metadata["render_modes"]:
self.metadata = deepcopy(self.env.metadata)
self.metadata["render_modes"].append("human")
Expand All @@ -57,7 +65,8 @@ def step(
) -> tuple[ObsType, ArrayType, ArrayType, ArrayType, dict[str, Any]]:
"""Perform a step in the base environment and render a frame to the screen."""
result = super().step(actions)
self._render_frame()
if self.auto_rendering:
self._render_frame()
return result

def reset(
Expand All @@ -68,9 +77,14 @@ def reset(
) -> tuple[ObsType, dict[str, Any]]:
"""Reset the base environment and render a frame to the screen."""
result = super().reset(seed=seed, options=options)
self._render_frame()
if self.auto_rendering:
self._render_frame()
return result

def render(self):
"""This method doesn't do much, actual rendering is usually performed in :meth:`step` and :meth:`reset`."""
return self._render_frame()

def _render_frame(self):
"""Fetch the last frame from the base environment and render it to the screen."""
try:
Expand All @@ -82,11 +96,12 @@ def _render_frame(self):

assert self.env.render_mode is not None
if self.env.render_mode.endswith("_last"):
subenv_renders = self.env.render()
assert isinstance(subenv_renders, list)
subenv_renders = subenv_renders[-1]
last_render = self.env.render()
assert isinstance(last_render, list)
subenv_renders = last_render[-1]
else:
subenv_renders = self.env.render()
last_render = self.env.render()
subenv_renders = last_render

assert subenv_renders is not None
assert len(subenv_renders) == self.num_envs
Expand Down Expand Up @@ -173,6 +188,8 @@ def _render_frame(self):
self.clock.tick(self.metadata["render_fps"])
pygame.display.flip()

return last_render

def close(self):
"""Close the rendering window."""
if self.window is not None:
Expand Down
23 changes: 23 additions & 0 deletions tests/wrappers/test_human_rendering.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Test suite of HumanRendering wrapper."""
import re

import numpy as np
import pytest

import gymnasium as gym
Expand Down Expand Up @@ -31,3 +32,25 @@ def test_human_rendering():
):
HumanRendering(env)
env.close()


@pytest.mark.parametrize("env_id", ["CartPole-v1"])
@pytest.mark.parametrize("num_envs", [1, 3, 9])
@pytest.mark.parametrize("screen_size", [None])
def test_human_rendering_manual(env_id, num_envs, screen_size):
env = HumanRendering(
gym.make("CartPole-v1", render_mode="rgb_array", disable_env_checker=True),
auto_rendering=False,
)
assert env.render_mode == "human"
env.reset()

for _ in range(75):
_, _, terminated, truncated, _ = env.step(env.action_space.sample())
if terminated or truncated:
env.reset()
rendering = env.render()
# output should match mode
assert isinstance(rendering, np.ndarray)

env.close()
25 changes: 25 additions & 0 deletions tests/wrappers/vector/test_human_rendering.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Test suite of HumanRendering wrapper."""
import re

import numpy as np
import pytest

import gymnasium as gym
Expand All @@ -19,6 +20,7 @@ def test_num_envs_screen_size(env_id, num_envs, screen_size):
envs.reset()
for _ in range(25):
envs.step(envs.action_space.sample())

envs.close()


Expand All @@ -41,3 +43,26 @@ def test_render_modes():
),
):
HumanRendering(envs)


@pytest.mark.parametrize("env_id", ["CartPole-v1"])
@pytest.mark.parametrize("num_envs", [1, 3, 9])
@pytest.mark.parametrize("screen_size", [None])
def test_human_rendering_manual(env_id, num_envs, screen_size):
envs = gym.make_vec(env_id, num_envs=num_envs, render_mode="rgb_array")
envs = HumanRendering(envs, screen_size=screen_size, auto_rendering=False)

assert envs.render_mode == "human"
assert not envs.auto_rendering

envs.reset()

# Test Manual render() call
envs.step(envs.action_space.sample())
rendering = envs.render()
# output should match mode, list of environment rgb_arrays
assert isinstance(rendering, list)
assert len(rendering) == num_envs
assert isinstance(rendering[0], np.ndarray)

envs.close()

0 comments on commit b537395

Please sign in to comment.