diff --git a/test/test_libs.py b/test/test_libs.py index 895b6695ed3..67fe74f9f6e 100644 --- a/test/test_libs.py +++ b/test/test_libs.py @@ -4,17 +4,7 @@ # LICENSE file in the root directory of this source tree. import functools import gc -import importlib -import os -from contextlib import nullcontext -from pathlib import Path - -from torchrl._utils import logger as torchrl_logger - -from torchrl.data.datasets.gen_dgrl import GenDGRLExperienceReplay - -from torchrl.envs.transforms import ActionMask, TransformedEnv -from torchrl.modules import MaskedCategorical +import importlib.util _has_isaac = importlib.util.find_spec("isaacgym") is not None @@ -23,11 +13,13 @@ import isaacgym # noqa import isaacgymenvs # noqa from torchrl.envs.libs.isaacgym import IsaacGymEnv - import argparse import importlib +import os import time +from contextlib import nullcontext +from pathlib import Path from sys import platform from typing import Optional, Union @@ -59,7 +51,8 @@ TensorDictSequential, ) from torch import nn -from torchrl._utils import implement_for + +from torchrl._utils import implement_for, logger as torchrl_logger from torchrl.collectors.collectors import SyncDataCollector from torchrl.data import ( BinaryDiscreteTensorSpec, @@ -76,6 +69,8 @@ ) from torchrl.data.datasets.atari_dqn import AtariDQNExperienceReplay from torchrl.data.datasets.d4rl import D4RLExperienceReplay + +from torchrl.data.datasets.gen_dgrl import GenDGRLExperienceReplay from torchrl.data.datasets.minari_data import MinariExperienceReplay from torchrl.data.datasets.openml import OpenMLExperienceReplay from torchrl.data.datasets.openx import OpenXExperienceReplay @@ -116,13 +111,21 @@ from torchrl.envs.libs.robohive import _has_robohive, RoboHiveEnv from torchrl.envs.libs.smacv2 import _has_smacv2, SMACv2Env from torchrl.envs.libs.vmas import _has_vmas, VmasEnv, VmasWrapper + +from torchrl.envs.transforms import ActionMask, TransformedEnv from torchrl.envs.utils import ( check_env_specs, ExplorationType, MarlGroupMapType, RandomPolicy, ) -from torchrl.modules import ActorCriticOperator, MLP, SafeModule, ValueOperator +from torchrl.modules import ( + ActorCriticOperator, + MaskedCategorical, + MLP, + SafeModule, + ValueOperator, +) _has_d4rl = importlib.util.find_spec("d4rl") is not None @@ -3218,22 +3221,28 @@ def test_data(self, dataset): ) @pytest.mark.parametrize("num_envs", [10, 20]) @pytest.mark.parametrize("device", get_default_devices()) +@pytest.mark.parametrize("from_pixels", [False]) class TestIsaacGym: @classmethod - def _run_on_proc(cls, q, task, num_envs, device): + def _run_on_proc(cls, q, task, num_envs, device, from_pixels): try: - env = IsaacGymEnv(task=task, num_envs=num_envs, device=device) + env = IsaacGymEnv( + task=task, num_envs=num_envs, device=device, from_pixels=from_pixels + ) check_env_specs(env) q.put(("succeeded!", None)) except Exception as err: q.put(("failed!", err)) raise err - def test_env(self, task, num_envs, device): + def test_env(self, task, num_envs, device, from_pixels): from torch import multiprocessing as mp q = mp.Queue(1) - proc = mp.Process(target=self._run_on_proc, args=(q, task, num_envs, device)) + self._run_on_proc(q, task, num_envs, device, from_pixels) + proc = mp.Process( + target=self._run_on_proc, args=(q, task, num_envs, device, from_pixels) + ) try: proc.start() msg, error = q.get() diff --git a/torchrl/envs/libs/gym.py b/torchrl/envs/libs/gym.py index 72b3494fc29..ba06fdf5586 100644 --- a/torchrl/envs/libs/gym.py +++ b/torchrl/envs/libs/gym.py @@ -949,6 +949,9 @@ def _reward_space(self, env): # noqa: F811 return rs def _make_specs(self, env: "gym.Env", batch_size=None) -> None: # noqa: F821 + # If batch_size is provided, we se it to tell what batch size must be used + # instead of self.batch_size + cur_batch_size = self.batch_size if batch_size is None else torch.Size([]) action_spec = _gym_to_torchrl_spec_transform( env.action_space, device=self.device, @@ -962,14 +965,14 @@ def _make_specs(self, env: "gym.Env", batch_size=None) -> None: # noqa: F821 if not isinstance(observation_spec, CompositeSpec): if self.from_pixels: observation_spec = CompositeSpec( - pixels=observation_spec, shape=self.batch_size + pixels=observation_spec, shape=cur_batch_size ) else: observation_spec = CompositeSpec( - observation=observation_spec, shape=self.batch_size + observation=observation_spec, shape=cur_batch_size ) - elif observation_spec.shape[: len(self.batch_size)] != self.batch_size: - observation_spec.shape = self.batch_size + elif observation_spec.shape[: len(cur_batch_size)] != cur_batch_size: + observation_spec.shape = cur_batch_size reward_space = self._reward_space(env) if reward_space is not None: @@ -989,10 +992,11 @@ def _make_specs(self, env: "gym.Env", batch_size=None) -> None: # noqa: F821 observation_spec = observation_spec.expand( *batch_size, *observation_spec.shape ) + self.done_spec = self._make_done_spec() self.action_spec = action_spec - if reward_spec.shape[: len(self.batch_size)] != self.batch_size: - self.reward_spec = reward_spec.expand(*self.batch_size, *reward_spec.shape) + if reward_spec.shape[: len(cur_batch_size)] != cur_batch_size: + self.reward_spec = reward_spec.expand(*cur_batch_size, *reward_spec.shape) else: self.reward_spec = reward_spec self.observation_spec = observation_spec diff --git a/torchrl/envs/libs/isaacgym.py b/torchrl/envs/libs/isaacgym.py index 90f90043ab7..4c56bea304a 100644 --- a/torchrl/envs/libs/isaacgym.py +++ b/torchrl/envs/libs/isaacgym.py @@ -14,6 +14,7 @@ import torch from tensordict import TensorDictBase +from torchrl.data import CompositeSpec from torchrl.envs.libs.gym import GymWrapper from torchrl.envs.utils import _classproperty, make_composite_from_td @@ -49,9 +50,8 @@ def __init__( warnings.warn( "IsaacGym environment support is an experimental feature that may change in the future." ) - num_envs = env.num_envs super().__init__( - env, torch.device(env.device), batch_size=torch.Size([num_envs]), **kwargs + env, torch.device(env.device), batch_size=torch.Size([]), **kwargs ) if not hasattr(self, "task"): # by convention in IsaacGymEnvs @@ -59,9 +59,14 @@ def __init__( def _make_specs(self, env: "gym.Env") -> None: # noqa: F821 super()._make_specs(env, batch_size=self.batch_size) - self.full_done_spec = { - key: spec.squeeze(-1) for key, spec in self.full_done_spec.items(True, True) - } + self.full_done_spec = CompositeSpec( + { + key: spec.squeeze(-1) + for key, spec in self.full_done_spec.items(True, True) + }, + shape=self.batch_size, + ) + self.observation_spec["obs"] = self.observation_spec["observation"] del self.observation_spec["observation"] @@ -78,13 +83,25 @@ def _make_specs(self, env: "gym.Env") -> None: # noqa: F821 obs_spec.unlock_() obs_spec.update(specs) obs_spec.lock_() - self.__dict__["full_observation_spec"] = obs_spec + + def _output_transform(self, output): + obs, reward, done, info = output + if self.from_pixels: + obs["pixels"] = self._env.render(mode="rgb_array") + return obs, reward, done ^ done, done, done, info + + def _reset_output_transform(self, reset_data): + reset_data.pop("reward", None) + if self.from_pixels: + reset_data["pixels"] = self._env.render(mode="rgb_array") + return reset_data, {} @classmethod - def _make_envs(cls, *, task, num_envs, device, seed=None, headless=True, **kwargs): + def _make_envs(cls, *, task, num_envs, device, seed=None, headless=False, **kwargs): import isaacgym # noqa import isaacgymenvs # noqa + _ = kwargs.pop("from_pixels", None) envs = isaacgymenvs.make( seed=seed, task=task, @@ -125,15 +142,8 @@ def read_done( done = done.bool() return terminated, truncated, done, done.any() - def read_reward(self, total_reward, step_reward): - """Reads a reward and the total reward so far (in the frame skip loop) and returns a sum of the two. - - Args: - total_reward (torch.Tensor or TensorDict): total reward so far in the step - step_reward (reward in the format provided by the inner env): reward of this particular step - - """ - return total_reward + step_reward + def read_reward(self, total_reward): + return total_reward def read_obs( self, observations: Union[Dict[str, Any], torch.Tensor, np.ndarray] @@ -183,6 +193,13 @@ def __init__(self, task=None, *, env=None, num_envs, device, **kwargs): raise RuntimeError("Cannot provide both `task` and `env` arguments.") elif env is not None: task = env - envs = self._make_envs(task=task, num_envs=num_envs, device=device, **kwargs) + from_pixels = kwargs.pop("from_pixels", False) + envs = self._make_envs( + task=task, + num_envs=num_envs, + device=device, + virtual_screen_capture=False, + **kwargs, + ) self.task = task - super().__init__(envs, **kwargs) + super().__init__(envs, from_pixels=from_pixels, **kwargs)