diff --git a/.github/unittest/linux_libs/scripts_jumanji/install.sh b/.github/unittest/linux_libs/scripts_jumanji/install.sh index 95a4a5a0e29..04875d6fa3d 100755 --- a/.github/unittest/linux_libs/scripts_jumanji/install.sh +++ b/.github/unittest/linux_libs/scripts_jumanji/install.sh @@ -28,15 +28,15 @@ git submodule sync && git submodule update --init --recursive printf "Installing PyTorch with cu121" if [[ "$TORCH_VERSION" == "nightly" ]]; then if [ "${CU_VERSION:-}" == cpu ] ; then - pip3 install --pre torch --index-url https://download.pytorch.org/whl/nightly/cpu -U + pip3 install --pre torch torchvision --index-url https://download.pytorch.org/whl/nightly/cpu -U else - pip3 install --pre torch --index-url https://download.pytorch.org/whl/nightly/cu121 -U + pip3 install --pre torch torchvision --index-url https://download.pytorch.org/whl/nightly/cu121 -U fi elif [[ "$TORCH_VERSION" == "stable" ]]; then if [ "${CU_VERSION:-}" == cpu ] ; then - pip3 install torch --index-url https://download.pytorch.org/whl/cpu + pip3 install torch torchvision --index-url https://download.pytorch.org/whl/cpu else - pip3 install torch --index-url https://download.pytorch.org/whl/cu121 + pip3 install torch torchvision --index-url https://download.pytorch.org/whl/cu121 fi else printf "Failed to install pytorch" diff --git a/torchrl/envs/libs/jumanji.py b/torchrl/envs/libs/jumanji.py index aa5b486577f..0b14bdb2a09 100644 --- a/torchrl/envs/libs/jumanji.py +++ b/torchrl/envs/libs/jumanji.py @@ -9,7 +9,6 @@ import numpy as np import torch -import torchvision.transforms.v2.functional from packaging import version from tensordict import TensorDict, TensorDictBase @@ -542,9 +541,13 @@ def render( import jax import jax.numpy as jnp import jumanji - import matplotlib - import matplotlib.pyplot as plt - import PIL + try: + import matplotlib + import matplotlib.pyplot as plt + import PIL + import torchvision.transforms.v2.functional + except ImportError as err: + raise ImportError("Rendering with Jumanji requires torchvision, matplotlib and PIL to be installed.") from err if matplotlib_backend is not None: matplotlib.use(matplotlib_backend)