Skip to content

Commit

Permalink
amend
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens committed Apr 29, 2024
1 parent 418d2b8 commit ca337e5
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 8 deletions.
8 changes: 4 additions & 4 deletions .github/unittest/linux_libs/scripts_jumanji/install.sh
Original file line number Diff line number Diff line change
Expand Up @@ -28,15 +28,15 @@ git submodule sync && git submodule update --init --recursive
printf "Installing PyTorch with cu121"
if [[ "$TORCH_VERSION" == "nightly" ]]; then
if [ "${CU_VERSION:-}" == cpu ] ; then
pip3 install --pre torch --index-url https://download.pytorch.org/whl/nightly/cpu -U
pip3 install --pre torch torchvision --index-url https://download.pytorch.org/whl/nightly/cpu -U
else
pip3 install --pre torch --index-url https://download.pytorch.org/whl/nightly/cu121 -U
pip3 install --pre torch torchvision --index-url https://download.pytorch.org/whl/nightly/cu121 -U
fi
elif [[ "$TORCH_VERSION" == "stable" ]]; then
if [ "${CU_VERSION:-}" == cpu ] ; then
pip3 install torch --index-url https://download.pytorch.org/whl/cpu
pip3 install torch torchvision --index-url https://download.pytorch.org/whl/cpu
else
pip3 install torch --index-url https://download.pytorch.org/whl/cu121
pip3 install torch torchvision --index-url https://download.pytorch.org/whl/cu121
fi
else
printf "Failed to install pytorch"
Expand Down
11 changes: 7 additions & 4 deletions torchrl/envs/libs/jumanji.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@

import numpy as np
import torch
import torchvision.transforms.v2.functional
from packaging import version
from tensordict import TensorDict, TensorDictBase

Expand Down Expand Up @@ -542,9 +541,13 @@ def render(
import jax
import jax.numpy as jnp
import jumanji
import matplotlib
import matplotlib.pyplot as plt
import PIL
try:
import matplotlib
import matplotlib.pyplot as plt
import PIL
import torchvision.transforms.v2.functional
except ImportError as err:
raise ImportError("Rendering with Jumanji requires torchvision, matplotlib and PIL to be installed.") from err

if matplotlib_backend is not None:
matplotlib.use(matplotlib_backend)
Expand Down

0 comments on commit ca337e5

Please sign in to comment.