diff --git a/src/plangym/control/dm_control.py b/src/plangym/control/dm_control.py index 638a063..0fe14ef 100644 --- a/src/plangym/control/dm_control.py +++ b/src/plangym/control/dm_control.py @@ -47,6 +47,7 @@ def __init__( render_mode="rgb_array", obs_type: str | None = None, remove_time_limit=None, # noqa: ARG002 + return_image: bool = False, ): """Initialize a :class:`DMControlEnv`. @@ -69,6 +70,7 @@ def __init__( render_mode: None|human|rgb_array. remove_time_limit: Ignored. obs_type: One of {"coords", "rgb", "grayscale"}. + return_image: If ``True``, add a "rgb" key to the observation dict. """ self._visualize_reward = visualize_reward @@ -84,6 +86,7 @@ def __init__( autoreset=autoreset, render_mode=render_mode, obs_type=obs_type, + return_image=return_image, ) @property