Skip to content

Commit

Permalink
Update tests
Browse files Browse the repository at this point in the history
Signed-off-by: guillemdb <[email protected]>
  • Loading branch information
Guillemdb committed Feb 1, 2025
1 parent 388d6d7 commit 17a7fe0
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 6 deletions.
12 changes: 12 additions & 0 deletions src/plangym/api_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@ class TestPlanEnv:
"autoreset",
"delay_setup",
"return_image",
"img_shape",
)

def test_init(self, env):
Expand Down Expand Up @@ -141,6 +142,13 @@ def test_obs_shape(self, env):
obs, *_ = env.step(env.sample_action())
assert obs.shape == env.obs_shape, (obs.shape, env.obs_shape)

def test_img_shape(self, env):
assert hasattr(env, "img_shape")
assert isinstance(env.img_shape, tuple)
if env.img_shape:
for val in env.img_shape:
assert isinstance(val, int)

def test_action_shape(self, env):
assert hasattr(env, "action_shape")
assert isinstance(env.action_shape, tuple)
Expand Down Expand Up @@ -193,6 +201,10 @@ def test_set_state(self, env):
def test_reset(self, env):
_ = env.reset(return_state=False)
state, obs, info = env.reset(return_state=True)
if env.return_image:
assert "rgb" in info
assert isinstance(info["rgb"], numpy.ndarray)
assert info["rgb"].shape == env.img_shape
state_is_array = isinstance(state, numpy.ndarray)
obs_is_array = isinstance(obs, numpy.ndarray)
assert isinstance(info, dict), info
Expand Down
13 changes: 7 additions & 6 deletions src/plangym/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,12 @@ def return_image(self) -> bool:
"""
return self._return_image

@cached_property
def img_shape(self) -> tuple[int, ...]:
"""Return the shape of the image returned by the environment."""
img = self.get_image()
return img.shape

def get_image(self) -> None | numpy.ndarray:
"""Return a numpy array containing the rendered view of the environment.
Expand Down Expand Up @@ -564,7 +570,7 @@ def gym_env(self):
return self._gym_env

@property
def obs_shape(self) -> tuple[int, ...]:
def obs_shape(self) -> tuple[int, ...] | None:
"""Tuple containing the shape of the *observations* returned by the Environment."""
if self.observation_space is None:
return None
Expand Down Expand Up @@ -614,11 +620,6 @@ def remove_time_limit(self) -> bool:
"""Return True if the Environment can only be stepped for a limited number of times."""
return self._remove_time_limit

@cached_property
def img_shape(self):
img = self.get_image()
return img.shape

def setup(self):
"""Initialize the target :class:`gym.Env` instance.
Expand Down

0 comments on commit 17a7fe0

Please sign in to comment.