diff --git a/src/plangym/api_tests.py b/src/plangym/api_tests.py index fe9ae4a..162bcfa 100644 --- a/src/plangym/api_tests.py +++ b/src/plangym/api_tests.py @@ -108,6 +108,7 @@ class TestPlanEnv: "autoreset", "delay_setup", "return_image", + "img_shape", ) def test_init(self, env): @@ -141,6 +142,13 @@ def test_obs_shape(self, env): obs, *_ = env.step(env.sample_action()) assert obs.shape == env.obs_shape, (obs.shape, env.obs_shape) + def test_img_shape(self, env): + assert hasattr(env, "img_shape") + assert isinstance(env.img_shape, tuple) + if env.img_shape: + for val in env.img_shape: + assert isinstance(val, int) + def test_action_shape(self, env): assert hasattr(env, "action_shape") assert isinstance(env.action_shape, tuple) @@ -193,6 +201,10 @@ def test_set_state(self, env): def test_reset(self, env): _ = env.reset(return_state=False) state, obs, info = env.reset(return_state=True) + if env.return_image: + assert "rgb" in info + assert isinstance(info["rgb"], numpy.ndarray) + assert info["rgb"].shape == env.img_shape state_is_array = isinstance(state, numpy.ndarray) obs_is_array = isinstance(obs, numpy.ndarray) assert isinstance(info, dict), info diff --git a/src/plangym/core.py b/src/plangym/core.py index ef1cf0d..d70f152 100644 --- a/src/plangym/core.py +++ b/src/plangym/core.py @@ -107,6 +107,12 @@ def return_image(self) -> bool: """ return self._return_image + @cached_property + def img_shape(self) -> tuple[int, ...]: + """Return the shape of the image returned by the environment.""" + img = self.get_image() + return img.shape + def get_image(self) -> None | numpy.ndarray: """Return a numpy array containing the rendered view of the environment. @@ -564,7 +570,7 @@ def gym_env(self): return self._gym_env @property - def obs_shape(self) -> tuple[int, ...]: + def obs_shape(self) -> tuple[int, ...] | None: """Tuple containing the shape of the *observations* returned by the Environment.""" if self.observation_space is None: return None @@ -614,11 +620,6 @@ def remove_time_limit(self) -> bool: """Return True if the Environment can only be stepped for a limited number of times.""" return self._remove_time_limit - @cached_property - def img_shape(self): - img = self.get_image() - return img.shape - def setup(self): """Initialize the target :class:`gym.Env` instance.