diff --git a/.github/unittest/linux/scripts/environment.yml b/.github/unittest/linux/scripts/environment.yml
index 30e01cfc4b5..2234683a497 100644
--- a/.github/unittest/linux/scripts/environment.yml
+++ b/.github/unittest/linux/scripts/environment.yml
@@ -25,6 +25,7 @@ dependencies:
- imageio==2.26.0
- wandb
- dm_control
+ - mujoco
- mlflow
- av
- coverage
diff --git a/.github/unittest/linux/scripts/run_all.sh b/.github/unittest/linux/scripts/run_all.sh
index 38235043d3f..22cdad1b479 100755
--- a/.github/unittest/linux/scripts/run_all.sh
+++ b/.github/unittest/linux/scripts/run_all.sh
@@ -3,8 +3,8 @@
set -euxo pipefail
set -v
-# ==================================================================================== #
-# ================================ Setup env ========================================= #
+# =============================================================================== #
+# ================================ Init ========================================= #
if [[ $OSTYPE != 'darwin'* ]]; then
@@ -31,6 +31,10 @@ if [[ $OSTYPE != 'darwin'* ]]; then
cp $this_dir/10_nvidia.json /usr/share/glvnd/egl_vendor.d/10_nvidia.json
fi
+
+# ==================================================================================== #
+# ================================ Setup env ========================================= #
+
# Avoid error: "fatal: unsafe repository"
git config --global --add safe.directory '*'
root_dir="$(git rev-parse --show-toplevel)"
@@ -61,7 +65,7 @@ if [ ! -d "${env_dir}" ]; then
fi
conda activate "${env_dir}"
-# 4. Install Conda dependencies
+# 3. Install Conda dependencies
printf "* Installing dependencies (except PyTorch)\n"
echo " - python=${PYTHON_VERSION}" >> "${this_dir}/environment.yml"
cat "${this_dir}/environment.yml"
@@ -76,7 +80,7 @@ export DISPLAY=:0
export SDL_VIDEODRIVER=dummy
# legacy from bash scripts: remove?
-conda env config vars set MUJOCO_GL=$MUJOCO_GL PYOPENGL_PLATFORM=$MUJOCO_GL DISPLAY=:0 SDL_VIDEODRIVER=dummy LAZY_LEGACY_OP=False
+conda env config vars set MUJOCO_GL=$MUJOCO_GL PYOPENGL_PLATFORM=$MUJOCO_GL DISPLAY=:0 SDL_VIDEODRIVER=dummy LAZY_LEGACY_OP=False RL_LOGGING_LEVEL=DEBUG
pip3 install pip --upgrade
pip install virtualenv
@@ -88,10 +92,14 @@ conda deactivate
conda activate "${env_dir}"
echo "installing gymnasium"
-pip3 install "gymnasium"
-pip3 install ale_py
-pip3 install mo-gymnasium[mujoco] # requires here bc needs mujoco-py
-pip3 install mujoco -U
+if [[ "$PYTHON_VERSION" == "3.12" ]]; then
+ pip3 install ale-py
+ pip3 install sympy
+ pip3 install "gymnasium[accept-rom-license,mujoco]<1.0" mo-gymnasium[mujoco]
+else
+ pip3 install "gymnasium[atari,accept-rom-license,mujoco]<1.0" mo-gymnasium[mujoco]
+fi
+pip3 install "mujoco" -U
# sanity check: remove?
python3 -c """
@@ -127,13 +135,13 @@ if [[ "$TORCH_VERSION" == "nightly" ]]; then
if [ "${CU_VERSION:-}" == cpu ] ; then
pip3 install --pre torch torchvision --index-url https://download.pytorch.org/whl/nightly/cpu -U
else
- pip3 install --pre torch torchvision --index-url https://download.pytorch.org/whl/nightly/$CU_VERSION
+ pip3 install --pre torch torchvision --index-url https://download.pytorch.org/whl/nightly/$CU_VERSION -U
fi
elif [[ "$TORCH_VERSION" == "stable" ]]; then
if [ "${CU_VERSION:-}" == cpu ] ; then
- pip3 install torch torchvision --index-url https://download.pytorch.org/whl/cpu
+ pip3 install torch torchvision --index-url https://download.pytorch.org/whl/cpu -U
else
- pip3 install torch torchvision --index-url https://download.pytorch.org/whl/$CU_VERSION
+ pip3 install torch torchvision --index-url https://download.pytorch.org/whl/$CU_VERSION -U
fi
else
printf "Failed to install pytorch"
@@ -181,7 +189,9 @@ fi
export PYTORCH_TEST_WITH_SLOW='1'
python -m torch.utils.collect_env
-# Avoid error: "fatal: unsafe repository"
+## Avoid error: "fatal: unsafe repository"
+#git config --global --add safe.directory '*'
+#root_dir="$(git rev-parse --show-toplevel)"
# solves ImportError: /lib64/libstdc++.so.6: version `GLIBCXX_3.4.21' not found
#export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:$lib_dir
@@ -198,7 +208,8 @@ if [ "${CU_VERSION:-}" != cpu ] ; then
--timeout=120 --mp_fork_if_no_cuda
else
python .github/unittest/helpers/coverage_run_parallel.py -m pytest test \
- --instafail --durations 200 -vv --capture no --ignore test/test_rlhf.py --ignore test/test_distributed.py \
+ --instafail --durations 200 -vv --capture no --ignore test/test_rlhf.py \
+ --ignore test/test_distributed.py \
--timeout=120 --mp_fork_if_no_cuda
fi
diff --git a/.github/unittest/linux_distributed/scripts/environment.yml b/.github/unittest/linux_distributed/scripts/environment.yml
index 6d27071791b..76160f7a16a 100644
--- a/.github/unittest/linux_distributed/scripts/environment.yml
+++ b/.github/unittest/linux_distributed/scripts/environment.yml
@@ -24,6 +24,7 @@ dependencies:
- imageio==2.26.0
- wandb
- dm_control
+ - mujoco
- mlflow
- av
- coverage
diff --git a/.github/unittest/linux_distributed/scripts/setup_env.sh b/.github/unittest/linux_distributed/scripts/setup_env.sh
index 501dbe1c914..2a48ab21459 100755
--- a/.github/unittest/linux_distributed/scripts/setup_env.sh
+++ b/.github/unittest/linux_distributed/scripts/setup_env.sh
@@ -119,7 +119,7 @@ if [[ $OSTYPE != 'darwin'* ]]; then
rm ale_py-0.8.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
fi
echo "installing gymnasium"
- pip install "gymnasium[atari,accept-rom-license]"
+ pip install "gymnasium[atari,accept-rom-license]<1.0"
else
- pip install "gymnasium[atari,accept-rom-license]"
+ pip install "gymnasium[atari,accept-rom-license]<1.0"
fi
diff --git a/.github/unittest/linux_examples/scripts/environment.yml b/.github/unittest/linux_examples/scripts/environment.yml
index 688921f826a..f7dddbc5e3c 100644
--- a/.github/unittest/linux_examples/scripts/environment.yml
+++ b/.github/unittest/linux_examples/scripts/environment.yml
@@ -22,6 +22,7 @@ dependencies:
- hydra-core
- imageio==2.26.0
- dm_control
+ - mujoco
- mlflow
- av
- coverage
diff --git a/.github/unittest/linux_examples/scripts/run_all.sh b/.github/unittest/linux_examples/scripts/run_all.sh
index 37719e51074..073ef59ed3f 100755
--- a/.github/unittest/linux_examples/scripts/run_all.sh
+++ b/.github/unittest/linux_examples/scripts/run_all.sh
@@ -130,7 +130,7 @@ elif [[ $PY_VERSION == *"3.11"* ]]; then
pip install ale_py-0.8.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
rm ale_py-0.8.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
fi
-pip install "gymnasium[atari,accept-rom-license]"
+pip install "gymnasium[atari,accept-rom-license]<1.0"
# ============================================================================================ #
# ================================ PyTorch & TorchRL ========================================= #
@@ -150,15 +150,15 @@ git submodule sync && git submodule update --init --recursive
printf "Installing PyTorch with %s\n" "${CU_VERSION}"
if [[ "$TORCH_VERSION" == "nightly" ]]; then
if [ "${CU_VERSION:-}" == cpu ] ; then
- pip3 install --pre torch torchvision --index-url https://download.pytorch.org/whl/nightly/cpu -U
+ pip3 install --pre torch torchvision numpy==1.26.4 --index-url https://download.pytorch.org/whl/nightly/cpu -U
else
- pip3 install --pre torch torchvision --index-url https://download.pytorch.org/whl/nightly/$CU_VERSION
+ pip3 install --pre torch torchvision numpy==1.26.4 --index-url https://download.pytorch.org/whl/nightly/$CU_VERSION
fi
elif [[ "$TORCH_VERSION" == "stable" ]]; then
if [ "${CU_VERSION:-}" == cpu ] ; then
- pip3 install torch torchvision --index-url https://download.pytorch.org/whl/cpu
+ pip3 install torch torchvision numpy==1.26.4 --index-url https://download.pytorch.org/whl/cpu
else
- pip3 install torch torchvision --index-url https://download.pytorch.org/whl/$CU_VERSION
+ pip3 install torch torchvision numpy==1.26.4 --index-url https://download.pytorch.org/whl/$CU_VERSION
fi
else
printf "Failed to install pytorch"
diff --git a/.github/unittest/linux_examples/scripts/run_test.sh b/.github/unittest/linux_examples/scripts/run_test.sh
index f8b700c0410..ef0d081f8fd 100755
--- a/.github/unittest/linux_examples/scripts/run_test.sh
+++ b/.github/unittest/linux_examples/scripts/run_test.sh
@@ -205,6 +205,13 @@ python .github/unittest/helpers/coverage_run_parallel.py sota-implementations/iq
env.train_num_envs=2 \
logger.mode=offline \
logger.backend=
+ python .github/unittest/helpers/coverage_run_parallel.py sota-implementations/gail/gail.py \
+ ppo.collector.total_frames=48 \
+ replay_buffer.batch_size=16 \
+ ppo.loss.mini_batch_size=10 \
+ ppo.collector.frames_per_batch=16 \
+ logger.mode=offline \
+ logger.backend=
# With single envs
python .github/unittest/helpers/coverage_run_parallel.py sota-implementations/dreamer/dreamer.py \
diff --git a/.github/unittest/linux_libs/scripts_brax/install.sh b/.github/unittest/linux_libs/scripts_brax/install.sh
index 80efdc536ab..20a2643dac8 100755
--- a/.github/unittest/linux_libs/scripts_brax/install.sh
+++ b/.github/unittest/linux_libs/scripts_brax/install.sh
@@ -34,7 +34,7 @@ if [[ "$TORCH_VERSION" == "nightly" ]]; then
fi
elif [[ "$TORCH_VERSION" == "stable" ]]; then
if [ "${CU_VERSION:-}" == cpu ] ; then
- pip3 install torch --index-url https://download.pytorch.org/whl/cpu
+ pip3 install torch --index-url https://download.pytorch.org/whl/cpu -U
else
pip3 install torch --index-url https://download.pytorch.org/whl/cu121
fi
diff --git a/.github/unittest/linux_libs/scripts_envpool/environment.yml b/.github/unittest/linux_libs/scripts_envpool/environment.yml
index 9259a2a4a43..74a3c91cf06 100644
--- a/.github/unittest/linux_libs/scripts_envpool/environment.yml
+++ b/.github/unittest/linux_libs/scripts_envpool/environment.yml
@@ -19,4 +19,5 @@ dependencies:
- pyyaml
- scipy
- dm_control
+ - mujoco
- coverage
diff --git a/.github/unittest/linux_libs/scripts_envpool/setup_env.sh b/.github/unittest/linux_libs/scripts_envpool/setup_env.sh
index bb5c09079ea..aabc153bde3 100755
--- a/.github/unittest/linux_libs/scripts_envpool/setup_env.sh
+++ b/.github/unittest/linux_libs/scripts_envpool/setup_env.sh
@@ -82,9 +82,9 @@ if [[ $OSTYPE != 'darwin'* ]]; then
fi
echo "installing gym"
# envpool does not currently work with gymnasium
- pip install "gym[atari,accept-rom-license]"
+ pip install "gym[atari,accept-rom-license]<1.0"
else
- pip install "gym[atari,accept-rom-license]"
+ pip install "gym[atari,accept-rom-license]<1.0"
fi
pip install envpool treevalue
diff --git a/.github/unittest/linux_libs/scripts_gym/batch_scripts.sh b/.github/unittest/linux_libs/scripts_gym/batch_scripts.sh
index 11921b44821..dc264e07b2d 100755
--- a/.github/unittest/linux_libs/scripts_gym/batch_scripts.sh
+++ b/.github/unittest/linux_libs/scripts_gym/batch_scripts.sh
@@ -6,6 +6,7 @@
DIR="$(cd "$(dirname "$0")" && pwd)"
set -e
+set -v
eval "$(./conda/bin/conda shell.bash hook)"
conda activate ./env
@@ -139,7 +140,7 @@ conda deactivate
conda create --prefix ./cloned_env --clone ./env -y
conda activate ./cloned_env
-pip3 install 'gymnasium[accept-rom-license,ale-py,atari]' mo-gymnasium gymnasium-robotics -U
+pip3 install 'gymnasium[accept-rom-license,ale-py,atari]<1.0' mo-gymnasium gymnasium-robotics -U
$DIR/run_test.sh
diff --git a/.github/unittest/linux_libs/scripts_gym/install.sh b/.github/unittest/linux_libs/scripts_gym/install.sh
index d3eac779861..a66fe5fddd1 100755
--- a/.github/unittest/linux_libs/scripts_gym/install.sh
+++ b/.github/unittest/linux_libs/scripts_gym/install.sh
@@ -7,6 +7,7 @@ unset PYTORCH_VERSION
apt-get update && apt-get install -y git wget gcc g++
set -e
+set -v
eval "$(./conda/bin/conda shell.bash hook)"
conda activate ./env
@@ -39,7 +40,7 @@ printf "Installing PyTorch with %s\n" "${CU_VERSION}"
if [ "${CU_VERSION:-}" == cpu ] ; then
conda install pytorch==2.0 torchvision==0.15 cpuonly -c pytorch -y
else
- conda install pytorch==2.0.1 torchvision==0.15.2 torchaudio==2.0.2 pytorch-cuda=11.8 -c pytorch -c nvidia -y
+ conda install pytorch==2.0.1 torchvision==0.15.2 torchaudio==2.0.2 pytorch-cuda=11.8 numpy==1.26 numpy-base==1.26 -c pytorch -c nvidia -y
fi
# Solving circular import: https://stackoverflow.com/questions/75501048/how-to-fix-attributeerror-partially-initialized-module-charset-normalizer-has
diff --git a/.github/unittest/linux_libs/scripts_habitat/setup_env.sh b/.github/unittest/linux_libs/scripts_habitat/setup_env.sh
index fc182a669ea..6ad970c3f47 100755
--- a/.github/unittest/linux_libs/scripts_habitat/setup_env.sh
+++ b/.github/unittest/linux_libs/scripts_habitat/setup_env.sh
@@ -67,8 +67,9 @@ pip install pip --upgrade
conda env update --file "${this_dir}/environment.yml" --prune
-#conda install habitat-sim withbullet headless -c conda-forge -c aihabitat -y
conda install habitat-sim withbullet headless -c conda-forge -c aihabitat -y
-conda run python -m pip install git+https://github.com/facebookresearch/habitat-lab.git@stable#subdirectory=habitat-lab
-#conda run python -m pip install git+https://github.com/facebookresearch/habitat-lab.git#subdirectory=habitat-baselines
+git clone https://github.com/facebookresearch/habitat-lab.git
+cd habitat-lab
+pip3 install -e habitat-lab
+pip3 install -e habitat-baselines # install habitat_baselines
conda run python -m pip install "gym[atari,accept-rom-license]" pygame
diff --git a/.github/unittest/linux_libs/scripts_minari/environment.yml b/.github/unittest/linux_libs/scripts_minari/environment.yml
index 27963a42a24..ad5bfc12650 100644
--- a/.github/unittest/linux_libs/scripts_minari/environment.yml
+++ b/.github/unittest/linux_libs/scripts_minari/environment.yml
@@ -17,4 +17,4 @@ dependencies:
- pyyaml
- scipy
- hydra-core
- - minari
+ - minari[gcs,hdf5]
diff --git a/.github/unittest/linux_libs/scripts_open_spiel/environment.yml b/.github/unittest/linux_libs/scripts_open_spiel/environment.yml
new file mode 100644
index 00000000000..937c37d58f6
--- /dev/null
+++ b/.github/unittest/linux_libs/scripts_open_spiel/environment.yml
@@ -0,0 +1,20 @@
+channels:
+ - pytorch
+ - defaults
+dependencies:
+ - pip
+ - pip:
+ - hypothesis
+ - future
+ - cloudpickle
+ - pytest
+ - pytest-cov
+ - pytest-mock
+ - pytest-instafail
+ - pytest-rerunfailures
+ - pytest-error-for-skips
+ - expecttest
+ - pyyaml
+ - scipy
+ - hydra-core
+ - open_spiel
diff --git a/.github/unittest/linux_libs/scripts_open_spiel/install.sh b/.github/unittest/linux_libs/scripts_open_spiel/install.sh
new file mode 100755
index 00000000000..95a4a5a0e29
--- /dev/null
+++ b/.github/unittest/linux_libs/scripts_open_spiel/install.sh
@@ -0,0 +1,60 @@
+#!/usr/bin/env bash
+
+unset PYTORCH_VERSION
+# For unittest, nightly PyTorch is used as the following section,
+# so no need to set PYTORCH_VERSION.
+# In fact, keeping PYTORCH_VERSION forces us to hardcode PyTorch version in config.
+
+set -e
+
+eval "$(./conda/bin/conda shell.bash hook)"
+conda activate ./env
+
+if [ "${CU_VERSION:-}" == cpu ] ; then
+ version="cpu"
+else
+ if [[ ${#CU_VERSION} -eq 4 ]]; then
+ CUDA_VERSION="${CU_VERSION:2:1}.${CU_VERSION:3:1}"
+ elif [[ ${#CU_VERSION} -eq 5 ]]; then
+ CUDA_VERSION="${CU_VERSION:2:2}.${CU_VERSION:4:1}"
+ fi
+ echo "Using CUDA $CUDA_VERSION as determined by CU_VERSION ($CU_VERSION)"
+ version="$(python -c "print('.'.join(\"${CUDA_VERSION}\".split('.')[:2]))")"
+fi
+
+# submodules
+git submodule sync && git submodule update --init --recursive
+
+printf "Installing PyTorch with cu121"
+if [[ "$TORCH_VERSION" == "nightly" ]]; then
+ if [ "${CU_VERSION:-}" == cpu ] ; then
+ pip3 install --pre torch --index-url https://download.pytorch.org/whl/nightly/cpu -U
+ else
+ pip3 install --pre torch --index-url https://download.pytorch.org/whl/nightly/cu121 -U
+ fi
+elif [[ "$TORCH_VERSION" == "stable" ]]; then
+ if [ "${CU_VERSION:-}" == cpu ] ; then
+ pip3 install torch --index-url https://download.pytorch.org/whl/cpu
+ else
+ pip3 install torch --index-url https://download.pytorch.org/whl/cu121
+ fi
+else
+ printf "Failed to install pytorch"
+ exit 1
+fi
+
+# install tensordict
+if [[ "$RELEASE" == 0 ]]; then
+ pip3 install git+https://github.com/pytorch/tensordict.git
+else
+ pip3 install tensordict
+fi
+
+# smoke test
+python -c "import functorch;import tensordict"
+
+printf "* Installing torchrl\n"
+python setup.py develop
+
+# smoke test
+python -c "import torchrl"
diff --git a/.github/unittest/linux_libs/scripts_open_spiel/post_process.sh b/.github/unittest/linux_libs/scripts_open_spiel/post_process.sh
new file mode 100755
index 00000000000..e97bf2a7b1b
--- /dev/null
+++ b/.github/unittest/linux_libs/scripts_open_spiel/post_process.sh
@@ -0,0 +1,6 @@
+#!/usr/bin/env bash
+
+set -e
+
+eval "$(./conda/bin/conda shell.bash hook)"
+conda activate ./env
diff --git a/.github/unittest/linux_libs/scripts_open_spiel/run-clang-format.py b/.github/unittest/linux_libs/scripts_open_spiel/run-clang-format.py
new file mode 100755
index 00000000000..5783a885d86
--- /dev/null
+++ b/.github/unittest/linux_libs/scripts_open_spiel/run-clang-format.py
@@ -0,0 +1,356 @@
+#!/usr/bin/env python
+"""
+MIT License
+
+Copyright (c) 2017 Guillaume Papin
+
+Permission is hereby granted, free of charge, to any person obtaining a copy
+of this software and associated documentation files (the "Software"), to deal
+in the Software without restriction, including without limitation the rights
+to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+copies of the Software, and to permit persons to whom the Software is
+furnished to do so, subject to the following conditions:
+
+The above copyright notice and this permission notice shall be included in all
+copies or substantial portions of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+SOFTWARE.
+
+A wrapper script around clang-format, suitable for linting multiple files
+and to use for continuous integration.
+
+This is an alternative API for the clang-format command line.
+It runs over multiple files and directories in parallel.
+A diff output is produced and a sensible exit code is returned.
+
+"""
+
+import argparse
+import difflib
+import fnmatch
+import multiprocessing
+import os
+import signal
+import subprocess
+import sys
+import traceback
+from functools import partial
+
+try:
+ from subprocess import DEVNULL # py3k
+except ImportError:
+ DEVNULL = open(os.devnull, "wb")
+
+
+DEFAULT_EXTENSIONS = "c,h,C,H,cpp,hpp,cc,hh,c++,h++,cxx,hxx,cu"
+
+
+class ExitStatus:
+ SUCCESS = 0
+ DIFF = 1
+ TROUBLE = 2
+
+
+def list_files(files, recursive=False, extensions=None, exclude=None):
+ if extensions is None:
+ extensions = []
+ if exclude is None:
+ exclude = []
+
+ out = []
+ for file in files:
+ if recursive and os.path.isdir(file):
+ for dirpath, dnames, fnames in os.walk(file):
+ fpaths = [os.path.join(dirpath, fname) for fname in fnames]
+ for pattern in exclude:
+ # os.walk() supports trimming down the dnames list
+ # by modifying it in-place,
+ # to avoid unnecessary directory listings.
+ dnames[:] = [
+ x
+ for x in dnames
+ if not fnmatch.fnmatch(os.path.join(dirpath, x), pattern)
+ ]
+ fpaths = [x for x in fpaths if not fnmatch.fnmatch(x, pattern)]
+ for f in fpaths:
+ ext = os.path.splitext(f)[1][1:]
+ if ext in extensions:
+ out.append(f)
+ else:
+ out.append(file)
+ return out
+
+
+def make_diff(file, original, reformatted):
+ return list(
+ difflib.unified_diff(
+ original,
+ reformatted,
+ fromfile=f"{file}\t(original)",
+ tofile=f"{file}\t(reformatted)",
+ n=3,
+ )
+ )
+
+
+class DiffError(Exception):
+ def __init__(self, message, errs=None):
+ super().__init__(message)
+ self.errs = errs or []
+
+
+class UnexpectedError(Exception):
+ def __init__(self, message, exc=None):
+ super().__init__(message)
+ self.formatted_traceback = traceback.format_exc()
+ self.exc = exc
+
+
+def run_clang_format_diff_wrapper(args, file):
+ try:
+ ret = run_clang_format_diff(args, file)
+ return ret
+ except DiffError:
+ raise
+ except Exception as e:
+ raise UnexpectedError(f"{file}: {e.__class__.__name__}: {e}", e)
+
+
+def run_clang_format_diff(args, file):
+ try:
+ with open(file, encoding="utf-8") as f:
+ original = f.readlines()
+ except OSError as exc:
+ raise DiffError(str(exc))
+ invocation = [args.clang_format_executable, file]
+
+ # Use of utf-8 to decode the process output.
+ #
+ # Hopefully, this is the correct thing to do.
+ #
+ # It's done due to the following assumptions (which may be incorrect):
+ # - clang-format will returns the bytes read from the files as-is,
+ # without conversion, and it is already assumed that the files use utf-8.
+ # - if the diagnostics were internationalized, they would use utf-8:
+ # > Adding Translations to Clang
+ # >
+ # > Not possible yet!
+ # > Diagnostic strings should be written in UTF-8,
+ # > the client can translate to the relevant code page if needed.
+ # > Each translation completely replaces the format string
+ # > for the diagnostic.
+ # > -- http://clang.llvm.org/docs/InternalsManual.html#internals-diag-translation
+
+ try:
+ proc = subprocess.Popen(
+ invocation,
+ stdout=subprocess.PIPE,
+ stderr=subprocess.PIPE,
+ universal_newlines=True,
+ encoding="utf-8",
+ )
+ except OSError as exc:
+ raise DiffError(
+ f"Command '{subprocess.list2cmdline(invocation)}' failed to start: {exc}"
+ )
+ proc_stdout = proc.stdout
+ proc_stderr = proc.stderr
+
+ # hopefully the stderr pipe won't get full and block the process
+ outs = list(proc_stdout.readlines())
+ errs = list(proc_stderr.readlines())
+ proc.wait()
+ if proc.returncode:
+ raise DiffError(
+ "Command '{}' returned non-zero exit status {}".format(
+ subprocess.list2cmdline(invocation), proc.returncode
+ ),
+ errs,
+ )
+ return make_diff(file, original, outs), errs
+
+
+def bold_red(s):
+ return "\x1b[1m\x1b[31m" + s + "\x1b[0m"
+
+
+def colorize(diff_lines):
+ def bold(s):
+ return "\x1b[1m" + s + "\x1b[0m"
+
+ def cyan(s):
+ return "\x1b[36m" + s + "\x1b[0m"
+
+ def green(s):
+ return "\x1b[32m" + s + "\x1b[0m"
+
+ def red(s):
+ return "\x1b[31m" + s + "\x1b[0m"
+
+ for line in diff_lines:
+ if line[:4] in ["--- ", "+++ "]:
+ yield bold(line)
+ elif line.startswith("@@ "):
+ yield cyan(line)
+ elif line.startswith("+"):
+ yield green(line)
+ elif line.startswith("-"):
+ yield red(line)
+ else:
+ yield line
+
+
+def print_diff(diff_lines, use_color):
+ if use_color:
+ diff_lines = colorize(diff_lines)
+ sys.stdout.writelines(diff_lines)
+
+
+def print_trouble(prog, message, use_colors):
+ error_text = "error:"
+ if use_colors:
+ error_text = bold_red(error_text)
+ print(f"{prog}: {error_text} {message}", file=sys.stderr)
+
+
+def main():
+ parser = argparse.ArgumentParser(description=__doc__)
+ parser.add_argument(
+ "--clang-format-executable",
+ metavar="EXECUTABLE",
+ help="path to the clang-format executable",
+ default="clang-format",
+ )
+ parser.add_argument(
+ "--extensions",
+ help=f"comma separated list of file extensions (default: {DEFAULT_EXTENSIONS})",
+ default=DEFAULT_EXTENSIONS,
+ )
+ parser.add_argument(
+ "-r",
+ "--recursive",
+ action="store_true",
+ help="run recursively over directories",
+ )
+ parser.add_argument("files", metavar="file", nargs="+")
+ parser.add_argument("-q", "--quiet", action="store_true")
+ parser.add_argument(
+ "-j",
+ metavar="N",
+ type=int,
+ default=0,
+ help="run N clang-format jobs in parallel (default number of cpus + 1)",
+ )
+ parser.add_argument(
+ "--color",
+ default="auto",
+ choices=["auto", "always", "never"],
+ help="show colored diff (default: auto)",
+ )
+ parser.add_argument(
+ "-e",
+ "--exclude",
+ metavar="PATTERN",
+ action="append",
+ default=[],
+ help="exclude paths matching the given glob-like pattern(s) from recursive search",
+ )
+
+ args = parser.parse_args()
+
+ # use default signal handling, like diff return SIGINT value on ^C
+ # https://bugs.python.org/issue14229#msg156446
+ signal.signal(signal.SIGINT, signal.SIG_DFL)
+ try:
+ signal.SIGPIPE
+ except AttributeError:
+ # compatibility, SIGPIPE does not exist on Windows
+ pass
+ else:
+ signal.signal(signal.SIGPIPE, signal.SIG_DFL)
+
+ colored_stdout = False
+ colored_stderr = False
+ if args.color == "always":
+ colored_stdout = True
+ colored_stderr = True
+ elif args.color == "auto":
+ colored_stdout = sys.stdout.isatty()
+ colored_stderr = sys.stderr.isatty()
+
+ version_invocation = [args.clang_format_executable, "--version"]
+ try:
+ subprocess.check_call(version_invocation, stdout=DEVNULL)
+ except subprocess.CalledProcessError as e:
+ print_trouble(parser.prog, str(e), use_colors=colored_stderr)
+ return ExitStatus.TROUBLE
+ except OSError as e:
+ print_trouble(
+ parser.prog,
+ f"Command '{subprocess.list2cmdline(version_invocation)}' failed to start: {e}",
+ use_colors=colored_stderr,
+ )
+ return ExitStatus.TROUBLE
+
+ retcode = ExitStatus.SUCCESS
+ files = list_files(
+ args.files,
+ recursive=args.recursive,
+ exclude=args.exclude,
+ extensions=args.extensions.split(","),
+ )
+
+ if not files:
+ return
+
+ njobs = args.j
+ if njobs == 0:
+ njobs = multiprocessing.cpu_count() + 1
+ njobs = min(len(files), njobs)
+
+ if njobs == 1:
+ # execute directly instead of in a pool,
+ # less overhead, simpler stacktraces
+ it = (run_clang_format_diff_wrapper(args, file) for file in files)
+ pool = None
+ else:
+ pool = multiprocessing.Pool(njobs)
+ it = pool.imap_unordered(partial(run_clang_format_diff_wrapper, args), files)
+ while True:
+ try:
+ outs, errs = next(it)
+ except StopIteration:
+ break
+ except DiffError as e:
+ print_trouble(parser.prog, str(e), use_colors=colored_stderr)
+ retcode = ExitStatus.TROUBLE
+ sys.stderr.writelines(e.errs)
+ except UnexpectedError as e:
+ print_trouble(parser.prog, str(e), use_colors=colored_stderr)
+ sys.stderr.write(e.formatted_traceback)
+ retcode = ExitStatus.TROUBLE
+ # stop at the first unexpected error,
+ # something could be very wrong,
+ # don't process all files unnecessarily
+ if pool:
+ pool.terminate()
+ break
+ else:
+ sys.stderr.writelines(errs)
+ if outs == []:
+ continue
+ if not args.quiet:
+ print_diff(outs, use_color=colored_stdout)
+ if retcode == ExitStatus.SUCCESS:
+ retcode = ExitStatus.DIFF
+ return retcode
+
+
+if __name__ == "__main__":
+ sys.exit(main())
diff --git a/.github/unittest/linux_libs/scripts_open_spiel/run_test.sh b/.github/unittest/linux_libs/scripts_open_spiel/run_test.sh
new file mode 100755
index 00000000000..a09229bf59a
--- /dev/null
+++ b/.github/unittest/linux_libs/scripts_open_spiel/run_test.sh
@@ -0,0 +1,28 @@
+#!/usr/bin/env bash
+
+set -e
+
+eval "$(./conda/bin/conda shell.bash hook)"
+conda activate ./env
+
+apt-get update && apt-get install -y git wget
+
+export PYTORCH_TEST_WITH_SLOW='1'
+export LAZY_LEGACY_OP=False
+python -m torch.utils.collect_env
+# Avoid error: "fatal: unsafe repository"
+git config --global --add safe.directory '*'
+
+root_dir="$(git rev-parse --show-toplevel)"
+env_dir="${root_dir}/env"
+lib_dir="${env_dir}/lib"
+
+conda deactivate && conda activate ./env
+
+# this workflow only tests the libs
+python -c "import pyspiel"
+
+python .github/unittest/helpers/coverage_run_parallel.py -m pytest test/test_libs.py --instafail -v --durations 200 --capture no -k TestOpenSpiel --error-for-skips --runslow
+
+coverage combine
+coverage xml -i
diff --git a/.github/unittest/linux_optdeps/scripts/setup_env.sh b/.github/unittest/linux_libs/scripts_open_spiel/setup_env.sh
similarity index 96%
rename from .github/unittest/linux_optdeps/scripts/setup_env.sh
rename to .github/unittest/linux_libs/scripts_open_spiel/setup_env.sh
index aa83bca32fc..e7b08ab02ff 100755
--- a/.github/unittest/linux_optdeps/scripts/setup_env.sh
+++ b/.github/unittest/linux_libs/scripts_open_spiel/setup_env.sh
@@ -6,8 +6,11 @@
# Do not install PyTorch and torchvision here, otherwise they also get cached.
set -e
+set -v
this_dir="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )"
+# Avoid error: "fatal: unsafe repository"
+
git config --global --add safe.directory '*'
root_dir="$(git rev-parse --show-toplevel)"
conda_dir="${root_dir}/conda"
diff --git a/.github/unittest/linux_libs/scripts_openx/install.sh b/.github/unittest/linux_libs/scripts_openx/install.sh
index 1be73fc1de0..c657fd48b46 100755
--- a/.github/unittest/linux_libs/scripts_openx/install.sh
+++ b/.github/unittest/linux_libs/scripts_openx/install.sh
@@ -37,9 +37,9 @@ if [[ "$TORCH_VERSION" == "nightly" ]]; then
fi
elif [[ "$TORCH_VERSION" == "stable" ]]; then
if [ "${CU_VERSION:-}" == cpu ] ; then
- pip3 install torch torchvision --index-url https://download.pytorch.org/whl/cpu
+ pip3 install torch torchvision --index-url https://download.pytorch.org/whl/cpu -U
else
- pip3 install torch torchvision --index-url https://download.pytorch.org/whl/cu121
+ pip3 install torch torchvision --index-url https://download.pytorch.org/whl/cu121 -U
fi
else
printf "Failed to install pytorch"
diff --git a/.github/unittest/linux_libs/scripts_pettingzoo/environment.yml b/.github/unittest/linux_libs/scripts_pettingzoo/environment.yml
index f6c79784a0b..8f4e35c8efa 100644
--- a/.github/unittest/linux_libs/scripts_pettingzoo/environment.yml
+++ b/.github/unittest/linux_libs/scripts_pettingzoo/environment.yml
@@ -20,3 +20,4 @@ dependencies:
- pyyaml
- autorom[accept-rom-license]
- pettingzoo[all]==1.24.3
+ - gymnasium<1.0.0
diff --git a/.github/unittest/linux_libs/scripts_rlhf/install.sh b/.github/unittest/linux_libs/scripts_rlhf/install.sh
index d0363186c1a..9a5cf82074b 100755
--- a/.github/unittest/linux_libs/scripts_rlhf/install.sh
+++ b/.github/unittest/linux_libs/scripts_rlhf/install.sh
@@ -31,15 +31,15 @@ git submodule sync && git submodule update --init --recursive
printf "Installing PyTorch with cu121"
if [[ "$TORCH_VERSION" == "nightly" ]]; then
if [ "${CU_VERSION:-}" == cpu ] ; then
- pip3 install --pre torch --index-url https://download.pytorch.org/whl/nightly/cpu -U
+ pip3 install --pre torch numpy==1.26.4 --index-url https://download.pytorch.org/whl/nightly/cpu -U
else
- pip3 install --pre torch --index-url https://download.pytorch.org/whl/nightly/cu121 -U
+ pip3 install --pre torch numpy==1.26.4 --index-url https://download.pytorch.org/whl/nightly/cu121 -U
fi
elif [[ "$TORCH_VERSION" == "stable" ]]; then
if [ "${CU_VERSION:-}" == cpu ] ; then
- pip3 install torch --index-url https://download.pytorch.org/whl/cpu
+ pip3 install torch numpy==1.26.4 --index-url https://download.pytorch.org/whl/cpu
else
- pip3 install torch --index-url https://download.pytorch.org/whl/cu121
+ pip3 install torch numpy==1.26.4 --index-url https://download.pytorch.org/whl/cu121
fi
else
printf "Failed to install pytorch"
diff --git a/.github/unittest/linux_libs/scripts_robohive/environment.yml b/.github/unittest/linux_libs/scripts_robohive/environment.yml
index cff88245d1e..4b6e4ef4f0e 100644
--- a/.github/unittest/linux_libs/scripts_robohive/environment.yml
+++ b/.github/unittest/linux_libs/scripts_robohive/environment.yml
@@ -6,7 +6,7 @@ dependencies:
- protobuf
- pip:
# Initial version is required to install Atari ROMS in setup_env.sh
- - gymnasium
+ - gymnasium<1.0
- hypothesis
- future
- cloudpickle
diff --git a/.github/unittest/linux_libs/scripts_unity_mlagents/environment.yml b/.github/unittest/linux_libs/scripts_unity_mlagents/environment.yml
new file mode 100644
index 00000000000..6dc82afbc25
--- /dev/null
+++ b/.github/unittest/linux_libs/scripts_unity_mlagents/environment.yml
@@ -0,0 +1,21 @@
+channels:
+ - pytorch
+ - defaults
+dependencies:
+ - python==3.10.12
+ - pip
+ - pip:
+ - mlagents_envs==1.0.0
+ - hypothesis
+ - future
+ - cloudpickle
+ - pytest
+ - pytest-cov
+ - pytest-mock
+ - pytest-instafail
+ - pytest-rerunfailures
+ - pytest-error-for-skips
+ - expecttest
+ - pyyaml
+ - scipy
+ - hydra-core
diff --git a/.github/unittest/linux_libs/scripts_unity_mlagents/install.sh b/.github/unittest/linux_libs/scripts_unity_mlagents/install.sh
new file mode 100755
index 00000000000..95a4a5a0e29
--- /dev/null
+++ b/.github/unittest/linux_libs/scripts_unity_mlagents/install.sh
@@ -0,0 +1,60 @@
+#!/usr/bin/env bash
+
+unset PYTORCH_VERSION
+# For unittest, nightly PyTorch is used as the following section,
+# so no need to set PYTORCH_VERSION.
+# In fact, keeping PYTORCH_VERSION forces us to hardcode PyTorch version in config.
+
+set -e
+
+eval "$(./conda/bin/conda shell.bash hook)"
+conda activate ./env
+
+if [ "${CU_VERSION:-}" == cpu ] ; then
+ version="cpu"
+else
+ if [[ ${#CU_VERSION} -eq 4 ]]; then
+ CUDA_VERSION="${CU_VERSION:2:1}.${CU_VERSION:3:1}"
+ elif [[ ${#CU_VERSION} -eq 5 ]]; then
+ CUDA_VERSION="${CU_VERSION:2:2}.${CU_VERSION:4:1}"
+ fi
+ echo "Using CUDA $CUDA_VERSION as determined by CU_VERSION ($CU_VERSION)"
+ version="$(python -c "print('.'.join(\"${CUDA_VERSION}\".split('.')[:2]))")"
+fi
+
+# submodules
+git submodule sync && git submodule update --init --recursive
+
+printf "Installing PyTorch with cu121"
+if [[ "$TORCH_VERSION" == "nightly" ]]; then
+ if [ "${CU_VERSION:-}" == cpu ] ; then
+ pip3 install --pre torch --index-url https://download.pytorch.org/whl/nightly/cpu -U
+ else
+ pip3 install --pre torch --index-url https://download.pytorch.org/whl/nightly/cu121 -U
+ fi
+elif [[ "$TORCH_VERSION" == "stable" ]]; then
+ if [ "${CU_VERSION:-}" == cpu ] ; then
+ pip3 install torch --index-url https://download.pytorch.org/whl/cpu
+ else
+ pip3 install torch --index-url https://download.pytorch.org/whl/cu121
+ fi
+else
+ printf "Failed to install pytorch"
+ exit 1
+fi
+
+# install tensordict
+if [[ "$RELEASE" == 0 ]]; then
+ pip3 install git+https://github.com/pytorch/tensordict.git
+else
+ pip3 install tensordict
+fi
+
+# smoke test
+python -c "import functorch;import tensordict"
+
+printf "* Installing torchrl\n"
+python setup.py develop
+
+# smoke test
+python -c "import torchrl"
diff --git a/.github/unittest/linux_libs/scripts_unity_mlagents/post_process.sh b/.github/unittest/linux_libs/scripts_unity_mlagents/post_process.sh
new file mode 100755
index 00000000000..e97bf2a7b1b
--- /dev/null
+++ b/.github/unittest/linux_libs/scripts_unity_mlagents/post_process.sh
@@ -0,0 +1,6 @@
+#!/usr/bin/env bash
+
+set -e
+
+eval "$(./conda/bin/conda shell.bash hook)"
+conda activate ./env
diff --git a/.github/unittest/linux_libs/scripts_unity_mlagents/run-clang-format.py b/.github/unittest/linux_libs/scripts_unity_mlagents/run-clang-format.py
new file mode 100755
index 00000000000..5783a885d86
--- /dev/null
+++ b/.github/unittest/linux_libs/scripts_unity_mlagents/run-clang-format.py
@@ -0,0 +1,356 @@
+#!/usr/bin/env python
+"""
+MIT License
+
+Copyright (c) 2017 Guillaume Papin
+
+Permission is hereby granted, free of charge, to any person obtaining a copy
+of this software and associated documentation files (the "Software"), to deal
+in the Software without restriction, including without limitation the rights
+to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+copies of the Software, and to permit persons to whom the Software is
+furnished to do so, subject to the following conditions:
+
+The above copyright notice and this permission notice shall be included in all
+copies or substantial portions of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+SOFTWARE.
+
+A wrapper script around clang-format, suitable for linting multiple files
+and to use for continuous integration.
+
+This is an alternative API for the clang-format command line.
+It runs over multiple files and directories in parallel.
+A diff output is produced and a sensible exit code is returned.
+
+"""
+
+import argparse
+import difflib
+import fnmatch
+import multiprocessing
+import os
+import signal
+import subprocess
+import sys
+import traceback
+from functools import partial
+
+try:
+ from subprocess import DEVNULL # py3k
+except ImportError:
+ DEVNULL = open(os.devnull, "wb")
+
+
+DEFAULT_EXTENSIONS = "c,h,C,H,cpp,hpp,cc,hh,c++,h++,cxx,hxx,cu"
+
+
+class ExitStatus:
+ SUCCESS = 0
+ DIFF = 1
+ TROUBLE = 2
+
+
+def list_files(files, recursive=False, extensions=None, exclude=None):
+ if extensions is None:
+ extensions = []
+ if exclude is None:
+ exclude = []
+
+ out = []
+ for file in files:
+ if recursive and os.path.isdir(file):
+ for dirpath, dnames, fnames in os.walk(file):
+ fpaths = [os.path.join(dirpath, fname) for fname in fnames]
+ for pattern in exclude:
+ # os.walk() supports trimming down the dnames list
+ # by modifying it in-place,
+ # to avoid unnecessary directory listings.
+ dnames[:] = [
+ x
+ for x in dnames
+ if not fnmatch.fnmatch(os.path.join(dirpath, x), pattern)
+ ]
+ fpaths = [x for x in fpaths if not fnmatch.fnmatch(x, pattern)]
+ for f in fpaths:
+ ext = os.path.splitext(f)[1][1:]
+ if ext in extensions:
+ out.append(f)
+ else:
+ out.append(file)
+ return out
+
+
+def make_diff(file, original, reformatted):
+ return list(
+ difflib.unified_diff(
+ original,
+ reformatted,
+ fromfile=f"{file}\t(original)",
+ tofile=f"{file}\t(reformatted)",
+ n=3,
+ )
+ )
+
+
+class DiffError(Exception):
+ def __init__(self, message, errs=None):
+ super().__init__(message)
+ self.errs = errs or []
+
+
+class UnexpectedError(Exception):
+ def __init__(self, message, exc=None):
+ super().__init__(message)
+ self.formatted_traceback = traceback.format_exc()
+ self.exc = exc
+
+
+def run_clang_format_diff_wrapper(args, file):
+ try:
+ ret = run_clang_format_diff(args, file)
+ return ret
+ except DiffError:
+ raise
+ except Exception as e:
+ raise UnexpectedError(f"{file}: {e.__class__.__name__}: {e}", e)
+
+
+def run_clang_format_diff(args, file):
+ try:
+ with open(file, encoding="utf-8") as f:
+ original = f.readlines()
+ except OSError as exc:
+ raise DiffError(str(exc))
+ invocation = [args.clang_format_executable, file]
+
+ # Use of utf-8 to decode the process output.
+ #
+ # Hopefully, this is the correct thing to do.
+ #
+ # It's done due to the following assumptions (which may be incorrect):
+ # - clang-format will returns the bytes read from the files as-is,
+ # without conversion, and it is already assumed that the files use utf-8.
+ # - if the diagnostics were internationalized, they would use utf-8:
+ # > Adding Translations to Clang
+ # >
+ # > Not possible yet!
+ # > Diagnostic strings should be written in UTF-8,
+ # > the client can translate to the relevant code page if needed.
+ # > Each translation completely replaces the format string
+ # > for the diagnostic.
+ # > -- http://clang.llvm.org/docs/InternalsManual.html#internals-diag-translation
+
+ try:
+ proc = subprocess.Popen(
+ invocation,
+ stdout=subprocess.PIPE,
+ stderr=subprocess.PIPE,
+ universal_newlines=True,
+ encoding="utf-8",
+ )
+ except OSError as exc:
+ raise DiffError(
+ f"Command '{subprocess.list2cmdline(invocation)}' failed to start: {exc}"
+ )
+ proc_stdout = proc.stdout
+ proc_stderr = proc.stderr
+
+ # hopefully the stderr pipe won't get full and block the process
+ outs = list(proc_stdout.readlines())
+ errs = list(proc_stderr.readlines())
+ proc.wait()
+ if proc.returncode:
+ raise DiffError(
+ "Command '{}' returned non-zero exit status {}".format(
+ subprocess.list2cmdline(invocation), proc.returncode
+ ),
+ errs,
+ )
+ return make_diff(file, original, outs), errs
+
+
+def bold_red(s):
+ return "\x1b[1m\x1b[31m" + s + "\x1b[0m"
+
+
+def colorize(diff_lines):
+ def bold(s):
+ return "\x1b[1m" + s + "\x1b[0m"
+
+ def cyan(s):
+ return "\x1b[36m" + s + "\x1b[0m"
+
+ def green(s):
+ return "\x1b[32m" + s + "\x1b[0m"
+
+ def red(s):
+ return "\x1b[31m" + s + "\x1b[0m"
+
+ for line in diff_lines:
+ if line[:4] in ["--- ", "+++ "]:
+ yield bold(line)
+ elif line.startswith("@@ "):
+ yield cyan(line)
+ elif line.startswith("+"):
+ yield green(line)
+ elif line.startswith("-"):
+ yield red(line)
+ else:
+ yield line
+
+
+def print_diff(diff_lines, use_color):
+ if use_color:
+ diff_lines = colorize(diff_lines)
+ sys.stdout.writelines(diff_lines)
+
+
+def print_trouble(prog, message, use_colors):
+ error_text = "error:"
+ if use_colors:
+ error_text = bold_red(error_text)
+ print(f"{prog}: {error_text} {message}", file=sys.stderr)
+
+
+def main():
+ parser = argparse.ArgumentParser(description=__doc__)
+ parser.add_argument(
+ "--clang-format-executable",
+ metavar="EXECUTABLE",
+ help="path to the clang-format executable",
+ default="clang-format",
+ )
+ parser.add_argument(
+ "--extensions",
+ help=f"comma separated list of file extensions (default: {DEFAULT_EXTENSIONS})",
+ default=DEFAULT_EXTENSIONS,
+ )
+ parser.add_argument(
+ "-r",
+ "--recursive",
+ action="store_true",
+ help="run recursively over directories",
+ )
+ parser.add_argument("files", metavar="file", nargs="+")
+ parser.add_argument("-q", "--quiet", action="store_true")
+ parser.add_argument(
+ "-j",
+ metavar="N",
+ type=int,
+ default=0,
+ help="run N clang-format jobs in parallel (default number of cpus + 1)",
+ )
+ parser.add_argument(
+ "--color",
+ default="auto",
+ choices=["auto", "always", "never"],
+ help="show colored diff (default: auto)",
+ )
+ parser.add_argument(
+ "-e",
+ "--exclude",
+ metavar="PATTERN",
+ action="append",
+ default=[],
+ help="exclude paths matching the given glob-like pattern(s) from recursive search",
+ )
+
+ args = parser.parse_args()
+
+ # use default signal handling, like diff return SIGINT value on ^C
+ # https://bugs.python.org/issue14229#msg156446
+ signal.signal(signal.SIGINT, signal.SIG_DFL)
+ try:
+ signal.SIGPIPE
+ except AttributeError:
+ # compatibility, SIGPIPE does not exist on Windows
+ pass
+ else:
+ signal.signal(signal.SIGPIPE, signal.SIG_DFL)
+
+ colored_stdout = False
+ colored_stderr = False
+ if args.color == "always":
+ colored_stdout = True
+ colored_stderr = True
+ elif args.color == "auto":
+ colored_stdout = sys.stdout.isatty()
+ colored_stderr = sys.stderr.isatty()
+
+ version_invocation = [args.clang_format_executable, "--version"]
+ try:
+ subprocess.check_call(version_invocation, stdout=DEVNULL)
+ except subprocess.CalledProcessError as e:
+ print_trouble(parser.prog, str(e), use_colors=colored_stderr)
+ return ExitStatus.TROUBLE
+ except OSError as e:
+ print_trouble(
+ parser.prog,
+ f"Command '{subprocess.list2cmdline(version_invocation)}' failed to start: {e}",
+ use_colors=colored_stderr,
+ )
+ return ExitStatus.TROUBLE
+
+ retcode = ExitStatus.SUCCESS
+ files = list_files(
+ args.files,
+ recursive=args.recursive,
+ exclude=args.exclude,
+ extensions=args.extensions.split(","),
+ )
+
+ if not files:
+ return
+
+ njobs = args.j
+ if njobs == 0:
+ njobs = multiprocessing.cpu_count() + 1
+ njobs = min(len(files), njobs)
+
+ if njobs == 1:
+ # execute directly instead of in a pool,
+ # less overhead, simpler stacktraces
+ it = (run_clang_format_diff_wrapper(args, file) for file in files)
+ pool = None
+ else:
+ pool = multiprocessing.Pool(njobs)
+ it = pool.imap_unordered(partial(run_clang_format_diff_wrapper, args), files)
+ while True:
+ try:
+ outs, errs = next(it)
+ except StopIteration:
+ break
+ except DiffError as e:
+ print_trouble(parser.prog, str(e), use_colors=colored_stderr)
+ retcode = ExitStatus.TROUBLE
+ sys.stderr.writelines(e.errs)
+ except UnexpectedError as e:
+ print_trouble(parser.prog, str(e), use_colors=colored_stderr)
+ sys.stderr.write(e.formatted_traceback)
+ retcode = ExitStatus.TROUBLE
+ # stop at the first unexpected error,
+ # something could be very wrong,
+ # don't process all files unnecessarily
+ if pool:
+ pool.terminate()
+ break
+ else:
+ sys.stderr.writelines(errs)
+ if outs == []:
+ continue
+ if not args.quiet:
+ print_diff(outs, use_color=colored_stdout)
+ if retcode == ExitStatus.SUCCESS:
+ retcode = ExitStatus.DIFF
+ return retcode
+
+
+if __name__ == "__main__":
+ sys.exit(main())
diff --git a/.github/unittest/linux_libs/scripts_unity_mlagents/run_test.sh b/.github/unittest/linux_libs/scripts_unity_mlagents/run_test.sh
new file mode 100755
index 00000000000..d5bb8695c44
--- /dev/null
+++ b/.github/unittest/linux_libs/scripts_unity_mlagents/run_test.sh
@@ -0,0 +1,28 @@
+#!/usr/bin/env bash
+
+set -e
+
+eval "$(./conda/bin/conda shell.bash hook)"
+conda activate ./env
+
+apt-get update && apt-get install -y git wget
+
+export PYTORCH_TEST_WITH_SLOW='1'
+export LAZY_LEGACY_OP=False
+python -m torch.utils.collect_env
+# Avoid error: "fatal: unsafe repository"
+git config --global --add safe.directory '*'
+
+root_dir="$(git rev-parse --show-toplevel)"
+env_dir="${root_dir}/env"
+lib_dir="${env_dir}/lib"
+
+conda deactivate && conda activate ./env
+
+# this workflow only tests the libs
+python -c "import mlagents_envs"
+
+python .github/unittest/helpers/coverage_run_parallel.py -m pytest test/test_libs.py --instafail -v --durations 200 --capture no -k TestUnityMLAgents --runslow
+
+coverage combine
+coverage xml -i
diff --git a/.github/unittest/linux_libs/scripts_unity_mlagents/setup_env.sh b/.github/unittest/linux_libs/scripts_unity_mlagents/setup_env.sh
new file mode 100755
index 00000000000..e7b08ab02ff
--- /dev/null
+++ b/.github/unittest/linux_libs/scripts_unity_mlagents/setup_env.sh
@@ -0,0 +1,49 @@
+#!/usr/bin/env bash
+
+# This script is for setting up environment in which unit test is ran.
+# To speed up the CI time, the resulting environment is cached.
+#
+# Do not install PyTorch and torchvision here, otherwise they also get cached.
+
+set -e
+set -v
+
+this_dir="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )"
+# Avoid error: "fatal: unsafe repository"
+
+git config --global --add safe.directory '*'
+root_dir="$(git rev-parse --show-toplevel)"
+conda_dir="${root_dir}/conda"
+env_dir="${root_dir}/env"
+
+cd "${root_dir}"
+
+case "$(uname -s)" in
+ Darwin*) os=MacOSX;;
+ *) os=Linux
+esac
+
+# 1. Install conda at ./conda
+if [ ! -d "${conda_dir}" ]; then
+ printf "* Installing conda\n"
+ wget -O miniconda.sh "http://repo.continuum.io/miniconda/Miniconda3-latest-${os}-x86_64.sh"
+ bash ./miniconda.sh -b -f -p "${conda_dir}"
+fi
+eval "$(${conda_dir}/bin/conda shell.bash hook)"
+
+# 2. Create test environment at ./env
+printf "python: ${PYTHON_VERSION}\n"
+if [ ! -d "${env_dir}" ]; then
+ printf "* Creating a test environment\n"
+ conda create --prefix "${env_dir}" -y python="$PYTHON_VERSION"
+fi
+conda activate "${env_dir}"
+
+# 3. Install Conda dependencies
+printf "* Installing dependencies (except PyTorch)\n"
+echo " - python=${PYTHON_VERSION}" >> "${this_dir}/environment.yml"
+cat "${this_dir}/environment.yml"
+
+pip install pip --upgrade
+
+conda env update --file "${this_dir}/environment.yml" --prune
diff --git a/.github/unittest/linux_libs/scripts_vd4rl/install.sh b/.github/unittest/linux_libs/scripts_vd4rl/install.sh
index 1be73fc1de0..256f8d065f6 100755
--- a/.github/unittest/linux_libs/scripts_vd4rl/install.sh
+++ b/.github/unittest/linux_libs/scripts_vd4rl/install.sh
@@ -37,7 +37,7 @@ if [[ "$TORCH_VERSION" == "nightly" ]]; then
fi
elif [[ "$TORCH_VERSION" == "stable" ]]; then
if [ "${CU_VERSION:-}" == cpu ] ; then
- pip3 install torch torchvision --index-url https://download.pytorch.org/whl/cpu
+ pip3 install torch torchvision --index-url https://download.pytorch.org/whl/cpu -U
else
pip3 install torch torchvision --index-url https://download.pytorch.org/whl/cu121
fi
diff --git a/.github/unittest/linux_olddeps/scripts_gym_0_13/environment.yml b/.github/unittest/linux_olddeps/scripts_gym_0_13/environment.yml
index d34011e7bdc..06c4a112933 100644
--- a/.github/unittest/linux_olddeps/scripts_gym_0_13/environment.yml
+++ b/.github/unittest/linux_olddeps/scripts_gym_0_13/environment.yml
@@ -22,6 +22,7 @@ dependencies:
- scipy
- hydra-core
- dm_control -e git+https://github.com/deepmind/dm_control.git@c053360edea6170acfd9c8f65446703307d9d352#egg={dm_control}
+ - mujoco
- patchelf
- pyopengl==3.1.4
- ray
diff --git a/.github/unittest/linux_olddeps/scripts_gym_0_13/install.sh b/.github/unittest/linux_olddeps/scripts_gym_0_13/install.sh
index 58a33cd43f4..c1dde8bb7d0 100755
--- a/.github/unittest/linux_olddeps/scripts_gym_0_13/install.sh
+++ b/.github/unittest/linux_olddeps/scripts_gym_0_13/install.sh
@@ -39,7 +39,7 @@ printf "Installing PyTorch with %s\n" "${CU_VERSION}"
if [ "${CU_VERSION:-}" == cpu ] ; then
conda install pytorch==2.0 torchvision==0.15 cpuonly -c pytorch -y
else
- conda install pytorch==2.0.1 torchvision==0.15.2 torchaudio==2.0.2 pytorch-cuda=11.8 -c pytorch -c nvidia -y
+ conda install pytorch==2.0.1 torchvision==0.15.2 torchaudio==2.0.2 pytorch-cuda=11.8 numpy==1.26 -c pytorch -c nvidia -y
fi
# Solving circular import: https://stackoverflow.com/questions/75501048/how-to-fix-attributeerror-partially-initialized-module-charset-normalizer-has
diff --git a/.github/unittest/linux_optdeps/scripts/install.sh b/.github/unittest/linux_optdeps/scripts/install.sh
deleted file mode 100755
index 8ccbfbb8e19..00000000000
--- a/.github/unittest/linux_optdeps/scripts/install.sh
+++ /dev/null
@@ -1,36 +0,0 @@
-#!/usr/bin/env bash
-
-unset PYTORCH_VERSION
-
-set -e
-set -v
-
-eval "$(./conda/bin/conda shell.bash hook)"
-conda activate ./env
-
-if [[ ${#CU_VERSION} -eq 4 ]]; then
- CUDA_VERSION="${CU_VERSION:2:1}.${CU_VERSION:3:1}"
-elif [[ ${#CU_VERSION} -eq 5 ]]; then
- CUDA_VERSION="${CU_VERSION:2:2}.${CU_VERSION:4:1}"
-fi
-echo "Using CUDA $CUDA_VERSION as determined by CU_VERSION ($CU_VERSION)"
-version="$(python -c "print('.'.join(\"${CUDA_VERSION}\".split('.')[:2]))")"
-
-# submodules
-git submodule sync && git submodule update --init --recursive
-
-printf "Installing PyTorch with %s\n" "${CU_VERSION}"
-pip3 install --pre torch --index-url https://download.pytorch.org/whl/nightly/$CU_VERSION
-
-# install tensordict
-if [[ "$RELEASE" == 0 ]]; then
- pip3 install git+https://github.com/pytorch/tensordict.git
-else
- pip3 install tensordict
-fi
-
-printf "* Installing torchrl\n"
-python setup.py develop
-
-# smoke test
-python -c "import torchrl"
diff --git a/.github/unittest/linux_optdeps/scripts/run_all.sh b/.github/unittest/linux_optdeps/scripts/run_all.sh
index 9edfec5ea46..7f34ffd42fd 100755
--- a/.github/unittest/linux_optdeps/scripts/run_all.sh
+++ b/.github/unittest/linux_optdeps/scripts/run_all.sh
@@ -2,9 +2,10 @@
set -euxo pipefail
set -v
+set -e
-# ==================================================================================== #
-# ================================ Init ============================================== #
+# =============================================================================== #
+# ================================ Init ========================================= #
if [[ $OSTYPE != 'darwin'* ]]; then
@@ -35,18 +36,133 @@ fi
# ==================================================================================== #
# ================================ Setup env ========================================= #
-bash ${this_dir}/setup_env.sh
+# Avoid error: "fatal: unsafe repository"
+git config --global --add safe.directory '*'
+root_dir="$(git rev-parse --show-toplevel)"
+conda_dir="${root_dir}/conda"
+env_dir="${root_dir}/env"
+lib_dir="${env_dir}/lib"
+
+cd "${root_dir}"
+
+case "$(uname -s)" in
+ Darwin*) os=MacOSX;;
+ *) os=Linux
+esac
+
+# 1. Install conda at ./conda
+if [ ! -d "${conda_dir}" ]; then
+ printf "* Installing conda\n"
+ wget -O miniconda.sh "http://repo.continuum.io/miniconda/Miniconda3-latest-${os}-x86_64.sh"
+ bash ./miniconda.sh -b -f -p "${conda_dir}"
+fi
+eval "$(${conda_dir}/bin/conda shell.bash hook)"
+
+# 2. Create test environment at ./env
+printf "python: ${PYTHON_VERSION}\n"
+if [ ! -d "${env_dir}" ]; then
+ printf "* Creating a test environment\n"
+ conda create --prefix "${env_dir}" -y python="$PYTHON_VERSION"
+fi
+conda activate "${env_dir}"
+
+# 3. Install Conda dependencies
+printf "* Installing dependencies (except PyTorch)\n"
+echo " - python=${PYTHON_VERSION}" >> "${this_dir}/environment.yml"
+cat "${this_dir}/environment.yml"
+
+pip3 install pip --upgrade
+
+conda env update --file "${this_dir}/environment.yml" --prune
# ============================================================================================ #
# ================================ PyTorch & TorchRL ========================================= #
-bash ${this_dir}/install.sh
+unset PYTORCH_VERSION
+
+if [ "${CU_VERSION:-}" == cpu ] ; then
+ version="cpu"
+ echo "Using cpu build"
+else
+ if [[ ${#CU_VERSION} -eq 4 ]]; then
+ CUDA_VERSION="${CU_VERSION:2:1}.${CU_VERSION:3:1}"
+ elif [[ ${#CU_VERSION} -eq 5 ]]; then
+ CUDA_VERSION="${CU_VERSION:2:2}.${CU_VERSION:4:1}"
+ fi
+ echo "Using CUDA $CUDA_VERSION as determined by CU_VERSION ($CU_VERSION)"
+ version="$(python -c "print('.'.join(\"${CUDA_VERSION}\".split('.')[:2]))")"
+fi
+
+# submodules
+git submodule sync && git submodule update --init --recursive
+
+printf "Installing PyTorch with %s\n" "${CU_VERSION}"
+if [[ "$TORCH_VERSION" == "nightly" ]]; then
+ if [ "${CU_VERSION:-}" == cpu ] ; then
+ pip3 install --pre torch --index-url https://download.pytorch.org/whl/nightly/cpu -U
+ else
+ pip3 install --pre torch --index-url https://download.pytorch.org/whl/nightly/$CU_VERSION -U
+ fi
+elif [[ "$TORCH_VERSION" == "stable" ]]; then
+ if [ "${CU_VERSION:-}" == cpu ] ; then
+ pip3 install torch --index-url https://download.pytorch.org/whl/cpu -U
+ else
+ pip3 install torch --index-url https://download.pytorch.org/whl/$CU_VERSION -U
+ fi
+else
+ printf "Failed to install pytorch"
+ exit 1
+fi
+
+# smoke test
+python -c "import functorch"
+
+## install snapshot
+#if [[ "$TORCH_VERSION" == "nightly" ]]; then
+# pip3 install git+https://github.com/pytorch/torchsnapshot
+#else
+# pip3 install torchsnapshot
+#fi
+
+# install tensordict
+if [[ "$RELEASE" == 0 ]]; then
+ pip3 install git+https://github.com/pytorch/tensordict.git
+else
+ pip3 install tensordict
+fi
+
+printf "* Installing torchrl\n"
+python setup.py develop
+
+# smoke test
+python -c "import torchrl"
# ==================================================================================== #
# ================================ Run tests ========================================= #
-bash ${this_dir}/run_test.sh
+# find libstdc
+STDC_LOC=$(find conda/ -name "libstdc++.so.6" | head -1)
+
+export PYTORCH_TEST_WITH_SLOW='1'
+export LAZY_LEGACY_OP=False
+python -m torch.utils.collect_env
+# Avoid error: "fatal: unsafe repository"
+git config --global --add safe.directory '*'
+root_dir="$(git rev-parse --show-toplevel)"
+
+export MKL_THREADING_LAYER=GNU
+export CKPT_BACKEND=torch
+export MAX_IDLE_COUNT=100
+export BATCHED_PIPE_TIMEOUT=60
+
+python .github/unittest/helpers/coverage_run_parallel.py -m pytest test \
+ --instafail --durations 200 -vv --capture no --ignore test/test_rlhf.py \
+ --ignore test/test_distributed.py \
+ --timeout=120 --mp_fork_if_no_cuda
+
+coverage combine
+coverage xml -i
# ==================================================================================== #
# ================================ Post-proc ========================================= #
diff --git a/.github/unittest/linux_optdeps/scripts/run_test.sh b/.github/unittest/linux_optdeps/scripts/run_test.sh
deleted file mode 100755
index e3f8c31cfc9..00000000000
--- a/.github/unittest/linux_optdeps/scripts/run_test.sh
+++ /dev/null
@@ -1,25 +0,0 @@
-#!/usr/bin/env bash
-
-set -e
-
-eval "$(./conda/bin/conda shell.bash hook)"
-conda activate ./env
-
-# find libstdc
-STDC_LOC=$(find conda/ -name "libstdc++.so.6" | head -1)
-
-export PYTORCH_TEST_WITH_SLOW='1'
-export LAZY_LEGACY_OP=False
-python -m torch.utils.collect_env
-# Avoid error: "fatal: unsafe repository"
-git config --global --add safe.directory '*'
-root_dir="$(git rev-parse --show-toplevel)"
-export MKL_THREADING_LAYER=GNU
-export CKPT_BACKEND=torch
-export BATCHED_PIPE_TIMEOUT=60
-
-MUJOCO_GL=egl python .github/unittest/helpers/coverage_run_parallel.py -m pytest --instafail \
- -v --durations 200 --ignore test/test_distributed.py --ignore test/test_rlhf.py --capture no \
- --timeout=120 --mp_fork_if_no_cuda
-coverage combine
-coverage xml -i
diff --git a/.github/workflows/benchmarks.yml b/.github/workflows/benchmarks.yml
index 8eaed2fb825..28832d5229b 100644
--- a/.github/workflows/benchmarks.yml
+++ b/.github/workflows/benchmarks.yml
@@ -24,25 +24,36 @@ jobs:
name: CPU Pytest benchmark
runs-on: ubuntu-20.04
steps:
- - uses: actions/checkout@v3
- - uses: actions/setup-python@v4
+ - name: Who triggered this?
+ run: |
+ echo "Action triggered by ${{ github.event.pull_request.html_url }}"
+ - name: Checkout
+ uses: actions/checkout@v3
+ with:
+ fetch-depth: 50 # this is to make sure we obtain the target base commit
+ - name: Python Setup
+ uses: actions/setup-python@v4
with:
- python-version: 3.8
+ python-version: '3.10'
- name: Setup Environment
run: |
- python -m pip install --pre torch --index-url https://download.pytorch.org/whl/nightly/cpu -U
- python -m pip install git+https://github.com/pytorch/tensordict
- python setup.py develop
- python -m pip install pytest pytest-benchmark
+ python3.10 -m venv ./py310
+ source ./py310/bin/activate
+
+ python3 -m pip install --pre torch --index-url https://download.pytorch.org/whl/nightly/cpu -U
+ python3 -m pip install git+https://github.com/pytorch/tensordict
+ python3 setup.py develop
+ python3 -m pip install pytest pytest-benchmark
python3 -m pip install "gym[accept-rom-license,atari]"
- python3 -m pip install dm_control
- - name: Run benchmarks
- run: |
+ python3 -m pip install "dm_control" "mujoco"
+
cd benchmarks/
- python -m pytest --benchmark-json output.json
+ export TORCHDYNAMO_INLINE_INBUILT_NN_MODULES=1
+ export TD_GET_DEFAULTS_TO_NONE=1
+ python3 -m pytest -vvv --rank 0 --benchmark-json output.json --ignore test_collectors_benchmark.py
- name: Store benchmark results
- if: ${{ github.ref == 'refs/heads/main' || github.event_name == 'workflow_dispatch' }}
uses: benchmark-action/github-action-benchmark@v1
+ if: ${{ github.ref == 'refs/heads/main' || github.event_name == 'workflow_dispatch' }}
with:
name: CPU Benchmark Results
tool: 'pytest'
@@ -66,46 +77,73 @@ jobs:
image: nvidia/cuda:12.3.0-base-ubuntu22.04
options: --gpus all
steps:
- - name: Install deps
+ - name: Set GITHUB_BRANCH environment variable
run: |
- export TZ=Europe/London
- export DEBIAN_FRONTEND=noninteractive # tzdata bug
- apt-get update -y
- apt-get install software-properties-common -y
- add-apt-repository ppa:git-core/candidate -y
- apt-get update -y
- apt-get upgrade -y
- apt-get -y install libglu1-mesa libgl1-mesa-glx libosmesa6 gcc curl g++ unzip wget libglfw3-dev libgles2-mesa-dev libglew-dev sudo git cmake libz-dev
+ if [ "${{ github.event_name }}" == "push" ]; then
+ export GITHUB_BRANCH=${{ github.event.branch }}
+ elif [ "${{ github.event_name }}" == "pull_request" ]; then
+ export GITHUB_BRANCH=${{ github.event.pull_request.head.ref }}
+ else
+ echo "Unsupported event type"
+ exit 1
+ fi
+ echo "GITHUB_BRANCH=$GITHUB_BRANCH" >> $GITHUB_ENV
+ - name: Who triggered this?
+ run: |
+ echo "Action triggered by ${{ github.event.pull_request.html_url }}"
- name: Check ldd --version
run: ldd --version
- name: Checkout
uses: actions/checkout@v3
+ with:
+ fetch-depth: 50 # this is to make sure we obtain the target base commit
- name: Python Setup
uses: actions/setup-python@v4
with:
- python-version: 3.8
+ python-version: '3.10'
+ - name: Setup Environment
+ run: |
+ export TZ=Europe/London
+ export DEBIAN_FRONTEND=noninteractive # tzdata bug
+ apt-get update -y
+ apt-get install software-properties-common -y
+ add-apt-repository ppa:git-core/candidate -y
+ apt-get update -y
+ apt-get upgrade -y
+ apt-get -y install libglu1-mesa libgl1-mesa-glx libosmesa6 gcc curl g++ unzip wget libglfw3-dev libgles2-mesa-dev libglew-dev sudo git cmake libz-dev libpython3.10-dev
- name: Setup git
run: git config --global --add safe.directory /__w/rl/rl
- name: setup Path
run: |
echo /usr/local/bin >> $GITHUB_PATH
- - name: Setup Environment
+ - name: Setup benchmarks
run: |
- python3 -m pip install --pre torch torchvision --index-url https://download.pytorch.org/whl/nightly/cu121 -U
- python3 -m pip install git+https://github.com/pytorch/tensordict
- python3 setup.py develop
- python3 -m pip install pytest pytest-benchmark
- python3 -m pip install "gym[accept-rom-license,atari]"
- python3 -m pip install dm_control
- - name: check GPU presence
+ echo "BASE_SHA=$(echo ${{ github.event.pull_request.base.sha }} | cut -c1-8)" >> $GITHUB_ENV
+ echo "HEAD_SHA=$(echo ${{ github.event.pull_request.head.sha }} | cut -c1-8)" >> $GITHUB_ENV
+ echo "BASELINE_JSON=$(mktemp)" >> $GITHUB_ENV
+ echo "CONTENDER_JSON=$(mktemp)" >> $GITHUB_ENV
+ echo "PR_COMMENT=$(mktemp)" >> $GITHUB_ENV
+ - name: Run
run: |
- python -c """import torch
+ python3.10 -m venv --system-site-packages ./py310
+ source ./py310/bin/activate
+ export PYTHON_INCLUDE_DIR=/usr/include/python3.10
+
+ python3.10 -m pip install --pre torch torchvision --index-url https://download.pytorch.org/whl/nightly/cu124 -U
+ python3.10 -m pip install cmake ninja pytest pytest-benchmark mujoco dm_control "gym[accept-rom-license,atari]"
+ python3.10 -m pip install git+https://github.com/pytorch/tensordict
+ python3.10 setup.py develop
+ # python3.10 -m pip install git+https://github.com/pytorch/rl@$GITHUB_BRANCH
+
+ # test import
+ python3 -c """import torch
assert torch.cuda.device_count()
"""
- - name: Run benchmarks
- run: |
+
cd benchmarks/
- python3 -m pytest --benchmark-json output.json
+ export TORCHDYNAMO_INLINE_INBUILT_NN_MODULES=1
+ export TD_GET_DEFAULTS_TO_NONE=1
+ python3 -m pytest -vvv --rank 0 --benchmark-json output.json --ignore test_collectors_benchmark.py
- name: Store benchmark results
uses: benchmark-action/github-action-benchmark@v1
if: ${{ github.ref == 'refs/heads/main' || github.event_name == 'workflow_dispatch' }}
diff --git a/.github/workflows/benchmarks_pr.yml b/.github/workflows/benchmarks_pr.yml
index a8a1bc4c8dc..dfd8850a6f7 100644
--- a/.github/workflows/benchmarks_pr.yml
+++ b/.github/workflows/benchmarks_pr.yml
@@ -1,5 +1,4 @@
name: Continuous Benchmark (PR)
-
on:
pull_request:
@@ -12,6 +11,7 @@ concurrency:
cancel-in-progress: true
jobs:
+
benchmark_cpu:
name: CPU Pytest benchmark
runs-on: ubuntu-20.04
@@ -26,15 +26,7 @@ jobs:
- name: Python Setup
uses: actions/setup-python@v4
with:
- python-version: 3.8
- - name: Setup Environment
- run: |
- python -m pip install --pre torch --index-url https://download.pytorch.org/whl/nightly/cpu -U
- python -m pip install git+https://github.com/pytorch/tensordict
- python setup.py develop
- python -m pip install pytest pytest-benchmark
- python3 -m pip install "gym[accept-rom-license,atari]"
- python3 -m pip install dm_control
+ python-version: '3.10'
- name: Setup benchmarks
run: |
echo "BASE_SHA=$(echo ${{ github.event.pull_request.base.sha }} | cut -c1-8)" >> $GITHUB_ENV
@@ -42,10 +34,22 @@ jobs:
echo "BASELINE_JSON=$(mktemp)" >> $GITHUB_ENV
echo "CONTENDER_JSON=$(mktemp)" >> $GITHUB_ENV
echo "PR_COMMENT=$(mktemp)" >> $GITHUB_ENV
- - name: Run benchmarks
+ - name: Setup Environment and tests
run: |
+ python3.10 -m venv ./py310
+ source ./py310/bin/activate
+
+ python3 -m pip install --pre torch --index-url https://download.pytorch.org/whl/nightly/cpu -U
+ python3 -m pip install git+https://github.com/pytorch/tensordict
+ python3 setup.py develop
+ python3 -m pip install pytest pytest-benchmark
+ python3 -m pip install "gym[accept-rom-license,atari]"
+ python3 -m pip install "dm_control" "mujoco"
+
cd benchmarks/
- RUN_BENCHMARK="pytest --rank 0 --benchmark-json "
+ export TORCHDYNAMO_INLINE_INBUILT_NN_MODULES=1
+ export TD_GET_DEFAULTS_TO_NONE=1
+ RUN_BENCHMARK="python3 -m pytest -vvv --rank 0 --ignore test_collectors_benchmark.py --benchmark-json "
git checkout ${{ github.event.pull_request.base.sha }}
$RUN_BENCHMARK ${{ env.BASELINE_JSON }}
git checkout ${{ github.event.pull_request.head.sha }}
@@ -69,22 +73,23 @@ jobs:
run:
shell: bash -l {0}
container:
- image: nvidia/cuda:12.3.0-base-ubuntu22.04
+ image: nvidia/cuda:12.4.1-cudnn-runtime-ubuntu22.04
options: --gpus all
steps:
+ - name: Set GITHUB_BRANCH environment variable
+ run: |
+ if [ "${{ github.event_name }}" == "push" ]; then
+ export GITHUB_BRANCH=${{ github.event.branch }}
+ elif [ "${{ github.event_name }}" == "pull_request" ]; then
+ export GITHUB_BRANCH=${{ github.event.pull_request.head.ref }}
+ else
+ echo "Unsupported event type"
+ exit 1
+ fi
+ echo "GITHUB_BRANCH=$GITHUB_BRANCH" >> $GITHUB_ENV
- name: Who triggered this?
run: |
echo "Action triggered by ${{ github.event.pull_request.html_url }}"
- - name: Install deps
- run: |
- export TZ=Europe/London
- export DEBIAN_FRONTEND=noninteractive # tzdata bug
- apt-get update -y
- apt-get install software-properties-common -y
- add-apt-repository ppa:git-core/candidate -y
- apt-get update -y
- apt-get upgrade -y
- apt-get -y install libglu1-mesa libgl1-mesa-glx libosmesa6 gcc curl g++ unzip wget libglfw3-dev libgles2-mesa-dev libglew-dev sudo git cmake libz-dev
- name: Check ldd --version
run: ldd --version
- name: Checkout
@@ -94,25 +99,22 @@ jobs:
- name: Python Setup
uses: actions/setup-python@v4
with:
- python-version: 3.8
+ python-version: '3.10'
+ - name: Setup Environment
+ run: |
+ export TZ=Europe/London
+ export DEBIAN_FRONTEND=noninteractive # tzdata bug
+ apt-get update -y
+ apt-get install software-properties-common -y
+ add-apt-repository ppa:git-core/candidate -y
+ apt-get update -y
+ apt-get upgrade -y
+ apt-get -y install libglu1-mesa libgl1-mesa-glx libosmesa6 gcc curl g++ unzip wget libglfw3-dev libgles2-mesa-dev libglew-dev sudo git cmake libz-dev libpython3.10-dev
- name: Setup git
run: git config --global --add safe.directory /__w/rl/rl
- name: setup Path
run: |
echo /usr/local/bin >> $GITHUB_PATH
- - name: Setup Environment
- run: |
- python3 -m pip install --pre torch torchvision --index-url https://download.pytorch.org/whl/nightly/cu121 -U
- python3 -m pip install git+https://github.com/pytorch/tensordict
- python3 setup.py develop
- python3 -m pip install pytest pytest-benchmark
- python3 -m pip install "gym[accept-rom-license,atari]"
- python3 -m pip install dm_control
- - name: check GPU presence
- run: |
- python -c """import torch
- assert torch.cuda.device_count()
- """
- name: Setup benchmarks
run: |
echo "BASE_SHA=$(echo ${{ github.event.pull_request.base.sha }} | cut -c1-8)" >> $GITHUB_ENV
@@ -120,10 +122,27 @@ jobs:
echo "BASELINE_JSON=$(mktemp)" >> $GITHUB_ENV
echo "CONTENDER_JSON=$(mktemp)" >> $GITHUB_ENV
echo "PR_COMMENT=$(mktemp)" >> $GITHUB_ENV
- - name: Run benchmarks
+ - name: Run
run: |
+ python3.10 -m venv --system-site-packages ./py310
+ source ./py310/bin/activate
+ export PYTHON_INCLUDE_DIR=/usr/include/python3.10
+
+ python3.10 -m pip install --pre torch torchvision --index-url https://download.pytorch.org/whl/nightly/cu124 -U
+ python3.10 -m pip install cmake ninja pytest pytest-benchmark mujoco dm_control "gym[accept-rom-license,atari]"
+ python3.10 -m pip install git+https://github.com/pytorch/tensordict
+ python3.10 setup.py develop
+ # python3.10 -m pip install git+https://github.com/pytorch/rl@$GITHUB_BRANCH
+
+ # test import
+ python3 -c """import torch
+ assert torch.cuda.device_count()
+ """
+
cd benchmarks/
- RUN_BENCHMARK="pytest --rank 0 --benchmark-json "
+ export TORCHDYNAMO_INLINE_INBUILT_NN_MODULES=1
+ export TD_GET_DEFAULTS_TO_NONE=1
+ RUN_BENCHMARK="python3 -m pytest -vvv --rank 0 --ignore test_collectors_benchmark.py --benchmark-json "
git checkout ${{ github.event.pull_request.base.sha }}
$RUN_BENCHMARK ${{ env.BASELINE_JSON }}
git checkout ${{ github.event.pull_request.head.sha }}
diff --git a/.github/workflows/build-wheels-aarch64-linux.yml b/.github/workflows/build-wheels-aarch64-linux.yml
new file mode 100644
index 00000000000..63818f07365
--- /dev/null
+++ b/.github/workflows/build-wheels-aarch64-linux.yml
@@ -0,0 +1,51 @@
+name: Build Aarch64 Linux Wheels
+
+on:
+ pull_request:
+ push:
+ branches:
+ - nightly
+ - main
+ - release/*
+ tags:
+ # NOTE: Binary build pipelines should only get triggered on release candidate builds
+ # Release candidate tags look like: v1.11.0-rc1
+ - v[0-9]+.[0-9]+.[0-9]+-rc[0-9]+
+ workflow_dispatch:
+
+permissions:
+ id-token: write
+ contents: read
+
+jobs:
+ generate-matrix:
+ uses: pytorch/test-infra/.github/workflows/generate_binary_build_matrix.yml@main
+ with:
+ package-type: wheel
+ os: linux-aarch64
+ test-infra-repository: pytorch/test-infra
+ test-infra-ref: main
+ with-cuda: disable
+ build:
+ needs: generate-matrix
+ strategy:
+ fail-fast: false
+ matrix:
+ include:
+ - repository: pytorch/rl
+ smoke-test-script: test/smoke_test.py
+ package-name: torchrl
+ name: pytorch/rl
+ uses: pytorch/test-infra/.github/workflows/build_wheels_linux.yml@main
+ with:
+ repository: ${{ matrix.repository }}
+ ref: ""
+ test-infra-repository: pytorch/test-infra
+ test-infra-ref: main
+ build-matrix: ${{ needs.generate-matrix.outputs.matrix }}
+ package-name: ${{ matrix.package-name }}
+ smoke-test-script: ${{ matrix.smoke-test-script }}
+ trigger-event: ${{ github.event_name }}
+ env-var-script: .github/scripts/td_script.sh
+ architecture: aarch64
+ setup-miniconda: false
diff --git a/.github/workflows/build-wheels-windows.yml b/.github/workflows/build-wheels-windows.yml
index 9f2666ccdbf..556d805c643 100644
--- a/.github/workflows/build-wheels-windows.yml
+++ b/.github/workflows/build-wheels-windows.yml
@@ -32,10 +32,12 @@ jobs:
matrix:
include:
- repository: pytorch/rl
+ pre-script: .github/scripts/td_script.sh
+ env-script: .github/scripts/version_script.bat
post-script: "python packaging/wheel/relocate.py"
smoke-test-script: test/smoke_test.py
package-name: torchrl
- name: pytorch/rl
+ name: ${{ matrix.repository }}
uses: pytorch/test-infra/.github/workflows/build_wheels_windows.yml@main
with:
repository: ${{ matrix.repository }}
@@ -43,8 +45,9 @@ jobs:
test-infra-repository: pytorch/test-infra
test-infra-ref: main
build-matrix: ${{ needs.generate-matrix.outputs.matrix }}
+ pre-script: ${{ matrix.pre-script }}
+ env-script: ${{ matrix.env-script }}
+ post-script: ${{ matrix.post-script }}
package-name: ${{ matrix.package-name }}
smoke-test-script: ${{ matrix.smoke-test-script }}
trigger-event: ${{ github.event_name }}
- pre-script: .github/scripts/td_script.sh
- env-script: .github/scripts/version_script.bat
diff --git a/.github/workflows/docs.yml b/.github/workflows/docs.yml
index 749e85b64dd..e153641e775 100644
--- a/.github/workflows/docs.yml
+++ b/.github/workflows/docs.yml
@@ -3,14 +3,13 @@ name: Generate documentation
on:
push:
branches:
+ - nightly
- main
- release/*
tags:
- v[0-9]+.[0-9]+.[0-9]
- v[0-9]+.[0-9]+.[0-9]+-rc[0-9]+
pull_request:
- branches:
- - "*"
workflow_dispatch:
concurrency:
@@ -23,7 +22,7 @@ jobs:
build-docs:
strategy:
matrix:
- python_version: ["3.9"]
+ python_version: ["3.10"]
cuda_arch_version: ["12.1"]
uses: pytorch/test-infra/.github/workflows/linux_job.yml@main
with:
@@ -35,7 +34,7 @@ jobs:
script: |
set -e
set -v
- apt-get update && apt-get install -y git wget gcc g++
+ apt-get update && apt-get install -y -f git wget gcc g++ dialog apt-utils
root_dir="$(pwd)"
conda_dir="${root_dir}/conda"
env_dir="${root_dir}/env"
@@ -47,14 +46,14 @@ jobs:
bash ./miniconda.sh -b -f -p "${conda_dir}"
eval "$(${conda_dir}/bin/conda shell.bash hook)"
printf "* Creating a test environment\n"
- conda create --prefix "${env_dir}" -y python=3.8
+ conda create --prefix "${env_dir}" -y python=3.10
printf "* Activating\n"
conda activate "${env_dir}"
-
+
# 2. upgrade pip, ninja and packaging
- apt-get install python3.8 python3-pip -y
+ apt-get install python3-pip unzip -y -f
python3 -m pip install --upgrade pip
- python3 -m pip install setuptools ninja packaging -U
+ python3 -m pip install setuptools ninja packaging cmake -U
# 3. check python version
python3 --version
diff --git a/.github/workflows/nightly_build.yml b/.github/workflows/nightly_build.yml
index c7a7d344157..08eb61bfa6c 100644
--- a/.github/workflows/nightly_build.yml
+++ b/.github/workflows/nightly_build.yml
@@ -39,7 +39,7 @@ jobs:
runs-on: ubuntu-20.04
strategy:
matrix:
- python_version: [["3.8", "cp38-cp38"], ["3.9", "cp39-cp39"], ["3.10", "cp310-cp310"], ["3.11", "cp311-cp311"], ["3.12", "cp312-cp312"]]
+ python_version: [["3.9", "cp39-cp39"], ["3.10", "cp310-cp310"], ["3.11", "cp311-cp311"], ["3.12", "cp312-cp312"]]
cuda_support: [["", "cpu", "cpu"]]
container: pytorch/manylinux-${{ matrix.cuda_support[2] }}
steps:
@@ -79,7 +79,7 @@ jobs:
runs-on: ubuntu-20.04
strategy:
matrix:
- python_version: [["3.8", "cp38-cp38"], ["3.9", "cp39-cp39"], ["3.10", "cp310-cp310"], ["3.11", "cp311-cp311"], ["3.12", "cp312-cp312"]]
+ python_version: [["3.9", "cp39-cp39"], ["3.10", "cp310-cp310"], ["3.11", "cp311-cp311"], ["3.12", "cp312-cp312"]]
cuda_support: [["", "cpu", "cpu"]]
container: pytorch/manylinux-${{ matrix.cuda_support[2] }}
steps:
@@ -110,7 +110,7 @@ jobs:
runs-on: ubuntu-20.04
strategy:
matrix:
- python_version: [["3.8", "cp38-cp38"], ["3.9", "cp39-cp39"], ["3.10", "cp310-cp310"], ["3.11", "cp311-cp311"], ["3.12", "cp312-cp312"]]
+ python_version: [["3.9", "cp39-cp39"], ["3.10", "cp310-cp310"], ["3.11", "cp311-cp311"], ["3.12", "cp312-cp312"]]
cuda_support: [["", "cpu", "cpu"]]
steps:
- name: Setup Python
@@ -172,7 +172,7 @@ jobs:
runs-on: windows-latest
strategy:
matrix:
- python_version: [["3.8", "3.8"], ["3.9", "3.9"], ["3.10", "3.10.3"], ["3.11", "3.11"], ["3.12", "3.12"]]
+ python_version: [["3.9", "3.9"], ["3.10", "3.10.3"], ["3.11", "3.11"], ["3.12", "3.12"]]
steps:
- name: Setup Python
uses: actions/setup-python@v5
@@ -205,7 +205,7 @@ jobs:
runs-on: windows-latest
strategy:
matrix:
- python_version: [["3.8", "3.8"], ["3.9", "3.9"], ["3.10", "3.10.3"], ["3.11", "3.11"], ["3.12", "3.12"]]
+ python_version: [["3.9", "3.9"], ["3.10", "3.10.3"], ["3.11", "3.11"], ["3.12", "3.12"]]
steps:
- name: Setup Python
uses: actions/setup-python@v5
@@ -262,7 +262,7 @@ jobs:
runs-on: windows-latest
strategy:
matrix:
- python_version: [["3.8", "3.8"], ["3.9", "3.9"], ["3.10", "3.10.3"], ["3.11", "3.11"], ["3.12", "3.12"]]
+ python_version: [["3.9", "3.9"], ["3.10", "3.10.3"], ["3.11", "3.11"], ["3.12", "3.12"]]
steps:
- name: Checkout torchrl
uses: actions/checkout@v3
diff --git a/.github/workflows/test-linux-examples.yml b/.github/workflows/test-linux-examples.yml
index fd0adaf6ed5..39c97fae266 100644
--- a/.github/workflows/test-linux-examples.yml
+++ b/.github/workflows/test-linux-examples.yml
@@ -49,6 +49,7 @@ jobs:
echo "PYTHON_VERSION: $PYTHON_VERSION"
echo "CU_VERSION: $CU_VERSION"
+ export TD_GET_DEFAULTS_TO_NONE=1
## setup_env.sh
bash .github/unittest/linux_examples/scripts/run_all.sh
diff --git a/.github/workflows/test-linux-habitat.yml b/.github/workflows/test-linux-habitat.yml
index 3f6e89a70f9..6a1c52f90fa 100644
--- a/.github/workflows/test-linux-habitat.yml
+++ b/.github/workflows/test-linux-habitat.yml
@@ -46,5 +46,6 @@ jobs:
export CU_VERSION="cu${CUDA_ARCH_VERSION:0:2}${CUDA_ARCH_VERSION:3:1}"
# Remove the following line when the GPU tests are working inside docker, and uncomment the above lines
#export CU_VERSION="cpu"
+ export TD_GET_DEFAULTS_TO_NONE=1
bash .github/unittest/linux_libs/scripts_habitat/run_all.sh
diff --git a/.github/workflows/test-linux-libs.yml b/.github/workflows/test-linux-libs.yml
index 9e1875cac18..bd394f39fa7 100644
--- a/.github/workflows/test-linux-libs.yml
+++ b/.github/workflows/test-linux-libs.yml
@@ -44,6 +44,7 @@ jobs:
export TAR_OPTIONS="--no-same-owner"
export UPLOAD_CHANNEL="nightly"
export TF_CPP_MIN_LOG_LEVEL=0
+ export TD_GET_DEFAULTS_TO_NONE=1
bash .github/unittest/linux_libs/scripts_ataridqn/setup_env.sh
bash .github/unittest/linux_libs/scripts_ataridqn/install.sh
@@ -81,6 +82,7 @@ jobs:
export UPLOAD_CHANNEL="nightly"
export TF_CPP_MIN_LOG_LEVEL=0
export BATCHED_PIPE_TIMEOUT=60
+ export TD_GET_DEFAULTS_TO_NONE=1
nvidia-smi
@@ -114,6 +116,7 @@ jobs:
export UPLOAD_CHANNEL="nightly"
export TF_CPP_MIN_LOG_LEVEL=0
export BATCHED_PIPE_TIMEOUT=60
+ export TD_GET_DEFAULTS_TO_NONE=1
bash .github/unittest/linux_libs/scripts_d4rl/setup_env.sh
bash .github/unittest/linux_libs/scripts_d4rl/install.sh
@@ -148,6 +151,7 @@ jobs:
export UPLOAD_CHANNEL="nightly"
export TF_CPP_MIN_LOG_LEVEL=0
export BATCHED_PIPE_TIMEOUT=60
+ export TD_GET_DEFAULTS_TO_NONE=1
bash .github/unittest/linux_libs/scripts_d4rl/setup_env.sh
bash .github/unittest/linux_libs/scripts_d4rl/install.sh
@@ -181,6 +185,7 @@ jobs:
export TAR_OPTIONS="--no-same-owner"
export UPLOAD_CHANNEL="nightly"
export TF_CPP_MIN_LOG_LEVEL=0
+ export TD_GET_DEFAULTS_TO_NONE=1
bash .github/unittest/linux_libs/scripts_gen-dgrl/setup_env.sh
bash .github/unittest/linux_libs/scripts_gen-dgrl/install.sh
@@ -216,6 +221,7 @@ jobs:
export LD_LIBRARY_PATH="$LD_LIBRARY_PATH:/work/mujoco-py/mujoco_py/binaries/linux/mujoco210/bin"
export TAR_OPTIONS="--no-same-owner"
export BATCHED_PIPE_TIMEOUT=60
+ export TD_GET_DEFAULTS_TO_NONE=1
./.github/unittest/linux_libs/scripts_gym/setup_env.sh
./.github/unittest/linux_libs/scripts_gym/batch_scripts.sh
@@ -251,6 +257,7 @@ jobs:
export UPLOAD_CHANNEL="nightly"
export TF_CPP_MIN_LOG_LEVEL=0
export BATCHED_PIPE_TIMEOUT=60
+ export TD_GET_DEFAULTS_TO_NONE=1
nvidia-smi
@@ -285,6 +292,7 @@ jobs:
export UPLOAD_CHANNEL="nightly"
export TF_CPP_MIN_LOG_LEVEL=0
export BATCHED_PIPE_TIMEOUT=60
+ export TD_GET_DEFAULTS_TO_NONE=1
nvidia-smi
@@ -293,6 +301,82 @@ jobs:
bash .github/unittest/linux_libs/scripts_meltingpot/run_test.sh
bash .github/unittest/linux_libs/scripts_meltingpot/post_process.sh
+ unittests-open_spiel:
+ strategy:
+ matrix:
+ python_version: ["3.9"]
+ cuda_arch_version: ["12.1"]
+ if: ${{ github.event_name == 'push' || contains(github.event.pull_request.labels.*.name, 'Environments') }}
+ uses: pytorch/test-infra/.github/workflows/linux_job.yml@main
+ with:
+ repository: pytorch/rl
+ runner: "linux.g5.4xlarge.nvidia.gpu"
+ gpu-arch-type: cuda
+ gpu-arch-version: "11.7"
+ docker-image: "pytorch/manylinux-cuda124"
+ timeout: 120
+ script: |
+ if [[ "${{ github.ref }}" =~ release/* ]]; then
+ export RELEASE=1
+ export TORCH_VERSION=stable
+ else
+ export RELEASE=0
+ export TORCH_VERSION=nightly
+ fi
+
+ set -euo pipefail
+ export PYTHON_VERSION="3.9"
+ export CU_VERSION="12.1"
+ export TAR_OPTIONS="--no-same-owner"
+ export UPLOAD_CHANNEL="nightly"
+ export TF_CPP_MIN_LOG_LEVEL=0
+ export BATCHED_PIPE_TIMEOUT=60
+
+ nvidia-smi
+
+ bash .github/unittest/linux_libs/scripts_open_spiel/setup_env.sh
+ bash .github/unittest/linux_libs/scripts_open_spiel/install.sh
+ bash .github/unittest/linux_libs/scripts_open_spiel/run_test.sh
+ bash .github/unittest/linux_libs/scripts_open_spiel/post_process.sh
+
+ unittests-unity_mlagents:
+ strategy:
+ matrix:
+ python_version: ["3.10.12"]
+ cuda_arch_version: ["12.1"]
+ if: ${{ github.event_name == 'push' || contains(github.event.pull_request.labels.*.name, 'Environments') }}
+ uses: pytorch/test-infra/.github/workflows/linux_job.yml@main
+ with:
+ repository: pytorch/rl
+ runner: "linux.g5.4xlarge.nvidia.gpu"
+ gpu-arch-type: cuda
+ gpu-arch-version: "11.7"
+ docker-image: "pytorch/manylinux-cuda124"
+ timeout: 120
+ script: |
+ if [[ "${{ github.ref }}" =~ release/* ]]; then
+ export RELEASE=1
+ export TORCH_VERSION=stable
+ else
+ export RELEASE=0
+ export TORCH_VERSION=nightly
+ fi
+
+ set -euo pipefail
+ export PYTHON_VERSION="3.10.12"
+ export CU_VERSION="12.1"
+ export TAR_OPTIONS="--no-same-owner"
+ export UPLOAD_CHANNEL="nightly"
+ export TF_CPP_MIN_LOG_LEVEL=0
+ export BATCHED_PIPE_TIMEOUT=60
+
+ nvidia-smi
+
+ bash .github/unittest/linux_libs/scripts_unity_mlagents/setup_env.sh
+ bash .github/unittest/linux_libs/scripts_unity_mlagents/install.sh
+ bash .github/unittest/linux_libs/scripts_unity_mlagents/run_test.sh
+ bash .github/unittest/linux_libs/scripts_unity_mlagents/post_process.sh
+
unittests-minari:
strategy:
matrix:
@@ -321,6 +405,7 @@ jobs:
export UPLOAD_CHANNEL="nightly"
export TF_CPP_MIN_LOG_LEVEL=0
export BATCHED_PIPE_TIMEOUT=60
+ export TD_GET_DEFAULTS_TO_NONE=1
bash .github/unittest/linux_libs/scripts_minari/setup_env.sh
bash .github/unittest/linux_libs/scripts_minari/install.sh
@@ -355,6 +440,7 @@ jobs:
export UPLOAD_CHANNEL="nightly"
export TF_CPP_MIN_LOG_LEVEL=0
export BATCHED_PIPE_TIMEOUT=60
+ export TD_GET_DEFAULTS_TO_NONE=1
bash .github/unittest/linux_libs/scripts_openx/setup_env.sh
bash .github/unittest/linux_libs/scripts_openx/install.sh
@@ -387,6 +473,7 @@ jobs:
export UPLOAD_CHANNEL="nightly"
export TF_CPP_MIN_LOG_LEVEL=0
export BATCHED_PIPE_TIMEOUT=60
+ export TD_GET_DEFAULTS_TO_NONE=1
nvidia-smi
@@ -423,6 +510,7 @@ jobs:
export UPLOAD_CHANNEL="nightly"
export TF_CPP_MIN_LOG_LEVEL=0
export BATCHED_PIPE_TIMEOUT=60
+ export TD_GET_DEFAULTS_TO_NONE=1
bash .github/unittest/linux_libs/scripts_robohive/setup_env.sh
bash .github/unittest/linux_libs/scripts_robohive/install_and_run_test.sh
@@ -456,6 +544,7 @@ jobs:
export UPLOAD_CHANNEL="nightly"
export TF_CPP_MIN_LOG_LEVEL=0
export BATCHED_PIPE_TIMEOUT=60
+ export TD_GET_DEFAULTS_TO_NONE=1
bash .github/unittest/linux_libs/scripts_roboset/setup_env.sh
bash .github/unittest/linux_libs/scripts_roboset/install.sh
@@ -491,6 +580,7 @@ jobs:
export UPLOAD_CHANNEL="nightly"
export TF_CPP_MIN_LOG_LEVEL=0
export BATCHED_PIPE_TIMEOUT=60
+ export TD_GET_DEFAULTS_TO_NONE=1
bash .github/unittest/linux_libs/scripts_sklearn/setup_env.sh
bash .github/unittest/linux_libs/scripts_sklearn/install.sh
@@ -527,6 +617,7 @@ jobs:
export UPLOAD_CHANNEL="nightly"
export TF_CPP_MIN_LOG_LEVEL=0
export BATCHED_PIPE_TIMEOUT=60
+ export TD_GET_DEFAULTS_TO_NONE=1
nvidia-smi
@@ -563,6 +654,7 @@ jobs:
export UPLOAD_CHANNEL="nightly"
export TF_CPP_MIN_LOG_LEVEL=0
export BATCHED_PIPE_TIMEOUT=60
+ export TD_GET_DEFAULTS_TO_NONE=1
bash .github/unittest/linux_libs/scripts_vd4rl/setup_env.sh
bash .github/unittest/linux_libs/scripts_vd4rl/install.sh
@@ -599,6 +691,7 @@ jobs:
export UPLOAD_CHANNEL="nightly"
export TF_CPP_MIN_LOG_LEVEL=0
export BATCHED_PIPE_TIMEOUT=60
+ export TD_GET_DEFAULTS_TO_NONE=1
nvidia-smi
diff --git a/.github/workflows/test-linux-rlhf.yml b/.github/workflows/test-linux-rlhf.yml
index 832d432c997..accbe6e7610 100644
--- a/.github/workflows/test-linux-rlhf.yml
+++ b/.github/workflows/test-linux-rlhf.yml
@@ -44,6 +44,7 @@ jobs:
export TAR_OPTIONS="--no-same-owner"
export UPLOAD_CHANNEL="nightly"
export TF_CPP_MIN_LOG_LEVEL=0
+ export TD_GET_DEFAULTS_TO_NONE=1
bash .github/unittest/linux_libs/scripts_rlhf/setup_env.sh
bash .github/unittest/linux_libs/scripts_rlhf/install.sh
diff --git a/.github/workflows/test-linux.yml b/.github/workflows/test-linux.yml
index e8728180c67..75a646c25c4 100644
--- a/.github/workflows/test-linux.yml
+++ b/.github/workflows/test-linux.yml
@@ -22,7 +22,7 @@ jobs:
tests-cpu:
strategy:
matrix:
- python_version: ["3.8", "3.9", "3.10", "3.11", "3.12"]
+ python_version: ["3.9", "3.10", "3.11", "3.12"]
fail-fast: false
uses: pytorch/test-infra/.github/workflows/linux_job.yml@main
with:
@@ -38,6 +38,39 @@ jobs:
export RELEASE=0
export TORCH_VERSION=nightly
fi
+ export TD_GET_DEFAULTS_TO_NONE=1
+ # Set env vars from matrix
+ export PYTHON_VERSION=${{ matrix.python_version }}
+ export CU_VERSION="cpu"
+
+ echo "PYTHON_VERSION: $PYTHON_VERSION"
+ echo "CU_VERSION: $CU_VERSION"
+
+ ## setup_env.sh
+ bash .github/unittest/linux/scripts/run_all.sh
+
+ tests-cpu-oldget:
+ # Tests that TD_GET_DEFAULTS_TO_NONE=0 works fine as this will be the default for TD up to 0.7
+ strategy:
+ matrix:
+ python_version: ["3.12"]
+ fail-fast: false
+ uses: pytorch/test-infra/.github/workflows/linux_job.yml@main
+ with:
+ runner: linux.12xlarge
+ repository: pytorch/rl
+ docker-image: "nvidia/cuda:12.2.0-devel-ubuntu22.04"
+ timeout: 90
+ script: |
+ if [[ "${{ github.ref }}" =~ release/* ]]; then
+ export RELEASE=1
+ export TORCH_VERSION=stable
+ else
+ export RELEASE=0
+ export TORCH_VERSION=nightly
+ fi
+ export TD_GET_DEFAULTS_TO_NONE=0
+
# Set env vars from matrix
export PYTHON_VERSION=${{ matrix.python_version }}
export CU_VERSION="cpu"
@@ -75,6 +108,8 @@ jobs:
export RELEASE=0
export TORCH_VERSION=nightly
fi
+ export TD_GET_DEFAULTS_TO_NONE=1
+
# Remove the following line when the GPU tests are working inside docker, and uncomment the above lines
#export CU_VERSION="cpu"
@@ -110,6 +145,7 @@ jobs:
export TORCH_VERSION=nightly
fi
export TF_CPP_MIN_LOG_LEVEL=0
+ export TD_GET_DEFAULTS_TO_NONE=1
bash .github/unittest/linux_olddeps/scripts_gym_0_13/setup_env.sh
@@ -119,8 +155,8 @@ jobs:
tests-optdeps:
strategy:
matrix:
- python_version: ["3.10"] # "3.8", "3.9", "3.10", "3.11"
- cuda_arch_version: ["12.1"] # "11.6", "11.7"
+ python_version: ["3.11"]
+ cuda_arch_version: ["12.1"]
fail-fast: false
uses: pytorch/test-infra/.github/workflows/linux_job.yml@main
with:
@@ -136,9 +172,6 @@ jobs:
# Commenting these out for now because the GPU test are not working inside docker
export CUDA_ARCH_VERSION=${{ matrix.cuda_arch_version }}
export CU_VERSION="cu${CUDA_ARCH_VERSION:0:2}${CUDA_ARCH_VERSION:3:1}"
- # Remove the following line when the GPU tests are working inside docker, and uncomment the above lines
- #export CU_VERSION="cpu"
-
if [[ "${{ github.ref }}" =~ release/* ]]; then
export RELEASE=1
export TORCH_VERSION=stable
@@ -146,6 +179,7 @@ jobs:
export RELEASE=0
export TORCH_VERSION=nightly
fi
+ export TD_GET_DEFAULTS_TO_NONE=1
echo "PYTHON_VERSION: $PYTHON_VERSION"
echo "CU_VERSION: $CU_VERSION"
@@ -156,7 +190,7 @@ jobs:
tests-stable-gpu:
strategy:
matrix:
- python_version: ["3.10"] # "3.8", "3.9", "3.10", "3.11"
+ python_version: ["3.10"] # "3.9", "3.10", "3.11"
cuda_arch_version: ["11.8"] # "11.6", "11.7"
fail-fast: false
uses: pytorch/test-infra/.github/workflows/linux_job.yml@main
@@ -187,6 +221,7 @@ jobs:
echo "PYTHON_VERSION: $PYTHON_VERSION"
echo "CU_VERSION: $CU_VERSION"
+ export TD_GET_DEFAULTS_TO_NONE=1
## setup_env.sh
bash .github/unittest/linux/scripts/run_all.sh
diff --git a/.github/workflows/test-windows-optdepts.yml b/.github/workflows/test-windows-optdepts.yml
index e98b6c1810e..14a8dd7ab13 100644
--- a/.github/workflows/test-windows-optdepts.yml
+++ b/.github/workflows/test-windows-optdepts.yml
@@ -42,6 +42,7 @@ jobs:
export RELEASE=0
export TORCH_VERSION=nightly
fi
+ export TD_GET_DEFAULTS_TO_NONE=1
## setup_env.sh
./.github/unittest/windows_optdepts/scripts/setup_env.sh
diff --git a/.github/workflows/wheels-legacy.yml b/.github/workflows/wheels-legacy.yml
index 80dd2640e17..998b9eba4b2 100644
--- a/.github/workflows/wheels-legacy.yml
+++ b/.github/workflows/wheels-legacy.yml
@@ -5,6 +5,7 @@ on:
push:
branches:
- release/*
+ - main
concurrency:
# Documentation suggests ${{ github.head_ref }}, but that's only available on pull_request/pull_request_target triggers, so using ${{ github.ref }}.
@@ -18,7 +19,7 @@ jobs:
runs-on: windows-latest
strategy:
matrix:
- python_version: [["3.8", "3.8"], ["3.9", "3.9"], ["3.10", "3.10.3"], ["3.11", "3.11"], ["3.12", "3.12"]]
+ python_version: [["3.9", "3.9"], ["3.10", "3.10.3"], ["3.11", "3.11"], ["3.12", "3.12"]]
steps:
- name: Setup Python
uses: actions/setup-python@v2
@@ -36,12 +37,12 @@ jobs:
python3 -mpip install wheel
TORCHRL_BUILD_VERSION=0.5.0 python3 setup.py bdist_wheel
- name: Upload wheel for the test-wheel job
- uses: actions/upload-artifact@v2
+ uses: actions/upload-artifact@v3
with:
name: torchrl-win-${{ matrix.python_version[0] }}.whl
path: dist/torchrl-*.whl
- name: Upload wheel for download
- uses: actions/upload-artifact@v2
+ uses: actions/upload-artifact@v3
with:
name: torchrl-batch.whl
path: dist/*.whl
@@ -50,7 +51,7 @@ jobs:
needs: build-wheel-windows
strategy:
matrix:
- python_version: [ "3.8", "3.9", "3.10", "3.11", "3.12" ]
+ python_version: ["3.9", "3.10", "3.11", "3.12" ]
runs-on: windows-latest
steps:
- name: Setup Python
@@ -76,7 +77,7 @@ jobs:
run: |
python3 -mpip install numpy pytest pytest-cov codecov unittest-xml-reporting pillow>=4.1.1 scipy av networkx expecttest pyyaml
- name: Download built wheels
- uses: actions/download-artifact@v2
+ uses: actions/download-artifact@v3
with:
name: torchrl-win-${{ matrix.python_version }}.whl
path: wheels
diff --git a/README.md b/README.md
index f82a8ff0c4c..abcf7349192 100644
--- a/README.md
+++ b/README.md
@@ -99,68 +99,69 @@ lines of code*!
from torchrl.collectors import SyncDataCollector
from torchrl.data.replay_buffers import TensorDictReplayBuffer, \
- LazyTensorStorage, SamplerWithoutReplacement
+ LazyTensorStorage, SamplerWithoutReplacement
from torchrl.envs.libs.gym import GymEnv
from torchrl.modules import ProbabilisticActor, ValueOperator, TanhNormal
from torchrl.objectives import ClipPPOLoss
from torchrl.objectives.value import GAE
- env = GymEnv("Pendulum-v1")
+ env = GymEnv("Pendulum-v1")
model = TensorDictModule(
- nn.Sequential(
- nn.Linear(3, 128), nn.Tanh(),
- nn.Linear(128, 128), nn.Tanh(),
- nn.Linear(128, 128), nn.Tanh(),
- nn.Linear(128, 2),
- NormalParamExtractor()
- ),
- in_keys=["observation"],
- out_keys=["loc", "scale"]
+ nn.Sequential(
+ nn.Linear(3, 128), nn.Tanh(),
+ nn.Linear(128, 128), nn.Tanh(),
+ nn.Linear(128, 128), nn.Tanh(),
+ nn.Linear(128, 2),
+ NormalParamExtractor()
+ ),
+ in_keys=["observation"],
+ out_keys=["loc", "scale"]
)
critic = ValueOperator(
- nn.Sequential(
- nn.Linear(3, 128), nn.Tanh(),
- nn.Linear(128, 128), nn.Tanh(),
- nn.Linear(128, 128), nn.Tanh(),
- nn.Linear(128, 1),
- ),
- in_keys=["observation"],
+ nn.Sequential(
+ nn.Linear(3, 128), nn.Tanh(),
+ nn.Linear(128, 128), nn.Tanh(),
+ nn.Linear(128, 128), nn.Tanh(),
+ nn.Linear(128, 1),
+ ),
+ in_keys=["observation"],
)
actor = ProbabilisticActor(
- model,
- in_keys=["loc", "scale"],
- distribution_class=TanhNormal,
- distribution_kwargs={"min": -1.0, "max": 1.0},
- return_log_prob=True
- )
+ model,
+ in_keys=["loc", "scale"],
+ distribution_class=TanhNormal,
+ distribution_kwargs={"low": -1.0, "high": 1.0},
+ return_log_prob=True
+ )
buffer = TensorDictReplayBuffer(
- LazyTensorStorage(1000),
- SamplerWithoutReplacement()
- )
+ storage=LazyTensorStorage(1000),
+ sampler=SamplerWithoutReplacement(),
+ batch_size=50,
+ )
collector = SyncDataCollector(
- env,
- actor,
- frames_per_batch=1000,
- total_frames=1_000_000
- )
- loss_fn = ClipPPOLoss(actor, critic, gamma=0.99)
+ env,
+ actor,
+ frames_per_batch=1000,
+ total_frames=1_000_000,
+ )
+ loss_fn = ClipPPOLoss(actor, critic)
+ adv_fn = GAE(value_network=critic, average_gae=True, gamma=0.99, lmbda=0.95)
optim = torch.optim.Adam(loss_fn.parameters(), lr=2e-4)
- adv_fn = GAE(value_network=critic, gamma=0.99, lmbda=0.95, average_gae=True)
+
for data in collector: # collect data
- for epoch in range(10):
- adv_fn(data) # compute advantage
- buffer.extend(data.view(-1))
- for i in range(20): # consume data
- sample = buffer.sample(50) # mini-batch
- loss_vals = loss_fn(sample)
- loss_val = sum(
- value for key, value in loss_vals.items() if
- key.startswith("loss")
- )
- loss_val.backward()
- optim.step()
- optim.zero_grad()
- print(f"avg reward: {data['next', 'reward'].mean().item(): 4.4f}")
+ for epoch in range(10):
+ adv_fn(data) # compute advantage
+ buffer.extend(data)
+ for sample in buffer: # consume data
+ loss_vals = loss_fn(sample)
+ loss_val = sum(
+ value for key, value in loss_vals.items() if
+ key.startswith("loss")
+ )
+ loss_val.backward()
+ optim.step()
+ optim.zero_grad()
+ print(f"avg reward: {data['next', 'reward'].mean().item(): 4.4f}")
```
@@ -478,7 +479,7 @@ And it is `functorch` and `torch.compile` compatible!
policy_explore = EGreedyWrapper(policy)
with set_exploration_type(ExplorationType.RANDOM):
tensordict = policy_explore(tensordict) # will use eps-greedy
- with set_exploration_type(ExplorationType.MODE):
+ with set_exploration_type(ExplorationType.DETERMINISTIC):
tensordict = policy_explore(tensordict) # will not use eps-greedy
```
@@ -521,23 +522,288 @@ If you would like to contribute to new features, check our [call for contributio
## Examples, tutorials and demos
-A series of [examples](https://github.com/pytorch/rl/blob/main/examples/) are provided with an illustrative purpose:
-- [DQN](https://github.com/pytorch/rl/blob/main/sota-implementations/dqn)
-- [DDPG](https://github.com/pytorch/rl/blob/main/sota-implementations/ddpg/ddpg.py)
-- [IQL](https://github.com/pytorch/rl/blob/main/sota-implementations/iql/iql_offline.py)
-- [CQL](https://github.com/pytorch/rl/blob/main/sota-implementations/cql/cql_offline.py)
-- [TD3](https://github.com/pytorch/rl/blob/main/sota-implementations/td3/td3.py)
-- [TD3+BC](https://github.com/pytorch/rl/blob/main/sota-implementations/td3+bc/td3+bc.py)
-- [A2C](https://github.com/pytorch/rl/blob/main/examples/a2c_old/a2c.py)
-- [PPO](https://github.com/pytorch/rl/blob/main/sota-implementations/ppo/ppo.py)
-- [SAC](https://github.com/pytorch/rl/blob/main/sota-implementations/sac/sac.py)
-- [REDQ](https://github.com/pytorch/rl/blob/main/sota-implementations/redq/redq.py)
-- [Dreamer](https://github.com/pytorch/rl/blob/main/sota-implementations/dreamer/dreamer.py)
-- [Decision Transformers](https://github.com/pytorch/rl/blob/main/sota-implementations/decision_transformer)
-- [RLHF](https://github.com/pytorch/rl/blob/main/examples/rlhf)
+A series of [State-of-the-Art implementations](https://github.com/pytorch/rl/blob/main/sota-implementations/) are provided with an illustrative purpose:
+
+
+
+ Algorithm
+ |
+ Compile Support**
+ |
+ Tensordict-free API
+ |
+ Modular Losses
+ |
+ Continuous and Discrete
+ |
+
+
+ DQN
+ |
+ 1.9x
+ |
+ +
+ |
+ NA
+ |
+ + (through ActionDiscretizer transform)
+ |
+
+
+ DDPG
+ |
+ 1.87x
+ |
+ +
+ |
+ +
+ |
+ - (continuous only)
+ |
+
+
+ IQL
+ |
+ 3.22x
+ |
+ +
+ |
+ +
+ |
+ +
+ |
+
+
+ CQL
+ |
+ 2.68x
+ |
+ +
+ |
+ +
+ |
+ +
+ |
+
+
+ TD3
+ |
+ 2.27x
+ |
+ +
+ |
+ +
+ |
+ - (continuous only)
+ |
+
+
+
+ TD3+BC
+ |
+ untested
+ |
+ +
+ |
+ +
+ |
+ - (continuous only)
+ |
+
+
+
+ A2C
+ |
+ 2.67x
+ |
+ +
+ |
+ -
+ |
+ +
+ |
+
+
+
+ PPO
+ |
+ 2.42x
+ |
+ +
+ |
+ -
+ |
+ +
+ |
+
+
+ SAC
+ |
+ 2.62x
+ |
+ +
+ |
+ -
+ |
+ +
+ |
+
+
+ REDQ
+ |
+ 2.28x
+ |
+ +
+ |
+ -
+ |
+ - (continuous only)
+ |
+
+
+ Dreamer v1
+ |
+ untested
+ |
+ +
+ |
+ + (different classes)
+ |
+ - (continuous only)
+ |
+
+
+ Decision Transformers
+ |
+ untested
+ |
+ +
+ |
+ NA
+ |
+ - (continuous only)
+ |
+
+
+ CrossQ
+ |
+ untested
+ |
+ +
+ |
+ +
+ |
+ - (continuous only)
+ |
+
+
+ Gail
+ |
+ untested
+ |
+ +
+ |
+ NA
+ |
+ +
+ |
+
+
+ Impala
+ |
+ untested
+ |
+ +
+ |
+ -
+ |
+ +
+ |
+
+
+ IQL (MARL)
+ |
+ untested
+ |
+ +
+ |
+ +
+ |
+ +
+ |
+
+
+ DDPG (MARL)
+ |
+ untested
+ |
+ +
+ |
+ +
+ |
+ - (continuous only)
+ |
+
+
+ PPO (MARL)
+ |
+ untested
+ |
+ +
+ |
+ -
+ |
+ +
+ |
+
+
+ QMIX-VDN (MARL)
+ |
+ untested
+ |
+ +
+ |
+ NA
+ |
+ +
+ |
+
+
+ SAC (MARL)
+ |
+ untested
+ |
+ +
+ |
+ -
+ |
+ +
+ |
+
+
+ RLHF
+ |
+ NA
+ |
+ +
+ |
+ NA
+ |
+ NA
+ |
+
+
+
+** The number indicates expected speed-up compared to eager mode when executed on CPU. Numbers may vary depending on
+ architecture and device.
and many more to come!
+[Code examples](examples/) displaying toy code snippets and training scripts are also available
+- [RLHF](examples/rlhf)
+- [Memory-mapped replay buffers](examples/torchrl_features)
+
+
Check the [examples](https://github.com/pytorch/rl/blob/main/sota-implementations/) directory for more details
about handling the various configuration settings.
@@ -592,7 +858,7 @@ Importantly, the nightly builds require the nightly builds of PyTorch too.
To install extra dependencies, call
```bash
-pip3 install "torchrl[atari,dm_control,gym_continuous,rendering,tests,utils,marl,checkpointing]"
+pip3 install "torchrl[atari,dm_control,gym_continuous,rendering,tests,utils,marl,open_spiel,checkpointing]"
```
or a subset of these.
diff --git a/benchmarks/test_collectors_benchmark.py b/benchmarks/test_collectors_benchmark.py
index 1bdd26c0746..f2273d5cc3f 100644
--- a/benchmarks/test_collectors_benchmark.py
+++ b/benchmarks/test_collectors_benchmark.py
@@ -3,16 +3,20 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import argparse
+import time
import pytest
import torch.cuda
+import tqdm
from torchrl.collectors import SyncDataCollector
from torchrl.collectors.collectors import (
MultiaSyncDataCollector,
MultiSyncDataCollector,
)
-from torchrl.envs import EnvCreator, GymEnv, StepCounter, TransformedEnv
+from torchrl.data import LazyTensorStorage, ReplayBuffer
+from torchrl.data.utils import CloudpickleWrapper
+from torchrl.envs import EnvCreator, GymEnv, ParallelEnv, StepCounter, TransformedEnv
from torchrl.envs.libs.dm_control import DMControlEnv
from torchrl.envs.utils import RandomPolicy
@@ -180,6 +184,57 @@ def test_async_pixels(benchmark):
benchmark(execute_collector, c)
+class TestRBGCollector:
+ @pytest.mark.parametrize(
+ "n_col,n_wokrers_per_col",
+ [
+ [2, 2],
+ [4, 2],
+ [8, 2],
+ [16, 2],
+ [2, 1],
+ [4, 1],
+ [8, 1],
+ [16, 1],
+ ],
+ )
+ def test_multiasync_rb(self, n_col, n_wokrers_per_col):
+ make_env = EnvCreator(lambda: GymEnv("ALE/Pong-v5"))
+ if n_wokrers_per_col > 1:
+ make_env = ParallelEnv(n_wokrers_per_col, make_env)
+ env = make_env
+ policy = RandomPolicy(env.action_spec)
+ else:
+ env = make_env()
+ policy = RandomPolicy(env.action_spec)
+
+ storage = LazyTensorStorage(10_000)
+ rb = ReplayBuffer(storage=storage)
+ rb.extend(env.rollout(2, policy).reshape(-1))
+ rb.append_transform(CloudpickleWrapper(lambda x: x.reshape(-1)), invert=True)
+
+ fpb = n_wokrers_per_col * 100
+ total_frames = n_wokrers_per_col * 100_000
+ c = MultiaSyncDataCollector(
+ [make_env] * n_col,
+ policy,
+ frames_per_batch=fpb,
+ total_frames=total_frames,
+ replay_buffer=rb,
+ )
+ frames = 0
+ pbar = tqdm.tqdm(total=total_frames - (n_col * fpb))
+ for i, _ in enumerate(c):
+ if i == n_col:
+ t0 = time.time()
+ if i >= n_col:
+ frames += fpb
+ if i > n_col:
+ fps = frames / (time.time() - t0)
+ pbar.update(fpb)
+ pbar.set_description(f"fps: {fps: 4.4f}")
+
+
if __name__ == "__main__":
args, unknown = argparse.ArgumentParser().parse_known_args()
pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown)
diff --git a/benchmarks/test_objectives_benchmarks.py b/benchmarks/test_objectives_benchmarks.py
index 4cfc8470a15..d07b40595bc 100644
--- a/benchmarks/test_objectives_benchmarks.py
+++ b/benchmarks/test_objectives_benchmarks.py
@@ -6,9 +6,11 @@
import pytest
import torch
+from packaging import version
from tensordict import TensorDict
from tensordict.nn import (
+ InteractionType,
NormalParamExtractor,
ProbabilisticTensorDictModule as ProbMod,
ProbabilisticTensorDictSequential as ProbSeq,
@@ -16,7 +18,7 @@
TensorDictSequential as Seq,
)
from torch.nn import functional as F
-from torchrl.data.tensor_specs import BoundedTensorSpec, UnboundedContinuousTensorSpec
+from torchrl.data.tensor_specs import Bounded, Unbounded
from torchrl.modules import MLP, QValueActor, TanhNormal
from torchrl.objectives import (
A2CLoss,
@@ -42,6 +44,20 @@
vec_td_lambda_return_estimate,
)
+TORCH_VERSION = torch.__version__
+FULLGRAPH = version.parse(".".join(TORCH_VERSION.split(".")[:3])) >= version.parse(
+ "2.5.0"
+) # Anything from 2.5, incl. nightlies, allows for fullgraph
+
+
+@pytest.fixture(scope="module")
+def set_default_device():
+ cur_device = torch.get_default_device()
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
+ torch.set_default_device(device)
+ yield
+ torch.set_default_device(cur_device)
+
class setup_value_fn:
def __init__(self, has_lmbda, has_state_value):
@@ -137,7 +153,26 @@ def test_gae_speed(benchmark, gae_fn, gamma_tensor, batches, timesteps):
)
-def test_dqn_speed(benchmark, n_obs=8, n_act=4, depth=3, ncells=128, batch=128):
+def _maybe_compile(fn, compile, td, fullgraph=FULLGRAPH, warmup=3):
+ if compile:
+ if isinstance(compile, str):
+ fn = torch.compile(fn, mode=compile, fullgraph=fullgraph)
+ else:
+ fn = torch.compile(fn, fullgraph=fullgraph)
+
+ for _ in range(warmup):
+ fn(td)
+
+ return fn
+
+
+@pytest.mark.parametrize("backward", [None, "backward"])
+@pytest.mark.parametrize("compile", [False, True, "reduce-overhead"])
+def test_dqn_speed(
+ benchmark, backward, compile, n_obs=8, n_act=4, depth=3, ncells=128, batch=128
+):
+ if compile:
+ torch._dynamo.reset_code_caches()
net = MLP(in_features=n_obs, out_features=n_act, depth=depth, num_cells=ncells)
action_space = "one-hot"
mod = QValueActor(net, in_keys=["obs"], action_space=action_space)
@@ -155,10 +190,36 @@ def test_dqn_speed(benchmark, n_obs=8, n_act=4, depth=3, ncells=128, batch=128):
[batch],
)
loss(td)
- benchmark(loss, td)
+ loss = _maybe_compile(loss, compile, td)
+
+ if backward:
+
+ def loss_and_bw(td):
+ losses = loss(td)
+ sum(
+ [val for key, val in losses.items() if key.startswith("loss")]
+ ).backward()
+
+ benchmark.pedantic(
+ loss_and_bw,
+ args=(td,),
+ setup=loss.zero_grad,
+ iterations=1,
+ warmup_rounds=5,
+ rounds=50,
+ )
+ else:
+ benchmark(loss, td)
-def test_ddpg_speed(benchmark, n_obs=8, n_act=4, ncells=128, batch=128, n_hidden=64):
+
+@pytest.mark.parametrize("backward", [None, "backward"])
+@pytest.mark.parametrize("compile", [False, True, "reduce-overhead"])
+def test_ddpg_speed(
+ benchmark, backward, compile, n_obs=8, n_act=4, ncells=128, batch=128, n_hidden=64
+):
+ if compile:
+ torch._dynamo.reset_code_caches()
common = MLP(
num_cells=ncells,
in_features=n_obs,
@@ -200,10 +261,36 @@ def test_ddpg_speed(benchmark, n_obs=8, n_act=4, ncells=128, batch=128, n_hidden
loss = DDPGLoss(actor, value)
loss(td)
- benchmark(loss, td)
+ loss = _maybe_compile(loss, compile, td)
+
+ if backward:
+
+ def loss_and_bw(td):
+ losses = loss(td)
+ sum(
+ [val for key, val in losses.items() if key.startswith("loss")]
+ ).backward()
+
+ benchmark.pedantic(
+ loss_and_bw,
+ args=(td,),
+ setup=loss.zero_grad,
+ iterations=1,
+ warmup_rounds=5,
+ rounds=50,
+ )
+ else:
+ benchmark(loss, td)
-def test_sac_speed(benchmark, n_obs=8, n_act=4, ncells=128, batch=128, n_hidden=64):
+
+@pytest.mark.parametrize("backward", [None, "backward"])
+@pytest.mark.parametrize("compile", [False, True, "reduce-overhead"])
+def test_sac_speed(
+ benchmark, backward, compile, n_obs=8, n_act=4, ncells=128, batch=128, n_hidden=64
+):
+ if compile:
+ torch._dynamo.reset_code_caches()
common = MLP(
num_cells=ncells,
in_features=n_obs,
@@ -245,23 +332,48 @@ def test_sac_speed(benchmark, n_obs=8, n_act=4, ncells=128, batch=128, n_hidden=
in_keys=["loc", "scale"],
out_keys=["action"],
distribution_class=TanhNormal,
+ distribution_kwargs={"safe_tanh": False},
),
)
value_head = Mod(
value, in_keys=["hidden", "action"], out_keys=["state_action_value"]
)
value = Seq(common, value_head)
- value(actor(td))
+ value(actor(td.clone()))
- loss = SACLoss(
- actor, value, action_spec=UnboundedContinuousTensorSpec(shape=(n_act,))
- )
+ loss = SACLoss(actor, value, action_spec=Unbounded(shape=(n_act,)))
loss(td)
- benchmark(loss, td)
+ loss = _maybe_compile(loss, compile, td)
+
+ if backward:
-def test_redq_speed(benchmark, n_obs=8, n_act=4, ncells=128, batch=128, n_hidden=64):
+ def loss_and_bw(td):
+ losses = loss(td)
+ sum(
+ [val for key, val in losses.items() if key.startswith("loss")]
+ ).backward()
+
+ benchmark.pedantic(
+ loss_and_bw,
+ args=(td,),
+ setup=loss.zero_grad,
+ iterations=1,
+ warmup_rounds=5,
+ rounds=50,
+ )
+ else:
+ benchmark(loss, td)
+
+
+@pytest.mark.parametrize("backward", [None, "backward"])
+@pytest.mark.parametrize("compile", [False, True, "reduce-overhead"])
+def test_redq_speed(
+ benchmark, backward, compile, n_obs=8, n_act=4, ncells=128, batch=128, n_hidden=64
+):
+ if compile:
+ torch._dynamo.reset_code_caches()
common = MLP(
num_cells=ncells,
in_features=n_obs,
@@ -304,25 +416,50 @@ def test_redq_speed(benchmark, n_obs=8, n_act=4, ncells=128, batch=128, n_hidden
out_keys=["action"],
distribution_class=TanhNormal,
return_log_prob=True,
+ distribution_kwargs={"safe_tanh": False},
),
)
value_head = Mod(
value, in_keys=["hidden", "action"], out_keys=["state_action_value"]
)
value = Seq(common, value_head)
- value(actor(td))
+ value(actor(td.copy()))
- loss = REDQLoss(
- actor, value, action_spec=UnboundedContinuousTensorSpec(shape=(n_act,))
- )
+ loss = REDQLoss(actor, value, action_spec=Unbounded(shape=(n_act,)))
loss(td)
- benchmark(loss, td)
+ loss = _maybe_compile(loss, compile, td)
+
+ if backward:
+
+ def loss_and_bw(td):
+ losses = loss(td)
+ totalloss = sum(
+ [val for key, val in losses.items() if key.startswith("loss")]
+ )
+ totalloss.backward()
+
+ loss_and_bw(td)
+
+ benchmark.pedantic(
+ loss_and_bw,
+ args=(td,),
+ setup=loss.zero_grad,
+ iterations=1,
+ warmup_rounds=5,
+ rounds=50,
+ )
+ else:
+ benchmark(loss, td)
+@pytest.mark.parametrize("backward", [None, "backward"])
+@pytest.mark.parametrize("compile", [False, True, "reduce-overhead"])
def test_redq_deprec_speed(
- benchmark, n_obs=8, n_act=4, ncells=128, batch=128, n_hidden=64
+ benchmark, backward, compile, n_obs=8, n_act=4, ncells=128, batch=128, n_hidden=64
):
+ if compile:
+ torch._dynamo.reset_code_caches()
common = MLP(
num_cells=ncells,
in_features=n_obs,
@@ -365,23 +502,48 @@ def test_redq_deprec_speed(
out_keys=["action"],
distribution_class=TanhNormal,
return_log_prob=True,
+ distribution_kwargs={"safe_tanh": False},
),
)
value_head = Mod(
value, in_keys=["hidden", "action"], out_keys=["state_action_value"]
)
value = Seq(common, value_head)
- value(actor(td))
+ value(actor(td.copy()))
- loss = REDQLoss_deprecated(
- actor, value, action_spec=UnboundedContinuousTensorSpec(shape=(n_act,))
- )
+ loss = REDQLoss_deprecated(actor, value, action_spec=Unbounded(shape=(n_act,)))
loss(td)
- benchmark(loss, td)
+ loss = _maybe_compile(loss, compile, td)
+
+ if backward:
+
+ def loss_and_bw(td):
+ losses = loss(td)
+ sum(
+ [val for key, val in losses.items() if key.startswith("loss")]
+ ).backward()
+
+ benchmark.pedantic(
+ loss_and_bw,
+ args=(td,),
+ setup=loss.zero_grad,
+ iterations=1,
+ warmup_rounds=5,
+ rounds=50,
+ )
+ else:
+ benchmark(loss, td)
-def test_td3_speed(benchmark, n_obs=8, n_act=4, ncells=128, batch=128, n_hidden=64):
+
+@pytest.mark.parametrize("backward", [None, "backward"])
+@pytest.mark.parametrize("compile", [False, True, "reduce-overhead"])
+def test_td3_speed(
+ benchmark, backward, compile, n_obs=8, n_act=4, ncells=128, batch=128, n_hidden=64
+):
+ if compile:
+ torch._dynamo.reset_code_caches()
common = MLP(
num_cells=ncells,
in_features=n_obs,
@@ -423,26 +585,54 @@ def test_td3_speed(benchmark, n_obs=8, n_act=4, ncells=128, batch=128, n_hidden=
in_keys=["loc", "scale"],
out_keys=["action"],
distribution_class=TanhNormal,
+ distribution_kwargs={"safe_tanh": False},
return_log_prob=True,
+ default_interaction_type=InteractionType.DETERMINISTIC,
),
)
value_head = Mod(
value, in_keys=["hidden", "action"], out_keys=["state_action_value"]
)
value = Seq(common, value_head)
- value(actor(td))
+ value(actor(td.clone()))
loss = TD3Loss(
actor,
value,
- action_spec=BoundedTensorSpec(shape=(n_act,), low=-1, high=1),
+ action_spec=Bounded(shape=(n_act,), low=-1, high=1),
)
loss(td)
- benchmark.pedantic(loss, args=(td,), rounds=100, iterations=10)
+ loss = _maybe_compile(loss, compile, td)
+
+ if backward:
+
+ def loss_and_bw(td):
+ losses = loss(td)
+ sum(
+ [val for key, val in losses.items() if key.startswith("loss")]
+ ).backward()
-def test_cql_speed(benchmark, n_obs=8, n_act=4, ncells=128, batch=128, n_hidden=64):
+ benchmark.pedantic(
+ loss_and_bw,
+ args=(td,),
+ setup=loss.zero_grad,
+ iterations=1,
+ warmup_rounds=5,
+ rounds=50,
+ )
+ else:
+ benchmark.pedantic(loss, args=(td,), rounds=100, iterations=10)
+
+
+@pytest.mark.parametrize("backward", [None, "backward"])
+@pytest.mark.parametrize("compile", [False, True, "reduce-overhead"])
+def test_cql_speed(
+ benchmark, backward, compile, n_obs=8, n_act=4, ncells=128, batch=128, n_hidden=64
+):
+ if compile:
+ torch._dynamo.reset_code_caches()
common = MLP(
num_cells=ncells,
in_features=n_obs,
@@ -481,26 +671,59 @@ def test_cql_speed(benchmark, n_obs=8, n_act=4, ncells=128, batch=128, n_hidden=
Mod(actor_net, in_keys=["hidden"], out_keys=["param"]),
Mod(NormalParamExtractor(), in_keys=["param"], out_keys=["loc", "scale"]),
ProbMod(
- in_keys=["loc", "scale"], out_keys=["action"], distribution_class=TanhNormal
+ in_keys=["loc", "scale"],
+ out_keys=["action"],
+ distribution_class=TanhNormal,
+ distribution_kwargs={"safe_tanh": False},
),
)
value_head = Mod(
value, in_keys=["hidden", "action"], out_keys=["state_action_value"]
)
value = Seq(common, value_head)
- value(actor(td))
+ value(actor(td.copy()))
- loss = CQLLoss(
- actor, value, action_spec=UnboundedContinuousTensorSpec(shape=(n_act,))
- )
+ loss = CQLLoss(actor, value, action_spec=Unbounded(shape=(n_act,)))
loss(td)
- benchmark(loss, td)
+ loss = _maybe_compile(loss, compile, td)
+
+ if backward:
+
+ def loss_and_bw(td):
+ losses = loss(td)
+ sum(
+ [val for key, val in losses.items() if key.startswith("loss")]
+ ).backward()
+ benchmark.pedantic(
+ loss_and_bw,
+ args=(td,),
+ setup=loss.zero_grad,
+ iterations=1,
+ warmup_rounds=5,
+ rounds=50,
+ )
+ else:
+ benchmark(loss, td)
+
+
+@pytest.mark.parametrize("backward", [None, "backward"])
+@pytest.mark.parametrize("compile", [False, True, "reduce-overhead"])
def test_a2c_speed(
- benchmark, n_obs=8, n_act=4, n_hidden=64, ncells=128, batch=128, T=10
+ benchmark,
+ backward,
+ compile,
+ n_obs=8,
+ n_act=4,
+ n_hidden=64,
+ ncells=128,
+ batch=128,
+ T=10,
):
+ if compile:
+ torch._dynamo.reset_code_caches()
common_net = MLP(
num_cells=ncells,
in_features=n_obs,
@@ -541,7 +764,10 @@ def test_a2c_speed(
Mod(actor_net, in_keys=["hidden"], out_keys=["param"]),
Mod(NormalParamExtractor(), in_keys=["param"], out_keys=["loc", "scale"]),
ProbMod(
- in_keys=["loc", "scale"], out_keys=["action"], distribution_class=TanhNormal
+ in_keys=["loc", "scale"],
+ out_keys=["action"],
+ distribution_class=TanhNormal,
+ distribution_kwargs={"safe_tanh": False},
),
)
critic = Seq(common, Mod(value_net, in_keys=["hidden"], out_keys=["state_value"]))
@@ -552,12 +778,44 @@ def test_a2c_speed(
advantage = GAE(value_network=critic, gamma=0.99, lmbda=0.95, shifted=True)
advantage(td)
loss(td)
- benchmark(loss, td)
+ loss = _maybe_compile(loss, compile, td)
+ if backward:
+
+ def loss_and_bw(td):
+ losses = loss(td)
+ sum(
+ [val for key, val in losses.items() if key.startswith("loss")]
+ ).backward()
+
+ benchmark.pedantic(
+ loss_and_bw,
+ args=(td,),
+ setup=loss.zero_grad,
+ iterations=1,
+ warmup_rounds=5,
+ rounds=50,
+ )
+ else:
+ benchmark(loss, td)
+
+
+@pytest.mark.parametrize("backward", [None, "backward"])
+@pytest.mark.parametrize("compile", [False, True, "reduce-overhead"])
def test_ppo_speed(
- benchmark, n_obs=8, n_act=4, n_hidden=64, ncells=128, batch=128, T=10
+ benchmark,
+ backward,
+ compile,
+ n_obs=8,
+ n_act=4,
+ n_hidden=64,
+ ncells=128,
+ batch=128,
+ T=10,
):
+ if compile:
+ torch._dynamo.reset_code_caches()
common_net = MLP(
num_cells=ncells,
in_features=n_obs,
@@ -598,7 +856,10 @@ def test_ppo_speed(
Mod(actor_net, in_keys=["hidden"], out_keys=["param"]),
Mod(NormalParamExtractor(), in_keys=["param"], out_keys=["loc", "scale"]),
ProbMod(
- in_keys=["loc", "scale"], out_keys=["action"], distribution_class=TanhNormal
+ in_keys=["loc", "scale"],
+ out_keys=["action"],
+ distribution_class=TanhNormal,
+ distribution_kwargs={"safe_tanh": False},
),
)
critic = Seq(common, Mod(value_net, in_keys=["hidden"], out_keys=["state_value"]))
@@ -609,12 +870,44 @@ def test_ppo_speed(
advantage = GAE(value_network=critic, gamma=0.99, lmbda=0.95, shifted=True)
advantage(td)
loss(td)
- benchmark(loss, td)
+ loss = _maybe_compile(loss, compile, td)
+
+ if backward:
+
+ def loss_and_bw(td):
+ losses = loss(td)
+ sum(
+ [val for key, val in losses.items() if key.startswith("loss")]
+ ).backward()
+
+ benchmark.pedantic(
+ loss_and_bw,
+ args=(td,),
+ setup=loss.zero_grad,
+ iterations=1,
+ warmup_rounds=5,
+ rounds=50,
+ )
+ else:
+ benchmark(loss, td)
+
+@pytest.mark.parametrize("backward", [None, "backward"])
+@pytest.mark.parametrize("compile", [False, True, "reduce-overhead"])
def test_reinforce_speed(
- benchmark, n_obs=8, n_act=4, n_hidden=64, ncells=128, batch=128, T=10
+ benchmark,
+ backward,
+ compile,
+ n_obs=8,
+ n_act=4,
+ n_hidden=64,
+ ncells=128,
+ batch=128,
+ T=10,
):
+ if compile:
+ torch._dynamo.reset_code_caches()
common_net = MLP(
num_cells=ncells,
in_features=n_obs,
@@ -655,7 +948,10 @@ def test_reinforce_speed(
Mod(actor_net, in_keys=["hidden"], out_keys=["param"]),
Mod(NormalParamExtractor(), in_keys=["param"], out_keys=["loc", "scale"]),
ProbMod(
- in_keys=["loc", "scale"], out_keys=["action"], distribution_class=TanhNormal
+ in_keys=["loc", "scale"],
+ out_keys=["action"],
+ distribution_class=TanhNormal,
+ distribution_kwargs={"safe_tanh": False},
),
)
critic = Seq(common, Mod(value_net, in_keys=["hidden"], out_keys=["state_value"]))
@@ -666,12 +962,44 @@ def test_reinforce_speed(
advantage = GAE(value_network=critic, gamma=0.99, lmbda=0.95, shifted=True)
advantage(td)
loss(td)
- benchmark(loss, td)
+ loss = _maybe_compile(loss, compile, td)
+
+ if backward:
+ def loss_and_bw(td):
+ losses = loss(td)
+ sum(
+ [val for key, val in losses.items() if key.startswith("loss")]
+ ).backward()
+
+ benchmark.pedantic(
+ loss_and_bw,
+ args=(td,),
+ setup=loss.zero_grad,
+ iterations=1,
+ warmup_rounds=5,
+ rounds=50,
+ )
+ else:
+ benchmark(loss, td)
+
+
+@pytest.mark.parametrize("backward", [None, "backward"])
+@pytest.mark.parametrize("compile", [False, True, "reduce-overhead"])
def test_iql_speed(
- benchmark, n_obs=8, n_act=4, n_hidden=64, ncells=128, batch=128, T=10
+ benchmark,
+ backward,
+ compile,
+ n_obs=8,
+ n_act=4,
+ n_hidden=64,
+ ncells=128,
+ batch=128,
+ T=10,
):
+ if compile:
+ torch._dynamo.reset_code_caches()
common_net = MLP(
num_cells=ncells,
in_features=n_obs,
@@ -718,7 +1046,10 @@ def test_iql_speed(
Mod(actor_net, in_keys=["hidden"], out_keys=["param"]),
Mod(NormalParamExtractor(), in_keys=["param"], out_keys=["loc", "scale"]),
ProbMod(
- in_keys=["loc", "scale"], out_keys=["action"], distribution_class=TanhNormal
+ in_keys=["loc", "scale"],
+ out_keys=["action"],
+ distribution_class=TanhNormal,
+ distribution_kwargs={"safe_tanh": False},
),
)
value = Seq(common, Mod(value_net, in_keys=["hidden"], out_keys=["state_value"]))
@@ -731,7 +1062,27 @@ def test_iql_speed(
loss = IQLLoss(actor_network=actor, value_network=value, qvalue_network=qvalue)
loss(td)
- benchmark(loss, td)
+
+ loss = _maybe_compile(loss, compile, td)
+
+ if backward:
+
+ def loss_and_bw(td):
+ losses = loss(td)
+ sum(
+ [val for key, val in losses.items() if key.startswith("loss")]
+ ).backward()
+
+ benchmark.pedantic(
+ loss_and_bw,
+ args=(td,),
+ setup=loss.zero_grad,
+ iterations=1,
+ warmup_rounds=5,
+ rounds=50,
+ )
+ else:
+ benchmark(loss, td)
if __name__ == "__main__":
diff --git a/docs/requirements.txt b/docs/requirements.txt
index f6138cac30a..702a2884421 100644
--- a/docs/requirements.txt
+++ b/docs/requirements.txt
@@ -15,9 +15,8 @@ sphinx_design
torchvision
dm_control
-atari-py
-ale-py
-gym[classic_control,accept-rom-license]
+mujoco
+gym[classic_control,accept-rom-license,ale-py,atari]
pygame
tqdm
ipython
diff --git a/docs/source/_static/img/collector-copy.png b/docs/source/_static/img/collector-copy.png
new file mode 100644
index 00000000000..8a8921cacca
Binary files /dev/null and b/docs/source/_static/img/collector-copy.png differ
diff --git a/docs/source/_static/img/rename_transform.png b/docs/source/_static/img/rename_transform.png
new file mode 100644
index 00000000000..3de518362cd
Binary files /dev/null and b/docs/source/_static/img/rename_transform.png differ
diff --git a/docs/source/reference/collectors.rst b/docs/source/reference/collectors.rst
index 74bd058b8f0..41b743dce15 100644
--- a/docs/source/reference/collectors.rst
+++ b/docs/source/reference/collectors.rst
@@ -45,7 +45,7 @@ worker) may also impact the memory management. The key parameters to control are
:obj:`devices` which controls the execution devices (ie the device of the policy)
and :obj:`storing_device` which will control the device where the environment and
data are stored during a rollout. A good heuristic is usually to use the same device
-for storage and compute, which is the default behaviour when only the `devices` argument
+for storage and compute, which is the default behavior when only the `devices` argument
is being passed.
Besides those compute parameters, users may choose to configure the following parameters:
@@ -99,6 +99,25 @@ delivers batches of data on a first-come, first-serve basis, whereas
:class:`~torchrl.collectors.MultiSyncDataCollector` gathers data from
each sub-collector before delivering it.
+Collectors and policy copies
+----------------------------
+
+When passing a policy to a collector, we can choose the device on which this policy will be run. This can be used to
+keep the training version of the policy on a device and the inference version on another. For example, if you have two
+CUDA devices, it may be wise to train on one device and execute the policy for inference on the other. If that is the
+case, a :meth:`~torchrl.collectors.DataCollector.update_policy_weights_` can be used to copy the parameters from one
+device to the other (if no copy is required, this method is a no-op).
+
+Since the goal is to avoid calling `policy.to(policy_device)` explicitly, the collector will do a deepcopy of the
+policy structure and copy the parameters placed on the new device during instantiation if necessary.
+Since not all policies support deepcopies (e.g., policies using CUDA graphs or relying on third-party libraries), we
+try to limit the cases where a deepcopy will be executed. The following chart shows when this will occur.
+
+.. figure:: /_static/img/collector-copy.png
+
+ Policy copy decision tree in Collectors.
+
+
Collectors and replay buffers interoperability
----------------------------------------------
diff --git a/docs/source/reference/data.rst b/docs/source/reference/data.rst
index 0dca499f4d9..6fbeada5bd0 100644
--- a/docs/source/reference/data.rst
+++ b/docs/source/reference/data.rst
@@ -171,7 +171,7 @@ using the following components:
Storage choice is very influential on replay buffer sampling latency, especially
in distributed reinforcement learning settings with larger data volumes.
:class:`~torchrl.data.replay_buffers.storages.LazyMemmapStorage` is highly
-advised in distributed settings with shared storage due to the lower serialisation
+advised in distributed settings with shared storage due to the lower serialization
cost of MemoryMappedTensors as well as the ability to specify file storage locations
for improved node failure recovery.
The following mean sampling latency improvements over using :class:`~torchrl.data.replay_buffers.ListStorage`
@@ -877,11 +877,58 @@ TensorSpec
.. _ref_specs:
-The `TensorSpec` parent class and subclasses define the basic properties of observations and actions in TorchRL, such
-as shape, device, dtype and domain.
+The :class:`~torchrl.data.TensorSpec` parent class and subclasses define the basic properties of state, observations
+actions, rewards and done status in TorchRL, such as their shape, device, dtype and domain.
+
It is important that your environment specs match the input and output that it sends and receives, as
-:obj:`ParallelEnv` will create buffers from these specs to communicate with the spawn processes.
-Check the :obj:`torchrl.envs.utils.check_env_specs` method for a sanity check.
+:class:`~torchrl.envs.ParallelEnv` will create buffers from these specs to communicate with the spawn processes.
+Check the :func:`torchrl.envs.utils.check_env_specs` method for a sanity check.
+
+If needed, specs can be automatially generated from data using the :func:`~torchrl.envs.utils.make_composite_from_td`
+function.
+
+Specs fall in two main categories, numerical and categorical.
+
+.. table:: Numerical TensorSpec subclasses.
+
+ +-------------------------------------------------------------------------------+
+ | Numerical |
+ +=====================================+=========================================+
+ | Bounded | Unbounded |
+ +-----------------+-------------------+-------------------+---------------------+
+ | BoundedDiscrete | BoundedContinuous | UnboundedDiscrete | UnboundedContinuous |
+ +-----------------+-------------------+-------------------+---------------------+
+
+Whenever a :class:`~torchrl.data.Bounded` instance is created, its domain (defined either implicitly by its dtype or
+explicitly by the `"domain"` keyword argument) will determine if the instantiated class will be of :class:`~torchrl.data.BoundedContinuous`
+or :class:`~torchrl.data.BoundedDiscrete` type. The same applies to the :class:`~torchrl.data.Unbounded` class.
+See these classes for further information.
+
+.. table:: Categorical TensorSpec subclasses.
+
+ +------------------------------------------------------------------+
+ | Categorical |
+ +========+=============+=============+==================+==========+
+ | OneHot | MultiOneHot | Categorical | MultiCategorical | Binary |
+ +--------+-------------+-------------+------------------+----------+
+
+Unlike ``gymnasium``, TorchRL does not have the concept of an arbitrary list of specs. If multiple specs have to be
+combined together, TorchRL assumes that the data will be presented as dictionaries (more specifically, as
+:class:`~tensordict.TensorDict` or related formats). The corresponding :class:`~torchrl.data.TensorSpec` class in these
+cases is the :class:`~torchrl.data.Composite` spec.
+
+Nevertheless, specs can be stacked together using :func:`~torch.stack`: if they are identical, their shape will be
+expanded accordingly.
+Otherwise, a lazy stack will be created through the :class:`~torchrl.data.Stacked` class.
+
+Similarly, ``TensorSpecs`` possess some common behavior with :class:`~torch.Tensor` and
+:class:`~tensordict.TensorDict`: they can be reshaped, indexed, squeezed, unsqueezed, moved to another device (``to``)
+or unbound (``unbind``) as regular :class:`~torch.Tensor` instances would be.
+
+Specs where some dimensions are ``-1`` are said to be "dynamic" and the negative dimensions indicate that the corresponding
+data has an inconsistent shape. When seen by an optimizer or an environment (e.g., batched environment such as
+:class:`~torchrl.envs.ParallelEnv`), these negative shapes tell TorchRL to avoid using buffers as the tensor shapes are
+not predictable.
.. currentmodule:: torchrl.data
@@ -890,19 +937,40 @@ Check the :obj:`torchrl.envs.utils.check_env_specs` method for a sanity check.
:template: rl_template.rst
TensorSpec
+ Binary
+ Bounded
+ Categorical
+ Composite
+ MultiCategorical
+ MultiOneHot
+ NonTensor
+ OneHotDiscrete
+ Stacked
+ StackedComposite
+ Unbounded
+ UnboundedContinuous
+ UnboundedDiscrete
+
+The following classes are deprecated and just point to the classes above:
+
+.. currentmodule:: torchrl.data
+
+.. autosummary::
+ :toctree: generated/
+ :template: rl_template.rst
+
BinaryDiscreteTensorSpec
BoundedTensorSpec
CompositeSpec
DiscreteTensorSpec
+ LazyStackedCompositeSpec
+ LazyStackedTensorSpec
MultiDiscreteTensorSpec
MultiOneHotDiscreteTensorSpec
NonTensorSpec
OneHotDiscreteTensorSpec
UnboundedContinuousTensorSpec
UnboundedDiscreteTensorSpec
- LazyStackedTensorSpec
- LazyStackedCompositeSpec
- NonTensorSpec
Reinforcement Learning From Human Feedback (RLHF)
-------------------------------------------------
diff --git a/docs/source/reference/envs.rst b/docs/source/reference/envs.rst
index 11a5bb041a6..960daf0fb12 100644
--- a/docs/source/reference/envs.rst
+++ b/docs/source/reference/envs.rst
@@ -28,9 +28,9 @@ Each env will have the following attributes:
This is especially useful for transforms (see below). For parametric environments (e.g.
model-based environments), the device does represent the hardware that will be used to
compute the operations.
-- :obj:`env.observation_spec`: a :class:`~torchrl.data.CompositeSpec` object
+- :obj:`env.observation_spec`: a :class:`~torchrl.data.Composite` object
containing all the observation key-spec pairs.
-- :obj:`env.state_spec`: a :class:`~torchrl.data.CompositeSpec` object
+- :obj:`env.state_spec`: a :class:`~torchrl.data.Composite` object
containing all the input key-spec pairs (except action). For most stateful
environments, this container will be empty.
- :obj:`env.action_spec`: a :class:`~torchrl.data.TensorSpec` object
@@ -39,10 +39,10 @@ Each env will have the following attributes:
the reward spec.
- :obj:`env.done_spec`: a :class:`~torchrl.data.TensorSpec` object representing
the done-flag spec. See the section on trajectory termination below.
-- :obj:`env.input_spec`: a :class:`~torchrl.data.CompositeSpec` object containing
+- :obj:`env.input_spec`: a :class:`~torchrl.data.Composite` object containing
all the input keys (:obj:`"full_action_spec"` and :obj:`"full_state_spec"`).
It is locked and should not be modified directly.
-- :obj:`env.output_spec`: a :class:`~torchrl.data.CompositeSpec` object containing
+- :obj:`env.output_spec`: a :class:`~torchrl.data.Composite` object containing
all the output keys (:obj:`"full_observation_spec"`, :obj:`"full_reward_spec"` and :obj:`"full_done_spec"`).
It is locked and should not be modified directly.
@@ -318,7 +318,7 @@ have on an environment returning zeros after reset:
We also offer the :class:`~.SerialEnv` class that enjoys the exact same API but is executed
serially. This is mostly useful for testing purposes, when one wants to assess the
-behaviour of a :class:`~.ParallelEnv` without launching the subprocesses.
+behavior of a :class:`~.ParallelEnv` without launching the subprocesses.
In addition to :class:`~.ParallelEnv`, which offers process-based parallelism, we also provide a way to create
multithreaded environments with :obj:`~.MultiThreadedEnv`. This class uses `EnvPool `_
@@ -433,28 +433,28 @@ only the done flag is shared across agents (as in VMAS):
... action_specs.append(agent_i_action_spec)
... reward_specs.append(agent_i_reward_spec)
... observation_specs.append(agent_i_observation_spec)
- >>> env.action_spec = CompositeSpec(
+ >>> env.action_spec = Composite(
... {
- ... "agents": CompositeSpec(
+ ... "agents": Composite(
... {"action": torch.stack(action_specs)}, shape=(env.n_agents,)
... )
... }
...)
- >>> env.reward_spec = CompositeSpec(
+ >>> env.reward_spec = Composite(
... {
- ... "agents": CompositeSpec(
+ ... "agents": Composite(
... {"reward": torch.stack(reward_specs)}, shape=(env.n_agents,)
... )
... }
...)
- >>> env.observation_spec = CompositeSpec(
+ >>> env.observation_spec = Composite(
... {
- ... "agents": CompositeSpec(
+ ... "agents": Composite(
... {"observation": torch.stack(observation_specs)}, shape=(env.n_agents,)
... )
... }
...)
- >>> env.done_spec = DiscreteTensorSpec(
+ >>> env.done_spec = Categorical(
... n=2,
... shape=torch.Size((1,)),
... dtype=torch.bool,
@@ -499,7 +499,7 @@ current episode.
To handle these cases, torchrl provides a :class:`~torchrl.envs.AutoResetTransform` that will copy the observations
that result from the call to `step` to the next `reset` and skip the calls to `reset` during rollouts (in both
:meth:`~torchrl.envs.EnvBase.rollout` and :class:`~torchrl.collectors.SyncDataCollector` iterations).
-This transform class also provides a fine-grained control over the behaviour to be adopted for the invalid observations,
+This transform class also provides a fine-grained control over the behavior to be adopted for the invalid observations,
which can be masked with `"nan"` or any other values, or not masked at all.
To tell torchrl that an environment is auto-resetting, it is sufficient to provide an ``auto_reset`` argument
@@ -582,23 +582,23 @@ the ``return_contiguous=False`` argument.
Here is a working example:
>>> from torchrl.envs import EnvBase
- >>> from torchrl.data import UnboundedContinuousTensorSpec, CompositeSpec, BoundedTensorSpec, BinaryDiscreteTensorSpec
+ >>> from torchrl.data import Unbounded, Composite, Bounded, Binary
>>> import torch
>>> from tensordict import TensorDict, TensorDictBase
>>>
>>> class EnvWithDynamicSpec(EnvBase):
... def __init__(self, max_count=5):
... super().__init__(batch_size=())
- ... self.observation_spec = CompositeSpec(
- ... observation=UnboundedContinuousTensorSpec(shape=(3, -1, 2)),
+ ... self.observation_spec = Composite(
+ ... observation=Unbounded(shape=(3, -1, 2)),
... )
- ... self.action_spec = BoundedTensorSpec(low=-1, high=1, shape=(2,))
- ... self.full_done_spec = CompositeSpec(
- ... done=BinaryDiscreteTensorSpec(1, shape=(1,), dtype=torch.bool),
- ... terminated=BinaryDiscreteTensorSpec(1, shape=(1,), dtype=torch.bool),
- ... truncated=BinaryDiscreteTensorSpec(1, shape=(1,), dtype=torch.bool),
+ ... self.action_spec = Bounded(low=-1, high=1, shape=(2,))
+ ... self.full_done_spec = Composite(
+ ... done=Binary(1, shape=(1,), dtype=torch.bool),
+ ... terminated=Binary(1, shape=(1,), dtype=torch.bool),
+ ... truncated=Binary(1, shape=(1,), dtype=torch.bool),
... )
- ... self.reward_spec = UnboundedContinuousTensorSpec((1,), dtype=torch.float)
+ ... self.reward_spec = Unbounded((1,), dtype=torch.float)
... self.count = 0
... self.max_count = max_count
...
@@ -722,6 +722,9 @@ Since each transform uses a ``"in_keys"``/``"out_keys"`` set of keyword argument
also easy to root the transform graph to each component of the observation data (e.g.
pixels or states etc).
+Forward and inverse transforms
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+
Transforms also have an ``inv`` method that is called before
the action is applied in reverse order over the composed transform chain:
this allows to apply transforms to data in the environment before the action is taken
@@ -733,6 +736,20 @@ in the environment. The keys to be included in this inverse transform are passed
>>> env.append_transform(DoubleToFloat(in_keys_inv=["action"])) # will map the action from float32 to float64 before calling the base_env.step
+The way ``in_keys`` relates to ``in_keys_inv`` can be understood by considering the base environment as the "inner" part
+of the transform. In constrast, the user inputs and outputs to and from the transform are to be considered as the
+outside world. The following figure shows what this means in practice for the :class:`~torchrl.envs.RenameTransform`
+class: the input ``TensorDict`` of the ``step`` function must have the ``out_keys_inv`` listed in its entries as they
+are part of the outside world. The transform changes these names to make them match the names of the inner, base
+environment using the ``in_keys_inv``. The inverse process is executed with the output tensordict, where the ``in_keys``
+are mapped to the corresponding ``out_keys``.
+
+.. figure:: /_static/img/rename_transform.png
+
+ Rename transform logic
+
+
+
Cloning transforms
~~~~~~~~~~~~~~~~~~
@@ -755,10 +772,10 @@ registered buffers:
>>> TransformedEnv(base_env, third_transform.clone()) # works
On a single process or if the buffers are placed in shared memory, this will
-result in all the clone transforms to keep the same behaviour even if the
+result in all the clone transforms to keep the same behavior even if the
buffers are changed in place (which is what will happen with the :class:`CatFrames`
transform, for instance). In distributed settings, this may not hold and one
-should be careful about the expected behaviour of the cloned transforms in this
+should be careful about the expected behavior of the cloned transforms in this
context.
Finally, notice that indexing multiple transforms from a :class:`Compose` transform
may also result in loss of parenthood for these transforms: the reason is that
@@ -979,11 +996,9 @@ Helpers
RandomPolicy
check_env_specs
- exploration_mode #deprecated
exploration_type
get_available_libraries
make_composite_from_td
- set_exploration_mode #deprecated
set_exploration_type
step_mdp
terminated_or_truncated
@@ -1061,7 +1076,7 @@ the current gym backend or any of its modules:
Another tool that comes in handy with gym and other external dependencies is
the :class:`torchrl._utils.implement_for` class. Decorating a function
with ``@implement_for`` will tell torchrl that, depending on the version
-indicated, a specific behaviour is to be expected. This allows us to easily
+indicated, a specific behavior is to be expected. This allows us to easily
support multiple versions of gym without requiring any effort from the user side.
For example, considering that our virtual environment has the v0.26.2 installed,
the following function will return ``1`` when queried:
@@ -1098,11 +1113,15 @@ the following function will return ``1`` when queried:
MultiThreadedEnv
MultiThreadedEnvWrapper
OpenMLEnv
+ OpenSpielWrapper
+ OpenSpielEnv
PettingZooEnv
PettingZooWrapper
RoboHiveEnv
SMACv2Env
SMACv2Wrapper
+ UnityMLAgentsEnv
+ UnityMLAgentsWrapper
VmasEnv
VmasWrapper
gym_backend
diff --git a/docs/source/reference/modules.rst b/docs/source/reference/modules.rst
index c73ed5083fd..e1642868228 100644
--- a/docs/source/reference/modules.rst
+++ b/docs/source/reference/modules.rst
@@ -57,16 +57,21 @@ projected (in a L1-manner) into the desired domain.
SafeSequential
TanhModule
-Exploration wrappers
-~~~~~~~~~~~~~~~~~~~~
+Exploration wrappers and modules
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
-To efficiently explore the environment, TorchRL proposes a series of wrappers
+To efficiently explore the environment, TorchRL proposes a series of modules
that will override the action sampled by the policy by a noisier version.
-Their behaviour is controlled by :func:`~torchrl.envs.utils.exploration_mode`:
-if the exploration is set to ``"random"``, the exploration is active. In all
+Their behavior is controlled by :func:`~torchrl.envs.utils.exploration_type`:
+if the exploration is set to ``ExplorationType.RANDOM``, the exploration is active. In all
other cases, the action written in the tensordict is simply the network output.
-.. currentmodule:: torchrl.modules.tensordict_module
+.. note:: Unlike other exploration modules, :class:`~torchrl.modules.ConsistentDropoutModule`
+ uses the ``train``/``eval`` mode to comply with the regular `Dropout` API in PyTorch.
+ The :func:`~torchrl.envs.utils.set_exploration_type` context manager will have no effect on
+ this module.
+
+.. currentmodule:: torchrl.modules
.. autosummary::
:toctree: generated/
@@ -74,6 +79,7 @@ other cases, the action written in the tensordict is simply the network output.
AdditiveGaussianModule
AdditiveGaussianWrapper
+ ConsistentDropoutModule
EGreedyModule
EGreedyWrapper
OrnsteinUhlenbeckProcessModule
@@ -163,11 +169,91 @@ resulting action in the input tensordict along with the list of action values.
>>> from tensordict import TensorDict
>>> from tensordict.nn.functional_modules import make_functional
>>> from torch import nn
- >>> from torchrl.data import OneHotDiscreteTensorSpec
+ >>> from torchrl.data import OneHot
+ >>> from torchrl.modules.tensordict_module.actors import QValueActor
+ >>> td = TensorDict({'observation': torch.randn(5, 3)}, [5])
+ >>> # we have 4 actions to choose from
+ >>> action_spec = OneHot(4)
+ >>> # the model reads a state of dimension 3 and outputs 4 values, one for each action available
+ >>> module = nn.Linear(3, 4)
+ >>> qvalue_actor = QValueActor(module=module, spec=action_spec)
+ >>> qvalue_actor(td)
+ >>> print(td)
+ TensorDict(
+ fields={
+ action: Tensor(shape=torch.Size([5, 4]), device=cpu, dtype=torch.int64, is_shared=False),
+ action_value: Tensor(shape=torch.Size([5, 4]), device=cpu, dtype=torch.float32, is_shared=False),
+ chosen_action_value: Tensor(shape=torch.Size([5, 1]), device=cpu, dtype=torch.float32, is_shared=False),
+ observation: Tensor(shape=torch.Size([5, 3]), device=cpu, dtype=torch.float32, is_shared=False)},
+ batch_size=torch.Size([5]),
+ device=None,
+ is_shared=False)
+
+Distributional Q-learning is slightly different: in this case, the value network
+does not output a scalar value for each state-action value.
+Instead, the value space is divided in a an arbitrary number of "bins". The
+value network outputs a probability that the state-action value belongs to one bin
+or another.
+Hence, for a state space of dimension M, an action space of dimension N and a number of bins B,
+the value network encodes a
+of a (s,a) -> v map. This map can be a table or a function.
+For discrete action spaces with continuous (or near-continuous such as pixels)
+states, it is customary to use a non-linear model such as a neural network for
+the map.
+The semantic of the Q-Value network is hopefully quite simple: we just need to
+feed a tensor-to-tensor map that given a certain state (the input tensor),
+outputs a list of action values to choose from. The wrapper will write the
+resulting action in the input tensordict along with the list of action values.
+
+ >>> import torch
+ >>> from tensordict import TensorDict
+ >>> from tensordict.nn.functional_modules import make_functional
+ >>> from torch import nn
+ >>> from torchrl.data import OneHot
+ >>> from torchrl.modules.tensordict_module.actors import QValueActor
+ >>> td = TensorDict({'observation': torch.randn(5, 3)}, [5])
+ >>> # we have 4 actions to choose from
+ >>> action_spec = OneHot(4)
+ >>> # the model reads a state of dimension 3 and outputs 4 values, one for each action available
+ >>> module = nn.Linear(3, 4)
+ >>> qvalue_actor = QValueActor(module=module, spec=action_spec)
+ >>> qvalue_actor(td)
+ >>> print(td)
+ TensorDict(
+ fields={
+ action: Tensor(shape=torch.Size([5, 4]), device=cpu, dtype=torch.int64, is_shared=False),
+ action_value: Tensor(shape=torch.Size([5, 4]), device=cpu, dtype=torch.float32, is_shared=False),
+ chosen_action_value: Tensor(shape=torch.Size([5, 1]), device=cpu, dtype=torch.float32, is_shared=False),
+ observation: Tensor(shape=torch.Size([5, 3]), device=cpu, dtype=torch.float32, is_shared=False)},
+ batch_size=torch.Size([5]),
+ device=None,
+ is_shared=False)
+
+Distributional Q-learning is slightly different: in this case, the value network
+does not output a scalar value for each state-action value.
+Instead, the value space is divided in a an arbitrary number of "bins". The
+value network outputs a probability that the state-action value belongs to one bin
+or another.
+Hence, for a state space of dimension M, an action space of dimension N and a number of bins B,
+the value network encodes a
+of a (s,a) -> v map. This map can be a table or a function.
+For discrete action spaces with continuous (or near-continuous such as pixels)
+states, it is customary to use a non-linear model such as a neural network for
+the map.
+The semantic of the Q-Value network is hopefully quite simple: we just need to
+feed a tensor-to-tensor map that given a certain state (the input tensor),
+outputs a list of action values to choose from. The wrapper will write the
+resulting action in the input tensordict along with the list of action values.
+
+ >>> import torch
+ >>> from tensordict import TensorDict
+ >>> from tensordict.nn.functional_modules import make_functional
+ >>> from torch import nn
+ >>> from torchrl.data import OneHot
>>> from torchrl.modules.tensordict_module.actors import QValueActor
>>> td = TensorDict({'observation': torch.randn(5, 3)}, [5])
>>> # we have 4 actions to choose from
- >>> action_spec = OneHotDiscreteTensorSpec(4)
+ >>> action_spec = OneHot(4)
>>> # the model reads a state of dimension 3 and outputs 4 values, one for each action available
>>> module = nn.Linear(3, 4)
>>> qvalue_actor = QValueActor(module=module, spec=action_spec)
@@ -196,13 +282,57 @@ class:
>>> import torch
>>> from tensordict import TensorDict
>>> from torch import nn
- >>> from torchrl.data import OneHotDiscreteTensorSpec
+ >>> from torchrl.data import OneHot
+ >>> from torchrl.modules import DistributionalQValueActor, MLP
+ >>> td = TensorDict({'observation': torch.randn(5, 4)}, [5])
+ >>> nbins = 3
+ >>> # our model reads the observation and outputs a stack of 4 logits (one for each action) of size nbins=3
+ >>> module = MLP(out_features=(nbins, 4), depth=2)
+ >>> action_spec = OneHot(4)
+ >>> qvalue_actor = DistributionalQValueActor(module=module, spec=action_spec, support=torch.arange(nbins))
+ >>> td = qvalue_actor(td)
+ >>> print(td)
+ TensorDict(
+ fields={
+ action: Tensor(shape=torch.Size([5, 4]), device=cpu, dtype=torch.int64, is_shared=False),
+ action_value: Tensor(shape=torch.Size([5, 3, 4]), device=cpu, dtype=torch.float32, is_shared=False),
+ observation: Tensor(shape=torch.Size([5, 4]), device=cpu, dtype=torch.float32, is_shared=False)},
+ batch_size=torch.Size([5]),
+ device=None,
+ is_shared=False)
+
+ >>> import torch
+ >>> from tensordict import TensorDict
+ >>> from torch import nn
+ >>> from torchrl.data import OneHot
>>> from torchrl.modules import DistributionalQValueActor, MLP
>>> td = TensorDict({'observation': torch.randn(5, 4)}, [5])
>>> nbins = 3
>>> # our model reads the observation and outputs a stack of 4 logits (one for each action) of size nbins=3
>>> module = MLP(out_features=(nbins, 4), depth=2)
- >>> action_spec = OneHotDiscreteTensorSpec(4)
+ >>> action_spec = OneHot(4)
+ >>> qvalue_actor = DistributionalQValueActor(module=module, spec=action_spec, support=torch.arange(nbins))
+ >>> td = qvalue_actor(td)
+ >>> print(td)
+ TensorDict(
+ fields={
+ action: Tensor(shape=torch.Size([5, 4]), device=cpu, dtype=torch.int64, is_shared=False),
+ action_value: Tensor(shape=torch.Size([5, 3, 4]), device=cpu, dtype=torch.float32, is_shared=False),
+ observation: Tensor(shape=torch.Size([5, 4]), device=cpu, dtype=torch.float32, is_shared=False)},
+ batch_size=torch.Size([5]),
+ device=None,
+ is_shared=False)
+
+ >>> import torch
+ >>> from tensordict import TensorDict
+ >>> from torch import nn
+ >>> from torchrl.data import OneHot
+ >>> from torchrl.modules import DistributionalQValueActor, MLP
+ >>> td = TensorDict({'observation': torch.randn(5, 4)}, [5])
+ >>> nbins = 3
+ >>> # our model reads the observation and outputs a stack of 4 logits (one for each action) of size nbins=3
+ >>> module = MLP(out_features=(nbins, 4), depth=2)
+ >>> action_spec = OneHot(4)
>>> qvalue_actor = DistributionalQValueActor(module=module, spec=action_spec, support=torch.arange(nbins))
>>> td = qvalue_actor(td)
>>> print(td)
@@ -314,12 +444,13 @@ Regular modules
:toctree: generated/
:template: rl_template_noinherit.rst
- MLP
- ConvNet
+ BatchRenorm1d
+ ConsistentDropout
Conv3dNet
- SqueezeLayer
+ ConvNet
+ MLP
Squeeze2dLayer
- BatchRenorm
+ SqueezeLayer
Algorithm-specific modules
~~~~~~~~~~~~~~~~~~~~~~~~~~
diff --git a/docs/source/reference/objectives.rst b/docs/source/reference/objectives.rst
index 1d92c390a4e..b3f8e242a9e 100644
--- a/docs/source/reference/objectives.rst
+++ b/docs/source/reference/objectives.rst
@@ -157,7 +157,7 @@ CrossQ
:toctree: generated/
:template: rl_template_noinherit.rst
- CrossQ
+ CrossQLoss
IQL
----
@@ -179,6 +179,15 @@ CQL
CQLLoss
DiscreteCQLLoss
+GAIL
+----
+
+.. autosummary::
+ :toctree: generated/
+ :template: rl_template_noinherit.rst
+
+ GAILLoss
+
DT
----
diff --git a/examples/distributed/collectors/multi_nodes/delayed_dist.py b/examples/distributed/collectors/multi_nodes/delayed_dist.py
index b140ee7bc67..7b7e053f498 100644
--- a/examples/distributed/collectors/multi_nodes/delayed_dist.py
+++ b/examples/distributed/collectors/multi_nodes/delayed_dist.py
@@ -114,7 +114,7 @@
def main():
import gym
from torchrl.collectors import MultiSyncDataCollector, SyncDataCollector
- from torchrl.data import BoundedTensorSpec
+ from torchrl.data import Bounded
from torchrl.envs.libs.gym import GymEnv, set_gym_backend
from torchrl.envs.utils import RandomPolicy
@@ -128,7 +128,7 @@ def make_env():
collector = DistributedDataCollector(
[EnvCreator(make_env)] * num_jobs,
- policy=RandomPolicy(BoundedTensorSpec(-1, 1, shape=(1,))),
+ policy=RandomPolicy(Bounded(-1, 1, shape=(1,))),
launcher="submitit_delayed",
frames_per_batch=frames_per_batch,
total_frames=total_frames,
diff --git a/examples/distributed/collectors/multi_nodes/delayed_rpc.py b/examples/distributed/collectors/multi_nodes/delayed_rpc.py
index adff8864413..f63c4d17409 100644
--- a/examples/distributed/collectors/multi_nodes/delayed_rpc.py
+++ b/examples/distributed/collectors/multi_nodes/delayed_rpc.py
@@ -113,7 +113,7 @@
def main():
import gym
from torchrl.collectors import MultiSyncDataCollector, SyncDataCollector
- from torchrl.data import BoundedTensorSpec
+ from torchrl.data import Bounded
from torchrl.envs.libs.gym import GymEnv, set_gym_backend
from torchrl.envs.utils import RandomPolicy
@@ -127,7 +127,7 @@ def make_env():
collector = RPCDataCollector(
[EnvCreator(make_env)] * num_jobs,
- policy=RandomPolicy(BoundedTensorSpec(-1, 1, shape=(1,))),
+ policy=RandomPolicy(Bounded(-1, 1, shape=(1,))),
launcher="submitit_delayed",
frames_per_batch=frames_per_batch,
total_frames=total_frames,
diff --git a/examples/distributed/collectors/multi_nodes/ray_train.py b/examples/distributed/collectors/multi_nodes/ray_train.py
index b05e92619fa..5697d88dc61 100644
--- a/examples/distributed/collectors/multi_nodes/ray_train.py
+++ b/examples/distributed/collectors/multi_nodes/ray_train.py
@@ -26,7 +26,7 @@
TransformedEnv,
)
from torchrl.envs.libs.gym import GymEnv
-from torchrl.envs.utils import check_env_specs, set_exploration_mode
+from torchrl.envs.utils import check_env_specs, ExplorationType, set_exploration_type
from torchrl.modules import ProbabilisticActor, TanhNormal, ValueOperator
from torchrl.objectives import ClipPPOLoss
from torchrl.objectives.value import GAE
@@ -85,8 +85,8 @@
in_keys=["loc", "scale"],
distribution_class=TanhNormal,
distribution_kwargs={
- "min": env.action_spec.space.low,
- "max": env.action_spec.space.high,
+ "low": env.action_spec.space.low,
+ "high": env.action_spec.space.high,
},
return_log_prob=True,
)
@@ -201,7 +201,7 @@
stepcount_str = f"step count (max): {logs['step_count'][-1]}"
logs["lr"].append(optim.param_groups[0]["lr"])
lr_str = f"lr policy: {logs['lr'][-1]: 4.4f}"
- with set_exploration_mode("mean"), torch.no_grad():
+ with set_exploration_type(ExplorationType.MODE), torch.no_grad():
# execute a rollout with the trained policy
eval_rollout = env.rollout(1000, policy_module)
logs["eval reward"].append(eval_rollout["next", "reward"].mean().item())
diff --git a/examples/distributed/replay_buffers/distributed_replay_buffer.py b/examples/distributed/replay_buffers/distributed_replay_buffer.py
index c7504fbf8ee..f25ea0bdc8b 100644
--- a/examples/distributed/replay_buffers/distributed_replay_buffer.py
+++ b/examples/distributed/replay_buffers/distributed_replay_buffer.py
@@ -150,8 +150,8 @@ def _create_and_launch_data_collectors(self) -> None:
class ReplayBufferNode(RemoteTensorDictReplayBuffer):
"""Experience replay buffer node that is capable of accepting remote connections. Being a `RemoteTensorDictReplayBuffer`
- means all of it's public methods are remotely invokable using `torch.rpc`.
- Using a LazyMemmapStorage is highly advised in distributed settings with shared storage due to the lower serialisation
+ means all of its public methods are remotely invokable using `torch.rpc`.
+ Using a LazyMemmapStorage is highly advised in distributed settings with shared storage due to the lower serialization
cost of MemoryMappedTensors as well as the ability to specify file storage locations which can improve ability to recover from node failures.
Args:
diff --git a/examples/envs/gym-async-info-reader.py b/examples/envs/gym-async-info-reader.py
index 3f98e039290..72330f13030 100644
--- a/examples/envs/gym-async-info-reader.py
+++ b/examples/envs/gym-async-info-reader.py
@@ -48,7 +48,7 @@ def step(self, action):
if __name__ == "__main__":
import torch
- from torchrl.data.tensor_specs import UnboundedContinuousTensorSpec
+ from torchrl.data.tensor_specs import Unbounded
from torchrl.envs import check_env_specs, GymEnv, GymWrapper
args = parser.parse_args()
@@ -66,7 +66,7 @@ def step(self, action):
keys = ["field1"]
specs = [
- UnboundedContinuousTensorSpec(shape=(num_envs, 3), dtype=torch.float64),
+ Unbounded(shape=(num_envs, 3), dtype=torch.float64),
]
# Create an info reader: this object will read the info and write its content to the tensordict
diff --git a/knowledge_base/VIDEO_CUSTOMISATION.md b/knowledge_base/VIDEO_CUSTOMISATION.md
index 956110d89aa..e28334708b2 100644
--- a/knowledge_base/VIDEO_CUSTOMISATION.md
+++ b/knowledge_base/VIDEO_CUSTOMISATION.md
@@ -50,9 +50,5 @@ as advised by the documentation.
We can improve the video quality by appending all our desired settings
(as keyword arguments) to `recorder` like so:
```python
-# The arguments' types don't appear to matter too much, as long as they are
-# appropriate for Python.
-# For example, this would work as well:
-# logger = CSVLogger(exp_name="my_exp", crf=17, preset="slow")
-logger = CSVLogger(exp_name="my_exp", crf="17", preset="slow")
+recorder = VideoRecorder(logger, tag = "my_video", options = {"crf": "17", "preset": "slow"})
```
diff --git a/pytest.ini b/pytest.ini
index 36d047d3055..39fe36617a1 100644
--- a/pytest.ini
+++ b/pytest.ini
@@ -4,6 +4,8 @@ addopts =
-ra
# Make tracebacks shorter
--tb=native
+markers =
+ unity_editor
testpaths =
test
xfail_strict = True
diff --git a/setup.py b/setup.py
index 1c4f267bb94..a52711b0ed5 100644
--- a/setup.py
+++ b/setup.py
@@ -152,11 +152,15 @@ def get_extensions():
}
sources = list(extension_sources)
+ include_dirs = [this_dir]
+ python_include_dir = os.getenv("PYTHON_INCLUDE_DIR")
+ if python_include_dir is not None:
+ include_dirs.append(python_include_dir)
ext_modules = [
extension(
"torchrl._torchrl",
sources,
- include_dirs=[this_dir],
+ include_dirs=include_dirs,
extra_compile_args=extra_compile_args,
extra_link_args=extra_link_args,
)
@@ -191,7 +195,7 @@ def _main(argv):
# tag = _run_cmd(["git", "describe", "--tags", "--exact-match", "@"])
this_directory = Path(__file__).parent
- long_description = (this_directory / "README.md").read_text()
+ long_description = (this_directory / "README.md").read_text(encoding="utf8")
sys.argv = [sys.argv[0]] + unknown
extra_requires = {
@@ -203,7 +207,7 @@ def _main(argv):
"pygame",
],
"dm_control": ["dm_control"],
- "gym_continuous": ["gymnasium", "mujoco"],
+ "gym_continuous": ["gymnasium<1.0", "mujoco"],
"rendering": ["moviepy"],
"tests": ["pytest", "pyyaml", "pytest-instafail", "scipy"],
"utils": [
@@ -229,6 +233,7 @@ def _main(argv):
"pillow",
],
"marl": ["vmas>=1.2.10", "pettingzoo>=1.24.1", "dm-meltingpot"],
+ "open_spiel": ["open_spiel>=1.5"],
}
extra_requires["all"] = set()
for key in list(extra_requires.keys()):
diff --git a/sota-implementations/a2c/a2c_atari.py b/sota-implementations/a2c/a2c_atari.py
index f8c18147306..42ef4301c4d 100644
--- a/sota-implementations/a2c/a2c_atari.py
+++ b/sota-implementations/a2c/a2c_atari.py
@@ -226,6 +226,9 @@ def main(cfg: "DictConfig"): # noqa: F821
collector.update_policy_weights_()
sampling_start = time.time()
+ collector.shutdown()
+ if not test_env.is_closed:
+ test_env.close()
end_time = time.time()
execution_time = end_time - start_time
torchrl_logger.info(f"Training took {execution_time:.2f} seconds to finish")
diff --git a/sota-implementations/a2c/a2c_mujoco.py b/sota-implementations/a2c/a2c_mujoco.py
index d115174eb9c..2b390d39d2a 100644
--- a/sota-implementations/a2c/a2c_mujoco.py
+++ b/sota-implementations/a2c/a2c_mujoco.py
@@ -212,6 +212,9 @@ def main(cfg: "DictConfig"): # noqa: F821
collector.update_policy_weights_()
sampling_start = time.time()
+ collector.shutdown()
+ if not test_env.is_closed:
+ test_env.close()
end_time = time.time()
execution_time = end_time - start_time
torchrl_logger.info(f"Training took {execution_time:.2f} seconds to finish")
diff --git a/sota-implementations/a2c/utils_atari.py b/sota-implementations/a2c/utils_atari.py
index 58fa8541d90..6a09ff715e4 100644
--- a/sota-implementations/a2c/utils_atari.py
+++ b/sota-implementations/a2c/utils_atari.py
@@ -7,8 +7,8 @@
import torch.nn
import torch.optim
from tensordict.nn import TensorDictModule
-from torchrl.data import CompositeSpec
-from torchrl.data.tensor_specs import DiscreteBox
+from torchrl.data import Composite
+from torchrl.data.tensor_specs import CategoricalBox
from torchrl.envs import (
CatFrames,
DoubleToFloat,
@@ -92,7 +92,7 @@ def make_ppo_modules_pixels(proof_environment):
input_shape = proof_environment.observation_spec["pixels"].shape
# Define distribution class and kwargs
- if isinstance(proof_environment.action_spec.space, DiscreteBox):
+ if isinstance(proof_environment.action_spec.space, CategoricalBox):
num_outputs = proof_environment.action_spec.space.n
distribution_class = OneHotCategorical
distribution_kwargs = {}
@@ -148,7 +148,7 @@ def make_ppo_modules_pixels(proof_environment):
policy_module = ProbabilisticActor(
policy_module,
in_keys=["logits"],
- spec=CompositeSpec(action=proof_environment.action_spec),
+ spec=Composite(action=proof_environment.action_spec),
distribution_class=distribution_class,
distribution_kwargs=distribution_kwargs,
return_log_prob=True,
diff --git a/sota-implementations/a2c/utils_mujoco.py b/sota-implementations/a2c/utils_mujoco.py
index 9bb5a1f6307..996706ce4f9 100644
--- a/sota-implementations/a2c/utils_mujoco.py
+++ b/sota-implementations/a2c/utils_mujoco.py
@@ -8,7 +8,7 @@
import torch.optim
from tensordict.nn import AddStateIndependentNormalScale, TensorDictModule
-from torchrl.data import CompositeSpec
+from torchrl.data import Composite
from torchrl.envs import (
ClipTransform,
DoubleToFloat,
@@ -90,7 +90,7 @@ def make_ppo_models_state(proof_environment):
out_keys=["loc", "scale"],
),
in_keys=["loc", "scale"],
- spec=CompositeSpec(action=proof_environment.action_spec),
+ spec=Composite(action=proof_environment.action_spec),
distribution_class=distribution_class,
distribution_kwargs=distribution_kwargs,
return_log_prob=True,
diff --git a/sota-implementations/cql/cql_offline.py b/sota-implementations/cql/cql_offline.py
index 5ca70f83b53..73155d9fa1a 100644
--- a/sota-implementations/cql/cql_offline.py
+++ b/sota-implementations/cql/cql_offline.py
@@ -58,14 +58,14 @@ def main(cfg: "DictConfig"): # noqa: F821
device = "cpu"
device = torch.device(device)
+ # Create replay buffer
+ replay_buffer = make_offline_replay_buffer(cfg.replay_buffer)
+
# Create env
train_env, eval_env = make_environment(
cfg, train_num_envs=1, eval_num_envs=cfg.logger.eval_envs, logger=logger
)
- # Create replay buffer
- replay_buffer = make_offline_replay_buffer(cfg.replay_buffer)
-
# Create agent
model = make_cql_model(cfg, train_env, eval_env, device)
del train_env
@@ -107,9 +107,6 @@ def main(cfg: "DictConfig"): # noqa: F821
q_loss = q_loss + cql_loss
- alpha_loss = loss_vals["loss_alpha"]
- alpha_prime_loss = loss_vals["loss_alpha_prime"]
-
# update model
alpha_loss = loss_vals["loss_alpha"]
alpha_prime_loss = loss_vals["loss_alpha_prime"]
diff --git a/sota-implementations/cql/utils.py b/sota-implementations/cql/utils.py
index fae54da049a..c1d6fb52024 100644
--- a/sota-implementations/cql/utils.py
+++ b/sota-implementations/cql/utils.py
@@ -11,7 +11,7 @@
from torchrl.collectors import SyncDataCollector
from torchrl.data import (
- CompositeSpec,
+ Composite,
LazyMemmapStorage,
TensorDictPrioritizedReplayBuffer,
TensorDictReplayBuffer,
@@ -252,7 +252,7 @@ def make_discretecql_model(cfg, train_env, eval_env, device="cpu"):
actor_net = MLP(**actor_net_kwargs)
qvalue_module = QValueActor(
module=actor_net,
- spec=CompositeSpec(action=action_spec),
+ spec=Composite(action=action_spec),
in_keys=["observation"],
)
qvalue_module = qvalue_module.to(device)
diff --git a/sota-implementations/crossq/crossq.py b/sota-implementations/crossq/crossq.py
index df34d4ae68d..b07ae880046 100644
--- a/sota-implementations/crossq/crossq.py
+++ b/sota-implementations/crossq/crossq.py
@@ -203,7 +203,7 @@ def main(cfg: "DictConfig"): # noqa: F821
# Evaluation
if abs(collected_frames % eval_iter) < frames_per_batch:
- with set_exploration_type(ExplorationType.MODE), torch.no_grad():
+ with set_exploration_type(ExplorationType.DETERMINISTIC), torch.no_grad():
eval_start = time.time()
eval_rollout = eval_env.rollout(
eval_rollout_steps,
@@ -220,6 +220,10 @@ def main(cfg: "DictConfig"): # noqa: F821
sampling_start = time.time()
collector.shutdown()
+ if not eval_env.is_closed:
+ eval_env.close()
+ if not train_env.is_closed:
+ train_env.close()
end_time = time.time()
execution_time = end_time - start_time
torchrl_logger.info(f"Training took {execution_time:.2f} seconds to finish")
diff --git a/sota-implementations/ddpg/ddpg.py b/sota-implementations/ddpg/ddpg.py
index 1b038d69d15..cebc3685625 100644
--- a/sota-implementations/ddpg/ddpg.py
+++ b/sota-implementations/ddpg/ddpg.py
@@ -205,6 +205,10 @@ def main(cfg: "DictConfig"): # noqa: F821
collector.shutdown()
end_time = time.time()
execution_time = end_time - start_time
+ if not eval_env.is_closed:
+ eval_env.close()
+ if not train_env.is_closed:
+ train_env.close()
torchrl_logger.info(f"Training took {execution_time:.2f} seconds to finish")
diff --git a/sota-implementations/decision_transformer/dt.py b/sota-implementations/decision_transformer/dt.py
index 9cca9fd8af5..b892462339c 100644
--- a/sota-implementations/decision_transformer/dt.py
+++ b/sota-implementations/decision_transformer/dt.py
@@ -131,6 +131,8 @@ def main(cfg: "DictConfig"): # noqa: F821
log_metrics(logger, to_log, i)
pbar.close()
+ if not test_env.is_closed:
+ test_env.close()
torchrl_logger.info(f"Training time: {time.time() - start_time}")
diff --git a/sota-implementations/decision_transformer/online_dt.py b/sota-implementations/decision_transformer/online_dt.py
index da2241ce9fa..184c850b626 100644
--- a/sota-implementations/decision_transformer/online_dt.py
+++ b/sota-implementations/decision_transformer/online_dt.py
@@ -145,6 +145,8 @@ def main(cfg: "DictConfig"): # noqa: F821
log_metrics(logger, to_log, i)
pbar.close()
+ if not test_env.is_closed:
+ test_env.close()
torchrl_logger.info(f"Training time: {time.time() - start_time}")
diff --git a/sota-implementations/decision_transformer/utils.py b/sota-implementations/decision_transformer/utils.py
index 409833c75fa..ee2cc6e424c 100644
--- a/sota-implementations/decision_transformer/utils.py
+++ b/sota-implementations/decision_transformer/utils.py
@@ -38,7 +38,7 @@
)
from torchrl.envs.libs.dm_control import DMControlEnv
from torchrl.envs.libs.gym import set_gym_backend
-from torchrl.envs.utils import set_exploration_mode
+from torchrl.envs.utils import ExplorationType, set_exploration_type
from torchrl.modules import (
DTActor,
OnlineDTActor,
@@ -374,13 +374,12 @@ def make_odt_model(cfg):
module=actor_module,
distribution_class=dist_class,
distribution_kwargs=dist_kwargs,
- default_interaction_mode="random",
cache_dist=False,
return_log_prob=False,
)
# init the lazy layers
- with torch.no_grad(), set_exploration_mode("random"):
+ with torch.no_grad(), set_exploration_type(ExplorationType.RANDOM):
td = proof_environment.rollout(max_steps=100)
td["action"] = td["next", "action"]
actor(td)
@@ -428,13 +427,12 @@ def make_dt_model(cfg):
module=actor_module,
distribution_class=dist_class,
distribution_kwargs=dist_kwargs,
- default_interaction_mode="random",
cache_dist=False,
return_log_prob=False,
)
# init the lazy layers
- with torch.no_grad(), set_exploration_mode("random"):
+ with torch.no_grad(), set_exploration_type(ExplorationType.RANDOM):
td = proof_environment.rollout(max_steps=100)
td["action"] = td["next", "action"]
actor(td)
diff --git a/sota-implementations/discrete_sac/discrete_sac.py b/sota-implementations/discrete_sac/discrete_sac.py
index 386f743c7d3..a9a08827f5d 100644
--- a/sota-implementations/discrete_sac/discrete_sac.py
+++ b/sota-implementations/discrete_sac/discrete_sac.py
@@ -222,6 +222,10 @@ def main(cfg: "DictConfig"): # noqa: F821
sampling_start = time.time()
collector.shutdown()
+ if not eval_env.is_closed:
+ eval_env.close()
+ if not train_env.is_closed:
+ train_env.close()
end_time = time.time()
execution_time = end_time - start_time
torchrl_logger.info(f"Training took {execution_time:.2f} seconds to finish")
diff --git a/sota-implementations/discrete_sac/utils.py b/sota-implementations/discrete_sac/utils.py
index ddffffc2a8e..8051f07fe95 100644
--- a/sota-implementations/discrete_sac/utils.py
+++ b/sota-implementations/discrete_sac/utils.py
@@ -12,7 +12,7 @@
from torch import nn, optim
from torchrl.collectors import SyncDataCollector
from torchrl.data import (
- CompositeSpec,
+ Composite,
TensorDictPrioritizedReplayBuffer,
TensorDictReplayBuffer,
)
@@ -203,7 +203,7 @@ def make_sac_agent(cfg, train_env, eval_env, device):
out_keys=["logits"],
)
actor = ProbabilisticActor(
- spec=CompositeSpec(action=eval_env.action_spec),
+ spec=Composite(action=eval_env.action_spec),
module=actor_module,
in_keys=["logits"],
out_keys=["action"],
diff --git a/sota-implementations/dqn/dqn_atari.py b/sota-implementations/dqn/dqn_atari.py
index 906273ee2f5..5d0162080e2 100644
--- a/sota-implementations/dqn/dqn_atari.py
+++ b/sota-implementations/dqn/dqn_atari.py
@@ -228,6 +228,9 @@ def main(cfg: "DictConfig"): # noqa: F821
sampling_start = time.time()
collector.shutdown()
+ if not test_env.is_closed:
+ test_env.close()
+
end_time = time.time()
execution_time = end_time - start_time
torchrl_logger.info(f"Training took {execution_time:.2f} seconds to finish")
diff --git a/sota-implementations/dqn/dqn_cartpole.py b/sota-implementations/dqn/dqn_cartpole.py
index 173f88f7028..8149c700958 100644
--- a/sota-implementations/dqn/dqn_cartpole.py
+++ b/sota-implementations/dqn/dqn_cartpole.py
@@ -207,6 +207,8 @@ def main(cfg: "DictConfig"): # noqa: F821
sampling_start = time.time()
collector.shutdown()
+ if not test_env.is_closed:
+ test_env.close()
end_time = time.time()
execution_time = end_time - start_time
torchrl_logger.info(f"Training took {execution_time:.2f} seconds to finish")
diff --git a/sota-implementations/dqn/utils_atari.py b/sota-implementations/dqn/utils_atari.py
index 3dbbfe87af4..6f39e824c60 100644
--- a/sota-implementations/dqn/utils_atari.py
+++ b/sota-implementations/dqn/utils_atari.py
@@ -5,7 +5,7 @@
import torch.nn
import torch.optim
-from torchrl.data import CompositeSpec
+from torchrl.data import Composite
from torchrl.envs import (
CatFrames,
DoubleToFloat,
@@ -84,7 +84,7 @@ def make_dqn_modules_pixels(proof_environment):
)
qvalue_module = QValueActor(
module=torch.nn.Sequential(cnn, mlp),
- spec=CompositeSpec(action=action_spec),
+ spec=Composite(action=action_spec),
in_keys=["pixels"],
)
return qvalue_module
diff --git a/sota-implementations/dqn/utils_cartpole.py b/sota-implementations/dqn/utils_cartpole.py
index 2df280a04b4..c7f7491ad15 100644
--- a/sota-implementations/dqn/utils_cartpole.py
+++ b/sota-implementations/dqn/utils_cartpole.py
@@ -5,7 +5,7 @@
import torch.nn
import torch.optim
-from torchrl.data import CompositeSpec
+from torchrl.data import Composite
from torchrl.envs import RewardSum, StepCounter, TransformedEnv
from torchrl.envs.libs.gym import GymEnv
from torchrl.modules import MLP, QValueActor
@@ -48,7 +48,7 @@ def make_dqn_modules(proof_environment):
qvalue_module = QValueActor(
module=mlp,
- spec=CompositeSpec(action=action_spec),
+ spec=Composite(action=action_spec),
in_keys=["observation"],
)
return qvalue_module
diff --git a/sota-implementations/dreamer/dreamer_utils.py b/sota-implementations/dreamer/dreamer_utils.py
index 6745b1a079a..849d8c813b6 100644
--- a/sota-implementations/dreamer/dreamer_utils.py
+++ b/sota-implementations/dreamer/dreamer_utils.py
@@ -20,11 +20,11 @@
from torchrl.collectors import SyncDataCollector
from torchrl.data import (
- CompositeSpec,
+ Composite,
LazyMemmapStorage,
SliceSampler,
TensorDictReplayBuffer,
- UnboundedContinuousTensorSpec,
+ Unbounded,
)
from torchrl.envs import (
@@ -92,8 +92,8 @@ def _make_env(cfg, device, from_pixels=False):
else:
raise NotImplementedError(f"Unknown lib {lib}.")
default_dict = {
- "state": UnboundedContinuousTensorSpec(shape=(cfg.networks.state_dim,)),
- "belief": UnboundedContinuousTensorSpec(shape=(cfg.networks.rssm_hidden_dim,)),
+ "state": Unbounded(shape=(cfg.networks.state_dim,)),
+ "belief": Unbounded(shape=(cfg.networks.rssm_hidden_dim,)),
}
env = env.append_transform(
TensorDictPrimer(random=False, default_value=0, **default_dict)
@@ -469,13 +469,13 @@ def _dreamer_make_actor_sim(action_key, proof_environment, actor_module):
actor_module,
in_keys=["state", "belief"],
out_keys=["loc", "scale"],
- spec=CompositeSpec(
+ spec=Composite(
**{
- "loc": UnboundedContinuousTensorSpec(
+ "loc": Unbounded(
proof_environment.action_spec.shape,
device=proof_environment.action_spec.device,
),
- "scale": UnboundedContinuousTensorSpec(
+ "scale": Unbounded(
proof_environment.action_spec.shape,
device=proof_environment.action_spec.device,
),
@@ -488,7 +488,7 @@ def _dreamer_make_actor_sim(action_key, proof_environment, actor_module):
default_interaction_type=InteractionType.RANDOM,
distribution_class=TanhNormal,
distribution_kwargs={"tanh_loc": True},
- spec=CompositeSpec(**{action_key: proof_environment.action_spec}),
+ spec=Composite(**{action_key: proof_environment.action_spec}),
),
)
return actor_simulator
@@ -526,12 +526,12 @@ def _dreamer_make_actor_real(
actor_module,
in_keys=["state", "belief"],
out_keys=["loc", "scale"],
- spec=CompositeSpec(
+ spec=Composite(
**{
- "loc": UnboundedContinuousTensorSpec(
+ "loc": Unbounded(
proof_environment.action_spec.shape,
),
- "scale": UnboundedContinuousTensorSpec(
+ "scale": Unbounded(
proof_environment.action_spec.shape,
),
}
@@ -543,9 +543,7 @@ def _dreamer_make_actor_real(
default_interaction_type=InteractionType.DETERMINISTIC,
distribution_class=TanhNormal,
distribution_kwargs={"tanh_loc": True},
- spec=CompositeSpec(
- **{action_key: proof_environment.action_spec.to("cpu")}
- ),
+ spec=Composite(**{action_key: proof_environment.action_spec.to("cpu")}),
),
),
SafeModule(
diff --git a/sota-implementations/gail/config.yaml b/sota-implementations/gail/config.yaml
new file mode 100644
index 00000000000..cf6c8053037
--- /dev/null
+++ b/sota-implementations/gail/config.yaml
@@ -0,0 +1,46 @@
+env:
+ env_name: HalfCheetah-v4
+ seed: 42
+ backend: gymnasium
+
+logger:
+ backend: wandb
+ project_name: gail
+ group_name: null
+ exp_name: gail_ppo
+ test_interval: 5000
+ num_test_episodes: 5
+ video: False
+ mode: online
+
+ppo:
+ collector:
+ frames_per_batch: 2048
+ total_frames: 1_000_000
+
+ optim:
+ lr: 3e-4
+ weight_decay: 0.0
+ anneal_lr: True
+
+ loss:
+ gamma: 0.99
+ mini_batch_size: 64
+ ppo_epochs: 10
+ gae_lambda: 0.95
+ clip_epsilon: 0.2
+ anneal_clip_epsilon: False
+ critic_coef: 0.25
+ entropy_coef: 0.0
+ loss_critic_type: l2
+
+gail:
+ hidden_dim: 128
+ lr: 3e-4
+ use_grad_penalty: False
+ gp_lambda: 10.0
+ device: null
+
+replay_buffer:
+ dataset: halfcheetah-expert-v2
+ batch_size: 256
diff --git a/sota-implementations/gail/gail.py b/sota-implementations/gail/gail.py
new file mode 100644
index 00000000000..a3c64693fb3
--- /dev/null
+++ b/sota-implementations/gail/gail.py
@@ -0,0 +1,281 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+"""GAIL Example.
+
+This is a self-contained example of an offline GAIL training script.
+
+The helper functions for gail are coded in the gail_utils.py and helper functions for ppo in ppo_utils.
+
+"""
+import hydra
+import numpy as np
+import torch
+import tqdm
+
+from gail_utils import log_metrics, make_gail_discriminator, make_offline_replay_buffer
+from ppo_utils import eval_model, make_env, make_ppo_models
+from torchrl.collectors import SyncDataCollector
+from torchrl.data import LazyMemmapStorage, TensorDictReplayBuffer
+from torchrl.data.replay_buffers.samplers import SamplerWithoutReplacement
+
+from torchrl.envs import set_gym_backend
+from torchrl.envs.utils import ExplorationType, set_exploration_type
+from torchrl.objectives import ClipPPOLoss, GAILLoss
+from torchrl.objectives.value.advantages import GAE
+from torchrl.record import VideoRecorder
+from torchrl.record.loggers import generate_exp_name, get_logger
+
+
+@hydra.main(config_path="", config_name="config")
+def main(cfg: "DictConfig"): # noqa: F821
+ set_gym_backend(cfg.env.backend).set()
+
+ device = cfg.gail.device
+ if device in ("", None):
+ if torch.cuda.is_available():
+ device = "cuda:0"
+ else:
+ device = "cpu"
+ device = torch.device(device)
+ num_mini_batches = (
+ cfg.ppo.collector.frames_per_batch // cfg.ppo.loss.mini_batch_size
+ )
+ total_network_updates = (
+ (cfg.ppo.collector.total_frames // cfg.ppo.collector.frames_per_batch)
+ * cfg.ppo.loss.ppo_epochs
+ * num_mini_batches
+ )
+
+ # Create logger
+ exp_name = generate_exp_name("Gail", cfg.logger.exp_name)
+ logger = None
+ if cfg.logger.backend:
+ logger = get_logger(
+ logger_type=cfg.logger.backend,
+ logger_name="gail_logging",
+ experiment_name=exp_name,
+ wandb_kwargs={
+ "mode": cfg.logger.mode,
+ "config": dict(cfg),
+ "project": cfg.logger.project_name,
+ "group": cfg.logger.group_name,
+ },
+ )
+
+ # Set seeds
+ torch.manual_seed(cfg.env.seed)
+ np.random.seed(cfg.env.seed)
+
+ # Create models (check utils_mujoco.py)
+ actor, critic = make_ppo_models(cfg.env.env_name)
+ actor, critic = actor.to(device), critic.to(device)
+
+ # Create collector
+ collector = SyncDataCollector(
+ create_env_fn=make_env(cfg.env.env_name, device),
+ policy=actor,
+ frames_per_batch=cfg.ppo.collector.frames_per_batch,
+ total_frames=cfg.ppo.collector.total_frames,
+ device=device,
+ storing_device=device,
+ max_frames_per_traj=-1,
+ )
+
+ # Create data buffer
+ data_buffer = TensorDictReplayBuffer(
+ storage=LazyMemmapStorage(cfg.ppo.collector.frames_per_batch),
+ sampler=SamplerWithoutReplacement(),
+ batch_size=cfg.ppo.loss.mini_batch_size,
+ )
+
+ # Create loss and adv modules
+ adv_module = GAE(
+ gamma=cfg.ppo.loss.gamma,
+ lmbda=cfg.ppo.loss.gae_lambda,
+ value_network=critic,
+ average_gae=False,
+ )
+
+ loss_module = ClipPPOLoss(
+ actor_network=actor,
+ critic_network=critic,
+ clip_epsilon=cfg.ppo.loss.clip_epsilon,
+ loss_critic_type=cfg.ppo.loss.loss_critic_type,
+ entropy_coef=cfg.ppo.loss.entropy_coef,
+ critic_coef=cfg.ppo.loss.critic_coef,
+ normalize_advantage=True,
+ )
+
+ # Create optimizers
+ actor_optim = torch.optim.Adam(actor.parameters(), lr=cfg.ppo.optim.lr, eps=1e-5)
+ critic_optim = torch.optim.Adam(critic.parameters(), lr=cfg.ppo.optim.lr, eps=1e-5)
+
+ # Create replay buffer
+ replay_buffer = make_offline_replay_buffer(cfg.replay_buffer)
+
+ # Create Discriminator
+ discriminator = make_gail_discriminator(cfg, collector.env, device)
+
+ # Create loss
+ discriminator_loss = GAILLoss(
+ discriminator,
+ use_grad_penalty=cfg.gail.use_grad_penalty,
+ gp_lambda=cfg.gail.gp_lambda,
+ )
+
+ # Create optimizer
+ discriminator_optim = torch.optim.Adam(
+ params=discriminator.parameters(), lr=cfg.gail.lr
+ )
+
+ # Create test environment
+ logger_video = cfg.logger.video
+ test_env = make_env(cfg.env.env_name, device, from_pixels=logger_video)
+ if logger_video:
+ test_env = test_env.append_transform(
+ VideoRecorder(logger, tag="rendering/test", in_keys=["pixels"])
+ )
+ test_env.eval()
+
+ # Training loop
+ collected_frames = 0
+ num_network_updates = 0
+ pbar = tqdm.tqdm(total=cfg.ppo.collector.total_frames)
+
+ # extract cfg variables
+ cfg_loss_ppo_epochs = cfg.ppo.loss.ppo_epochs
+ cfg_optim_anneal_lr = cfg.ppo.optim.anneal_lr
+ cfg_optim_lr = cfg.ppo.optim.lr
+ cfg_loss_anneal_clip_eps = cfg.ppo.loss.anneal_clip_epsilon
+ cfg_loss_clip_epsilon = cfg.ppo.loss.clip_epsilon
+ cfg_logger_test_interval = cfg.logger.test_interval
+ cfg_logger_num_test_episodes = cfg.logger.num_test_episodes
+
+ for i, data in enumerate(collector):
+
+ log_info = {}
+ frames_in_batch = data.numel()
+ collected_frames += frames_in_batch
+ pbar.update(data.numel())
+
+ # Update discriminator
+ # Get expert data
+ expert_data = replay_buffer.sample()
+ expert_data = expert_data.to(device)
+ # Add collector data to expert data
+ expert_data.set(
+ discriminator_loss.tensor_keys.collector_action,
+ data["action"][: expert_data.batch_size[0]],
+ )
+ expert_data.set(
+ discriminator_loss.tensor_keys.collector_observation,
+ data["observation"][: expert_data.batch_size[0]],
+ )
+ d_loss = discriminator_loss(expert_data)
+
+ # Backward pass
+ discriminator_optim.zero_grad()
+ d_loss.get("loss").backward()
+ discriminator_optim.step()
+
+ # Compute discriminator reward
+ with torch.no_grad():
+ data = discriminator(data)
+ d_rewards = -torch.log(1 - data["d_logits"] + 1e-8)
+
+ # Set discriminator rewards to tensordict
+ data.set(("next", "reward"), d_rewards)
+
+ # Get training rewards and episode lengths
+ episode_rewards = data["next", "episode_reward"][data["next", "done"]]
+ if len(episode_rewards) > 0:
+ episode_length = data["next", "step_count"][data["next", "done"]]
+ log_info.update(
+ {
+ "train/reward": episode_rewards.mean().item(),
+ "train/episode_length": episode_length.sum().item()
+ / len(episode_length),
+ }
+ )
+ # Update PPO
+ for _ in range(cfg_loss_ppo_epochs):
+
+ # Compute GAE
+ with torch.no_grad():
+ data = adv_module(data)
+ data_reshape = data.reshape(-1)
+
+ # Update the data buffer
+ data_buffer.extend(data_reshape)
+
+ for _, batch in enumerate(data_buffer):
+
+ # Get a data batch
+ batch = batch.to(device)
+
+ # Linearly decrease the learning rate and clip epsilon
+ alpha = 1.0
+ if cfg_optim_anneal_lr:
+ alpha = 1 - (num_network_updates / total_network_updates)
+ for group in actor_optim.param_groups:
+ group["lr"] = cfg_optim_lr * alpha
+ for group in critic_optim.param_groups:
+ group["lr"] = cfg_optim_lr * alpha
+ if cfg_loss_anneal_clip_eps:
+ loss_module.clip_epsilon.copy_(cfg_loss_clip_epsilon * alpha)
+ num_network_updates += 1
+
+ # Forward pass PPO loss
+ loss = loss_module(batch)
+ critic_loss = loss["loss_critic"]
+ actor_loss = loss["loss_objective"] + loss["loss_entropy"]
+
+ # Backward pass
+ actor_loss.backward()
+ critic_loss.backward()
+
+ # Update the networks
+ actor_optim.step()
+ critic_optim.step()
+ actor_optim.zero_grad()
+ critic_optim.zero_grad()
+
+ log_info.update(
+ {
+ "train/actor_loss": actor_loss.item(),
+ "train/critic_loss": critic_loss.item(),
+ "train/discriminator_loss": d_loss["loss"].item(),
+ "train/lr": alpha * cfg_optim_lr,
+ "train/clip_epsilon": (
+ alpha * cfg_loss_clip_epsilon
+ if cfg_loss_anneal_clip_eps
+ else cfg_loss_clip_epsilon
+ ),
+ }
+ )
+
+ # evaluation
+ with torch.no_grad(), set_exploration_type(ExplorationType.DETERMINISTIC):
+ if ((i - 1) * frames_in_batch) // cfg_logger_test_interval < (
+ i * frames_in_batch
+ ) // cfg_logger_test_interval:
+ actor.eval()
+ test_rewards = eval_model(
+ actor, test_env, num_episodes=cfg_logger_num_test_episodes
+ )
+ log_info.update(
+ {
+ "eval/reward": test_rewards.mean(),
+ }
+ )
+ actor.train()
+ if logger is not None:
+ log_metrics(logger, log_info, i)
+
+ pbar.close()
+
+
+if __name__ == "__main__":
+ main()
diff --git a/sota-implementations/gail/gail_utils.py b/sota-implementations/gail/gail_utils.py
new file mode 100644
index 00000000000..067e9c8c927
--- /dev/null
+++ b/sota-implementations/gail/gail_utils.py
@@ -0,0 +1,69 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+import torch.nn as nn
+import torch.optim
+
+from torchrl.data.datasets.d4rl import D4RLExperienceReplay
+from torchrl.data.replay_buffers import SamplerWithoutReplacement
+from torchrl.envs import DoubleToFloat
+
+from torchrl.modules import SafeModule
+
+
+# ====================================================================
+# Offline Replay buffer
+# ---------------------------
+
+
+def make_offline_replay_buffer(rb_cfg):
+ data = D4RLExperienceReplay(
+ dataset_id=rb_cfg.dataset,
+ split_trajs=False,
+ batch_size=rb_cfg.batch_size,
+ sampler=SamplerWithoutReplacement(drop_last=False),
+ prefetch=4,
+ direct_download=True,
+ )
+
+ data.append_transform(DoubleToFloat())
+
+ return data
+
+
+def make_gail_discriminator(cfg, train_env, device="cpu"):
+ """Make GAIL discriminator."""
+
+ state_dim = train_env.observation_spec["observation"].shape[0]
+ action_dim = train_env.action_spec.shape[0]
+
+ hidden_dim = cfg.gail.hidden_dim
+
+ # Define Discriminator Network
+ class Discriminator(nn.Module):
+ def __init__(self, state_dim, action_dim):
+ super(Discriminator, self).__init__()
+ self.fc1 = nn.Linear(state_dim + action_dim, hidden_dim)
+ self.fc2 = nn.Linear(hidden_dim, hidden_dim)
+ self.fc3 = nn.Linear(hidden_dim, 1)
+
+ def forward(self, state, action):
+ x = torch.cat([state, action], dim=1)
+ x = torch.relu(self.fc1(x))
+ x = torch.relu(self.fc2(x))
+ return torch.sigmoid(self.fc3(x))
+
+ d_module = SafeModule(
+ module=Discriminator(state_dim, action_dim),
+ in_keys=["observation", "action"],
+ out_keys=["d_logits"],
+ )
+ return d_module.to(device)
+
+
+def log_metrics(logger, metrics, step):
+ if logger is not None:
+ for metric_name, metric_value in metrics.items():
+ logger.log_scalar(metric_name, metric_value, step)
diff --git a/sota-implementations/gail/ppo_utils.py b/sota-implementations/gail/ppo_utils.py
new file mode 100644
index 00000000000..7986738f8e6
--- /dev/null
+++ b/sota-implementations/gail/ppo_utils.py
@@ -0,0 +1,150 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+import torch.nn
+import torch.optim
+
+from tensordict.nn import AddStateIndependentNormalScale, TensorDictModule
+from torchrl.data import CompositeSpec
+from torchrl.envs import (
+ ClipTransform,
+ DoubleToFloat,
+ ExplorationType,
+ RewardSum,
+ StepCounter,
+ TransformedEnv,
+ VecNorm,
+)
+from torchrl.envs.libs.gym import GymEnv
+from torchrl.modules import MLP, ProbabilisticActor, TanhNormal, ValueOperator
+from torchrl.record import VideoRecorder
+
+
+# ====================================================================
+# Environment utils
+# --------------------------------------------------------------------
+
+
+def make_env(env_name="HalfCheetah-v4", device="cpu", from_pixels: bool = False):
+ env = GymEnv(env_name, device=device, from_pixels=from_pixels, pixels_only=False)
+ env = TransformedEnv(env)
+ env.append_transform(VecNorm(in_keys=["observation"], decay=0.99999, eps=1e-2))
+ env.append_transform(ClipTransform(in_keys=["observation"], low=-10, high=10))
+ env.append_transform(RewardSum())
+ env.append_transform(StepCounter())
+ env.append_transform(DoubleToFloat(in_keys=["observation"]))
+ return env
+
+
+# ====================================================================
+# Model utils
+# --------------------------------------------------------------------
+
+
+def make_ppo_models_state(proof_environment):
+
+ # Define input shape
+ input_shape = proof_environment.observation_spec["observation"].shape
+
+ # Define policy output distribution class
+ num_outputs = proof_environment.action_spec.shape[-1]
+ distribution_class = TanhNormal
+ distribution_kwargs = {
+ "low": proof_environment.action_spec.space.low,
+ "high": proof_environment.action_spec.space.high,
+ "tanh_loc": False,
+ }
+
+ # Define policy architecture
+ policy_mlp = MLP(
+ in_features=input_shape[-1],
+ activation_class=torch.nn.Tanh,
+ out_features=num_outputs, # predict only loc
+ num_cells=[64, 64],
+ )
+
+ # Initialize policy weights
+ for layer in policy_mlp.modules():
+ if isinstance(layer, torch.nn.Linear):
+ torch.nn.init.orthogonal_(layer.weight, 1.0)
+ layer.bias.data.zero_()
+
+ # Add state-independent normal scale
+ policy_mlp = torch.nn.Sequential(
+ policy_mlp,
+ AddStateIndependentNormalScale(
+ proof_environment.action_spec.shape[-1], scale_lb=1e-8
+ ),
+ )
+
+ # Add probabilistic sampling of the actions
+ policy_module = ProbabilisticActor(
+ TensorDictModule(
+ module=policy_mlp,
+ in_keys=["observation"],
+ out_keys=["loc", "scale"],
+ ),
+ in_keys=["loc", "scale"],
+ spec=CompositeSpec(action=proof_environment.action_spec),
+ distribution_class=distribution_class,
+ distribution_kwargs=distribution_kwargs,
+ return_log_prob=True,
+ default_interaction_type=ExplorationType.RANDOM,
+ )
+
+ # Define value architecture
+ value_mlp = MLP(
+ in_features=input_shape[-1],
+ activation_class=torch.nn.Tanh,
+ out_features=1,
+ num_cells=[64, 64],
+ )
+
+ # Initialize value weights
+ for layer in value_mlp.modules():
+ if isinstance(layer, torch.nn.Linear):
+ torch.nn.init.orthogonal_(layer.weight, 0.01)
+ layer.bias.data.zero_()
+
+ # Define value module
+ value_module = ValueOperator(
+ value_mlp,
+ in_keys=["observation"],
+ )
+
+ return policy_module, value_module
+
+
+def make_ppo_models(env_name):
+ proof_environment = make_env(env_name, device="cpu")
+ actor, critic = make_ppo_models_state(proof_environment)
+ return actor, critic
+
+
+# ====================================================================
+# Evaluation utils
+# --------------------------------------------------------------------
+
+
+def dump_video(module):
+ if isinstance(module, VideoRecorder):
+ module.dump()
+
+
+def eval_model(actor, test_env, num_episodes=3):
+ test_rewards = []
+ for _ in range(num_episodes):
+ td_test = test_env.rollout(
+ policy=actor,
+ auto_reset=True,
+ auto_cast_to_device=True,
+ break_when_any_done=True,
+ max_steps=10_000_000,
+ )
+ reward = td_test["next", "episode_reward"][td_test["next", "done"]]
+ test_rewards.append(reward.cpu())
+ test_env.apply(dump_video)
+ del td_test
+ return torch.cat(test_rewards, 0).mean()
diff --git a/sota-implementations/impala/utils.py b/sota-implementations/impala/utils.py
index b365dca3867..30293940377 100644
--- a/sota-implementations/impala/utils.py
+++ b/sota-implementations/impala/utils.py
@@ -6,7 +6,7 @@
import torch.nn
import torch.optim
from tensordict.nn import TensorDictModule
-from torchrl.data import CompositeSpec
+from torchrl.data import Composite
from torchrl.envs import (
CatFrames,
DoubleToFloat,
@@ -100,7 +100,7 @@ def make_ppo_modules_pixels(proof_environment):
out_keys=["common_features"],
)
- # Define on head for the policy
+ # Define one head for the policy
policy_net = MLP(
in_features=common_mlp_output.shape[-1],
out_features=num_outputs,
@@ -117,7 +117,7 @@ def make_ppo_modules_pixels(proof_environment):
policy_module = ProbabilisticActor(
policy_module,
in_keys=["logits"],
- spec=CompositeSpec(action=proof_environment.action_spec),
+ spec=Composite(action=proof_environment.action_spec),
distribution_class=distribution_class,
distribution_kwargs=distribution_kwargs,
return_log_prob=True,
diff --git a/sota-implementations/iql/iql_offline.py b/sota-implementations/iql/iql_offline.py
index d1a16fd8192..53581782d20 100644
--- a/sota-implementations/iql/iql_offline.py
+++ b/sota-implementations/iql/iql_offline.py
@@ -141,6 +141,10 @@ def main(cfg: "DictConfig"): # noqa: F821
log_metrics(logger, to_log, i)
pbar.close()
+ if not eval_env.is_closed:
+ eval_env.close()
+ if not train_env.is_closed:
+ train_env.close()
torchrl_logger.info(f"Training time: {time.time() - start_time}")
diff --git a/sota-implementations/iql/iql_online.py b/sota-implementations/iql/iql_online.py
index d50ff806294..3cdff06ffa2 100644
--- a/sota-implementations/iql/iql_online.py
+++ b/sota-implementations/iql/iql_online.py
@@ -204,6 +204,12 @@ def main(cfg: "DictConfig"): # noqa: F821
collector.shutdown()
end_time = time.time()
execution_time = end_time - start_time
+
+ if not eval_env.is_closed:
+ eval_env.close()
+ if not train_env.is_closed:
+ train_env.close()
+
torchrl_logger.info(f"Training took {execution_time:.2f} seconds to finish")
diff --git a/sota-implementations/iql/utils.py b/sota-implementations/iql/utils.py
index 61d31b88eb8..a24c6168375 100644
--- a/sota-implementations/iql/utils.py
+++ b/sota-implementations/iql/utils.py
@@ -11,7 +11,7 @@
from torchrl.collectors import SyncDataCollector
from torchrl.data import (
- CompositeSpec,
+ Composite,
LazyMemmapStorage,
TensorDictPrioritizedReplayBuffer,
TensorDictReplayBuffer,
@@ -306,7 +306,7 @@ def make_discrete_iql_model(cfg, train_env, eval_env, device):
out_keys=["logits"],
)
actor = ProbabilisticActor(
- spec=CompositeSpec(action=eval_env.action_spec),
+ spec=Composite(action=eval_env.action_spec),
module=actor_module,
in_keys=["logits"],
out_keys=["action"],
diff --git a/sota-implementations/multiagent/iql.py b/sota-implementations/multiagent/iql.py
index a4d2b88a9d0..39750c5d425 100644
--- a/sota-implementations/multiagent/iql.py
+++ b/sota-implementations/multiagent/iql.py
@@ -225,6 +225,12 @@ def train(cfg: "DictConfig"): # noqa: F821
logger.experiment.log({}, commit=True)
sampling_start = time.time()
+ collector.shutdown()
+ if not env.is_closed:
+ env.close()
+ if not env_test.is_closed:
+ env_test.close()
+
if __name__ == "__main__":
train()
diff --git a/sota-implementations/multiagent/maddpg_iddpg.py b/sota-implementations/multiagent/maddpg_iddpg.py
index e9de2ac4e14..aad1df14fff 100644
--- a/sota-implementations/multiagent/maddpg_iddpg.py
+++ b/sota-implementations/multiagent/maddpg_iddpg.py
@@ -251,6 +251,11 @@ def train(cfg: "DictConfig"): # noqa: F821
if cfg.logger.backend == "wandb":
logger.experiment.log({}, commit=True)
sampling_start = time.time()
+ collector.shutdown()
+ if not env.is_closed:
+ env.close()
+ if not env_test.is_closed:
+ env_test.close()
if __name__ == "__main__":
diff --git a/sota-implementations/multiagent/mappo_ippo.py b/sota-implementations/multiagent/mappo_ippo.py
index fa006a7d4a2..d2e218b843a 100644
--- a/sota-implementations/multiagent/mappo_ippo.py
+++ b/sota-implementations/multiagent/mappo_ippo.py
@@ -254,6 +254,11 @@ def train(cfg: "DictConfig"): # noqa: F821
if cfg.logger.backend == "wandb":
logger.experiment.log({}, commit=True)
sampling_start = time.time()
+ collector.shutdown()
+ if not env.is_closed:
+ env.close()
+ if not env_test.is_closed:
+ env_test.close()
if __name__ == "__main__":
diff --git a/sota-implementations/multiagent/qmix_vdn.py b/sota-implementations/multiagent/qmix_vdn.py
index 4e6a962c556..c5993f902c6 100644
--- a/sota-implementations/multiagent/qmix_vdn.py
+++ b/sota-implementations/multiagent/qmix_vdn.py
@@ -259,6 +259,11 @@ def train(cfg: "DictConfig"): # noqa: F821
if cfg.logger.backend == "wandb":
logger.experiment.log({}, commit=True)
sampling_start = time.time()
+ collector.shutdown()
+ if not env.is_closed:
+ env.close()
+ if not env_test.is_closed:
+ env_test.close()
if __name__ == "__main__":
diff --git a/sota-implementations/multiagent/sac.py b/sota-implementations/multiagent/sac.py
index f7b2523010b..cfafdd47c96 100644
--- a/sota-implementations/multiagent/sac.py
+++ b/sota-implementations/multiagent/sac.py
@@ -318,6 +318,11 @@ def train(cfg: "DictConfig"): # noqa: F821
if cfg.logger.backend == "wandb":
logger.experiment.log({}, commit=True)
sampling_start = time.time()
+ collector.shutdown()
+ if not env.is_closed:
+ env.close()
+ if not env_test.is_closed:
+ env_test.close()
if __name__ == "__main__":
diff --git a/sota-implementations/ppo/ppo_atari.py b/sota-implementations/ppo/ppo_atari.py
index 2b02254032a..6d8883393d5 100644
--- a/sota-implementations/ppo/ppo_atari.py
+++ b/sota-implementations/ppo/ppo_atari.py
@@ -243,6 +243,9 @@ def main(cfg: "DictConfig"): # noqa: F821
sampling_start = time.time()
collector.shutdown()
+ if not test_env.is_closed:
+ test_env.close()
+
end_time = time.time()
execution_time = end_time - start_time
torchrl_logger.info(f"Training took {execution_time:.2f} seconds to finish")
diff --git a/sota-implementations/ppo/ppo_mujoco.py b/sota-implementations/ppo/ppo_mujoco.py
index 219ae1b59b6..8cfea74d0bc 100644
--- a/sota-implementations/ppo/ppo_mujoco.py
+++ b/sota-implementations/ppo/ppo_mujoco.py
@@ -235,6 +235,9 @@ def main(cfg: "DictConfig"): # noqa: F821
collector.update_policy_weights_()
sampling_start = time.time()
+ collector.shutdown()
+ if not test_env.is_closed:
+ test_env.close()
end_time = time.time()
execution_time = end_time - start_time
torchrl_logger.info(f"Training took {execution_time:.2f} seconds to finish")
diff --git a/sota-implementations/ppo/utils_atari.py b/sota-implementations/ppo/utils_atari.py
index 2344da518bc..50f91ed49cd 100644
--- a/sota-implementations/ppo/utils_atari.py
+++ b/sota-implementations/ppo/utils_atari.py
@@ -6,8 +6,8 @@
import torch.nn
import torch.optim
from tensordict.nn import TensorDictModule
-from torchrl.data import CompositeSpec
-from torchrl.data.tensor_specs import DiscreteBox
+from torchrl.data import Composite
+from torchrl.data.tensor_specs import CategoricalBox
from torchrl.envs import (
CatFrames,
DoubleToFloat,
@@ -92,7 +92,7 @@ def make_ppo_modules_pixels(proof_environment):
input_shape = proof_environment.observation_spec["pixels"].shape
# Define distribution class and kwargs
- if isinstance(proof_environment.action_spec.space, DiscreteBox):
+ if isinstance(proof_environment.action_spec.space, CategoricalBox):
num_outputs = proof_environment.action_spec.space.n
distribution_class = OneHotCategorical
distribution_kwargs = {}
@@ -148,7 +148,7 @@ def make_ppo_modules_pixels(proof_environment):
policy_module = ProbabilisticActor(
policy_module,
in_keys=["logits"],
- spec=CompositeSpec(action=proof_environment.action_spec),
+ spec=Composite(action=proof_environment.action_spec),
distribution_class=distribution_class,
distribution_kwargs=distribution_kwargs,
return_log_prob=True,
diff --git a/sota-implementations/ppo/utils_mujoco.py b/sota-implementations/ppo/utils_mujoco.py
index 7986738f8e6..a05d205b000 100644
--- a/sota-implementations/ppo/utils_mujoco.py
+++ b/sota-implementations/ppo/utils_mujoco.py
@@ -7,7 +7,7 @@
import torch.optim
from tensordict.nn import AddStateIndependentNormalScale, TensorDictModule
-from torchrl.data import CompositeSpec
+from torchrl.data import Composite
from torchrl.envs import (
ClipTransform,
DoubleToFloat,
@@ -87,7 +87,7 @@ def make_ppo_models_state(proof_environment):
out_keys=["loc", "scale"],
),
in_keys=["loc", "scale"],
- spec=CompositeSpec(action=proof_environment.action_spec),
+ spec=Composite(action=proof_environment.action_spec),
distribution_class=distribution_class,
distribution_kwargs=distribution_kwargs,
return_log_prob=True,
diff --git a/sota-implementations/redq/config.yaml b/sota-implementations/redq/config.yaml
index e60191c0f93..818f3386fda 100644
--- a/sota-implementations/redq/config.yaml
+++ b/sota-implementations/redq/config.yaml
@@ -36,7 +36,6 @@ collector:
multi_step: 1
n_steps_return: 3
max_frames_per_traj: -1
- exploration_mode: random
logger:
backend: wandb
diff --git a/sota-implementations/redq/redq.py b/sota-implementations/redq/redq.py
index eb802f6773d..865533aee2f 100644
--- a/sota-implementations/redq/redq.py
+++ b/sota-implementations/redq/redq.py
@@ -159,7 +159,7 @@ def main(cfg: "DictConfig"): # noqa: F821
use_env_creator=False,
)()
if isinstance(create_env_fn, ParallelEnv):
- raise NotImplementedError("This behaviour is deprecated")
+ raise NotImplementedError("This behavior is deprecated")
elif isinstance(create_env_fn, EnvCreator):
recorder.transform[1:].load_state_dict(
get_norm_state_dict(create_env_fn()), strict=False
diff --git a/sota-implementations/redq/utils.py b/sota-implementations/redq/utils.py
index dd922372cbb..8312d359366 100644
--- a/sota-implementations/redq/utils.py
+++ b/sota-implementations/redq/utils.py
@@ -1021,7 +1021,6 @@ def make_collector_offpolicy(
"init_random_frames": cfg.collector.init_random_frames,
"split_trajs": True,
# trajectories must be separated if multi-step is used
- "exploration_type": ExplorationType.from_str(cfg.collector.exploration_mode),
}
collector = collector_helper(**collector_helper_kwargs)
diff --git a/sota-implementations/sac/sac.py b/sota-implementations/sac/sac.py
index 9904fe072ab..68860500149 100644
--- a/sota-implementations/sac/sac.py
+++ b/sota-implementations/sac/sac.py
@@ -215,6 +215,10 @@ def main(cfg: "DictConfig"): # noqa: F821
sampling_start = time.time()
collector.shutdown()
+ if not eval_env.is_closed:
+ eval_env.close()
+ if not train_env.is_closed:
+ train_env.close()
end_time = time.time()
execution_time = end_time - start_time
torchrl_logger.info(f"Training took {execution_time:.2f} seconds to finish")
diff --git a/sota-implementations/td3/td3.py b/sota-implementations/td3/td3.py
index 632ee58503d..01a59686ac9 100644
--- a/sota-implementations/td3/td3.py
+++ b/sota-implementations/td3/td3.py
@@ -213,6 +213,10 @@ def main(cfg: "DictConfig"): # noqa: F821
sampling_start = time.time()
collector.shutdown()
+ if not eval_env.is_closed:
+ eval_env.close()
+ if not train_env.is_closed:
+ train_env.close()
end_time = time.time()
execution_time = end_time - start_time
torchrl_logger.info(f"Training took {execution_time:.2f} seconds to finish")
diff --git a/sota-implementations/td3_bc/td3_bc.py b/sota-implementations/td3_bc/td3_bc.py
index 7c43fdc1a12..930ff509488 100644
--- a/sota-implementations/td3_bc/td3_bc.py
+++ b/sota-implementations/td3_bc/td3_bc.py
@@ -128,7 +128,7 @@ def main(cfg: "DictConfig"): # noqa: F821
# evaluation
if i % evaluation_interval == 0:
- with set_exploration_type(ExplorationType.MODE), torch.no_grad():
+ with set_exploration_type(ExplorationType.DETERMINISTIC), torch.no_grad():
eval_td = eval_env.rollout(
max_steps=eval_steps, policy=model[0], auto_cast_to_device=True
)
@@ -138,6 +138,8 @@ def main(cfg: "DictConfig"): # noqa: F821
if logger is not None:
log_metrics(logger, to_log, i)
+ if not eval_env.is_closed:
+ eval_env.close()
pbar.close()
torchrl_logger.info(f"Training time: {time.time() - start_time}")
diff --git a/test/_utils_internal.py b/test/_utils_internal.py
index 61b0c003f9d..51535afa606 100644
--- a/test/_utils_internal.py
+++ b/test/_utils_internal.py
@@ -56,7 +56,7 @@ def HALFCHEETAH_VERSIONED():
def PONG_VERSIONED():
# load gym
- # Gymnasium says that the ale_py behaviour changes from 1.0
+ # Gymnasium says that the ale_py behavior changes from 1.0
# but with python 3.12 it is already the case with 0.29.1
try:
import ale_py # noqa
@@ -70,7 +70,7 @@ def PONG_VERSIONED():
def BREAKOUT_VERSIONED():
# load gym
- # Gymnasium says that the ale_py behaviour changes from 1.0
+ # Gymnasium says that the ale_py behavior changes from 1.0
# but with python 3.12 it is already the case with 0.29.1
try:
import ale_py # noqa
@@ -121,7 +121,7 @@ def _set_gym_environments(): # noqa: F811
_BREAKOUT_VERSIONED = "ALE/Breakout-v5"
-@implement_for("gymnasium")
+@implement_for("gymnasium", None, "1.0.0")
def _set_gym_environments(): # noqa: F811
global _CARTPOLE_VERSIONED, _HALFCHEETAH_VERSIONED, _PENDULUM_VERSIONED, _PONG_VERSIONED, _BREAKOUT_VERSIONED
@@ -132,6 +132,11 @@ def _set_gym_environments(): # noqa: F811
_BREAKOUT_VERSIONED = "ALE/Breakout-v5"
+@implement_for("gymnasium", "1.0.0", None)
+def _set_gym_environments(): # noqa: F811
+ raise ImportError
+
+
if _has_gym:
_set_gym_environments()
@@ -155,6 +160,8 @@ def get_default_devices():
return [torch.device("cpu")]
elif num_cuda == 1:
return [torch.device("cuda:0")]
+ elif torch.mps.is_available():
+ return [torch.device("mps:0")]
else:
# then run on all devices
return get_available_devices()
diff --git a/test/conftest.py b/test/conftest.py
index ca418d7b6f2..f2648a18041 100644
--- a/test/conftest.py
+++ b/test/conftest.py
@@ -113,6 +113,18 @@ def pytest_addoption(parser):
help="Use 'fork' start method for mp dedicated tests only if there is no cuda device available.",
)
+ parser.addoption(
+ "--unity_editor",
+ action="store_true",
+ default=False,
+ help="Run tests that require manually pressing play in the Unity editor.",
+ )
+
+
+def pytest_runtest_setup(item):
+ if "unity_editor" in item.keywords and not item.config.getoption("--unity_editor"):
+ pytest.skip("need --unity_editor option to run this test")
+
def pytest_configure(config):
config.addinivalue_line("markers", "slow: mark test as slow to run")
diff --git a/test/mocking_classes.py b/test/mocking_classes.py
index ea4327bb460..4d86d8ec0ac 100644
--- a/test/mocking_classes.py
+++ b/test/mocking_classes.py
@@ -11,15 +11,15 @@
from tensordict.utils import expand_right, NestedKey
from torchrl.data.tensor_specs import (
- BinaryDiscreteTensorSpec,
- BoundedTensorSpec,
- CompositeSpec,
- DiscreteTensorSpec,
- MultiOneHotDiscreteTensorSpec,
- NonTensorSpec,
- OneHotDiscreteTensorSpec,
+ Binary,
+ Bounded,
+ Categorical,
+ Composite,
+ MultiOneHot,
+ NonTensor,
+ OneHot,
TensorSpec,
- UnboundedContinuousTensorSpec,
+ Unbounded,
)
from torchrl.data.utils import consolidate_spec
from torchrl.envs.common import EnvBase
@@ -27,27 +27,27 @@
from torchrl.envs.utils import _terminated_or_truncated
spec_dict = {
- "bounded": BoundedTensorSpec,
- "one_hot": OneHotDiscreteTensorSpec,
- "categorical": DiscreteTensorSpec,
- "unbounded": UnboundedContinuousTensorSpec,
- "binary": BinaryDiscreteTensorSpec,
- "mult_one_hot": MultiOneHotDiscreteTensorSpec,
- "composite": CompositeSpec,
+ "bounded": Bounded,
+ "one_hot": OneHot,
+ "categorical": Categorical,
+ "unbounded": Unbounded,
+ "binary": Binary,
+ "mult_one_hot": MultiOneHot,
+ "composite": Composite,
}
default_spec_kwargs = {
- OneHotDiscreteTensorSpec: {"n": 7},
- DiscreteTensorSpec: {"n": 7},
- BoundedTensorSpec: {"minimum": -torch.ones(4), "maximum": torch.ones(4)},
- UnboundedContinuousTensorSpec: {
+ OneHot: {"n": 7},
+ Categorical: {"n": 7},
+ Bounded: {"minimum": -torch.ones(4), "maximum": torch.ones(4)},
+ Unbounded: {
"shape": [
7,
]
},
- BinaryDiscreteTensorSpec: {"n": 7},
- MultiOneHotDiscreteTensorSpec: {"nvec": [7, 3, 5]},
- CompositeSpec: {},
+ Binary: {"n": 7},
+ MultiOneHot: {"nvec": [7, 3, 5]},
+ Composite: {},
}
@@ -68,8 +68,8 @@ def __new__(
torch.get_default_dtype()
)
reward_spec = cls._output_spec["full_reward_spec"]
- if isinstance(reward_spec, CompositeSpec):
- reward_spec = CompositeSpec(
+ if isinstance(reward_spec, Composite):
+ reward_spec = Composite(
{
key: item.to(torch.get_default_dtype())
for key, item in reward_spec.items(True, True)
@@ -80,19 +80,19 @@ def __new__(
else:
reward_spec = reward_spec.to(torch.get_default_dtype())
cls._output_spec["full_reward_spec"] = reward_spec
- if not isinstance(cls._output_spec["full_reward_spec"], CompositeSpec):
- cls._output_spec["full_reward_spec"] = CompositeSpec(
+ if not isinstance(cls._output_spec["full_reward_spec"], Composite):
+ cls._output_spec["full_reward_spec"] = Composite(
reward=cls._output_spec["full_reward_spec"],
shape=cls._output_spec["full_reward_spec"].shape[:-1],
)
- if not isinstance(cls._output_spec["full_done_spec"], CompositeSpec):
- cls._output_spec["full_done_spec"] = CompositeSpec(
+ if not isinstance(cls._output_spec["full_done_spec"], Composite):
+ cls._output_spec["full_done_spec"] = Composite(
done=cls._output_spec["full_done_spec"].clone(),
terminated=cls._output_spec["full_done_spec"].clone(),
shape=cls._output_spec["full_done_spec"].shape[:-1],
)
- if not isinstance(cls._input_spec["full_action_spec"], CompositeSpec):
- cls._input_spec["full_action_spec"] = CompositeSpec(
+ if not isinstance(cls._input_spec["full_action_spec"], Composite):
+ cls._input_spec["full_action_spec"] = Composite(
action=cls._input_spec["full_action_spec"],
shape=cls._input_spec["full_action_spec"].shape[:-1],
)
@@ -156,15 +156,15 @@ def __new__(
):
batch_size = kwargs.setdefault("batch_size", torch.Size([]))
if action_spec is None:
- action_spec = UnboundedContinuousTensorSpec(
+ action_spec = Unbounded(
(
*batch_size,
1,
)
)
if observation_spec is None:
- observation_spec = CompositeSpec(
- observation=UnboundedContinuousTensorSpec(
+ observation_spec = Composite(
+ observation=Unbounded(
(
*batch_size,
1,
@@ -173,35 +173,35 @@ def __new__(
shape=batch_size,
)
if reward_spec is None:
- reward_spec = UnboundedContinuousTensorSpec(
+ reward_spec = Unbounded(
(
*batch_size,
1,
)
)
if done_spec is None:
- done_spec = DiscreteTensorSpec(2, dtype=torch.bool, shape=(*batch_size, 1))
+ done_spec = Categorical(2, dtype=torch.bool, shape=(*batch_size, 1))
if state_spec is None:
- state_spec = CompositeSpec(shape=batch_size)
- input_spec = CompositeSpec(
+ state_spec = Composite(shape=batch_size)
+ input_spec = Composite(
full_action_spec=action_spec, full_state_spec=state_spec, shape=batch_size
)
- cls._output_spec = CompositeSpec(shape=batch_size)
+ cls._output_spec = Composite(shape=batch_size)
cls._output_spec["full_reward_spec"] = reward_spec
cls._output_spec["full_done_spec"] = done_spec
cls._output_spec["full_observation_spec"] = observation_spec
cls._input_spec = input_spec
- if not isinstance(cls._output_spec["full_reward_spec"], CompositeSpec):
- cls._output_spec["full_reward_spec"] = CompositeSpec(
+ if not isinstance(cls._output_spec["full_reward_spec"], Composite):
+ cls._output_spec["full_reward_spec"] = Composite(
reward=cls._output_spec["full_reward_spec"], shape=batch_size
)
- if not isinstance(cls._output_spec["full_done_spec"], CompositeSpec):
- cls._output_spec["full_done_spec"] = CompositeSpec(
+ if not isinstance(cls._output_spec["full_done_spec"], Composite):
+ cls._output_spec["full_done_spec"] = Composite(
done=cls._output_spec["full_done_spec"], shape=batch_size
)
- if not isinstance(cls._input_spec["full_action_spec"], CompositeSpec):
- cls._input_spec["full_action_spec"] = CompositeSpec(
+ if not isinstance(cls._input_spec["full_action_spec"], Composite):
+ cls._input_spec["full_action_spec"] = Composite(
action=cls._input_spec["full_action_spec"], shape=batch_size
)
return super().__new__(*args, **kwargs)
@@ -268,15 +268,15 @@ def __new__(
):
batch_size = kwargs.setdefault("batch_size", torch.Size([]))
if action_spec is None:
- action_spec = UnboundedContinuousTensorSpec(
+ action_spec = Unbounded(
(
*batch_size,
1,
)
)
if state_spec is None:
- state_spec = CompositeSpec(
- observation=UnboundedContinuousTensorSpec(
+ state_spec = Composite(
+ observation=Unbounded(
(
*batch_size,
1,
@@ -285,8 +285,8 @@ def __new__(
shape=batch_size,
)
if observation_spec is None:
- observation_spec = CompositeSpec(
- observation=UnboundedContinuousTensorSpec(
+ observation_spec = Composite(
+ observation=Unbounded(
(
*batch_size,
1,
@@ -295,33 +295,33 @@ def __new__(
shape=batch_size,
)
if reward_spec is None:
- reward_spec = UnboundedContinuousTensorSpec(
+ reward_spec = Unbounded(
(
*batch_size,
1,
)
)
if done_spec is None:
- done_spec = DiscreteTensorSpec(2, dtype=torch.bool, shape=(*batch_size, 1))
- cls._output_spec = CompositeSpec(shape=batch_size)
+ done_spec = Categorical(2, dtype=torch.bool, shape=(*batch_size, 1))
+ cls._output_spec = Composite(shape=batch_size)
cls._output_spec["full_reward_spec"] = reward_spec
cls._output_spec["full_done_spec"] = done_spec
cls._output_spec["full_observation_spec"] = observation_spec
- cls._input_spec = CompositeSpec(
+ cls._input_spec = Composite(
full_action_spec=action_spec,
full_state_spec=state_spec,
shape=batch_size,
)
- if not isinstance(cls._output_spec["full_reward_spec"], CompositeSpec):
- cls._output_spec["full_reward_spec"] = CompositeSpec(
+ if not isinstance(cls._output_spec["full_reward_spec"], Composite):
+ cls._output_spec["full_reward_spec"] = Composite(
reward=cls._output_spec["full_reward_spec"], shape=batch_size
)
- if not isinstance(cls._output_spec["full_done_spec"], CompositeSpec):
- cls._output_spec["full_done_spec"] = CompositeSpec(
+ if not isinstance(cls._output_spec["full_done_spec"], Composite):
+ cls._output_spec["full_done_spec"] = Composite(
done=cls._output_spec["full_done_spec"], shape=batch_size
)
- if not isinstance(cls._input_spec["full_action_spec"], CompositeSpec):
- cls._input_spec["full_action_spec"] = CompositeSpec(
+ if not isinstance(cls._input_spec["full_action_spec"], Composite):
+ cls._input_spec["full_action_spec"] = Composite(
action=cls._input_spec["full_action_spec"], shape=batch_size
)
return super().__new__(cls, *args, **kwargs)
@@ -442,46 +442,38 @@ def __new__(
size = cls.size = 7
if observation_spec is None:
cls.out_key = "observation"
- observation_spec = CompositeSpec(
- observation=UnboundedContinuousTensorSpec(
- shape=torch.Size([*batch_size, size])
- ),
- observation_orig=UnboundedContinuousTensorSpec(
- shape=torch.Size([*batch_size, size])
- ),
+ observation_spec = Composite(
+ observation=Unbounded(shape=torch.Size([*batch_size, size])),
+ observation_orig=Unbounded(shape=torch.Size([*batch_size, size])),
shape=batch_size,
)
if action_spec is None:
if categorical_action_encoding:
- action_spec_cls = DiscreteTensorSpec
+ action_spec_cls = Categorical
action_spec = action_spec_cls(n=7, shape=batch_size)
else:
- action_spec_cls = OneHotDiscreteTensorSpec
+ action_spec_cls = OneHot
action_spec = action_spec_cls(n=7, shape=(*batch_size, 7))
if reward_spec is None:
- reward_spec = CompositeSpec(
- reward=UnboundedContinuousTensorSpec(shape=(1,))
- )
+ reward_spec = Composite(reward=Unbounded(shape=(1,)))
if done_spec is None:
- done_spec = CompositeSpec(
- terminated=DiscreteTensorSpec(
- 2, dtype=torch.bool, shape=(*batch_size, 1)
- )
+ done_spec = Composite(
+ terminated=Categorical(2, dtype=torch.bool, shape=(*batch_size, 1))
)
if state_spec is None:
cls._out_key = "observation_orig"
- state_spec = CompositeSpec(
+ state_spec = Composite(
{
cls._out_key: observation_spec["observation"],
},
shape=batch_size,
)
- cls._output_spec = CompositeSpec(shape=batch_size)
+ cls._output_spec = Composite(shape=batch_size)
cls._output_spec["full_reward_spec"] = reward_spec
cls._output_spec["full_done_spec"] = done_spec
cls._output_spec["full_observation_spec"] = observation_spec
- cls._input_spec = CompositeSpec(
+ cls._input_spec = Composite(
full_action_spec=action_spec,
full_state_spec=state_spec,
shape=batch_size,
@@ -553,17 +545,13 @@ def __new__(
size = cls.size = 7
if observation_spec is None:
cls.out_key = "observation"
- observation_spec = CompositeSpec(
- observation=UnboundedContinuousTensorSpec(
- shape=torch.Size([*batch_size, size])
- ),
- observation_orig=UnboundedContinuousTensorSpec(
- shape=torch.Size([*batch_size, size])
- ),
+ observation_spec = Composite(
+ observation=Unbounded(shape=torch.Size([*batch_size, size])),
+ observation_orig=Unbounded(shape=torch.Size([*batch_size, size])),
shape=batch_size,
)
if action_spec is None:
- action_spec = BoundedTensorSpec(
+ action_spec = Bounded(
-1,
1,
(
@@ -572,23 +560,23 @@ def __new__(
),
)
if reward_spec is None:
- reward_spec = UnboundedContinuousTensorSpec(shape=(*batch_size, 1))
+ reward_spec = Unbounded(shape=(*batch_size, 1))
if done_spec is None:
- done_spec = DiscreteTensorSpec(2, dtype=torch.bool, shape=(*batch_size, 1))
+ done_spec = Categorical(2, dtype=torch.bool, shape=(*batch_size, 1))
if state_spec is None:
cls._out_key = "observation_orig"
- state_spec = CompositeSpec(
+ state_spec = Composite(
{
cls._out_key: observation_spec["observation"],
},
shape=batch_size,
)
- cls._output_spec = CompositeSpec(shape=batch_size)
+ cls._output_spec = Composite(shape=batch_size)
cls._output_spec["full_reward_spec"] = reward_spec
cls._output_spec["full_done_spec"] = done_spec
cls._output_spec["full_observation_spec"] = observation_spec
- cls._input_spec = CompositeSpec(
+ cls._input_spec = Composite(
full_action_spec=action_spec,
full_state_spec=state_spec,
shape=batch_size,
@@ -681,25 +669,21 @@ def __new__(
batch_size = kwargs.setdefault("batch_size", torch.Size([]))
if observation_spec is None:
cls.out_key = "pixels"
- observation_spec = CompositeSpec(
- pixels=UnboundedContinuousTensorSpec(
- shape=torch.Size([*batch_size, 1, 7, 7])
- ),
- pixels_orig=UnboundedContinuousTensorSpec(
- shape=torch.Size([*batch_size, 1, 7, 7])
- ),
+ observation_spec = Composite(
+ pixels=Unbounded(shape=torch.Size([*batch_size, 1, 7, 7])),
+ pixels_orig=Unbounded(shape=torch.Size([*batch_size, 1, 7, 7])),
shape=batch_size,
)
if action_spec is None:
- action_spec = OneHotDiscreteTensorSpec(7, shape=(*batch_size, 7))
+ action_spec = OneHot(7, shape=(*batch_size, 7))
if reward_spec is None:
- reward_spec = UnboundedContinuousTensorSpec(shape=(*batch_size, 1))
+ reward_spec = Unbounded(shape=(*batch_size, 1))
if done_spec is None:
- done_spec = DiscreteTensorSpec(2, dtype=torch.bool, shape=(*batch_size, 1))
+ done_spec = Categorical(2, dtype=torch.bool, shape=(*batch_size, 1))
if state_spec is None:
cls._out_key = "pixels_orig"
- state_spec = CompositeSpec(
+ state_spec = Composite(
{
cls._out_key: observation_spec["pixels_orig"].clone(),
},
@@ -741,25 +725,17 @@ def __new__(
batch_size = kwargs.setdefault("batch_size", torch.Size([]))
if observation_spec is None:
cls.out_key = "pixels"
- observation_spec = CompositeSpec(
- pixels=UnboundedContinuousTensorSpec(
- shape=torch.Size([*batch_size, 7, 7, 3])
- ),
- pixels_orig=UnboundedContinuousTensorSpec(
- shape=torch.Size([*batch_size, 7, 7, 3])
- ),
+ observation_spec = Composite(
+ pixels=Unbounded(shape=torch.Size([*batch_size, 7, 7, 3])),
+ pixels_orig=Unbounded(shape=torch.Size([*batch_size, 7, 7, 3])),
shape=batch_size,
)
if action_spec is None:
- action_spec_cls = (
- DiscreteTensorSpec
- if categorical_action_encoding
- else OneHotDiscreteTensorSpec
- )
+ action_spec_cls = Categorical if categorical_action_encoding else OneHot
action_spec = action_spec_cls(7, shape=(*batch_size, 7))
if state_spec is None:
cls._out_key = "pixels_orig"
- state_spec = CompositeSpec(
+ state_spec = Composite(
{
cls._out_key: observation_spec["pixels_orig"],
},
@@ -808,25 +784,21 @@ def __new__(
pixel_shape = [1, 7, 7]
if observation_spec is None:
cls.out_key = "pixels"
- observation_spec = CompositeSpec(
- pixels=UnboundedContinuousTensorSpec(
- shape=torch.Size([*batch_size, *pixel_shape])
- ),
- pixels_orig=UnboundedContinuousTensorSpec(
- shape=torch.Size([*batch_size, *pixel_shape])
- ),
+ observation_spec = Composite(
+ pixels=Unbounded(shape=torch.Size([*batch_size, *pixel_shape])),
+ pixels_orig=Unbounded(shape=torch.Size([*batch_size, *pixel_shape])),
shape=batch_size,
)
if action_spec is None:
- action_spec = BoundedTensorSpec(-1, 1, [*batch_size, pixel_shape[-1]])
+ action_spec = Bounded(-1, 1, [*batch_size, pixel_shape[-1]])
if reward_spec is None:
- reward_spec = UnboundedContinuousTensorSpec(shape=(*batch_size, 1))
+ reward_spec = Unbounded(shape=(*batch_size, 1))
if done_spec is None:
- done_spec = DiscreteTensorSpec(2, dtype=torch.bool, shape=(*batch_size, 1))
+ done_spec = Categorical(2, dtype=torch.bool, shape=(*batch_size, 1))
if state_spec is None:
cls._out_key = "pixels_orig"
- state_spec = CompositeSpec(
+ state_spec = Composite(
{cls._out_key: observation_spec["pixels"]}, shape=batch_size
)
return super().__new__(
@@ -865,13 +837,9 @@ def __new__(
batch_size = kwargs.setdefault("batch_size", torch.Size([]))
if observation_spec is None:
cls.out_key = "pixels"
- observation_spec = CompositeSpec(
- pixels=UnboundedContinuousTensorSpec(
- shape=torch.Size([*batch_size, 7, 7, 3])
- ),
- pixels_orig=UnboundedContinuousTensorSpec(
- shape=torch.Size([*batch_size, 7, 7, 3])
- ),
+ observation_spec = Composite(
+ pixels=Unbounded(shape=torch.Size([*batch_size, 7, 7, 3])),
+ pixels_orig=Unbounded(shape=torch.Size([*batch_size, 7, 7, 3])),
)
return super().__new__(
*args,
@@ -928,8 +896,8 @@ def __init__(
device=device,
batch_size=batch_size,
)
- self.observation_spec = CompositeSpec(
- hidden_observation=UnboundedContinuousTensorSpec(
+ self.observation_spec = Composite(
+ hidden_observation=Unbounded(
(
*self.batch_size,
4,
@@ -937,8 +905,8 @@ def __init__(
),
shape=self.batch_size,
)
- self.state_spec = CompositeSpec(
- hidden_observation=UnboundedContinuousTensorSpec(
+ self.state_spec = Composite(
+ hidden_observation=Unbounded(
(
*self.batch_size,
4,
@@ -946,13 +914,13 @@ def __init__(
),
shape=self.batch_size,
)
- self.action_spec = UnboundedContinuousTensorSpec(
+ self.action_spec = Unbounded(
(
*self.batch_size,
1,
)
)
- self.reward_spec = UnboundedContinuousTensorSpec(
+ self.reward_spec = Unbounded(
(
*self.batch_size,
1,
@@ -1012,8 +980,8 @@ def __init__(self, max_steps: int = 5, start_val: int = 0, **kwargs):
self.max_steps = max_steps
self.start_val = start_val
- self.observation_spec = CompositeSpec(
- observation=UnboundedContinuousTensorSpec(
+ self.observation_spec = Composite(
+ observation=Unbounded(
(
*self.batch_size,
1,
@@ -1024,14 +992,14 @@ def __init__(self, max_steps: int = 5, start_val: int = 0, **kwargs):
shape=self.batch_size,
device=self.device,
)
- self.reward_spec = UnboundedContinuousTensorSpec(
+ self.reward_spec = Unbounded(
(
*self.batch_size,
1,
),
device=self.device,
)
- self.done_spec = DiscreteTensorSpec(
+ self.done_spec = Categorical(
2,
dtype=torch.bool,
shape=(
@@ -1040,9 +1008,7 @@ def __init__(self, max_steps: int = 5, start_val: int = 0, **kwargs):
),
device=self.device,
)
- self.action_spec = BinaryDiscreteTensorSpec(
- n=1, shape=[*self.batch_size, 1], device=self.device
- )
+ self.action_spec = Binary(n=1, shape=[*self.batch_size, 1], device=self.device)
self.register_buffer(
"count",
torch.zeros((*self.batch_size, 1), device=self.device, dtype=torch.int),
@@ -1072,7 +1038,10 @@ def _step(
tensordict: TensorDictBase,
) -> TensorDictBase:
action = tensordict.get(self.action_key)
- self.count += action.to(dtype=torch.int, device=self.device)
+ self.count += action.to(
+ dtype=torch.int,
+ device=self.action_spec.device if self.device is None else self.device,
+ )
tensordict = TensorDict(
source={
"observation": self.count.clone(),
@@ -1129,9 +1098,9 @@ def __init__(
self.nested_reward = nest_reward
if self.nested_obs_action:
- self.observation_spec = CompositeSpec(
+ self.observation_spec = Composite(
{
- "data": CompositeSpec(
+ "data": Composite(
{
"states": self.observation_spec["observation"]
.unsqueeze(-1)
@@ -1145,9 +1114,9 @@ def __init__(
},
shape=self.batch_size,
)
- self.action_spec = CompositeSpec(
+ self.action_spec = Composite(
{
- "data": CompositeSpec(
+ "data": Composite(
{
"action": self.action_spec.unsqueeze(-1).expand(
*self.batch_size, self.nested_dim, 1
@@ -1163,9 +1132,9 @@ def __init__(
)
if self.nested_reward:
- self.reward_spec = CompositeSpec(
+ self.reward_spec = Composite(
{
- "data": CompositeSpec(
+ "data": Composite(
{
"reward": self.reward_spec.unsqueeze(-1).expand(
*self.batch_size, self.nested_dim, 1
@@ -1184,12 +1153,12 @@ def __init__(
done_spec = self.full_done_spec.unsqueeze(-1).expand(
*self.batch_size, self.nested_dim
)
- done_spec = CompositeSpec(
+ done_spec = Composite(
{"data": done_spec},
shape=self.batch_size,
)
if self.has_root_done:
- done_spec["done"] = DiscreteTensorSpec(
+ done_spec["done"] = Categorical(
2,
shape=(
*self.batch_size,
@@ -1309,8 +1278,8 @@ def __init__(
self.max_steps = max_steps
- self.observation_spec = CompositeSpec(
- observation=UnboundedContinuousTensorSpec(
+ self.observation_spec = Composite(
+ observation=Unbounded(
(
*self.batch_size,
1,
@@ -1319,13 +1288,13 @@ def __init__(
),
shape=self.batch_size,
)
- self.reward_spec = UnboundedContinuousTensorSpec(
+ self.reward_spec = Unbounded(
(
*self.batch_size,
1,
)
)
- self.done_spec = DiscreteTensorSpec(
+ self.done_spec = Categorical(
2,
dtype=torch.bool,
shape=(
@@ -1333,7 +1302,7 @@ def __init__(
1,
),
)
- self.action_spec = BinaryDiscreteTensorSpec(n=1, shape=[*self.batch_size, 1])
+ self.action_spec = Binary(n=1, shape=[*self.batch_size, 1])
self.count = torch.zeros(
(*self.batch_size, 1), device=self.device, dtype=torch.int
@@ -1419,34 +1388,30 @@ def _make_specs(self):
obs_spec_unlazy = consolidate_spec(obs_specs)
action_specs = torch.stack(action_specs, dim=0)
- self.unbatched_observation_spec = CompositeSpec(
+ self.unbatched_observation_spec = Composite(
lazy=obs_spec_unlazy,
- state=UnboundedContinuousTensorSpec(shape=(64, 64, 3)),
+ state=Unbounded(shape=(64, 64, 3)),
device=self.device,
)
- self.unbatched_action_spec = CompositeSpec(
+ self.unbatched_action_spec = Composite(
lazy=action_specs,
device=self.device,
)
- self.unbatched_reward_spec = CompositeSpec(
+ self.unbatched_reward_spec = Composite(
{
- "lazy": CompositeSpec(
- {
- "reward": UnboundedContinuousTensorSpec(
- shape=(self.n_nested_dim, 1)
- )
- },
+ "lazy": Composite(
+ {"reward": Unbounded(shape=(self.n_nested_dim, 1))},
shape=(self.n_nested_dim,),
)
},
device=self.device,
)
- self.unbatched_done_spec = CompositeSpec(
+ self.unbatched_done_spec = Composite(
{
- "lazy": CompositeSpec(
+ "lazy": Composite(
{
- "done": DiscreteTensorSpec(
+ "done": Categorical(
n=2,
shape=(self.n_nested_dim, 1),
dtype=torch.bool,
@@ -1472,17 +1437,17 @@ def _make_specs(self):
)
def get_agent_obs_spec(self, i):
- camera = BoundedTensorSpec(low=0, high=200, shape=(7, 7, 3))
- vector_3d = UnboundedContinuousTensorSpec(shape=(3,))
- vector_2d = UnboundedContinuousTensorSpec(shape=(2,))
- lidar = BoundedTensorSpec(low=0, high=5, shape=(8,))
+ camera = Bounded(low=0, high=200, shape=(7, 7, 3))
+ vector_3d = Unbounded(shape=(3,))
+ vector_2d = Unbounded(shape=(2,))
+ lidar = Bounded(low=0, high=5, shape=(8,))
- tensor_0 = UnboundedContinuousTensorSpec(shape=(1,))
- tensor_1 = BoundedTensorSpec(low=0, high=3, shape=(1, 2))
- tensor_2 = UnboundedContinuousTensorSpec(shape=(1, 2, 3))
+ tensor_0 = Unbounded(shape=(1,))
+ tensor_1 = Bounded(low=0, high=3, shape=(1, 2))
+ tensor_2 = Unbounded(shape=(1, 2, 3))
if i == 0:
- return CompositeSpec(
+ return Composite(
{
"camera": camera,
"lidar": lidar,
@@ -1492,7 +1457,7 @@ def get_agent_obs_spec(self, i):
device=self.device,
)
elif i == 1:
- return CompositeSpec(
+ return Composite(
{
"camera": camera,
"lidar": lidar,
@@ -1502,7 +1467,7 @@ def get_agent_obs_spec(self, i):
device=self.device,
)
elif i == 2:
- return CompositeSpec(
+ return Composite(
{
"camera": camera,
"vector": vector_2d,
@@ -1514,8 +1479,8 @@ def get_agent_obs_spec(self, i):
raise ValueError(f"Index {i} undefined for index 3")
def get_agent_action_spec(self, i):
- action_3d = BoundedTensorSpec(low=-1, high=1, shape=(3,))
- action_2d = BoundedTensorSpec(low=-1, high=1, shape=(2,))
+ action_3d = Bounded(low=-1, high=1, shape=(3,))
+ action_2d = Bounded(low=-1, high=1, shape=(2,))
# Some have 2d action and some 3d
# TODO Introduce composite heterogeneous actions
@@ -1528,7 +1493,7 @@ def get_agent_action_spec(self, i):
else:
raise ValueError(f"Index {i} undefined for index 3")
- return CompositeSpec({"action": ret})
+ return Composite({"action": ret})
def _reset(
self,
@@ -1659,18 +1624,16 @@ def __init__(self, max_steps: int = 5, start_val: int = 0, **kwargs):
)
def make_specs(self):
- self.unbatched_observation_spec = CompositeSpec(
- nested_1=CompositeSpec(
- observation=BoundedTensorSpec(
- low=0, high=200, shape=(self.nested_dim_1, 3)
- ),
+ self.unbatched_observation_spec = Composite(
+ nested_1=Composite(
+ observation=Bounded(low=0, high=200, shape=(self.nested_dim_1, 3)),
shape=(self.nested_dim_1,),
),
- nested_2=CompositeSpec(
- observation=UnboundedContinuousTensorSpec(shape=(self.nested_dim_2, 2)),
+ nested_2=Composite(
+ observation=Unbounded(shape=(self.nested_dim_2, 2)),
shape=(self.nested_dim_2,),
),
- observation=UnboundedContinuousTensorSpec(
+ observation=Unbounded(
shape=(
10,
10,
@@ -1679,51 +1642,51 @@ def make_specs(self):
),
)
- self.unbatched_action_spec = CompositeSpec(
- nested_1=CompositeSpec(
- action=DiscreteTensorSpec(n=2, shape=(self.nested_dim_1,)),
+ self.unbatched_action_spec = Composite(
+ nested_1=Composite(
+ action=Categorical(n=2, shape=(self.nested_dim_1,)),
shape=(self.nested_dim_1,),
),
- nested_2=CompositeSpec(
- azione=BoundedTensorSpec(low=0, high=100, shape=(self.nested_dim_2, 1)),
+ nested_2=Composite(
+ azione=Bounded(low=0, high=100, shape=(self.nested_dim_2, 1)),
shape=(self.nested_dim_2,),
),
- action=OneHotDiscreteTensorSpec(n=2),
+ action=OneHot(n=2),
)
- self.unbatched_reward_spec = CompositeSpec(
- nested_1=CompositeSpec(
- gift=UnboundedContinuousTensorSpec(shape=(self.nested_dim_1, 1)),
+ self.unbatched_reward_spec = Composite(
+ nested_1=Composite(
+ gift=Unbounded(shape=(self.nested_dim_1, 1)),
shape=(self.nested_dim_1,),
),
- nested_2=CompositeSpec(
- reward=UnboundedContinuousTensorSpec(shape=(self.nested_dim_2, 1)),
+ nested_2=Composite(
+ reward=Unbounded(shape=(self.nested_dim_2, 1)),
shape=(self.nested_dim_2,),
),
- reward=UnboundedContinuousTensorSpec(shape=(1,)),
+ reward=Unbounded(shape=(1,)),
)
- self.unbatched_done_spec = CompositeSpec(
- nested_1=CompositeSpec(
- done=DiscreteTensorSpec(
+ self.unbatched_done_spec = Composite(
+ nested_1=Composite(
+ done=Categorical(
n=2,
shape=(self.nested_dim_1, 1),
dtype=torch.bool,
),
- terminated=DiscreteTensorSpec(
+ terminated=Categorical(
n=2,
shape=(self.nested_dim_1, 1),
dtype=torch.bool,
),
shape=(self.nested_dim_1,),
),
- nested_2=CompositeSpec(
- done=DiscreteTensorSpec(
+ nested_2=Composite(
+ done=Categorical(
n=2,
shape=(self.nested_dim_2, 1),
dtype=torch.bool,
),
- terminated=DiscreteTensorSpec(
+ terminated=Categorical(
n=2,
shape=(self.nested_dim_2, 1),
dtype=torch.bool,
@@ -1731,12 +1694,12 @@ def make_specs(self):
shape=(self.nested_dim_2,),
),
# done at the root always prevail
- done=DiscreteTensorSpec(
+ done=Categorical(
n=2,
shape=(1,),
dtype=torch.bool,
),
- terminated=DiscreteTensorSpec(
+ terminated=Categorical(
n=2,
shape=(1,),
dtype=torch.bool,
@@ -1829,15 +1792,15 @@ def _set_seed(self, seed: Optional[int]):
class EnvWithMetadata(EnvBase):
def __init__(self):
super().__init__()
- self.observation_spec = CompositeSpec(
- tensor=UnboundedContinuousTensorSpec(3),
- non_tensor=NonTensorSpec(shape=()),
+ self.observation_spec = Composite(
+ tensor=Unbounded(3),
+ non_tensor=NonTensor(shape=()),
)
- self.state_spec = CompositeSpec(
- non_tensor=NonTensorSpec(shape=()),
+ self.state_spec = Composite(
+ non_tensor=NonTensor(shape=()),
)
- self.reward_spec = UnboundedContinuousTensorSpec(1)
- self.action_spec = UnboundedContinuousTensorSpec(1)
+ self.reward_spec = Unbounded(1)
+ self.action_spec = Unbounded(1)
def _reset(self, tensordict):
data = self.observation_spec.zero()
@@ -1935,16 +1898,16 @@ def _reset(self, tensordict=None):
class EnvWithDynamicSpec(EnvBase):
def __init__(self, max_count=5):
super().__init__(batch_size=())
- self.observation_spec = CompositeSpec(
- observation=UnboundedContinuousTensorSpec(shape=(3, -1, 2)),
+ self.observation_spec = Composite(
+ observation=Unbounded(shape=(3, -1, 2)),
)
- self.action_spec = BoundedTensorSpec(low=-1, high=1, shape=(2,))
- self.full_done_spec = CompositeSpec(
- done=BinaryDiscreteTensorSpec(1, shape=(1,), dtype=torch.bool),
- terminated=BinaryDiscreteTensorSpec(1, shape=(1,), dtype=torch.bool),
- truncated=BinaryDiscreteTensorSpec(1, shape=(1,), dtype=torch.bool),
+ self.action_spec = Bounded(low=-1, high=1, shape=(2,))
+ self.full_done_spec = Composite(
+ done=Binary(1, shape=(1,), dtype=torch.bool),
+ terminated=Binary(1, shape=(1,), dtype=torch.bool),
+ truncated=Binary(1, shape=(1,), dtype=torch.bool),
)
- self.reward_spec = UnboundedContinuousTensorSpec((1,), dtype=torch.float)
+ self.reward_spec = Unbounded((1,), dtype=torch.float)
self.count = 0
self.max_count = max_count
diff --git a/test/test_actors.py b/test/test_actors.py
index 2d160e31bba..b81f322b708 100644
--- a/test/test_actors.py
+++ b/test/test_actors.py
@@ -14,14 +14,7 @@
from tensordict.nn.distributions import NormalParamExtractor
from torch import distributions as dist, nn
-from torchrl.data import (
- BinaryDiscreteTensorSpec,
- BoundedTensorSpec,
- CompositeSpec,
- DiscreteTensorSpec,
- MultiOneHotDiscreteTensorSpec,
- OneHotDiscreteTensorSpec,
-)
+from torchrl.data import Binary, Bounded, Categorical, Composite, MultiOneHot, OneHot
from torchrl.data.rlhf.dataset import _has_transformers
from torchrl.modules import MLP, SafeModule, TanhDelta, TanhNormal
from torchrl.modules.tensordict_module.actors import (
@@ -50,9 +43,7 @@
)
def test_probabilistic_actor_nested_delta(log_prob_key, nested_dim=5, n_actions=3):
env = NestedCountingEnv(nested_dim=nested_dim)
- action_spec = BoundedTensorSpec(
- shape=torch.Size((nested_dim, n_actions)), high=1, low=-1
- )
+ action_spec = Bounded(shape=torch.Size((nested_dim, n_actions)), high=1, low=-1)
policy_module = TensorDictModule(
nn.Linear(1, 1), in_keys=[("data", "states")], out_keys=[("data", "param")]
)
@@ -63,8 +54,8 @@ def test_probabilistic_actor_nested_delta(log_prob_key, nested_dim=5, n_actions=
out_keys=[("data", "action")],
distribution_class=TanhDelta,
distribution_kwargs={
- "min": action_spec.space.low,
- "max": action_spec.space.high,
+ "low": action_spec.space.low,
+ "high": action_spec.space.high,
},
log_prob_key=log_prob_key,
return_log_prob=True,
@@ -86,8 +77,8 @@ def test_probabilistic_actor_nested_delta(log_prob_key, nested_dim=5, n_actions=
out_keys=[("data", "action")],
distribution_class=TanhDelta,
distribution_kwargs={
- "min": action_spec.space.low,
- "max": action_spec.space.high,
+ "low": action_spec.space.low,
+ "high": action_spec.space.high,
},
log_prob_key=log_prob_key,
return_log_prob=True,
@@ -111,9 +102,7 @@ def test_probabilistic_actor_nested_delta(log_prob_key, nested_dim=5, n_actions=
)
def test_probabilistic_actor_nested_normal(log_prob_key, nested_dim=5, n_actions=3):
env = NestedCountingEnv(nested_dim=nested_dim)
- action_spec = BoundedTensorSpec(
- shape=torch.Size((nested_dim, n_actions)), high=1, low=-1
- )
+ action_spec = Bounded(shape=torch.Size((nested_dim, n_actions)), high=1, low=-1)
actor_net = nn.Sequential(
nn.Linear(1, 2),
NormalParamExtractor(),
@@ -181,7 +170,7 @@ def test_distributional_qvalue_hook_wrong_action_space(self):
DistributionalQValueHook(action_space="wrong_value", support=None)
def test_distributional_qvalue_hook_conflicting_spec(self):
- spec = OneHotDiscreteTensorSpec(3)
+ spec = OneHot(3)
_process_action_space_spec("one-hot", spec)
_process_action_space_spec("one_hot", spec)
_process_action_space_spec("one_hot", None)
@@ -190,19 +179,19 @@ def test_distributional_qvalue_hook_conflicting_spec(self):
ValueError, match="The action spec and the action space do not match"
):
_process_action_space_spec("multi-one-hot", spec)
- spec = MultiOneHotDiscreteTensorSpec([3, 3])
+ spec = MultiOneHot([3, 3])
_process_action_space_spec("multi-one-hot", spec)
_process_action_space_spec(spec, spec)
with pytest.raises(
ValueError, match="Passing an action_space as a TensorSpec and a spec"
):
- _process_action_space_spec(OneHotDiscreteTensorSpec(3), spec)
+ _process_action_space_spec(OneHot(3), spec)
with pytest.raises(
- ValueError, match="action_space cannot be of type CompositeSpec"
+ ValueError, match="action_space cannot be of type Composite"
):
- _process_action_space_spec(CompositeSpec(), spec)
+ _process_action_space_spec(Composite(), spec)
with pytest.raises(KeyError, match="action could not be found in the spec"):
- _process_action_space_spec(None, CompositeSpec())
+ _process_action_space_spec(None, Composite())
with pytest.raises(
ValueError, match="Neither action_space nor spec was defined"
):
@@ -248,10 +237,10 @@ def test_nested_keys(self, nested_action, batch_size, nested_dim=5):
ValueError,
match="Passing an action_space as a TensorSpec and a spec isn't allowed, unless they match.",
):
- _process_action_space_spec(BinaryDiscreteTensorSpec(n=1), action_spec)
- _process_action_space_spec(BinaryDiscreteTensorSpec(n=1), leaf_action_spec)
+ _process_action_space_spec(Binary(n=1), action_spec)
+ _process_action_space_spec(Binary(n=1), leaf_action_spec)
with pytest.raises(
- ValueError, match="action_space cannot be of type CompositeSpec"
+ ValueError, match="action_space cannot be of type Composite"
):
_process_action_space_spec(action_spec, None)
@@ -652,7 +641,7 @@ def test_value_based_policy(device):
torch.manual_seed(0)
obs_dim = 4
action_dim = 5
- action_spec = OneHotDiscreteTensorSpec(action_dim)
+ action_spec = OneHot(action_dim)
def make_net():
net = MLP(in_features=obs_dim, out_features=action_dim, depth=2, device=device)
@@ -681,9 +670,7 @@ def make_net():
assert (action.sum(-1) == 1).all()
-@pytest.mark.parametrize(
- "spec", [None, OneHotDiscreteTensorSpec(3), MultiOneHotDiscreteTensorSpec([3, 2])]
-)
+@pytest.mark.parametrize("spec", [None, OneHot(3), MultiOneHot([3, 2])])
@pytest.mark.parametrize(
"action_space", [None, "one-hot", "one_hot", "mult-one-hot", "mult_one_hot"]
)
@@ -706,12 +693,9 @@ def test_qvalactor_construct(
QValueActor(**kwargs)
return
if (
- type(spec) is MultiOneHotDiscreteTensorSpec
+ type(spec) is MultiOneHot
and action_space not in ("mult-one-hot", "mult_one_hot", None)
- ) or (
- type(spec) is OneHotDiscreteTensorSpec
- and action_space not in ("one-hot", "one_hot", None)
- ):
+ ) or (type(spec) is OneHot and action_space not in ("one-hot", "one_hot", None)):
with pytest.raises(
ValueError, match="The action spec and the action space do not match"
):
@@ -725,7 +709,7 @@ def test_value_based_policy_categorical(device):
torch.manual_seed(0)
obs_dim = 4
action_dim = 5
- action_spec = DiscreteTensorSpec(action_dim)
+ action_spec = Categorical(action_dim)
def make_net():
net = MLP(in_features=obs_dim, out_features=action_dim, depth=2, device=device)
diff --git a/test/test_collector.py b/test/test_collector.py
index 7d7208aead0..9e6ccd79408 100644
--- a/test/test_collector.py
+++ b/test/test_collector.py
@@ -50,7 +50,12 @@
TensorDict,
TensorDictBase,
)
-from tensordict.nn import TensorDictModule, TensorDictModuleBase, TensorDictSequential
+from tensordict.nn import (
+ CudaGraphModule,
+ TensorDictModule,
+ TensorDictModuleBase,
+ TensorDictSequential,
+)
from torch import nn
from torchrl._utils import (
@@ -68,13 +73,15 @@
)
from torchrl.collectors.utils import split_trajectories
from torchrl.data import (
- CompositeSpec,
+ Composite,
+ LazyMemmapStorage,
LazyTensorStorage,
- NonTensorSpec,
+ NonTensor,
ReplayBuffer,
TensorSpec,
- UnboundedContinuousTensorSpec,
+ Unbounded,
)
+from torchrl.data.utils import CloudpickleWrapper
from torchrl.envs import (
EnvBase,
EnvCreator,
@@ -114,19 +121,6 @@ def forward(self, observation):
return self.linear(observation)
-class TensorDictCompatiblePolicy(nn.Module):
- def __init__(self, out_features: int):
- super().__init__()
- self.in_keys = ["observation"]
- self.out_keys = ["action"]
- self.linear = nn.LazyLinear(out_features)
-
- def forward(self, tensordict):
- return tensordict.set(
- self.out_keys[0], self.linear(tensordict.get(self.in_keys[0]))
- )
-
-
class UnwrappablePolicy(nn.Module):
def __init__(self, out_features: int):
super().__init__()
@@ -210,22 +204,16 @@ class DeviceLessEnv(EnvBase):
def __init__(self, default_device):
self.default_device = default_device
super().__init__(device=None)
- self.observation_spec = CompositeSpec(
- observation=UnboundedContinuousTensorSpec((), device=default_device)
+ self.observation_spec = Composite(
+ observation=Unbounded((), device=default_device)
)
- self.reward_spec = UnboundedContinuousTensorSpec(1, device=default_device)
- self.full_done_spec = CompositeSpec(
- done=UnboundedContinuousTensorSpec(
- 1, dtype=torch.bool, device=self.default_device
- ),
- truncated=UnboundedContinuousTensorSpec(
- 1, dtype=torch.bool, device=self.default_device
- ),
- terminated=UnboundedContinuousTensorSpec(
- 1, dtype=torch.bool, device=self.default_device
- ),
+ self.reward_spec = Unbounded(1, device=default_device)
+ self.full_done_spec = Composite(
+ done=Unbounded(1, dtype=torch.bool, device=self.default_device),
+ truncated=Unbounded(1, dtype=torch.bool, device=self.default_device),
+ terminated=Unbounded(1, dtype=torch.bool, device=self.default_device),
)
- self.action_spec = UnboundedContinuousTensorSpec((), device=None)
+ self.action_spec = Unbounded((), device=None)
assert self.device is None
assert self.full_observation_spec is not None
assert self.full_done_spec is not None
@@ -268,29 +256,17 @@ class EnvWithDevice(EnvBase):
def __init__(self, default_device):
self.default_device = default_device
super().__init__(device=self.default_device)
- self.observation_spec = CompositeSpec(
- observation=UnboundedContinuousTensorSpec(
- (), device=self.default_device
- )
- )
- self.reward_spec = UnboundedContinuousTensorSpec(
- 1, device=self.default_device
+ self.observation_spec = Composite(
+ observation=Unbounded((), device=self.default_device)
)
- self.full_done_spec = CompositeSpec(
- done=UnboundedContinuousTensorSpec(
- 1, dtype=torch.bool, device=self.default_device
- ),
- truncated=UnboundedContinuousTensorSpec(
- 1, dtype=torch.bool, device=self.default_device
- ),
- terminated=UnboundedContinuousTensorSpec(
- 1, dtype=torch.bool, device=self.default_device
- ),
+ self.reward_spec = Unbounded(1, device=self.default_device)
+ self.full_done_spec = Composite(
+ done=Unbounded(1, dtype=torch.bool, device=self.default_device),
+ truncated=Unbounded(1, dtype=torch.bool, device=self.default_device),
+ terminated=Unbounded(1, dtype=torch.bool, device=self.default_device),
device=self.default_device,
)
- self.action_spec = UnboundedContinuousTensorSpec(
- (), device=self.default_device
- )
+ self.action_spec = Unbounded((), device=self.default_device)
assert self.device == _make_ordinal_device(
torch.device(self.default_device)
)
@@ -1295,7 +1271,7 @@ def make_env():
policy,
copier,
OrnsteinUhlenbeckProcessModule(
- spec=CompositeSpec({key: None for key in policy.out_keys})
+ spec=Composite({key: None for key in policy.out_keys})
),
)
@@ -1368,12 +1344,12 @@ def test_collector_output_keys(
],
}
if explicit_spec:
- hidden_spec = UnboundedContinuousTensorSpec((1, hidden_size))
- policy_kwargs["spec"] = CompositeSpec(
- action=UnboundedContinuousTensorSpec(),
+ hidden_spec = Unbounded((1, hidden_size))
+ policy_kwargs["spec"] = Composite(
+ action=Unbounded(),
hidden1=hidden_spec,
hidden2=hidden_spec,
- next=CompositeSpec(hidden1=hidden_spec, hidden2=hidden_spec),
+ next=Composite(hidden1=hidden_spec, hidden2=hidden_spec),
)
policy = SafeModule(**policy_kwargs)
@@ -1614,8 +1590,8 @@ def test_auto_wrap_error(self, collector_class, env_maker):
policy = UnwrappablePolicy(out_features=env_maker().action_spec.shape[-1])
with pytest.raises(
TypeError,
- match=(r"Arguments to policy.forward are incompatible with entries in"),
- ) if collector_class is SyncDataCollector else pytest.raises(EOFError):
+ match=("Arguments to policy.forward are incompatible with entries in"),
+ ):
collector_class(
**self._create_collector_kwargs(env_maker, collector_class, policy)
)
@@ -1844,10 +1820,15 @@ def test_set_truncated(collector_cls):
NestedCountingEnv(), InitTracker()
).add_truncated_keys()
env = env_fn()
- policy = env.rand_action
+ policy = CloudpickleWrapper(env.rand_action)
if collector_cls == SyncDataCollector:
collector = collector_cls(
- env, policy=policy, frames_per_batch=20, total_frames=-1, set_truncated=True
+ env,
+ policy=policy,
+ frames_per_batch=20,
+ total_frames=-1,
+ set_truncated=True,
+ trust_policy=True,
)
else:
collector = collector_cls(
@@ -1857,6 +1838,7 @@ def test_set_truncated(collector_cls):
total_frames=-1,
cat_results="stack",
set_truncated=True,
+ trust_policy=True,
)
try:
for data in collector:
@@ -2164,21 +2146,18 @@ def test_multi_collector_consistency(
assert_allclose_td(c2.unsqueeze(0), d2)
-@pytest.mark.skipif(not torch.cuda.device_count(), reason="No casting if no cuda")
+@pytest.mark.skipif(
+ not torch.cuda.is_available() and not torch.mps.is_available(),
+ reason="No casting if no cuda",
+)
class TestUpdateParams:
class DummyEnv(EnvBase):
def __init__(self, device, batch_size=[]): # noqa: B006
super().__init__(batch_size=batch_size, device=device)
self.state = torch.zeros(self.batch_size, device=device)
- self.observation_spec = CompositeSpec(
- state=UnboundedContinuousTensorSpec(shape=(), device=device)
- )
- self.action_spec = UnboundedContinuousTensorSpec(
- shape=batch_size, device=device
- )
- self.reward_spec = UnboundedContinuousTensorSpec(
- shape=(*batch_size, 1), device=device
- )
+ self.observation_spec = Composite(state=Unbounded(shape=(), device=device))
+ self.action_spec = Unbounded(shape=batch_size, device=device)
+ self.reward_spec = Unbounded(shape=(*batch_size, 1), device=device)
def _step(
self,
@@ -2228,8 +2207,8 @@ def forward(self, td):
@pytest.mark.parametrize(
"policy_device,env_device",
[
- ["cpu", "cuda"],
- ["cuda", "cpu"],
+ ["cpu", get_default_devices()[0]],
+ [get_default_devices()[0], "cpu"],
# ["cpu", "cuda:0"], # 1226: faster execution
# ["cuda:0", "cpu"],
# ["cuda", "cuda:0"],
@@ -2253,9 +2232,7 @@ def test_param_sync(self, give_weights, collector, policy_device, env_device):
policy.param.data += 1
policy.buf.data += 2
if give_weights:
- d = dict(policy.named_parameters())
- d.update(policy.named_buffers())
- p_w = TensorDict(d, [])
+ p_w = TensorDict.from_module(policy)
else:
p_w = None
col.update_policy_weights_(p_w)
@@ -2609,8 +2586,15 @@ def test_unique_traj_sync(self, cat_results):
buffer.extend(d)
assert c._use_buffers
traj_ids = buffer[:].get(("collector", "traj_ids"))
- # check that we have as many trajs as expected (no skip)
- assert traj_ids.unique().numel() == traj_ids.max() + 1
+ # Ideally, we'd like that (sorted_traj.values == sorted_traj.indices).all()
+ # but in practice, one env can reach the end of the rollout and do a reset
+ # (which we don't want to prevent) and increment the global traj count,
+ # when the others have not finished yet. In that case, this traj number will never
+ # appear.
+ # sorted_traj = traj_ids.unique().sort()
+ # assert (sorted_traj.values == sorted_traj.indices).all()
+ # assert traj_ids.unique().numel() == traj_ids.max() + 1
+
# check that trajs are not overlapping
if stack_results:
sets = [
@@ -2670,6 +2654,89 @@ def test_dynamic_multiasync_collector(self):
assert data.names[-1] == "time"
+class TestCompile:
+ @pytest.mark.parametrize(
+ "collector_cls",
+ # Clearing compiled policies causes segfault on machines with cuda
+ [SyncDataCollector, MultiaSyncDataCollector, MultiSyncDataCollector]
+ if not torch.cuda.is_available()
+ else [SyncDataCollector],
+ )
+ @pytest.mark.parametrize("compile_policy", [True, {}, {"mode": "default"}])
+ @pytest.mark.parametrize(
+ "device", [torch.device("cuda:0" if torch.cuda.is_available() else "cpu")]
+ )
+ def test_compiled_policy(self, collector_cls, compile_policy, device):
+ policy = TensorDictModule(
+ nn.Linear(7, 7, device=device), in_keys=["observation"], out_keys=["action"]
+ )
+ make_env = functools.partial(ContinuousActionVecMockEnv, device=device)
+ if collector_cls is SyncDataCollector:
+ torch._dynamo.reset_code_caches()
+ collector = SyncDataCollector(
+ make_env(),
+ policy,
+ frames_per_batch=30,
+ total_frames=120,
+ compile_policy=compile_policy,
+ )
+ assert collector.compiled_policy
+ else:
+ collector = collector_cls(
+ [make_env] * 2,
+ policy,
+ frames_per_batch=30,
+ total_frames=120,
+ compile_policy=compile_policy,
+ )
+ assert collector.compiled_policy
+ try:
+ for data in collector:
+ assert data is not None
+ finally:
+ collector.shutdown()
+ del collector
+
+ @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA is not available")
+ @pytest.mark.parametrize(
+ "collector_cls",
+ [SyncDataCollector],
+ )
+ @pytest.mark.parametrize("cudagraph_policy", [True, {}, {"warmup": 10}])
+ def test_cudagraph_policy(self, collector_cls, cudagraph_policy):
+ device = torch.device("cuda:0")
+ policy = TensorDictModule(
+ nn.Linear(7, 7, device=device), in_keys=["observation"], out_keys=["action"]
+ )
+ make_env = functools.partial(ContinuousActionVecMockEnv, device=device)
+ if collector_cls is SyncDataCollector:
+ collector = SyncDataCollector(
+ make_env(),
+ policy,
+ frames_per_batch=30,
+ total_frames=120,
+ cudagraph_policy=cudagraph_policy,
+ device=device,
+ )
+ assert collector.cudagraphed_policy
+ else:
+ collector = collector_cls(
+ [make_env] * 2,
+ policy,
+ frames_per_batch=30,
+ total_frames=120,
+ cudagraph_policy=cudagraph_policy,
+ device=device,
+ )
+ assert collector.cudagraphed_policy
+ try:
+ for data in collector:
+ assert data is not None
+ finally:
+ collector.shutdown()
+ del collector
+
+
@pytest.mark.skipif(not _has_gym, reason="gym required for this test")
class TestCollectorsNonTensor:
class AddNontTensorData(Transform):
@@ -2685,7 +2752,7 @@ def _reset(
def transform_observation_spec(
self, observation_spec: TensorSpec
) -> TensorSpec:
- observation_spec["nt"] = NonTensorSpec(shape=())
+ observation_spec["nt"] = NonTensor(shape=())
return observation_spec
@classmethod
@@ -2775,6 +2842,289 @@ def test_async(self, use_buffers):
del collector
+class TestCollectorRB:
+ @pytest.mark.skipif(not _has_gym, reason="requires gym.")
+ def test_collector_rb_sync(self):
+ env = SerialEnv(8, lambda cp=CARTPOLE_VERSIONED(): GymEnv(cp))
+ env.set_seed(0)
+ rb = ReplayBuffer(storage=LazyTensorStorage(256, ndim=2), batch_size=5)
+ collector = SyncDataCollector(
+ env,
+ RandomPolicy(env.action_spec),
+ replay_buffer=rb,
+ total_frames=256,
+ frames_per_batch=16,
+ )
+ torch.manual_seed(0)
+
+ for c in collector:
+ assert c is None
+ rb.sample()
+ rbdata0 = rb[:].clone()
+ collector.shutdown()
+ if not env.is_closed:
+ env.close()
+ del collector, env
+
+ env = SerialEnv(8, lambda cp=CARTPOLE_VERSIONED(): GymEnv(cp))
+ env.set_seed(0)
+ rb = ReplayBuffer(storage=LazyTensorStorage(256, ndim=2), batch_size=5)
+ collector = SyncDataCollector(
+ env, RandomPolicy(env.action_spec), total_frames=256, frames_per_batch=16
+ )
+ torch.manual_seed(0)
+
+ for i, c in enumerate(collector):
+ rb.extend(c)
+ torch.testing.assert_close(
+ rbdata0[:, : (i + 1) * 2]["observation"], rb[:]["observation"]
+ )
+ assert c is not None
+ rb.sample()
+
+ rbdata1 = rb[:].clone()
+ collector.shutdown()
+ if not env.is_closed:
+ env.close()
+ del collector, env
+ assert assert_allclose_td(rbdata0, rbdata1)
+
+ @pytest.mark.skipif(not _has_gym, reason="requires gym.")
+ @pytest.mark.parametrize("replay_buffer_chunk", [False, True])
+ @pytest.mark.parametrize("env_creator", [False, True])
+ @pytest.mark.parametrize("storagetype", [LazyTensorStorage, LazyMemmapStorage])
+ def test_collector_rb_multisync(
+ self, replay_buffer_chunk, env_creator, storagetype, tmpdir
+ ):
+ if not env_creator:
+ env = GymEnv(CARTPOLE_VERSIONED()).append_transform(StepCounter())
+ env.set_seed(0)
+ action_spec = env.action_spec
+ env = lambda env=env: env
+ else:
+ env = EnvCreator(
+ lambda cp=CARTPOLE_VERSIONED(): GymEnv(cp).append_transform(
+ StepCounter()
+ )
+ )
+ action_spec = env.meta_data.specs["input_spec", "full_action_spec"]
+
+ if storagetype == LazyMemmapStorage:
+ storagetype = functools.partial(LazyMemmapStorage, scratch_dir=tmpdir)
+ rb = ReplayBuffer(storage=storagetype(256), batch_size=5)
+
+ collector = MultiSyncDataCollector(
+ [env, env],
+ RandomPolicy(action_spec),
+ replay_buffer=rb,
+ total_frames=256,
+ frames_per_batch=32,
+ replay_buffer_chunk=replay_buffer_chunk,
+ )
+ torch.manual_seed(0)
+ pred_len = 0
+ for c in collector:
+ pred_len += 32
+ assert c is None
+ assert len(rb) == pred_len
+ collector.shutdown()
+ assert len(rb) == 256
+ if not replay_buffer_chunk:
+ steps_counts = rb["step_count"].squeeze().split(16)
+ collector_ids = rb["collector", "traj_ids"].squeeze().split(16)
+ for step_count, ids in zip(steps_counts, collector_ids):
+ step_countdiff = step_count.diff()
+ idsdiff = ids.diff()
+ assert (
+ (step_countdiff == 1) | (step_countdiff < 0)
+ ).all(), steps_counts
+ assert (idsdiff >= 0).all()
+
+ @pytest.mark.skipif(not _has_gym, reason="requires gym.")
+ @pytest.mark.parametrize("replay_buffer_chunk", [False, True])
+ @pytest.mark.parametrize("env_creator", [False, True])
+ @pytest.mark.parametrize("storagetype", [LazyTensorStorage, LazyMemmapStorage])
+ def test_collector_rb_multiasync(
+ self, replay_buffer_chunk, env_creator, storagetype, tmpdir
+ ):
+ if not env_creator:
+ env = GymEnv(CARTPOLE_VERSIONED()).append_transform(StepCounter())
+ env.set_seed(0)
+ action_spec = env.action_spec
+ env = lambda env=env: env
+ else:
+ env = EnvCreator(
+ lambda cp=CARTPOLE_VERSIONED(): GymEnv(cp).append_transform(
+ StepCounter()
+ )
+ )
+ action_spec = env.meta_data.specs["input_spec", "full_action_spec"]
+
+ if storagetype == LazyMemmapStorage:
+ storagetype = functools.partial(LazyMemmapStorage, scratch_dir=tmpdir)
+ rb = ReplayBuffer(storage=storagetype(256), batch_size=5)
+
+ collector = MultiaSyncDataCollector(
+ [env, env],
+ RandomPolicy(action_spec),
+ replay_buffer=rb,
+ total_frames=256,
+ frames_per_batch=16,
+ replay_buffer_chunk=replay_buffer_chunk,
+ )
+ torch.manual_seed(0)
+ pred_len = 0
+ for c in collector:
+ pred_len += 16
+ assert c is None
+ assert len(rb) >= pred_len
+ collector.shutdown()
+ assert len(rb) == 256
+ if not replay_buffer_chunk:
+ steps_counts = rb["step_count"].squeeze().split(16)
+ collector_ids = rb["collector", "traj_ids"].squeeze().split(16)
+ for step_count, ids in zip(steps_counts, collector_ids):
+ step_countdiff = step_count.diff()
+ idsdiff = ids.diff()
+ assert (
+ (step_countdiff == 1) | (step_countdiff < 0)
+ ).all(), steps_counts
+ assert (idsdiff >= 0).all()
+
+
+def __deepcopy_error__(*args, **kwargs):
+ raise RuntimeError("deepcopy not allowed")
+
+
+@pytest.mark.filterwarnings("error")
+@pytest.mark.filterwarnings("ignore:Tensordict is registered in PyTree")
+@pytest.mark.parametrize(
+ "collector_type",
+ [
+ SyncDataCollector,
+ MultiaSyncDataCollector,
+ functools.partial(MultiSyncDataCollector, cat_results="stack"),
+ ],
+)
+def test_no_deepcopy_policy(collector_type):
+ # Tests that the collector instantiation does not make a deepcopy of the policy if not necessary.
+ #
+ # The only situation where we want to deepcopy the policy is when the policy_device differs from the actual device
+ # of the policy. This can only be checked if the policy is an nn.Module and any of the params is not on the desired
+ # device.
+ #
+ # If the policy is not a nn.Module or has no parameter, policy_device should warn (we don't know what to do but we
+ # can trust that the user knows what to do).
+
+ shared_device = torch.device("cpu")
+ if torch.cuda.is_available():
+ original_device = torch.device("cuda:0")
+ elif torch.mps.is_available():
+ original_device = torch.device("mps")
+ else:
+ pytest.skip("No GPU or MPS device")
+
+ def make_policy(device=None, nn_module=True):
+ if nn_module:
+ return TensorDictModule(
+ nn.Linear(7, 7, device=device),
+ in_keys=["observation"],
+ out_keys=["action"],
+ )
+ policy = make_policy(device=device)
+ return CloudpickleWrapper(policy)
+
+ def make_and_test_policy(
+ policy,
+ policy_device=None,
+ env_device=None,
+ device=None,
+ trust_policy=None,
+ ):
+ # make sure policy errors when copied
+
+ policy.__deepcopy__ = __deepcopy_error__
+ envs = ContinuousActionVecMockEnv(device=env_device)
+ if collector_type is not SyncDataCollector:
+ envs = [envs, envs]
+ c = collector_type(
+ envs,
+ policy=policy,
+ total_frames=1000,
+ frames_per_batch=10,
+ policy_device=policy_device,
+ env_device=env_device,
+ device=device,
+ trust_policy=trust_policy,
+ )
+ for _ in c:
+ return
+
+ # Simplest use cases
+ policy = make_policy()
+ make_and_test_policy(policy)
+
+ if collector_type is SyncDataCollector or original_device.type != "mps":
+ # mps cannot be shared
+ policy = make_policy(device=original_device)
+ make_and_test_policy(policy, env_device=original_device)
+
+ if collector_type is SyncDataCollector or original_device.type != "mps":
+ policy = make_policy(device=original_device)
+ make_and_test_policy(
+ policy, policy_device=original_device, env_device=original_device
+ )
+
+ # a deepcopy must occur when the policy_device differs from the actual device
+ with pytest.raises(RuntimeError, match="deepcopy not allowed"):
+ policy = make_policy(device=original_device)
+ make_and_test_policy(
+ policy, policy_device=shared_device, env_device=shared_device
+ )
+
+ # a deepcopy must occur when device differs from the actual device
+ with pytest.raises(RuntimeError, match="deepcopy not allowed"):
+ policy = make_policy(device=original_device)
+ make_and_test_policy(policy, device=shared_device)
+
+ # If the policy is not an nn.Module, we can't cast it to device, so we assume that the policy device
+ # is there to inform us
+ substitute_device = (
+ original_device if torch.cuda.is_available() else torch.device("cpu")
+ )
+ policy = make_policy(substitute_device, nn_module=False)
+ with pytest.warns(UserWarning):
+ make_and_test_policy(
+ policy, policy_device=substitute_device, env_device=substitute_device
+ )
+ # For instance, if the env is on CPU, knowing the policy device helps with casting stuff on the right device
+ with pytest.warns(UserWarning):
+ make_and_test_policy(
+ policy, policy_device=substitute_device, env_device=shared_device
+ )
+ make_and_test_policy(
+ policy,
+ policy_device=substitute_device,
+ env_device=shared_device,
+ trust_policy=True,
+ )
+
+ # If there is no policy_device, we assume that the user is doing things right too but don't warn
+ if collector_type is SyncDataCollector or original_device.type != "mps":
+ policy = make_policy(original_device, nn_module=False)
+ make_and_test_policy(policy, env_device=original_device)
+
+ # If the policy is a CudaGraphModule, we know it's on cuda - no need to warn
+ if torch.cuda.is_available() and collector_type is SyncDataCollector:
+ policy = make_policy(original_device)
+ cudagraph_policy = CudaGraphModule(policy)
+ make_and_test_policy(
+ cudagraph_policy,
+ policy_device=original_device,
+ env_device=shared_device,
+ )
+
+
if __name__ == "__main__":
args, unknown = argparse.ArgumentParser().parse_known_args()
pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown)
diff --git a/test/test_cost.py b/test/test_cost.py
index 871d9170aa1..3530fff825d 100644
--- a/test/test_cost.py
+++ b/test/test_cost.py
@@ -13,8 +13,8 @@
from packaging import version as pack_version
from tensordict._C import unravel_keys
-
from tensordict.nn import (
+ CompositeDistribution,
InteractionType,
ProbabilisticTensorDictModule,
ProbabilisticTensorDictModule as ProbMod,
@@ -25,7 +25,6 @@
TensorDictSequential as Seq,
)
from torchrl.envs.utils import exploration_type, ExplorationType, set_exploration_type
-
from torchrl.modules.models import QMixer
_has_functorch = True
@@ -49,19 +48,12 @@
from mocking_classes import ContinuousActionConvMockEnv
# from torchrl.data.postprocs.utils import expand_as_right
-from tensordict import assert_allclose_td, TensorDict
+from tensordict import assert_allclose_td, TensorDict, TensorDictBase
from tensordict.nn import NormalParamExtractor, TensorDictModule
from tensordict.nn.utils import Buffer
from tensordict.utils import unravel_key
from torch import autograd, nn
-from torchrl.data import (
- BoundedTensorSpec,
- CompositeSpec,
- DiscreteTensorSpec,
- MultiOneHotDiscreteTensorSpec,
- OneHotDiscreteTensorSpec,
- UnboundedContinuousTensorSpec,
-)
+from torchrl.data import Bounded, Categorical, Composite, MultiOneHot, OneHot, Unbounded
from torchrl.data.postprocs.postprocs import MultiStep
from torchrl.envs.model_based.dreamer import DreamerEnv
from torchrl.envs.transforms import TensorDictPrimer, TransformedEnv
@@ -105,6 +97,7 @@
DreamerModelLoss,
DreamerValueLoss,
DTLoss,
+ GAILLoss,
IQLLoss,
KLPENPPOLoss,
OnlineDTLoss,
@@ -153,9 +146,15 @@
_split_and_pad_sequence,
)
+TORCH_VERSION = torch.__version__
# Capture all warnings
-pytestmark = pytest.mark.filterwarnings("error")
+pytestmark = [
+ pytest.mark.filterwarnings("error"),
+ pytest.mark.filterwarnings(
+ "ignore:The current behavior of MLP when not providing `num_cells` is that the number"
+ ),
+]
class _check_td_steady:
@@ -303,9 +302,9 @@ def _create_mock_actor(
):
# Actor
if action_spec_type == "one_hot":
- action_spec = OneHotDiscreteTensorSpec(action_dim)
+ action_spec = OneHot(action_dim)
elif action_spec_type == "categorical":
- action_spec = DiscreteTensorSpec(action_dim)
+ action_spec = Categorical(action_dim)
# elif action_spec_type == "nd_bounded":
# action_spec = BoundedTensorSpec(
# -torch.ones(action_dim), torch.ones(action_dim), (action_dim,)
@@ -317,7 +316,7 @@ def _create_mock_actor(
if is_nn_module:
return module.to(device)
actor = QValueActor(
- spec=CompositeSpec(
+ spec=Composite(
{
"action": action_spec,
(
@@ -348,14 +347,12 @@ def _create_mock_distributional_actor(
# Actor
var_nums = None
if action_spec_type == "mult_one_hot":
- action_spec = MultiOneHotDiscreteTensorSpec(
- [action_dim // 2, action_dim // 2]
- )
+ action_spec = MultiOneHot([action_dim // 2, action_dim // 2])
var_nums = action_spec.nvec
elif action_spec_type == "one_hot":
- action_spec = OneHotDiscreteTensorSpec(action_dim)
+ action_spec = OneHot(action_dim)
elif action_spec_type == "categorical":
- action_spec = DiscreteTensorSpec(action_dim)
+ action_spec = Categorical(action_dim)
else:
raise ValueError(f"Wrong {action_spec_type}")
support = torch.linspace(vmin, vmax, atoms, dtype=torch.float)
@@ -366,7 +363,7 @@ def _create_mock_distributional_actor(
# if is_nn_module:
# return module
actor = DistributionalQValueActor(
- spec=CompositeSpec(
+ spec=Composite(
{
"action": action_spec,
action_value_key: None,
@@ -775,7 +772,7 @@ def test_dqn_notensordict(
):
n_obs = 3
n_action = 4
- action_spec = OneHotDiscreteTensorSpec(n_action)
+ action_spec = OneHot(n_action)
module = nn.Linear(n_obs, n_action) # a simple value model
actor = QValueActor(
spec=action_spec,
@@ -936,9 +933,9 @@ def _create_mock_actor(
):
# Actor
if action_spec_type == "one_hot":
- action_spec = OneHotDiscreteTensorSpec(action_dim)
+ action_spec = OneHot(action_dim)
elif action_spec_type == "categorical":
- action_spec = DiscreteTensorSpec(action_dim)
+ action_spec = Categorical(action_dim)
else:
raise ValueError(f"Wrong {action_spec_type}")
@@ -1385,7 +1382,7 @@ class TestDDPG(LossModuleTestBase):
def _create_mock_actor(self, batch=2, obs_dim=3, action_dim=4, device="cpu"):
# Actor
- action_spec = BoundedTensorSpec(
+ action_spec = Bounded(
-torch.ones(action_dim), torch.ones(action_dim), (action_dim,)
)
module = nn.Linear(obs_dim, action_dim)
@@ -2023,7 +2020,7 @@ def _create_mock_actor(
dropout=0.0,
):
# Actor
- action_spec = BoundedTensorSpec(
+ action_spec = Bounded(
-torch.ones(action_dim), torch.ones(action_dim), (action_dim,)
)
module = nn.Sequential(
@@ -2375,7 +2372,7 @@ def test_td3_separate_losses(
loss_fn = TD3Loss(
actor,
value,
- action_spec=BoundedTensorSpec(shape=(n_act,), low=-1, high=1),
+ action_spec=Bounded(shape=(n_act,), low=-1, high=1),
loss_function="l2",
separate_losses=separate_losses,
)
@@ -2729,7 +2726,7 @@ def _create_mock_actor(
dropout=0.0,
):
# Actor
- action_spec = BoundedTensorSpec(
+ action_spec = Bounded(
-torch.ones(action_dim), torch.ones(action_dim), (action_dim,)
)
module = nn.Sequential(
@@ -3088,7 +3085,7 @@ def test_td3bc_separate_losses(
loss_fn = TD3BCLoss(
actor,
value,
- action_spec=BoundedTensorSpec(shape=(n_act,), low=-1, high=1),
+ action_spec=Bounded(shape=(n_act,), low=-1, high=1),
loss_function="l2",
separate_losses=separate_losses,
)
@@ -3453,21 +3450,40 @@ def _create_mock_actor(
device="cpu",
observation_key="observation",
action_key="action",
+ composite_action_dist=False,
):
# Actor
- action_spec = BoundedTensorSpec(
+ action_spec = Bounded(
-torch.ones(action_dim), torch.ones(action_dim), (action_dim,)
)
+ if composite_action_dist:
+ action_spec = Composite({action_key: {"action1": action_spec}})
net = nn.Sequential(nn.Linear(obs_dim, 2 * action_dim), NormalParamExtractor())
+ if composite_action_dist:
+ distribution_class = functools.partial(
+ CompositeDistribution,
+ distribution_map={
+ "action1": TanhNormal,
+ },
+ aggregate_probabilities=True,
+ )
+ module_out_keys = [
+ ("params", "action1", "loc"),
+ ("params", "action1", "scale"),
+ ]
+ actor_in_keys = ["params"]
+ else:
+ distribution_class = TanhNormal
+ module_out_keys = actor_in_keys = ["loc", "scale"]
module = TensorDictModule(
- net, in_keys=[observation_key], out_keys=["loc", "scale"]
+ net, in_keys=[observation_key], out_keys=module_out_keys
)
actor = ProbabilisticActor(
module=module,
- in_keys=["loc", "scale"],
- spec=action_spec,
- distribution_class=TanhNormal,
+ distribution_class=distribution_class,
+ in_keys=actor_in_keys,
out_keys=[action_key],
+ spec=action_spec,
)
return actor.to(device)
@@ -3487,6 +3503,8 @@ def __init__(self):
self.linear = nn.Linear(obs_dim + action_dim, 1)
def forward(self, obs, act):
+ if isinstance(act, TensorDictBase):
+ act = act.get("action1")
return self.linear(torch.cat([obs, act], -1))
module = ValueClass()
@@ -3515,8 +3533,26 @@ def _create_mock_value(
return value.to(device)
def _create_mock_common_layer_setup(
- self, n_obs=3, n_act=4, ncells=4, batch=2, n_hidden=2
+ self,
+ n_obs=3,
+ n_act=4,
+ ncells=4,
+ batch=2,
+ n_hidden=2,
+ composite_action_dist=False,
):
+ class QValueClass(nn.Module):
+ def __init__(self):
+ super().__init__()
+ self.linear1 = nn.Linear(n_hidden + n_act, n_hidden)
+ self.relu = nn.ReLU()
+ self.linear2 = nn.Linear(n_hidden, 1)
+
+ def forward(self, obs, act):
+ if isinstance(act, TensorDictBase):
+ act = act.get("action1")
+ return self.linear2(self.relu(self.linear1(torch.cat([obs, act], -1))))
+
common = MLP(
num_cells=ncells,
in_features=n_obs,
@@ -3529,17 +3565,13 @@ def _create_mock_common_layer_setup(
depth=1,
out_features=2 * n_act,
)
- qvalue = MLP(
- in_features=n_hidden + n_act,
- num_cells=ncells,
- depth=1,
- out_features=1,
- )
+ qvalue = QValueClass()
batch = [batch]
+ action = torch.randn(*batch, n_act)
td = TensorDict(
{
"obs": torch.randn(*batch, n_obs),
- "action": torch.randn(*batch, n_act),
+ "action": {"action1": action} if composite_action_dist else action,
"done": torch.zeros(*batch, 1, dtype=torch.bool),
"terminated": torch.zeros(*batch, 1, dtype=torch.bool),
"next": {
@@ -3552,14 +3584,30 @@ def _create_mock_common_layer_setup(
batch,
)
common = Mod(common, in_keys=["obs"], out_keys=["hidden"])
+ if composite_action_dist:
+ distribution_class = functools.partial(
+ CompositeDistribution,
+ distribution_map={
+ "action1": TanhNormal,
+ },
+ aggregate_probabilities=True,
+ )
+ module_out_keys = [
+ ("params", "action1", "loc"),
+ ("params", "action1", "scale"),
+ ]
+ actor_in_keys = ["params"]
+ else:
+ distribution_class = TanhNormal
+ module_out_keys = actor_in_keys = ["loc", "scale"]
actor = ProbSeq(
common,
Mod(actor_net, in_keys=["hidden"], out_keys=["param"]),
- Mod(NormalParamExtractor(), in_keys=["param"], out_keys=["loc", "scale"]),
+ Mod(NormalParamExtractor(), in_keys=["param"], out_keys=module_out_keys),
ProbMod(
- in_keys=["loc", "scale"],
+ in_keys=actor_in_keys,
out_keys=["action"],
- distribution_class=TanhNormal,
+ distribution_class=distribution_class,
),
)
qvalue_head = Mod(
@@ -3585,6 +3633,7 @@ def _create_mock_data_sac(
done_key="done",
terminated_key="terminated",
reward_key="reward",
+ composite_action_dist=False,
):
# create a tensordict
obs = torch.randn(batch, obs_dim, device=device)
@@ -3606,14 +3655,21 @@ def _create_mock_data_sac(
terminated_key: terminated,
reward_key: reward,
},
- action_key: action,
+ action_key: {"action1": action} if composite_action_dist else action,
},
device=device,
)
return td
def _create_seq_mock_data_sac(
- self, batch=8, T=4, obs_dim=3, action_dim=4, atoms=None, device="cpu"
+ self,
+ batch=8,
+ T=4,
+ obs_dim=3,
+ action_dim=4,
+ atoms=None,
+ device="cpu",
+ composite_action_dist=False,
):
# create a tensordict
total_obs = torch.randn(batch, T + 1, obs_dim, device=device)
@@ -3629,6 +3685,7 @@ def _create_seq_mock_data_sac(
done = torch.zeros(batch, T, 1, dtype=torch.bool, device=device)
terminated = torch.zeros(batch, T, 1, dtype=torch.bool, device=device)
mask = torch.ones(batch, T, dtype=torch.bool, device=device)
+ action = action.masked_fill_(~mask.unsqueeze(-1), 0.0)
td = TensorDict(
batch_size=(batch, T),
source={
@@ -3640,7 +3697,7 @@ def _create_seq_mock_data_sac(
"reward": reward.masked_fill_(~mask.unsqueeze(-1), 0.0),
},
"collector": {"mask": mask},
- "action": action.masked_fill_(~mask.unsqueeze(-1), 0.0),
+ "action": {"action1": action} if composite_action_dist else action,
},
names=[None, "time"],
device=device,
@@ -3653,6 +3710,7 @@ def _create_seq_mock_data_sac(
@pytest.mark.parametrize("num_qvalue", [1, 2, 4, 8])
@pytest.mark.parametrize("device", get_default_devices())
@pytest.mark.parametrize("td_est", list(ValueEstimators) + [None])
+ @pytest.mark.parametrize("composite_action_dist", [True, False])
def test_sac(
self,
delay_value,
@@ -3662,14 +3720,19 @@ def test_sac(
device,
version,
td_est,
+ composite_action_dist,
):
if (delay_actor or delay_qvalue) and not delay_value:
pytest.skip("incompatible config")
torch.manual_seed(self.seed)
- td = self._create_mock_data_sac(device=device)
+ td = self._create_mock_data_sac(
+ device=device, composite_action_dist=composite_action_dist
+ )
- actor = self._create_mock_actor(device=device)
+ actor = self._create_mock_actor(
+ device=device, composite_action_dist=composite_action_dist
+ )
qvalue = self._create_mock_qvalue(device=device)
if version == 1:
value = self._create_mock_value(device=device)
@@ -3819,6 +3882,7 @@ def test_sac(
@pytest.mark.parametrize("delay_qvalue", (True, False))
@pytest.mark.parametrize("num_qvalue", [2])
@pytest.mark.parametrize("device", get_default_devices())
+ @pytest.mark.parametrize("composite_action_dist", [True, False])
def test_sac_state_dict(
self,
delay_value,
@@ -3827,13 +3891,16 @@ def test_sac_state_dict(
num_qvalue,
device,
version,
+ composite_action_dist,
):
if (delay_actor or delay_qvalue) and not delay_value:
pytest.skip("incompatible config")
torch.manual_seed(self.seed)
- actor = self._create_mock_actor(device=device)
+ actor = self._create_mock_actor(
+ device=device, composite_action_dist=composite_action_dist
+ )
qvalue = self._create_mock_qvalue(device=device)
if version == 1:
value = self._create_mock_value(device=device)
@@ -3869,20 +3936,24 @@ def test_sac_state_dict(
@pytest.mark.parametrize("device", get_default_devices())
@pytest.mark.parametrize("separate_losses", [False, True])
+ @pytest.mark.parametrize("composite_action_dist", [True, False])
def test_sac_separate_losses(
self,
device,
separate_losses,
version,
+ composite_action_dist,
n_act=4,
):
torch.manual_seed(self.seed)
- actor, qvalue, common, td = self._create_mock_common_layer_setup(n_act=n_act)
+ actor, qvalue, common, td = self._create_mock_common_layer_setup(
+ n_act=n_act, composite_action_dist=composite_action_dist
+ )
loss_fn = SACLoss(
actor_network=actor,
qvalue_network=qvalue,
- action_spec=UnboundedContinuousTensorSpec(shape=(n_act,)),
+ action_spec=Unbounded(shape=(n_act,)),
num_qvalue_nets=1,
separate_losses=separate_losses,
)
@@ -3963,6 +4034,7 @@ def test_sac_separate_losses(
@pytest.mark.parametrize("delay_qvalue", (True, False))
@pytest.mark.parametrize("num_qvalue", [1, 2, 4, 8])
@pytest.mark.parametrize("device", get_default_devices())
+ @pytest.mark.parametrize("composite_action_dist", [True, False])
def test_sac_batcher(
self,
n,
@@ -3972,13 +4044,18 @@ def test_sac_batcher(
num_qvalue,
device,
version,
+ composite_action_dist,
):
if (delay_actor or delay_qvalue) and not delay_value:
pytest.skip("incompatible config")
torch.manual_seed(self.seed)
- td = self._create_seq_mock_data_sac(device=device)
+ td = self._create_seq_mock_data_sac(
+ device=device, composite_action_dist=composite_action_dist
+ )
- actor = self._create_mock_actor(device=device)
+ actor = self._create_mock_actor(
+ device=device, composite_action_dist=composite_action_dist
+ )
qvalue = self._create_mock_qvalue(device=device)
if version == 1:
value = self._create_mock_value(device=device)
@@ -4129,10 +4206,11 @@ def test_sac_batcher(
@pytest.mark.parametrize(
"td_est", [ValueEstimators.TD1, ValueEstimators.TD0, ValueEstimators.TDLambda]
)
- def test_sac_tensordict_keys(self, td_est, version):
- td = self._create_mock_data_sac()
+ @pytest.mark.parametrize("composite_action_dist", [True, False])
+ def test_sac_tensordict_keys(self, td_est, version, composite_action_dist):
+ td = self._create_mock_data_sac(composite_action_dist=composite_action_dist)
- actor = self._create_mock_actor()
+ actor = self._create_mock_actor(composite_action_dist=composite_action_dist)
qvalue = self._create_mock_qvalue()
if version == 1:
value = self._create_mock_value()
@@ -4152,7 +4230,7 @@ def test_sac_tensordict_keys(self, td_est, version):
"value": "state_value",
"state_action_value": "state_action_value",
"action": "action",
- "log_prob": "_log_prob",
+ "log_prob": "sample_log_prob",
"reward": "reward",
"done": "done",
"terminated": "terminated",
@@ -4286,14 +4364,14 @@ def test_state_dict(self, version):
loss = SACLoss(
actor_network=policy,
qvalue_network=value,
- action_spec=UnboundedContinuousTensorSpec(shape=(2,)),
+ action_spec=Unbounded(shape=(2,)),
)
state = loss.state_dict()
loss = SACLoss(
actor_network=policy,
qvalue_network=value,
- action_spec=UnboundedContinuousTensorSpec(shape=(2,)),
+ action_spec=Unbounded(shape=(2,)),
)
loss.load_state_dict(state)
@@ -4301,7 +4379,7 @@ def test_state_dict(self, version):
loss = SACLoss(
actor_network=policy,
qvalue_network=value,
- action_spec=UnboundedContinuousTensorSpec(shape=(2,)),
+ action_spec=Unbounded(shape=(2,)),
)
loss.target_entropy
state = loss.state_dict()
@@ -4309,20 +4387,25 @@ def test_state_dict(self, version):
loss = SACLoss(
actor_network=policy,
qvalue_network=value,
- action_spec=UnboundedContinuousTensorSpec(shape=(2,)),
+ action_spec=Unbounded(shape=(2,)),
)
loss.load_state_dict(state)
@pytest.mark.parametrize("reduction", [None, "none", "mean", "sum"])
- def test_sac_reduction(self, reduction, version):
+ @pytest.mark.parametrize("composite_action_dist", [True, False])
+ def test_sac_reduction(self, reduction, version, composite_action_dist):
torch.manual_seed(self.seed)
device = (
torch.device("cpu")
if torch.cuda.device_count() == 0
else torch.device("cuda")
)
- td = self._create_mock_data_sac(device=device)
- actor = self._create_mock_actor(device=device)
+ td = self._create_mock_data_sac(
+ device=device, composite_action_dist=composite_action_dist
+ )
+ actor = self._create_mock_actor(
+ device=device, composite_action_dist=composite_action_dist
+ )
qvalue = self._create_mock_qvalue(device=device)
if version == 1:
value = self._create_mock_value(device=device)
@@ -4367,7 +4450,7 @@ def _create_mock_actor(
action_key="action",
):
# Actor
- action_spec = OneHotDiscreteTensorSpec(action_dim)
+ action_spec = OneHot(action_dim)
net = nn.Sequential(nn.Linear(obs_dim, 2 * action_dim), NormalParamExtractor())
module = TensorDictModule(net, in_keys=[observation_key], out_keys=["logits"])
actor = ProbabilisticActor(
@@ -4953,7 +5036,7 @@ def _create_mock_actor(
action_key="action",
):
# Actor
- action_spec = BoundedTensorSpec(
+ action_spec = Bounded(
-torch.ones(action_dim), torch.ones(action_dim), (action_dim,)
)
net = nn.Sequential(nn.Linear(obs_dim, 2 * action_dim), NormalParamExtractor())
@@ -5269,7 +5352,7 @@ def test_crossq_separate_losses(
loss_fn = CrossQLoss(
actor_network=actor,
qvalue_network=qvalue,
- action_spec=UnboundedContinuousTensorSpec(shape=(n_act,)),
+ action_spec=Unbounded(shape=(n_act,)),
num_qvalue_nets=1,
separate_losses=separate_losses,
)
@@ -5574,14 +5657,14 @@ def test_state_dict(
loss = CrossQLoss(
actor_network=policy,
qvalue_network=value,
- action_spec=UnboundedContinuousTensorSpec(shape=(2,)),
+ action_spec=Unbounded(shape=(2,)),
)
state = loss.state_dict()
loss = CrossQLoss(
actor_network=policy,
qvalue_network=value,
- action_spec=UnboundedContinuousTensorSpec(shape=(2,)),
+ action_spec=Unbounded(shape=(2,)),
)
loss.load_state_dict(state)
@@ -5589,7 +5672,7 @@ def test_state_dict(
loss = CrossQLoss(
actor_network=policy,
qvalue_network=value,
- action_spec=UnboundedContinuousTensorSpec(shape=(2,)),
+ action_spec=Unbounded(shape=(2,)),
)
loss.target_entropy
state = loss.state_dict()
@@ -5597,7 +5680,7 @@ def test_state_dict(
loss = CrossQLoss(
actor_network=policy,
qvalue_network=value,
- action_spec=UnboundedContinuousTensorSpec(shape=(2,)),
+ action_spec=Unbounded(shape=(2,)),
)
loss.load_state_dict(state)
@@ -5648,7 +5731,7 @@ def _create_mock_actor(
action_key="action",
):
# Actor
- action_spec = BoundedTensorSpec(
+ action_spec = Bounded(
-torch.ones(action_dim), torch.ones(action_dim), (action_dim,)
)
net = nn.Sequential(nn.Linear(obs_dim, 2 * action_dim), NormalParamExtractor())
@@ -6593,7 +6676,7 @@ class TestCQL(LossModuleTestBase):
def _create_mock_actor(self, batch=2, obs_dim=3, action_dim=4, device="cpu"):
# Actor
- action_spec = BoundedTensorSpec(
+ action_spec = Bounded(
-torch.ones(action_dim), torch.ones(action_dim), (action_dim,)
)
net = nn.Sequential(nn.Linear(obs_dim, 2 * action_dim), NormalParamExtractor())
@@ -7156,9 +7239,9 @@ def _create_mock_actor(
):
# Actor
if action_spec_type == "one_hot":
- action_spec = OneHotDiscreteTensorSpec(action_dim)
+ action_spec = OneHot(action_dim)
elif action_spec_type == "categorical":
- action_spec = DiscreteTensorSpec(action_dim)
+ action_spec = Categorical(action_dim)
else:
raise ValueError(f"Wrong action spec type: {action_spec_type}")
@@ -7166,7 +7249,7 @@ def _create_mock_actor(
if is_nn_module:
return module.to(device)
actor = QValueActor(
- spec=CompositeSpec(
+ spec=Composite(
{
"action": action_spec,
(
@@ -7476,7 +7559,7 @@ def test_dcql_notensordict(
):
n_obs = 3
n_action = 4
- action_spec = OneHotDiscreteTensorSpec(n_action)
+ action_spec = OneHot(n_action)
module = nn.Linear(n_obs, n_action) # a simple value model
actor = QValueActor(
spec=action_spec,
@@ -7547,21 +7630,46 @@ def _create_mock_actor(
obs_dim=3,
action_dim=4,
device="cpu",
+ action_key="action",
observation_key="observation",
sample_log_prob_key="sample_log_prob",
+ composite_action_dist=False,
):
# Actor
- action_spec = BoundedTensorSpec(
+ action_spec = Bounded(
-torch.ones(action_dim), torch.ones(action_dim), (action_dim,)
)
+ if composite_action_dist:
+ action_spec = Composite({action_key: {"action1": action_spec}})
net = nn.Sequential(nn.Linear(obs_dim, 2 * action_dim), NormalParamExtractor())
+ if composite_action_dist:
+ distribution_class = functools.partial(
+ CompositeDistribution,
+ distribution_map={
+ "action1": TanhNormal,
+ },
+ name_map={
+ "action1": (action_key, "action1"),
+ },
+ log_prob_key=sample_log_prob_key,
+ aggregate_probabilities=True,
+ )
+ module_out_keys = [
+ ("params", "action1", "loc"),
+ ("params", "action1", "scale"),
+ ]
+ actor_in_keys = ["params"]
+ else:
+ distribution_class = TanhNormal
+ module_out_keys = actor_in_keys = ["loc", "scale"]
module = TensorDictModule(
- net, in_keys=[observation_key], out_keys=["loc", "scale"]
+ net, in_keys=[observation_key], out_keys=module_out_keys
)
actor = ProbabilisticActor(
module=module,
- distribution_class=TanhNormal,
- in_keys=["loc", "scale"],
+ distribution_class=distribution_class,
+ in_keys=actor_in_keys,
+ out_keys=[action_key],
spec=action_spec,
return_log_prob=True,
log_prob_key=sample_log_prob_key,
@@ -7585,22 +7693,52 @@ def _create_mock_value(
)
return value.to(device)
- def _create_mock_actor_value(self, batch=2, obs_dim=3, action_dim=4, device="cpu"):
+ def _create_mock_actor_value(
+ self,
+ batch=2,
+ obs_dim=3,
+ action_dim=4,
+ device="cpu",
+ composite_action_dist=False,
+ sample_log_prob_key="sample_log_prob",
+ ):
# Actor
- action_spec = BoundedTensorSpec(
+ action_spec = Bounded(
-torch.ones(action_dim), torch.ones(action_dim), (action_dim,)
)
+ if composite_action_dist:
+ action_spec = Composite({"action": {"action1": action_spec}})
base_layer = nn.Linear(obs_dim, 5)
net = nn.Sequential(
base_layer, nn.Linear(5, 2 * action_dim), NormalParamExtractor()
)
+ if composite_action_dist:
+ distribution_class = functools.partial(
+ CompositeDistribution,
+ distribution_map={
+ "action1": TanhNormal,
+ },
+ name_map={
+ "action1": ("action", "action1"),
+ },
+ log_prob_key=sample_log_prob_key,
+ aggregate_probabilities=True,
+ )
+ module_out_keys = [
+ ("params", "action1", "loc"),
+ ("params", "action1", "scale"),
+ ]
+ actor_in_keys = ["params"]
+ else:
+ distribution_class = TanhNormal
+ module_out_keys = actor_in_keys = ["loc", "scale"]
module = TensorDictModule(
- net, in_keys=["observation"], out_keys=["loc", "scale"]
+ net, in_keys=["observation"], out_keys=module_out_keys
)
actor = ProbabilisticActor(
module=module,
- distribution_class=TanhNormal,
- in_keys=["loc", "scale"],
+ distribution_class=distribution_class,
+ in_keys=actor_in_keys,
spec=action_spec,
return_log_prob=True,
)
@@ -7612,22 +7750,50 @@ def _create_mock_actor_value(self, batch=2, obs_dim=3, action_dim=4, device="cpu
return actor.to(device), value.to(device)
def _create_mock_actor_value_shared(
- self, batch=2, obs_dim=3, action_dim=4, device="cpu"
+ self,
+ batch=2,
+ obs_dim=3,
+ action_dim=4,
+ device="cpu",
+ composite_action_dist=False,
+ sample_log_prob_key="sample_log_prob",
):
# Actor
- action_spec = BoundedTensorSpec(
+ action_spec = Bounded(
-torch.ones(action_dim), torch.ones(action_dim), (action_dim,)
)
+ if composite_action_dist:
+ action_spec = Composite({"action": {"action1": action_spec}})
base_layer = nn.Linear(obs_dim, 5)
common = TensorDictModule(
base_layer, in_keys=["observation"], out_keys=["hidden"]
)
net = nn.Sequential(nn.Linear(5, 2 * action_dim), NormalParamExtractor())
- module = TensorDictModule(net, in_keys=["hidden"], out_keys=["loc", "scale"])
+ if composite_action_dist:
+ distribution_class = functools.partial(
+ CompositeDistribution,
+ distribution_map={
+ "action1": TanhNormal,
+ },
+ name_map={
+ "action1": ("action", "action1"),
+ },
+ log_prob_key=sample_log_prob_key,
+ aggregate_probabilities=True,
+ )
+ module_out_keys = [
+ ("params", "action1", "loc"),
+ ("params", "action1", "scale"),
+ ]
+ actor_in_keys = ["params"]
+ else:
+ distribution_class = TanhNormal
+ module_out_keys = actor_in_keys = ["loc", "scale"]
+ module = TensorDictModule(net, in_keys=["hidden"], out_keys=module_out_keys)
actor_head = ProbabilisticActor(
module=module,
- distribution_class=TanhNormal,
- in_keys=["loc", "scale"],
+ distribution_class=distribution_class,
+ in_keys=actor_in_keys,
spec=action_spec,
return_log_prob=True,
)
@@ -7657,6 +7823,7 @@ def _create_mock_data_ppo(
done_key="done",
terminated_key="terminated",
sample_log_prob_key="sample_log_prob",
+ composite_action_dist=False,
):
# create a tensordict
obs = torch.randn(batch, obs_dim, device=device)
@@ -7682,13 +7849,17 @@ def _create_mock_data_ppo(
terminated_key: terminated,
reward_key: reward,
},
- action_key: action,
+ action_key: {"action1": action} if composite_action_dist else action,
sample_log_prob_key: torch.randn_like(action[..., 1]) / 10,
- loc_key: loc,
- scale_key: scale,
},
device=device,
)
+ if composite_action_dist:
+ td[("params", "action1", loc_key)] = loc
+ td[("params", "action1", scale_key)] = scale
+ else:
+ td[loc_key] = loc
+ td[scale_key] = scale
return td
def _create_seq_mock_data_ppo(
@@ -7701,6 +7872,7 @@ def _create_seq_mock_data_ppo(
device="cpu",
sample_log_prob_key="sample_log_prob",
action_key="action",
+ composite_action_dist=False,
):
# create a tensordict
total_obs = torch.randn(batch, T + 1, obs_dim, device=device)
@@ -7716,8 +7888,11 @@ def _create_seq_mock_data_ppo(
done = torch.zeros(batch, T, 1, dtype=torch.bool, device=device)
terminated = torch.zeros(batch, T, 1, dtype=torch.bool, device=device)
mask = torch.ones(batch, T, dtype=torch.bool, device=device)
+ action = action.masked_fill_(~mask.unsqueeze(-1), 0.0)
params_mean = torch.randn_like(action) / 10
params_scale = torch.rand_like(action) / 10
+ loc = params_mean.masked_fill_(~mask.unsqueeze(-1), 0.0)
+ scale = params_scale.masked_fill_(~mask.unsqueeze(-1), 0.0)
td = TensorDict(
batch_size=(batch, T),
source={
@@ -7729,16 +7904,21 @@ def _create_seq_mock_data_ppo(
"reward": reward.masked_fill_(~mask.unsqueeze(-1), 0.0),
},
"collector": {"mask": mask},
- action_key: action.masked_fill_(~mask.unsqueeze(-1), 0.0),
+ action_key: {"action1": action} if composite_action_dist else action,
sample_log_prob_key: (
torch.randn_like(action[..., 1]) / 10
).masked_fill_(~mask, 0.0),
- "loc": params_mean.masked_fill_(~mask.unsqueeze(-1), 0.0),
- "scale": params_scale.masked_fill_(~mask.unsqueeze(-1), 0.0),
},
device=device,
names=[None, "time"],
)
+ if composite_action_dist:
+ td[("params", "action1", "loc")] = loc
+ td[("params", "action1", "scale")] = scale
+ else:
+ td["loc"] = loc
+ td["scale"] = scale
+
return td
@pytest.mark.parametrize("loss_class", (PPOLoss, ClipPPOLoss, KLPENPPOLoss))
@@ -7747,6 +7927,7 @@ def _create_seq_mock_data_ppo(
@pytest.mark.parametrize("device", get_default_devices())
@pytest.mark.parametrize("td_est", list(ValueEstimators) + [None])
@pytest.mark.parametrize("functional", [True, False])
+ @pytest.mark.parametrize("composite_action_dist", [True, False])
def test_ppo(
self,
loss_class,
@@ -7755,11 +7936,16 @@ def test_ppo(
advantage,
td_est,
functional,
+ composite_action_dist,
):
torch.manual_seed(self.seed)
- td = self._create_seq_mock_data_ppo(device=device)
+ td = self._create_seq_mock_data_ppo(
+ device=device, composite_action_dist=composite_action_dist
+ )
- actor = self._create_mock_actor(device=device)
+ actor = self._create_mock_actor(
+ device=device, composite_action_dist=composite_action_dist
+ )
value = self._create_mock_value(device=device)
if advantage == "gae":
advantage = GAE(
@@ -7799,7 +7985,10 @@ def test_ppo(
loss = loss_fn(td)
if isinstance(loss_fn, KLPENPPOLoss):
- kl = loss.pop("kl")
+ if composite_action_dist:
+ kl = loss.pop("kl_approx")
+ else:
+ kl = loss.pop("kl")
assert (kl != 0).any()
loss_critic = loss["loss_critic"]
@@ -7836,10 +8025,15 @@ def test_ppo(
@pytest.mark.parametrize("loss_class", (PPOLoss, ClipPPOLoss, KLPENPPOLoss))
@pytest.mark.parametrize("gradient_mode", (True,))
@pytest.mark.parametrize("device", get_default_devices())
- def test_ppo_state_dict(self, loss_class, device, gradient_mode):
+ @pytest.mark.parametrize("composite_action_dist", [True, False])
+ def test_ppo_state_dict(
+ self, loss_class, device, gradient_mode, composite_action_dist
+ ):
torch.manual_seed(self.seed)
- actor = self._create_mock_actor(device=device)
+ actor = self._create_mock_actor(
+ device=device, composite_action_dist=composite_action_dist
+ )
value = self._create_mock_value(device=device)
loss_fn = loss_class(actor, value, loss_critic_type="l2")
sd = loss_fn.state_dict()
@@ -7849,11 +8043,16 @@ def test_ppo_state_dict(self, loss_class, device, gradient_mode):
@pytest.mark.parametrize("loss_class", (PPOLoss, ClipPPOLoss, KLPENPPOLoss))
@pytest.mark.parametrize("advantage", ("gae", "vtrace", "td", "td_lambda", None))
@pytest.mark.parametrize("device", get_default_devices())
- def test_ppo_shared(self, loss_class, device, advantage):
+ @pytest.mark.parametrize("composite_action_dist", [True, False])
+ def test_ppo_shared(self, loss_class, device, advantage, composite_action_dist):
torch.manual_seed(self.seed)
- td = self._create_seq_mock_data_ppo(device=device)
+ td = self._create_seq_mock_data_ppo(
+ device=device, composite_action_dist=composite_action_dist
+ )
- actor, value = self._create_mock_actor_value(device=device)
+ actor, value = self._create_mock_actor_value(
+ device=device, composite_action_dist=composite_action_dist
+ )
if advantage == "gae":
advantage = GAE(
gamma=0.9,
@@ -7935,18 +8134,24 @@ def test_ppo_shared(self, loss_class, device, advantage):
)
@pytest.mark.parametrize("device", get_default_devices())
@pytest.mark.parametrize("separate_losses", [True, False])
+ @pytest.mark.parametrize("composite_action_dist", [True, False])
def test_ppo_shared_seq(
self,
loss_class,
device,
advantage,
separate_losses,
+ composite_action_dist,
):
"""Tests PPO with shared module with and without passing twice across the common module."""
torch.manual_seed(self.seed)
- td = self._create_seq_mock_data_ppo(device=device)
+ td = self._create_seq_mock_data_ppo(
+ device=device, composite_action_dist=composite_action_dist
+ )
- model, actor, value = self._create_mock_actor_value_shared(device=device)
+ model, actor, value = self._create_mock_actor_value_shared(
+ device=device, composite_action_dist=composite_action_dist
+ )
value2 = value[-1] # prune the common module
if advantage == "gae":
advantage = GAE(
@@ -8004,8 +8209,20 @@ def test_ppo_shared_seq(
grad2 = TensorDict(dict(model.named_parameters()), []).apply(
lambda x: x.grad.clone()
)
- assert_allclose_td(loss, loss2)
- assert_allclose_td(grad, grad2)
+ if composite_action_dist and loss_class is KLPENPPOLoss:
+ # KL computation for composite dist is based on randomly
+ # sampled data, thus will not be the same.
+ # Similarly, objective loss depends on the KL, so ir will
+ # not be the same either.
+ # Finally, gradients will be different too.
+ loss.pop("kl", None)
+ loss2.pop("kl", None)
+ loss.pop("loss_objective", None)
+ loss2.pop("loss_objective", None)
+ assert_allclose_td(loss, loss2)
+ else:
+ assert_allclose_td(loss, loss2)
+ assert_allclose_td(grad, grad2)
model.zero_grad()
@pytest.mark.skipif(
@@ -8015,11 +8232,18 @@ def test_ppo_shared_seq(
@pytest.mark.parametrize("gradient_mode", (True, False))
@pytest.mark.parametrize("advantage", ("gae", "vtrace", "td", "td_lambda", None))
@pytest.mark.parametrize("device", get_default_devices())
- def test_ppo_diff(self, loss_class, device, gradient_mode, advantage):
+ @pytest.mark.parametrize("composite_action_dist", [True, False])
+ def test_ppo_diff(
+ self, loss_class, device, gradient_mode, advantage, composite_action_dist
+ ):
torch.manual_seed(self.seed)
- td = self._create_seq_mock_data_ppo(device=device)
+ td = self._create_seq_mock_data_ppo(
+ device=device, composite_action_dist=composite_action_dist
+ )
- actor = self._create_mock_actor(device=device)
+ actor = self._create_mock_actor(
+ device=device, composite_action_dist=composite_action_dist
+ )
value = self._create_mock_value(device=device)
if advantage == "gae":
advantage = GAE(
@@ -8108,8 +8332,9 @@ def zero_param(p):
ValueEstimators.TDLambda,
],
)
- def test_ppo_tensordict_keys(self, loss_class, td_est):
- actor = self._create_mock_actor()
+ @pytest.mark.parametrize("composite_action_dist", [True, False])
+ def test_ppo_tensordict_keys(self, loss_class, td_est, composite_action_dist):
+ actor = self._create_mock_actor(composite_action_dist=composite_action_dist)
value = self._create_mock_value()
loss_fn = loss_class(actor, value, loss_critic_type="l2")
@@ -8148,7 +8373,10 @@ def test_ppo_tensordict_keys(self, loss_class, td_est):
@pytest.mark.parametrize("loss_class", (PPOLoss, ClipPPOLoss, KLPENPPOLoss))
@pytest.mark.parametrize("advantage", ("gae", "vtrace", "td", "td_lambda", None))
@pytest.mark.parametrize("td_est", list(ValueEstimators) + [None])
- def test_ppo_tensordict_keys_run(self, loss_class, advantage, td_est):
+ @pytest.mark.parametrize("composite_action_dist", [True, False])
+ def test_ppo_tensordict_keys_run(
+ self, loss_class, advantage, td_est, composite_action_dist
+ ):
"""Test PPO loss module with non-default tensordict keys."""
torch.manual_seed(self.seed)
gradient_mode = True
@@ -8163,9 +8391,12 @@ def test_ppo_tensordict_keys_run(self, loss_class, advantage, td_est):
td = self._create_seq_mock_data_ppo(
sample_log_prob_key=tensor_keys["sample_log_prob"],
action_key=tensor_keys["action"],
+ composite_action_dist=composite_action_dist,
)
actor = self._create_mock_actor(
- sample_log_prob_key=tensor_keys["sample_log_prob"]
+ sample_log_prob_key=tensor_keys["sample_log_prob"],
+ composite_action_dist=composite_action_dist,
+ action_key=tensor_keys["action"],
)
value = self._create_mock_value(out_keys=[tensor_keys["value"]])
@@ -8256,6 +8487,12 @@ def test_ppo_tensordict_keys_run(self, loss_class, advantage, td_est):
@pytest.mark.parametrize("reward_key", ["reward", "reward2"])
@pytest.mark.parametrize("done_key", ["done", "done2"])
@pytest.mark.parametrize("terminated_key", ["terminated", "terminated2"])
+ @pytest.mark.parametrize(
+ "composite_action_dist",
+ [
+ False,
+ ],
+ )
def test_ppo_notensordict(
self,
loss_class,
@@ -8265,6 +8502,7 @@ def test_ppo_notensordict(
reward_key,
done_key,
terminated_key,
+ composite_action_dist,
):
torch.manual_seed(self.seed)
td = self._create_mock_data_ppo(
@@ -8274,10 +8512,14 @@ def test_ppo_notensordict(
reward_key=reward_key,
done_key=done_key,
terminated_key=terminated_key,
+ composite_action_dist=composite_action_dist,
)
actor = self._create_mock_actor(
- observation_key=observation_key, sample_log_prob_key=sample_log_prob_key
+ observation_key=observation_key,
+ sample_log_prob_key=sample_log_prob_key,
+ composite_action_dist=composite_action_dist,
+ action_key=action_key,
)
value = self._create_mock_value(observation_key=observation_key)
@@ -8300,7 +8542,9 @@ def test_ppo_notensordict(
f"next_{observation_key}": td.get(("next", observation_key)),
}
if loss_class is KLPENPPOLoss:
- kwargs.update({"loc": td.get("loc"), "scale": td.get("scale")})
+ loc_key = "params" if composite_action_dist else "loc"
+ scale_key = "params" if composite_action_dist else "scale"
+ kwargs.update({loc_key: td.get(loc_key), scale_key: td.get(scale_key)})
td = TensorDict(kwargs, td.batch_size, names=["time"]).unflatten_keys("_")
@@ -8313,6 +8557,7 @@ def test_ppo_notensordict(
loss_val = loss(**kwargs)
torch.manual_seed(self.seed)
if beta is not None:
+
loss.beta = beta.clone()
loss_val_td = loss(td)
@@ -8340,15 +8585,20 @@ def test_ppo_notensordict(
@pytest.mark.parametrize("loss_class", (PPOLoss, ClipPPOLoss, KLPENPPOLoss))
@pytest.mark.parametrize("reduction", [None, "none", "mean", "sum"])
- def test_ppo_reduction(self, reduction, loss_class):
+ @pytest.mark.parametrize("composite_action_dist", [True, False])
+ def test_ppo_reduction(self, reduction, loss_class, composite_action_dist):
torch.manual_seed(self.seed)
device = (
torch.device("cpu")
if torch.cuda.device_count() == 0
else torch.device("cuda")
)
- td = self._create_seq_mock_data_ppo(device=device)
- actor = self._create_mock_actor(device=device)
+ td = self._create_seq_mock_data_ppo(
+ device=device, composite_action_dist=composite_action_dist
+ )
+ actor = self._create_mock_actor(
+ device=device, composite_action_dist=composite_action_dist
+ )
value = self._create_mock_value(device=device)
advantage = GAE(
gamma=0.9,
@@ -8376,10 +8626,17 @@ def test_ppo_reduction(self, reduction, loss_class):
@pytest.mark.parametrize("device", get_default_devices())
@pytest.mark.parametrize("loss_class", (PPOLoss, ClipPPOLoss, KLPENPPOLoss))
@pytest.mark.parametrize("clip_value", [True, False, None, 0.5, torch.tensor(0.5)])
- def test_ppo_value_clipping(self, clip_value, loss_class, device):
+ @pytest.mark.parametrize("composite_action_dist", [True, False])
+ def test_ppo_value_clipping(
+ self, clip_value, loss_class, device, composite_action_dist
+ ):
torch.manual_seed(self.seed)
- td = self._create_seq_mock_data_ppo(device=device)
- actor = self._create_mock_actor(device=device)
+ td = self._create_seq_mock_data_ppo(
+ device=device, composite_action_dist=composite_action_dist
+ )
+ actor = self._create_mock_actor(
+ device=device, composite_action_dist=composite_action_dist
+ )
value = self._create_mock_value(device=device)
advantage = GAE(
gamma=0.9,
@@ -8438,22 +8695,47 @@ def _create_mock_actor(
obs_dim=3,
action_dim=4,
device="cpu",
+ action_key="action",
observation_key="observation",
sample_log_prob_key="sample_log_prob",
+ composite_action_dist=False,
):
# Actor
- action_spec = BoundedTensorSpec(
+ action_spec = Bounded(
-torch.ones(action_dim), torch.ones(action_dim), (action_dim,)
)
+ if composite_action_dist:
+ action_spec = Composite({action_key: {"action1": action_spec}})
net = nn.Sequential(nn.Linear(obs_dim, 2 * action_dim), NormalParamExtractor())
+ if composite_action_dist:
+ distribution_class = functools.partial(
+ CompositeDistribution,
+ distribution_map={
+ "action1": TanhNormal,
+ },
+ name_map={
+ "action1": (action_key, "action1"),
+ },
+ log_prob_key=sample_log_prob_key,
+ aggregate_probabilities=True,
+ )
+ module_out_keys = [
+ ("params", "action1", "loc"),
+ ("params", "action1", "scale"),
+ ]
+ actor_in_keys = ["params"]
+ else:
+ distribution_class = TanhNormal
+ module_out_keys = actor_in_keys = ["loc", "scale"]
module = TensorDictModule(
- net, in_keys=[observation_key], out_keys=["loc", "scale"]
+ net, in_keys=[observation_key], out_keys=module_out_keys
)
actor = ProbabilisticActor(
module=module,
- in_keys=["loc", "scale"],
+ in_keys=actor_in_keys,
+ out_keys=[action_key],
spec=action_spec,
- distribution_class=TanhNormal,
+ distribution_class=distribution_class,
return_log_prob=True,
log_prob_key=sample_log_prob_key,
)
@@ -8477,7 +8759,15 @@ def _create_mock_value(
return value.to(device)
def _create_mock_common_layer_setup(
- self, n_obs=3, n_act=4, ncells=4, batch=2, n_hidden=2, T=10
+ self,
+ n_obs=3,
+ n_act=4,
+ ncells=4,
+ batch=2,
+ n_hidden=2,
+ T=10,
+ composite_action_dist=False,
+ sample_log_prob_key="sample_log_prob",
):
common_net = MLP(
num_cells=ncells,
@@ -8498,10 +8788,11 @@ def _create_mock_common_layer_setup(
out_features=1,
)
batch = [batch, T]
+ action = torch.randn(*batch, n_act)
td = TensorDict(
{
"obs": torch.randn(*batch, n_obs),
- "action": torch.randn(*batch, n_act),
+ "action": {"action1": action} if composite_action_dist else action,
"sample_log_prob": torch.randn(*batch),
"done": torch.zeros(*batch, 1, dtype=torch.bool),
"terminated": torch.zeros(*batch, 1, dtype=torch.bool),
@@ -8516,14 +8807,36 @@ def _create_mock_common_layer_setup(
names=[None, "time"],
)
common = Mod(common_net, in_keys=["obs"], out_keys=["hidden"])
+
+ if composite_action_dist:
+ distribution_class = functools.partial(
+ CompositeDistribution,
+ distribution_map={
+ "action1": TanhNormal,
+ },
+ name_map={
+ "action1": ("action", "action1"),
+ },
+ log_prob_key=sample_log_prob_key,
+ aggregate_probabilities=True,
+ )
+ module_out_keys = [
+ ("params", "action1", "loc"),
+ ("params", "action1", "scale"),
+ ]
+ actor_in_keys = ["params"]
+ else:
+ distribution_class = TanhNormal
+ module_out_keys = actor_in_keys = ["loc", "scale"]
+
actor = ProbSeq(
common,
Mod(actor_net, in_keys=["hidden"], out_keys=["param"]),
- Mod(NormalParamExtractor(), in_keys=["param"], out_keys=["loc", "scale"]),
+ Mod(NormalParamExtractor(), in_keys=["param"], out_keys=module_out_keys),
ProbMod(
- in_keys=["loc", "scale"],
+ in_keys=actor_in_keys,
out_keys=["action"],
- distribution_class=TanhNormal,
+ distribution_class=distribution_class,
),
)
critic = Seq(
@@ -8547,6 +8860,7 @@ def _create_seq_mock_data_a2c(
done_key="done",
terminated_key="terminated",
sample_log_prob_key="sample_log_prob",
+ composite_action_dist=False,
):
# create a tensordict
total_obs = torch.randn(batch, T + 1, obs_dim, device=device)
@@ -8562,8 +8876,11 @@ def _create_seq_mock_data_a2c(
done = torch.zeros(batch, T, 1, dtype=torch.bool, device=device)
terminated = torch.zeros(batch, T, 1, dtype=torch.bool, device=device)
mask = ~torch.zeros(batch, T, dtype=torch.bool, device=device)
+ action = action.masked_fill_(~mask.unsqueeze(-1), 0.0)
params_mean = torch.randn_like(action) / 10
params_scale = torch.rand_like(action) / 10
+ loc = params_mean.masked_fill_(~mask.unsqueeze(-1), 0.0)
+ scale = params_scale.masked_fill_(~mask.unsqueeze(-1), 0.0)
td = TensorDict(
batch_size=(batch, T),
source={
@@ -8575,17 +8892,21 @@ def _create_seq_mock_data_a2c(
reward_key: reward.masked_fill_(~mask.unsqueeze(-1), 0.0),
},
"collector": {"mask": mask},
- action_key: action.masked_fill_(~mask.unsqueeze(-1), 0.0),
+ action_key: {"action1": action} if composite_action_dist else action,
sample_log_prob_key: torch.randn_like(action[..., 1]).masked_fill_(
~mask, 0.0
)
/ 10,
- "loc": params_mean.masked_fill_(~mask.unsqueeze(-1), 0.0),
- "scale": params_scale.masked_fill_(~mask.unsqueeze(-1), 0.0),
},
device=device,
names=[None, "time"],
)
+ if composite_action_dist:
+ td[("params", "action1", "loc")] = loc
+ td[("params", "action1", "scale")] = scale
+ else:
+ td["loc"] = loc
+ td["scale"] = scale
return td
@pytest.mark.parametrize("gradient_mode", (True, False))
@@ -8593,11 +8914,24 @@ def _create_seq_mock_data_a2c(
@pytest.mark.parametrize("device", get_default_devices())
@pytest.mark.parametrize("td_est", list(ValueEstimators) + [None])
@pytest.mark.parametrize("functional", (True, False))
- def test_a2c(self, device, gradient_mode, advantage, td_est, functional):
+ @pytest.mark.parametrize("composite_action_dist", [True, False])
+ def test_a2c(
+ self,
+ device,
+ gradient_mode,
+ advantage,
+ td_est,
+ functional,
+ composite_action_dist,
+ ):
torch.manual_seed(self.seed)
- td = self._create_seq_mock_data_a2c(device=device)
+ td = self._create_seq_mock_data_a2c(
+ device=device, composite_action_dist=composite_action_dist
+ )
- actor = self._create_mock_actor(device=device)
+ actor = self._create_mock_actor(
+ device=device, composite_action_dist=composite_action_dist
+ )
value = self._create_mock_value(device=device)
if advantage == "gae":
advantage = GAE(
@@ -8630,14 +8964,24 @@ def test_a2c(self, device, gradient_mode, advantage, td_est, functional):
functional=functional,
)
+ def set_requires_grad(tensor, requires_grad):
+ tensor.requires_grad = requires_grad
+ return tensor
+
# Check error is raised when actions require grads
- td["action"].requires_grad = True
+ if composite_action_dist:
+ td["action"].apply_(lambda x: set_requires_grad(x, True))
+ else:
+ td["action"].requires_grad = True
with pytest.raises(
RuntimeError,
- match="tensordict stored action require grad.",
+ match="tensordict stored action requires grad.",
):
_ = loss_fn._log_probs(td)
- td["action"].requires_grad = False
+ if composite_action_dist:
+ td["action"].apply_(lambda x: set_requires_grad(x, False))
+ else:
+ td["action"].requires_grad = False
td = td.exclude(loss_fn.tensor_keys.value_target)
if advantage is not None:
@@ -8678,9 +9022,12 @@ def test_a2c(self, device, gradient_mode, advantage, td_est, functional):
@pytest.mark.parametrize("gradient_mode", (True, False))
@pytest.mark.parametrize("device", get_default_devices())
- def test_a2c_state_dict(self, device, gradient_mode):
+ @pytest.mark.parametrize("composite_action_dist", [True, False])
+ def test_a2c_state_dict(self, device, gradient_mode, composite_action_dist):
torch.manual_seed(self.seed)
- actor = self._create_mock_actor(device=device)
+ actor = self._create_mock_actor(
+ device=device, composite_action_dist=composite_action_dist
+ )
value = self._create_mock_value(device=device)
loss_fn = A2CLoss(actor, value, loss_critic_type="l2")
sd = loss_fn.state_dict()
@@ -8688,23 +9035,36 @@ def test_a2c_state_dict(self, device, gradient_mode):
loss_fn2.load_state_dict(sd)
@pytest.mark.parametrize("separate_losses", [False, True])
- def test_a2c_separate_losses(self, separate_losses):
+ @pytest.mark.parametrize("composite_action_dist", [True, False])
+ def test_a2c_separate_losses(self, separate_losses, composite_action_dist):
torch.manual_seed(self.seed)
- actor, critic, common, td = self._create_mock_common_layer_setup()
+ actor, critic, common, td = self._create_mock_common_layer_setup(
+ composite_action_dist=composite_action_dist
+ )
loss_fn = A2CLoss(
actor_network=actor,
critic_network=critic,
separate_losses=separate_losses,
)
+ def set_requires_grad(tensor, requires_grad):
+ tensor.requires_grad = requires_grad
+ return tensor
+
# Check error is raised when actions require grads
- td["action"].requires_grad = True
+ if composite_action_dist:
+ td["action"].apply_(lambda x: set_requires_grad(x, True))
+ else:
+ td["action"].requires_grad = True
with pytest.raises(
RuntimeError,
- match="tensordict stored action require grad.",
+ match="tensordict stored action requires grad.",
):
_ = loss_fn._log_probs(td)
- td["action"].requires_grad = False
+ if composite_action_dist:
+ td["action"].apply_(lambda x: set_requires_grad(x, False))
+ else:
+ td["action"].requires_grad = False
td = td.exclude(loss_fn.tensor_keys.value_target)
loss = loss_fn(td)
@@ -8748,13 +9108,18 @@ def test_a2c_separate_losses(self, separate_losses):
@pytest.mark.parametrize("gradient_mode", (True, False))
@pytest.mark.parametrize("advantage", ("gae", "vtrace", "td", "td_lambda", None))
@pytest.mark.parametrize("device", get_default_devices())
- def test_a2c_diff(self, device, gradient_mode, advantage):
+ @pytest.mark.parametrize("composite_action_dist", [True, False])
+ def test_a2c_diff(self, device, gradient_mode, advantage, composite_action_dist):
if pack_version.parse(torch.__version__) > pack_version.parse("1.14"):
raise pytest.skip("make_functional_with_buffers needs to be changed")
torch.manual_seed(self.seed)
- td = self._create_seq_mock_data_a2c(device=device)
+ td = self._create_seq_mock_data_a2c(
+ device=device, composite_action_dist=composite_action_dist
+ )
- actor = self._create_mock_actor(device=device)
+ actor = self._create_mock_actor(
+ device=device, composite_action_dist=composite_action_dist
+ )
value = self._create_mock_value(device=device)
if advantage == "gae":
advantage = GAE(
@@ -8824,8 +9189,9 @@ def test_a2c_diff(self, device, gradient_mode, advantage):
ValueEstimators.TDLambda,
],
)
- def test_a2c_tensordict_keys(self, td_est):
- actor = self._create_mock_actor()
+ @pytest.mark.parametrize("composite_action_dist", [True, False])
+ def test_a2c_tensordict_keys(self, td_est, composite_action_dist):
+ actor = self._create_mock_actor(composite_action_dist=composite_action_dist)
value = self._create_mock_value()
loss_fn = A2CLoss(actor, value, loss_critic_type="l2")
@@ -8870,7 +9236,10 @@ def test_a2c_tensordict_keys(self, td_est):
)
@pytest.mark.parametrize("advantage", ("gae", "vtrace", None))
@pytest.mark.parametrize("device", get_default_devices())
- def test_a2c_tensordict_keys_run(self, device, advantage, td_est):
+ @pytest.mark.parametrize("composite_action_dist", [True, False])
+ def test_a2c_tensordict_keys_run(
+ self, device, advantage, td_est, composite_action_dist
+ ):
"""Test A2C loss module with non-default tensordict keys."""
torch.manual_seed(self.seed)
gradient_mode = True
@@ -8890,10 +9259,14 @@ def test_a2c_tensordict_keys_run(self, device, advantage, td_est):
done_key=done_key,
terminated_key=terminated_key,
sample_log_prob_key=sample_log_prob_key,
+ composite_action_dist=composite_action_dist,
)
actor = self._create_mock_actor(
- device=device, sample_log_prob_key=sample_log_prob_key
+ device=device,
+ sample_log_prob_key=sample_log_prob_key,
+ composite_action_dist=composite_action_dist,
+ action_key=action_key,
)
value = self._create_mock_value(device=device, out_keys=[value_key])
if advantage == "gae":
@@ -8972,12 +9345,26 @@ def test_a2c_tensordict_keys_run(self, device, advantage, td_est):
@pytest.mark.parametrize("reward_key", ["reward", "reward2"])
@pytest.mark.parametrize("done_key", ["done", "done2"])
@pytest.mark.parametrize("terminated_key", ["terminated", "terminated2"])
+ @pytest.mark.parametrize(
+ "composite_action_dist",
+ [
+ False,
+ ],
+ )
def test_a2c_notensordict(
- self, action_key, observation_key, reward_key, done_key, terminated_key
+ self,
+ action_key,
+ observation_key,
+ reward_key,
+ done_key,
+ terminated_key,
+ composite_action_dist,
):
torch.manual_seed(self.seed)
- actor = self._create_mock_actor(observation_key=observation_key)
+ actor = self._create_mock_actor(
+ observation_key=observation_key, composite_action_dist=composite_action_dist
+ )
value = self._create_mock_value(observation_key=observation_key)
td = self._create_seq_mock_data_a2c(
action_key=action_key,
@@ -8985,6 +9372,7 @@ def test_a2c_notensordict(
reward_key=reward_key,
done_key=done_key,
terminated_key=terminated_key,
+ composite_action_dist=composite_action_dist,
)
loss = A2CLoss(actor, value)
@@ -9029,15 +9417,20 @@ def test_a2c_notensordict(
assert loss_critic == loss_val_td["loss_critic"]
@pytest.mark.parametrize("reduction", [None, "none", "mean", "sum"])
- def test_a2c_reduction(self, reduction):
+ @pytest.mark.parametrize("composite_action_dist", [True, False])
+ def test_a2c_reduction(self, reduction, composite_action_dist):
torch.manual_seed(self.seed)
device = (
torch.device("cpu")
if torch.cuda.device_count() == 0
else torch.device("cuda")
)
- td = self._create_seq_mock_data_a2c(device=device)
- actor = self._create_mock_actor(device=device)
+ td = self._create_seq_mock_data_a2c(
+ device=device, composite_action_dist=composite_action_dist
+ )
+ actor = self._create_mock_actor(
+ device=device, composite_action_dist=composite_action_dist
+ )
value = self._create_mock_value(device=device)
advantage = GAE(
gamma=0.9,
@@ -9064,10 +9457,15 @@ def test_a2c_reduction(self, reduction):
@pytest.mark.parametrize("device", get_default_devices())
@pytest.mark.parametrize("clip_value", [True, None, 0.5, torch.tensor(0.5)])
- def test_a2c_value_clipping(self, clip_value, device):
+ @pytest.mark.parametrize("composite_action_dist", [True, False])
+ def test_a2c_value_clipping(self, clip_value, device, composite_action_dist):
torch.manual_seed(self.seed)
- td = self._create_seq_mock_data_a2c(device=device)
- actor = self._create_mock_actor(device=device)
+ td = self._create_seq_mock_data_a2c(
+ device=device, composite_action_dist=composite_action_dist
+ )
+ actor = self._create_mock_actor(
+ device=device, composite_action_dist=composite_action_dist
+ )
value = self._create_mock_value(device=device)
advantage = GAE(
gamma=0.9,
@@ -9151,7 +9549,7 @@ def test_reinforce_value_net(
distribution_class=TanhNormal,
return_log_prob=True,
in_keys=["loc", "scale"],
- spec=UnboundedContinuousTensorSpec(n_act),
+ spec=Unbounded(n_act),
)
if advantage == "gae":
advantage = GAE(
@@ -9261,7 +9659,7 @@ def test_reinforce_tensordict_keys(self, td_est):
distribution_class=TanhNormal,
return_log_prob=True,
in_keys=["loc", "scale"],
- spec=UnboundedContinuousTensorSpec(n_act),
+ spec=Unbounded(n_act),
)
loss_fn = ReinforceLoss(
@@ -9455,7 +9853,7 @@ def test_reinforce_notensordict(
distribution_class=TanhNormal,
return_log_prob=True,
in_keys=["loc", "scale"],
- spec=UnboundedContinuousTensorSpec(n_act),
+ spec=Unbounded(n_act),
)
loss = ReinforceLoss(actor_network=actor_net, critic_network=value_net)
loss.set_keys(
@@ -9631,8 +10029,8 @@ def _create_world_model_model(self, rssm_hidden_dim, state_dim, mlp_num_units=13
ContinuousActionConvMockEnv(pixel_shape=[3, *self.img_size])
)
default_dict = {
- "state": UnboundedContinuousTensorSpec(state_dim),
- "belief": UnboundedContinuousTensorSpec(rssm_hidden_dim),
+ "state": Unbounded(state_dim),
+ "belief": Unbounded(rssm_hidden_dim),
}
mock_env.append_transform(
TensorDictPrimer(random=False, default_value=0, **default_dict)
@@ -9708,8 +10106,8 @@ def _create_mb_env(self, rssm_hidden_dim, state_dim, mlp_num_units=13):
ContinuousActionConvMockEnv(pixel_shape=[3, *self.img_size])
)
default_dict = {
- "state": UnboundedContinuousTensorSpec(state_dim),
- "belief": UnboundedContinuousTensorSpec(rssm_hidden_dim),
+ "state": Unbounded(state_dim),
+ "belief": Unbounded(rssm_hidden_dim),
}
mock_env.append_transform(
TensorDictPrimer(random=False, default_value=0, **default_dict)
@@ -9759,8 +10157,8 @@ def _create_actor_model(self, rssm_hidden_dim, state_dim, mlp_num_units=13):
ContinuousActionConvMockEnv(pixel_shape=[3, *self.img_size])
)
default_dict = {
- "state": UnboundedContinuousTensorSpec(state_dim),
- "belief": UnboundedContinuousTensorSpec(rssm_hidden_dim),
+ "state": Unbounded(state_dim),
+ "belief": Unbounded(rssm_hidden_dim),
}
mock_env.append_transform(
TensorDictPrimer(random=False, default_value=0, **default_dict)
@@ -10049,7 +10447,7 @@ class TestOnlineDT(LossModuleTestBase):
def _create_mock_actor(self, batch=2, obs_dim=3, action_dim=4, device="cpu"):
# Actor
- action_spec = BoundedTensorSpec(
+ action_spec = Bounded(
-torch.ones(action_dim), torch.ones(action_dim), (action_dim,)
)
net = nn.Sequential(nn.Linear(obs_dim, 2 * action_dim), NormalParamExtractor())
@@ -10281,7 +10679,7 @@ class TestDT(LossModuleTestBase):
def _create_mock_actor(self, batch=2, obs_dim=3, action_dim=4, device="cpu"):
# Actor
- action_spec = BoundedTensorSpec(
+ action_spec = Bounded(
-torch.ones(action_dim), torch.ones(action_dim), (action_dim,)
)
net = nn.Sequential(nn.Linear(obs_dim, 2 * action_dim), NormalParamExtractor())
@@ -10459,6 +10857,227 @@ def test_dt_reduction(self, reduction):
assert loss["loss"].shape == torch.Size([])
+class TestGAIL(LossModuleTestBase):
+ seed = 0
+
+ def _create_mock_discriminator(
+ self, batch=2, obs_dim=3, action_dim=4, device="cpu"
+ ):
+ # Discriminator
+ body = TensorDictModule(
+ MLP(
+ in_features=obs_dim + action_dim,
+ out_features=32,
+ depth=1,
+ num_cells=32,
+ activation_class=torch.nn.ReLU,
+ activate_last_layer=True,
+ ),
+ in_keys=["observation", "action"],
+ out_keys="hidden",
+ )
+ head = TensorDictModule(
+ MLP(
+ in_features=32,
+ out_features=1,
+ depth=0,
+ num_cells=32,
+ activation_class=torch.nn.Sigmoid,
+ activate_last_layer=True,
+ ),
+ in_keys="hidden",
+ out_keys="d_logits",
+ )
+ discriminator = TensorDictSequential(body, head)
+
+ return discriminator.to(device)
+
+ def _create_mock_data_gail(self, batch=2, obs_dim=3, action_dim=4, device="cpu"):
+ # create a tensordict
+ obs = torch.randn(batch, obs_dim, device=device)
+ action = torch.randn(batch, action_dim, device=device).clamp(-1, 1)
+ td = TensorDict(
+ batch_size=(batch,),
+ source={
+ "observation": obs,
+ "action": action,
+ "collector_action": action,
+ "collector_observation": obs,
+ },
+ device=device,
+ )
+ return td
+
+ def _create_seq_mock_data_gail(
+ self, batch=2, T=4, obs_dim=3, action_dim=4, device="cpu"
+ ):
+ # create a tensordict
+ obs = torch.randn(batch, T, obs_dim, device=device)
+ action = torch.randn(batch, T, action_dim, device=device).clamp(-1, 1)
+
+ td = TensorDict(
+ batch_size=(batch, T),
+ source={
+ "observation": obs,
+ "action": action,
+ "collector_action": action,
+ "collector_observation": obs,
+ },
+ device=device,
+ )
+ return td
+
+ def test_gail_tensordict_keys(self):
+ discriminator = self._create_mock_discriminator()
+ loss_fn = GAILLoss(discriminator)
+
+ default_keys = {
+ "expert_action": "action",
+ "expert_observation": "observation",
+ "collector_action": "collector_action",
+ "collector_observation": "collector_observation",
+ "discriminator_pred": "d_logits",
+ }
+
+ self.tensordict_keys_test(
+ loss_fn,
+ default_keys=default_keys,
+ )
+
+ @pytest.mark.parametrize("device", get_default_devices())
+ @pytest.mark.parametrize("use_grad_penalty", [True, False])
+ @pytest.mark.parametrize("gp_lambda", [0.1, 1.0])
+ def test_gail_notensordict(self, device, use_grad_penalty, gp_lambda):
+ torch.manual_seed(self.seed)
+ discriminator = self._create_mock_discriminator(device=device)
+ loss_fn = GAILLoss(
+ discriminator, use_grad_penalty=use_grad_penalty, gp_lambda=gp_lambda
+ )
+
+ tensordict = self._create_mock_data_gail(device=device)
+
+ in_keys = self._flatten_in_keys(loss_fn.in_keys)
+ kwargs = dict(tensordict.flatten_keys("_").select(*in_keys))
+
+ loss_val_td = loss_fn(tensordict)
+ if use_grad_penalty:
+ loss_val, _ = loss_fn(**kwargs)
+ else:
+ loss_val = loss_fn(**kwargs)
+
+ torch.testing.assert_close(loss_val_td.get("loss"), loss_val)
+ # test select
+ loss_fn.select_out_keys("loss")
+ if torch.__version__ >= "2.0.0":
+ loss_discriminator = loss_fn(**kwargs)
+ else:
+ with pytest.raises(
+ RuntimeError,
+ match="You are likely using tensordict.nn.dispatch with keyword arguments",
+ ):
+ loss_discriminator = loss_fn(**kwargs)
+ return
+ assert loss_discriminator == loss_val_td["loss"]
+
+ @pytest.mark.parametrize("device", get_available_devices())
+ @pytest.mark.parametrize("use_grad_penalty", [True, False])
+ @pytest.mark.parametrize("gp_lambda", [0.1, 1.0])
+ def test_gail(self, device, use_grad_penalty, gp_lambda):
+ torch.manual_seed(self.seed)
+ td = self._create_mock_data_gail(device=device)
+
+ discriminator = self._create_mock_discriminator(device=device)
+
+ loss_fn = GAILLoss(
+ discriminator, use_grad_penalty=use_grad_penalty, gp_lambda=gp_lambda
+ )
+ loss = loss_fn(td)
+ loss_transformer = loss["loss"]
+ loss_transformer.backward(retain_graph=True)
+ named_parameters = loss_fn.named_parameters()
+
+ for name, p in named_parameters:
+ if p.grad is not None and p.grad.norm() > 0.0:
+ assert "discriminator" in name
+ if p.grad is None:
+ assert "discriminator" not in name
+ loss_fn.zero_grad()
+
+ sum([loss_transformer]).backward()
+ named_parameters = list(loss_fn.named_parameters())
+ named_buffers = list(loss_fn.named_buffers())
+
+ assert len({p for n, p in named_parameters}) == len(list(named_parameters))
+ assert len({p for n, p in named_buffers}) == len(list(named_buffers))
+
+ for name, p in named_parameters:
+ assert p.grad.norm() > 0.0, f"parameter {name} has a null gradient"
+
+ @pytest.mark.parametrize("device", get_available_devices())
+ def test_gail_state_dict(self, device):
+ torch.manual_seed(self.seed)
+
+ discriminator = self._create_mock_discriminator(device=device)
+
+ loss_fn = GAILLoss(discriminator)
+ sd = loss_fn.state_dict()
+ loss_fn2 = GAILLoss(discriminator)
+ loss_fn2.load_state_dict(sd)
+
+ @pytest.mark.parametrize("device", get_available_devices())
+ @pytest.mark.parametrize("use_grad_penalty", [True, False])
+ @pytest.mark.parametrize("gp_lambda", [0.1, 1.0])
+ def test_seq_gail(self, device, use_grad_penalty, gp_lambda):
+ torch.manual_seed(self.seed)
+ td = self._create_seq_mock_data_gail(device=device)
+
+ discriminator = self._create_mock_discriminator(device=device)
+
+ loss_fn = GAILLoss(
+ discriminator, use_grad_penalty=use_grad_penalty, gp_lambda=gp_lambda
+ )
+ loss = loss_fn(td)
+ loss_transformer = loss["loss"]
+ loss_transformer.backward(retain_graph=True)
+ named_parameters = loss_fn.named_parameters()
+
+ for name, p in named_parameters:
+ if p.grad is not None and p.grad.norm() > 0.0:
+ assert "discriminator" in name
+ if p.grad is None:
+ assert "discriminator" not in name
+ loss_fn.zero_grad()
+
+ sum([loss_transformer]).backward()
+ named_parameters = list(loss_fn.named_parameters())
+ named_buffers = list(loss_fn.named_buffers())
+
+ assert len({p for n, p in named_parameters}) == len(list(named_parameters))
+ assert len({p for n, p in named_buffers}) == len(list(named_buffers))
+
+ for name, p in named_parameters:
+ assert p.grad.norm() > 0.0, f"parameter {name} has a null gradient"
+
+ @pytest.mark.parametrize("reduction", [None, "none", "mean", "sum"])
+ @pytest.mark.parametrize("use_grad_penalty", [True, False])
+ @pytest.mark.parametrize("gp_lambda", [0.1, 1.0])
+ def test_gail_reduction(self, reduction, use_grad_penalty, gp_lambda):
+ torch.manual_seed(self.seed)
+ device = (
+ torch.device("cpu")
+ if torch.cuda.device_count() == 0
+ else torch.device("cuda")
+ )
+ td = self._create_mock_data_gail(device=device)
+ discriminator = self._create_mock_discriminator(device=device)
+ loss_fn = GAILLoss(discriminator, reduction=reduction)
+ loss = loss_fn(td)
+ if reduction == "none":
+ assert loss["loss"].shape == (td["observation"].shape[0], 1)
+ else:
+ assert loss["loss"].shape == torch.Size([])
+
+
@pytest.mark.skipif(
not _has_functorch, reason=f"functorch not installed: {FUNCTORCH_ERR}"
)
@@ -10474,7 +11093,7 @@ def _create_mock_actor(
observation_key="observation",
):
# Actor
- action_spec = BoundedTensorSpec(
+ action_spec = Bounded(
-torch.ones(action_dim), torch.ones(action_dim), (action_dim,)
)
net = nn.Sequential(nn.Linear(obs_dim, 2 * action_dim), NormalParamExtractor())
@@ -11285,7 +11904,7 @@ def _create_mock_actor(
action_key="action",
):
# Actor
- action_spec = OneHotDiscreteTensorSpec(action_dim)
+ action_spec = OneHot(action_dim)
net = nn.Sequential(nn.Linear(obs_dim, 2 * action_dim), NormalParamExtractor())
module = TensorDictModule(net, in_keys=[observation_key], out_keys=["logits"])
actor = ProbabilisticActor(
@@ -14750,7 +15369,7 @@ def _forward_value_estimator_keys(self, **kwargs) -> None:
class MyLoss3(MyLoss2):
@dataclass
class _AcceptedKeys:
- some_key = "some_value"
+ some_key: str = "some_value"
loss_module = MyLoss3()
assert loss_module.tensor_keys.some_key == "some_value"
@@ -15112,6 +15731,68 @@ def __init__(self):
assert p.device == dest
+@pytest.mark.skipif(TORCH_VERSION < "2.5", reason="requires torch>=2.5")
+def test_exploration_compile():
+ m = ProbabilisticTensorDictModule(
+ in_keys=["loc", "scale"],
+ out_keys=["sample"],
+ distribution_class=torch.distributions.Normal,
+ )
+
+ # class set_exploration_type_random(set_exploration_type):
+ # __init__ = object.__init__
+ # type = ExplorationType.RANDOM
+ it = exploration_type()
+
+ @torch.compile(fullgraph=True)
+ def func(t):
+ with set_exploration_type(ExplorationType.RANDOM):
+ t0 = m(t.clone())
+ t1 = m(t.clone())
+ return t0, t1
+
+ t = TensorDict(loc=torch.randn(3), scale=torch.rand(3))
+ t0, t1 = func(t)
+ assert (t0["sample"] != t1["sample"]).any()
+ assert it == exploration_type()
+
+ @torch.compile(fullgraph=True)
+ def func(t):
+ with set_exploration_type(ExplorationType.MEAN):
+ t0 = m(t.clone())
+ t1 = m(t.clone())
+ return t0, t1
+
+ t = TensorDict(loc=torch.randn(3), scale=torch.rand(3))
+ t0, t1 = func(t)
+ assert (t0["sample"] == t1["sample"]).all()
+ assert it == exploration_type()
+
+ @torch.compile(fullgraph=True)
+ @set_exploration_type(ExplorationType.RANDOM)
+ def func(t):
+ t0 = m(t.clone())
+ t1 = m(t.clone())
+ return t0, t1
+
+ t = TensorDict(loc=torch.randn(3), scale=torch.rand(3))
+ t0, t1 = func(t)
+ assert (t0["sample"] != t1["sample"]).any()
+ assert it == exploration_type()
+
+ @torch.compile(fullgraph=True)
+ @set_exploration_type(ExplorationType.MEAN)
+ def func(t):
+ t0 = m(t.clone())
+ t1 = m(t.clone())
+ return t0, t1
+
+ t = TensorDict(loc=torch.randn(3), scale=torch.rand(3))
+ t0, t1 = func(t)
+ assert (t0["sample"] == t1["sample"]).all()
+ assert it == exploration_type()
+
+
def test_loss_exploration():
class DummyLoss(LossModule):
def forward(self, td, mode):
diff --git a/test/test_distributed.py b/test/test_distributed.py
index 40b4f5eae44..fd369f64962 100644
--- a/test/test_distributed.py
+++ b/test/test_distributed.py
@@ -405,7 +405,7 @@ def _test_distributed_collector_updatepolicy(
MultiaSyncDataCollector,
],
)
- @pytest.mark.parametrize("update_interval", [1_000_000, 1])
+ @pytest.mark.parametrize("update_interval", [1])
def test_distributed_collector_updatepolicy(self, collector_class, update_interval):
"""Testing various collector classes to be used in nodes."""
queue = mp.Queue(1)
diff --git a/test/test_distributions.py b/test/test_distributions.py
index 53bfda343a2..8a5b651531e 100644
--- a/test/test_distributions.py
+++ b/test/test_distributions.py
@@ -190,8 +190,8 @@ def test_truncnormal(self, min, max, vecs, upscale, shape, device):
d = TruncatedNormal(
*vecs,
upscale=upscale,
- min=min,
- max=max,
+ low=min,
+ high=max,
)
assert d.device == device
for _ in range(100):
@@ -218,7 +218,7 @@ def test_truncnormal_against_scipy(self):
high = 2
low = -1
log_pi_x = TruncatedNormal(
- mu, sigma, min=low, max=high, tanh_loc=False
+ mu, sigma, low=low, high=high, tanh_loc=False
).log_prob(x)
pi_x = torch.exp(log_pi_x)
log_pi_x.backward(torch.ones_like(log_pi_x))
@@ -264,8 +264,8 @@ def test_truncnormal_mode(self, min, max, vecs, upscale, shape, device):
d = TruncatedNormal(
*vecs,
upscale=upscale,
- min=min,
- max=max,
+ low=min,
+ high=max,
)
assert d.mode is not None
assert d.entropy() is not None
diff --git a/test/test_env.py b/test/test_env.py
index dee03c06e7d..9602c596f22 100644
--- a/test/test_env.py
+++ b/test/test_env.py
@@ -4,6 +4,7 @@
# LICENSE file in the root directory of this source tree.
import argparse
+import contextlib
import functools
import gc
import os.path
@@ -66,12 +67,7 @@
from torch import nn
from torchrl.collectors import MultiSyncDataCollector, SyncDataCollector
-from torchrl.data.tensor_specs import (
- CompositeSpec,
- DiscreteTensorSpec,
- NonTensorSpec,
- UnboundedContinuousTensorSpec,
-)
+from torchrl.data.tensor_specs import Categorical, Composite, NonTensor, Unbounded
from torchrl.envs import (
CatFrames,
CatTensors,
@@ -495,6 +491,26 @@ def test_mb_env_batch_lock(self, device, seed=0):
class TestParallel:
+ def test_create_env_fn(self, maybe_fork_ParallelEnv):
+ def make_env():
+ return GymEnv(PENDULUM_VERSIONED())
+
+ with pytest.raises(
+ RuntimeError, match="len\\(create_env_fn\\) and num_workers mismatch"
+ ):
+ maybe_fork_ParallelEnv(4, [make_env, make_env])
+
+ def test_create_env_kwargs(self, maybe_fork_ParallelEnv):
+ def make_env():
+ return GymEnv(PENDULUM_VERSIONED())
+
+ with pytest.raises(
+ RuntimeError, match="len\\(create_env_kwargs\\) and num_workers mismatch"
+ ):
+ maybe_fork_ParallelEnv(
+ 4, make_env, create_env_kwargs=[{"seed": 0}, {"seed": 1}]
+ )
+
@pytest.mark.skipif(
not torch.cuda.device_count(), reason="No cuda device detected."
)
@@ -1125,6 +1141,25 @@ def env_fn2(seed):
env1.close()
env2.close()
+ @pytest.mark.parametrize("parallel", [True, False])
+ def test_parallel_env_update_kwargs(self, parallel, maybe_fork_ParallelEnv):
+ def make_env(seed=None):
+ env = DiscreteActionConvMockEnv()
+ if seed is not None:
+ env.set_seed(seed)
+ return env
+
+ _class = maybe_fork_ParallelEnv if parallel else SerialEnv
+ env = _class(
+ num_workers=2,
+ create_env_fn=make_env,
+ create_env_kwargs=[{"seed": 0}, {"seed": 1}],
+ )
+ with pytest.raises(
+ RuntimeError, match="len\\(kwargs\\) and num_workers mismatch"
+ ):
+ env.update_kwargs([{"seed": 42}])
+
@pytest.mark.parametrize("batch_size", [(32, 5), (4,), (1,), ()])
@pytest.mark.parametrize("n_workers", [2, 1])
def test_parallel_env_reset_flag(
@@ -1908,18 +1943,12 @@ def test_info_dict_reader(self, device, seed=0):
env.set_info_dict_reader(
default_info_dict_reader(
["x_position"],
- spec=CompositeSpec(
- x_position=UnboundedContinuousTensorSpec(
- dtype=torch.float64, shape=()
- )
- ),
+ spec=Composite(x_position=Unbounded(dtype=torch.float64, shape=())),
)
)
assert "x_position" in env.observation_spec.keys()
- assert isinstance(
- env.observation_spec["x_position"], UnboundedContinuousTensorSpec
- )
+ assert isinstance(env.observation_spec["x_position"], Unbounded)
tensordict = env.reset()
tensordict = env.rand_step(tensordict)
@@ -1932,13 +1961,13 @@ def test_info_dict_reader(self, device, seed=0):
)
for spec in (
- {"x_position": UnboundedContinuousTensorSpec((), dtype=torch.float64)},
+ {"x_position": Unbounded((), dtype=torch.float64)},
# None,
- CompositeSpec(
- x_position=UnboundedContinuousTensorSpec((), dtype=torch.float64),
+ Composite(
+ x_position=Unbounded((), dtype=torch.float64),
shape=[],
),
- [UnboundedContinuousTensorSpec((), dtype=torch.float64)],
+ [Unbounded((), dtype=torch.float64)],
):
env2 = GymWrapper(gym.make("HalfCheetah-v4"))
env2.set_info_dict_reader(
@@ -2079,7 +2108,7 @@ def main_penv(j, q=None):
],
)
spec = env_p.action_spec
- policy = TestConcurrentEnvs.Policy(CompositeSpec(action=spec.to(device)))
+ policy = TestConcurrentEnvs.Policy(Composite(action=spec.to(device)))
N = 10
r_p = []
r_s = []
@@ -2113,7 +2142,7 @@ def main_collector(j, q=None):
lambda i=i: CountingEnv(i, device=device) for i in range(j, j + n_workers)
]
spec = make_envs[0]().action_spec
- policy = TestConcurrentEnvs.Policy(CompositeSpec(action=spec))
+ policy = TestConcurrentEnvs.Policy(Composite(action=spec))
collector = MultiSyncDataCollector(
make_envs,
policy,
@@ -2225,7 +2254,7 @@ def test_nested_env(self, envclass):
else:
raise NotImplementedError
reset = env.reset()
- assert not isinstance(env.reward_spec, CompositeSpec)
+ assert not isinstance(env.reward_spec, Composite)
for done_key in env.done_keys:
assert (
env.full_done_spec[done_key]
@@ -2496,8 +2525,8 @@ def test_mocking_envs(envclass):
class TestTerminatedOrTruncated:
@pytest.mark.parametrize("done_key", ["done", "terminated", "truncated"])
def test_root_prevail(self, done_key):
- _spec = DiscreteTensorSpec(2, shape=(), dtype=torch.bool)
- spec = CompositeSpec({done_key: _spec, ("agent", done_key): _spec})
+ _spec = Categorical(2, shape=(), dtype=torch.bool)
+ spec = Composite({done_key: _spec, ("agent", done_key): _spec})
data = TensorDict({done_key: [False], ("agent", done_key): [True, False]}, [])
assert not _terminated_or_truncated(data)
assert not _terminated_or_truncated(data, full_done_spec=spec)
@@ -2560,8 +2589,8 @@ def test_terminated_or_truncated_nospec(self):
def test_terminated_or_truncated_spec(self):
done_shape = (2, 1)
nested_done_shape = (2, 3, 1)
- spec = CompositeSpec(
- done=DiscreteTensorSpec(2, shape=done_shape, dtype=torch.bool),
+ spec = Composite(
+ done=Categorical(2, shape=done_shape, dtype=torch.bool),
shape=[
2,
],
@@ -2578,12 +2607,12 @@ def test_terminated_or_truncated_spec(self):
)
assert data.get("_reset", None) is None
- spec = CompositeSpec(
+ spec = Composite(
{
- ("agent", "done"): DiscreteTensorSpec(
+ ("agent", "done"): Categorical(
2, shape=nested_done_shape, dtype=torch.bool
),
- ("nested", "done"): DiscreteTensorSpec(
+ ("nested", "done"): Categorical(
2, shape=nested_done_shape, dtype=torch.bool
),
},
@@ -2618,11 +2647,11 @@ def test_terminated_or_truncated_spec(self):
assert data["agent", "_reset"].shape == nested_done_shape
assert data["nested", "_reset"].shape == nested_done_shape
- spec = CompositeSpec(
+ spec = Composite(
{
- "truncated": DiscreteTensorSpec(2, shape=done_shape, dtype=torch.bool),
- "terminated": DiscreteTensorSpec(2, shape=done_shape, dtype=torch.bool),
- ("nested", "terminated"): DiscreteTensorSpec(
+ "truncated": Categorical(2, shape=done_shape, dtype=torch.bool),
+ "terminated": Categorical(2, shape=done_shape, dtype=torch.bool),
+ ("nested", "terminated"): Categorical(
2, shape=nested_done_shape, dtype=torch.bool
),
},
@@ -2774,15 +2803,15 @@ def test_backprop(device, maybe_fork_ParallelEnv, share_individual_td):
class DifferentiableEnv(EnvBase):
def __init__(self, device):
super().__init__(device=device)
- self.observation_spec = CompositeSpec(
- observation=UnboundedContinuousTensorSpec(3, device=device),
+ self.observation_spec = Composite(
+ observation=Unbounded(3, device=device),
device=device,
)
- self.action_spec = CompositeSpec(
- action=UnboundedContinuousTensorSpec(3, device=device), device=device
+ self.action_spec = Composite(
+ action=Unbounded(3, device=device), device=device
)
- self.reward_spec = CompositeSpec(
- reward=UnboundedContinuousTensorSpec(1, device=device), device=device
+ self.reward_spec = Composite(
+ reward=Unbounded(1, device=device), device=device
)
self.seed = 0
@@ -3283,7 +3312,7 @@ def _reset(
return tensordict_reset
def transform_observation_spec(self, observation_spec):
- observation_spec["string"] = NonTensorSpec(())
+ observation_spec["string"] = NonTensor(())
return observation_spec
@pytest.mark.parametrize("batched", ["serial", "parallel"])
@@ -3351,6 +3380,98 @@ def test_pendulum_env(self):
assert r.shape == torch.Size((5, 10))
+@pytest.mark.parametrize("device", [None, *get_default_devices()])
+@pytest.mark.parametrize("env_device", [None, *get_default_devices()])
+class TestPartialSteps:
+ @pytest.mark.parametrize("use_buffers", [False, True])
+ def test_parallel_partial_steps(
+ self, use_buffers, device, env_device, maybe_fork_ParallelEnv
+ ):
+ with torch.device(device) if device is not None else contextlib.nullcontext():
+ penv = maybe_fork_ParallelEnv(
+ 4,
+ lambda: CountingEnv(max_steps=10, start_val=2, device=env_device),
+ use_buffers=use_buffers,
+ device=device,
+ )
+ td = penv.reset()
+ psteps = torch.zeros(4, dtype=torch.bool)
+ psteps[[1, 3]] = True
+ td.set("_step", psteps)
+
+ td.set("action", penv.action_spec.one())
+ td = penv.step(td)
+ assert (td[0].get("next") == 0).all()
+ assert (td[1].get("next") != 0).any()
+ assert (td[2].get("next") == 0).all()
+ assert (td[3].get("next") != 0).any()
+
+ @pytest.mark.parametrize("use_buffers", [False, True])
+ def test_parallel_partial_step_and_maybe_reset(
+ self, use_buffers, device, env_device, maybe_fork_ParallelEnv
+ ):
+ with torch.device(device) if device is not None else contextlib.nullcontext():
+ penv = maybe_fork_ParallelEnv(
+ 4,
+ lambda: CountingEnv(max_steps=10, start_val=2, device=env_device),
+ use_buffers=use_buffers,
+ device=device,
+ )
+ td = penv.reset()
+ psteps = torch.zeros(4, dtype=torch.bool)
+ psteps[[1, 3]] = True
+ td.set("_step", psteps)
+
+ td.set("action", penv.action_spec.one())
+ td, tdreset = penv.step_and_maybe_reset(td)
+ assert (td[0].get("next") == 0).all()
+ assert (td[1].get("next") != 0).any()
+ assert (td[2].get("next") == 0).all()
+ assert (td[3].get("next") != 0).any()
+
+ @pytest.mark.parametrize("use_buffers", [False, True])
+ def test_serial_partial_steps(self, use_buffers, device, env_device):
+ with torch.device(device) if device is not None else contextlib.nullcontext():
+ penv = SerialEnv(
+ 4,
+ lambda: CountingEnv(max_steps=10, start_val=2, device=env_device),
+ use_buffers=use_buffers,
+ device=device,
+ )
+ td = penv.reset()
+ psteps = torch.zeros(4, dtype=torch.bool)
+ psteps[[1, 3]] = True
+ td.set("_step", psteps)
+
+ td.set("action", penv.action_spec.one())
+ td = penv.step(td)
+ assert (td[0].get("next") == 0).all()
+ assert (td[1].get("next") != 0).any()
+ assert (td[2].get("next") == 0).all()
+ assert (td[3].get("next") != 0).any()
+
+ @pytest.mark.parametrize("use_buffers", [False, True])
+ def test_serial_partial_step_and_maybe_reset(self, use_buffers, device, env_device):
+ with torch.device(device) if device is not None else contextlib.nullcontext():
+ penv = SerialEnv(
+ 4,
+ lambda: CountingEnv(max_steps=10, start_val=2, device=env_device),
+ use_buffers=use_buffers,
+ device=device,
+ )
+ td = penv.reset()
+ psteps = torch.zeros(4, dtype=torch.bool)
+ psteps[[1, 3]] = True
+ td.set("_step", psteps)
+
+ td.set("action", penv.action_spec.one())
+ td = penv.step(td)
+ assert (td[0].get("next") == 0).all()
+ assert (td[1].get("next") != 0).any()
+ assert (td[2].get("next") == 0).all()
+ assert (td[3].get("next") != 0).any()
+
+
if __name__ == "__main__":
args, unknown = argparse.ArgumentParser().parse_known_args()
pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown)
diff --git a/test/test_exploration.py b/test/test_exploration.py
index 83ee4bc4220..f6a3ab7041b 100644
--- a/test/test_exploration.py
+++ b/test/test_exploration.py
@@ -21,12 +21,7 @@
from torchrl._utils import _replace_last
from torchrl.collectors import SyncDataCollector
-from torchrl.data import (
- BoundedTensorSpec,
- CompositeSpec,
- DiscreteTensorSpec,
- OneHotDiscreteTensorSpec,
-)
+from torchrl.data import Bounded, Categorical, Composite, OneHot
from torchrl.envs import SerialEnv
from torchrl.envs.transforms.transforms import gSDENoise, InitTracker, TransformedEnv
from torchrl.envs.utils import set_exploration_type
@@ -36,7 +31,7 @@
NormalParamExtractor,
TanhNormal,
)
-from torchrl.modules.models.exploration import LazygSDEModule
+from torchrl.modules.models.exploration import ConsistentDropoutModule, LazygSDEModule
from torchrl.modules.tensordict_module.actors import (
Actor,
ProbabilisticActor,
@@ -59,7 +54,7 @@ class TestEGreedy:
@set_exploration_type(InteractionType.RANDOM)
def test_egreedy(self, eps_init, module):
torch.manual_seed(0)
- spec = BoundedTensorSpec(1, 1, torch.Size([4]))
+ spec = Bounded(1, 1, torch.Size([4]))
module = torch.nn.Linear(4, 4, bias=False)
policy = Actor(spec=spec, module=module)
@@ -91,9 +86,9 @@ def test_egreedy_masked(self, module, eps_init, spec_class):
batch_size = (3, 4, 2)
module = torch.nn.Linear(action_size, action_size, bias=False)
if spec_class == "discrete":
- spec = DiscreteTensorSpec(action_size)
+ spec = Categorical(action_size)
else:
- spec = OneHotDiscreteTensorSpec(
+ spec = OneHot(
action_size,
shape=(action_size,),
)
@@ -166,7 +161,7 @@ def test_no_spec_error(
action_size = 4
batch_size = (3, 4, 2)
module = torch.nn.Linear(action_size, action_size, bias=False)
- spec = OneHotDiscreteTensorSpec(action_size, shape=(action_size,))
+ spec = OneHot(action_size, shape=(action_size,))
policy = QValueActor(spec=spec, module=module)
explorative_policy = TensorDictSequential(
policy,
@@ -187,7 +182,7 @@ def test_no_spec_error(
@pytest.mark.parametrize("module", [True, False])
def test_wrong_action_shape(self, module):
torch.manual_seed(0)
- spec = BoundedTensorSpec(1, 1, torch.Size([4]))
+ spec = Bounded(1, 1, torch.Size([4]))
module = torch.nn.Linear(4, 5, bias=False)
policy = Actor(spec=spec, module=module)
@@ -240,7 +235,7 @@ def test_ou(
device
)
module = SafeModule(net, in_keys=["observation"], out_keys=["loc", "scale"])
- action_spec = BoundedTensorSpec(-torch.ones(d_act), torch.ones(d_act), (d_act,))
+ action_spec = Bounded(-torch.ones(d_act), torch.ones(d_act), (d_act,))
policy = ProbabilisticActor(
spec=action_spec,
module=module,
@@ -444,7 +439,7 @@ def test_additivegaussian_sd(
pytest.skip("module raises an error if given spec=None")
torch.manual_seed(seed)
- action_spec = BoundedTensorSpec(
+ action_spec = Bounded(
-torch.ones(d_act, device=device),
torch.ones(d_act, device=device),
(d_act,),
@@ -463,9 +458,7 @@ def test_additivegaussian_sd(
spec=None,
)
policy = ProbabilisticActor(
- spec=CompositeSpec(action=action_spec)
- if spec_origin is not None
- else None,
+ spec=Composite(action=action_spec) if spec_origin is not None else None,
module=module,
in_keys=["loc", "scale"],
distribution_class=TanhNormal,
@@ -541,7 +534,7 @@ def test_additivegaussian(
device
)
module = SafeModule(net, in_keys=["observation"], out_keys=["loc", "scale"])
- action_spec = BoundedTensorSpec(
+ action_spec = Bounded(
-torch.ones(d_act, device=device),
torch.ones(d_act, device=device),
(d_act,),
@@ -644,7 +637,7 @@ def test_no_spec_error(self, device):
@pytest.mark.parametrize("safe", [True, False])
@pytest.mark.parametrize("device", get_default_devices())
@pytest.mark.parametrize(
- "exploration_type", [InteractionType.RANDOM, InteractionType.MODE]
+ "exploration_type", [InteractionType.RANDOM, InteractionType.DETERMINISTIC]
)
def test_gsde(
state_dim, action_dim, gSDE, device, safe, exploration_type, batch=16, bound=0.1
@@ -670,7 +663,7 @@ def test_gsde(
module = SafeModule(wrapper, in_keys=in_keys, out_keys=["loc", "scale"])
distribution_class = TanhNormal
distribution_kwargs = {"low": -bound, "high": bound}
- spec = BoundedTensorSpec(
+ spec = Bounded(
-torch.ones(action_dim) * bound, torch.ones(action_dim) * bound, (action_dim,)
).to(device)
@@ -708,7 +701,10 @@ def test_gsde(
with set_exploration_type(exploration_type):
action1 = module(td).get("action")
action2 = actor(td.exclude("action")).get("action")
- if gSDE or exploration_type == InteractionType.MODE:
+ if gSDE or exploration_type in (
+ InteractionType.DETERMINISTIC,
+ InteractionType.MODE,
+ ):
torch.testing.assert_close(action1, action2)
else:
with pytest.raises(AssertionError):
@@ -742,6 +738,156 @@ def test_gsde_init(sigma_init, state_dim, action_dim, mean, std, device, learn_s
), f"failed: mean={mean}, std={std}, sigma_init={sigma_init}, actual: {sigma.mean()}"
+class TestConsistentDropout:
+ @pytest.mark.parametrize("dropout_p", [0.0, 0.1, 0.5])
+ @pytest.mark.parametrize("parallel_spec", [False, True])
+ @pytest.mark.parametrize("device", get_default_devices())
+ def test_consistent_dropout(self, dropout_p, parallel_spec, device):
+ """
+
+ This preliminary test seeks to ensure two things for both
+ ConsistentDropout and ConsistentDropoutModule:
+ 1. Rollout transitions generate a dropout mask as desired.
+ - We can easily verify the existence of a mask
+ 2. The dropout mask is correctly applied.
+ - We will check with stochastic policies whether or not
+ the loc and scale are the same.
+ """
+ torch.manual_seed(0)
+
+ # NOTE: Please only put a module with one dropout layer.
+ # That's how this test is constructed anyways.
+ @torch.no_grad
+ def inner_verify_routine(module, env):
+ # Perform transitions.
+ collector = SyncDataCollector(
+ create_env_fn=env,
+ policy=module,
+ frames_per_batch=1,
+ total_frames=10,
+ device=device,
+ )
+ for frames in collector:
+ masks = [
+ (key, value)
+ for key, value in frames.items()
+ if key.startswith("mask_")
+ ]
+ # Assert rollouts do indeed correctly generate the masks.
+ assert len(masks) == 1, (
+ "Expected exactly ONE mask since we only put "
+ f"one dropout module, got {len(masks)}."
+ )
+
+ # Verify that the result for this batch is the same.
+ # Kind of Monte Carlo, to be honest.
+ sentinel_mask = masks[0][1].clone()
+ sentinel_outputs = frames.select("loc", "scale").clone()
+
+ desired_dropout_mask = torch.full_like(
+ sentinel_mask, 1 / (1 - dropout_p)
+ )
+ desired_dropout_mask[sentinel_mask == 0.0] = 0.0
+ # As of 15/08/24, :meth:`~torch.nn.functional.dropout`
+ # is being used. Never hurts to be safe.
+ assert torch.allclose(
+ sentinel_mask, desired_dropout_mask
+ ), "Dropout was not scaled properly."
+
+ new_frames = module(frames.clone())
+ infer_mask = new_frames[masks[0][0]]
+ infer_outputs = new_frames.select("loc", "scale")
+ assert (infer_mask == sentinel_mask).all(), "Mask does not match"
+
+ assert all(
+ [
+ torch.allclose(infer_outputs[key], sentinel_outputs[key])
+ for key in ("loc", "scale")
+ ]
+ ), (
+ "Outputs do not match:\n "
+ f"{infer_outputs['loc']}\n--- vs ---\n{sentinel_outputs['loc']}"
+ f"{infer_outputs['scale']}\n--- vs ---\n{sentinel_outputs['scale']}"
+ )
+
+ env = SerialEnv(
+ 2,
+ ContinuousActionVecMockEnv,
+ )
+ env = TransformedEnv(env.to(device), InitTracker())
+ env = env.to(device)
+ # the module must work with the action spec of a single env or a serial env
+ if parallel_spec:
+ action_spec = env.action_spec
+ else:
+ action_spec = ContinuousActionVecMockEnv(device=device).action_spec
+ d_act = action_spec.shape[-1]
+
+ # NOTE: Please only put a module with one dropout layer.
+ # That's how this test is constructed anyways.
+ module_td_seq = TensorDictSequential(
+ TensorDictModule(
+ nn.LazyLinear(2 * d_act), in_keys=["observation"], out_keys=["out"]
+ ),
+ ConsistentDropoutModule(p=dropout_p, in_keys="out"),
+ TensorDictModule(
+ NormalParamExtractor(), in_keys=["out"], out_keys=["loc", "scale"]
+ ),
+ )
+
+ policy_td_seq = ProbabilisticActor(
+ module=module_td_seq,
+ in_keys=["loc", "scale"],
+ distribution_class=TanhNormal,
+ default_interaction_type=InteractionType.RANDOM,
+ spec=action_spec,
+ ).to(device)
+
+ # Wake up the policies
+ policy_td_seq(env.reset())
+
+ # Test.
+ inner_verify_routine(policy_td_seq, env)
+
+ def test_consistent_dropout_primer(self):
+ import torch
+
+ from tensordict.nn import TensorDictModule as Mod, TensorDictSequential as Seq
+ from torchrl.envs import SerialEnv, StepCounter
+ from torchrl.modules import ConsistentDropoutModule, get_primers_from_module
+
+ torch.manual_seed(0)
+
+ m = Seq(
+ Mod(
+ torch.nn.Linear(7, 4),
+ in_keys=["observation"],
+ out_keys=["intermediate"],
+ ),
+ ConsistentDropoutModule(
+ p=0.5,
+ input_shape=(
+ 2,
+ 4,
+ ),
+ in_keys="intermediate",
+ ),
+ Mod(torch.nn.Linear(4, 7), in_keys=["intermediate"], out_keys=["action"]),
+ )
+ primer = get_primers_from_module(m)
+ env0 = ContinuousActionVecMockEnv().append_transform(StepCounter(5))
+ env1 = ContinuousActionVecMockEnv().append_transform(StepCounter(6))
+ env = SerialEnv(2, [lambda env=env0: env, lambda env=env1: env])
+ env = env.append_transform(primer)
+ r = env.rollout(10, m, break_when_any_done=False)
+ mask = [k for k in r.keys() if k.startswith("mask")][0]
+ assert (r[mask][0, :5] != r[mask][0, 5:6]).any()
+ assert (r[mask][0, :4] == r[mask][0, 4:5]).all()
+
+ assert (r[mask][1, :6] != r[mask][1, 6:7]).any()
+ assert (r[mask][1, :5] == r[mask][1, 5:6]).all()
+
+
if __name__ == "__main__":
args, unknown = argparse.ArgumentParser().parse_known_args()
pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown)
diff --git a/test/test_helpers.py b/test/test_helpers.py
index f468eddf6ed..cf28252a318 100644
--- a/test/test_helpers.py
+++ b/test/test_helpers.py
@@ -30,7 +30,7 @@
MockSerialEnv,
)
from packaging import version
-from torchrl.data import BoundedTensorSpec, CompositeSpec
+from torchrl.data import Bounded, Composite
from torchrl.envs.libs.gym import _has_gym
from torchrl.envs.transforms import ObservationNorm
from torchrl.envs.transforms.transforms import (
@@ -259,17 +259,14 @@ def test_transformed_env_constructor_with_state_dict(from_pixels):
def test_initialize_stats_from_observation_norms(device, keys, composed, initialized):
obs_spec, stat_key = None, None
if keys:
- obs_spec = CompositeSpec(
- **{
- key: BoundedTensorSpec(high=1, low=1, shape=torch.Size([1]))
- for key in keys
- }
+ obs_spec = Composite(
+ **{key: Bounded(high=1, low=1, shape=torch.Size([1])) for key in keys}
)
stat_key = keys[0]
env = ContinuousActionVecMockEnv(
device=device,
observation_spec=obs_spec,
- action_spec=BoundedTensorSpec(low=1, high=2, shape=torch.Size((1,))),
+ action_spec=Bounded(low=1, high=2, shape=torch.Size((1,))),
)
env.out_key = "observation"
else:
diff --git a/test/test_libs.py b/test/test_libs.py
index 6ccbf2788a9..3d04648fd4e 100644
--- a/test/test_libs.py
+++ b/test/test_libs.py
@@ -5,6 +5,7 @@
import functools
import gc
import importlib.util
+import urllib.error
_has_isaac = importlib.util.find_spec("isaacgym") is not None
@@ -18,10 +19,12 @@
import os
import time
+import urllib
from contextlib import nullcontext
from pathlib import Path
from sys import platform
from typing import Optional, Union
+from unittest import mock
import numpy as np
import pytest
@@ -36,6 +39,7 @@
PENDULUM_VERSIONED,
PONG_VERSIONED,
rand_reset,
+ retry,
rollout_consistency_assertion,
)
from packaging import version
@@ -55,16 +59,16 @@
from torchrl._utils import implement_for, logger as torchrl_logger
from torchrl.collectors.collectors import SyncDataCollector
from torchrl.data import (
- BinaryDiscreteTensorSpec,
- BoundedTensorSpec,
- CompositeSpec,
- DiscreteTensorSpec,
- MultiDiscreteTensorSpec,
- MultiOneHotDiscreteTensorSpec,
- OneHotDiscreteTensorSpec,
+ Binary,
+ Bounded,
+ Categorical,
+ Composite,
+ MultiCategorical,
+ MultiOneHot,
+ OneHot,
ReplayBuffer,
ReplayBufferEnsemble,
- UnboundedContinuousTensorSpec,
+ Unbounded,
UnboundedDiscreteTensorSpec,
)
from torchrl.data.datasets.atari_dqn import AtariDQNExperienceReplay
@@ -107,9 +111,15 @@
from torchrl.envs.libs.jumanji import _has_jumanji, JumanjiEnv
from torchrl.envs.libs.meltingpot import MeltingpotEnv, MeltingpotWrapper
from torchrl.envs.libs.openml import OpenMLEnv
+from torchrl.envs.libs.openspiel import _has_pyspiel, OpenSpielEnv, OpenSpielWrapper
from torchrl.envs.libs.pettingzoo import _has_pettingzoo, PettingZooEnv
from torchrl.envs.libs.robohive import _has_robohive, RoboHiveEnv
from torchrl.envs.libs.smacv2 import _has_smacv2, SMACv2Env
+from torchrl.envs.libs.unity_mlagents import (
+ _has_unity_mlagents,
+ UnityMLAgentsEnv,
+ UnityMLAgentsWrapper,
+)
from torchrl.envs.libs.vmas import _has_vmas, VmasEnv, VmasWrapper
from torchrl.envs.transforms import ActionMask, TransformedEnv
@@ -152,6 +162,16 @@
_has_meltingpot = importlib.util.find_spec("meltingpot") is not None
+_has_minigrid = importlib.util.find_spec("minigrid") is not None
+
+
+@pytest.fixture(scope="session", autouse=True)
+def maybe_init_minigrid():
+ if _has_minigrid and _has_gymnasium:
+ import minigrid
+
+ minigrid.register_minigrid_envs()
+
def get_gym_pixel_wrapper():
try:
@@ -206,18 +226,16 @@ def __init__(self, arg1, *, arg2, **kwargs):
assert arg1 == 1
assert arg2 == 2
- self.observation_spec = CompositeSpec(
- observation=UnboundedContinuousTensorSpec((*self.batch_size, 3)),
- other=CompositeSpec(
- another_other=UnboundedContinuousTensorSpec((*self.batch_size, 3)),
+ self.observation_spec = Composite(
+ observation=Unbounded((*self.batch_size, 3)),
+ other=Composite(
+ another_other=Unbounded((*self.batch_size, 3)),
shape=self.batch_size,
),
shape=self.batch_size,
)
- self.action_spec = UnboundedContinuousTensorSpec((*self.batch_size, 3))
- self.done_spec = DiscreteTensorSpec(
- 2, (*self.batch_size, 1), dtype=torch.bool
- )
+ self.action_spec = Unbounded((*self.batch_size, 3))
+ self.done_spec = Categorical(2, (*self.batch_size, 1), dtype=torch.bool)
self.full_done_spec["truncated"] = self.full_done_spec["terminated"].clone()
def _reset(self, tensordict):
@@ -242,16 +260,14 @@ def _set_seed(self, seed):
@implement_for("gym", None, "0.18")
def _make_spec(self, batch_size, cat, cat_shape, multicat, multicat_shape):
- return CompositeSpec(
- a=UnboundedContinuousTensorSpec(shape=(*batch_size, 1)),
- b=CompositeSpec(
- c=cat(5, shape=cat_shape, dtype=torch.int64), shape=batch_size
- ),
+ return Composite(
+ a=Unbounded(shape=(*batch_size, 1)),
+ b=Composite(c=cat(5, shape=cat_shape, dtype=torch.int64), shape=batch_size),
d=cat(5, shape=cat_shape, dtype=torch.int64),
e=multicat([2, 3], shape=(*batch_size, multicat_shape), dtype=torch.int64),
- f=BoundedTensorSpec(-3, 4, shape=(*batch_size, 1)),
+ f=Bounded(-3, 4, shape=(*batch_size, 1)),
# g=UnboundedDiscreteTensorSpec(shape=(*batch_size, 1), dtype=torch.long),
- h=BinaryDiscreteTensorSpec(n=5, shape=(*batch_size, 5)),
+ h=Binary(n=5, shape=(*batch_size, 5)),
shape=batch_size,
)
@@ -259,44 +275,38 @@ def _make_spec(self, batch_size, cat, cat_shape, multicat, multicat_shape):
def _make_spec( # noqa: F811
self, batch_size, cat, cat_shape, multicat, multicat_shape
):
- return CompositeSpec(
- a=UnboundedContinuousTensorSpec(shape=(*batch_size, 1)),
- b=CompositeSpec(
- c=cat(5, shape=cat_shape, dtype=torch.int64), shape=batch_size
- ),
+ return Composite(
+ a=Unbounded(shape=(*batch_size, 1)),
+ b=Composite(c=cat(5, shape=cat_shape, dtype=torch.int64), shape=batch_size),
d=cat(5, shape=cat_shape, dtype=torch.int64),
e=multicat([2, 3], shape=(*batch_size, multicat_shape), dtype=torch.int64),
- f=BoundedTensorSpec(-3, 4, shape=(*batch_size, 1)),
+ f=Bounded(-3, 4, shape=(*batch_size, 1)),
g=UnboundedDiscreteTensorSpec(shape=(*batch_size, 1), dtype=torch.long),
- h=BinaryDiscreteTensorSpec(n=5, shape=(*batch_size, 5)),
+ h=Binary(n=5, shape=(*batch_size, 5)),
shape=batch_size,
)
- @implement_for("gymnasium")
+ @implement_for("gymnasium", None, "1.0.0")
def _make_spec( # noqa: F811
self, batch_size, cat, cat_shape, multicat, multicat_shape
):
- return CompositeSpec(
- a=UnboundedContinuousTensorSpec(shape=(*batch_size, 1)),
- b=CompositeSpec(
- c=cat(5, shape=cat_shape, dtype=torch.int64), shape=batch_size
- ),
+ return Composite(
+ a=Unbounded(shape=(*batch_size, 1)),
+ b=Composite(c=cat(5, shape=cat_shape, dtype=torch.int64), shape=batch_size),
d=cat(5, shape=cat_shape, dtype=torch.int64),
e=multicat([2, 3], shape=(*batch_size, multicat_shape), dtype=torch.int64),
- f=BoundedTensorSpec(-3, 4, shape=(*batch_size, 1)),
+ f=Bounded(-3, 4, shape=(*batch_size, 1)),
g=UnboundedDiscreteTensorSpec(shape=(*batch_size, 1), dtype=torch.long),
- h=BinaryDiscreteTensorSpec(n=5, shape=(*batch_size, 5)),
+ h=Binary(n=5, shape=(*batch_size, 5)),
shape=batch_size,
)
@pytest.mark.parametrize("categorical", [True, False])
def test_gym_spec_cast(self, categorical):
batch_size = [3, 4]
- cat = DiscreteTensorSpec if categorical else OneHotDiscreteTensorSpec
+ cat = Categorical if categorical else OneHot
cat_shape = batch_size if categorical else (*batch_size, 5)
- multicat = (
- MultiDiscreteTensorSpec if categorical else MultiOneHotDiscreteTensorSpec
- )
+ multicat = MultiCategorical if categorical else MultiOneHot
multicat_shape = 2 if categorical else 5
spec = self._make_spec(batch_size, cat, cat_shape, multicat, multicat_shape)
recon = _gym_to_torchrl_spec_transform(
@@ -321,7 +331,7 @@ def test_gym_spec_cast_tuple_sequential(self, order):
# @pytest.mark.parametrize("order", ["seq_tuple", "tuple_seq"])
@pytest.mark.parametrize("order", ["tuple_seq"])
- @implement_for("gymnasium")
+ @implement_for("gymnasium", None, "1.0.0")
def test_gym_spec_cast_tuple_sequential(self, order): # noqa: F811
with set_gym_backend("gymnasium"):
if order == "seq_tuple":
@@ -837,7 +847,7 @@ def info_reader(info, tensordict):
finally:
set_gym_backend(gb).set()
- @implement_for("gymnasium")
+ @implement_for("gymnasium", None, "1.0.0")
def test_one_hot_and_categorical(self):
# tests that one-hot and categorical work ok when an integer is expected as action
cliff_walking = GymEnv("CliffWalking-v0", categorical_action_encoding=True)
@@ -856,7 +866,7 @@ def test_one_hot_and_categorical(self): # noqa: F811
# versions.
return
- @implement_for("gymnasium")
+ @implement_for("gymnasium", None, "1.0.0")
@pytest.mark.parametrize(
"envname",
["HalfCheetah-v4", "CartPole-v1", "ALE/Pong-v5"]
@@ -882,7 +892,7 @@ def test_vecenvs_wrapper(self, envname):
assert env.batch_size == torch.Size([2])
check_env_specs(env)
- @implement_for("gymnasium")
+ @implement_for("gymnasium", None, "1.0.0")
# this env has Dict-based observation which is a nice thing to test
@pytest.mark.parametrize(
"envname",
@@ -1044,7 +1054,7 @@ def test_gym_output_num(self, wrapper): # noqa: F811
finally:
set_gym_backend(gym).set()
- @implement_for("gymnasium")
+ @implement_for("gymnasium", None, "1.0.0")
@pytest.mark.parametrize("wrapper", [True, False])
def test_gym_output_num(self, wrapper): # noqa: F811
# gym has 5 outputs, with truncation
@@ -1147,7 +1157,7 @@ def test_vecenvs_nan(self): # noqa: F811
del c
return
- @implement_for("gymnasium")
+ @implement_for("gymnasium", None, "1.0.0")
def test_vecenvs_nan(self): # noqa: F811
# new versions of gym must never return nan for next values when there is a done state
torch.manual_seed(0)
@@ -1288,6 +1298,24 @@ def test_resetting_strategies(self, heterogeneous):
gc.collect()
+@pytest.mark.skipif(
+ not _has_minigrid or not _has_gymnasium, reason="MiniGrid not found"
+)
+class TestMiniGrid:
+ @pytest.mark.parametrize(
+ "id",
+ [
+ "BabyAI-KeyCorridorS6R3-v0",
+ "MiniGrid-Empty-16x16-v0",
+ "MiniGrid-BlockedUnlockPickup-v0",
+ ],
+ )
+ def test_minigrid(self, id):
+ env_base = gymnasium.make(id)
+ env = GymWrapper(env_base)
+ check_env_specs(env)
+
+
@implement_for("gym", None, "0.26")
def _make_gym_environment(env_name): # noqa: F811
gym = gym_backend()
@@ -1300,7 +1328,7 @@ def _make_gym_environment(env_name): # noqa: F811
return gym.make(env_name, render_mode="rgb_array")
-@implement_for("gymnasium")
+@implement_for("gymnasium", None, "1.0.0")
def _make_gym_environment(env_name): # noqa: F811
gym = gym_backend()
return gym.make(env_name, render_mode="rgb_array")
@@ -2804,34 +2832,11 @@ def _minari_selected_datasets():
torch.manual_seed(0)
- # We rely on sorting the keys as v0 < v1 but if the version is greater than 9 this won't work
- total_keys = sorted(minari.list_remote_datasets())
- assert not any(
- key[-2:] == "10" for key in total_keys
- ), "You should adapt the Minari test scripts as some dataset have a version >= 10 and sorting will fail."
- total_keys_splits = [key.split("-") for key in total_keys]
+ total_keys = sorted(
+ minari.list_remote_datasets(latest_version=True, compatible_minari_version=True)
+ )
indices = torch.randperm(len(total_keys))[:20]
keys = [total_keys[idx] for idx in indices]
- keys = [
- key
- for key in keys
- if "=0.4" in minari.list_remote_datasets()[key]["minari_version"]
- ]
-
- def _replace_with_max(key):
- key_split = key.split("-")
- same_entries = (
- torch.tensor(
- [total_key[:-1] == key_split[:-1] for total_key in total_keys_splits]
- )
- .nonzero()
- .squeeze()
- .tolist()
- )
- last_same_entry = same_entries[-1]
- return total_keys[last_same_entry]
-
- keys = [_replace_with_max(key) for key in keys]
assert len(keys) > 5, keys
_MINARI_DATASETS += keys
@@ -2861,12 +2866,8 @@ def test_load(self, selected_dataset, split):
break
def test_minari_preproc(self, tmpdir):
- global _MINARI_DATASETS
- if not _MINARI_DATASETS:
- _minari_selected_datasets()
- selected_dataset = _MINARI_DATASETS[0]
dataset = MinariExperienceReplay(
- selected_dataset,
+ "D4RL/pointmaze/large-v2",
batch_size=32,
split_trajs=False,
download="force",
@@ -3073,7 +3074,7 @@ def test_atari_preproc(self, dataset_id, tmpdir):
t = Compose(
UnsqueezeTransform(
- unsqueeze_dim=-3, in_keys=["observation", ("next", "observation")]
+ dim=-3, in_keys=["observation", ("next", "observation")]
),
Resize(32, in_keys=["observation", ("next", "observation")]),
RenameTransform(in_keys=["action"], out_keys=["other_action"]),
@@ -3692,7 +3693,7 @@ class TestRoboHive:
# The other option would be not to use parametrize but that also
# means less informative error trace stacks.
# In the CI, robohive should not coexist with other libs so that's fine.
- # Robohive logging behaviour can be controlled via ROBOHIVE_VERBOSITY=ALL/INFO/(WARN)/ERROR/ONCE/ALWAYS/SILENT
+ # Robohive logging behavior can be controlled via ROBOHIVE_VERBOSITY=ALL/INFO/(WARN)/ERROR/ONCE/ALWAYS/SILENT
@pytest.mark.parametrize("from_pixels", [False, True])
@pytest.mark.parametrize("from_depths", [False, True])
@pytest.mark.parametrize("envname", RoboHiveEnv.available_envs)
@@ -3812,6 +3813,270 @@ def test_collector(self):
collector.shutdown()
+# List of OpenSpiel games to test
+# TODO: Some of the games in `OpenSpielWrapper.available_envs` raise errors for
+# a few different reasons, mostly because we do not support chance nodes yet. So
+# we cannot run tests on all of them yet.
+_openspiel_games = [
+ # ----------------
+ # Sequential games
+ # 1-player
+ "morpion_solitaire",
+ # 2-player
+ "amazons",
+ "battleship",
+ "breakthrough",
+ "checkers",
+ "chess",
+ "cliff_walking",
+ "clobber",
+ "connect_four",
+ "cursor_go",
+ "dark_chess",
+ "dark_hex",
+ "dark_hex_ir",
+ "dots_and_boxes",
+ "go",
+ "havannah",
+ "hex",
+ "kriegspiel",
+ "mancala",
+ "nim",
+ "nine_mens_morris",
+ "othello",
+ "oware",
+ "pentago",
+ "phantom_go",
+ "phantom_ttt",
+ "phantom_ttt_ir",
+ "sheriff",
+ "tic_tac_toe",
+ "twixt",
+ "ultimate_tic_tac_toe",
+ "y",
+ # --------------
+ # Parallel games
+ # 2-player
+ "blotto",
+ "matrix_bos",
+ "matrix_brps",
+ "matrix_cd",
+ "matrix_coordination",
+ "matrix_mp",
+ "matrix_pd",
+ "matrix_rps",
+ "matrix_rpsw",
+ "matrix_sh",
+ "matrix_shapleys_game",
+ "oshi_zumo",
+ # 3-player
+ "matching_pennies_3p",
+]
+
+
+@pytest.mark.skipif(not _has_pyspiel, reason="open_spiel not found")
+class TestOpenSpiel:
+ @pytest.mark.parametrize("game_string", _openspiel_games)
+ @pytest.mark.parametrize("return_state", [False, True])
+ @pytest.mark.parametrize("categorical_actions", [False, True])
+ def test_all_envs(self, game_string, return_state, categorical_actions):
+ env = OpenSpielEnv(
+ game_string,
+ categorical_actions=categorical_actions,
+ return_state=return_state,
+ )
+ check_env_specs(env)
+
+ @pytest.mark.parametrize("game_string", _openspiel_games)
+ @pytest.mark.parametrize("return_state", [False, True])
+ @pytest.mark.parametrize("categorical_actions", [False, True])
+ def test_wrapper(self, game_string, return_state, categorical_actions):
+ import pyspiel
+
+ base_env = pyspiel.load_game(game_string).new_initial_state()
+ env_torchrl = OpenSpielWrapper(
+ base_env, categorical_actions=categorical_actions, return_state=return_state
+ )
+ env_torchrl.rollout(max_steps=5)
+
+ @pytest.mark.parametrize("game_string", _openspiel_games)
+ @pytest.mark.parametrize("return_state", [False, True])
+ @pytest.mark.parametrize("categorical_actions", [False, True])
+ def test_reset_state(self, game_string, return_state, categorical_actions):
+ env = OpenSpielEnv(
+ game_string,
+ categorical_actions=categorical_actions,
+ return_state=return_state,
+ )
+ td = env.reset()
+ td_init = td.clone()
+
+ # Perform an action
+ td = env.step(env.full_action_spec.rand())
+
+ # Save the current td for reset
+ td_reset = td["next"].clone()
+
+ # Perform a second action
+ td = env.step(env.full_action_spec.rand())
+
+ # Resetting to a specific state can only happen if `return_state` is
+ # enabled. Otherwise, it is reset to the initial state.
+ if return_state:
+ # Check that the state was reset to the specified state
+ td = env.reset(td_reset)
+ assert (td == td_reset).all()
+ else:
+ # Check that the state was reset to the initial state
+ td = env.reset()
+ assert (td == td_init).all()
+
+ def test_chance_not_implemented(self):
+ with pytest.raises(
+ NotImplementedError,
+ match="not yet supported",
+ ):
+ OpenSpielEnv("bridge")
+
+
+# NOTE: Each of the registered envs are around 180 MB, so only test a few.
+_mlagents_registered_envs = [
+ "3DBall",
+ "StrikersVsGoalie",
+]
+
+
+@pytest.mark.skipif(not _has_unity_mlagents, reason="mlagents_envs not found")
+class TestUnityMLAgents:
+ @mock.patch("mlagents_envs.env_utils.launch_executable")
+ @mock.patch("mlagents_envs.environment.UnityEnvironment._get_communicator")
+ @pytest.mark.parametrize(
+ "group_map",
+ [None, MarlGroupMapType.ONE_GROUP_PER_AGENT, MarlGroupMapType.ALL_IN_ONE_GROUP],
+ )
+ def test_env(self, mock_communicator, mock_launcher, group_map):
+ from mlagents_envs.mock_communicator import MockCommunicator
+
+ mock_communicator.return_value = MockCommunicator(
+ discrete_action=False, visual_inputs=0
+ )
+ env = UnityMLAgentsEnv(" ", group_map=group_map)
+ try:
+ check_env_specs(env)
+ finally:
+ env.close()
+
+ @mock.patch("mlagents_envs.env_utils.launch_executable")
+ @mock.patch("mlagents_envs.environment.UnityEnvironment._get_communicator")
+ @pytest.mark.parametrize(
+ "group_map",
+ [None, MarlGroupMapType.ONE_GROUP_PER_AGENT, MarlGroupMapType.ALL_IN_ONE_GROUP],
+ )
+ def test_wrapper(self, mock_communicator, mock_launcher, group_map):
+ from mlagents_envs.environment import UnityEnvironment
+ from mlagents_envs.mock_communicator import MockCommunicator
+
+ mock_communicator.return_value = MockCommunicator(
+ discrete_action=False, visual_inputs=0
+ )
+ env = UnityMLAgentsWrapper(UnityEnvironment(" "), group_map=group_map)
+ try:
+ check_env_specs(env)
+ finally:
+ env.close()
+
+ @mock.patch("mlagents_envs.env_utils.launch_executable")
+ @mock.patch("mlagents_envs.environment.UnityEnvironment._get_communicator")
+ @pytest.mark.parametrize(
+ "group_map",
+ [None, MarlGroupMapType.ONE_GROUP_PER_AGENT, MarlGroupMapType.ALL_IN_ONE_GROUP],
+ )
+ def test_rollout(self, mock_communicator, mock_launcher, group_map):
+ from mlagents_envs.environment import UnityEnvironment
+ from mlagents_envs.mock_communicator import MockCommunicator
+
+ mock_communicator.return_value = MockCommunicator(
+ discrete_action=False, visual_inputs=0
+ )
+ env = UnityMLAgentsWrapper(UnityEnvironment(" "), group_map=group_map)
+ try:
+ env.rollout(
+ max_steps=500, break_when_any_done=False, break_when_all_done=False
+ )
+ finally:
+ env.close()
+
+ @pytest.mark.unity_editor
+ def test_with_editor(self):
+ print("Please press play in the Unity editor") # noqa: T201
+ env = UnityMLAgentsEnv(timeout_wait=30)
+ try:
+ env.reset()
+ check_env_specs(env)
+
+ # Perform a rollout
+ td = env.reset()
+ env.rollout(
+ max_steps=100, break_when_any_done=False, break_when_all_done=False
+ )
+
+ # Step manually
+ tensordicts = []
+ td = env.reset()
+ tensordicts.append(td)
+ traj_len = 200
+ for _ in range(traj_len - 1):
+ td = env.step(td.update(env.full_action_spec.rand()))
+ tensordicts.append(td)
+
+ traj = torch.stack(tensordicts)
+ assert traj.batch_size == torch.Size([traj_len])
+ finally:
+ env.close()
+
+ @retry(
+ (
+ urllib.error.HTTPError,
+ urllib.error.URLError,
+ urllib.error.ContentTooShortError,
+ ),
+ 5,
+ )
+ @pytest.mark.parametrize("registered_name", _mlagents_registered_envs)
+ @pytest.mark.parametrize(
+ "group_map",
+ [None, MarlGroupMapType.ONE_GROUP_PER_AGENT, MarlGroupMapType.ALL_IN_ONE_GROUP],
+ )
+ def test_registered_envs(self, registered_name, group_map):
+ env = UnityMLAgentsEnv(
+ registered_name=registered_name,
+ no_graphics=True,
+ group_map=group_map,
+ )
+ try:
+ check_env_specs(env)
+
+ # Perform a rollout
+ td = env.reset()
+ env.rollout(
+ max_steps=20, break_when_any_done=False, break_when_all_done=False
+ )
+
+ # Step manually
+ tensordicts = []
+ td = env.reset()
+ tensordicts.append(td)
+ traj_len = 20
+ for _ in range(traj_len - 1):
+ td = env.step(td.update(env.full_action_spec.rand()))
+ tensordicts.append(td)
+
+ traj = torch.stack(tensordicts)
+ assert traj.batch_size == torch.Size([traj_len])
+ finally:
+ env.close()
+
+
@pytest.mark.skipif(not _has_meltingpot, reason="Meltingpot not found")
class TestMeltingpot:
@pytest.mark.parametrize("substrate", MeltingpotWrapper.available_envs)
diff --git a/test/test_loggers.py b/test/test_loggers.py
index 735911bd95c..eb40ca1fdb8 100644
--- a/test/test_loggers.py
+++ b/test/test_loggers.py
@@ -281,25 +281,27 @@ def test_log_video(self, wandb_logger):
# C - number of image channels (e.g. 3 for RGB), H, W - image dimensions.
# the first 64 frames are black and the next 64 are white
video = torch.cat(
- (torch.zeros(64, 1, 32, 32), torch.full((64, 1, 32, 32), 255))
+ (torch.zeros(128, 1, 32, 32), torch.full((128, 1, 32, 32), 255))
)
video = video[None, :]
wandb_logger.log_video(
name="foo",
video=video,
- fps=6,
+ fps=4,
+ format="mp4",
)
wandb_logger.log_video(
- name="foo_12fps",
+ name="foo_16fps",
video=video,
- fps=24,
+ fps=16,
+ format="mp4",
)
sleep(0.01) # wait until events are registered
# check that fps can be passed and that it has impact on the length of the video
- video_6fps_size = wandb_logger.experiment.summary["foo"]["size"]
- video_24fps_size = wandb_logger.experiment.summary["foo_12fps"]["size"]
- assert video_6fps_size > video_24fps_size, video_6fps_size
+ video_4fps_size = wandb_logger.experiment.summary["foo"]["size"]
+ video_16fps_size = wandb_logger.experiment.summary["foo_16fps"]["size"]
+ assert video_4fps_size > video_16fps_size, (video_4fps_size, video_16fps_size)
# check that we catch the error in case the format of the tensor is wrong
video_wrong_format = torch.zeros(64, 2, 32, 32)
diff --git a/test/test_modules.py b/test/test_modules.py
index 00e58678788..8966b61154c 100644
--- a/test/test_modules.py
+++ b/test/test_modules.py
@@ -16,7 +16,7 @@
from packaging import version
from tensordict import TensorDict
from torch import nn
-from torchrl.data.tensor_specs import BoundedTensorSpec, CompositeSpec
+from torchrl.data.tensor_specs import Bounded, Composite
from torchrl.modules import (
CEMPlanner,
DTActor,
@@ -466,9 +466,7 @@ def test_dreamer_decoder(
@pytest.mark.parametrize("deter_size", [20, 30])
@pytest.mark.parametrize("action_size", [3, 6])
def test_rssm_prior(self, device, batch_size, stoch_size, deter_size, action_size):
- action_spec = BoundedTensorSpec(
- shape=(action_size,), dtype=torch.float32, low=-1, high=1
- )
+ action_spec = Bounded(shape=(action_size,), dtype=torch.float32, low=-1, high=1)
rssm_prior = RSSMPrior(
action_spec,
hidden_dim=stoch_size,
@@ -521,9 +519,7 @@ def test_rssm_posterior(self, device, batch_size, stoch_size, deter_size):
def test_rssm_rollout(
self, device, batch_size, temporal_size, stoch_size, deter_size, action_size
):
- action_spec = BoundedTensorSpec(
- shape=(action_size,), dtype=torch.float32, low=-1, high=1
- )
+ action_spec = Bounded(shape=(action_size,), dtype=torch.float32, low=-1, high=1)
rssm_prior = RSSMPrior(
action_spec,
hidden_dim=stoch_size,
@@ -650,10 +646,10 @@ def test_errors(self):
):
TanhModule(in_keys=["a", "b"], out_keys=["a"])
with pytest.raises(ValueError, match=r"The minimum value \(-2\) provided"):
- spec = BoundedTensorSpec(-1, 1, shape=())
+ spec = Bounded(-1, 1, shape=())
TanhModule(in_keys=["act"], low=-2, spec=spec)
with pytest.raises(ValueError, match=r"The maximum value \(-2\) provided to"):
- spec = BoundedTensorSpec(-1, 1, shape=())
+ spec = Bounded(-1, 1, shape=())
TanhModule(in_keys=["act"], high=-2, spec=spec)
with pytest.raises(ValueError, match="Got high < low"):
TanhModule(in_keys=["act"], high=-2, low=-1)
@@ -709,12 +705,12 @@ def test_multi_inputs(self, out_keys, has_spec):
if any(has_spec):
spec = {}
if has_spec[0]:
- spec.update({real_out_keys[0]: BoundedTensorSpec(-2.0, 2.0, shape=())})
+ spec.update({real_out_keys[0]: Bounded(-2.0, 2.0, shape=())})
low, high = -2.0, 2.0
if has_spec[1]:
- spec.update({real_out_keys[1]: BoundedTensorSpec(-3.0, 3.0, shape=())})
+ spec.update({real_out_keys[1]: Bounded(-3.0, 3.0, shape=())})
low, high = None, None
- spec = CompositeSpec(spec)
+ spec = Composite(spec)
else:
spec = None
low, high = -2.0, 2.0
diff --git a/test/test_rb.py b/test/test_rb.py
index e17cd410c49..24b33f89795 100644
--- a/test/test_rb.py
+++ b/test/test_rb.py
@@ -26,6 +26,7 @@
assert_allclose_td,
is_tensor_collection,
is_tensorclass,
+ LazyStackedTensorDict,
tensorclass,
TensorDict,
TensorDictBase,
@@ -58,6 +59,11 @@
SliceSampler,
SliceSamplerWithoutReplacement,
)
+from torchrl.data.replay_buffers.scheduler import (
+ LinearScheduler,
+ SchedulerList,
+ StepScheduler,
+)
from torchrl.data.replay_buffers.storages import (
LazyMemmapStorage,
@@ -99,6 +105,7 @@
VecNorm,
)
+
OLD_TORCH = parse(torch.__version__) < parse("2.0.0")
_has_tv = importlib.util.find_spec("torchvision") is not None
_has_gym = importlib.util.find_spec("gym") is not None
@@ -109,6 +116,11 @@
".".join([str(s) for s in version.parse(str(torch.__version__)).release])
) >= version.parse("2.3.0")
+ReplayBufferRNG = functools.partial(ReplayBuffer, generator=torch.Generator())
+TensorDictReplayBufferRNG = functools.partial(
+ TensorDictReplayBuffer, generator=torch.Generator()
+)
+
@pytest.mark.parametrize(
"sampler",
@@ -125,17 +137,27 @@
"rb_type,storage,datatype",
[
[ReplayBuffer, ListStorage, None],
+ [ReplayBufferRNG, ListStorage, None],
[TensorDictReplayBuffer, ListStorage, "tensordict"],
+ [TensorDictReplayBufferRNG, ListStorage, "tensordict"],
[RemoteTensorDictReplayBuffer, ListStorage, "tensordict"],
[ReplayBuffer, LazyTensorStorage, "tensor"],
[ReplayBuffer, LazyTensorStorage, "tensordict"],
[ReplayBuffer, LazyTensorStorage, "pytree"],
+ [ReplayBufferRNG, LazyTensorStorage, "tensor"],
+ [ReplayBufferRNG, LazyTensorStorage, "tensordict"],
+ [ReplayBufferRNG, LazyTensorStorage, "pytree"],
[TensorDictReplayBuffer, LazyTensorStorage, "tensordict"],
+ [TensorDictReplayBufferRNG, LazyTensorStorage, "tensordict"],
[RemoteTensorDictReplayBuffer, LazyTensorStorage, "tensordict"],
[ReplayBuffer, LazyMemmapStorage, "tensor"],
[ReplayBuffer, LazyMemmapStorage, "tensordict"],
[ReplayBuffer, LazyMemmapStorage, "pytree"],
+ [ReplayBufferRNG, LazyMemmapStorage, "tensor"],
+ [ReplayBufferRNG, LazyMemmapStorage, "tensordict"],
+ [ReplayBufferRNG, LazyMemmapStorage, "pytree"],
[TensorDictReplayBuffer, LazyMemmapStorage, "tensordict"],
+ [TensorDictReplayBufferRNG, LazyMemmapStorage, "tensordict"],
[RemoteTensorDictReplayBuffer, LazyMemmapStorage, "tensordict"],
],
)
@@ -531,6 +553,20 @@ def test_errors(self, storage_type):
):
storage_type(data, max_size=4)
+ def test_existsok_lazymemmap(self, tmpdir):
+ storage0 = LazyMemmapStorage(10, scratch_dir=tmpdir)
+ rb = ReplayBuffer(storage=storage0)
+ rb.extend(TensorDict(a=torch.randn(3), batch_size=[3]))
+
+ storage1 = LazyMemmapStorage(10, scratch_dir=tmpdir)
+ rb = ReplayBuffer(storage=storage1)
+ with pytest.raises(RuntimeError, match="existsok"):
+ rb.extend(TensorDict(a=torch.randn(3), batch_size=[3]))
+
+ storage2 = LazyMemmapStorage(10, scratch_dir=tmpdir, existsok=True)
+ rb = ReplayBuffer(storage=storage2)
+ rb.extend(TensorDict(a=torch.randn(3), batch_size=[3]))
+
@pytest.mark.parametrize(
"data_type", ["tensor", "tensordict", "tensorclass", "pytree"]
)
@@ -686,6 +722,20 @@ def test_storage_state_dict(self, storage_in, storage_out, init_out, backend):
s = new_replay_buffer.sample()
assert (s.exclude("index") == 1).all()
+ @pytest.mark.parametrize("storage_type", [LazyMemmapStorage, LazyTensorStorage])
+ def test_extend_lazystack(self, storage_type):
+
+ rb = ReplayBuffer(
+ storage=storage_type(6),
+ batch_size=2,
+ )
+ td1 = TensorDict(a=torch.rand(5, 4, 8), batch_size=5)
+ td2 = TensorDict(a=torch.rand(5, 3, 8), batch_size=5)
+ ltd = LazyStackedTensorDict(td1, td2, stack_dim=1)
+ rb.extend(ltd)
+ rb.sample(3)
+ assert len(rb) == 5
+
@pytest.mark.parametrize("device_data", get_default_devices())
@pytest.mark.parametrize("storage_type", [LazyMemmapStorage, LazyTensorStorage])
@pytest.mark.parametrize("data_type", ["tensor", "tc", "td", "pytree"])
@@ -1155,17 +1205,115 @@ def test_replay_buffer_trajectories(stack, reduction, datatype):
# sampled_td_filtered.batch_size = [3, 4]
+class TestRNG:
+ def test_rb_rng(self):
+ state = torch.random.get_rng_state()
+ rb = ReplayBufferRNG(sampler=RandomSampler(), storage=LazyTensorStorage(100))
+ rb.extend(torch.arange(100))
+ rb._rng.set_state(state)
+ a = rb.sample(32)
+ rb._rng.set_state(state)
+ b = rb.sample(32)
+ assert (a == b).all()
+ c = rb.sample(32)
+ assert (a != c).any()
+
+ def test_prb_rng(self):
+ state = torch.random.get_rng_state()
+ rb = ReplayBuffer(
+ sampler=PrioritizedSampler(100, 1.0, 1.0),
+ storage=LazyTensorStorage(100),
+ generator=torch.Generator(),
+ )
+ rb.extend(torch.arange(100))
+ rb.update_priority(index=torch.arange(100), priority=torch.arange(1, 101))
+
+ rb._rng.set_state(state)
+ a = rb.sample(32)
+
+ rb._rng.set_state(state)
+ b = rb.sample(32)
+ assert (a == b).all()
+
+ c = rb.sample(32)
+ assert (a != c).any()
+
+ def test_slice_rng(self):
+ state = torch.random.get_rng_state()
+ rb = ReplayBuffer(
+ sampler=SliceSampler(num_slices=4),
+ storage=LazyTensorStorage(100),
+ generator=torch.Generator(),
+ )
+ done = torch.zeros(100, 1, dtype=torch.bool)
+ done[49] = 1
+ done[-1] = 1
+ data = TensorDict(
+ {
+ "data": torch.arange(100),
+ ("next", "done"): done,
+ },
+ batch_size=[100],
+ )
+ rb.extend(data)
+
+ rb._rng.set_state(state)
+ a = rb.sample(32)
+
+ rb._rng.set_state(state)
+ b = rb.sample(32)
+ assert (a == b).all()
+
+ c = rb.sample(32)
+ assert (a != c).any()
+
+ def test_rng_state_dict(self):
+ state = torch.random.get_rng_state()
+ rb = ReplayBufferRNG(sampler=RandomSampler(), storage=LazyTensorStorage(100))
+ rb.extend(torch.arange(100))
+ rb._rng.set_state(state)
+ sd = rb.state_dict()
+ assert sd.get("_rng") is not None
+ a = rb.sample(32)
+
+ rb.load_state_dict(sd)
+ b = rb.sample(32)
+ assert (a == b).all()
+ c = rb.sample(32)
+ assert (a != c).any()
+
+ def test_rng_dumps(self, tmpdir):
+ state = torch.random.get_rng_state()
+ rb = ReplayBufferRNG(sampler=RandomSampler(), storage=LazyTensorStorage(100))
+ rb.extend(torch.arange(100))
+ rb._rng.set_state(state)
+ rb.dumps(tmpdir)
+ a = rb.sample(32)
+
+ rb.loads(tmpdir)
+ b = rb.sample(32)
+ assert (a == b).all()
+ c = rb.sample(32)
+ assert (a != c).any()
+
+
@pytest.mark.parametrize(
"rbtype,storage",
[
(ReplayBuffer, None),
(ReplayBuffer, ListStorage),
+ (ReplayBufferRNG, None),
+ (ReplayBufferRNG, ListStorage),
(PrioritizedReplayBuffer, None),
(PrioritizedReplayBuffer, ListStorage),
(TensorDictReplayBuffer, None),
(TensorDictReplayBuffer, ListStorage),
(TensorDictReplayBuffer, LazyTensorStorage),
(TensorDictReplayBuffer, LazyMemmapStorage),
+ (TensorDictReplayBufferRNG, None),
+ (TensorDictReplayBufferRNG, ListStorage),
+ (TensorDictReplayBufferRNG, LazyTensorStorage),
+ (TensorDictReplayBufferRNG, LazyMemmapStorage),
(TensorDictPrioritizedReplayBuffer, None),
(TensorDictPrioritizedReplayBuffer, ListStorage),
(TensorDictPrioritizedReplayBuffer, LazyTensorStorage),
@@ -1175,33 +1323,34 @@ def test_replay_buffer_trajectories(stack, reduction, datatype):
@pytest.mark.parametrize("size", [3, 5, 100])
@pytest.mark.parametrize("prefetch", [0])
class TestBuffers:
- _default_params_rb = {}
- _default_params_td_rb = {}
- _default_params_prb = {"alpha": 0.8, "beta": 0.9}
- _default_params_td_prb = {"alpha": 0.8, "beta": 0.9}
+
+ default_constr = {
+ ReplayBuffer: ReplayBuffer,
+ PrioritizedReplayBuffer: functools.partial(
+ PrioritizedReplayBuffer, alpha=0.8, beta=0.9
+ ),
+ TensorDictReplayBuffer: TensorDictReplayBuffer,
+ TensorDictPrioritizedReplayBuffer: functools.partial(
+ TensorDictPrioritizedReplayBuffer, alpha=0.8, beta=0.9
+ ),
+ TensorDictReplayBufferRNG: TensorDictReplayBufferRNG,
+ ReplayBufferRNG: ReplayBufferRNG,
+ }
def _get_rb(self, rbtype, size, storage, prefetch):
if storage is not None:
storage = storage(size)
- if rbtype is ReplayBuffer:
- params = self._default_params_rb
- elif rbtype is PrioritizedReplayBuffer:
- params = self._default_params_prb
- elif rbtype is TensorDictReplayBuffer:
- params = self._default_params_td_rb
- elif rbtype is TensorDictPrioritizedReplayBuffer:
- params = self._default_params_td_prb
- else:
- raise NotImplementedError(rbtype)
- rb = rbtype(storage=storage, prefetch=prefetch, batch_size=3, **params)
+ rb = self.default_constr[rbtype](
+ storage=storage, prefetch=prefetch, batch_size=3
+ )
return rb
def _get_datum(self, rbtype):
- if rbtype is ReplayBuffer:
+ if rbtype in (ReplayBuffer, ReplayBufferRNG):
data = torch.randint(100, (1,))
elif rbtype is PrioritizedReplayBuffer:
data = torch.randint(100, (1,))
- elif rbtype is TensorDictReplayBuffer:
+ elif rbtype in (TensorDictReplayBuffer, TensorDictReplayBufferRNG):
data = TensorDict({"a": torch.randint(100, (1,))}, [])
elif rbtype is TensorDictPrioritizedReplayBuffer:
data = TensorDict({"a": torch.randint(100, (1,))}, [])
@@ -1210,11 +1359,11 @@ def _get_datum(self, rbtype):
return data
def _get_data(self, rbtype, size):
- if rbtype is ReplayBuffer:
+ if rbtype in (ReplayBuffer, ReplayBufferRNG):
data = [torch.randint(100, (1,)) for _ in range(size)]
elif rbtype is PrioritizedReplayBuffer:
data = [torch.randint(100, (1,)) for _ in range(size)]
- elif rbtype is TensorDictReplayBuffer:
+ elif rbtype in (TensorDictReplayBuffer, TensorDictReplayBufferRNG):
data = TensorDict(
{
"a": torch.randint(100, (size,)),
@@ -1627,10 +1776,8 @@ def test_insert_transform(self):
not _has_tv, reason="needs torchvision dependency"
),
),
- pytest.param(
- partial(UnsqueezeTransform, unsqueeze_dim=-1), id="UnsqueezeTransform"
- ),
- pytest.param(partial(SqueezeTransform, squeeze_dim=-1), id="SqueezeTransform"),
+ pytest.param(partial(UnsqueezeTransform, dim=-1), id="UnsqueezeTransform"),
+ pytest.param(partial(SqueezeTransform, dim=-1), id="SqueezeTransform"),
GrayScale,
pytest.param(partial(ObservationNorm, loc=1, scale=2), id="ObservationNorm"),
pytest.param(partial(CatFrames, dim=-3, N=4), id="CatFrames"),
@@ -1950,13 +2097,16 @@ def exec_multiproc_rb(
init=True,
writer_type=TensorDictRoundRobinWriter,
sampler_type=RandomSampler,
+ device=None,
):
rb = TensorDictReplayBuffer(
storage=storage_type(21), writer=writer_type(), sampler=sampler_type()
)
if init:
td = TensorDict(
- {"a": torch.zeros(10), "next": {"reward": torch.ones(10)}}, [10]
+ {"a": torch.zeros(10), "next": {"reward": torch.ones(10)}},
+ [10],
+ device=device,
)
rb.extend(td)
q0 = mp.Queue(1)
@@ -1984,13 +2134,6 @@ def test_error_list(self):
with pytest.raises(RuntimeError, match="Cannot share a storage of type"):
self.exec_multiproc_rb(storage_type=ListStorage)
- def test_error_nonshared(self):
- # non shared tensor storage cannot be shared
- with pytest.raises(
- RuntimeError, match="The storage must be place in shared memory"
- ):
- self.exec_multiproc_rb(storage_type=LazyTensorStorage)
-
def test_error_maxwriter(self):
# TensorDictMaxValueWriter cannot be shared
with pytest.raises(RuntimeError, match="cannot be shared between processes"):
@@ -2902,6 +3045,77 @@ def test_prioritized_slice_sampler_episodes(device):
), "after priority update, only episode 1 and 3 are expected to be sampled"
+@pytest.mark.parametrize("alpha", [0.6, torch.tensor(1.0)])
+@pytest.mark.parametrize("beta", [0.7, torch.tensor(0.1)])
+@pytest.mark.parametrize("gamma", [0.1])
+@pytest.mark.parametrize("total_steps", [200])
+@pytest.mark.parametrize("n_annealing_steps", [100])
+@pytest.mark.parametrize("anneal_every_n", [10, 159])
+@pytest.mark.parametrize("alpha_min", [0, 0.2])
+@pytest.mark.parametrize("beta_max", [1, 1.4])
+def test_prioritized_parameter_scheduler(
+ alpha,
+ beta,
+ gamma,
+ total_steps,
+ n_annealing_steps,
+ anneal_every_n,
+ alpha_min,
+ beta_max,
+):
+ rb = TensorDictPrioritizedReplayBuffer(
+ alpha=alpha, beta=beta, storage=ListStorage(max_size=1000)
+ )
+ data = TensorDict({"data": torch.randn(1000, 5)}, batch_size=1000)
+ rb.extend(data)
+ alpha_scheduler = LinearScheduler(
+ rb, param_name="alpha", final_value=alpha_min, num_steps=n_annealing_steps
+ )
+ beta_scheduler = StepScheduler(
+ rb,
+ param_name="beta",
+ gamma=gamma,
+ n_steps=anneal_every_n,
+ max_value=beta_max,
+ mode="additive",
+ )
+
+ scheduler = SchedulerList(schedulers=(alpha_scheduler, beta_scheduler))
+
+ alpha = alpha if torch.is_tensor(alpha) else torch.tensor(alpha)
+ alpha_min = torch.tensor(alpha_min)
+ expected_alpha_vals = torch.linspace(alpha, alpha_min, n_annealing_steps + 1)
+ expected_alpha_vals = torch.nn.functional.pad(
+ expected_alpha_vals, (0, total_steps - n_annealing_steps), value=alpha_min
+ )
+
+ expected_beta_vals = [beta]
+ annealing_steps = total_steps // anneal_every_n
+ gammas = torch.arange(0, annealing_steps + 1, dtype=torch.float32) * gamma
+ expected_beta_vals = (
+ (beta + gammas).repeat_interleave(anneal_every_n).clip(None, beta_max)
+ )
+ for i in range(total_steps):
+ curr_alpha = rb.sampler.alpha
+ torch.testing.assert_close(
+ curr_alpha
+ if torch.is_tensor(curr_alpha)
+ else torch.tensor(curr_alpha).float(),
+ expected_alpha_vals[i],
+ msg=f"expected {expected_alpha_vals[i]}, got {curr_alpha}",
+ )
+ curr_beta = rb.sampler.beta
+ torch.testing.assert_close(
+ curr_beta
+ if torch.is_tensor(curr_beta)
+ else torch.tensor(curr_beta).float(),
+ expected_beta_vals[i],
+ msg=f"expected {expected_beta_vals[i]}, got {curr_beta}",
+ )
+ rb.sample(20)
+ scheduler.step()
+
+
class TestEnsemble:
def _make_data(self, data_type):
if data_type is torch.Tensor:
diff --git a/test/test_specs.py b/test/test_specs.py
index 2d597d770f0..82d2b7f2e1d 100644
--- a/test/test_specs.py
+++ b/test/test_specs.py
@@ -4,6 +4,7 @@
# LICENSE file in the root directory of this source tree.
import argparse
import contextlib
+import warnings
import numpy as np
import pytest
@@ -14,19 +15,32 @@
from tensordict import LazyStackedTensorDict, TensorDict, TensorDictBase
from tensordict.utils import _unravel_key_to_tuple
from torchrl._utils import _make_ordinal_device
+
from torchrl.data.tensor_specs import (
_keys_to_empty_composite_spec,
+ Binary,
BinaryDiscreteTensorSpec,
+ Bounded,
BoundedTensorSpec,
+ Categorical,
+ Composite,
CompositeSpec,
+ ContinuousBox,
DiscreteTensorSpec,
- LazyStackedCompositeSpec,
+ MultiCategorical,
MultiDiscreteTensorSpec,
+ MultiOneHot,
MultiOneHotDiscreteTensorSpec,
+ NonTensor,
NonTensorSpec,
+ OneHot,
OneHotDiscreteTensorSpec,
+ StackedComposite,
TensorSpec,
+ Unbounded,
+ UnboundedContinuous,
UnboundedContinuousTensorSpec,
+ UnboundedDiscrete,
UnboundedDiscreteTensorSpec,
)
from torchrl.data.utils import check_no_exclusive_keys, consolidate_spec
@@ -38,9 +52,7 @@ def test_bounded(dtype):
np.random.seed(0)
for _ in range(100):
bounds = torch.randn(2).sort()[0]
- ts = BoundedTensorSpec(
- bounds[0].item(), bounds[1].item(), torch.Size((1,)), dtype=dtype
- )
+ ts = Bounded(bounds[0].item(), bounds[1].item(), torch.Size((1,)), dtype=dtype)
_dtype = dtype
if dtype is None:
_dtype = torch.get_default_dtype()
@@ -53,7 +65,7 @@ def test_bounded(dtype):
assert (ts.encode(ts.to_numpy(r)) == r).all()
-@pytest.mark.parametrize("cls", [OneHotDiscreteTensorSpec, DiscreteTensorSpec])
+@pytest.mark.parametrize("cls", [OneHot, Categorical])
def test_discrete(cls):
torch.manual_seed(0)
np.random.seed(0)
@@ -78,7 +90,7 @@ def test_discrete(cls):
def test_unbounded(dtype):
torch.manual_seed(0)
np.random.seed(0)
- ts = UnboundedContinuousTensorSpec(dtype=dtype)
+ ts = Unbounded(dtype=dtype)
if dtype is None:
dtype = torch.get_default_dtype()
@@ -99,7 +111,7 @@ def test_ndbounded(dtype, shape):
for _ in range(100):
lb = torch.rand(10) - 1
ub = torch.rand(10) + 1
- ts = BoundedTensorSpec(lb, ub, dtype=dtype)
+ ts = Bounded(lb, ub, dtype=dtype)
_dtype = dtype
if dtype is None:
_dtype = torch.get_default_dtype()
@@ -150,7 +162,7 @@ def test_ndunbounded(dtype, n, shape):
torch.manual_seed(0)
np.random.seed(0)
- ts = UnboundedContinuousTensorSpec(
+ ts = Unbounded(
shape=[
n,
],
@@ -195,7 +207,7 @@ def test_binary(n, shape):
torch.manual_seed(0)
np.random.seed(0)
- ts = BinaryDiscreteTensorSpec(n)
+ ts = Binary(n)
for _ in range(100):
r = ts.rand(shape)
assert r.shape == torch.Size(
@@ -238,7 +250,7 @@ def test_binary(n, shape):
def test_mult_onehot(shape, ns):
torch.manual_seed(0)
np.random.seed(0)
- ts = MultiOneHotDiscreteTensorSpec(nvec=ns)
+ ts = MultiOneHot(nvec=ns)
for _ in range(100):
r = ts.rand(shape)
assert r.shape == torch.Size(
@@ -279,7 +291,7 @@ def test_mult_onehot(shape, ns):
def test_multi_discrete(shape, ns, dtype):
torch.manual_seed(0)
np.random.seed(0)
- ts = MultiDiscreteTensorSpec(ns, dtype=dtype)
+ ts = MultiCategorical(ns, dtype=dtype)
_real_shape = shape if shape is not None else []
nvec_shape = torch.tensor(ns).size()
for _ in range(100):
@@ -315,9 +327,9 @@ def test_multi_discrete(shape, ns, dtype):
@pytest.mark.parametrize("device", get_default_devices())
@pytest.mark.parametrize("shape", [None, [], [1], [1, 2]])
def test_discrete_conversion(n, device, shape):
- categorical = DiscreteTensorSpec(n, device=device, shape=shape)
+ categorical = Categorical(n, device=device, shape=shape)
shape_one_hot = [n] if not shape else [*shape, n]
- one_hot = OneHotDiscreteTensorSpec(n, device=device, shape=shape_one_hot)
+ one_hot = OneHot(n, device=device, shape=shape_one_hot)
assert categorical != one_hot
assert categorical.to_one_hot_spec() == one_hot
@@ -333,8 +345,8 @@ def test_discrete_conversion(n, device, shape):
@pytest.mark.parametrize("shape", [torch.Size([3]), torch.Size([4, 5])])
@pytest.mark.parametrize("device", get_default_devices())
def test_multi_discrete_conversion(ns, shape, device):
- categorical = MultiDiscreteTensorSpec(ns, device=device)
- one_hot = MultiOneHotDiscreteTensorSpec(ns, device=device)
+ categorical = MultiCategorical(ns, device=device)
+ one_hot = MultiOneHot(ns, device=device)
assert categorical != one_hot
assert categorical.to_one_hot_spec() == one_hot
@@ -356,14 +368,14 @@ def _composite_spec(shape, is_complete=True, device=None, dtype=None):
torch.manual_seed(0)
np.random.seed(0)
- return CompositeSpec(
- obs=BoundedTensorSpec(
+ return Composite(
+ obs=Bounded(
torch.zeros(*shape, 3, 32, 32),
torch.ones(*shape, 3, 32, 32),
dtype=dtype,
device=device,
),
- act=UnboundedContinuousTensorSpec(
+ act=Unbounded(
(
*shape,
7,
@@ -379,9 +391,9 @@ def _composite_spec(shape, is_complete=True, device=None, dtype=None):
def test_getitem(self, shape, is_complete, device, dtype):
ts = self._composite_spec(shape, is_complete, device, dtype)
- assert isinstance(ts["obs"], BoundedTensorSpec)
+ assert isinstance(ts["obs"], Bounded)
if is_complete:
- assert isinstance(ts["act"], UnboundedContinuousTensorSpec)
+ assert isinstance(ts["act"], Unbounded)
else:
assert ts["act"] is None
with pytest.raises(KeyError):
@@ -397,21 +409,17 @@ def test_setitem_forbidden_keys(self, shape, is_complete, device, dtype):
def test_setitem_matches_device(self, shape, is_complete, device, dtype, dest):
ts = self._composite_spec(shape, is_complete, device, dtype)
- ts["good"] = UnboundedContinuousTensorSpec(
- shape=shape, device=device, dtype=dtype
- )
+ ts["good"] = Unbounded(shape=shape, device=device, dtype=dtype)
cm = (
contextlib.nullcontext()
if (device == dest) or (device is None)
else pytest.raises(
- RuntimeError, match="All devices of CompositeSpec must match"
+ RuntimeError, match="All devices of Composite must match"
)
)
with cm:
# auto-casting is introduced since v0.3
- ts["bad"] = UnboundedContinuousTensorSpec(
- shape=shape, device=dest, dtype=dtype
- )
+ ts["bad"] = Unbounded(shape=shape, device=dest, dtype=dtype)
assert ts.device == device
assert ts["good"].device == (
device if device is not None else torch.zeros(()).device
@@ -490,7 +498,7 @@ def test_rand(self, shape, is_complete, device, dtype, shape_other):
def test_repr(self, shape, is_complete, device, dtype):
ts = self._composite_spec(shape, is_complete, device, dtype)
output = repr(ts)
- assert output.startswith("CompositeSpec")
+ assert output.startswith("Composite")
assert "obs: " in output
assert "act: " in output
@@ -606,7 +614,7 @@ def test_nested_composite_spec_delitem(self, shape, is_complete, device, dtype):
def test_nested_composite_spec_update(self, shape, is_complete, device, dtype):
ts = self._composite_spec(shape, is_complete, device, dtype)
ts["nested_cp"] = self._composite_spec(shape, is_complete, device, dtype)
- td2 = CompositeSpec(new=None)
+ td2 = Composite(new=None)
ts.update(td2)
assert set(ts.keys(include_nested=True)) == {
"obs",
@@ -619,7 +627,7 @@ def test_nested_composite_spec_update(self, shape, is_complete, device, dtype):
ts = self._composite_spec(shape, is_complete, device, dtype)
ts["nested_cp"] = self._composite_spec(shape, is_complete, device, dtype)
- td2 = CompositeSpec(nested_cp=CompositeSpec(new=None).to(device))
+ td2 = Composite(nested_cp=Composite(new=None).to(device))
ts.update(td2)
assert set(ts.keys(include_nested=True)) == {
"obs",
@@ -632,7 +640,7 @@ def test_nested_composite_spec_update(self, shape, is_complete, device, dtype):
ts = self._composite_spec(shape, is_complete, device, dtype)
ts["nested_cp"] = self._composite_spec(shape, is_complete, device, dtype)
- td2 = CompositeSpec(nested_cp=CompositeSpec(act=None).to(device))
+ td2 = Composite(nested_cp=Composite(act=None).to(device))
ts.update(td2)
assert set(ts.keys(include_nested=True)) == {
"obs",
@@ -645,13 +653,13 @@ def test_nested_composite_spec_update(self, shape, is_complete, device, dtype):
ts = self._composite_spec(shape, is_complete, device, dtype)
ts["nested_cp"] = self._composite_spec(shape, is_complete, device, dtype)
- td2 = CompositeSpec(
- nested_cp=CompositeSpec(act=None, shape=shape).to(device), shape=shape
+ td2 = Composite(
+ nested_cp=Composite(act=None, shape=shape).to(device), shape=shape
)
ts.update(td2)
- td2 = CompositeSpec(
- nested_cp=CompositeSpec(
- act=UnboundedContinuousTensorSpec(shape=shape, device=device),
+ td2 = Composite(
+ nested_cp=Composite(
+ act=Unbounded(shape=shape, device=device),
shape=shape,
),
shape=shape,
@@ -668,8 +676,8 @@ def test_nested_composite_spec_update(self, shape, is_complete, device, dtype):
def test_change_batch_size(self, shape, is_complete, device, dtype):
ts = self._composite_spec(shape, is_complete, device, dtype)
- ts["nested"] = CompositeSpec(
- leaf=UnboundedContinuousTensorSpec(shape, device=device),
+ ts["nested"] = Composite(
+ leaf=Unbounded(shape, device=device),
shape=shape,
device=device,
)
@@ -690,12 +698,12 @@ def test_change_batch_size(self, shape, is_complete, device, dtype):
@pytest.mark.parametrize("device", get_default_devices())
def test_create_composite_nested(shape, device):
d = [
- {("a", "b"): UnboundedContinuousTensorSpec(shape=shape, device=device)},
- {"a": {"b": UnboundedContinuousTensorSpec(shape=shape, device=device)}},
+ {("a", "b"): Unbounded(shape=shape, device=device)},
+ {"a": {"b": Unbounded(shape=shape, device=device)}},
]
for _d in d:
- c = CompositeSpec(_d, shape=shape)
- assert isinstance(c["a", "b"], UnboundedContinuousTensorSpec)
+ c = Composite(_d, shape=shape)
+ assert isinstance(c["a", "b"], Unbounded)
assert c["a"].shape == torch.Size(shape)
assert c.device is None # device not explicitly passed
assert c["a"].device is None # device not explicitly passed
@@ -708,10 +716,8 @@ def test_create_composite_nested(shape, device):
@pytest.mark.parametrize("recurse", [True, False])
def test_lock(recurse):
shape = [3, 4, 5]
- spec = CompositeSpec(
- a=CompositeSpec(
- b=CompositeSpec(shape=shape[:3], device="cpu"), shape=shape[:2]
- ),
+ spec = Composite(
+ a=Composite(b=Composite(shape=shape[:3], device="cpu"), shape=shape[:2]),
shape=shape[:1],
)
spec["a"] = spec["a"].clone()
@@ -719,15 +725,15 @@ def test_lock(recurse):
assert not spec.locked
spec.lock_(recurse=recurse)
assert spec.locked
- with pytest.raises(RuntimeError, match="Cannot modify a locked CompositeSpec."):
+ with pytest.raises(RuntimeError, match="Cannot modify a locked Composite."):
spec["a"] = spec["a"].clone()
- with pytest.raises(RuntimeError, match="Cannot modify a locked CompositeSpec."):
+ with pytest.raises(RuntimeError, match="Cannot modify a locked Composite."):
spec.set("a", spec["a"].clone())
if recurse:
assert spec["a"].locked
- with pytest.raises(RuntimeError, match="Cannot modify a locked CompositeSpec."):
+ with pytest.raises(RuntimeError, match="Cannot modify a locked Composite."):
spec["a"].set("b", spec["a", "b"].clone())
- with pytest.raises(RuntimeError, match="Cannot modify a locked CompositeSpec."):
+ with pytest.raises(RuntimeError, match="Cannot modify a locked Composite."):
spec["a", "b"] = spec["a", "b"].clone()
else:
assert not spec["a"].locked
@@ -763,33 +769,25 @@ def test_equality_bounded(self):
device = "cpu"
dtype = torch.float16
- ts = BoundedTensorSpec(minimum, maximum, torch.Size((1,)), device, dtype)
+ ts = Bounded(minimum, maximum, torch.Size((1,)), device, dtype)
- ts_same = BoundedTensorSpec(minimum, maximum, torch.Size((1,)), device, dtype)
+ ts_same = Bounded(minimum, maximum, torch.Size((1,)), device, dtype)
assert ts == ts_same
- ts_other = BoundedTensorSpec(
- minimum + 1, maximum, torch.Size((1,)), device, dtype
- )
+ ts_other = Bounded(minimum + 1, maximum, torch.Size((1,)), device, dtype)
assert ts != ts_other
- ts_other = BoundedTensorSpec(
- minimum, maximum + 1, torch.Size((1,)), device, dtype
- )
+ ts_other = Bounded(minimum, maximum + 1, torch.Size((1,)), device, dtype)
assert ts != ts_other
if torch.cuda.device_count():
- ts_other = BoundedTensorSpec(
- minimum, maximum, torch.Size((1,)), "cuda:0", dtype
- )
+ ts_other = Bounded(minimum, maximum, torch.Size((1,)), "cuda:0", dtype)
assert ts != ts_other
- ts_other = BoundedTensorSpec(
- minimum, maximum, torch.Size((1,)), device, torch.float64
- )
+ ts_other = Bounded(minimum, maximum, torch.Size((1,)), device, torch.float64)
assert ts != ts_other
ts_other = TestEquality._ts_make_all_fields_equal(
- UnboundedContinuousTensorSpec(device=device, dtype=dtype), ts
+ Unbounded(device=device, dtype=dtype), ts
)
assert ts != ts_other
@@ -799,38 +797,34 @@ def test_equality_onehot(self):
dtype = torch.float16
use_register = False
- ts = OneHotDiscreteTensorSpec(
- n=n, device=device, dtype=dtype, use_register=use_register
- )
+ ts = OneHot(n=n, device=device, dtype=dtype, use_register=use_register)
- ts_same = OneHotDiscreteTensorSpec(
- n=n, device=device, dtype=dtype, use_register=use_register
- )
+ ts_same = OneHot(n=n, device=device, dtype=dtype, use_register=use_register)
assert ts == ts_same
- ts_other = OneHotDiscreteTensorSpec(
+ ts_other = OneHot(
n=n + 1, device=device, dtype=dtype, use_register=use_register
)
assert ts != ts_other
if torch.cuda.device_count():
- ts_other = OneHotDiscreteTensorSpec(
+ ts_other = OneHot(
n=n, device="cuda:0", dtype=dtype, use_register=use_register
)
assert ts != ts_other
- ts_other = OneHotDiscreteTensorSpec(
+ ts_other = OneHot(
n=n, device=device, dtype=torch.float64, use_register=use_register
)
assert ts != ts_other
- ts_other = OneHotDiscreteTensorSpec(
+ ts_other = OneHot(
n=n, device=device, dtype=dtype, use_register=not use_register
)
assert ts != ts_other
ts_other = TestEquality._ts_make_all_fields_equal(
- UnboundedContinuousTensorSpec(device=device, dtype=dtype), ts
+ Unbounded(device=device, dtype=dtype), ts
)
assert ts != ts_other
@@ -838,21 +832,25 @@ def test_equality_unbounded(self):
device = "cpu"
dtype = torch.float16
- ts = UnboundedContinuousTensorSpec(device=device, dtype=dtype)
+ ts = Unbounded(device=device, dtype=dtype)
- ts_same = UnboundedContinuousTensorSpec(device=device, dtype=dtype)
+ ts_same = Unbounded(device=device, dtype=dtype)
assert ts == ts_same
if torch.cuda.device_count():
- ts_other = UnboundedContinuousTensorSpec(device="cuda:0", dtype=dtype)
+ ts_other = Unbounded(device="cuda:0", dtype=dtype)
assert ts != ts_other
- ts_other = UnboundedContinuousTensorSpec(device=device, dtype=torch.float64)
+ ts_other = Unbounded(device=device, dtype=torch.float64)
assert ts != ts_other
ts_other = TestEquality._ts_make_all_fields_equal(
- BoundedTensorSpec(0, 1, torch.Size((1,)), device, dtype), ts
+ Bounded(0, 1, torch.Size((1,)), device, dtype), ts
+ )
+ ts_other.space = ContinuousBox(
+ ts_other.space.low * 0, ts_other.space.high * 0 + 1
)
+ assert ts.space != ts_other.space, (ts.space, ts_other.space)
assert ts != ts_other
def test_equality_ndbounded(self):
@@ -861,36 +859,28 @@ def test_equality_ndbounded(self):
device = "cpu"
dtype = torch.float16
- ts = BoundedTensorSpec(low=minimum, high=maximum, device=device, dtype=dtype)
+ ts = Bounded(low=minimum, high=maximum, device=device, dtype=dtype)
- ts_same = BoundedTensorSpec(
- low=minimum, high=maximum, device=device, dtype=dtype
- )
+ ts_same = Bounded(low=minimum, high=maximum, device=device, dtype=dtype)
assert ts == ts_same
- ts_other = BoundedTensorSpec(
- low=minimum + 1, high=maximum, device=device, dtype=dtype
- )
+ ts_other = Bounded(low=minimum + 1, high=maximum, device=device, dtype=dtype)
assert ts != ts_other
- ts_other = BoundedTensorSpec(
- low=minimum, high=maximum + 1, device=device, dtype=dtype
- )
+ ts_other = Bounded(low=minimum, high=maximum + 1, device=device, dtype=dtype)
assert ts != ts_other
if torch.cuda.device_count():
- ts_other = BoundedTensorSpec(
- low=minimum, high=maximum, device="cuda:0", dtype=dtype
- )
+ ts_other = Bounded(low=minimum, high=maximum, device="cuda:0", dtype=dtype)
assert ts != ts_other
- ts_other = BoundedTensorSpec(
+ ts_other = Bounded(
low=minimum, high=maximum, device=device, dtype=torch.float64
)
assert ts != ts_other
ts_other = TestEquality._ts_make_all_fields_equal(
- UnboundedContinuousTensorSpec(device=device, dtype=dtype), ts
+ Unbounded(device=device, dtype=dtype), ts
)
assert ts != ts_other
@@ -900,32 +890,28 @@ def test_equality_discrete(self):
device = "cpu"
dtype = torch.float16
- ts = DiscreteTensorSpec(n=n, shape=shape, device=device, dtype=dtype)
+ ts = Categorical(n=n, shape=shape, device=device, dtype=dtype)
- ts_same = DiscreteTensorSpec(n=n, shape=shape, device=device, dtype=dtype)
+ ts_same = Categorical(n=n, shape=shape, device=device, dtype=dtype)
assert ts == ts_same
- ts_other = DiscreteTensorSpec(n=n + 1, shape=shape, device=device, dtype=dtype)
+ ts_other = Categorical(n=n + 1, shape=shape, device=device, dtype=dtype)
assert ts != ts_other
if torch.cuda.device_count():
- ts_other = DiscreteTensorSpec(
- n=n, shape=shape, device="cuda:0", dtype=dtype
- )
+ ts_other = Categorical(n=n, shape=shape, device="cuda:0", dtype=dtype)
assert ts != ts_other
- ts_other = DiscreteTensorSpec(
- n=n, shape=shape, device=device, dtype=torch.float64
- )
+ ts_other = Categorical(n=n, shape=shape, device=device, dtype=torch.float64)
assert ts != ts_other
- ts_other = DiscreteTensorSpec(
+ ts_other = Categorical(
n=n, shape=torch.Size([2]), device=device, dtype=torch.float64
)
assert ts != ts_other
ts_other = TestEquality._ts_make_all_fields_equal(
- UnboundedContinuousTensorSpec(device=device, dtype=dtype), ts
+ Unbounded(device=device, dtype=dtype), ts
)
assert ts != ts_other
@@ -941,30 +927,24 @@ def test_equality_ndunbounded(self, shape):
device = "cpu"
dtype = torch.float16
- ts = UnboundedContinuousTensorSpec(shape=shape, device=device, dtype=dtype)
+ ts = Unbounded(shape=shape, device=device, dtype=dtype)
- ts_same = UnboundedContinuousTensorSpec(shape=shape, device=device, dtype=dtype)
+ ts_same = Unbounded(shape=shape, device=device, dtype=dtype)
assert ts == ts_same
- other_shape = 13 if type(shape) == int else torch.Size(np.array(shape) + 10)
- ts_other = UnboundedContinuousTensorSpec(
- shape=other_shape, device=device, dtype=dtype
- )
+ other_shape = 13 if isinstance(shape, int) else torch.Size(np.array(shape) + 10)
+ ts_other = Unbounded(shape=other_shape, device=device, dtype=dtype)
assert ts != ts_other
if torch.cuda.device_count():
- ts_other = UnboundedContinuousTensorSpec(
- shape=shape, device="cuda:0", dtype=dtype
- )
+ ts_other = Unbounded(shape=shape, device="cuda:0", dtype=dtype)
assert ts != ts_other
- ts_other = UnboundedContinuousTensorSpec(
- shape=shape, device=device, dtype=torch.float64
- )
+ ts_other = Unbounded(shape=shape, device=device, dtype=torch.float64)
assert ts != ts_other
ts_other = TestEquality._ts_make_all_fields_equal(
- BoundedTensorSpec(0, 1, torch.Size((1,)), device, dtype), ts
+ Bounded(0, 1, torch.Size((1,)), device, dtype), ts
)
# Unbounded and bounded without space are technically the same
assert ts == ts_other
@@ -974,23 +954,23 @@ def test_equality_binary(self):
device = "cpu"
dtype = torch.float16
- ts = BinaryDiscreteTensorSpec(n=n, device=device, dtype=dtype)
+ ts = Binary(n=n, device=device, dtype=dtype)
- ts_same = BinaryDiscreteTensorSpec(n=n, device=device, dtype=dtype)
+ ts_same = Binary(n=n, device=device, dtype=dtype)
assert ts == ts_same
- ts_other = BinaryDiscreteTensorSpec(n=n + 5, device=device, dtype=dtype)
+ ts_other = Binary(n=n + 5, device=device, dtype=dtype)
assert ts != ts_other
if torch.cuda.device_count():
- ts_other = BinaryDiscreteTensorSpec(n=n, device="cuda:0", dtype=dtype)
+ ts_other = Binary(n=n, device="cuda:0", dtype=dtype)
assert ts != ts_other
- ts_other = BinaryDiscreteTensorSpec(n=n, device=device, dtype=torch.float64)
+ ts_other = Binary(n=n, device=device, dtype=torch.float64)
assert ts != ts_other
ts_other = TestEquality._ts_make_all_fields_equal(
- BoundedTensorSpec(0, 1, torch.Size((1,)), device, dtype), ts
+ Bounded(0, 1, torch.Size((1,)), device, dtype), ts
)
assert ts != ts_other
@@ -999,42 +979,32 @@ def test_equality_multi_onehot(self, nvec):
device = "cpu"
dtype = torch.float16
- ts = MultiOneHotDiscreteTensorSpec(nvec=nvec, device=device, dtype=dtype)
+ ts = MultiOneHot(nvec=nvec, device=device, dtype=dtype)
- ts_same = MultiOneHotDiscreteTensorSpec(nvec=nvec, device=device, dtype=dtype)
+ ts_same = MultiOneHot(nvec=nvec, device=device, dtype=dtype)
assert ts == ts_same
other_nvec = np.array(nvec) + 3
- ts_other = MultiOneHotDiscreteTensorSpec(
- nvec=other_nvec, device=device, dtype=dtype
- )
+ ts_other = MultiOneHot(nvec=other_nvec, device=device, dtype=dtype)
assert ts != ts_other
other_nvec = [12]
- ts_other = MultiOneHotDiscreteTensorSpec(
- nvec=other_nvec, device=device, dtype=dtype
- )
+ ts_other = MultiOneHot(nvec=other_nvec, device=device, dtype=dtype)
assert ts != ts_other
other_nvec = [12, 13]
- ts_other = MultiOneHotDiscreteTensorSpec(
- nvec=other_nvec, device=device, dtype=dtype
- )
+ ts_other = MultiOneHot(nvec=other_nvec, device=device, dtype=dtype)
assert ts != ts_other
if torch.cuda.device_count():
- ts_other = MultiOneHotDiscreteTensorSpec(
- nvec=nvec, device="cuda:0", dtype=dtype
- )
+ ts_other = MultiOneHot(nvec=nvec, device="cuda:0", dtype=dtype)
assert ts != ts_other
- ts_other = MultiOneHotDiscreteTensorSpec(
- nvec=nvec, device=device, dtype=torch.float64
- )
+ ts_other = MultiOneHot(nvec=nvec, device=device, dtype=torch.float64)
assert ts != ts_other
ts_other = TestEquality._ts_make_all_fields_equal(
- BoundedTensorSpec(0, 1, torch.Size((1,)), device, dtype), ts
+ Bounded(0, 1, torch.Size((1,)), device, dtype), ts
)
assert ts != ts_other
@@ -1043,34 +1013,32 @@ def test_equality_multi_discrete(self, nvec):
device = "cpu"
dtype = torch.float16
- ts = MultiDiscreteTensorSpec(nvec=nvec, device=device, dtype=dtype)
+ ts = MultiCategorical(nvec=nvec, device=device, dtype=dtype)
- ts_same = MultiDiscreteTensorSpec(nvec=nvec, device=device, dtype=dtype)
+ ts_same = MultiCategorical(nvec=nvec, device=device, dtype=dtype)
assert ts == ts_same
other_nvec = np.array(nvec) + 3
- ts_other = MultiDiscreteTensorSpec(nvec=other_nvec, device=device, dtype=dtype)
+ ts_other = MultiCategorical(nvec=other_nvec, device=device, dtype=dtype)
assert ts != ts_other
other_nvec = [12]
- ts_other = MultiDiscreteTensorSpec(nvec=other_nvec, device=device, dtype=dtype)
+ ts_other = MultiCategorical(nvec=other_nvec, device=device, dtype=dtype)
assert ts != ts_other
other_nvec = [12, 13]
- ts_other = MultiDiscreteTensorSpec(nvec=other_nvec, device=device, dtype=dtype)
+ ts_other = MultiCategorical(nvec=other_nvec, device=device, dtype=dtype)
assert ts != ts_other
if torch.cuda.device_count():
- ts_other = MultiDiscreteTensorSpec(nvec=nvec, device="cuda:0", dtype=dtype)
+ ts_other = MultiCategorical(nvec=nvec, device="cuda:0", dtype=dtype)
assert ts != ts_other
- ts_other = MultiDiscreteTensorSpec(
- nvec=nvec, device=device, dtype=torch.float64
- )
+ ts_other = MultiCategorical(nvec=nvec, device=device, dtype=torch.float64)
assert ts != ts_other
ts_other = TestEquality._ts_make_all_fields_equal(
- BoundedTensorSpec(0, 1, torch.Size((1,)), device, dtype), ts
+ Bounded(0, 1, torch.Size((1,)), device, dtype), ts
)
assert ts != ts_other
@@ -1080,69 +1048,63 @@ def test_equality_composite(self):
device = "cpu"
dtype = torch.float16
- bounded = BoundedTensorSpec(0, 1, torch.Size((1,)), device, dtype)
- bounded_same = BoundedTensorSpec(0, 1, torch.Size((1,)), device, dtype)
- bounded_other = BoundedTensorSpec(0, 2, torch.Size((1,)), device, dtype)
+ bounded = Bounded(0, 1, torch.Size((1,)), device, dtype)
+ bounded_same = Bounded(0, 1, torch.Size((1,)), device, dtype)
+ bounded_other = Bounded(0, 2, torch.Size((1,)), device, dtype)
- nd = BoundedTensorSpec(
- low=minimum, high=maximum + 1, device=device, dtype=dtype
- )
- nd_same = BoundedTensorSpec(
- low=minimum, high=maximum + 1, device=device, dtype=dtype
- )
- _ = BoundedTensorSpec(low=minimum, high=maximum + 3, device=device, dtype=dtype)
+ nd = Bounded(low=minimum, high=maximum + 1, device=device, dtype=dtype)
+ nd_same = Bounded(low=minimum, high=maximum + 1, device=device, dtype=dtype)
+ _ = Bounded(low=minimum, high=maximum + 3, device=device, dtype=dtype)
# Equality tests
- ts = CompositeSpec(ts1=bounded)
- ts_same = CompositeSpec(ts1=bounded)
+ ts = Composite(ts1=bounded)
+ ts_same = Composite(ts1=bounded)
assert ts == ts_same
- ts = CompositeSpec(ts1=bounded)
- ts_same = CompositeSpec(ts1=bounded_same)
+ ts = Composite(ts1=bounded)
+ ts_same = Composite(ts1=bounded_same)
assert ts == ts_same
- ts = CompositeSpec(ts1=bounded, ts2=nd)
- ts_same = CompositeSpec(ts1=bounded, ts2=nd)
+ ts = Composite(ts1=bounded, ts2=nd)
+ ts_same = Composite(ts1=bounded, ts2=nd)
assert ts == ts_same
- ts = CompositeSpec(ts1=bounded, ts2=nd)
- ts_same = CompositeSpec(ts1=bounded_same, ts2=nd_same)
+ ts = Composite(ts1=bounded, ts2=nd)
+ ts_same = Composite(ts1=bounded_same, ts2=nd_same)
assert ts == ts_same
- ts = CompositeSpec(ts1=bounded, ts2=nd)
- ts_same = CompositeSpec(ts2=nd_same, ts1=bounded_same)
+ ts = Composite(ts1=bounded, ts2=nd)
+ ts_same = Composite(ts2=nd_same, ts1=bounded_same)
assert ts == ts_same
# Inequality tests
- ts = CompositeSpec(ts1=bounded)
- ts_other = CompositeSpec(ts5=bounded)
+ ts = Composite(ts1=bounded)
+ ts_other = Composite(ts5=bounded)
assert ts != ts_other
- ts = CompositeSpec(ts1=bounded)
- ts_other = CompositeSpec(ts1=bounded_other)
+ ts = Composite(ts1=bounded)
+ ts_other = Composite(ts1=bounded_other)
assert ts != ts_other
- ts = CompositeSpec(ts1=bounded)
- ts_other = CompositeSpec(ts1=nd)
+ ts = Composite(ts1=bounded)
+ ts_other = Composite(ts1=nd)
assert ts != ts_other
- ts = CompositeSpec(ts1=bounded)
- ts_other = CompositeSpec(ts1=bounded, ts2=nd)
+ ts = Composite(ts1=bounded)
+ ts_other = Composite(ts1=bounded, ts2=nd)
assert ts != ts_other
- ts = CompositeSpec(ts1=bounded, ts2=nd)
- ts_other = CompositeSpec(ts2=nd)
+ ts = Composite(ts1=bounded, ts2=nd)
+ ts_other = Composite(ts2=nd)
assert ts != ts_other
- ts = CompositeSpec(ts1=bounded, ts2=nd)
- ts_other = CompositeSpec(ts1=bounded, ts2=nd, ts3=bounded_other)
+ ts = Composite(ts1=bounded, ts2=nd)
+ ts_other = Composite(ts1=bounded, ts2=nd, ts3=bounded_other)
assert ts != ts_other
class TestSpec:
- @pytest.mark.parametrize(
- "action_spec_cls", [OneHotDiscreteTensorSpec, DiscreteTensorSpec]
- )
+ @pytest.mark.parametrize("action_spec_cls", [OneHot, Categorical])
def test_discrete_action_spec_reconstruct(self, action_spec_cls):
torch.manual_seed(0)
action_spec = action_spec_cls(10)
@@ -1161,7 +1123,7 @@ def test_discrete_action_spec_reconstruct(self, action_spec_cls):
def test_mult_discrete_action_spec_reconstruct(self):
torch.manual_seed(0)
- action_spec = MultiOneHotDiscreteTensorSpec((10, 5))
+ action_spec = MultiOneHot((10, 5))
actions_tensors = [action_spec.rand() for _ in range(10)]
actions_categorical = [action_spec.to_categorical(a) for a in actions_tensors]
@@ -1183,7 +1145,7 @@ def test_mult_discrete_action_spec_reconstruct(self):
def test_one_hot_discrete_action_spec_rand(self):
torch.manual_seed(0)
- action_spec = OneHotDiscreteTensorSpec(10)
+ action_spec = OneHot(10)
sample = action_spec.rand((100000,))
@@ -1197,7 +1159,7 @@ def test_one_hot_discrete_action_spec_rand(self):
def test_categorical_action_spec_rand(self):
torch.manual_seed(1)
- action_spec = DiscreteTensorSpec(10)
+ action_spec = Categorical(10)
sample = action_spec.rand((10000,))
@@ -1213,7 +1175,7 @@ def test_mult_discrete_action_spec_rand(self):
torch.manual_seed(0)
ns = (10, 5)
N = 100000
- action_spec = MultiOneHotDiscreteTensorSpec((10, 5))
+ action_spec = MultiOneHot((10, 5))
actions_tensors = [action_spec.rand() for _ in range(10)]
actions_categorical = [action_spec.to_categorical(a) for a in actions_tensors]
@@ -1238,7 +1200,7 @@ def test_mult_discrete_action_spec_rand(self):
assert chisquare(sample_list).pvalue > 0.1
def test_categorical_action_spec_encode(self):
- action_spec = DiscreteTensorSpec(10)
+ action_spec = Categorical(10)
projected = action_spec.project(
torch.tensor([-100, -1, 0, 1, 9, 10, 100], dtype=torch.long)
@@ -1255,12 +1217,12 @@ def test_categorical_action_spec_encode(self):
).all()
def test_bounded_rand(self):
- spec = BoundedTensorSpec(-3, 3, torch.Size((1,)))
+ spec = Bounded(-3, 3, torch.Size((1,)))
sample = torch.stack([spec.rand() for _ in range(100)])
assert (-3 <= sample).all() and (3 >= sample).all()
def test_ndbounded_shape(self):
- spec = BoundedTensorSpec(-3, 3 * torch.ones(10, 5), shape=[10, 5])
+ spec = Bounded(-3, 3 * torch.ones(10, 5), shape=[10, 5])
sample = torch.stack([spec.rand() for _ in range(100)], 0)
assert (-3 <= sample).all() and (3 >= sample).all()
assert sample.shape == torch.Size([100, 10, 5])
@@ -1270,9 +1232,7 @@ class TestExpand:
@pytest.mark.parametrize("shape1", [None, (4,), (5, 4)])
@pytest.mark.parametrize("shape2", [(), (10,)])
def test_binary(self, shape1, shape2):
- spec = BinaryDiscreteTensorSpec(
- n=4, shape=shape1, device="cpu", dtype=torch.bool
- )
+ spec = Binary(n=4, shape=shape1, device="cpu", dtype=torch.bool)
if shape1 is not None:
shape2_real = (*shape2, *shape1)
else:
@@ -1304,9 +1264,7 @@ def test_binary(self, shape1, shape2):
],
)
def test_bounded(self, shape1, shape2, mini, maxi):
- spec = BoundedTensorSpec(
- mini, maxi, shape=shape1, device="cpu", dtype=torch.bool
- )
+ spec = Bounded(mini, maxi, shape=shape1, device="cpu", dtype=torch.bool)
shape1 = spec.shape
assert shape1 == torch.Size([10])
shape2_real = (*shape2, *shape1)
@@ -1326,7 +1284,7 @@ def test_bounded(self, shape1, shape2, mini, maxi):
def test_composite(self):
batch_size = (5,)
- spec1 = BoundedTensorSpec(
+ spec1 = Bounded(
-torch.ones([*batch_size, 10]),
torch.ones([*batch_size, 10]),
shape=(
@@ -1336,22 +1294,16 @@ def test_composite(self):
device="cpu",
dtype=torch.bool,
)
- spec2 = BinaryDiscreteTensorSpec(
- n=4, shape=(*batch_size, 4), device="cpu", dtype=torch.bool
- )
- spec3 = DiscreteTensorSpec(
- n=4, shape=batch_size, device="cpu", dtype=torch.long
- )
- spec4 = MultiDiscreteTensorSpec(
+ spec2 = Binary(n=4, shape=(*batch_size, 4), device="cpu", dtype=torch.bool)
+ spec3 = Categorical(n=4, shape=batch_size, device="cpu", dtype=torch.long)
+ spec4 = MultiCategorical(
nvec=(4, 5, 6), shape=(*batch_size, 3), device="cpu", dtype=torch.long
)
- spec5 = MultiOneHotDiscreteTensorSpec(
+ spec5 = MultiOneHot(
nvec=(4, 5, 6), shape=(*batch_size, 15), device="cpu", dtype=torch.long
)
- spec6 = OneHotDiscreteTensorSpec(
- n=15, shape=(*batch_size, 15), device="cpu", dtype=torch.long
- )
- spec7 = UnboundedContinuousTensorSpec(
+ spec6 = OneHot(n=15, shape=(*batch_size, 15), device="cpu", dtype=torch.long)
+ spec7 = Unbounded(
shape=(*batch_size, 9),
device="cpu",
dtype=torch.float64,
@@ -1361,7 +1313,7 @@ def test_composite(self):
device="cpu",
dtype=torch.long,
)
- spec = CompositeSpec(
+ spec = Composite(
spec1=spec1,
spec2=spec2,
spec3=spec3,
@@ -1392,7 +1344,7 @@ def test_composite(self):
@pytest.mark.parametrize("shape1", [None, (), (5,)])
@pytest.mark.parametrize("shape2", [(), (10,)])
def test_discrete(self, shape1, shape2):
- spec = DiscreteTensorSpec(n=4, shape=shape1, device="cpu", dtype=torch.long)
+ spec = Categorical(n=4, shape=shape1, device="cpu", dtype=torch.long)
if shape1 is not None:
shape2_real = (*shape2, *shape1)
else:
@@ -1418,7 +1370,7 @@ def test_multidiscrete(self, shape1, shape2):
shape1 = (3,)
else:
shape1 = (*shape1, 3)
- spec = MultiDiscreteTensorSpec(
+ spec = MultiCategorical(
nvec=(4, 5, 6), shape=shape1, device="cpu", dtype=torch.long
)
if shape1 is not None:
@@ -1446,9 +1398,7 @@ def test_multionehot(self, shape1, shape2):
shape1 = (15,)
else:
shape1 = (*shape1, 15)
- spec = MultiOneHotDiscreteTensorSpec(
- nvec=(4, 5, 6), shape=shape1, device="cpu", dtype=torch.long
- )
+ spec = MultiOneHot(nvec=(4, 5, 6), shape=shape1, device="cpu", dtype=torch.long)
if shape1 is not None:
shape2_real = (*shape2, *shape1)
else:
@@ -1468,11 +1418,11 @@ def test_multionehot(self, shape1, shape2):
assert spec2.zero().shape == spec2.shape
def test_non_tensor(self):
- spec = NonTensorSpec((3, 4), device="cpu")
+ spec = NonTensor((3, 4), device="cpu")
assert (
spec.expand(2, 3, 4)
== spec.expand((2, 3, 4))
- == NonTensorSpec((2, 3, 4), device="cpu")
+ == NonTensor((2, 3, 4), device="cpu")
)
@pytest.mark.parametrize("shape1", [None, (), (5,)])
@@ -1482,9 +1432,7 @@ def test_onehot(self, shape1, shape2):
shape1 = (15,)
else:
shape1 = (*shape1, 15)
- spec = OneHotDiscreteTensorSpec(
- n=15, shape=shape1, device="cpu", dtype=torch.long
- )
+ spec = OneHot(n=15, shape=shape1, device="cpu", dtype=torch.long)
if shape1 is not None:
shape2_real = (*shape2, *shape1)
else:
@@ -1510,9 +1458,7 @@ def test_unbounded(self, shape1, shape2):
shape1 = (15,)
else:
shape1 = (*shape1, 15)
- spec = UnboundedContinuousTensorSpec(
- shape=shape1, device="cpu", dtype=torch.float64
- )
+ spec = Unbounded(shape=shape1, device="cpu", dtype=torch.float64)
if shape1 is not None:
shape2_real = (*shape2, *shape1)
else:
@@ -1571,9 +1517,7 @@ class TestClone:
],
)
def test_binary(self, shape1):
- spec = BinaryDiscreteTensorSpec(
- n=4, shape=shape1, device="cpu", dtype=torch.bool
- )
+ spec = Binary(n=4, shape=shape1, device="cpu", dtype=torch.bool)
assert spec == spec.clone()
assert spec is not spec.clone()
@@ -1589,15 +1533,13 @@ def test_binary(self, shape1):
],
)
def test_bounded(self, shape1, mini, maxi):
- spec = BoundedTensorSpec(
- mini, maxi, shape=shape1, device="cpu", dtype=torch.bool
- )
+ spec = Bounded(mini, maxi, shape=shape1, device="cpu", dtype=torch.bool)
assert spec == spec.clone()
assert spec is not spec.clone()
def test_composite(self):
batch_size = (5,)
- spec1 = BoundedTensorSpec(
+ spec1 = Bounded(
-torch.ones([*batch_size, 10]),
torch.ones([*batch_size, 10]),
shape=(
@@ -1607,22 +1549,16 @@ def test_composite(self):
device="cpu",
dtype=torch.bool,
)
- spec2 = BinaryDiscreteTensorSpec(
- n=4, shape=(*batch_size, 4), device="cpu", dtype=torch.bool
- )
- spec3 = DiscreteTensorSpec(
- n=4, shape=batch_size, device="cpu", dtype=torch.long
- )
- spec4 = MultiDiscreteTensorSpec(
+ spec2 = Binary(n=4, shape=(*batch_size, 4), device="cpu", dtype=torch.bool)
+ spec3 = Categorical(n=4, shape=batch_size, device="cpu", dtype=torch.long)
+ spec4 = MultiCategorical(
nvec=(4, 5, 6), shape=(*batch_size, 3), device="cpu", dtype=torch.long
)
- spec5 = MultiOneHotDiscreteTensorSpec(
+ spec5 = MultiOneHot(
nvec=(4, 5, 6), shape=(*batch_size, 15), device="cpu", dtype=torch.long
)
- spec6 = OneHotDiscreteTensorSpec(
- n=15, shape=(*batch_size, 15), device="cpu", dtype=torch.long
- )
- spec7 = UnboundedContinuousTensorSpec(
+ spec6 = OneHot(n=15, shape=(*batch_size, 15), device="cpu", dtype=torch.long)
+ spec7 = Unbounded(
shape=(*batch_size, 9),
device="cpu",
dtype=torch.float64,
@@ -1632,7 +1568,7 @@ def test_composite(self):
device="cpu",
dtype=torch.long,
)
- spec = CompositeSpec(
+ spec = Composite(
spec1=spec1,
spec2=spec2,
spec3=spec3,
@@ -1654,7 +1590,7 @@ def test_discrete(
self,
shape1,
):
- spec = DiscreteTensorSpec(n=4, shape=shape1, device="cpu", dtype=torch.long)
+ spec = Categorical(n=4, shape=shape1, device="cpu", dtype=torch.long)
assert spec == spec.clone()
assert spec is not spec.clone()
@@ -1667,7 +1603,7 @@ def test_multidiscrete(
shape1 = (3,)
else:
shape1 = (*shape1, 3)
- spec = MultiDiscreteTensorSpec(
+ spec = MultiCategorical(
nvec=(4, 5, 6), shape=shape1, device="cpu", dtype=torch.long
)
assert spec == spec.clone()
@@ -1682,14 +1618,12 @@ def test_multionehot(
shape1 = (15,)
else:
shape1 = (*shape1, 15)
- spec = MultiOneHotDiscreteTensorSpec(
- nvec=(4, 5, 6), shape=shape1, device="cpu", dtype=torch.long
- )
+ spec = MultiOneHot(nvec=(4, 5, 6), shape=shape1, device="cpu", dtype=torch.long)
assert spec == spec.clone()
assert spec is not spec.clone()
def test_non_tensor(self):
- spec = NonTensorSpec(shape=(3, 4), device="cpu")
+ spec = NonTensor(shape=(3, 4), device="cpu")
assert spec.clone() == spec
assert spec.clone() is not spec
@@ -1702,9 +1636,7 @@ def test_onehot(
shape1 = (15,)
else:
shape1 = (*shape1, 15)
- spec = OneHotDiscreteTensorSpec(
- n=15, shape=shape1, device="cpu", dtype=torch.long
- )
+ spec = OneHot(n=15, shape=shape1, device="cpu", dtype=torch.long)
assert spec == spec.clone()
assert spec is not spec.clone()
@@ -1717,9 +1649,7 @@ def test_unbounded(
shape1 = (15,)
else:
shape1 = (*shape1, 15)
- spec = UnboundedContinuousTensorSpec(
- shape=shape1, device="cpu", dtype=torch.float64
- )
+ spec = Unbounded(shape=shape1, device="cpu", dtype=torch.float64)
assert spec == spec.clone()
assert spec is not spec.clone()
@@ -1740,9 +1670,7 @@ def test_unboundeddiscrete(
class TestUnbind:
@pytest.mark.parametrize("shape1", [(5, 4)])
def test_binary(self, shape1):
- spec = BinaryDiscreteTensorSpec(
- n=4, shape=shape1, device="cpu", dtype=torch.bool
- )
+ spec = Binary(n=4, shape=shape1, device="cpu", dtype=torch.bool)
assert spec == torch.stack(spec.unbind(0), 0)
with pytest.raises(ValueError):
spec.unbind(-1)
@@ -1759,16 +1687,14 @@ def test_binary(self, shape1):
],
)
def test_bounded(self, shape1, mini, maxi):
- spec = BoundedTensorSpec(
- mini, maxi, shape=shape1, device="cpu", dtype=torch.bool
- )
+ spec = Bounded(mini, maxi, shape=shape1, device="cpu", dtype=torch.bool)
assert spec == torch.stack(spec.unbind(0), 0)
with pytest.raises(ValueError):
spec.unbind(-1)
def test_composite(self):
batch_size = (5,)
- spec1 = BoundedTensorSpec(
+ spec1 = Bounded(
-torch.ones([*batch_size, 10]),
torch.ones([*batch_size, 10]),
shape=(
@@ -1778,22 +1704,16 @@ def test_composite(self):
device="cpu",
dtype=torch.bool,
)
- spec2 = BinaryDiscreteTensorSpec(
- n=4, shape=(*batch_size, 4), device="cpu", dtype=torch.bool
- )
- spec3 = DiscreteTensorSpec(
- n=4, shape=batch_size, device="cpu", dtype=torch.long
- )
- spec4 = MultiDiscreteTensorSpec(
+ spec2 = Binary(n=4, shape=(*batch_size, 4), device="cpu", dtype=torch.bool)
+ spec3 = Categorical(n=4, shape=batch_size, device="cpu", dtype=torch.long)
+ spec4 = MultiCategorical(
nvec=(4, 5, 6), shape=(*batch_size, 3), device="cpu", dtype=torch.long
)
- spec5 = MultiOneHotDiscreteTensorSpec(
+ spec5 = MultiOneHot(
nvec=(4, 5, 6), shape=(*batch_size, 15), device="cpu", dtype=torch.long
)
- spec6 = OneHotDiscreteTensorSpec(
- n=15, shape=(*batch_size, 15), device="cpu", dtype=torch.long
- )
- spec7 = UnboundedContinuousTensorSpec(
+ spec6 = OneHot(n=15, shape=(*batch_size, 15), device="cpu", dtype=torch.long)
+ spec7 = Unbounded(
shape=(*batch_size, 9),
device="cpu",
dtype=torch.float64,
@@ -1803,7 +1723,7 @@ def test_composite(self):
device="cpu",
dtype=torch.long,
)
- spec = CompositeSpec(
+ spec = Composite(
spec1=spec1,
spec2=spec2,
spec3=spec3,
@@ -1822,7 +1742,7 @@ def test_discrete(
self,
shape1,
):
- spec = DiscreteTensorSpec(n=4, shape=shape1, device="cpu", dtype=torch.long)
+ spec = Categorical(n=4, shape=shape1, device="cpu", dtype=torch.long)
assert spec == torch.stack(spec.unbind(0), 0)
assert spec == torch.stack(spec.unbind(-1), -1)
@@ -1835,7 +1755,7 @@ def test_multidiscrete(
shape1 = (3,)
else:
shape1 = (*shape1, 3)
- spec = MultiDiscreteTensorSpec(
+ spec = MultiCategorical(
nvec=(4, 5, 6), shape=shape1, device="cpu", dtype=torch.long
)
assert spec == torch.stack(spec.unbind(0), 0)
@@ -1851,15 +1771,13 @@ def test_multionehot(
shape1 = (15,)
else:
shape1 = (*shape1, 15)
- spec = MultiOneHotDiscreteTensorSpec(
- nvec=(4, 5, 6), shape=shape1, device="cpu", dtype=torch.long
- )
+ spec = MultiOneHot(nvec=(4, 5, 6), shape=shape1, device="cpu", dtype=torch.long)
assert spec == torch.stack(spec.unbind(0), 0)
with pytest.raises(ValueError):
spec.unbind(-1)
def test_non_tensor(self):
- spec = NonTensorSpec(shape=(3, 4), device="cpu")
+ spec = NonTensor(shape=(3, 4), device="cpu")
assert spec.unbind(1)[0] == spec[:, 0]
assert spec.unbind(1)[0] is not spec[:, 0]
@@ -1872,9 +1790,7 @@ def test_onehot(
shape1 = (15,)
else:
shape1 = (*shape1, 15)
- spec = OneHotDiscreteTensorSpec(
- n=15, shape=shape1, device="cpu", dtype=torch.long
- )
+ spec = OneHot(n=15, shape=shape1, device="cpu", dtype=torch.long)
assert spec == torch.stack(spec.unbind(0), 0)
with pytest.raises(ValueError):
spec.unbind(-1)
@@ -1888,9 +1804,7 @@ def test_unbounded(
shape1 = (15,)
else:
shape1 = (*shape1, 15)
- spec = UnboundedContinuousTensorSpec(
- shape=shape1, device="cpu", dtype=torch.float64
- )
+ spec = Unbounded(shape=shape1, device="cpu", dtype=torch.float64)
assert spec == torch.stack(spec.unbind(0), 0)
assert spec == torch.stack(spec.unbind(-1), -1)
@@ -1908,15 +1822,15 @@ def test_unboundeddiscrete(
assert spec == torch.stack(spec.unbind(-1), -1)
def test_composite_encode_err(self):
- c = CompositeSpec(
- a=UnboundedContinuousTensorSpec(
+ c = Composite(
+ a=Unbounded(
1,
),
- b=UnboundedContinuousTensorSpec(
+ b=Unbounded(
2,
),
)
- with pytest.raises(KeyError, match="The CompositeSpec instance with keys"):
+ with pytest.raises(KeyError, match="The Composite instance with keys"):
c.encode({"c": 0})
with pytest.raises(
RuntimeError, match="raised a RuntimeError. Scroll up to know more"
@@ -1932,9 +1846,7 @@ def test_composite_encode_err(self):
class TestTo:
@pytest.mark.parametrize("shape1", [(5, 4)])
def test_binary(self, shape1, device):
- spec = BinaryDiscreteTensorSpec(
- n=4, shape=shape1, device="cpu", dtype=torch.bool
- )
+ spec = Binary(n=4, shape=shape1, device="cpu", dtype=torch.bool)
assert spec.to(device).device == device
@pytest.mark.parametrize(
@@ -1949,14 +1861,12 @@ def test_binary(self, shape1, device):
],
)
def test_bounded(self, shape1, mini, maxi, device):
- spec = BoundedTensorSpec(
- mini, maxi, shape=shape1, device="cpu", dtype=torch.bool
- )
+ spec = Bounded(mini, maxi, shape=shape1, device="cpu", dtype=torch.bool)
assert spec.to(device).device == device
def test_composite(self, device):
batch_size = (5,)
- spec1 = BoundedTensorSpec(
+ spec1 = Bounded(
-torch.ones([*batch_size, 10]),
torch.ones([*batch_size, 10]),
shape=(
@@ -1966,22 +1876,16 @@ def test_composite(self, device):
device="cpu",
dtype=torch.bool,
)
- spec2 = BinaryDiscreteTensorSpec(
- n=4, shape=(*batch_size, 4), device="cpu", dtype=torch.bool
- )
- spec3 = DiscreteTensorSpec(
- n=4, shape=batch_size, device="cpu", dtype=torch.long
- )
- spec4 = MultiDiscreteTensorSpec(
+ spec2 = Binary(n=4, shape=(*batch_size, 4), device="cpu", dtype=torch.bool)
+ spec3 = Categorical(n=4, shape=batch_size, device="cpu", dtype=torch.long)
+ spec4 = MultiCategorical(
nvec=(4, 5, 6), shape=(*batch_size, 3), device="cpu", dtype=torch.long
)
- spec5 = MultiOneHotDiscreteTensorSpec(
+ spec5 = MultiOneHot(
nvec=(4, 5, 6), shape=(*batch_size, 15), device="cpu", dtype=torch.long
)
- spec6 = OneHotDiscreteTensorSpec(
- n=15, shape=(*batch_size, 15), device="cpu", dtype=torch.long
- )
- spec7 = UnboundedContinuousTensorSpec(
+ spec6 = OneHot(n=15, shape=(*batch_size, 15), device="cpu", dtype=torch.long)
+ spec7 = Unbounded(
shape=(*batch_size, 9),
device="cpu",
dtype=torch.float64,
@@ -1991,7 +1895,7 @@ def test_composite(self, device):
device="cpu",
dtype=torch.long,
)
- spec = CompositeSpec(
+ spec = Composite(
spec1=spec1,
spec2=spec2,
spec3=spec3,
@@ -2010,7 +1914,7 @@ def test_discrete(
shape1,
device,
):
- spec = DiscreteTensorSpec(n=4, shape=shape1, device="cpu", dtype=torch.long)
+ spec = Categorical(n=4, shape=shape1, device="cpu", dtype=torch.long)
assert spec.to(device).device == device
@pytest.mark.parametrize("shape1", [(5,), (5, 6)])
@@ -2019,7 +1923,7 @@ def test_multidiscrete(self, shape1, device):
shape1 = (3,)
else:
shape1 = (*shape1, 3)
- spec = MultiDiscreteTensorSpec(
+ spec = MultiCategorical(
nvec=(4, 5, 6), shape=shape1, device="cpu", dtype=torch.long
)
assert spec.to(device).device == device
@@ -2030,13 +1934,11 @@ def test_multionehot(self, shape1, device):
shape1 = (15,)
else:
shape1 = (*shape1, 15)
- spec = MultiOneHotDiscreteTensorSpec(
- nvec=(4, 5, 6), shape=shape1, device="cpu", dtype=torch.long
- )
+ spec = MultiOneHot(nvec=(4, 5, 6), shape=shape1, device="cpu", dtype=torch.long)
assert spec.to(device).device == device
def test_non_tensor(self, device):
- spec = NonTensorSpec(shape=(3, 4), device="cpu")
+ spec = NonTensor(shape=(3, 4), device="cpu")
assert spec.to(device).device == device
@pytest.mark.parametrize("shape1", [(5,), (5, 6)])
@@ -2045,9 +1947,7 @@ def test_onehot(self, shape1, device):
shape1 = (15,)
else:
shape1 = (*shape1, 15)
- spec = OneHotDiscreteTensorSpec(
- n=15, shape=shape1, device="cpu", dtype=torch.long
- )
+ spec = OneHot(n=15, shape=shape1, device="cpu", dtype=torch.long)
assert spec.to(device).device == device
@pytest.mark.parametrize("shape1", [(5,), (5, 6)])
@@ -2056,9 +1956,7 @@ def test_unbounded(self, shape1, device):
shape1 = (15,)
else:
shape1 = (*shape1, 15)
- spec = UnboundedContinuousTensorSpec(
- shape=shape1, device="cpu", dtype=torch.float64
- )
+ spec = Unbounded(shape=shape1, device="cpu", dtype=torch.float64)
assert spec.to(device).device == device
@pytest.mark.parametrize("shape1", [(5,), (5, 6)])
@@ -2079,10 +1977,10 @@ class TestStack:
def test_stack_binarydiscrete(self, shape, stack_dim):
n = 5
shape = (*shape, n)
- c1 = BinaryDiscreteTensorSpec(n=n, shape=shape)
+ c1 = Binary(n=n, shape=shape)
c2 = c1.clone()
c = torch.stack([c1, c2], stack_dim)
- assert isinstance(c, BinaryDiscreteTensorSpec)
+ assert isinstance(c, Binary)
shape = list(shape)
if stack_dim < 0:
stack_dim = len(shape) + stack_dim + 1
@@ -2092,7 +1990,7 @@ def test_stack_binarydiscrete(self, shape, stack_dim):
def test_stack_binarydiscrete_expand(self, shape, stack_dim):
n = 5
shape = (*shape, n)
- c1 = BinaryDiscreteTensorSpec(n=n, shape=shape)
+ c1 = Binary(n=n, shape=shape)
c2 = c1.clone()
c = torch.stack([c1, c2], stack_dim)
shape = list(shape)
@@ -2105,7 +2003,7 @@ def test_stack_binarydiscrete_expand(self, shape, stack_dim):
def test_stack_binarydiscrete_rand(self, shape, stack_dim):
n = 5
shape = (*shape, n)
- c1 = BinaryDiscreteTensorSpec(n=n, shape=shape)
+ c1 = Binary(n=n, shape=shape)
c2 = c1.clone()
c = torch.stack([c1, c2], 0)
r = c.rand()
@@ -2114,7 +2012,7 @@ def test_stack_binarydiscrete_rand(self, shape, stack_dim):
def test_stack_binarydiscrete_zero(self, shape, stack_dim):
n = 5
shape = (*shape, n)
- c1 = BinaryDiscreteTensorSpec(n=n, shape=shape)
+ c1 = Binary(n=n, shape=shape)
c2 = c1.clone()
c = torch.stack([c1, c2], 0)
r = c.zero()
@@ -2124,10 +2022,10 @@ def test_stack_bounded(self, shape, stack_dim):
mini = -1
maxi = 1
shape = (*shape,)
- c1 = BoundedTensorSpec(mini, maxi, shape=shape)
+ c1 = Bounded(mini, maxi, shape=shape)
c2 = c1.clone()
c = torch.stack([c1, c2], stack_dim)
- assert isinstance(c, BoundedTensorSpec)
+ assert isinstance(c, Bounded)
shape = list(shape)
if stack_dim < 0:
stack_dim = len(shape) + stack_dim + 1
@@ -2138,7 +2036,7 @@ def test_stack_bounded_expand(self, shape, stack_dim):
mini = -1
maxi = 1
shape = (*shape,)
- c1 = BoundedTensorSpec(mini, maxi, shape=shape)
+ c1 = Bounded(mini, maxi, shape=shape)
c2 = c1.clone()
c = torch.stack([c1, c2], stack_dim)
shape = list(shape)
@@ -2152,7 +2050,7 @@ def test_stack_bounded_rand(self, shape, stack_dim):
mini = -1
maxi = 1
shape = (*shape,)
- c1 = BoundedTensorSpec(mini, maxi, shape=shape)
+ c1 = Bounded(mini, maxi, shape=shape)
c2 = c1.clone()
c = torch.stack([c1, c2], 0)
r = c.rand()
@@ -2162,7 +2060,7 @@ def test_stack_bounded_zero(self, shape, stack_dim):
mini = -1
maxi = 1
shape = (*shape,)
- c1 = BoundedTensorSpec(mini, maxi, shape=shape)
+ c1 = Bounded(mini, maxi, shape=shape)
c2 = c1.clone()
c = torch.stack([c1, c2], 0)
r = c.zero()
@@ -2171,10 +2069,10 @@ def test_stack_bounded_zero(self, shape, stack_dim):
def test_stack_discrete(self, shape, stack_dim):
n = 4
shape = (*shape,)
- c1 = DiscreteTensorSpec(n, shape=shape)
+ c1 = Categorical(n, shape=shape)
c2 = c1.clone()
c = torch.stack([c1, c2], stack_dim)
- assert isinstance(c, DiscreteTensorSpec)
+ assert isinstance(c, Categorical)
shape = list(shape)
if stack_dim < 0:
stack_dim = len(shape) + stack_dim + 1
@@ -2184,7 +2082,7 @@ def test_stack_discrete(self, shape, stack_dim):
def test_stack_discrete_expand(self, shape, stack_dim):
n = 4
shape = (*shape,)
- c1 = DiscreteTensorSpec(n, shape=shape)
+ c1 = Categorical(n, shape=shape)
c2 = c1.clone()
c = torch.stack([c1, c2], stack_dim)
shape = list(shape)
@@ -2197,7 +2095,7 @@ def test_stack_discrete_expand(self, shape, stack_dim):
def test_stack_discrete_rand(self, shape, stack_dim):
n = 4
shape = (*shape,)
- c1 = DiscreteTensorSpec(n, shape=shape)
+ c1 = Categorical(n, shape=shape)
c2 = c1.clone()
c = torch.stack([c1, c2], 0)
r = c.rand()
@@ -2206,7 +2104,7 @@ def test_stack_discrete_rand(self, shape, stack_dim):
def test_stack_discrete_zero(self, shape, stack_dim):
n = 4
shape = (*shape,)
- c1 = DiscreteTensorSpec(n, shape=shape)
+ c1 = Categorical(n, shape=shape)
c2 = c1.clone()
c = torch.stack([c1, c2], 0)
r = c.zero()
@@ -2215,10 +2113,10 @@ def test_stack_discrete_zero(self, shape, stack_dim):
def test_stack_multidiscrete(self, shape, stack_dim):
nvec = [4, 5]
shape = (*shape, 2)
- c1 = MultiDiscreteTensorSpec(nvec, shape=shape)
+ c1 = MultiCategorical(nvec, shape=shape)
c2 = c1.clone()
c = torch.stack([c1, c2], stack_dim)
- assert isinstance(c, MultiDiscreteTensorSpec)
+ assert isinstance(c, MultiCategorical)
shape = list(shape)
if stack_dim < 0:
stack_dim = len(shape) + stack_dim + 1
@@ -2228,7 +2126,7 @@ def test_stack_multidiscrete(self, shape, stack_dim):
def test_stack_multidiscrete_expand(self, shape, stack_dim):
nvec = [4, 5]
shape = (*shape, 2)
- c1 = MultiDiscreteTensorSpec(nvec, shape=shape)
+ c1 = MultiCategorical(nvec, shape=shape)
c2 = c1.clone()
c = torch.stack([c1, c2], stack_dim)
shape = list(shape)
@@ -2241,7 +2139,7 @@ def test_stack_multidiscrete_expand(self, shape, stack_dim):
def test_stack_multidiscrete_rand(self, shape, stack_dim):
nvec = [4, 5]
shape = (*shape, 2)
- c1 = MultiDiscreteTensorSpec(nvec, shape=shape)
+ c1 = MultiCategorical(nvec, shape=shape)
c2 = c1.clone()
c = torch.stack([c1, c2], 0)
r = c.rand()
@@ -2250,7 +2148,7 @@ def test_stack_multidiscrete_rand(self, shape, stack_dim):
def test_stack_multidiscrete_zero(self, shape, stack_dim):
nvec = [4, 5]
shape = (*shape, 2)
- c1 = MultiDiscreteTensorSpec(nvec, shape=shape)
+ c1 = MultiCategorical(nvec, shape=shape)
c2 = c1.clone()
c = torch.stack([c1, c2], 0)
r = c.zero()
@@ -2259,10 +2157,10 @@ def test_stack_multidiscrete_zero(self, shape, stack_dim):
def test_stack_multionehot(self, shape, stack_dim):
nvec = [4, 5]
shape = (*shape, 9)
- c1 = MultiOneHotDiscreteTensorSpec(nvec, shape=shape)
+ c1 = MultiOneHot(nvec, shape=shape)
c2 = c1.clone()
c = torch.stack([c1, c2], stack_dim)
- assert isinstance(c, MultiOneHotDiscreteTensorSpec)
+ assert isinstance(c, MultiOneHot)
shape = list(shape)
if stack_dim < 0:
stack_dim = len(shape) + stack_dim + 1
@@ -2272,7 +2170,7 @@ def test_stack_multionehot(self, shape, stack_dim):
def test_stack_multionehot_expand(self, shape, stack_dim):
nvec = [4, 5]
shape = (*shape, 9)
- c1 = MultiOneHotDiscreteTensorSpec(nvec, shape=shape)
+ c1 = MultiOneHot(nvec, shape=shape)
c2 = c1.clone()
c = torch.stack([c1, c2], stack_dim)
shape = list(shape)
@@ -2285,7 +2183,7 @@ def test_stack_multionehot_expand(self, shape, stack_dim):
def test_stack_multionehot_rand(self, shape, stack_dim):
nvec = [4, 5]
shape = (*shape, 9)
- c1 = MultiOneHotDiscreteTensorSpec(nvec, shape=shape)
+ c1 = MultiOneHot(nvec, shape=shape)
c2 = c1.clone()
c = torch.stack([c1, c2], 0)
r = c.rand()
@@ -2294,15 +2192,15 @@ def test_stack_multionehot_rand(self, shape, stack_dim):
def test_stack_multionehot_zero(self, shape, stack_dim):
nvec = [4, 5]
shape = (*shape, 9)
- c1 = MultiOneHotDiscreteTensorSpec(nvec, shape=shape)
+ c1 = MultiOneHot(nvec, shape=shape)
c2 = c1.clone()
c = torch.stack([c1, c2], 0)
r = c.zero()
assert r.shape == c.shape
def test_stack_non_tensor(self, shape, stack_dim):
- spec0 = NonTensorSpec(shape=shape, device="cpu")
- spec1 = NonTensorSpec(shape=shape, device="cpu")
+ spec0 = NonTensor(shape=shape, device="cpu")
+ spec1 = NonTensor(shape=shape, device="cpu")
new_spec = torch.stack([spec0, spec1], stack_dim)
shape_insert = list(shape)
shape_insert.insert(stack_dim, 2)
@@ -2312,10 +2210,10 @@ def test_stack_non_tensor(self, shape, stack_dim):
def test_stack_onehot(self, shape, stack_dim):
n = 5
shape = (*shape, 5)
- c1 = OneHotDiscreteTensorSpec(n, shape=shape)
+ c1 = OneHot(n, shape=shape)
c2 = c1.clone()
c = torch.stack([c1, c2], stack_dim)
- assert isinstance(c, OneHotDiscreteTensorSpec)
+ assert isinstance(c, OneHot)
shape = list(shape)
if stack_dim < 0:
stack_dim = len(shape) + stack_dim + 1
@@ -2325,7 +2223,7 @@ def test_stack_onehot(self, shape, stack_dim):
def test_stack_onehot_expand(self, shape, stack_dim):
n = 5
shape = (*shape, 5)
- c1 = OneHotDiscreteTensorSpec(n, shape=shape)
+ c1 = OneHot(n, shape=shape)
c2 = c1.clone()
c = torch.stack([c1, c2], stack_dim)
shape = list(shape)
@@ -2338,7 +2236,7 @@ def test_stack_onehot_expand(self, shape, stack_dim):
def test_stack_onehot_rand(self, shape, stack_dim):
n = 5
shape = (*shape, 5)
- c1 = OneHotDiscreteTensorSpec(n, shape=shape)
+ c1 = OneHot(n, shape=shape)
c2 = c1.clone()
c = torch.stack([c1, c2], 0)
r = c.rand()
@@ -2347,7 +2245,7 @@ def test_stack_onehot_rand(self, shape, stack_dim):
def test_stack_onehot_zero(self, shape, stack_dim):
n = 5
shape = (*shape, 5)
- c1 = OneHotDiscreteTensorSpec(n, shape=shape)
+ c1 = OneHot(n, shape=shape)
c2 = c1.clone()
c = torch.stack([c1, c2], 0)
r = c.zero()
@@ -2355,10 +2253,10 @@ def test_stack_onehot_zero(self, shape, stack_dim):
def test_stack_unboundedcont(self, shape, stack_dim):
shape = (*shape,)
- c1 = UnboundedContinuousTensorSpec(shape=shape)
+ c1 = Unbounded(shape=shape)
c2 = c1.clone()
c = torch.stack([c1, c2], stack_dim)
- assert isinstance(c, UnboundedContinuousTensorSpec)
+ assert isinstance(c, Unbounded)
shape = list(shape)
if stack_dim < 0:
stack_dim = len(shape) + stack_dim + 1
@@ -2367,7 +2265,7 @@ def test_stack_unboundedcont(self, shape, stack_dim):
def test_stack_unboundedcont_expand(self, shape, stack_dim):
shape = (*shape,)
- c1 = UnboundedContinuousTensorSpec(shape=shape)
+ c1 = Unbounded(shape=shape)
c2 = c1.clone()
c = torch.stack([c1, c2], stack_dim)
shape = list(shape)
@@ -2379,7 +2277,7 @@ def test_stack_unboundedcont_expand(self, shape, stack_dim):
def test_stack_unboundedcont_rand(self, shape, stack_dim):
shape = (*shape,)
- c1 = UnboundedContinuousTensorSpec(shape=shape)
+ c1 = Unbounded(shape=shape)
c2 = c1.clone()
c = torch.stack([c1, c2], 0)
r = c.rand()
@@ -2434,8 +2332,8 @@ def test_stack_unboundeddiscrete_zero(self, shape, stack_dim):
assert r.shape == c.shape
def test_to_numpy(self, shape, stack_dim):
- c1 = BoundedTensorSpec(-1, 1, shape=shape, dtype=torch.float64)
- c2 = BoundedTensorSpec(-1, 1, shape=shape, dtype=torch.float64)
+ c1 = Bounded(-1, 1, shape=shape, dtype=torch.float64)
+ c2 = Bounded(-1, 1, shape=shape, dtype=torch.float64)
c = torch.stack([c1, c2], stack_dim)
@@ -2455,13 +2353,13 @@ def test_to_numpy(self, shape, stack_dim):
c.to_numpy(val + 1, safe=True)
def test_malformed_stack(self, shape, stack_dim):
- c1 = BoundedTensorSpec(-1, 1, shape=shape, dtype=torch.float64)
- c2 = BoundedTensorSpec(-1, 1, shape=shape, dtype=torch.float32)
+ c1 = Bounded(-1, 1, shape=shape, dtype=torch.float64)
+ c2 = Bounded(-1, 1, shape=shape, dtype=torch.float32)
with pytest.raises(RuntimeError, match="Dtypes differ"):
torch.stack([c1, c2], stack_dim)
- c1 = BoundedTensorSpec(-1, 1, shape=shape, dtype=torch.float32)
- c2 = UnboundedContinuousTensorSpec(shape=shape, dtype=torch.float32)
+ c1 = Bounded(-1, 1, shape=shape, dtype=torch.float32)
+ c2 = Unbounded(shape=shape, dtype=torch.float32)
c3 = UnboundedDiscreteTensorSpec(shape=shape, dtype=torch.float32)
with pytest.raises(
RuntimeError,
@@ -2470,40 +2368,40 @@ def test_malformed_stack(self, shape, stack_dim):
torch.stack([c1, c2], stack_dim)
torch.stack([c3, c2], stack_dim)
- c1 = BoundedTensorSpec(-1, 1, shape=shape, dtype=torch.float32)
- c2 = BoundedTensorSpec(-1, 1, shape=shape + (3,), dtype=torch.float32)
+ c1 = Bounded(-1, 1, shape=shape, dtype=torch.float32)
+ c2 = Bounded(-1, 1, shape=shape + (3,), dtype=torch.float32)
with pytest.raises(RuntimeError, match="Ndims differ"):
torch.stack([c1, c2], stack_dim)
-class TestDenseStackedCompositeSpecs:
+class TestDenseStackedComposite:
def test_stack(self):
- c1 = CompositeSpec(a=UnboundedContinuousTensorSpec())
+ c1 = Composite(a=Unbounded())
c2 = c1.clone()
c = torch.stack([c1, c2], 0)
- assert isinstance(c, CompositeSpec)
+ assert isinstance(c, Composite)
-class TestLazyStackedCompositeSpecs:
+class TestLazyStackedComposite:
def _get_heterogeneous_specs(
self,
batch_size=(),
stack_dim: int = 0,
):
- shared = BoundedTensorSpec(low=0, high=1, shape=(*batch_size, 32, 32, 3))
- hetero_3d = UnboundedContinuousTensorSpec(
+ shared = Bounded(low=0, high=1, shape=(*batch_size, 32, 32, 3))
+ hetero_3d = Unbounded(
shape=(
*batch_size,
3,
)
)
- hetero_2d = UnboundedContinuousTensorSpec(
+ hetero_2d = Unbounded(
shape=(
*batch_size,
2,
)
)
- lidar = BoundedTensorSpec(
+ lidar = Bounded(
low=0,
high=5,
shape=(
@@ -2512,9 +2410,9 @@ def _get_heterogeneous_specs(
),
)
- individual_0_obs = CompositeSpec(
+ individual_0_obs = Composite(
{
- "individual_0_obs_0": UnboundedContinuousTensorSpec(
+ "individual_0_obs_0": Unbounded(
shape=(
*batch_size,
3,
@@ -2524,25 +2422,21 @@ def _get_heterogeneous_specs(
},
shape=(*batch_size, 3),
)
- individual_1_obs = CompositeSpec(
+ individual_1_obs = Composite(
{
- "individual_1_obs_0": BoundedTensorSpec(
+ "individual_1_obs_0": Bounded(
low=0, high=3, shape=(*batch_size, 3, 1, 2)
)
},
shape=(*batch_size, 3),
)
- individual_2_obs = CompositeSpec(
- {
- "individual_1_obs_0": UnboundedContinuousTensorSpec(
- shape=(*batch_size, 3, 1, 2, 3)
- )
- },
+ individual_2_obs = Composite(
+ {"individual_1_obs_0": Unbounded(shape=(*batch_size, 3, 1, 2, 3))},
shape=(*batch_size, 3),
)
spec_list = [
- CompositeSpec(
+ Composite(
{
"shared": shared,
"lidar": lidar,
@@ -2551,7 +2445,7 @@ def _get_heterogeneous_specs(
},
shape=batch_size,
),
- CompositeSpec(
+ Composite(
{
"shared": shared,
"lidar": lidar,
@@ -2560,7 +2454,7 @@ def _get_heterogeneous_specs(
},
shape=batch_size,
),
- CompositeSpec(
+ Composite(
{
"shared": shared,
"hetero": hetero_2d,
@@ -2573,10 +2467,8 @@ def _get_heterogeneous_specs(
return torch.stack(spec_list, dim=stack_dim).cpu()
def test_stack_index(self):
- c1 = CompositeSpec(a=UnboundedContinuousTensorSpec())
- c2 = CompositeSpec(
- a=UnboundedContinuousTensorSpec(), b=UnboundedDiscreteTensorSpec()
- )
+ c1 = Composite(a=Unbounded())
+ c2 = Composite(a=Unbounded(), b=UnboundedDiscreteTensorSpec())
c = torch.stack([c1, c2], 0)
assert c.shape == torch.Size([2])
assert c[0] is c1
@@ -2585,19 +2477,19 @@ def test_stack_index(self):
assert c[..., 1] is c2
assert c[0, ...] is c1
assert c[1, ...] is c2
- assert isinstance(c[:], LazyStackedCompositeSpec)
+ assert isinstance(c[:], StackedComposite)
@pytest.mark.parametrize("stack_dim", [0, 1, 2, -3, -2, -1])
def test_stack_index_multdim(self, stack_dim):
- c1 = CompositeSpec(a=UnboundedContinuousTensorSpec(shape=(1, 3)), shape=(1, 3))
- c2 = CompositeSpec(
- a=UnboundedContinuousTensorSpec(shape=(1, 3)),
+ c1 = Composite(a=Unbounded(shape=(1, 3)), shape=(1, 3))
+ c2 = Composite(
+ a=Unbounded(shape=(1, 3)),
b=UnboundedDiscreteTensorSpec(shape=(1, 3)),
shape=(1, 3),
)
c = torch.stack([c1, c2], stack_dim)
if stack_dim in (0, -3):
- assert isinstance(c[:], LazyStackedCompositeSpec)
+ assert isinstance(c[:], StackedComposite)
assert c.shape == torch.Size([2, 1, 3])
assert c[0] is c1
assert c[1] is c2
@@ -2614,7 +2506,7 @@ def test_stack_index_multdim(self, stack_dim):
assert c[0, ...] is c1
assert c[1, ...] is c2
elif stack_dim == (1, -2):
- assert isinstance(c[:, :], LazyStackedCompositeSpec)
+ assert isinstance(c[:, :], StackedComposite)
assert c.shape == torch.Size([1, 2, 3])
assert c[:, 0] is c1
assert c[:, 1] is c2
@@ -2641,7 +2533,7 @@ def test_stack_index_multdim(self, stack_dim):
assert c[:, 0, ...] is c1
assert c[:, 1, ...] is c2
elif stack_dim == (2, -1):
- assert isinstance(c[:, :, :], LazyStackedCompositeSpec)
+ assert isinstance(c[:, :, :], StackedComposite)
with pytest.raises(
IndexError, match="along dimension 0 when the stack dimension is 2."
):
@@ -2660,9 +2552,9 @@ def test_stack_index_multdim(self, stack_dim):
@pytest.mark.parametrize("stack_dim", [0, 1, 2, -3, -2, -1])
def test_stack_expand_multi(self, stack_dim):
- c1 = CompositeSpec(a=UnboundedContinuousTensorSpec(shape=(1, 3)), shape=(1, 3))
- c2 = CompositeSpec(
- a=UnboundedContinuousTensorSpec(shape=(1, 3)),
+ c1 = Composite(a=Unbounded(shape=(1, 3)), shape=(1, 3))
+ c2 = Composite(
+ a=Unbounded(shape=(1, 3)),
b=UnboundedDiscreteTensorSpec(shape=(1, 3)),
shape=(1, 3),
)
@@ -2691,9 +2583,9 @@ def test_stack_expand_multi(self, stack_dim):
@pytest.mark.parametrize("stack_dim", [0, 1, 2, -3, -2, -1])
def test_stack_rand(self, stack_dim):
- c1 = CompositeSpec(a=UnboundedContinuousTensorSpec(shape=(1, 3)), shape=(1, 3))
- c2 = CompositeSpec(
- a=UnboundedContinuousTensorSpec(shape=(1, 3)),
+ c1 = Composite(a=Unbounded(shape=(1, 3)), shape=(1, 3))
+ c2 = Composite(
+ a=Unbounded(shape=(1, 3)),
b=UnboundedDiscreteTensorSpec(shape=(1, 3)),
shape=(1, 3),
)
@@ -2713,9 +2605,9 @@ def test_stack_rand(self, stack_dim):
@pytest.mark.parametrize("stack_dim", [0, 1, 2, -3, -2, -1])
def test_stack_rand_shape(self, stack_dim):
- c1 = CompositeSpec(a=UnboundedContinuousTensorSpec(shape=(1, 3)), shape=(1, 3))
- c2 = CompositeSpec(
- a=UnboundedContinuousTensorSpec(shape=(1, 3)),
+ c1 = Composite(a=Unbounded(shape=(1, 3)), shape=(1, 3))
+ c2 = Composite(
+ a=Unbounded(shape=(1, 3)),
b=UnboundedDiscreteTensorSpec(shape=(1, 3)),
shape=(1, 3),
)
@@ -2736,9 +2628,9 @@ def test_stack_rand_shape(self, stack_dim):
@pytest.mark.parametrize("stack_dim", [0, 1, 2, -3, -2, -1])
def test_stack_zero(self, stack_dim):
- c1 = CompositeSpec(a=UnboundedContinuousTensorSpec(shape=(1, 3)), shape=(1, 3))
- c2 = CompositeSpec(
- a=UnboundedContinuousTensorSpec(shape=(1, 3)),
+ c1 = Composite(a=Unbounded(shape=(1, 3)), shape=(1, 3))
+ c2 = Composite(
+ a=Unbounded(shape=(1, 3)),
b=UnboundedDiscreteTensorSpec(shape=(1, 3)),
shape=(1, 3),
)
@@ -2758,9 +2650,9 @@ def test_stack_zero(self, stack_dim):
@pytest.mark.parametrize("stack_dim", [0, 1, 2, -3, -2, -1])
def test_stack_zero_shape(self, stack_dim):
- c1 = CompositeSpec(a=UnboundedContinuousTensorSpec(shape=(1, 3)), shape=(1, 3))
- c2 = CompositeSpec(
- a=UnboundedContinuousTensorSpec(shape=(1, 3)),
+ c1 = Composite(a=Unbounded(shape=(1, 3)), shape=(1, 3))
+ c2 = Composite(
+ a=Unbounded(shape=(1, 3)),
b=UnboundedDiscreteTensorSpec(shape=(1, 3)),
shape=(1, 3),
)
@@ -2782,14 +2674,14 @@ def test_stack_zero_shape(self, stack_dim):
@pytest.mark.skipif(not torch.cuda.device_count(), reason="no cuda")
@pytest.mark.parametrize("stack_dim", [0, 1, 2, -3, -2, -1])
def test_to(self, stack_dim):
- c1 = CompositeSpec(a=UnboundedContinuousTensorSpec(shape=(1, 3)), shape=(1, 3))
- c2 = CompositeSpec(
- a=UnboundedContinuousTensorSpec(shape=(1, 3)),
+ c1 = Composite(a=Unbounded(shape=(1, 3)), shape=(1, 3))
+ c2 = Composite(
+ a=Unbounded(shape=(1, 3)),
b=UnboundedDiscreteTensorSpec(shape=(1, 3)),
shape=(1, 3),
)
c = torch.stack([c1, c2], stack_dim)
- assert isinstance(c, LazyStackedCompositeSpec)
+ assert isinstance(c, StackedComposite)
cdevice = c.to("cuda:0")
assert cdevice.device != c.device
assert cdevice.device == torch.device("cuda:0")
@@ -2799,9 +2691,9 @@ def test_to(self, stack_dim):
assert cdevice[index].device == torch.device("cuda:0")
def test_clone(self):
- c1 = CompositeSpec(a=UnboundedContinuousTensorSpec(shape=(1, 3)), shape=(1, 3))
- c2 = CompositeSpec(
- a=UnboundedContinuousTensorSpec(shape=(1, 3)),
+ c1 = Composite(a=Unbounded(shape=(1, 3)), shape=(1, 3))
+ c2 = Composite(
+ a=Unbounded(shape=(1, 3)),
b=UnboundedDiscreteTensorSpec(shape=(1, 3)),
shape=(1, 3),
)
@@ -2811,9 +2703,9 @@ def test_clone(self):
assert cclone[0] == c[0]
def test_to_numpy(self):
- c1 = CompositeSpec(a=BoundedTensorSpec(-1, 1, shape=(1, 3)), shape=(1, 3))
- c2 = CompositeSpec(
- a=BoundedTensorSpec(-1, 1, shape=(1, 3)),
+ c1 = Composite(a=Bounded(-1, 1, shape=(1, 3)), shape=(1, 3))
+ c2 = Composite(
+ a=Bounded(-1, 1, shape=(1, 3)),
b=UnboundedDiscreteTensorSpec(shape=(1, 3)),
shape=(1, 3),
)
@@ -2829,9 +2721,9 @@ def test_to_numpy(self):
c.to_numpy(td_fail, safe=True)
def test_unsqueeze(self):
- c1 = CompositeSpec(a=BoundedTensorSpec(-1, 1, shape=(1, 3)), shape=(1, 3))
- c2 = CompositeSpec(
- a=BoundedTensorSpec(-1, 1, shape=(1, 3)),
+ c1 = Composite(a=Bounded(-1, 1, shape=(1, 3)), shape=(1, 3))
+ c2 = Composite(
+ a=Bounded(-1, 1, shape=(1, 3)),
b=UnboundedDiscreteTensorSpec(shape=(1, 3)),
shape=(1, 3),
)
@@ -2984,12 +2876,11 @@ def test_project(self, batch_size):
def test_repr(self):
c = self._get_heterogeneous_specs()
-
- expected = f"""LazyStackedCompositeSpec(
+ expected = f"""StackedComposite(
fields={{
- hetero: LazyStackedUnboundedContinuousTensorSpec(
+ hetero: StackedUnboundedContinuous(
shape=torch.Size([3, -1]), device=cpu, dtype=torch.float32, domain=continuous),
- shared: BoundedTensorSpec(
+ shared: BoundedContinuous(
shape=torch.Size([3, 32, 32, 3]),
space=ContinuousBox(
low=Tensor(shape=torch.Size([3, 32, 32, 3]), device=cpu, dtype=torch.float32, contiguous=True),
@@ -2999,7 +2890,7 @@ def test_repr(self):
domain=continuous)}},
exclusive_fields={{
0 ->
- lidar: BoundedTensorSpec(
+ lidar: BoundedContinuous(
shape=torch.Size([20]),
space=ContinuousBox(
low=Tensor(shape=torch.Size([20]), device=cpu, dtype=torch.float32, contiguous=True),
@@ -3007,17 +2898,19 @@ def test_repr(self):
device=cpu,
dtype=torch.float32,
domain=continuous),
- individual_0_obs: CompositeSpec(
- individual_0_obs_0: UnboundedContinuousTensorSpec(
+ individual_0_obs: Composite(
+ individual_0_obs_0: UnboundedContinuous(
shape=torch.Size([3, 1]),
- space=None,
+ space=ContinuousBox(
+ low=Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.float32, contiguous=True),
+ high=Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.float32, contiguous=True)),
device=cpu,
dtype=torch.float32,
domain=continuous),
device=cpu,
shape=torch.Size([3])),
1 ->
- lidar: BoundedTensorSpec(
+ lidar: BoundedContinuous(
shape=torch.Size([20]),
space=ContinuousBox(
low=Tensor(shape=torch.Size([20]), device=cpu, dtype=torch.float32, contiguous=True),
@@ -3025,8 +2918,8 @@ def test_repr(self):
device=cpu,
dtype=torch.float32,
domain=continuous),
- individual_1_obs: CompositeSpec(
- individual_1_obs_0: BoundedTensorSpec(
+ individual_1_obs: Composite(
+ individual_1_obs_0: BoundedContinuous(
shape=torch.Size([3, 1, 2]),
space=ContinuousBox(
low=Tensor(shape=torch.Size([3, 1, 2]), device=cpu, dtype=torch.float32, contiguous=True),
@@ -3037,10 +2930,12 @@ def test_repr(self):
device=cpu,
shape=torch.Size([3])),
2 ->
- individual_2_obs: CompositeSpec(
- individual_1_obs_0: UnboundedContinuousTensorSpec(
+ individual_2_obs: Composite(
+ individual_1_obs_0: UnboundedContinuous(
shape=torch.Size([3, 1, 2, 3]),
- space=None,
+ space=ContinuousBox(
+ low=Tensor(shape=torch.Size([3, 1, 2, 3]), device=cpu, dtype=torch.float32, contiguous=True),
+ high=Tensor(shape=torch.Size([3, 1, 2, 3]), device=cpu, dtype=torch.float32, contiguous=True)),
device=cpu,
dtype=torch.float32,
domain=continuous),
@@ -3054,11 +2949,11 @@ def test_repr(self):
c = c[0:2]
del c["individual_0_obs"]
del c["individual_1_obs"]
- expected = f"""LazyStackedCompositeSpec(
+ expected = f"""StackedComposite(
fields={{
- hetero: LazyStackedUnboundedContinuousTensorSpec(
+ hetero: StackedUnboundedContinuous(
shape=torch.Size([2, -1]), device=cpu, dtype=torch.float32, domain=continuous),
- lidar: BoundedTensorSpec(
+ lidar: BoundedContinuous(
shape=torch.Size([2, 20]),
space=ContinuousBox(
low=Tensor(shape=torch.Size([2, 20]), device=cpu, dtype=torch.float32, contiguous=True),
@@ -3066,7 +2961,7 @@ def test_repr(self):
device=cpu,
dtype=torch.float32,
domain=continuous),
- shared: BoundedTensorSpec(
+ shared: BoundedContinuous(
shape=torch.Size([2, 32, 32, 3]),
space=ContinuousBox(
low=Tensor(shape=torch.Size([2, 32, 32, 3]), device=cpu, dtype=torch.float32, contiguous=True),
@@ -3100,7 +2995,7 @@ def test_consolidate_spec(self, batch_size):
@pytest.mark.parametrize("batch_size", [(), (2,), (2, 1)])
def test_consolidate_spec_exclusive_lazy_stacked(self, batch_size):
- shared = UnboundedContinuousTensorSpec(
+ shared = Unbounded(
shape=(
*batch_size,
5,
@@ -3110,29 +3005,29 @@ def test_consolidate_spec_exclusive_lazy_stacked(self, batch_size):
)
lazy_spec = torch.stack(
[
- UnboundedContinuousTensorSpec(shape=(*batch_size, 5, 6, 7)),
- UnboundedContinuousTensorSpec(shape=(*batch_size, 5, 7, 7)),
- UnboundedContinuousTensorSpec(shape=(*batch_size, 5, 8, 7)),
- UnboundedContinuousTensorSpec(shape=(*batch_size, 5, 8, 7)),
+ Unbounded(shape=(*batch_size, 5, 6, 7)),
+ Unbounded(shape=(*batch_size, 5, 7, 7)),
+ Unbounded(shape=(*batch_size, 5, 8, 7)),
+ Unbounded(shape=(*batch_size, 5, 8, 7)),
],
dim=len(batch_size),
)
spec_list = [
- CompositeSpec(
+ Composite(
{
"shared": shared,
"lazy_spec": lazy_spec,
},
shape=batch_size,
),
- CompositeSpec(
+ Composite(
{
"shared": shared,
},
shape=batch_size,
),
- CompositeSpec(
+ Composite(
{},
shape=batch_size,
device="cpu",
@@ -3168,9 +3063,7 @@ def test_update(self, batch_size, stack_dim=0):
spec[1]["individual_1_obs"]["individual_1_obs_0"].space.low.sum() == 0
) # Only non exclusive keys will be updated
- new = torch.stack(
- [UnboundedContinuousTensorSpec(shape=(*batch_size, i)) for i in range(3)], 0
- )
+ new = torch.stack([Unbounded(shape=(*batch_size, i)) for i in range(3)], 0)
spec2["new"] = new
spec.update(spec2)
assert spec["new"] == new
@@ -3181,7 +3074,7 @@ def test_set_item(self, batch_size, stack_dim):
spec = self._get_heterogeneous_specs(batch_size, stack_dim)
new = torch.stack(
- [UnboundedContinuousTensorSpec(shape=(*batch_size, i)) for i in range(3)],
+ [Unbounded(shape=(*batch_size, i)) for i in range(3)],
stack_dim,
)
spec["new"] = new
@@ -3196,15 +3089,15 @@ def test_set_item(self, batch_size, stack_dim):
spec[("other", "key")] = new
assert spec[("other", "key")] == new
- assert isinstance(spec["other"], LazyStackedCompositeSpec)
+ assert isinstance(spec["other"], StackedComposite)
with pytest.raises(RuntimeError, match="key should be a Sequence"):
spec[0] = new
comp = torch.stack(
[
- CompositeSpec(
- {"a": UnboundedContinuousTensorSpec(shape=(*batch_size, i))},
+ Composite(
+ {"a": Unbounded(shape=(*batch_size, i))},
shape=batch_size,
)
for i in range(3)
@@ -3220,10 +3113,10 @@ def test_set_item(self, batch_size, stack_dim):
@pytest.mark.parametrize(
"spec_class",
[
- BinaryDiscreteTensorSpec,
- OneHotDiscreteTensorSpec,
- MultiOneHotDiscreteTensorSpec,
- CompositeSpec,
+ Binary,
+ OneHot,
+ MultiOneHot,
+ Composite,
],
)
@pytest.mark.parametrize(
@@ -3240,13 +3133,13 @@ def test_set_item(self, batch_size, stack_dim):
], # [:,1:2,1]
)
def test_invalid_indexing(spec_class, idx):
- if spec_class in [BinaryDiscreteTensorSpec, OneHotDiscreteTensorSpec]:
+ if spec_class in [Binary, OneHot]:
spec = spec_class(n=4, shape=[3, 4])
- elif spec_class == MultiDiscreteTensorSpec:
+ elif spec_class == MultiCategorical:
spec = spec_class([2, 2, 2], shape=[3])
- elif spec_class == MultiOneHotDiscreteTensorSpec:
+ elif spec_class == MultiOneHot:
spec = spec_class([4], shape=[3, 4])
- elif spec_class == CompositeSpec:
+ elif spec_class == Composite:
spec = spec_class(k=UnboundedDiscreteTensorSpec(shape=(3, 4)), shape=(3,))
with pytest.raises(IndexError):
spec[idx]
@@ -3256,13 +3149,13 @@ def test_invalid_indexing(spec_class, idx):
@pytest.mark.parametrize(
"spec_class",
[
- BinaryDiscreteTensorSpec,
- DiscreteTensorSpec,
- MultiOneHotDiscreteTensorSpec,
- OneHotDiscreteTensorSpec,
- UnboundedContinuousTensorSpec,
+ Binary,
+ Categorical,
+ MultiOneHot,
+ OneHot,
+ Unbounded,
UnboundedDiscreteTensorSpec,
- CompositeSpec,
+ Composite,
],
)
def test_valid_indexing(spec_class):
@@ -3270,14 +3163,14 @@ def test_valid_indexing(spec_class):
args = {"0d": [], "2d": [], "3d": [], "4d": [], "5d": []}
kwargs = {}
if spec_class in [
- BinaryDiscreteTensorSpec,
- DiscreteTensorSpec,
- OneHotDiscreteTensorSpec,
+ Binary,
+ Categorical,
+ OneHot,
]:
args = {"0d": [0], "2d": [3], "3d": [4], "4d": [6], "5d": [7]}
- elif spec_class == MultiOneHotDiscreteTensorSpec:
+ elif spec_class == MultiOneHot:
args = {"0d": [[0]], "2d": [[3]], "3d": [[4]], "4d": [[6]], "5d": [[7]]}
- elif spec_class == MultiDiscreteTensorSpec:
+ elif spec_class == MultiCategorical:
args = {
"0d": [[0]],
"2d": [[2] * 3],
@@ -3285,7 +3178,7 @@ def test_valid_indexing(spec_class):
"4d": [[1] * 6],
"5d": [[2] * 7],
}
- elif spec_class == BoundedTensorSpec:
+ elif spec_class == Bounded:
min_max = (-1, -1)
args = {
"0d": min_max,
@@ -3294,17 +3187,17 @@ def test_valid_indexing(spec_class):
"4d": min_max,
"5d": min_max,
}
- elif spec_class == CompositeSpec:
+ elif spec_class == Composite:
kwargs = {
"k1": UnboundedDiscreteTensorSpec(shape=(5, 3, 4, 6, 7, 8)),
- "k2": OneHotDiscreteTensorSpec(n=7, shape=(5, 3, 4, 6, 7)),
+ "k2": OneHot(n=7, shape=(5, 3, 4, 6, 7)),
}
spec_0d = spec_class(*args["0d"], **kwargs)
if spec_class in [
- UnboundedContinuousTensorSpec,
+ Unbounded,
UnboundedDiscreteTensorSpec,
- CompositeSpec,
+ Composite,
]:
spec_0d = spec_class(*args["0d"], shape=[], **kwargs)
spec_2d = spec_class(*args["2d"], shape=[5, 3], **kwargs)
@@ -3374,10 +3267,10 @@ def test_valid_indexing(spec_class):
# Specific tests when specs have non-indexable dimensions
if spec_class in [
- BinaryDiscreteTensorSpec,
- OneHotDiscreteTensorSpec,
- MultiDiscreteTensorSpec,
- MultiOneHotDiscreteTensorSpec,
+ Binary,
+ OneHot,
+ MultiCategorical,
+ MultiOneHot,
]:
# Ellipsis
assert spec_0d[None].shape == torch.Size([1, 0])
@@ -3390,7 +3283,6 @@ def test_valid_indexing(spec_class):
assert spec_3d[None, 1, ..., None].shape == torch.Size([1, 3, 1, 4])
assert spec_4d[:, None, ..., None, :].shape == torch.Size([5, 1, 3, 1, 4, 6])
- # BoundedTensorSpec, DiscreteTensorSpec, UnboundedContinuousTensorSpec, UnboundedDiscreteTensorSpec, CompositeSpec
else:
# Integers
assert spec_2d[0, 1].shape == torch.Size([])
@@ -3407,7 +3299,7 @@ def test_valid_indexing(spec_class):
assert spec_4d[:, None, ..., None, :].shape == torch.Size([5, 1, 3, 4, 1, 6])
# Additional tests for composite spec
- if spec_class == CompositeSpec:
+ if spec_class == Composite:
assert spec_2d[1]["k1"].shape == torch.Size([3, 4, 6, 7, 8])
assert spec_3d[[1, 2]]["k1"].shape == torch.Size([2, 3, 4, 6, 7, 8])
assert spec_2d[torch.randint(3, (3, 2))]["k1"].shape == torch.Size(
@@ -3422,9 +3314,7 @@ def test_valid_indexing(spec_class):
def test_composite_contains():
- spec = CompositeSpec(
- a=CompositeSpec(b=CompositeSpec(c=UnboundedContinuousTensorSpec()))
- )
+ spec = Composite(a=Composite(b=Composite(c=Unbounded())))
assert "a" in spec.keys()
assert "a" in spec.keys(True)
assert ("a",) in spec.keys()
@@ -3444,10 +3334,10 @@ def get_all_keys(spec: TensorSpec, include_exclusive: bool):
"""
keys = set()
- if isinstance(spec, LazyStackedCompositeSpec) and include_exclusive:
+ if isinstance(spec, StackedComposite) and include_exclusive:
for t in spec._specs:
keys = keys.union(get_all_keys(t, include_exclusive))
- if isinstance(spec, CompositeSpec):
+ if isinstance(spec, Composite):
for key in spec.keys():
keys.add((key,))
inner_keys = get_all_keys(spec[key], include_exclusive)
@@ -3481,7 +3371,7 @@ def _make_mask(self, shape):
def _one_hot_spec(self, shape, device, n):
shape = torch.Size([*shape, n])
mask = self._make_mask(shape).to(device)
- return OneHotDiscreteTensorSpec(n, shape, device, mask=mask)
+ return OneHot(n, shape, device, mask=mask)
def _mult_one_hot_spec(self, shape, device, n):
shape = torch.Size([*shape, n + n + 2])
@@ -3492,11 +3382,11 @@ def _mult_one_hot_spec(self, shape, device, n):
],
-1,
)
- return MultiOneHotDiscreteTensorSpec([n, n + 2], shape, device, mask=mask)
+ return MultiOneHot([n, n + 2], shape, device, mask=mask)
def _discrete_spec(self, shape, device, n):
mask = self._make_mask(torch.Size([*shape, n])).to(device)
- return DiscreteTensorSpec(n, shape, device, mask=mask)
+ return Categorical(n, shape, device, mask=mask)
def _mult_discrete_spec(self, shape, device, n):
shape = torch.Size([*shape, 2])
@@ -3507,7 +3397,7 @@ def _mult_discrete_spec(self, shape, device, n):
],
-1,
)
- return MultiDiscreteTensorSpec([n, n + 2], shape, device, mask=mask)
+ return MultiCategorical([n, n + 2], shape, device, mask=mask)
def test_equal(self, shape, device, spectype, rand_shape, n=5):
shape = torch.Size(shape)
@@ -3579,7 +3469,7 @@ def test_project(self, shape, device, spectype, rand_shape, n=5):
class TestDynamicSpec:
def test_all(self):
- spec = UnboundedContinuousTensorSpec((-1, 1, 2))
+ spec = Unbounded((-1, 1, 2))
unb = spec
assert spec.shape == (-1, 1, 2)
x = torch.randn(3, 1, 2)
@@ -3593,14 +3483,14 @@ def test_all(self):
xunbd = x
assert spec.is_in(x)
- spec = BoundedTensorSpec(shape=(-1, 1, 2), low=-1, high=1)
+ spec = Bounded(shape=(-1, 1, 2), low=-1, high=1)
bound = spec
assert spec.shape == (-1, 1, 2)
x = torch.rand((3, 1, 2))
xbound = x
assert spec.is_in(x)
- spec = OneHotDiscreteTensorSpec(shape=(-1, 1, 2, 4), n=4)
+ spec = OneHot(shape=(-1, 1, 2, 4), n=4)
oneh = spec
assert spec.shape == (-1, 1, 2, 4)
x = torch.zeros((3, 1, 2, 4), dtype=torch.bool)
@@ -3608,14 +3498,14 @@ def test_all(self):
xoneh = x
assert spec.is_in(x)
- spec = DiscreteTensorSpec(shape=(-1, 1, 2), n=4)
+ spec = Categorical(shape=(-1, 1, 2), n=4)
disc = spec
assert spec.shape == (-1, 1, 2)
x = torch.randint(4, (3, 1, 2))
xdisc = x
assert spec.is_in(x)
- spec = MultiOneHotDiscreteTensorSpec(shape=(-1, 1, 2, 7), nvec=[3, 4])
+ spec = MultiOneHot(shape=(-1, 1, 2, 7), nvec=[3, 4])
moneh = spec
assert spec.shape == (-1, 1, 2, 7)
x = torch.zeros((3, 1, 2, 7), dtype=torch.bool)
@@ -3624,7 +3514,7 @@ def test_all(self):
xmoneh = x
assert spec.is_in(x)
- spec = MultiDiscreteTensorSpec(shape=(-1, 1, 2, 2), nvec=[3, 4])
+ spec = MultiCategorical(shape=(-1, 1, 2, 2), nvec=[3, 4])
mdisc = spec
assert spec.mask is None
assert spec.shape == (-1, 1, 2, 2)
@@ -3632,7 +3522,7 @@ def test_all(self):
xmdisc = x
assert spec.is_in(x)
- spec = CompositeSpec(
+ spec = Composite(
unb=unb,
unbd=unbd,
bound=bound,
@@ -3659,15 +3549,15 @@ def test_all(self):
assert spec.is_in(data)
def test_expand(self):
- unb = UnboundedContinuousTensorSpec((-1, 1, 2))
+ unb = Unbounded((-1, 1, 2))
unbd = UnboundedDiscreteTensorSpec((-1, 1, 2))
- bound = BoundedTensorSpec(shape=(-1, 1, 2), low=-1, high=1)
- oneh = OneHotDiscreteTensorSpec(shape=(-1, 1, 2, 4), n=4)
- disc = DiscreteTensorSpec(shape=(-1, 1, 2), n=4)
- moneh = MultiOneHotDiscreteTensorSpec(shape=(-1, 1, 2, 7), nvec=[3, 4])
- mdisc = MultiDiscreteTensorSpec(shape=(-1, 1, 2, 2), nvec=[3, 4])
+ bound = Bounded(shape=(-1, 1, 2), low=-1, high=1)
+ oneh = OneHot(shape=(-1, 1, 2, 4), n=4)
+ disc = Categorical(shape=(-1, 1, 2), n=4)
+ moneh = MultiOneHot(shape=(-1, 1, 2, 7), nvec=[3, 4])
+ mdisc = MultiCategorical(shape=(-1, 1, 2, 2), nvec=[3, 4])
- spec = CompositeSpec(
+ spec = Composite(
unb=unb,
unbd=unbd,
bound=bound,
@@ -3689,7 +3579,7 @@ def test_expand(self):
class TestNonTensorSpec:
def test_sample(self):
- nts = NonTensorSpec(shape=(3, 4))
+ nts = NonTensor(shape=(3, 4))
assert nts.one((2,)).shape == (2, 3, 4)
assert nts.rand((2,)).shape == (2, 3, 4)
assert nts.zero((2,)).shape == (2, 3, 4)
@@ -3707,26 +3597,24 @@ def test_device_ordinal():
assert _make_ordinal_device(device) is None
device = torch.device("cuda")
- unb = UnboundedContinuousTensorSpec((-1, 1, 2), device=device)
+ unb = Unbounded((-1, 1, 2), device=device)
assert unb.device == torch.device("cuda:0")
unbd = UnboundedDiscreteTensorSpec((-1, 1, 2), device=device)
assert unbd.device == torch.device("cuda:0")
- bound = BoundedTensorSpec(shape=(-1, 1, 2), low=-1, high=1, device=device)
+ bound = Bounded(shape=(-1, 1, 2), low=-1, high=1, device=device)
assert bound.device == torch.device("cuda:0")
- oneh = OneHotDiscreteTensorSpec(shape=(-1, 1, 2, 4), n=4, device=device)
+ oneh = OneHot(shape=(-1, 1, 2, 4), n=4, device=device)
assert oneh.device == torch.device("cuda:0")
- disc = DiscreteTensorSpec(shape=(-1, 1, 2), n=4, device=device)
+ disc = Categorical(shape=(-1, 1, 2), n=4, device=device)
assert disc.device == torch.device("cuda:0")
- moneh = MultiOneHotDiscreteTensorSpec(
- shape=(-1, 1, 2, 7), nvec=[3, 4], device=device
- )
+ moneh = MultiOneHot(shape=(-1, 1, 2, 7), nvec=[3, 4], device=device)
assert moneh.device == torch.device("cuda:0")
- mdisc = MultiDiscreteTensorSpec(shape=(-1, 1, 2, 2), nvec=[3, 4], device=device)
+ mdisc = MultiCategorical(shape=(-1, 1, 2, 2), nvec=[3, 4], device=device)
assert mdisc.device == torch.device("cuda:0")
- mdisc = NonTensorSpec(shape=(-1, 1, 2, 2), device=device)
+ mdisc = NonTensor(shape=(-1, 1, 2, 2), device=device)
assert mdisc.device == torch.device("cuda:0")
- spec = CompositeSpec(
+ spec = Composite(
unb=unb,
unbd=unbd,
bound=bound,
@@ -3740,6 +3628,181 @@ def test_device_ordinal():
assert spec.device == torch.device("cuda:0")
+class TestLegacy:
+ def test_one_hot(self):
+ with pytest.warns(
+ DeprecationWarning,
+ match="The OneHotDiscreteTensorSpec has been deprecated and will be removed in v0.7. Please use OneHot instead.",
+ ):
+ one_hot = OneHotDiscreteTensorSpec(n=4)
+ assert isinstance(one_hot, OneHotDiscreteTensorSpec)
+ assert isinstance(one_hot, OneHot)
+ assert not isinstance(one_hot, Categorical)
+ one_hot = OneHot(n=4)
+ assert isinstance(one_hot, OneHotDiscreteTensorSpec)
+ assert isinstance(one_hot, OneHot)
+ assert not isinstance(one_hot, Categorical)
+
+ def test_discrete(self):
+ with pytest.warns(
+ DeprecationWarning,
+ match="The DiscreteTensorSpec has been deprecated and will be removed in v0.7. Please use Categorical instead.",
+ ):
+ discrete = DiscreteTensorSpec(n=4)
+ assert isinstance(discrete, DiscreteTensorSpec)
+ assert isinstance(discrete, Categorical)
+ assert not isinstance(discrete, OneHot)
+ discrete = Categorical(n=4)
+ assert isinstance(discrete, DiscreteTensorSpec)
+ assert isinstance(discrete, Categorical)
+ assert not isinstance(discrete, OneHot)
+
+ def test_unbounded(self):
+
+ unbounded_continuous_impl = Unbounded(dtype=torch.float)
+ assert isinstance(unbounded_continuous_impl, Unbounded)
+ assert isinstance(unbounded_continuous_impl, UnboundedContinuous)
+ assert isinstance(unbounded_continuous_impl, UnboundedContinuousTensorSpec)
+ assert not isinstance(unbounded_continuous_impl, UnboundedDiscreteTensorSpec)
+
+ unbounded_discrete_impl = Unbounded(dtype=torch.int)
+ assert isinstance(unbounded_discrete_impl, Unbounded)
+ assert isinstance(unbounded_discrete_impl, UnboundedDiscrete)
+ assert isinstance(unbounded_discrete_impl, UnboundedDiscreteTensorSpec)
+ assert not isinstance(unbounded_discrete_impl, UnboundedContinuousTensorSpec)
+
+ with pytest.warns(
+ DeprecationWarning,
+ match="The UnboundedContinuousTensorSpec has been deprecated and will be removed in v0.7. Please use Unbounded instead.",
+ ):
+ unbounded_continuous = UnboundedContinuousTensorSpec()
+ assert isinstance(unbounded_continuous, Unbounded)
+ assert isinstance(unbounded_continuous, UnboundedContinuous)
+ assert isinstance(unbounded_continuous, UnboundedContinuousTensorSpec)
+ assert not isinstance(unbounded_continuous, UnboundedDiscreteTensorSpec)
+
+ with warnings.catch_warnings():
+ unbounded_continuous = UnboundedContinuous()
+
+ with pytest.warns(
+ DeprecationWarning,
+ match="The UnboundedDiscreteTensorSpec has been deprecated and will be removed in v0.7. Please use Unbounded instead.",
+ ):
+ unbounded_discrete = UnboundedDiscreteTensorSpec()
+ assert isinstance(unbounded_discrete, Unbounded)
+ assert isinstance(unbounded_discrete, UnboundedDiscrete)
+ assert isinstance(unbounded_discrete, UnboundedDiscreteTensorSpec)
+ assert not isinstance(unbounded_discrete, UnboundedContinuousTensorSpec)
+
+ with warnings.catch_warnings():
+ unbounded_discrete = UnboundedDiscrete()
+
+ # What if we mess with dtypes?
+ with pytest.warns(DeprecationWarning):
+ unbounded_continuous_fake = UnboundedContinuousTensorSpec(dtype=torch.int32)
+ assert isinstance(unbounded_continuous_fake, Unbounded)
+ assert not isinstance(unbounded_continuous_fake, UnboundedContinuous)
+ assert not isinstance(unbounded_continuous_fake, UnboundedContinuousTensorSpec)
+ assert isinstance(unbounded_continuous_fake, UnboundedDiscrete)
+ assert isinstance(unbounded_continuous_fake, UnboundedDiscreteTensorSpec)
+
+ with pytest.warns(DeprecationWarning):
+ unbounded_discrete_fake = UnboundedDiscreteTensorSpec(dtype=torch.float32)
+ assert isinstance(unbounded_discrete_fake, Unbounded)
+ assert isinstance(unbounded_discrete_fake, UnboundedContinuous)
+ assert isinstance(unbounded_discrete_fake, UnboundedContinuousTensorSpec)
+ assert not isinstance(unbounded_discrete_fake, UnboundedDiscrete)
+ assert not isinstance(unbounded_discrete_fake, UnboundedDiscreteTensorSpec)
+
+ def test_multi_one_hot(self):
+ with pytest.warns(
+ DeprecationWarning,
+ match="The MultiOneHotDiscreteTensorSpec has been deprecated and will be removed in v0.7. Please use MultiOneHot instead.",
+ ):
+ one_hot = MultiOneHotDiscreteTensorSpec(nvec=[4, 3])
+ assert isinstance(one_hot, MultiOneHotDiscreteTensorSpec)
+ assert isinstance(one_hot, MultiOneHot)
+ assert not isinstance(one_hot, MultiCategorical)
+ one_hot = MultiOneHot(nvec=[4, 3])
+ assert isinstance(one_hot, MultiOneHotDiscreteTensorSpec)
+ assert isinstance(one_hot, MultiOneHot)
+ assert not isinstance(one_hot, MultiCategorical)
+
+ def test_multi_categorical(self):
+ with pytest.warns(
+ DeprecationWarning,
+ match="The MultiDiscreteTensorSpec has been deprecated and will be removed in v0.7. Please use MultiCategorical instead.",
+ ):
+ categorical = MultiDiscreteTensorSpec(nvec=[4, 3])
+ assert isinstance(categorical, MultiDiscreteTensorSpec)
+ assert isinstance(categorical, MultiCategorical)
+ assert not isinstance(categorical, MultiOneHot)
+ categorical = MultiCategorical(nvec=[4, 3])
+ assert isinstance(categorical, MultiDiscreteTensorSpec)
+ assert isinstance(categorical, MultiCategorical)
+ assert not isinstance(categorical, MultiOneHot)
+
+ def test_binary(self):
+ with pytest.warns(
+ DeprecationWarning,
+ match="The BinaryDiscreteTensorSpec has been deprecated and will be removed in v0.7. Please use Binary instead.",
+ ):
+ binary = BinaryDiscreteTensorSpec(5)
+ assert isinstance(binary, BinaryDiscreteTensorSpec)
+ assert isinstance(binary, Binary)
+ assert not isinstance(binary, MultiOneHot)
+ binary = Binary(5)
+ assert isinstance(binary, BinaryDiscreteTensorSpec)
+ assert isinstance(binary, Binary)
+ assert not isinstance(binary, MultiOneHot)
+
+ def test_bounded(self):
+ with pytest.warns(
+ DeprecationWarning,
+ match="The BoundedTensorSpec has been deprecated and will be removed in v0.7. Please use Bounded instead.",
+ ):
+ bounded = BoundedTensorSpec(-2, 2, shape=())
+ assert isinstance(bounded, BoundedTensorSpec)
+ assert isinstance(bounded, Bounded)
+ assert not isinstance(bounded, MultiOneHot)
+ bounded = Bounded(-2, 2, shape=())
+ assert isinstance(bounded, BoundedTensorSpec)
+ assert isinstance(bounded, Bounded)
+ assert not isinstance(bounded, MultiOneHot)
+
+ def test_composite(self):
+ with (
+ pytest.warns(
+ DeprecationWarning,
+ match="The CompositeSpec has been deprecated and will be removed in v0.7. Please use Composite instead.",
+ )
+ ):
+ composite = CompositeSpec()
+ assert isinstance(composite, CompositeSpec)
+ assert isinstance(composite, Composite)
+ assert not isinstance(composite, MultiOneHot)
+ composite = Composite()
+ assert isinstance(composite, CompositeSpec)
+ assert isinstance(composite, Composite)
+ assert not isinstance(composite, MultiOneHot)
+
+ def test_non_tensor(self):
+ with (
+ pytest.warns(
+ DeprecationWarning,
+ match="The NonTensorSpec has been deprecated and will be removed in v0.7. Please use NonTensor instead.",
+ )
+ ):
+ non_tensor = NonTensorSpec()
+ assert isinstance(non_tensor, NonTensorSpec)
+ assert isinstance(non_tensor, NonTensor)
+ assert not isinstance(non_tensor, MultiOneHot)
+ non_tensor = NonTensor()
+ assert isinstance(non_tensor, NonTensorSpec)
+ assert isinstance(non_tensor, NonTensor)
+ assert not isinstance(non_tensor, MultiOneHot)
+
+
if __name__ == "__main__":
args, unknown = argparse.ArgumentParser().parse_known_args()
pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown)
diff --git a/test/test_tensordictmodules.py b/test/test_tensordictmodules.py
index 38360a464e0..ea177cb9f96 100644
--- a/test/test_tensordictmodules.py
+++ b/test/test_tensordictmodules.py
@@ -11,11 +11,7 @@
from tensordict import LazyStackedTensorDict, pad, TensorDict, unravel_key_list
from tensordict.nn import InteractionType, TensorDictModule, TensorDictSequential
from torch import nn
-from torchrl.data.tensor_specs import (
- BoundedTensorSpec,
- CompositeSpec,
- UnboundedContinuousTensorSpec,
-)
+from torchrl.data.tensor_specs import Bounded, Composite, Unbounded
from torchrl.envs import (
CatFrames,
Compose,
@@ -119,8 +115,8 @@ def forward(self, x):
return self.linear_1(x), self.linear_2(x)
spec_dict = {
- "_": UnboundedContinuousTensorSpec((4,)),
- "out_2": UnboundedContinuousTensorSpec((3,)),
+ "_": Unbounded((4,)),
+ "out_2": Unbounded((3,)),
}
# warning due to "_" in spec keys
@@ -129,7 +125,7 @@ def forward(self, x):
MultiHeadLinear(5, 4, 3),
in_keys=["input"],
out_keys=["_", "out_2"],
- spec=CompositeSpec(**spec_dict),
+ spec=Composite(**spec_dict),
)
@pytest.mark.parametrize("safe", [True, False])
@@ -146,9 +142,9 @@ def test_stateful(self, safe, spec_type, lazy):
if spec_type is None:
spec = None
elif spec_type == "bounded":
- spec = BoundedTensorSpec(-0.1, 0.1, 4)
+ spec = Bounded(-0.1, 0.1, 4)
elif spec_type == "unbounded":
- spec = UnboundedContinuousTensorSpec(4)
+ spec = Unbounded(4)
if safe and spec is None:
with pytest.raises(
@@ -189,7 +185,7 @@ def test_stateful(self, safe, spec_type, lazy):
@pytest.mark.parametrize("out_keys", [["loc", "scale"], ["loc_1", "scale_1"]])
@pytest.mark.parametrize("lazy", [True, False])
@pytest.mark.parametrize(
- "exp_mode", [InteractionType.MODE, InteractionType.RANDOM, None]
+ "exp_mode", [InteractionType.DETERMINISTIC, InteractionType.RANDOM, None]
)
def test_stateful_probabilistic(self, safe, spec_type, lazy, exp_mode, out_keys):
torch.manual_seed(0)
@@ -210,9 +206,9 @@ def test_stateful_probabilistic(self, safe, spec_type, lazy, exp_mode, out_keys)
if spec_type is None:
spec = None
elif spec_type == "bounded":
- spec = BoundedTensorSpec(-0.1, 0.1, 4)
+ spec = Bounded(-0.1, 0.1, 4)
elif spec_type == "unbounded":
- spec = UnboundedContinuousTensorSpec(4)
+ spec = Unbounded(4)
else:
raise NotImplementedError
@@ -291,9 +287,9 @@ def test_stateful(self, safe, spec_type, lazy):
if spec_type is None:
spec = None
elif spec_type == "bounded":
- spec = BoundedTensorSpec(-0.1, 0.1, 4)
+ spec = Bounded(-0.1, 0.1, 4)
elif spec_type == "unbounded":
- spec = UnboundedContinuousTensorSpec(4)
+ spec = Unbounded(4)
kwargs = {}
@@ -368,9 +364,9 @@ def test_stateful_probabilistic(self, safe, spec_type, lazy):
if spec_type is None:
spec = None
elif spec_type == "bounded":
- spec = BoundedTensorSpec(-0.1, 0.1, 4)
+ spec = Bounded(-0.1, 0.1, 4)
elif spec_type == "unbounded":
- spec = UnboundedContinuousTensorSpec(4)
+ spec = Unbounded(4)
else:
raise NotImplementedError
@@ -481,7 +477,7 @@ def test_sequential_partial(self, stack):
net3 = nn.Sequential(net3, NormalParamExtractor())
net3 = SafeModule(net3, in_keys=["c"], out_keys=["loc", "scale"])
- spec = BoundedTensorSpec(-0.1, 0.1, 4)
+ spec = Bounded(-0.1, 0.1, 4)
kwargs = {"distribution_class": TanhNormal}
@@ -1340,7 +1336,7 @@ def call(data, params):
def test_safe_specs():
out_key = ("a", "b")
- spec = CompositeSpec(CompositeSpec({out_key: UnboundedContinuousTensorSpec()}))
+ spec = Composite(Composite({out_key: Unbounded()}))
original_spec = spec.clone()
mod = SafeModule(
module=nn.Linear(3, 1),
@@ -1354,9 +1350,7 @@ def test_safe_specs():
def test_actor_critic_specs():
action_key = ("agents", "action")
- spec = CompositeSpec(
- CompositeSpec({action_key: UnboundedContinuousTensorSpec(shape=(3,))})
- )
+ spec = Composite(Composite({action_key: Unbounded(shape=(3,))}))
policy_module = TensorDictModule(
nn.Linear(3, 1),
in_keys=[("agents", "observation")],
diff --git a/test/test_transforms.py b/test/test_transforms.py
index c38908eba1d..ca9a031bb2f 100644
--- a/test/test_transforms.py
+++ b/test/test_transforms.py
@@ -15,7 +15,6 @@
from functools import partial
from sys import platform
-import numpy as np
import pytest
import tensordict.tensordict
@@ -51,15 +50,15 @@
from torch import multiprocessing as mp, nn, Tensor
from torchrl._utils import _replace_last, prod
from torchrl.data import (
- BoundedTensorSpec,
- CompositeSpec,
- DiscreteTensorSpec,
+ Bounded,
+ Categorical,
+ Composite,
LazyTensorStorage,
ReplayBuffer,
TensorDictReplayBuffer,
TensorSpec,
TensorStorage,
- UnboundedContinuousTensorSpec,
+ Unbounded,
)
from torchrl.envs import (
ActionMask,
@@ -147,7 +146,7 @@ class TransformBase:
We ask for every new transform tests to be coded following this minimum requirement class.
- Of course, specific behaviours can also be tested separately.
+ Of course, specific behaviors can also be tested separately.
If your transform identifies an issue with the EnvBase or _BatchedEnv abstraction(s),
this needs to be corrected independently.
@@ -934,21 +933,17 @@ def test_catframes_transform_observation_spec(self):
)
mins = [0, 0.5]
maxes = [0.5, 1]
- observation_spec = CompositeSpec(
+ observation_spec = Composite(
{
- key: BoundedTensorSpec(
- space_min, space_max, (1, 3, 3), dtype=torch.double
- )
+ key: Bounded(space_min, space_max, (1, 3, 3), dtype=torch.double)
for key, space_min, space_max in zip(keys, mins, maxes)
}
)
result = cat_frames.transform_observation_spec(observation_spec)
- observation_spec = CompositeSpec(
+ observation_spec = Composite(
{
- key: BoundedTensorSpec(
- space_min, space_max, (1, 3, 3), dtype=torch.double
- )
+ key: Bounded(space_min, space_max, (1, 3, 3), dtype=torch.double)
for key, space_min, space_max in zip(keys, mins, maxes)
}
)
@@ -1502,15 +1497,12 @@ def test_r3mnet_transform_observation_spec(
):
r3m_net = _R3MNet(in_keys, out_keys, model, del_keys)
- observation_spec = CompositeSpec(
- {key: BoundedTensorSpec(-1, 1, (3, 16, 16), device) for key in in_keys}
+ observation_spec = Composite(
+ {key: Bounded(-1, 1, (3, 16, 16), device) for key in in_keys}
)
if del_keys:
- exp_ts = CompositeSpec(
- {
- key: UnboundedContinuousTensorSpec(r3m_net.outdim, device)
- for key in out_keys
- }
+ exp_ts = Composite(
+ {key: Unbounded(r3m_net.outdim, device) for key in out_keys}
)
observation_spec_out = r3m_net.transform_observation_spec(observation_spec)
@@ -1526,8 +1518,8 @@ def test_r3mnet_transform_observation_spec(
for key in in_keys:
ts_dict[key] = observation_spec[key]
for key in out_keys:
- ts_dict[key] = UnboundedContinuousTensorSpec(r3m_net.outdim, device)
- exp_ts = CompositeSpec(ts_dict)
+ ts_dict[key] = Unbounded(r3m_net.outdim, device)
+ exp_ts = Composite(ts_dict)
observation_spec_out = r3m_net.transform_observation_spec(observation_spec)
@@ -2020,12 +2012,12 @@ def test_transform_no_env(self, keys, device, out_key):
assert tdc.get("dont touch").shape == dont_touch.shape
if len(keys) == 1:
- observation_spec = BoundedTensorSpec(0, 1, (1, 4, 32))
+ observation_spec = Bounded(0, 1, (1, 4, 32))
observation_spec = cattensors.transform_observation_spec(observation_spec)
assert observation_spec.shape == torch.Size([1, len(keys) * 4, 32])
else:
- observation_spec = CompositeSpec(
- {key: BoundedTensorSpec(0, 1, (1, 4, 32)) for key in keys}
+ observation_spec = Composite(
+ {key: Bounded(0, 1, (1, 4, 32)) for key in keys}
)
observation_spec = cattensors.transform_observation_spec(observation_spec)
assert observation_spec[out_key].shape == torch.Size([1, len(keys) * 4, 32])
@@ -2166,12 +2158,12 @@ def test_transform_no_env(self, keys, h, nchannels, batch, device):
assert (td.get("dont touch") == dont_touch).all()
if len(keys) == 1:
- observation_spec = BoundedTensorSpec(-1, 1, (nchannels, 16, 16))
+ observation_spec = Bounded(-1, 1, (nchannels, 16, 16))
observation_spec = crop.transform_observation_spec(observation_spec)
assert observation_spec.shape == torch.Size([nchannels, 20, h])
else:
- observation_spec = CompositeSpec(
- {key: BoundedTensorSpec(-1, 1, (nchannels, 16, 16)) for key in keys}
+ observation_spec = Composite(
+ {key: Bounded(-1, 1, (nchannels, 16, 16)) for key in keys}
)
observation_spec = crop.transform_observation_spec(observation_spec)
for key in keys:
@@ -2373,12 +2365,12 @@ def test_transform_no_env(self, keys, h, nchannels, batch, device):
assert (td.get("dont touch") == dont_touch).all()
if len(keys) == 1:
- observation_spec = BoundedTensorSpec(-1, 1, (nchannels, 16, 16))
+ observation_spec = Bounded(-1, 1, (nchannels, 16, 16))
observation_spec = cc.transform_observation_spec(observation_spec)
assert observation_spec.shape == torch.Size([nchannels, 20, h])
else:
- observation_spec = CompositeSpec(
- {key: BoundedTensorSpec(-1, 1, (nchannels, 16, 16)) for key in keys}
+ observation_spec = Composite(
+ {key: Bounded(-1, 1, (nchannels, 16, 16)) for key in keys}
)
observation_spec = cc.transform_observation_spec(observation_spec)
for key in keys:
@@ -2722,18 +2714,15 @@ def test_double2float(self, keys, keys_inv, device):
assert td.get("dont touch").dtype != torch.double
if len(keys_total) == 1 and len(keys_inv) and keys[0] == "action":
- action_spec = BoundedTensorSpec(0, 1, (1, 3, 3), dtype=torch.double)
- input_spec = CompositeSpec(
- full_action_spec=CompositeSpec(action=action_spec), full_state_spec=None
+ action_spec = Bounded(0, 1, (1, 3, 3), dtype=torch.double)
+ input_spec = Composite(
+ full_action_spec=Composite(action=action_spec), full_state_spec=None
)
action_spec = double2float.transform_input_spec(input_spec)
assert action_spec.dtype == torch.float
else:
- observation_spec = CompositeSpec(
- {
- key: BoundedTensorSpec(0, 1, (1, 3, 3), dtype=torch.double)
- for key in keys
- }
+ observation_spec = Composite(
+ {key: Bounded(0, 1, (1, 3, 3), dtype=torch.double) for key in keys}
)
observation_spec = double2float.transform_observation_spec(observation_spec)
for key in keys:
@@ -2950,13 +2939,13 @@ class TestExcludeTransform(TransformBase):
class EnvWithManyKeys(EnvBase):
def __init__(self):
super().__init__()
- self.observation_spec = CompositeSpec(
- a=UnboundedContinuousTensorSpec(3),
- b=UnboundedContinuousTensorSpec(3),
- c=UnboundedContinuousTensorSpec(3),
+ self.observation_spec = Composite(
+ a=Unbounded(3),
+ b=Unbounded(3),
+ c=Unbounded(3),
)
- self.reward_spec = UnboundedContinuousTensorSpec(1)
- self.action_spec = UnboundedContinuousTensorSpec(2)
+ self.reward_spec = Unbounded(1)
+ self.action_spec = Unbounded(2)
def _step(
self,
@@ -3188,13 +3177,13 @@ class TestSelectTransform(TransformBase):
class EnvWithManyKeys(EnvBase):
def __init__(self):
super().__init__()
- self.observation_spec = CompositeSpec(
- a=UnboundedContinuousTensorSpec(3),
- b=UnboundedContinuousTensorSpec(3),
- c=UnboundedContinuousTensorSpec(3),
+ self.observation_spec = Composite(
+ a=Unbounded(3),
+ b=Unbounded(3),
+ c=Unbounded(3),
)
- self.reward_spec = UnboundedContinuousTensorSpec(1)
- self.action_spec = UnboundedContinuousTensorSpec(2)
+ self.reward_spec = Unbounded(1)
+ self.action_spec = Unbounded(2)
def _step(
self,
@@ -3513,15 +3502,12 @@ def test_transform_no_env(self, keys, size, nchannels, batch, device):
assert (td.get("dont touch") == dont_touch).all()
if len(keys) == 1:
- observation_spec = BoundedTensorSpec(-1, 1, (*size, nchannels, 16, 16))
+ observation_spec = Bounded(-1, 1, (*size, nchannels, 16, 16))
observation_spec = flatten.transform_observation_spec(observation_spec)
assert observation_spec.shape[-3] == expected_size
else:
- observation_spec = CompositeSpec(
- {
- key: BoundedTensorSpec(-1, 1, (*size, nchannels, 16, 16))
- for key in keys
- }
+ observation_spec = Composite(
+ {key: Bounded(-1, 1, (*size, nchannels, 16, 16)) for key in keys}
)
observation_spec = flatten.transform_observation_spec(observation_spec)
for key in keys:
@@ -3556,15 +3542,12 @@ def test_transform_compose(self, keys, size, nchannels, batch, device):
assert (td.get("dont touch") == dont_touch).all()
if len(keys) == 1:
- observation_spec = BoundedTensorSpec(-1, 1, (*size, nchannels, 16, 16))
+ observation_spec = Bounded(-1, 1, (*size, nchannels, 16, 16))
observation_spec = flatten.transform_observation_spec(observation_spec)
assert observation_spec.shape[-3] == expected_size
else:
- observation_spec = CompositeSpec(
- {
- key: BoundedTensorSpec(-1, 1, (*size, nchannels, 16, 16))
- for key in keys
- }
+ observation_spec = Composite(
+ {key: Bounded(-1, 1, (*size, nchannels, 16, 16)) for key in keys}
)
observation_spec = flatten.transform_observation_spec(observation_spec)
for key in keys:
@@ -3801,12 +3784,12 @@ def test_transform_no_env(self, keys, device):
assert (td.get("dont touch") == dont_touch).all()
if len(keys) == 1:
- observation_spec = BoundedTensorSpec(-1, 1, (nchannels, 16, 16))
+ observation_spec = Bounded(-1, 1, (nchannels, 16, 16))
observation_spec = gs.transform_observation_spec(observation_spec)
assert observation_spec.shape == torch.Size([1, 16, 16])
else:
- observation_spec = CompositeSpec(
- {key: BoundedTensorSpec(-1, 1, (nchannels, 16, 16)) for key in keys}
+ observation_spec = Composite(
+ {key: Bounded(-1, 1, (nchannels, 16, 16)) for key in keys}
)
observation_spec = gs.transform_observation_spec(observation_spec)
for key in keys:
@@ -3838,12 +3821,12 @@ def test_transform_compose(self, keys, device):
assert (td.get("dont touch") == dont_touch).all()
if len(keys) == 1:
- observation_spec = BoundedTensorSpec(-1, 1, (nchannels, 16, 16))
+ observation_spec = Bounded(-1, 1, (nchannels, 16, 16))
observation_spec = gs.transform_observation_spec(observation_spec)
assert observation_spec.shape == torch.Size([1, 16, 16])
else:
- observation_spec = CompositeSpec(
- {key: BoundedTensorSpec(-1, 1, (nchannels, 16, 16)) for key in keys}
+ observation_spec = Composite(
+ {key: Bounded(-1, 1, (nchannels, 16, 16)) for key in keys}
)
observation_spec = gs.transform_observation_spec(observation_spec)
for key in keys:
@@ -4443,9 +4426,7 @@ def test_observationnorm(
assert (td.get("dont touch") == dont_touch).all()
if len(keys) == 1:
- observation_spec = BoundedTensorSpec(
- 0, 1, (nchannels, 16, 16), device=device
- )
+ observation_spec = Bounded(0, 1, (nchannels, 16, 16), device=device)
observation_spec = on.transform_observation_spec(observation_spec)
if standard_normal:
assert (observation_spec.space.low == -loc / scale).all()
@@ -4455,11 +4436,8 @@ def test_observationnorm(
assert (observation_spec.space.high == scale + loc).all()
else:
- observation_spec = CompositeSpec(
- {
- key: BoundedTensorSpec(0, 1, (nchannels, 16, 16), device=device)
- for key in keys
- }
+ observation_spec = Composite(
+ {key: Bounded(0, 1, (nchannels, 16, 16), device=device) for key in keys}
)
observation_spec = on.transform_observation_spec(observation_spec)
for key in keys:
@@ -4480,15 +4458,11 @@ def test_observationnorm_init_stats(
):
def make_env():
base_env = ContinuousActionVecMockEnv(
- observation_spec=CompositeSpec(
- observation=BoundedTensorSpec(
- low=1, high=1, shape=torch.Size([size])
- ),
- observation_orig=BoundedTensorSpec(
- low=1, high=1, shape=torch.Size([size])
- ),
+ observation_spec=Composite(
+ observation=Bounded(low=1, high=1, shape=torch.Size([size])),
+ observation_orig=Bounded(low=1, high=1, shape=torch.Size([size])),
),
- action_spec=BoundedTensorSpec(low=1, high=1, shape=torch.Size((size,))),
+ action_spec=Bounded(low=1, high=1, shape=torch.Size((size,))),
seed=0,
)
base_env.out_key = "observation"
@@ -4669,12 +4643,12 @@ def test_transform_no_env(self, interpolation, keys, nchannels, batch, device):
assert (td.get("dont touch") == dont_touch).all()
if len(keys) == 1:
- observation_spec = BoundedTensorSpec(-1, 1, (nchannels, 16, 16))
+ observation_spec = Bounded(-1, 1, (nchannels, 16, 16))
observation_spec = resize.transform_observation_spec(observation_spec)
assert observation_spec.shape == torch.Size([nchannels, 20, 21])
else:
- observation_spec = CompositeSpec(
- {key: BoundedTensorSpec(-1, 1, (nchannels, 16, 16)) for key in keys}
+ observation_spec = Composite(
+ {key: Bounded(-1, 1, (nchannels, 16, 16)) for key in keys}
)
observation_spec = resize.transform_observation_spec(observation_spec)
for key in keys:
@@ -4706,12 +4680,12 @@ def test_transform_compose(self, interpolation, keys, nchannels, batch, device):
assert (td.get("dont touch") == dont_touch).all()
if len(keys) == 1:
- observation_spec = BoundedTensorSpec(-1, 1, (nchannels, 16, 16))
+ observation_spec = Bounded(-1, 1, (nchannels, 16, 16))
observation_spec = resize.transform_observation_spec(observation_spec)
assert observation_spec.shape == torch.Size([nchannels, 20, 21])
else:
- observation_spec = CompositeSpec(
- {key: BoundedTensorSpec(-1, 1, (nchannels, 16, 16)) for key in keys}
+ observation_spec = Composite(
+ {key: Bounded(-1, 1, (nchannels, 16, 16)) for key in keys}
)
observation_spec = resize.transform_observation_spec(observation_spec)
for key in keys:
@@ -4947,7 +4921,7 @@ def test_reward_scaling(self, batch, scale, loc, keys, device, standard_normal):
assert (td.get("dont touch") == td_copy.get("dont touch")).all()
if len(keys_total) == 1:
- reward_spec = UnboundedContinuousTensorSpec(device=device)
+ reward_spec = Unbounded(device=device)
reward_spec = reward_scaling.transform_reward_spec(reward_spec)
assert reward_spec.shape == torch.Size([1])
@@ -5140,7 +5114,9 @@ def test_trans_parallel_env_check(self, maybe_fork_ParallelEnv):
pass
@pytest.mark.parametrize("has_in_keys,", [True, False])
- @pytest.mark.parametrize("reset_keys,", [None, ["_reset"] * 3])
+ @pytest.mark.parametrize(
+ "reset_keys,", [[("some", "nested", "reset")], ["_reset"] * 3, None]
+ )
def test_trans_multi_key(
self, has_in_keys, reset_keys, n_workers=2, batch_size=(3, 2), max_steps=5
):
@@ -5162,9 +5138,9 @@ def test_trans_multi_key(
)
with pytest.raises(
ValueError, match="Could not match the env reset_keys"
- ) if reset_keys is None else contextlib.nullcontext():
+ ) if reset_keys == [("some", "nested", "reset")] else contextlib.nullcontext():
check_env_specs(env)
- if reset_keys is not None:
+ if reset_keys != [("some", "nested", "reset")]:
td = env.rollout(max_steps, policy=policy)
for reward_key in env.reward_keys:
reward_key = _unravel_key_to_tuple(reward_key)
@@ -5341,24 +5317,24 @@ def test_sum_reward(self, keys, device):
# test transform_observation_spec
base_env = ContinuousActionVecMockEnv(
- reward_spec=UnboundedContinuousTensorSpec(shape=(3, 16, 16)),
+ reward_spec=Unbounded(shape=(3, 16, 16)),
)
transfomed_env = TransformedEnv(base_env, RewardSum())
transformed_observation_spec1 = transfomed_env.observation_spec
- assert isinstance(transformed_observation_spec1, CompositeSpec)
+ assert isinstance(transformed_observation_spec1, Composite)
assert "episode_reward" in transformed_observation_spec1.keys()
assert "observation" in transformed_observation_spec1.keys()
base_env = ContinuousActionVecMockEnv(
- reward_spec=UnboundedContinuousTensorSpec(),
- observation_spec=CompositeSpec(
- observation=UnboundedContinuousTensorSpec(),
- some_extra_observation=UnboundedContinuousTensorSpec(),
+ reward_spec=Unbounded(),
+ observation_spec=Composite(
+ observation=Unbounded(),
+ some_extra_observation=Unbounded(),
),
)
transfomed_env = TransformedEnv(base_env, RewardSum())
transformed_observation_spec2 = transfomed_env.observation_spec
- assert isinstance(transformed_observation_spec2, CompositeSpec)
+ assert isinstance(transformed_observation_spec2, Composite)
assert "some_extra_observation" in transformed_observation_spec2.keys()
assert "episode_reward" in transformed_observation_spec2.keys()
@@ -5653,7 +5629,7 @@ def test_transform_model(self):
class TestUnsqueezeTransform(TransformBase):
- @pytest.mark.parametrize("unsqueeze_dim", [1, -2])
+ @pytest.mark.parametrize("dim", [1, -2])
@pytest.mark.parametrize("nchannels", [1, 3])
@pytest.mark.parametrize("batch", [[], [2], [2, 4]])
@pytest.mark.parametrize("size", [[], [4]])
@@ -5661,14 +5637,10 @@ class TestUnsqueezeTransform(TransformBase):
"keys", [["observation", ("some_other", "nested_key")], ["observation_pixels"]]
)
@pytest.mark.parametrize("device", get_default_devices())
- def test_transform_no_env(
- self, keys, size, nchannels, batch, device, unsqueeze_dim
- ):
+ def test_transform_no_env(self, keys, size, nchannels, batch, device, dim):
torch.manual_seed(0)
dont_touch = torch.randn(*batch, *size, nchannels, 16, 16, device=device)
- unsqueeze = UnsqueezeTransform(
- unsqueeze_dim, in_keys=keys, allow_positive_dim=True
- )
+ unsqueeze = UnsqueezeTransform(dim, in_keys=keys, allow_positive_dim=True)
td = TensorDict(
{
key: torch.randn(*batch, *size, nchannels, 16, 16, device=device)
@@ -5678,16 +5650,16 @@ def test_transform_no_env(
device=device,
)
td.set("dont touch", dont_touch.clone())
- if unsqueeze_dim >= 0 and unsqueeze_dim < len(batch):
+ if dim >= 0 and dim < len(batch):
with pytest.raises(RuntimeError, match="batch dimension mismatch"):
unsqueeze(td)
return
unsqueeze(td)
expected_size = [*batch, *size, nchannels, 16, 16]
- if unsqueeze_dim < 0:
- expected_size.insert(len(expected_size) + unsqueeze_dim + 1, 1)
+ if dim < 0:
+ expected_size.insert(len(expected_size) + dim + 1, 1)
else:
- expected_size.insert(unsqueeze_dim, 1)
+ expected_size.insert(dim, 1)
expected_size = torch.Size(expected_size)
for key in keys:
@@ -5695,20 +5667,18 @@ def test_transform_no_env(
batch,
size,
nchannels,
- unsqueeze_dim,
+ dim,
)
assert (td.get("dont touch") == dont_touch).all()
if len(keys) == 1:
- observation_spec = BoundedTensorSpec(
- -1, 1, (*batch, *size, nchannels, 16, 16)
- )
+ observation_spec = Bounded(-1, 1, (*batch, *size, nchannels, 16, 16))
observation_spec = unsqueeze.transform_observation_spec(observation_spec)
assert observation_spec.shape == expected_size
else:
- observation_spec = CompositeSpec(
+ observation_spec = Composite(
{
- key: BoundedTensorSpec(-1, 1, (*batch, *size, nchannels, 16, 16))
+ key: Bounded(-1, 1, (*batch, *size, nchannels, 16, 16))
for key in keys
}
)
@@ -5716,7 +5686,7 @@ def test_transform_no_env(
for key in keys:
assert observation_spec[key].shape == expected_size
- @pytest.mark.parametrize("unsqueeze_dim", [1, -2])
+ @pytest.mark.parametrize("dim", [1, -2])
@pytest.mark.parametrize("nchannels", [1, 3])
@pytest.mark.parametrize("batch", [[], [2], [2, 4]])
@pytest.mark.parametrize("size", [[], [4]])
@@ -5732,13 +5702,11 @@ def test_transform_no_env(
[("next", "observation_pixels")],
],
)
- def test_unsqueeze_inv(
- self, keys, keys_inv, size, nchannels, batch, device, unsqueeze_dim
- ):
+ def test_unsqueeze_inv(self, keys, keys_inv, size, nchannels, batch, device, dim):
torch.manual_seed(0)
keys_total = set(keys + keys_inv)
unsqueeze = UnsqueezeTransform(
- unsqueeze_dim, in_keys=keys, in_keys_inv=keys_inv, allow_positive_dim=True
+ dim, in_keys=keys, in_keys_inv=keys_inv, allow_positive_dim=True
)
td = TensorDict(
{
@@ -5754,8 +5722,8 @@ def test_unsqueeze_inv(
for key in keys_total.difference(keys_inv):
assert td.get(key).shape == torch.Size(expected_size)
- if expected_size[unsqueeze_dim] == 1:
- del expected_size[unsqueeze_dim]
+ if expected_size[dim] == 1:
+ del expected_size[dim]
for key in keys_inv:
assert td_modif.get(key).shape == torch.Size(expected_size)
# for key in keys_inv:
@@ -5815,7 +5783,7 @@ def test_trans_parallel_env_check(self, maybe_fork_ParallelEnv):
except RuntimeError:
pass
- @pytest.mark.parametrize("unsqueeze_dim", [1, -2])
+ @pytest.mark.parametrize("dim", [1, -2])
@pytest.mark.parametrize("nchannels", [1, 3])
@pytest.mark.parametrize("batch", [[], [2], [2, 4]])
@pytest.mark.parametrize("size", [[], [4]])
@@ -5823,13 +5791,11 @@ def test_trans_parallel_env_check(self, maybe_fork_ParallelEnv):
"keys", [["observation", "some_other_key"], ["observation_pixels"]]
)
@pytest.mark.parametrize("device", get_default_devices())
- def test_transform_compose(
- self, keys, size, nchannels, batch, device, unsqueeze_dim
- ):
+ def test_transform_compose(self, keys, size, nchannels, batch, device, dim):
torch.manual_seed(0)
dont_touch = torch.randn(*batch, *size, nchannels, 16, 16, device=device)
unsqueeze = Compose(
- UnsqueezeTransform(unsqueeze_dim, in_keys=keys, allow_positive_dim=True)
+ UnsqueezeTransform(dim, in_keys=keys, allow_positive_dim=True)
)
td = TensorDict(
{
@@ -5840,16 +5806,16 @@ def test_transform_compose(
device=device,
)
td.set("dont touch", dont_touch.clone())
- if unsqueeze_dim >= 0 and unsqueeze_dim < len(batch):
+ if dim >= 0 and dim < len(batch):
with pytest.raises(RuntimeError, match="batch dimension mismatch"):
unsqueeze(td)
return
unsqueeze(td)
expected_size = [*batch, *size, nchannels, 16, 16]
- if unsqueeze_dim < 0:
- expected_size.insert(len(expected_size) + unsqueeze_dim + 1, 1)
+ if dim < 0:
+ expected_size.insert(len(expected_size) + dim + 1, 1)
else:
- expected_size.insert(unsqueeze_dim, 1)
+ expected_size.insert(dim, 1)
expected_size = torch.Size(expected_size)
for key in keys:
@@ -5857,20 +5823,18 @@ def test_transform_compose(
batch,
size,
nchannels,
- unsqueeze_dim,
+ dim,
)
assert (td.get("dont touch") == dont_touch).all()
if len(keys) == 1:
- observation_spec = BoundedTensorSpec(
- -1, 1, (*batch, *size, nchannels, 16, 16)
- )
+ observation_spec = Bounded(-1, 1, (*batch, *size, nchannels, 16, 16))
observation_spec = unsqueeze.transform_observation_spec(observation_spec)
assert observation_spec.shape == expected_size
else:
- observation_spec = CompositeSpec(
+ observation_spec = Composite(
{
- key: BoundedTensorSpec(-1, 1, (*batch, *size, nchannels, 16, 16))
+ key: Bounded(-1, 1, (*batch, *size, nchannels, 16, 16))
for key in keys
}
)
@@ -5895,10 +5859,10 @@ def test_transform_env(self, out_keys):
check_env_specs(env)
@pytest.mark.parametrize("out_keys", [None, ["stuff"]])
- @pytest.mark.parametrize("unsqueeze_dim", [-1, 1])
- def test_transform_model(self, out_keys, unsqueeze_dim):
+ @pytest.mark.parametrize("dim", [-1, 1])
+ def test_transform_model(self, out_keys, dim):
t = UnsqueezeTransform(
- unsqueeze_dim,
+ dim,
in_keys=["observation"],
out_keys=out_keys,
allow_positive_dim=True,
@@ -5908,21 +5872,21 @@ def test_transform_model(self, out_keys, unsqueeze_dim):
)
t(td)
expected_shape = [3, 4]
- if unsqueeze_dim >= 0:
- expected_shape.insert(unsqueeze_dim, 1)
+ if dim >= 0:
+ expected_shape.insert(dim, 1)
else:
- expected_shape.insert(len(expected_shape) + unsqueeze_dim + 1, 1)
+ expected_shape.insert(len(expected_shape) + dim + 1, 1)
if out_keys is None:
assert td["observation"].shape == torch.Size(expected_shape)
else:
assert td[out_keys[0]].shape == torch.Size(expected_shape)
@pytest.mark.parametrize("out_keys", [None, ["stuff"]])
- @pytest.mark.parametrize("unsqueeze_dim", [-1, 1])
+ @pytest.mark.parametrize("dim", [-1, 1])
@pytest.mark.parametrize("rbclass", [ReplayBuffer, TensorDictReplayBuffer])
- def test_transform_rb(self, rbclass, out_keys, unsqueeze_dim):
+ def test_transform_rb(self, rbclass, out_keys, dim):
t = UnsqueezeTransform(
- unsqueeze_dim,
+ dim,
in_keys=["observation"],
out_keys=out_keys,
allow_positive_dim=True,
@@ -5935,10 +5899,10 @@ def test_transform_rb(self, rbclass, out_keys, unsqueeze_dim):
rb.extend(td)
td = rb.sample(2)
expected_shape = [2, 3, 4]
- if unsqueeze_dim >= 0:
- expected_shape.insert(unsqueeze_dim, 1)
+ if dim >= 0:
+ expected_shape.insert(dim, 1)
else:
- expected_shape.insert(len(expected_shape) + unsqueeze_dim + 1, 1)
+ expected_shape.insert(len(expected_shape) + dim + 1, 1)
if out_keys is None:
assert td["observation"].shape == torch.Size(expected_shape)
else:
@@ -5962,7 +5926,7 @@ def test_transform_inverse(self):
class TestSqueezeTransform(TransformBase):
- @pytest.mark.parametrize("squeeze_dim", [1, -2])
+ @pytest.mark.parametrize("dim", [1, -2])
@pytest.mark.parametrize("nchannels", [1, 3])
@pytest.mark.parametrize("batch", [[], [2], [2, 4]])
@pytest.mark.parametrize("size", [[], [4]])
@@ -5983,12 +5947,12 @@ class TestSqueezeTransform(TransformBase):
],
)
def test_transform_no_env(
- self, keys, keys_inv, size, nchannels, batch, device, squeeze_dim
+ self, keys, keys_inv, size, nchannels, batch, device, dim
):
torch.manual_seed(0)
keys_total = set(keys + keys_inv)
squeeze = SqueezeTransform(
- squeeze_dim, in_keys=keys, in_keys_inv=keys_inv, allow_positive_dim=True
+ dim, in_keys=keys, in_keys_inv=keys_inv, allow_positive_dim=True
)
td = TensorDict(
{
@@ -6003,12 +5967,12 @@ def test_transform_no_env(
for key in keys_total.difference(keys):
assert td.get(key).shape == torch.Size(expected_size)
- if expected_size[squeeze_dim] == 1:
- del expected_size[squeeze_dim]
+ if expected_size[dim] == 1:
+ del expected_size[dim]
for key in keys:
assert td.get(key).shape == torch.Size(expected_size)
- @pytest.mark.parametrize("squeeze_dim", [1, -2])
+ @pytest.mark.parametrize("dim", [1, -2])
@pytest.mark.parametrize("nchannels", [1, 3])
@pytest.mark.parametrize("batch", [[], [2], [2, 4]])
@pytest.mark.parametrize("size", [[], [4]])
@@ -6028,15 +5992,13 @@ def test_transform_no_env(
[("next", "observation_pixels")],
],
)
- def test_squeeze_inv(
- self, keys, keys_inv, size, nchannels, batch, device, squeeze_dim
- ):
+ def test_squeeze_inv(self, keys, keys_inv, size, nchannels, batch, device, dim):
torch.manual_seed(0)
- if squeeze_dim >= 0:
- squeeze_dim = squeeze_dim + len(batch)
+ if dim >= 0:
+ dim = dim + len(batch)
keys_total = set(keys + keys_inv)
squeeze = SqueezeTransform(
- squeeze_dim, in_keys=keys, in_keys_inv=keys_inv, allow_positive_dim=True
+ dim, in_keys=keys, in_keys_inv=keys_inv, allow_positive_dim=True
)
td = TensorDict(
{
@@ -6051,14 +6013,14 @@ def test_squeeze_inv(
for key in keys_total.difference(keys_inv):
assert td.get(key).shape == torch.Size(expected_size)
- if squeeze_dim < 0:
- expected_size.insert(len(expected_size) + squeeze_dim + 1, 1)
+ if dim < 0:
+ expected_size.insert(len(expected_size) + dim + 1, 1)
else:
- expected_size.insert(squeeze_dim, 1)
+ expected_size.insert(dim, 1)
expected_size = torch.Size(expected_size)
for key in keys_inv:
- assert td.get(key).shape == torch.Size(expected_size), squeeze_dim
+ assert td.get(key).shape == torch.Size(expected_size), dim
@property
def _circular_transform(self):
@@ -6131,7 +6093,7 @@ def test_trans_parallel_env_check(self, maybe_fork_ParallelEnv):
except RuntimeError:
pass
- @pytest.mark.parametrize("squeeze_dim", [1, -2])
+ @pytest.mark.parametrize("dim", [1, -2])
@pytest.mark.parametrize("nchannels", [1, 3])
@pytest.mark.parametrize("batch", [[], [2], [2, 4]])
@pytest.mark.parametrize("size", [[], [4]])
@@ -6144,13 +6106,13 @@ def test_trans_parallel_env_check(self, maybe_fork_ParallelEnv):
"keys_inv", [[], ["action", "some_other_key"], [("next", "observation_pixels")]]
)
def test_transform_compose(
- self, keys, keys_inv, size, nchannels, batch, device, squeeze_dim
+ self, keys, keys_inv, size, nchannels, batch, device, dim
):
torch.manual_seed(0)
keys_total = set(keys + keys_inv)
squeeze = Compose(
SqueezeTransform(
- squeeze_dim, in_keys=keys, in_keys_inv=keys_inv, allow_positive_dim=True
+ dim, in_keys=keys, in_keys_inv=keys_inv, allow_positive_dim=True
)
)
td = TensorDict(
@@ -6166,8 +6128,8 @@ def test_transform_compose(
for key in keys_total.difference(keys):
assert td.get(key).shape == torch.Size(expected_size)
- if expected_size[squeeze_dim] == 1:
- del expected_size[squeeze_dim]
+ if expected_size[dim] == 1:
+ del expected_size[dim]
for key in keys:
assert td.get(key).shape == torch.Size(expected_size)
@@ -6184,9 +6146,9 @@ def test_transform_env(self, keys_inv):
@pytest.mark.parametrize("out_keys", [None, ["obs_sq"]])
def test_transform_model(self, out_keys):
- squeeze_dim = 1
+ dim = 1
t = SqueezeTransform(
- squeeze_dim,
+ dim,
in_keys=["observation"],
out_keys=out_keys,
allow_positive_dim=True,
@@ -6205,9 +6167,9 @@ def test_transform_model(self, out_keys):
@pytest.mark.parametrize("out_keys", [None, ["obs_sq"]])
@pytest.mark.parametrize("rbclass", [ReplayBuffer, TensorDictReplayBuffer])
def test_transform_rb(self, out_keys, rbclass):
- squeeze_dim = -2
+ dim = -2
t = SqueezeTransform(
- squeeze_dim,
+ dim,
in_keys=["observation"],
out_keys=out_keys,
allow_positive_dim=True,
@@ -6466,7 +6428,7 @@ def test_transform_no_env(self, keys, batch, device):
assert (td.get("dont touch") == dont_touch).all()
if len(keys) == 1:
- observation_spec = BoundedTensorSpec(0, 255, (16, 16, 3), dtype=torch.uint8)
+ observation_spec = Bounded(0, 255, (16, 16, 3), dtype=torch.uint8)
observation_spec = totensorimage.transform_observation_spec(
observation_spec
)
@@ -6474,11 +6436,8 @@ def test_transform_no_env(self, keys, batch, device):
assert (observation_spec.space.low == 0).all()
assert (observation_spec.space.high == 1).all()
else:
- observation_spec = CompositeSpec(
- {
- key: BoundedTensorSpec(0, 255, (16, 16, 3), dtype=torch.uint8)
- for key in keys
- }
+ observation_spec = Composite(
+ {key: Bounded(0, 255, (16, 16, 3), dtype=torch.uint8) for key in keys}
)
observation_spec = totensorimage.transform_observation_spec(
observation_spec
@@ -6515,7 +6474,7 @@ def test_transform_compose(self, keys, batch, device):
assert (td.get("dont touch") == dont_touch).all()
if len(keys) == 1:
- observation_spec = BoundedTensorSpec(0, 255, (16, 16, 3), dtype=torch.uint8)
+ observation_spec = Bounded(0, 255, (16, 16, 3), dtype=torch.uint8)
observation_spec = totensorimage.transform_observation_spec(
observation_spec
)
@@ -6523,11 +6482,8 @@ def test_transform_compose(self, keys, batch, device):
assert (observation_spec.space.low == 0).all()
assert (observation_spec.space.high == 1).all()
else:
- observation_spec = CompositeSpec(
- {
- key: BoundedTensorSpec(0, 255, (16, 16, 3), dtype=torch.uint8)
- for key in keys
- }
+ observation_spec = Composite(
+ {key: Bounded(0, 255, (16, 16, 3), dtype=torch.uint8) for key in keys}
)
observation_spec = totensorimage.transform_observation_spec(
observation_spec
@@ -6670,7 +6626,7 @@ class TestTensorDictPrimer(TransformBase):
def test_single_trans_env_check(self):
env = TransformedEnv(
ContinuousActionVecMockEnv(),
- TensorDictPrimer(mykey=UnboundedContinuousTensorSpec([3])),
+ TensorDictPrimer(mykey=Unbounded([3])),
)
check_env_specs(env)
assert "mykey" in env.reset().keys()
@@ -6682,14 +6638,10 @@ def test_nested_key_env(self):
env = TransformedEnv(
env,
TensorDictPrimer(
- CompositeSpec(
+ Composite(
{
- "nested_1": CompositeSpec(
- {
- "mykey": UnboundedContinuousTensorSpec(
- (env.nested_dim_1, 4)
- )
- },
+ "nested_1": Composite(
+ {"mykey": Unbounded((env.nested_dim_1, 4))},
shape=(env.nested_dim_1,),
)
}
@@ -6707,13 +6659,13 @@ def test_nested_key_env(self):
assert ("next", "nested_1", "mykey") in env.rollout(3).keys(True, True)
def test_transform_no_env(self):
- t = TensorDictPrimer(mykey=UnboundedContinuousTensorSpec([3]))
+ t = TensorDictPrimer(mykey=Unbounded([3]))
td = TensorDict({"a": torch.zeros(())}, [])
t(td)
assert "mykey" in td.keys()
def test_transform_model(self):
- t = TensorDictPrimer(mykey=UnboundedContinuousTensorSpec([3]))
+ t = TensorDictPrimer(mykey=Unbounded([3]))
model = nn.Sequential(t, nn.Identity())
td = TensorDict({}, [])
model(td)
@@ -6722,7 +6674,7 @@ def test_transform_model(self):
@pytest.mark.parametrize("rbclass", [ReplayBuffer, TensorDictReplayBuffer])
def test_transform_rb(self, rbclass):
batch_size = (2,)
- t = TensorDictPrimer(mykey=UnboundedContinuousTensorSpec([*batch_size, 3]))
+ t = TensorDictPrimer(mykey=Unbounded([*batch_size, 3]))
rb = rbclass(storage=LazyTensorStorage(10))
rb.append_transform(t)
td = TensorDict({"a": torch.zeros(())}, [])
@@ -6734,7 +6686,7 @@ def test_transform_inverse(self):
raise pytest.skip("No inverse method for TensorDictPrimer")
def test_transform_compose(self):
- t = Compose(TensorDictPrimer(mykey=UnboundedContinuousTensorSpec([3])))
+ t = Compose(TensorDictPrimer(mykey=Unbounded([3])))
td = TensorDict({"a": torch.zeros(())}, [])
t(td)
assert "mykey" in td.keys()
@@ -6743,7 +6695,7 @@ def test_parallel_trans_env_check(self, maybe_fork_ParallelEnv):
def make_env():
return TransformedEnv(
ContinuousActionVecMockEnv(),
- TensorDictPrimer(mykey=UnboundedContinuousTensorSpec([3])),
+ TensorDictPrimer(mykey=Unbounded([3])),
)
env = maybe_fork_ParallelEnv(2, make_env)
@@ -6761,7 +6713,7 @@ def test_serial_trans_env_check(self):
def make_env():
return TransformedEnv(
ContinuousActionVecMockEnv(),
- TensorDictPrimer(mykey=UnboundedContinuousTensorSpec([3])),
+ TensorDictPrimer(mykey=Unbounded([3])),
)
env = SerialEnv(2, make_env)
@@ -6778,7 +6730,7 @@ def make_env():
def test_trans_parallel_env_check(self, maybe_fork_ParallelEnv):
env = TransformedEnv(
maybe_fork_ParallelEnv(2, ContinuousActionVecMockEnv),
- TensorDictPrimer(mykey=UnboundedContinuousTensorSpec([2, 4])),
+ TensorDictPrimer(mykey=Unbounded([2, 4])),
)
try:
check_env_specs(env)
@@ -6796,7 +6748,7 @@ def test_trans_parallel_env_check(self, maybe_fork_ParallelEnv):
def test_trans_serial_env_check(self, spec_shape):
env = TransformedEnv(
SerialEnv(2, ContinuousActionVecMockEnv),
- TensorDictPrimer(mykey=UnboundedContinuousTensorSpec(spec_shape)),
+ TensorDictPrimer(mykey=Unbounded(spec_shape)),
)
check_env_specs(env)
assert "mykey" in env.reset().keys()
@@ -6810,8 +6762,8 @@ def test_trans_serial_env_check(self, spec_shape):
@pytest.mark.parametrize(
"spec",
[
- CompositeSpec(b=BoundedTensorSpec(-3, 3, [4])),
- BoundedTensorSpec(-3, 3, [4]),
+ Composite(b=Bounded(-3, 3, [4])),
+ Bounded(-3, 3, [4]),
],
)
@pytest.mark.parametrize("random", [True, False])
@@ -6861,9 +6813,7 @@ def make_env():
else:
assert (tensordict_select == value).all()
- if isinstance(spec, CompositeSpec) and any(
- key != "action" for key in default_keys
- ):
+ if isinstance(spec, Composite) and any(key != "action" for key in default_keys):
for key in default_keys:
if key in ("action",):
continue
@@ -6878,7 +6828,7 @@ def test_tensordictprimer_batching(self, batched_class, break_when_any_done):
env = TransformedEnv(
batched_class(2, lambda: GymEnv(CARTPOLE_VERSIONED())),
- TensorDictPrimer(mykey=UnboundedContinuousTensorSpec([2, 4])),
+ TensorDictPrimer(mykey=Unbounded([2, 4])),
)
torch.manual_seed(0)
env.set_seed(0)
@@ -6888,7 +6838,7 @@ def test_tensordictprimer_batching(self, batched_class, break_when_any_done):
2,
lambda: TransformedEnv(
GymEnv(CARTPOLE_VERSIONED()),
- TensorDictPrimer(mykey=UnboundedContinuousTensorSpec([4])),
+ TensorDictPrimer(mykey=Unbounded([4])),
),
)
torch.manual_seed(0)
@@ -6902,9 +6852,7 @@ def create_tensor():
env = TransformedEnv(
ContinuousActionVecMockEnv(),
- TensorDictPrimer(
- mykey=UnboundedContinuousTensorSpec([3]), default_value=create_tensor
- ),
+ TensorDictPrimer(mykey=Unbounded([3]), default_value=create_tensor),
)
check_env_specs(env)
assert "mykey" in env.reset().keys()
@@ -6913,8 +6861,8 @@ def create_tensor():
def test_dict_default_value(self):
# Test with a dict of float default values
- key1_spec = UnboundedContinuousTensorSpec([3])
- key2_spec = UnboundedContinuousTensorSpec([3])
+ key1_spec = Unbounded([3])
+ key2_spec = Unbounded([3])
env = TransformedEnv(
ContinuousActionVecMockEnv(),
TensorDictPrimer(
@@ -6937,8 +6885,8 @@ def test_dict_default_value(self):
assert (rollout_td.get(("next", "mykey2")) == 2.0).all()
# Test with a dict of callable default values
- key1_spec = UnboundedContinuousTensorSpec([3])
- key2_spec = DiscreteTensorSpec(3, dtype=torch.int64)
+ key1_spec = Unbounded([3])
+ key2_spec = Categorical(3, dtype=torch.int64)
env = TransformedEnv(
ContinuousActionVecMockEnv(),
TensorDictPrimer(
@@ -7751,13 +7699,11 @@ def test_vipnet_transform_observation_spec(
):
vip_net = _VIPNet(in_keys, out_keys, model, del_keys)
- observation_spec = CompositeSpec(
- {key: BoundedTensorSpec(-1, 1, (3, 16, 16), device) for key in in_keys}
+ observation_spec = Composite(
+ {key: Bounded(-1, 1, (3, 16, 16), device) for key in in_keys}
)
if del_keys:
- exp_ts = CompositeSpec(
- {key: UnboundedContinuousTensorSpec(1024, device) for key in out_keys}
- )
+ exp_ts = Composite({key: Unbounded(1024, device) for key in out_keys})
observation_spec_out = vip_net.transform_observation_spec(observation_spec)
@@ -7772,8 +7718,8 @@ def test_vipnet_transform_observation_spec(
for key in in_keys:
ts_dict[key] = observation_spec[key]
for key in out_keys:
- ts_dict[key] = UnboundedContinuousTensorSpec(1024, device)
- exp_ts = CompositeSpec(ts_dict)
+ ts_dict[key] = Unbounded(1024, device)
+ exp_ts = Composite(ts_dict)
observation_spec_out = vip_net.transform_observation_spec(observation_spec)
@@ -8466,8 +8412,8 @@ def transform_reward_spec(self, reward_spec: TensorSpec) -> TensorSpec:
env.transform.transform_reward_spec(env.base_env.full_reward_spec)
def test_independent_obs_specs_from_shared_env(self):
- obs_spec = CompositeSpec(
- observation=BoundedTensorSpec(low=0, high=10, shape=torch.Size((1,)))
+ obs_spec = Composite(
+ observation=Bounded(low=0, high=10, shape=torch.Size((1,)))
)
base_env = ContinuousActionVecMockEnv(observation_spec=obs_spec)
t1 = TransformedEnv(
@@ -8490,7 +8436,7 @@ def test_independent_obs_specs_from_shared_env(self):
assert base_env.observation_spec["observation"].space.high == 10
def test_independent_reward_specs_from_shared_env(self):
- reward_spec = UnboundedContinuousTensorSpec()
+ reward_spec = Unbounded()
base_env = ContinuousActionVecMockEnv(reward_spec=reward_spec)
t1 = TransformedEnv(
base_env, transform=RewardClipping(clamp_min=0, clamp_max=4)
@@ -8508,8 +8454,14 @@ def test_independent_reward_specs_from_shared_env(self):
assert t2_reward_spec.space.low == -2
assert t2_reward_spec.space.high == 2
- assert base_env.reward_spec.space.low == -np.inf
- assert base_env.reward_spec.space.high == np.inf
+ assert (
+ base_env.reward_spec.space.low
+ == torch.finfo(base_env.reward_spec.dtype).min
+ )
+ assert (
+ base_env.reward_spec.space.high
+ == torch.finfo(base_env.reward_spec.dtype).max
+ )
def test_allow_done_after_reset(self):
base_env = ContinuousActionVecMockEnv(allow_done_after_reset=True)
@@ -8637,13 +8589,13 @@ def test_compose(self, keys, batch, device, nchannels=1, N=4):
assert (td.get("dont touch") == dont_touch).all()
if len(keys) == 1:
- observation_spec = BoundedTensorSpec(0, 255, (nchannels, 16, 16))
+ observation_spec = Bounded(0, 255, (nchannels, 16, 16))
# StepCounter does not want non composite specs
observation_spec = compose[:2].transform_observation_spec(observation_spec)
assert observation_spec.shape == torch.Size([nchannels * N, 16, 16])
else:
- observation_spec = CompositeSpec(
- {key: BoundedTensorSpec(0, 255, (nchannels, 16, 16)) for key in keys}
+ observation_spec = Composite(
+ {key: Bounded(0, 255, (nchannels, 16, 16)) for key in keys}
)
observation_spec = compose.transform_observation_spec(observation_spec)
for key in keys:
@@ -8715,6 +8667,35 @@ def test_compose_indexing(self):
assert last_t.scale == 4
assert last_t2.scale == 4
+ def test_compose_action_spec(self):
+ # Create a Compose transform that renames "action" to "action_1" and then to "action_2"
+ c = Compose(
+ RenameTransform(
+ in_keys=(),
+ out_keys=(),
+ in_keys_inv=("action",),
+ out_keys_inv=("action_1",),
+ ),
+ RenameTransform(
+ in_keys=(),
+ out_keys=(),
+ in_keys_inv=("action_1",),
+ out_keys_inv=("action_2",),
+ ),
+ )
+ base_env = ContinuousActionVecMockEnv()
+ env = TransformedEnv(base_env, c)
+
+ # Check the `full_action_spec`s
+ assert "action_2" in env.full_action_spec
+ # Ensure intermediate keys are no longer in the action spec
+ assert "action_1" not in env.full_action_spec
+ assert "action" not in env.full_action_spec
+
+ # Final check to ensure clean sampling from the action_spec
+ action = env.rand_action()
+ assert "action_2" in action
+
@pytest.mark.parametrize("device", get_default_devices())
def test_finitetensordictcheck(self, device):
ftd = FiniteTensorDictCheck()
@@ -8936,10 +8917,8 @@ def test_batch_unlocked_with_batch_size_transformed(device):
pytest.param(
partial(FlattenObservation, first_dim=-3, last_dim=-3), id="FlattenObservation"
),
- pytest.param(
- partial(UnsqueezeTransform, unsqueeze_dim=-1), id="UnsqueezeTransform"
- ),
- pytest.param(partial(SqueezeTransform, squeeze_dim=-1), id="SqueezeTransform"),
+ pytest.param(partial(UnsqueezeTransform, dim=-1), id="UnsqueezeTransform"),
+ pytest.param(partial(SqueezeTransform, dim=-1), id="SqueezeTransform"),
GrayScale,
pytest.param(
partial(ObservationNorm, in_keys=["observation"]), id="ObservationNorm"
@@ -9371,6 +9350,28 @@ def test_transform_inverse(self, create_copy):
else:
assert "b" not in tensordict.keys()
+ def test_rename_action(self, create_copy):
+ base_env = ContinuousActionVecMockEnv()
+ env = base_env.append_transform(
+ RenameTransform(
+ in_keys=[],
+ out_keys=[],
+ in_keys_inv=["action"],
+ out_keys_inv=[("renamed", "action")],
+ create_copy=create_copy,
+ )
+ )
+ r = env.rollout(3)
+ assert ("renamed", "action") in env.action_keys, env.action_keys
+ assert ("renamed", "action") in r
+ assert env.full_action_spec[("renamed", "action")] is not None
+ if create_copy:
+ assert "action" in env.action_keys
+ assert "action" in r
+ else:
+ assert "action" not in env.action_keys
+ assert "action" not in r
+
class TestInitTracker(TransformBase):
@pytest.mark.skipif(not _has_gym, reason="no gym detected")
@@ -9600,9 +9601,7 @@ def _make_transform_env(self, out_key, base_env):
return Compose(
TensorDictPrimer(
primers={
- "sample_log_prob": UnboundedContinuousTensorSpec(
- shape=base_env.action_spec.shape[:-1]
- )
+ "sample_log_prob": Unbounded(shape=base_env.action_spec.shape[:-1])
}
),
transform,
@@ -9836,20 +9835,18 @@ def test_kl_lstm(self):
class TestActionMask(TransformBase):
@property
def _env_class(self):
- from torchrl.data import BinaryDiscreteTensorSpec, DiscreteTensorSpec
+ from torchrl.data import Binary, Categorical
class MaskedEnv(EnvBase):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
- self.action_spec = DiscreteTensorSpec(4)
- self.state_spec = CompositeSpec(
- action_mask=BinaryDiscreteTensorSpec(4, dtype=torch.bool)
+ self.action_spec = Categorical(4)
+ self.state_spec = Composite(action_mask=Binary(4, dtype=torch.bool))
+ self.observation_spec = Composite(
+ obs=Unbounded(3),
+ action_mask=Binary(4, dtype=torch.bool),
)
- self.observation_spec = CompositeSpec(
- obs=UnboundedContinuousTensorSpec(3),
- action_mask=BinaryDiscreteTensorSpec(4, dtype=torch.bool),
- )
- self.reward_spec = UnboundedContinuousTensorSpec(1)
+ self.reward_spec = Unbounded(1)
def _reset(self, tensordict):
td = self.observation_spec.rand()
@@ -9960,16 +9957,27 @@ def test_transform_inverse(self):
class TestDeviceCastTransformPart(TransformBase):
+ @pytest.fixture(scope="class")
+ def _cast_device(self):
+ if torch.cuda.is_available():
+ yield torch.device("cuda:0")
+ elif torch.backends.mps.is_available():
+ yield torch.device("mps:0")
+ else:
+ yield torch.device("cpu:1")
+
@pytest.mark.parametrize("in_keys", ["observation"])
@pytest.mark.parametrize("out_keys", [None, ["obs_device"]])
@pytest.mark.parametrize("in_keys_inv", ["action"])
@pytest.mark.parametrize("out_keys_inv", [None, ["action_device"]])
- def test_single_trans_env_check(self, in_keys, out_keys, in_keys_inv, out_keys_inv):
+ def test_single_trans_env_check(
+ self, in_keys, out_keys, in_keys_inv, out_keys_inv, _cast_device
+ ):
env = ContinuousActionVecMockEnv(device="cpu:0")
env = TransformedEnv(
env,
DeviceCastTransform(
- "cpu:1",
+ _cast_device,
in_keys=in_keys,
out_keys=out_keys,
in_keys_inv=in_keys_inv,
@@ -9983,12 +9991,14 @@ def test_single_trans_env_check(self, in_keys, out_keys, in_keys_inv, out_keys_i
@pytest.mark.parametrize("out_keys", [None, ["obs_device"]])
@pytest.mark.parametrize("in_keys_inv", ["action"])
@pytest.mark.parametrize("out_keys_inv", [None, ["action_device"]])
- def test_serial_trans_env_check(self, in_keys, out_keys, in_keys_inv, out_keys_inv):
+ def test_serial_trans_env_check(
+ self, in_keys, out_keys, in_keys_inv, out_keys_inv, _cast_device
+ ):
def make_env():
return TransformedEnv(
ContinuousActionVecMockEnv(device="cpu:0"),
DeviceCastTransform(
- "cpu:1",
+ _cast_device,
in_keys=in_keys,
out_keys=out_keys,
in_keys_inv=in_keys_inv,
@@ -10005,13 +10015,13 @@ def make_env():
@pytest.mark.parametrize("in_keys_inv", ["action"])
@pytest.mark.parametrize("out_keys_inv", [None, ["action_device"]])
def test_parallel_trans_env_check(
- self, in_keys, out_keys, in_keys_inv, out_keys_inv
+ self, in_keys, out_keys, in_keys_inv, out_keys_inv, _cast_device
):
def make_env():
return TransformedEnv(
ContinuousActionVecMockEnv(device="cpu:0"),
DeviceCastTransform(
- "cpu:1",
+ _cast_device,
in_keys=in_keys,
out_keys=out_keys,
in_keys_inv=in_keys_inv,
@@ -10037,14 +10047,16 @@ def make_env():
@pytest.mark.parametrize("out_keys", [None, ["obs_device"]])
@pytest.mark.parametrize("in_keys_inv", ["action"])
@pytest.mark.parametrize("out_keys_inv", [None, ["action_device"]])
- def test_trans_serial_env_check(self, in_keys, out_keys, in_keys_inv, out_keys_inv):
+ def test_trans_serial_env_check(
+ self, in_keys, out_keys, in_keys_inv, out_keys_inv, _cast_device
+ ):
def make_env():
return ContinuousActionVecMockEnv(device="cpu:0")
env = TransformedEnv(
SerialEnv(2, make_env),
DeviceCastTransform(
- "cpu:1",
+ _cast_device,
in_keys=in_keys,
out_keys=out_keys,
in_keys_inv=in_keys_inv,
@@ -10059,7 +10071,7 @@ def make_env():
@pytest.mark.parametrize("in_keys_inv", ["action"])
@pytest.mark.parametrize("out_keys_inv", [None, ["action_device"]])
def test_trans_parallel_env_check(
- self, in_keys, out_keys, in_keys_inv, out_keys_inv
+ self, in_keys, out_keys, in_keys_inv, out_keys_inv, _cast_device
):
def make_env():
return ContinuousActionVecMockEnv(device="cpu:0")
@@ -10071,7 +10083,7 @@ def make_env():
mp_start_method=mp_ctx if not torch.cuda.is_available() else "spawn",
),
DeviceCastTransform(
- "cpu:1",
+ _cast_device,
in_keys=in_keys,
out_keys=out_keys,
in_keys_inv=in_keys_inv,
@@ -10087,8 +10099,8 @@ def make_env():
except RuntimeError:
pass
- def test_transform_no_env(self):
- t = DeviceCastTransform("cpu:1", "cpu:0", in_keys=["a"], out_keys=["b"])
+ def test_transform_no_env(self, _cast_device):
+ t = DeviceCastTransform(_cast_device, "cpu:0", in_keys=["a"], out_keys=["b"])
td = TensorDict({"a": torch.randn((), device="cpu:0")}, [], device="cpu:0")
tdt = t._call(td)
assert tdt.device is None
@@ -10097,12 +10109,14 @@ def test_transform_no_env(self):
@pytest.mark.parametrize("out_keys", [None, ["obs_device"]])
@pytest.mark.parametrize("in_keys_inv", ["action"])
@pytest.mark.parametrize("out_keys_inv", [None, ["action_device"]])
- def test_transform_env(self, in_keys, out_keys, in_keys_inv, out_keys_inv):
+ def test_transform_env(
+ self, in_keys, out_keys, in_keys_inv, out_keys_inv, _cast_device
+ ):
env = ContinuousActionVecMockEnv(device="cpu:0")
env = TransformedEnv(
env,
DeviceCastTransform(
- "cpu:1",
+ _cast_device,
in_keys=in_keys,
out_keys=out_keys,
in_keys_inv=in_keys_inv,
@@ -10110,13 +10124,13 @@ def test_transform_env(self, in_keys, out_keys, in_keys_inv, out_keys_inv):
),
)
assert env.device is None
- assert env.transform.device == torch.device("cpu:1")
+ assert env.transform.device == _cast_device
assert env.transform.orig_device == torch.device("cpu:0")
- def test_transform_compose(self):
+ def test_transform_compose(self, _cast_device):
t = Compose(
DeviceCastTransform(
- "cpu:1",
+ _cast_device,
"cpu:0",
in_keys=["a"],
out_keys=["b"],
@@ -10128,7 +10142,7 @@ def test_transform_compose(self):
td = TensorDict(
{
"a": torch.randn((), device="cpu:0"),
- "c": torch.randn((), device="cpu:1"),
+ "c": torch.randn((), device=_cast_device),
},
[],
device="cpu:0",
@@ -10139,11 +10153,11 @@ def test_transform_compose(self):
assert tdt.device is None
assert tdit.device is None
- def test_transform_model(self):
+ def test_transform_model(self, _cast_device):
t = nn.Sequential(
Compose(
DeviceCastTransform(
- "cpu:1",
+ _cast_device,
"cpu:0",
in_keys=["a"],
out_keys=["b"],
@@ -10166,11 +10180,11 @@ def test_transform_model(self):
@pytest.mark.parametrize("rbclass", [ReplayBuffer, TensorDictReplayBuffer])
@pytest.mark.parametrize("storage", [LazyTensorStorage])
- def test_transform_rb(self, rbclass, storage):
+ def test_transform_rb(self, rbclass, storage, _cast_device):
# we don't test casting to cuda on Memmap tensor storage since it's discouraged
t = Compose(
DeviceCastTransform(
- "cpu:1",
+ _cast_device,
"cpu:0",
in_keys=["a"],
out_keys=["b"],
@@ -10183,7 +10197,7 @@ def test_transform_rb(self, rbclass, storage):
td = TensorDict(
{
"a": torch.randn((), device="cpu:0"),
- "c": torch.randn((), device="cpu:1"),
+ "c": torch.randn((), device=_cast_device),
},
[],
device="cpu:0",
@@ -10467,17 +10481,18 @@ def test_transform_no_env(self, batch):
reason="EndOfLifeTransform can only be tested when Gym is present.",
)
class TestEndOfLife(TransformBase):
+ pytest.mark.filterwarnings("ignore:The base_env is not a gym env")
+
def test_trans_parallel_env_check(self, maybe_fork_ParallelEnv):
def make():
with set_gym_backend("gymnasium"):
return GymEnv(BREAKOUT_VERSIONED())
- with pytest.warns(UserWarning, match="The base_env is not a gym env"):
- with pytest.raises(AttributeError):
- env = TransformedEnv(
- maybe_fork_ParallelEnv(2, make), transform=EndOfLifeTransform()
- )
- check_env_specs(env)
+ with pytest.raises(AttributeError):
+ env = TransformedEnv(
+ maybe_fork_ParallelEnv(2, make), transform=EndOfLifeTransform()
+ )
+ check_env_specs(env)
def test_trans_serial_env_check(self):
def make():
@@ -10987,27 +11002,25 @@ class TestRemoveEmptySpecs(TransformBase):
class DummyEnv(EnvBase):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
- self.observation_spec = CompositeSpec(
- observation=UnboundedContinuousTensorSpec((*self.batch_size, 3)),
- other=CompositeSpec(
- another_other=CompositeSpec(shape=self.batch_size),
+ self.observation_spec = Composite(
+ observation=Unbounded((*self.batch_size, 3)),
+ other=Composite(
+ another_other=Composite(shape=self.batch_size),
shape=self.batch_size,
),
shape=self.batch_size,
)
- self.action_spec = UnboundedContinuousTensorSpec((*self.batch_size, 3))
- self.done_spec = DiscreteTensorSpec(
- 2, (*self.batch_size, 1), dtype=torch.bool
- )
+ self.action_spec = Unbounded((*self.batch_size, 3))
+ self.done_spec = Categorical(2, (*self.batch_size, 1), dtype=torch.bool)
self.full_done_spec["truncated"] = self.full_done_spec["terminated"].clone()
- self.reward_spec = CompositeSpec(
- reward=UnboundedContinuousTensorSpec(*self.batch_size, 1),
- other_reward=CompositeSpec(shape=self.batch_size),
+ self.reward_spec = Composite(
+ reward=Unbounded(*self.batch_size, 1),
+ other_reward=Composite(shape=self.batch_size),
shape=self.batch_size,
)
- self.state_spec = CompositeSpec(
- state=CompositeSpec(
- sub=CompositeSpec(shape=self.batch_size), shape=self.batch_size
+ self.state_spec = Composite(
+ state=Composite(
+ sub=Composite(shape=self.batch_size), shape=self.batch_size
),
shape=self.batch_size,
)
@@ -11213,11 +11226,9 @@ class MyEnv(EnvBase):
def __init__(self):
super().__init__()
- self.observation_spec = CompositeSpec(
- observation=UnboundedContinuousTensorSpec(3)
- )
- self.reward_spec = UnboundedContinuousTensorSpec(1)
- self.action_spec = UnboundedContinuousTensorSpec(1)
+ self.observation_spec = Composite(observation=Unbounded(3))
+ self.reward_spec = Unbounded(1)
+ self.action_spec = Unbounded(1)
def _reset(self, tensordict: TensorDictBase, **kwargs) -> TensorDictBase:
tensordict_batch_size = (
diff --git a/test/test_utils.py b/test/test_utils.py
index c2ce2eae6b9..4224a36b54f 100644
--- a/test/test_utils.py
+++ b/test/test_utils.py
@@ -174,8 +174,8 @@ def test_implement_for_reset():
("0.9.0", "0.1.0", "0.21.0", True),
("0.19.99", "0.19.9", "0.21.0", True),
("0.19.99", None, "0.19.0", False),
- ("5.61.77", "0.21.0", None, True),
- ("5.61.77", None, "0.21.0", False),
+ ("0.99.0", "0.21.0", None, True),
+ ("0.99.0", None, "0.21.0", False),
],
)
def test_implement_for_check_versions(
@@ -189,9 +189,9 @@ def test_implement_for_check_versions(
@pytest.mark.parametrize(
"gymnasium_version, expected_from_version_gymnasium, expected_to_version_gymnasium",
[
- ("0.27.0", None, None),
- ("0.27.2", None, None),
- ("5.1.77", None, None),
+ ("0.27.0", None, "1.0.0"),
+ ("0.27.2", None, "1.0.0"),
+ # ("1.0.1", "1.0.0", None),
],
)
@pytest.mark.parametrize(
@@ -199,7 +199,7 @@ def test_implement_for_check_versions(
[
("0.21.0", "0.21.0", None),
("0.22.0", "0.21.0", None),
- ("5.61.77", "0.21.0", None),
+ ("0.99.0", "0.21.0", None),
("0.9.0", None, "0.21.0"),
("0.20.0", None, "0.21.0"),
("0.19.99", None, "0.21.0"),
@@ -228,6 +228,8 @@ def test_set_gym_environments(
import gymnasium
# look for the right function that should be called according to gym versions (and same for gymnasium)
+ expected_fn_gymnasium = None
+ expected_fn_gym = None
for impfor in implement_for._setters:
if impfor.fn.__name__ == "_set_gym_environments":
if (impfor.module_name, impfor.from_version, impfor.to_version) == (
@@ -242,20 +244,22 @@ def test_set_gym_environments(
expected_to_version_gymnasium,
):
expected_fn_gymnasium = impfor.fn
+ if expected_fn_gym is not None and expected_fn_gymnasium is not None:
+ break
with set_gym_backend(gymnasium):
assert (
- _utils_internal._set_gym_environments == expected_fn_gymnasium
+ _utils_internal._set_gym_environments is expected_fn_gymnasium
), expected_fn_gym
with set_gym_backend(gym):
assert (
- _utils_internal._set_gym_environments == expected_fn_gym
+ _utils_internal._set_gym_environments is expected_fn_gym
), expected_fn_gymnasium
with set_gym_backend(gymnasium):
assert (
- _utils_internal._set_gym_environments == expected_fn_gymnasium
+ _utils_internal._set_gym_environments is expected_fn_gymnasium
), expected_fn_gym
diff --git a/torchrl/__init__.py b/torchrl/__init__.py
index 25103423cac..7a41bf0ab8f 100644
--- a/torchrl/__init__.py
+++ b/torchrl/__init__.py
@@ -3,6 +3,7 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import os
+import weakref
from warnings import warn
import torch
@@ -10,6 +11,7 @@
from tensordict import set_lazy_legacy
from torch import multiprocessing as mp
+from torch.distributions.transforms import _InverseTransform, ComposeTransform
set_lazy_legacy(False).set()
@@ -25,6 +27,11 @@
except ImportError:
__version__ = None
+try:
+ from torch.compiler import is_dynamo_compiling
+except ImportError:
+ from torch._dynamo import is_compiling as is_dynamo_compiling
+
_init_extension()
try:
@@ -51,3 +58,45 @@
filter_warnings_subprocess = True
_THREAD_POOL_INIT = torch.get_num_threads()
+
+
+# monkey-patch dist transforms until https://github.com/pytorch/pytorch/pull/135001/ finds a home
+@property
+def _inv(self):
+ """Patched version of Transform.inv.
+
+ Returns the inverse :class:`Transform` of this transform.
+
+ This should satisfy ``t.inv.inv is t``.
+ """
+ inv = None
+ if self._inv is not None:
+ inv = self._inv()
+ if inv is None:
+ inv = _InverseTransform(self)
+ if not is_dynamo_compiling():
+ self._inv = weakref.ref(inv)
+ return inv
+
+
+torch.distributions.transforms.Transform.inv = _inv
+
+
+@property
+def _inv(self):
+ inv = None
+ if self._inv is not None:
+ inv = self._inv()
+ if inv is None:
+ inv = ComposeTransform([p.inv for p in reversed(self.parts)])
+ if not is_dynamo_compiling():
+ self._inv = weakref.ref(inv)
+ inv._inv = weakref.ref(self)
+ else:
+ # We need inv.inv to be equal to self, but weakref can cause a graph break
+ inv._inv = lambda out=self: out
+
+ return inv
+
+
+ComposeTransform.inv = _inv
diff --git a/torchrl/_utils.py b/torchrl/_utils.py
index 895f3d80fdc..3af44ee0ed7 100644
--- a/torchrl/_utils.py
+++ b/torchrl/_utils.py
@@ -46,7 +46,7 @@
console_handler.setFormatter(formatter)
logger.addHandler(console_handler)
-VERBOSE = strtobool(os.environ.get("VERBOSE", "0"))
+VERBOSE = strtobool(os.environ.get("VERBOSE", str(logger.isEnabledFor(logging.DEBUG))))
_os_is_windows = sys.platform == "win32"
RL_WARNINGS = strtobool(os.environ.get("RL_WARNINGS", "1"))
if RL_WARNINGS:
@@ -269,7 +269,7 @@ class implement_for:
... # More recent gym versions will return x + 2
... return x + 2
...
- >>> @implement_for("gymnasium")
+ >>> @implement_for("gymnasium", None, "1.0.0")
>>> def fun(self, x):
... # If gymnasium is to be used instead of gym, x+3 will be returned
... return x + 3
@@ -785,4 +785,6 @@ def _make_ordinal_device(device: torch.device):
return device
if device.type == "cuda" and device.index is None:
return torch.device("cuda", index=torch.cuda.current_device())
+ if device.type == "mps" and device.index is None:
+ return torch.device("mps", index=0)
return device
diff --git a/torchrl/collectors/collectors.py b/torchrl/collectors/collectors.py
index be24a06e39c..91355ae261f 100644
--- a/torchrl/collectors/collectors.py
+++ b/torchrl/collectors/collectors.py
@@ -10,13 +10,13 @@
import contextlib
import functools
-
import os
import queue
import sys
import time
+import typing
import warnings
-from collections import OrderedDict
+from collections import defaultdict, OrderedDict
from copy import deepcopy
from multiprocessing import connection, queues
from multiprocessing.managers import SyncManager
@@ -33,7 +33,8 @@
TensorDictBase,
TensorDictParams,
)
-from tensordict.nn import TensorDictModule
+from tensordict.base import NO_DEFAULT
+from tensordict.nn import CudaGraphModule, TensorDictModule
from torch import multiprocessing as mp
from torch.utils.data import IterableDataset
@@ -50,20 +51,22 @@
VERBOSE,
)
from torchrl.collectors.utils import split_trajectories
+from torchrl.data import ReplayBuffer
from torchrl.data.tensor_specs import TensorSpec
from torchrl.data.utils import CloudpickleWrapper, DEVICE_TYPING
from torchrl.envs.common import _do_nothing, EnvBase
+from torchrl.envs.env_creator import EnvCreator
from torchrl.envs.transforms import StepCounter, TransformedEnv
from torchrl.envs.utils import (
_aggregate_end_of_traj,
- _convert_exploration_type,
_make_compatible_policy,
- _NonParametricPolicyWrapper,
ExplorationType,
+ RandomPolicy,
set_exploration_type,
)
_TIMEOUT = 1.0
+INSTANTIATE_TIMEOUT = 20
_MIN_TIMEOUT = 1e-3 # should be several orders of magnitude inferior wrt time spent collecting a trajectory
# MAX_IDLE_COUNT is the maximum number of times a Dataloader worker can timeout with his queue.
_MAX_IDLE_COUNT = int(os.environ.get("MAX_IDLE_COUNT", 1000))
@@ -131,59 +134,97 @@ class DataCollectorBase(IterableDataset, metaclass=abc.ABCMeta):
"""Base class for data collectors."""
_iterator = None
+ total_frames: int
+ frames_per_batch: int
+ trust_policy: bool
+ compiled_policy: bool
+ cudagraphed_policy: bool
def _get_policy_and_device(
self,
- policy: Optional[
- Union[
- TensorDictModule,
- Callable[[TensorDictBase], TensorDictBase],
- ]
- ] = None,
+ policy: Callable[[Any], Any] | None = None,
observation_spec: TensorSpec = None,
+ policy_device: Any = NO_DEFAULT,
+ env_maker: Any | None = None,
+ env_maker_kwargs: dict | None = None,
) -> Tuple[TensorDictModule, Union[None, Callable[[], dict]]]:
"""Util method to get a policy and its device given the collector __init__ inputs.
Args:
- create_env_fn (Callable or list of callables): an env creator
- function (or a list of creators)
- create_env_kwargs (dictionary): kwargs for the env creator
policy (TensorDictModule, optional): a policy to be used
observation_spec (TensorSpec, optional): spec of the observations
+ policy_device (torch.device, optional): the device where the policy should be placed.
+ Defaults to self.policy_device
+ env_maker (a callable or a batched env, optional): the env_maker function for this device/policy pair.
+ env_maker_kwargs (a dict, optional): the env_maker function kwargs.
"""
- policy = _make_compatible_policy(
- policy, observation_spec, env=getattr(self, "env", None)
- )
- param_and_buf = TensorDict.from_module(policy, as_module=True)
-
- def get_weights_fn(param_and_buf=param_and_buf):
- return param_and_buf.data
-
- if self.policy_device:
- # create a stateless policy and populate it with params
- def _map_to_device_params(param, device):
- is_param = isinstance(param, nn.Parameter)
-
- pd = param.detach().to(device, non_blocking=True)
-
- if is_param:
- pd = nn.Parameter(pd, requires_grad=False)
- return pd
+ if policy_device is NO_DEFAULT:
+ policy_device = self.policy_device
+
+ if not self.trust_policy:
+ env = getattr(self, "env", None)
+ policy = _make_compatible_policy(
+ policy,
+ observation_spec,
+ env=env,
+ env_maker=env_maker,
+ env_maker_kwargs=env_maker_kwargs,
+ )
+ if not policy_device:
+ return policy, None
- # Create a stateless policy, then populate this copy with params on device
- with param_and_buf.apply(
- functools.partial(_map_to_device_params, device="meta"),
- filter_empty=False,
- ).to_module(policy):
- policy = deepcopy(policy)
+ if isinstance(policy, nn.Module):
+ param_and_buf = TensorDict.from_module(policy, as_module=True)
+ else:
+ # Because we want to reach the warning
+ param_and_buf = TensorDict()
+
+ i = -1
+ for p in param_and_buf.values(True, True):
+ i += 1
+ if p.device != policy_device:
+ # Then we need casting
+ break
+ else:
+ if i == -1 and not self.trust_policy:
+ # We trust that the policy policy device is adequate
+ warnings.warn(
+ "A policy device was provided but no parameter/buffer could be found in "
+ "the policy. Casting to policy_device is therefore impossible. "
+ "The collector will trust that the devices match. To suppress this "
+ "warning, set `trust_policy=True` when building the collector."
+ )
+ return policy, None
- param_and_buf.apply(
- functools.partial(_map_to_device_params, device=self.policy_device),
- filter_empty=False,
- ).to_module(policy)
+ def map_weight(
+ weight,
+ policy_device=policy_device,
+ ):
- return policy, get_weights_fn
+ is_param = isinstance(weight, nn.Parameter)
+ is_buffer = isinstance(weight, nn.Buffer)
+ weight = weight.data
+ if weight.device != policy_device:
+ weight = weight.to(policy_device)
+ elif weight.device.type in ("cpu", "mps"):
+ weight = weight.share_memory_()
+ if is_param:
+ weight = nn.Parameter(weight, requires_grad=False)
+ elif is_buffer:
+ weight = nn.Buffer(weight)
+ return weight
+
+ # Create a stateless policy, then populate this copy with params on device
+ get_original_weights = functools.partial(TensorDict.from_module, policy)
+ with param_and_buf.to("meta").to_module(policy):
+ policy = deepcopy(policy)
+
+ param_and_buf.apply(
+ functools.partial(map_weight),
+ filter_empty=False,
+ ).to_module(policy)
+ return policy, get_original_weights
def update_policy_weights_(
self, policy_weights: Optional[TensorDictBase] = None
@@ -234,6 +275,16 @@ def state_dict(self) -> OrderedDict:
def load_state_dict(self, state_dict: OrderedDict) -> None:
raise NotImplementedError
+ def _read_compile_kwargs(self, compile_policy, cudagraph_policy):
+ self.compiled_policy = compile_policy not in (False, None)
+ self.cudagraphed_policy = cudagraph_policy not in (False, None)
+ self.compiled_policy_kwargs = (
+ {} if not isinstance(compile_policy, typing.Mapping) else compile_policy
+ )
+ self.cudagraphed_policy_kwargs = (
+ {} if not isinstance(cudagraph_policy, typing.Mapping) else cudagraph_policy
+ )
+
def __repr__(self) -> str:
string = f"{self.__class__.__name__}()"
return string
@@ -241,6 +292,11 @@ def __repr__(self) -> str:
def __class_getitem__(self, index):
raise NotImplementedError
+ def __len__(self) -> int:
+ if self.total_frames > 0:
+ return -(self.total_frames // -self.frames_per_batch)
+ raise RuntimeError("Non-terminating collectors do not have a length")
+
@accept_remote_rref_udf_invocation
class SyncDataCollector(DataCollectorBase):
@@ -357,6 +413,17 @@ class SyncDataCollector(DataCollectorBase):
use_buffers (bool, optional): if ``True``, a buffer will be used to stack the data.
This isn't compatible with environments with dynamic specs. Defaults to ``True``
for envs without dynamic specs, ``False`` for others.
+ replay_buffer (ReplayBuffer, optional): if provided, the collector will not yield tensordict
+ but populate the buffer instead. Defaults to ``None``.
+ trust_policy (bool, optional): if ``True``, a non-TensorDictModule policy will be trusted to be
+ assumed to be compatible with the collector. This defaults to ``True`` for CudaGraphModules
+ and ``False`` otherwise.
+ compile_policy (bool or Dict[str, Any], optional): if ``True``, the policy will be compiled
+ using :func:`~torch.compile` default behaviour. If a dictionary of kwargs is passed, it
+ will be used to compile the policy.
+ cudagraph_policy (bool or Dict[str, Any], optional): if ``True``, the policy will be wrapped
+ in :class:`~tensordict.nn.CudaGraphModule` with default kwargs.
+ If a dictionary of kwargs is passed, it will be used to wrap the policy.
Examples:
>>> from torchrl.envs.libs.gym import GymEnv
@@ -440,19 +507,20 @@ def __init__(
postproc: Callable[[TensorDictBase], TensorDictBase] | None = None,
split_trajs: bool | None = None,
exploration_type: ExplorationType = DEFAULT_EXPLORATION_TYPE,
- exploration_mode: str | None = None,
return_same_td: bool = False,
reset_when_done: bool = True,
interruptor=None,
set_truncated: bool = False,
use_buffers: bool | None = None,
+ replay_buffer: ReplayBuffer | None = None,
+ trust_policy: bool = None,
+ compile_policy: bool | Dict[str, Any] | None = None,
+ cudagraph_policy: bool | Dict[str, Any] | None = None,
+ **kwargs,
):
from torchrl.envs.batched_envs import BatchedEnvBase
self.closed = True
- exploration_type = _convert_exploration_type(
- exploration_mode=exploration_mode, exploration_type=exploration_type
- )
if create_env_kwargs is None:
create_env_kwargs = {}
if not isinstance(create_env_fn, EnvBase):
@@ -468,9 +536,20 @@ def __init__(
env.update_kwargs(create_env_kwargs)
if policy is None:
- from torchrl.collectors import RandomPolicy
policy = RandomPolicy(env.full_action_spec)
+ if trust_policy is None:
+ trust_policy = isinstance(policy, (RandomPolicy, CudaGraphModule))
+ self.trust_policy = trust_policy
+ self._read_compile_kwargs(compile_policy, cudagraph_policy)
+
+ ##########################
+ # Trajectory pool
+ self._traj_pool_val = kwargs.pop("traj_pool", None)
+ if kwargs:
+ raise TypeError(
+ f"Keys {list(kwargs.keys())} are unknown to {type(self).__name__}."
+ )
##########################
# Setting devices:
@@ -493,7 +572,8 @@ def __init__(
# Cuda handles sync
if torch.cuda.is_available():
self._sync_storage = torch.cuda.synchronize
- elif torch.backends.mps.is_available():
+ elif torch.backends.mps.is_available() and hasattr(torch, "mps"):
+ # Will break for older PT versions which don't have torch.mps
self._sync_storage = torch.mps.synchronize
elif self.storing_device.type == "cpu":
self._sync_storage = _do_nothing
@@ -507,7 +587,7 @@ def __init__(
# Cuda handles sync
if torch.cuda.is_available():
self._sync_env = torch.cuda.synchronize
- elif torch.backends.mps.is_available():
+ elif torch.backends.mps.is_available() and hasattr(torch, "mps"):
self._sync_env = torch.mps.synchronize
elif self.env_device.type == "cpu":
self._sync_env = _do_nothing
@@ -520,7 +600,7 @@ def __init__(
# Cuda handles sync
if torch.cuda.is_available():
self._sync_policy = torch.cuda.synchronize
- elif torch.backends.mps.is_available():
+ elif torch.backends.mps.is_available() and hasattr(torch, "mps"):
self._sync_policy = torch.mps.synchronize
elif self.policy_device.type == "cpu":
self._sync_policy = _do_nothing
@@ -538,10 +618,19 @@ def __init__(
self.env: EnvBase = env
del env
+ self.replay_buffer = replay_buffer
+ if self.replay_buffer is not None:
+ if postproc is not None:
+ raise TypeError("postproc must be None when a replay buffer is passed.")
+ if use_buffers:
+ raise TypeError("replay_buffer is exclusive with use_buffers.")
if use_buffers is None:
- use_buffers = not self.env._has_dynamic_specs
+ use_buffers = not self.env._has_dynamic_specs and self.replay_buffer is None
self._use_buffers = use_buffers
+ self.replay_buffer = replay_buffer
+
self.closed = False
+
if not reset_when_done:
raise ValueError("reset_when_done is deprectated.")
self.reset_when_done = reset_when_done
@@ -551,11 +640,15 @@ def __init__(
policy=policy,
observation_spec=self.env.observation_spec,
)
-
if isinstance(self.policy, nn.Module):
self.policy_weights = TensorDict.from_module(self.policy, as_module=True)
else:
- self.policy_weights = TensorDict({}, [])
+ self.policy_weights = TensorDict()
+
+ if self.compiled_policy:
+ self.policy = torch.compile(self.policy, **self.compiled_policy_kwargs)
+ if self.cudagraphed_policy:
+ self.policy = CudaGraphModule(self.policy, **self.cudagraphed_policy_kwargs)
if self.env_device:
self.env: EnvBase = self.env.to(self.env_device)
@@ -655,6 +748,13 @@ def __init__(
self._frames = 0
self._iter = -1
+ @property
+ def _traj_pool(self):
+ pool = getattr(self, "_traj_pool_val", None)
+ if pool is None:
+ pool = self._traj_pool_val = _TrajectoryPool()
+ return pool
+
def _make_shuttle(self):
# Shuttle is a deviceless tensordict that just carried data from env to policy and policy to env
with torch.no_grad():
@@ -665,9 +765,9 @@ def _make_shuttle(self):
else:
self._shuttle_has_no_device = False
- traj_ids = torch.arange(self.n_env, device=self.storing_device).view(
- self.env.batch_size
- )
+ traj_ids = self._traj_pool.get_traj_and_increment(
+ self.n_env, device=self.storing_device
+ ).view(self.env.batch_size)
self._shuttle.set(
("collector", "traj_ids"),
traj_ids,
@@ -753,11 +853,13 @@ def check_exclusive(val):
# changed them here).
# This will cause a failure to update entries when policy and env device mismatch and
# casting is necessary.
- def filter_policy(value_output, value_input, value_input_clone):
- if (
- (value_input is None)
- or (value_output is not value_input)
- or ~torch.isclose(value_output, value_input_clone).any()
+ def filter_policy(name, value_output, value_input, value_input_clone):
+ if (value_input is None) or (
+ (value_output is not value_input)
+ and (
+ value_output.device != value_input_clone.device
+ or ~torch.isclose(value_output, value_input_clone).any()
+ )
):
return value_output
@@ -767,6 +869,7 @@ def filter_policy(value_output, value_input, value_input_clone):
policy_input_clone,
default=None,
filter_empty=True,
+ named=True,
)
self._policy_output_keys = list(
self._policy_output_keys.union(
@@ -871,7 +974,15 @@ def set_seed(self, seed: int, static_seed: bool = False) -> int:
>>> out_seed = collector.set_seed(1) # out_seed = 6
"""
- return self.env.set_seed(seed, static_seed=static_seed)
+ out = self.env.set_seed(seed, static_seed=static_seed)
+ return out
+
+ def _increment_frames(self, numel):
+ self._frames += numel
+ completed = self._frames >= self.total_frames
+ if completed:
+ self.env.close()
+ return completed
def iterator(self) -> Iterator[TensorDictBase]:
"""Iterates through the DataCollector.
@@ -917,14 +1028,15 @@ def cuda_check(tensor: torch.Tensor):
for stream in streams:
stack.enter_context(torch.cuda.stream(stream))
- total_frames = self.total_frames
-
while self._frames < self.total_frames:
self._iter += 1
tensordict_out = self.rollout()
- self._frames += tensordict_out.numel()
- if self._frames >= total_frames:
- self.env.close()
+ if tensordict_out is None:
+ # if a replay buffer is passed, there is no tensordict_out
+ # frames are updated within the rollout function
+ yield
+ continue
+ self._increment_frames(tensordict_out.numel())
if self.split_trajs:
tensordict_out = split_trajectories(
@@ -976,14 +1088,20 @@ def _update_traj_ids(self, env_output) -> None:
env_output.get("next"), done_keys=self.env.done_keys
)
if traj_sop.any():
+ device = self.storing_device
+
traj_ids = self._shuttle.get(("collector", "traj_ids"))
- traj_sop = traj_sop.to(self.storing_device)
- traj_ids = traj_ids.clone().to(self.storing_device)
- traj_ids[traj_sop] = traj_ids.max() + torch.arange(
- 1,
- traj_sop.sum() + 1,
- device=self.storing_device,
+ if device is not None:
+ traj_ids = traj_ids.to(device)
+ traj_sop = traj_sop.to(device)
+ elif traj_sop.device != traj_ids.device:
+ traj_sop = traj_sop.to(traj_ids.device)
+
+ pool = self._traj_pool
+ new_traj = pool.get_traj_and_increment(
+ traj_sop.sum(), device=traj_sop.device
)
+ traj_ids = traj_ids.masked_scatter(traj_sop, new_traj)
self._shuttle.set(("collector", "traj_ids"), traj_ids)
@torch.no_grad()
@@ -1053,13 +1171,18 @@ def rollout(self) -> TensorDictBase:
next_data.clear_device_()
self._shuttle.set("next", next_data)
- if self.storing_device is not None:
- tensordicts.append(
- self._shuttle.to(self.storing_device, non_blocking=True)
- )
- self._sync_storage()
+ if self.replay_buffer is not None:
+ self.replay_buffer.add(self._shuttle)
+ if self._increment_frames(self._shuttle.numel()):
+ return
else:
- tensordicts.append(self._shuttle)
+ if self.storing_device is not None:
+ tensordicts.append(
+ self._shuttle.to(self.storing_device, non_blocking=True)
+ )
+ self._sync_storage()
+ else:
+ tensordicts.append(self._shuttle)
# carry over collector data without messing up devices
collector_data = self._shuttle.get("collector").copy()
@@ -1067,13 +1190,14 @@ def rollout(self) -> TensorDictBase:
if self._shuttle_has_no_device:
self._shuttle.clear_device_()
self._shuttle.set("collector", collector_data)
-
self._update_traj_ids(env_output)
if (
self.interruptor is not None
and self.interruptor.collection_stopped()
):
+ if self.replay_buffer is not None:
+ return
result = self._final_rollout
if self._use_buffers:
try:
@@ -1109,6 +1233,8 @@ def rollout(self) -> TensorDictBase:
self._final_rollout.ndim - 1,
out=self._final_rollout,
)
+ elif self.replay_buffer is not None:
+ return
else:
result = TensorDict.maybe_dense_stack(tensordicts, dim=-1)
result.refine_names(..., "time")
@@ -1359,7 +1485,7 @@ class _MultiDataCollector(DataCollectorBase):
workers may charge the cpu load too much and harm performance.
cat_results (str, int or None): (:class:`~torchrl.collectors.MultiSyncDataCollector` exclusively).
If ``"stack"``, the data collected from the workers will be stacked along the
- first dimension. This is the preferred behaviour as it is the most compatible
+ first dimension. This is the preferred behavior as it is the most compatible
with the rest of the library.
If ``0``, results will be concatenated along the first dimension
of the outputs, which can be the batched dimension if the environments are
@@ -1367,7 +1493,7 @@ class _MultiDataCollector(DataCollectorBase):
A ``cat_results`` value of ``-1`` will always concatenate results along the
time dimension. This should be preferred over the default. Intermediate values
are also accepted.
- Defaults to ``0``.
+ Defaults to ``"stack"``.
.. note:: From v0.5, this argument will default to ``"stack"`` for a better
interoperability with the rest of the library.
@@ -1380,6 +1506,17 @@ class _MultiDataCollector(DataCollectorBase):
use_buffers (bool, optional): if ``True``, a buffer will be used to stack the data.
This isn't compatible with environments with dynamic specs. Defaults to ``True``
for envs without dynamic specs, ``False`` for others.
+ replay_buffer (ReplayBuffer, optional): if provided, the collector will not yield tensordict
+ but populate the buffer instead. Defaults to ``None``.
+ trust_policy (bool, optional): if ``True``, a non-TensorDictModule policy will be trusted to be
+ assumed to be compatible with the collector. This defaults to ``True`` for CudaGraphModules
+ and ``False`` otherwise.
+ compile_policy (bool or Dict[str, Any], optional): if ``True``, the policy will be compiled
+ using :func:`~torch.compile` default behaviour. If a dictionary of kwargs is passed, it
+ will be used to compile the policy.
+ cudagraph_policy (bool or Dict[str, Any], optional): if ``True``, the policy will be wrapped
+ in :class:`~tensordict.nn.CudaGraphModule` with default kwargs.
+ If a dictionary of kwargs is passed, it will be used to wrap the policy.
"""
@@ -1406,7 +1543,6 @@ def __init__(
postproc: Optional[Callable[[TensorDictBase], TensorDictBase]] = None,
split_trajs: Optional[bool] = None,
exploration_type: ExplorationType = DEFAULT_EXPLORATION_TYPE,
- exploration_mode=None,
reset_when_done: bool = True,
update_at_each_batch: bool = False,
preemptive_threshold: float = None,
@@ -1415,10 +1551,12 @@ def __init__(
cat_results: str | int | None = None,
set_truncated: bool = False,
use_buffers: bool | None = None,
+ replay_buffer: ReplayBuffer | None = None,
+ replay_buffer_chunk: bool = True,
+ trust_policy: bool = None,
+ compile_policy: bool | Dict[str, Any] | None = None,
+ cudagraph_policy: bool | Dict[str, Any] | None = None,
):
- exploration_type = _convert_exploration_type(
- exploration_mode=exploration_mode, exploration_type=exploration_type
- )
self.closed = True
self.num_workers = len(create_env_fn)
@@ -1426,6 +1564,7 @@ def __init__(
self.num_sub_threads = num_sub_threads
self.num_threads = num_threads
self.create_env_fn = create_env_fn
+ self._read_compile_kwargs(compile_policy, cudagraph_policy)
self.create_env_kwargs = (
create_env_kwargs
if create_env_kwargs is not None
@@ -1458,79 +1597,42 @@ def __init__(
del storing_device, env_device, policy_device, device
self._use_buffers = use_buffers
+ self.replay_buffer = replay_buffer
+ self._check_replay_buffer_init()
+ self.replay_buffer_chunk = replay_buffer_chunk
+ if (
+ replay_buffer is not None
+ and hasattr(replay_buffer, "shared")
+ and not replay_buffer.shared
+ ):
+ replay_buffer.share()
- _policy_weights_dict = {}
- _get_weights_fn_dict = {}
-
- if policy is not None:
- policy = _NonParametricPolicyWrapper(policy)
- policy_weights = TensorDict.from_module(policy, as_module=True)
-
- # store a stateless policy
- with policy_weights.apply(_make_meta_params).to_module(policy):
- # TODO:
- self.policy = deepcopy(policy)
-
- else:
- policy_weights = TensorDict()
- self.policy = None
-
- for policy_device in policy_devices:
- # if we have already mapped onto that device, get that value
- if policy_device in _policy_weights_dict:
- continue
- # If policy device is None, the only thing we need to do is
- # make sure that the weights are shared.
- if policy_device is None:
+ self._policy_weights_dict = {}
+ self._get_weights_fn_dict = {}
- def map_weight(
- weight,
- ):
- is_param = isinstance(weight, nn.Parameter)
- weight = weight.data
- if weight.device.type in ("cpu", "mps"):
- weight = weight.share_memory_()
- if is_param:
- weight = nn.Parameter(weight, requires_grad=False)
- return weight
-
- # in other cases, we need to cast the policy if and only if not all the weights
- # are on the appropriate device
- else:
- # check the weights devices
- has_different_device = [False]
+ if trust_policy is None:
+ trust_policy = isinstance(policy, CudaGraphModule)
+ self.trust_policy = trust_policy
- def map_weight(
- weight,
- policy_device=policy_device,
- has_different_device=has_different_device,
- ):
- is_param = isinstance(weight, nn.Parameter)
- weight = weight.data
- if weight.device != policy_device:
- has_different_device[0] = True
- weight = weight.to(policy_device)
- elif weight.device.type in ("cpu", "mps"):
- weight = weight.share_memory_()
- if is_param:
- weight = nn.Parameter(weight, requires_grad=False)
- return weight
-
- local_policy_weights = TensorDictParams(
- policy_weights.apply(map_weight, filter_empty=False)
+ for policy_device, env_maker, env_maker_kwargs in zip(
+ self.policy_device, self.create_env_fn, self.create_env_kwargs
+ ):
+ (policy_copy, get_weights_fn,) = self._get_policy_and_device(
+ policy=policy,
+ policy_device=policy_device,
+ env_maker=env_maker,
+ env_maker_kwargs=env_maker_kwargs,
)
-
- def _get_weight_fn(weights=policy_weights):
- # This function will give the local_policy_weight the original weights.
- # see self.update_policy_weights_ to see how this is used
- return weights
-
- # We lock the weights to be able to cache a bunch of ops and to avoid modifying it
- _policy_weights_dict[policy_device] = local_policy_weights.lock_()
- _get_weights_fn_dict[policy_device] = _get_weight_fn
-
- self._policy_weights_dict = _policy_weights_dict
- self._get_weights_fn_dict = _get_weights_fn_dict
+ if type(policy_copy) is not type(policy):
+ policy = policy_copy
+ weights = (
+ TensorDict.from_module(policy_copy)
+ if isinstance(policy_copy, nn.Module)
+ else TensorDict()
+ )
+ self._policy_weights_dict[policy_device] = weights
+ self._get_weights_fn_dict[policy_device] = get_weights_fn
+ self.policy = policy
if total_frames is None or total_frames < 0:
total_frames = float("inf")
@@ -1598,6 +1700,26 @@ def _get_weight_fn(weights=policy_weights):
)
self.cat_results = cat_results
+ def _check_replay_buffer_init(self):
+ if self.replay_buffer is None:
+ return
+ is_init = getattr(self.replay_buffer._storage, "initialized", True)
+ if not is_init:
+ if isinstance(self.create_env_fn[0], EnvCreator):
+ fake_td = self.create_env_fn[0].meta_data.tensordict
+ elif isinstance(self.create_env_fn[0], EnvBase):
+ fake_td = self.create_env_fn[0].fake_tensordict()
+ else:
+ fake_td = self.create_env_fn[0](
+ **self.create_env_kwargs[0]
+ ).fake_tensordict()
+ fake_td["collector", "traj_ids"] = torch.zeros(
+ fake_td.shape, dtype=torch.long
+ )
+
+ self.replay_buffer.add(fake_td)
+ self.replay_buffer.empty()
+
@classmethod
def _total_workers_from_env(cls, env_creators):
if isinstance(env_creators, (tuple, list)):
@@ -1665,10 +1787,10 @@ def frames_per_batch_worker(self):
raise NotImplementedError
def update_policy_weights_(self, policy_weights=None) -> None:
+ if isinstance(policy_weights, TensorDictParams):
+ policy_weights = policy_weights.data
for _device in self._policy_weights_dict:
if policy_weights is not None:
- if isinstance(policy_weights, TensorDictParams):
- policy_weights = policy_weights.data
self._policy_weights_dict[_device].data.update_(policy_weights)
elif self._get_weights_fn_dict[_device] is not None:
original_weights = self._get_weights_fn_dict[_device]()
@@ -1694,6 +1816,8 @@ def _run_processes(self) -> None:
queue_out = mp.Queue(self._queue_len) # sends data from proc to main
self.procs = []
self.pipes = []
+ self._traj_pool = _TrajectoryPool(lock=True)
+
for i, (env_fun, env_fun_kwargs) in enumerate(
zip(self.create_env_fn, self.create_env_kwargs)
):
@@ -1708,9 +1832,12 @@ def _run_processes(self) -> None:
storing_device = self.storing_device[i]
env_device = self.env_device[i]
policy = self.policy
- with self._policy_weights_dict[policy_device].to_module(
- policy
- ) if policy is not None else contextlib.nullcontext():
+ policy_weights = self._policy_weights_dict[policy_device]
+ if policy is not None and policy_weights is not None:
+ cm = policy_weights.to_module(policy)
+ else:
+ cm = contextlib.nullcontext()
+ with cm:
kwargs = {
"pipe_parent": pipe_parent,
"pipe_child": pipe_child,
@@ -1730,6 +1857,16 @@ def _run_processes(self) -> None:
"interruptor": self.interruptor,
"set_truncated": self.set_truncated,
"use_buffers": self._use_buffers,
+ "replay_buffer": self.replay_buffer,
+ "replay_buffer_chunk": self.replay_buffer_chunk,
+ "traj_pool": self._traj_pool,
+ "trust_policy": self.trust_policy,
+ "compile_policy": self.compiled_policy_kwargs
+ if self.compiled_policy
+ else False,
+ "cudagraph_policy": self.cudagraphed_policy_kwargs
+ if self.cudagraphed_policy
+ else False,
}
proc = _ProcessNoWarn(
target=_main_async_collector,
@@ -1754,6 +1891,7 @@ def _run_processes(self) -> None:
self.procs.append(proc)
self.pipes.append(pipe_parent)
for pipe_parent in self.pipes:
+ pipe_parent.poll(timeout=INSTANTIATE_TIMEOUT)
msg = pipe_parent.recv()
if msg != "instantiated":
raise RuntimeError(msg)
@@ -2069,29 +2207,12 @@ def iterator(self) -> Iterator[TensorDictBase]:
cat_results = self.cat_results
if cat_results is None:
cat_results = "stack"
- warnings.warn(
- f"`cat_results` was not specified in the constructor of {type(self).__name__}. "
- f"For MultiSyncDataCollector, `cat_results` indicates how the data should "
- f"be packed: the preferred option and current default is `cat_results='stack'` "
- f"which provides the best interoperability across torchrl components. "
- f"Other accepted values are `cat_results=0` (previous behaviour) and "
- f"`cat_results=-1` (cat along time dimension). Among these two, the latter "
- f"should be preferred for consistency across environment configurations. "
- f"Currently, the default value is `'stack'`."
- f"From v0.6 onward, this warning will be removed. "
- f"To suppress this warning, set `cat_results` to the desired value.",
- category=DeprecationWarning,
- )
self.buffers = {}
dones = [False for _ in range(self.num_workers)]
workers_frames = [0 for _ in range(self.num_workers)]
same_device = None
self.out_buffer = None
- last_traj_ids = [-10 for _ in range(self.num_workers)]
- last_traj_ids_subs = [None for _ in range(self.num_workers)]
- traj_max = -1
- traj_ids_list = [None for _ in range(self.num_workers)]
preempt = self.interruptor is not None and self.preemptive_threshold < 1.0
while not all(dones) and self._frames < self.total_frames:
@@ -2125,7 +2246,13 @@ def iterator(self) -> Iterator[TensorDictBase]:
for _ in range(self.num_workers):
new_data, j = self.queue_out.get()
use_buffers = self._use_buffers
- if j == 0 or not use_buffers:
+ if self.replay_buffer is not None:
+ idx = new_data
+ workers_frames[idx] = (
+ workers_frames[idx] + self.frames_per_batch_worker
+ )
+ continue
+ elif j == 0 or not use_buffers:
try:
data, idx = new_data
self.buffers[idx] = data
@@ -2167,51 +2294,25 @@ def iterator(self) -> Iterator[TensorDictBase]:
if workers_frames[idx] >= self.total_frames:
dones[idx] = True
+ if self.replay_buffer is not None:
+ yield
+ self._frames += self.frames_per_batch_worker * self.num_workers
+ continue
+
# we have to correct the traj_ids to make sure that they don't overlap
# We can count the number of frames collected for free in this loop
n_collected = 0
for idx in range(self.num_workers):
buffer = buffers[idx]
traj_ids = buffer.get(("collector", "traj_ids"))
- is_last = traj_ids == last_traj_ids[idx]
- # If we `cat` interrupted data, we have already filtered out
- # non-valid steps. If we stack, we haven't.
- if preempt and cat_results == "stack":
- valid = buffer.get(("collector", "traj_ids")) != -1
- if valid.ndim > 2:
- valid = valid.flatten(0, -2)
- if valid.ndim == 2:
- valid = valid.any(0)
- last_traj_ids[idx] = traj_ids[..., valid][..., -1:].clone()
- else:
- last_traj_ids[idx] = traj_ids[..., -1:].clone()
- if not is_last.all():
- traj_to_correct = traj_ids[~is_last]
- traj_to_correct = (
- traj_to_correct + (traj_max + 1) - traj_to_correct.min()
- )
- traj_ids = traj_ids.masked_scatter(~is_last, traj_to_correct)
- # is_last can only be true if we're after the first iteration
- if is_last.any():
- traj_ids = torch.where(
- is_last, last_traj_ids_subs[idx].expand_as(traj_ids), traj_ids
- )
-
if preempt:
if cat_results == "stack":
mask_frames = buffer.get(("collector", "traj_ids")) != -1
- traj_ids = torch.where(mask_frames, traj_ids, -1)
n_collected += mask_frames.sum().cpu()
- last_traj_ids_subs[idx] = traj_ids[..., valid][..., -1:].clone()
else:
- last_traj_ids_subs[idx] = traj_ids[..., -1:].clone()
n_collected += traj_ids.numel()
else:
- last_traj_ids_subs[idx] = traj_ids[..., -1:].clone()
n_collected += traj_ids.numel()
- traj_ids_list[idx] = traj_ids
-
- traj_max = max(traj_max, traj_ids.max())
if same_device is None:
prev_device = None
@@ -2232,9 +2333,6 @@ def iterator(self) -> Iterator[TensorDictBase]:
self.out_buffer = stack(
[item.cpu() for item in buffers.values()], 0
)
- self.out_buffer.set(
- ("collector", "traj_ids"), torch.stack(traj_ids_list), inplace=True
- )
else:
if self._use_buffers is None:
torchrl_logger.warning(
@@ -2251,9 +2349,6 @@ def iterator(self) -> Iterator[TensorDictBase]:
self.out_buffer = torch.cat(
[item.cpu() for item in buffers.values()], cat_results
)
- self.out_buffer.set_(
- ("collector", "traj_ids"), torch.cat(traj_ids_list, cat_results)
- )
except RuntimeError as err:
if (
preempt
@@ -2397,7 +2492,7 @@ class MultiaSyncDataCollector(_MultiDataCollector):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
- self.out_tensordicts = {}
+ self.out_tensordicts = defaultdict(lambda: None)
self.running = False
if self.postprocs is not None:
@@ -2442,7 +2537,9 @@ def frames_per_batch_worker(self):
def _get_from_queue(self, timeout=None) -> Tuple[int, int, TensorDictBase]:
new_data, j = self.queue_out.get(timeout=timeout)
use_buffers = self._use_buffers
- if j == 0 or not use_buffers:
+ if self.replay_buffer is not None:
+ idx = new_data
+ elif j == 0 or not use_buffers:
try:
data, idx = new_data
self.out_tensordicts[idx] = data
@@ -2457,7 +2554,7 @@ def _get_from_queue(self, timeout=None) -> Tuple[int, int, TensorDictBase]:
else:
idx = new_data
out = self.out_tensordicts[idx]
- if j == 0 or use_buffers:
+ if not self.replay_buffer and (j == 0 or use_buffers):
# we clone the data to make sure that we'll be working with a fixed copy
out = out.clone()
return idx, j, out
@@ -2479,12 +2576,19 @@ def iterator(self) -> Iterator[TensorDictBase]:
workers_frames = [0 for _ in range(self.num_workers)]
while self._frames < self.total_frames:
- _check_for_faulty_process(self.procs)
self._iter += 1
- idx, j, out = self._get_from_queue()
- worker_frames = out.numel()
- if self.split_trajs:
- out = split_trajectories(out, prefix="collector")
+ while True:
+ try:
+ idx, j, out = self._get_from_queue(timeout=10.0)
+ break
+ except TimeoutError:
+ _check_for_faulty_process(self.procs)
+ if self.replay_buffer is None:
+ worker_frames = out.numel()
+ if self.split_trajs:
+ out = split_trajectories(out, prefix="collector")
+ else:
+ worker_frames = self.frames_per_batch_worker
self._frames += worker_frames
workers_frames[idx] = workers_frames[idx] + worker_frames
if self.postprocs:
@@ -2500,7 +2604,7 @@ def iterator(self) -> Iterator[TensorDictBase]:
else:
msg = "continue"
self.pipes[idx].send((idx, msg))
- if self._exclude_private_keys:
+ if out is not None and self._exclude_private_keys:
excluded_keys = [key for key in out.keys() if key.startswith("_")]
out = out.exclude(*excluded_keys)
yield out
@@ -2687,7 +2791,6 @@ def __init__(
postproc: Optional[Callable[[TensorDictBase], TensorDictBase]] = None,
split_trajs: Optional[bool] = None,
exploration_type: ExplorationType = DEFAULT_EXPLORATION_TYPE,
- exploration_mode=None,
reset_when_done: bool = True,
update_at_each_batch: bool = False,
preemptive_threshold: float = None,
@@ -2712,7 +2815,6 @@ def __init__(
env_device=env_device,
storing_device=storing_device,
exploration_type=exploration_type,
- exploration_mode=exploration_mode,
reset_when_done=reset_when_done,
update_at_each_batch=update_at_each_batch,
preemptive_threshold=preemptive_threshold,
@@ -2762,6 +2864,12 @@ def _main_async_collector(
interruptor=None,
set_truncated: bool = False,
use_buffers: bool | None = None,
+ replay_buffer: ReplayBuffer | None = None,
+ replay_buffer_chunk: bool = True,
+ traj_pool: _TrajectoryPool = None,
+ trust_policy: bool = False,
+ compile_policy: bool = False,
+ cudagraph_policy: bool = False,
) -> None:
pipe_parent.close()
# init variables that will be cleared when closing
@@ -2782,10 +2890,15 @@ def _main_async_collector(
env_device=env_device,
exploration_type=exploration_type,
reset_when_done=reset_when_done,
- return_same_td=True,
+ return_same_td=replay_buffer is None,
interruptor=interruptor,
set_truncated=set_truncated,
use_buffers=use_buffers,
+ replay_buffer=replay_buffer if replay_buffer_chunk else None,
+ traj_pool=traj_pool,
+ trust_policy=trust_policy,
+ compile_policy=compile_policy,
+ cudagraph_policy=cudagraph_policy,
)
use_buffers = inner_collector._use_buffers
if verbose:
@@ -2848,6 +2961,25 @@ def _main_async_collector(
# In that case, we skip the collected trajectory and get the message from main. This is faster than
# sending the trajectory in the queue until timeout when it's never going to be received.
continue
+
+ if replay_buffer is not None:
+ if not replay_buffer_chunk:
+ next_data.names = None
+ replay_buffer.extend(next_data)
+
+ try:
+ queue_out.put((idx, j), timeout=_TIMEOUT)
+ if verbose:
+ torchrl_logger.info(f"worker {idx} successfully sent data")
+ j += 1
+ has_timed_out = False
+ continue
+ except queue.Full:
+ if verbose:
+ torchrl_logger.info(f"worker {idx} has timed out")
+ has_timed_out = True
+ continue
+
if j == 0 or not use_buffers:
collected_tensordict = next_data
if (
@@ -2862,10 +2994,16 @@ def _main_async_collector(
# if policy is on cuda and env on cuda, we are fine with this
# If policy is on cuda and env on cpu (or opposite) we put tensors that
# are on cpu in shared mem.
+ MPS_ERROR = (
+ "tensors on mps device cannot be put in shared memory. Make sure "
+ "the shared device (aka storing_device) is set to CPU."
+ )
if collected_tensordict.device is not None:
- # placehoder in case we need different behaviours
- if collected_tensordict.device.type in ("cpu", "mps"):
+ # placehoder in case we need different behaviors
+ if collected_tensordict.device.type in ("cpu",):
collected_tensordict.share_memory_()
+ elif collected_tensordict.device.type in ("mps",):
+ raise RuntimeError(MPS_ERROR)
elif collected_tensordict.device.type == "cuda":
collected_tensordict.share_memory_()
else:
@@ -2874,11 +3012,13 @@ def _main_async_collector(
)
else:
# make sure each cpu tensor is shared - assuming non-cpu devices are shared
- collected_tensordict.apply(
- lambda x: x.share_memory_()
- if x.device.type in ("cpu", "mps")
- else x
- )
+ def cast_tensor(x, MPS_ERROR=MPS_ERROR):
+ if x.device.type in ("cpu",):
+ x.share_memory_()
+ if x.device.type in ("mps",):
+ RuntimeError(MPS_ERROR)
+
+ collected_tensordict.apply(cast_tensor, filter_empty=True)
data = (collected_tensordict, idx)
else:
if next_data is not collected_tensordict:
@@ -2956,3 +3096,20 @@ def _make_meta_params(param):
if is_param:
pd = nn.Parameter(pd, requires_grad=False)
return pd
+
+
+class _TrajectoryPool:
+ def __init__(self, ctx=None, lock: bool = False):
+ self.ctx = ctx
+ self._traj_id = torch.zeros((), device="cpu", dtype=torch.int).share_memory_()
+ if ctx is None:
+ self.lock = contextlib.nullcontext() if not lock else mp.RLock()
+ else:
+ self.lock = contextlib.nullcontext() if not lock else ctx.RLock()
+
+ def get_traj_and_increment(self, n=1, device=None):
+ with self.lock:
+ v = self._traj_id.item()
+ out = torch.arange(v, v + n).to(device)
+ self._traj_id.copy_(1 + out[-1].item())
+ return out
diff --git a/torchrl/collectors/distributed/generic.py b/torchrl/collectors/distributed/generic.py
index 596c1f5d191..729b8a48171 100644
--- a/torchrl/collectors/distributed/generic.py
+++ b/torchrl/collectors/distributed/generic.py
@@ -30,11 +30,10 @@
MAX_TIME_TO_CONNECT,
TCP_PORT,
)
-from torchrl.collectors.utils import split_trajectories
+from torchrl.collectors.utils import _NON_NN_POLICY_WEIGHTS, split_trajectories
from torchrl.data.utils import CloudpickleWrapper
from torchrl.envs.common import EnvBase
from torchrl.envs.env_creator import EnvCreator
-from torchrl.envs.utils import _convert_exploration_type
SUBMITIT_ERR = None
try:
@@ -172,18 +171,11 @@ def _run_collector(
)
if isinstance(policy, nn.Module):
- policy_weights = TensorDict(dict(policy.named_parameters()), [])
- # TODO: Do we want this?
- # updates the policy weights to avoid them to be shared
- if all(
- param.device == torch.device("cpu") for param in policy_weights.values()
- ):
- policy = deepcopy(policy)
- policy_weights = TensorDict(dict(policy.named_parameters()), [])
-
- policy_weights = policy_weights.apply(lambda x: x.data)
+ policy_weights = TensorDict.from_module(policy)
+ policy_weights = policy_weights.data.lock_()
else:
- policy_weights = TensorDict({}, [])
+ warnings.warn(_NON_NN_POLICY_WEIGHTS)
+ policy_weights = TensorDict(lock=True)
collector = collector_class(
env_make,
@@ -426,7 +418,6 @@ def __init__(
postproc: Callable | None = None,
split_trajs: bool = False,
exploration_type: "ExporationType" = DEFAULT_EXPLORATION_TYPE, # noqa
- exploration_mode: str = None,
collector_class: Type = SyncDataCollector,
collector_kwargs: dict = None,
num_workers_per_collector: int = 1,
@@ -438,9 +429,6 @@ def __init__(
launcher: str = "submitit",
tcp_port: int = None,
):
- exploration_type = _convert_exploration_type(
- exploration_mode=exploration_mode, exploration_type=exploration_type
- )
if collector_class == "async":
collector_class = MultiaSyncDataCollector
@@ -452,10 +440,11 @@ def __init__(
self.env_constructors = create_env_fn
self.policy = policy
if isinstance(policy, nn.Module):
- policy_weights = TensorDict(dict(policy.named_parameters()), [])
- policy_weights = policy_weights.apply(lambda x: x.data)
+ policy_weights = TensorDict.from_module(policy)
+ policy_weights = policy_weights.data.lock_()
else:
- policy_weights = TensorDict({}, [])
+ warnings.warn(_NON_NN_POLICY_WEIGHTS)
+ policy_weights = TensorDict(lock=True)
self.policy_weights = policy_weights
self.num_workers = len(create_env_fn)
self.frames_per_batch = frames_per_batch
@@ -820,6 +809,8 @@ def _iterator_dist(self):
for i in range(self.num_workers):
rank = i + 1
+ if self._VERBOSE:
+ torchrl_logger.info(f"shutting down rank {rank}.")
self._store.set(f"NODE_{rank}_in", b"shutdown")
def _next_sync(self, total_frames):
diff --git a/torchrl/collectors/distributed/ray.py b/torchrl/collectors/distributed/ray.py
index 79b3ee9063c..1f088c2c404 100644
--- a/torchrl/collectors/distributed/ray.py
+++ b/torchrl/collectors/distributed/ray.py
@@ -20,7 +20,7 @@
MultiSyncDataCollector,
SyncDataCollector,
)
-from torchrl.collectors.utils import split_trajectories
+from torchrl.collectors.utils import _NON_NN_POLICY_WEIGHTS, split_trajectories
from torchrl.envs.common import EnvBase
from torchrl.envs.env_creator import EnvCreator
@@ -99,7 +99,7 @@ class RayCollector(DataCollectorBase):
The class dictionary input parameter "ray_init_config" can be used to provide the kwargs to
call Ray initialization method ray.init(). If "ray_init_config" is not provided, the default
- behaviour is to autodetect an existing Ray cluster or start a new Ray instance locally if no
+ behavior is to autodetect an existing Ray cluster or start a new Ray instance locally if no
existing cluster is found. Refer to Ray documentation for advanced initialization kwargs.
Similarly, dictionary input parameter "remote_configs" can be used to specify the kwargs for
@@ -401,9 +401,11 @@ def check_list_length_consistency(*lists):
self._local_policy = policy
if isinstance(self._local_policy, nn.Module):
- policy_weights = TensorDict(dict(policy.named_parameters()), [])
+ policy_weights = TensorDict.from_module(self._local_policy)
+ policy_weights = policy_weights.data.lock_()
else:
- policy_weights = TensorDict({}, [])
+ warnings.warn(_NON_NN_POLICY_WEIGHTS)
+ policy_weights = TensorDict(lock=True)
self.policy_weights = policy_weights
self.collector_class = collector_class
self.collected_frames = 0
diff --git a/torchrl/collectors/distributed/rpc.py b/torchrl/collectors/distributed/rpc.py
index b6c324bb7b5..73247df4b0c 100644
--- a/torchrl/collectors/distributed/rpc.py
+++ b/torchrl/collectors/distributed/rpc.py
@@ -22,9 +22,8 @@
IDLE_TIMEOUT,
TCP_PORT,
)
-from torchrl.collectors.utils import split_trajectories
+from torchrl.collectors.utils import _NON_NN_POLICY_WEIGHTS, split_trajectories
from torchrl.data.utils import CloudpickleWrapper
-from torchrl.envs.utils import _convert_exploration_type
SUBMITIT_ERR = None
try:
@@ -275,7 +274,6 @@ def __init__(
postproc: Callable | None = None,
split_trajs: bool = False,
exploration_type: "ExporationType" = DEFAULT_EXPLORATION_TYPE, # noqa
- exploration_mode: str = None,
collector_class=SyncDataCollector,
collector_kwargs=None,
num_workers_per_collector=1,
@@ -288,9 +286,6 @@ def __init__(
visible_devices=None,
tensorpipe_options=None,
):
- exploration_type = _convert_exploration_type(
- exploration_mode=exploration_mode, exploration_type=exploration_type
- )
if collector_class == "async":
collector_class = MultiaSyncDataCollector
elif collector_class == "sync":
@@ -301,9 +296,11 @@ def __init__(
self.env_constructors = create_env_fn
self.policy = policy
if isinstance(policy, nn.Module):
- policy_weights = TensorDict(dict(policy.named_parameters()), [])
+ policy_weights = TensorDict.from_module(policy)
+ policy_weights = policy_weights.data.lock_()
else:
- policy_weights = TensorDict({}, [])
+ warnings.warn(_NON_NN_POLICY_WEIGHTS)
+ policy_weights = TensorDict(lock=True)
self.policy_weights = policy_weights
self.num_workers = len(create_env_fn)
self.frames_per_batch = frames_per_batch
diff --git a/torchrl/collectors/distributed/sync.py b/torchrl/collectors/distributed/sync.py
index 6f959086c83..481fb70cc31 100644
--- a/torchrl/collectors/distributed/sync.py
+++ b/torchrl/collectors/distributed/sync.py
@@ -8,6 +8,7 @@
import os
import socket
+import warnings
from copy import copy, deepcopy
from datetime import timedelta
from typing import Callable, List, OrderedDict
@@ -29,11 +30,10 @@
DEFAULT_SLURM_CONF,
MAX_TIME_TO_CONNECT,
)
-from torchrl.collectors.utils import split_trajectories
+from torchrl.collectors.utils import _NON_NN_POLICY_WEIGHTS, split_trajectories
from torchrl.data.utils import CloudpickleWrapper
from torchrl.envs.common import EnvBase
from torchrl.envs.env_creator import EnvCreator
-from torchrl.envs.utils import _convert_exploration_type
SUBMITIT_ERR = None
try:
@@ -78,18 +78,11 @@ def _distributed_init_collection_node(
)
if isinstance(policy, nn.Module):
- policy_weights = TensorDict(dict(policy.named_parameters()), [])
- # TODO: Do we want this?
- # updates the policy weights to avoid them to be shared
- if all(
- param.device == torch.device("cpu") for param in policy_weights.values()
- ):
- policy = deepcopy(policy)
- policy_weights = TensorDict(dict(policy.named_parameters()), [])
-
- policy_weights = policy_weights.apply(lambda x: x.data)
+ policy_weights = TensorDict.from_module(policy)
+ policy_weights = policy_weights.data.lock_()
else:
- policy_weights = TensorDict({}, [])
+ warnings.warn(_NON_NN_POLICY_WEIGHTS)
+ policy_weights = TensorDict(lock=True)
collector = collector_class(
env_make,
@@ -291,7 +284,6 @@ def __init__(
postproc: Callable | None = None,
split_trajs: bool = False,
exploration_type: "ExporationType" = DEFAULT_EXPLORATION_TYPE, # noqa
- exploration_mode: str = None,
collector_class=SyncDataCollector,
collector_kwargs=None,
num_workers_per_collector=1,
@@ -302,9 +294,6 @@ def __init__(
launcher="submitit",
tcp_port=None,
):
- exploration_type = _convert_exploration_type(
- exploration_mode=exploration_mode, exploration_type=exploration_type
- )
if collector_class == "async":
collector_class = MultiaSyncDataCollector
@@ -315,11 +304,14 @@ def __init__(
self.collector_class = collector_class
self.env_constructors = create_env_fn
self.policy = policy
+
if isinstance(policy, nn.Module):
- policy_weights = TensorDict(dict(policy.named_parameters()), [])
- policy_weights = policy_weights.apply(lambda x: x.data)
+ policy_weights = TensorDict.from_module(policy)
+ policy_weights = policy_weights.data.lock_()
else:
- policy_weights = TensorDict({}, [])
+ warnings.warn(_NON_NN_POLICY_WEIGHTS)
+ policy_weights = TensorDict(lock=True)
+
self.policy_weights = policy_weights
self.num_workers = len(create_env_fn)
self.frames_per_batch = frames_per_batch
diff --git a/torchrl/collectors/distributed/utils.py b/torchrl/collectors/distributed/utils.py
index aeee573f8dc..2dd6fcf6c93 100644
--- a/torchrl/collectors/distributed/utils.py
+++ b/torchrl/collectors/distributed/utils.py
@@ -53,10 +53,10 @@ class submitit_delayed_launcher:
... def main():
... from torchrl.envs.utils import RandomPolicy
from torchrl.envs.libs.gym import GymEnv
- ... from torchrl.data import BoundedTensorSpec
+ ... from torchrl.data import BoundedContinuous
... collector = DistributedDataCollector(
... [EnvCreator(lambda: GymEnv("Pendulum-v1"))] * num_jobs,
- ... policy=RandomPolicy(BoundedTensorSpec(-1, 1, shape=(1,))),
+ ... policy=RandomPolicy(BoundedContinuous(-1, 1, shape=(1,))),
... launcher="submitit_delayed",
... )
... for data in collector:
diff --git a/torchrl/collectors/utils.py b/torchrl/collectors/utils.py
index d777da3de2a..74bea267c22 100644
--- a/torchrl/collectors/utils.py
+++ b/torchrl/collectors/utils.py
@@ -11,6 +11,12 @@
from tensordict import NestedKey, pad, set_lazy_legacy, TensorDictBase
+_NON_NN_POLICY_WEIGHTS = (
+ "The policy is not an nn.Module. TorchRL will assume that the parameter set is empty and "
+ "update_policy_weights_ will be a no-op."
+)
+
+
def _stack_output(fun) -> Callable:
def stacked_output_fun(*args, **kwargs):
out = fun(*args, **kwargs)
diff --git a/torchrl/data/__init__.py b/torchrl/data/__init__.py
index c894c15724b..026a0b3baf2 100644
--- a/torchrl/data/__init__.py
+++ b/torchrl/data/__init__.py
@@ -59,19 +59,32 @@
TokenizedDatasetLoader,
)
from .tensor_specs import (
+ Binary,
BinaryDiscreteTensorSpec,
+ Bounded,
BoundedTensorSpec,
+ Categorical,
+ Composite,
CompositeSpec,
DEVICE_TYPING,
DiscreteTensorSpec,
LazyStackedCompositeSpec,
LazyStackedTensorSpec,
+ MultiCategorical,
MultiDiscreteTensorSpec,
+ MultiOneHot,
MultiOneHotDiscreteTensorSpec,
+ NonTensor,
NonTensorSpec,
+ OneHot,
OneHotDiscreteTensorSpec,
+ Stacked,
+ StackedComposite,
TensorSpec,
+ Unbounded,
+ UnboundedContinuous,
UnboundedContinuousTensorSpec,
+ UnboundedDiscrete,
UnboundedDiscreteTensorSpec,
)
from .utils import check_no_exclusive_keys, consolidate_spec, contains_lazy_spec
diff --git a/torchrl/data/datasets/minari_data.py b/torchrl/data/datasets/minari_data.py
index 3cc8d7437c0..d6a49f17113 100644
--- a/torchrl/data/datasets/minari_data.py
+++ b/torchrl/data/datasets/minari_data.py
@@ -25,12 +25,7 @@
from torchrl.data.replay_buffers.samplers import Sampler
from torchrl.data.replay_buffers.storages import TensorStorage
from torchrl.data.replay_buffers.writers import ImmutableDatasetWriter, Writer
-from torchrl.data.tensor_specs import (
- BoundedTensorSpec,
- CompositeSpec,
- DiscreteTensorSpec,
- UnboundedContinuousTensorSpec,
-)
+from torchrl.data.tensor_specs import Bounded, Categorical, Composite, Unbounded
from torchrl.envs.utils import _classproperty
_has_tqdm = importlib.util.find_spec("tqdm", None) is not None
@@ -398,24 +393,22 @@ def _proc_spec(spec):
if spec is None:
return
if spec["type"] == "Dict":
- return CompositeSpec(
+ return Composite(
{key: _proc_spec(subspec) for key, subspec in spec["subspaces"].items()}
)
elif spec["type"] == "Box":
if all(item == -float("inf") for item in spec["low"]) and all(
item == float("inf") for item in spec["high"]
):
- return UnboundedContinuousTensorSpec(
- spec["shape"], dtype=_DTYPE_DIR[spec["dtype"]]
- )
- return BoundedTensorSpec(
+ return Unbounded(spec["shape"], dtype=_DTYPE_DIR[spec["dtype"]])
+ return Bounded(
shape=spec["shape"],
low=torch.as_tensor(spec["low"]),
high=torch.as_tensor(spec["high"]),
dtype=_DTYPE_DIR[spec["dtype"]],
)
elif spec["type"] == "Discrete":
- return DiscreteTensorSpec(
+ return Categorical(
spec["n"], shape=spec["shape"], dtype=_DTYPE_DIR[spec["dtype"]]
)
else:
diff --git a/torchrl/data/datasets/openx.py b/torchrl/data/datasets/openx.py
index 975384a3662..2dbf0720a37 100644
--- a/torchrl/data/datasets/openx.py
+++ b/torchrl/data/datasets/openx.py
@@ -77,7 +77,7 @@ class for more information on how to interact with non-tensor data
shuffle=False will also impact the sampling. We advice users to
create a copy of the dataset where the ``shuffle`` attribute of the
sampler is set to ``False`` if they wish to enjoy the two different
- behaviours (shuffled and not) within the same code base.
+ behaviors (shuffled and not) within the same code base.
num_slices (int, optional): the number of slices in a batch. This
corresponds to the number of trajectories present in a batch.
@@ -134,7 +134,7 @@ class for more information on how to interact with non-tensor data
the dataset. This isn't possible at a reasonable cost with
`streaming=True`: in this case, trajectories will be sampled
one at a time and delivered as such (with cropping to comply with
- the batch-size etc). The behaviour of the two modalities is
+ the batch-size etc). The behavior of the two modalities is
much more similar when `num_slices` and `slice_len` are specified,
as in these cases, views of sub-episodes will be returned in both
cases.
diff --git a/torchrl/data/replay_buffers/replay_buffers.py b/torchrl/data/replay_buffers/replay_buffers.py
index a688bd8585e..2e0eeb80705 100644
--- a/torchrl/data/replay_buffers/replay_buffers.py
+++ b/torchrl/data/replay_buffers/replay_buffers.py
@@ -7,6 +7,7 @@
import collections
import contextlib
import json
+import multiprocessing
import textwrap
import threading
import warnings
@@ -23,6 +24,7 @@
is_tensorclass,
LazyStackedTensorDict,
NestedKey,
+ TensorDict,
TensorDictBase,
unravel_key,
)
@@ -120,7 +122,15 @@ class ReplayBuffer:
>>> for d in data.unbind(1):
... rb.add(d)
>>> rb.extend(data)
+ generator (torch.Generator, optional): a generator to use for sampling.
+ Using a dedicated generator for the replay buffer can allow a fine-grained control
+ over seeding, for instance keeping the global seed different but the RB seed identical
+ for distributed jobs.
+ Defaults to ``None`` (global default generator).
+ .. warning:: As of now, the generator has no effect on the transforms.
+ shared (bool, optional): whether the buffer will be shared using multiprocessing or not.
+ Defaults to ``False``.
Examples:
>>> import torch
@@ -204,6 +214,8 @@ def __init__(
batch_size: int | None = None,
dim_extend: int | None = None,
checkpointer: "StorageCheckpointerBase" | None = None, # noqa: F821
+ generator: torch.Generator | None = None,
+ shared: bool = False,
) -> None:
self._storage = storage if storage is not None else ListStorage(max_size=1_000)
self._storage.attach(self)
@@ -220,6 +232,9 @@ def __init__(
if self._prefetch_cap:
self._prefetch_executor = ThreadPoolExecutor(max_workers=self._prefetch_cap)
+ self.shared = shared
+ self.share(self.shared)
+
self._replay_lock = threading.RLock()
self._futures_lock = threading.RLock()
from torchrl.envs.transforms.transforms import (
@@ -262,6 +277,20 @@ def __init__(
raise ValueError("dim_extend must be a positive value.")
self.dim_extend = dim_extend
self._storage.checkpointer = checkpointer
+ self.set_rng(generator=generator)
+
+ def share(self, shared: bool = True):
+ self.shared = shared
+ if self.shared:
+ self._write_lock = multiprocessing.Lock()
+ else:
+ self._write_lock = contextlib.nullcontext()
+
+ def set_rng(self, generator):
+ self._rng = generator
+ self._storage._rng = generator
+ self._sampler._rng = generator
+ self._writer._rng = generator
@property
def dim_extend(self):
@@ -335,6 +364,11 @@ def __len__(self) -> int:
with self._replay_lock:
return len(self._storage)
+ @property
+ def write_count(self):
+ """The total number of items written so far in the buffer through add and extend."""
+ return self._writer._write_count
+
def __repr__(self) -> str:
from torchrl.envs.transforms import Compose
@@ -414,6 +448,9 @@ def state_dict(self) -> Dict[str, Any]:
"_writer": self._writer.state_dict(),
"_transforms": self._transform.state_dict(),
"_batch_size": self._batch_size,
+ "_rng": (self._rng.get_state().clone(), str(self._rng.device))
+ if self._rng is not None
+ else None,
}
def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
@@ -422,6 +459,12 @@ def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
self._writer.load_state_dict(state_dict["_writer"])
self._transform.load_state_dict(state_dict["_transforms"])
self._batch_size = state_dict["_batch_size"]
+ rng = state_dict.get("_rng")
+ if rng is not None:
+ state, device = rng
+ rng = torch.Generator(device=device)
+ rng.set_state(state)
+ self.set_rng(generator=rng)
def dumps(self, path):
"""Saves the replay buffer on disk at the specified path.
@@ -465,6 +508,13 @@ def dumps(self, path):
self._storage.dumps(path / "storage")
self._sampler.dumps(path / "sampler")
self._writer.dumps(path / "writer")
+ if self._rng is not None:
+ rng_state = TensorDict(
+ rng_state=self._rng.get_state().clone(),
+ device=self._rng.device,
+ )
+ rng_state.memmap(path / "rng_state")
+
# fall back on state_dict for transforms
transform_sd = self._transform.state_dict()
if transform_sd:
@@ -487,6 +537,11 @@ def loads(self, path):
self._storage.loads(path / "storage")
self._sampler.loads(path / "sampler")
self._writer.loads(path / "writer")
+ if (path / "rng_state").exists():
+ rng_state = TensorDict.load_memmap(path / "rng_state")
+ rng = torch.Generator(device=rng_state.device)
+ rng.set_state(rng_state["rng_state"])
+ self.set_rng(rng)
# fall back on state_dict for transforms
if (path / "transform.t").exists():
self._transform.load_state_dict(torch.load(path / "transform.t"))
@@ -540,13 +595,13 @@ def add(self, data: Any) -> int:
return self._add(data)
def _add(self, data):
- with self._replay_lock:
+ with self._replay_lock, self._write_lock:
index = self._writer.add(data)
self._sampler.add(index)
return index
def _extend(self, data: Sequence) -> torch.Tensor:
- with self._replay_lock:
+ with self._replay_lock, self._write_lock:
if self.dim_extend > 0:
data = self._transpose(data)
index = self._writer.extend(data)
@@ -594,7 +649,7 @@ def update_priority(
if self.dim_extend > 0 and priority.ndim > 1:
priority = self._transpose(priority).flatten()
# priority = priority.flatten()
- with self._replay_lock:
+ with self._replay_lock, self._write_lock:
self._sampler.update_priority(index, priority, storage=self.storage)
@pin_memory_output
@@ -753,6 +808,12 @@ def __iter__(self):
def __getstate__(self) -> Dict[str, Any]:
state = self.__dict__.copy()
+ if self._rng is not None:
+ rng_state = TensorDict(
+ rng_state=self._rng.get_state().clone(),
+ device=self._rng.device,
+ )
+ state["_rng"] = rng_state
_replay_lock = state.pop("_replay_lock", None)
_futures_lock = state.pop("_futures_lock", None)
if _replay_lock is not None:
@@ -762,6 +823,13 @@ def __getstate__(self) -> Dict[str, Any]:
return state
def __setstate__(self, state: Dict[str, Any]):
+ rngstate = None
+ if "_rng" in state:
+ rngstate = state["_rng"]
+ if rngstate is not None:
+ rng = torch.Generator(device=rngstate.device)
+ rng.set_state(rngstate["rng_state"])
+
if "_replay_lock_placeholder" in state:
state.pop("_replay_lock_placeholder")
_replay_lock = threading.RLock()
@@ -771,6 +839,8 @@ def __setstate__(self, state: Dict[str, Any]):
_futures_lock = threading.RLock()
state["_futures_lock"] = _futures_lock
self.__dict__.update(state)
+ if rngstate is not None:
+ self.set_rng(rng)
@property
def sampler(self):
@@ -995,6 +1065,15 @@ class TensorDictReplayBuffer(ReplayBuffer):
>>> for d in data.unbind(1):
... rb.add(d)
>>> rb.extend(data)
+ generator (torch.Generator, optional): a generator to use for sampling.
+ Using a dedicated generator for the replay buffer can allow a fine-grained control
+ over seeding, for instance keeping the global seed different but the RB seed identical
+ for distributed jobs.
+ Defaults to ``None`` (global default generator).
+
+ .. warning:: As of now, the generator has no effect on the transforms.
+ shared (bool, optional): whether the buffer will be shared using multiprocessing or not.
+ Defaults to ``False``.
Examples:
>>> import torch
@@ -1207,7 +1286,7 @@ def sample(
if include_info is not None:
warnings.warn(
"include_info is going to be deprecated soon."
- "The default behaviour has changed to `include_info=True` "
+ "The default behavior has changed to `include_info=True` "
"to avoid bugs linked to wrongly preassigned values in the "
"output tensordict."
)
@@ -1327,6 +1406,15 @@ class TensorDictPrioritizedReplayBuffer(TensorDictReplayBuffer):
>>> for d in data.unbind(1):
... rb.add(d)
>>> rb.extend(data)
+ generator (torch.Generator, optional): a generator to use for sampling.
+ Using a dedicated generator for the replay buffer can allow a fine-grained control
+ over seeding, for instance keeping the global seed different but the RB seed identical
+ for distributed jobs.
+ Defaults to ``None`` (global default generator).
+
+ .. warning:: As of now, the generator has no effect on the transforms.
+ shared (bool, optional): whether the buffer will be shared using multiprocessing or not.
+ Defaults to ``False``.
Examples:
>>> import torch
@@ -1400,6 +1488,8 @@ def __init__(
reduction: str = "max",
batch_size: int | None = None,
dim_extend: int | None = None,
+ generator: torch.Generator | None = None,
+ shared: bool = False,
) -> None:
if storage is None:
storage = ListStorage(max_size=1_000)
@@ -1416,6 +1506,8 @@ def __init__(
transform=transform,
batch_size=batch_size,
dim_extend=dim_extend,
+ generator=generator,
+ shared=shared,
)
@@ -1454,12 +1546,18 @@ def update_tensordict_priority(self, data: TensorDictBase) -> None:
class InPlaceSampler:
"""A sampler to write tennsordicts in-place.
- To be used cautiously as this may lead to unexpected behaviour (i.e. tensordicts
+ .. warning:: This class is deprecated and will be removed in v0.7.
+
+ To be used cautiously as this may lead to unexpected behavior (i.e. tensordicts
overwritten during execution).
"""
def __init__(self, device: DEVICE_TYPING | None = None):
+ warnings.warn(
+ "InPlaceSampler has been deprecated and will be removed in v0.7.",
+ category=DeprecationWarning,
+ )
self.out = None
if device is None:
device = "cpu"
@@ -1555,6 +1653,15 @@ class ReplayBufferEnsemble(ReplayBuffer):
sampled according to the probabilities ``p``. Can also
be passed to torchrl.data.replay_buffers.samplers.SamplerEnsemble`
if the buffer is built explicitely.
+ generator (torch.Generator, optional): a generator to use for sampling.
+ Using a dedicated generator for the replay buffer can allow a fine-grained control
+ over seeding, for instance keeping the global seed different but the RB seed identical
+ for distributed jobs.
+ Defaults to ``None`` (global default generator).
+
+ .. warning:: As of now, the generator has no effect on the transforms.
+ shared (bool, optional): whether the buffer will be shared using multiprocessing or not.
+ Defaults to ``False``.
Examples:
>>> from torchrl.envs import Compose, ToTensorImage, Resize, RenameTransform
@@ -1644,6 +1751,8 @@ def __init__(
p: Tensor = None,
sample_from_all: bool = False,
num_buffer_sampled: int | None = None,
+ generator: torch.Generator | None = None,
+ shared: bool = False,
**kwargs,
):
@@ -1680,6 +1789,8 @@ def __init__(
transform=transform,
batch_size=batch_size,
collate_fn=collate_fn,
+ generator=generator,
+ shared=shared,
**kwargs,
)
diff --git a/torchrl/data/replay_buffers/samplers.py b/torchrl/data/replay_buffers/samplers.py
index 582ac88f52d..45fede16cf5 100644
--- a/torchrl/data/replay_buffers/samplers.py
+++ b/torchrl/data/replay_buffers/samplers.py
@@ -22,7 +22,7 @@
from torchrl._extension import EXTENSION_WARNING
-from torchrl._utils import _replace_last, implement_for, logger
+from torchrl._utils import _replace_last, logger
from torchrl.data.replay_buffers.storages import Storage, StorageEnsemble, TensorStorage
from torchrl.data.replay_buffers.utils import _is_int, unravel_index
@@ -46,6 +46,9 @@ class Sampler(ABC):
# need to keep track of the number of remaining batches
_remaining_batches = int(torch.iinfo(torch.int64).max)
+ # The RNG is set by the replay buffer
+ _rng: torch.Generator | None = None
+
@abstractmethod
def sample(self, storage: Storage, batch_size: int) -> Tuple[Any, dict]:
...
@@ -64,7 +67,7 @@ def update_priority(
storage: Storage | None = None,
) -> dict | None:
warnings.warn(
- f"Calling update_priority() on a sampler {type(self).__name__} that is not prioritized. Make sure this is the indented behaviour."
+ f"Calling update_priority() on a sampler {type(self).__name__} that is not prioritized. Make sure this is the indented behavior."
)
return
@@ -105,6 +108,11 @@ def loads(self, path):
def __repr__(self):
return f"{self.__class__.__name__}()"
+ def __getstate__(self):
+ state = copy(self.__dict__)
+ state["_rng"] = None
+ return state
+
class RandomSampler(Sampler):
"""A uniformly random sampler for composable replay buffers.
@@ -192,7 +200,9 @@ def _get_sample_list(self, storage: Storage, len_storage: int, batch_size: int):
device = storage.device if hasattr(storage, "device") else None
if self.shuffle:
- _sample_list = torch.randperm(len_storage, device=device)
+ _sample_list = torch.randperm(
+ len_storage, device=device, generator=self._rng
+ )
else:
_sample_list = torch.arange(len_storage, device=device)
self._sample_list = _sample_list
@@ -385,13 +395,28 @@ def __repr__(self):
def max_size(self):
return self._max_capacity
+ @property
+ def alpha(self):
+ return self._alpha
+
+ @alpha.setter
+ def alpha(self, value):
+ self._alpha = value
+
+ @property
+ def beta(self):
+ return self._beta
+
+ @beta.setter
+ def beta(self, value):
+ self._beta = value
+
def __getstate__(self):
if get_spawning_popen() is not None:
raise RuntimeError(
f"Samplers of type {type(self)} cannot be shared between processes."
)
- state = copy(self.__dict__)
- return state
+ return super().__getstate__()
def _init(self):
if self.dtype in (torch.float, torch.FloatType, torch.float32):
@@ -473,7 +498,11 @@ def sample(self, storage: Storage, batch_size: int) -> torch.Tensor:
raise RuntimeError("non-positive p_min")
# For some undefined reason, only np.random works here.
# All PT attempts fail, even when subsequently transformed into numpy
- mass = np.random.uniform(0.0, p_sum, size=batch_size)
+ if self._rng is None:
+ mass = np.random.uniform(0.0, p_sum, size=batch_size)
+ else:
+ mass = torch.rand(batch_size, generator=self._rng) * p_sum
+
# mass = torch.zeros(batch_size, dtype=torch.double).uniform_(0.0, p_sum)
# mass = torch.rand(batch_size).mul_(p_sum)
index = self._sum_tree.scan_lower_bound(mass)
@@ -929,7 +958,7 @@ def __getstate__(self):
f"one process will NOT erase the cache on another process's sampler, "
f"which will cause synchronization issues."
)
- state = copy(self.__dict__)
+ state = super().__getstate__()
state["_cache"] = {}
return state
@@ -969,22 +998,20 @@ def _find_start_stop_traj(
# faster
end = trajectory[:-1] != trajectory[1:]
- end = torch.cat([end, trajectory[-1:] != trajectory[:1]], 0)
+ if not at_capacity:
+ end = torch.cat([end, torch.ones_like(end[:1])], 0)
+ else:
+ end = torch.cat([end, trajectory[-1:] != trajectory[:1]], 0)
length = trajectory.shape[0]
else:
- # TODO: check that storage is at capacity here, if not we need to assume that the last element of end is True
-
# We presume that not done at the end means that the traj spans across end and beginning of storage
length = end.shape[0]
+ if not at_capacity:
+ end = end.clone()
+ end[length - 1] = True
+ ndim = end.ndim
- if not at_capacity:
- end = torch.index_fill(
- end,
- index=torch.tensor(-1, device=end.device, dtype=torch.long),
- dim=0,
- value=1,
- )
- else:
+ if at_capacity:
# we must have at least one end by traj to individuate trajectories
# so if no end can be found we set it manually
if cursor is not None:
@@ -1006,7 +1033,6 @@ def _find_start_stop_traj(
mask = ~end.any(0, True)
mask = torch.cat([torch.zeros_like(end[:-1]), mask])
end = torch.masked_fill(mask, end, 1)
- ndim = end.ndim
if ndim == 0:
raise RuntimeError(
"Expected the end-of-trajectory signal to be at least 1-dimensional."
@@ -1063,9 +1089,28 @@ def _tensor_slices_from_startend(self, seq_length, start, storage_length):
# seq_length is a 1d tensor indicating the desired length of each sequence
if isinstance(seq_length, int):
- result = torch.cat(
- [self._start_to_end(_start, length=seq_length) for _start in start]
+ arange = torch.arange(seq_length, device=start.device, dtype=start.dtype)
+ ndims = start.shape[-1] - 1 if (start.ndim - 1) else 0
+ if ndims:
+ arange_reshaped = torch.empty(
+ arange.shape + torch.Size([ndims + 1]),
+ device=start.device,
+ dtype=start.dtype,
+ )
+ arange_reshaped[..., 0] = arange
+ arange_reshaped[..., 1:] = 0
+ else:
+ arange_reshaped = arange.unsqueeze(-1)
+ arange_expanded = arange_reshaped.expand(
+ torch.Size([start.shape[0]]) + arange_reshaped.shape
)
+ if start.shape != arange_expanded.shape:
+ n_missing_dims = arange_expanded.dim() - start.dim()
+ start_expanded = start[
+ (slice(None),) + (None,) * n_missing_dims
+ ].expand_as(arange_expanded)
+ result = (start_expanded + arange_expanded).flatten(0, 1)
+
else:
# when padding is needed
result = torch.cat(
@@ -1094,7 +1139,7 @@ def _get_stop_and_length(self, storage, fallback=True):
"Could not get a tensordict out of the storage, which is required for SliceSampler to compute the trajectories."
)
vals = self._find_start_stop_traj(
- trajectory=trajectory,
+ trajectory=trajectory.clone(),
at_capacity=storage._is_full,
cursor=getattr(storage, "_last_cursor", None),
)
@@ -1187,7 +1232,9 @@ def _sample_slices(
# start_idx and stop_idx are 2d tensors organized like a non-zero
def get_traj_idx(maxval):
- return torch.randint(maxval, (num_slices,), device=lengths.device)
+ return torch.randint(
+ maxval, (num_slices,), device=lengths.device, generator=self._rng
+ )
if (lengths < seq_length).any():
if self.strict_length:
@@ -1290,7 +1337,8 @@ def _get_index(
start_point = -span_right
relative_starts = (
- torch.rand(num_slices, device=lengths.device) * (end_point - start_point)
+ torch.rand(num_slices, device=lengths.device, generator=self._rng)
+ * (end_point - start_point)
).floor().to(start_idx.dtype) + start_point
if self.span[0]:
@@ -1800,34 +1848,38 @@ def __repr__(self):
def __getstate__(self):
state = SliceSampler.__getstate__(self)
state.update(PrioritizedSampler.__getstate__(self))
+ return state
def mark_update(
self, index: Union[int, torch.Tensor], *, storage: Storage | None = None
) -> None:
return PrioritizedSampler.mark_update(self, index, storage=storage)
- @implement_for("torch", "2.4")
def _padded_indices(self, shapes, arange) -> torch.Tensor:
# this complex mumbo jumbo creates a left padded tensor with valid indices on the right, e.g.
# tensor([[ 0, 1, 2, 3, 4],
# [-1, -1, 5, 6, 7],
# [-1, 8, 9, 10, 11]])
# where the -1 items on the left are padded values
- st, off = torch._nested_compute_contiguous_strides_offsets(shapes.flip(0))
- nt = torch._nested_view_from_buffer(
- arange.flip(0).contiguous(), shapes.flip(0), st, off
+ num_groups = shapes.shape[0]
+ max_group_len = shapes.max()
+ pad_lengths = max_group_len - shapes
+
+ # Get all the start and end indices within arange for each group
+ group_ends = shapes.cumsum(0)
+ group_starts = torch.empty_like(group_ends)
+ group_starts[0] = 0
+ group_starts[1:] = group_ends[:-1]
+ pad = torch.empty(
+ (num_groups, max_group_len), dtype=arange.dtype, device=arange.device
)
- pad = nt.to_padded_tensor(-1).flip(-1).flip(0)
- return pad
+ for pad_row, group_start, group_end, pad_len in zip(
+ pad, group_starts, group_ends, pad_lengths
+ ):
+ pad_row[:pad_len] = -1
+ pad_row[pad_len:] = arange[group_start:group_end]
- @implement_for("torch", None, "2.4")
- def _padded_indices(self, shapes, arange) -> torch.Tensor: # noqa: F811
- arange = arange.flip(0).split(shapes.flip(0).squeeze().unbind())
- return (
- torch.nn.utils.rnn.pad_sequence(arange, batch_first=True, padding_value=-1)
- .flip(-1)
- .flip(0)
- )
+ return pad
def _preceding_stop_idx(self, storage, lengths, seq_length, start_idx):
preceding_stop_idx = self._cache.get("preceding_stop_idx")
@@ -2033,6 +2085,7 @@ class SamplerEnsemble(Sampler):
def __init__(
self, *samplers, p=None, sample_from_all=False, num_buffer_sampled=None
):
+ self._rng_private = None
self._samplers = samplers
self.sample_from_all = sample_from_all
if sample_from_all and p is not None:
@@ -2042,6 +2095,16 @@ def __init__(
self.p = p
self.num_buffer_sampled = num_buffer_sampled
+ @property
+ def _rng(self):
+ return self._rng_private
+
+ @_rng.setter
+ def _rng(self, value):
+ self._rng_private = value
+ for sampler in self._samplers:
+ sampler._rng = value
+
@property
def p(self):
return self._p
@@ -2082,7 +2145,10 @@ def sample(self, storage, batch_size):
else:
if self.p is None:
buffer_ids = torch.randint(
- len(self._samplers), (self.num_buffer_sampled,)
+ len(self._samplers),
+ (self.num_buffer_sampled,),
+ generator=self._rng,
+ device=getattr(storage, "device", None),
)
else:
buffer_ids = torch.multinomial(self.p, self.num_buffer_sampled, True)
diff --git a/torchrl/data/replay_buffers/scheduler.py b/torchrl/data/replay_buffers/scheduler.py
new file mode 100644
index 00000000000..6829424c620
--- /dev/null
+++ b/torchrl/data/replay_buffers/scheduler.py
@@ -0,0 +1,267 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+from __future__ import annotations
+
+from abc import ABC, abstractmethod
+
+from typing import Any, Callable, Dict
+
+import numpy as np
+
+import torch
+
+from torchrl.data.replay_buffers.replay_buffers import ReplayBuffer
+from torchrl.data.replay_buffers.samplers import Sampler
+
+
+class ParameterScheduler(ABC):
+ """Scheduler to adjust the value of a given parameter of a replay buffer's sampler.
+
+ Scheduler can for example be used to alter the alpha and beta values in the PrioritizedSampler.
+
+ Args:
+ obj (ReplayBuffer or Sampler): the replay buffer or sampler whose sampler to adjust
+ param_name (str): the name of the attribute to adjust, e.g. `beta` to adjust the beta parameter
+ min_value (Union[int, float], optional): a lower bound for the parameter to be adjusted
+ Defaults to `None`.
+ max_value (Union[int, float], optional): an upper bound for the parameter to be adjusted
+ Defaults to `None`.
+
+ """
+
+ def __init__(
+ self,
+ obj: ReplayBuffer | Sampler,
+ param_name: str,
+ min_value: int | float | None = None,
+ max_value: int | float | None = None,
+ ):
+ if not isinstance(obj, (ReplayBuffer, Sampler)):
+ raise TypeError(
+ f"ParameterScheduler only supports Sampler class. Pass either `ReplayBuffer` or `Sampler` object. Got {type(obj)} instead."
+ )
+ self.sampler = obj.sampler if isinstance(obj, ReplayBuffer) else obj
+ self.param_name = param_name
+ self._min_val = min_value or float("-inf")
+ self._max_val = max_value or float("inf")
+ if not hasattr(self.sampler, self.param_name):
+ raise ValueError(
+ f"Provided class {type(obj).__name__} does not have an attribute {param_name}"
+ )
+ initial_val = getattr(self.sampler, self.param_name)
+ if isinstance(initial_val, torch.Tensor):
+ initial_val = initial_val.clone()
+ self.backend = torch
+ else:
+ self.backend = np
+ self.initial_val = initial_val
+ self._step_cnt = 0
+
+ def state_dict(self):
+ """Returns the state of the scheduler as a :class:`dict`.
+
+ It contains an entry for every variable in ``self.__dict__`` which
+ is not the sampler.
+ """
+ sd = dict(self.__dict__)
+ del sd["sampler"]
+ return sd
+
+ def load_state_dict(self, state_dict: Dict[str, Any]):
+ """Load the scheduler's state.
+
+ Args:
+ state_dict (dict): scheduler state. Should be an object returned
+ from a call to :meth:`state_dict`.
+ """
+ self.__dict__.update(state_dict)
+
+ def step(self):
+ self._step_cnt += 1
+ # Apply the step function
+ new_value = self._step()
+ # clip value to specified range
+ new_value_clipped = self.backend.clip(new_value, self._min_val, self._max_val)
+ # Set the new value of the parameter dynamically
+ setattr(self.sampler, self.param_name, new_value_clipped)
+
+ @abstractmethod
+ def _step(self):
+ ...
+
+
+class LambdaScheduler(ParameterScheduler):
+ """Sets a parameter to its initial value times a given function.
+
+ Similar to :class:`~torch.optim.LambdaLR`.
+
+ Args:
+ obj (ReplayBuffer or Sampler): the replay buffer whose sampler to adjust (or the sampler itself).
+ param_name (str): the name of the attribute to adjust, e.g. `beta` to adjust the
+ beta parameter.
+ lambda_fn (Callable[[int], float]): A function which computes a multiplicative factor given an integer
+ parameter ``step_count``.
+ min_value (Union[int, float], optional): a lower bound for the parameter to be adjusted
+ Defaults to `None`.
+ max_value (Union[int, float], optional): an upper bound for the parameter to be adjusted
+ Defaults to `None`.
+
+ """
+
+ def __init__(
+ self,
+ obj: ReplayBuffer | Sampler,
+ param_name: str,
+ lambda_fn: Callable[[int], float],
+ min_value: int | float | None = None,
+ max_value: int | float | None = None,
+ ):
+ super().__init__(obj, param_name, min_value, max_value)
+ self.lambda_fn = lambda_fn
+
+ def _step(self):
+ return self.initial_val * self.lambda_fn(self._step_cnt)
+
+
+class LinearScheduler(ParameterScheduler):
+ """A linear scheduler for gradually altering a parameter in an object over a given number of steps.
+
+ This scheduler linearly interpolates between the initial value of the parameter and a final target value.
+
+ Args:
+ obj (ReplayBuffer or Sampler): the replay buffer whose sampler to adjust (or the sampler itself).
+ param_name (str): the name of the attribute to adjust, e.g. `beta` to adjust the
+ beta parameter.
+ final_value (number): The final value that the parameter will reach after the
+ specified number of steps.
+ num_steps (number, optional): The total number of steps over which the parameter
+ will be linearly altered.
+
+ Example:
+ >>> # xdoctest: +SKIP
+ >>> # Assuming sampler uses initial beta = 0.6
+ >>> # beta = 0.7 if step == 1
+ >>> # beta = 0.8 if step == 2
+ >>> # beta = 0.9 if step == 3
+ >>> # beta = 1.0 if step >= 4
+ >>> scheduler = LinearScheduler(sampler, param_name='beta', final_value=1.0, num_steps=4)
+ >>> for epoch in range(100):
+ >>> train(...)
+ >>> validate(...)
+ >>> scheduler.step()
+ """
+
+ def __init__(
+ self,
+ obj: ReplayBuffer | Sampler,
+ param_name: str,
+ final_value: int | float,
+ num_steps: int,
+ ):
+ super().__init__(obj, param_name)
+ if isinstance(self.initial_val, torch.Tensor):
+ # cast to same type as initial value
+ final_value = torch.tensor(final_value).to(self.initial_val)
+ self.final_val = final_value
+ self.num_steps = num_steps
+ self._delta = (self.final_val - self.initial_val) / self.num_steps
+
+ def _step(self):
+ # Nit: we should use torch.where instead than if/else here to make the scheduler compatible with compile
+ # without graph breaks
+ if self._step_cnt < self.num_steps:
+ return self.initial_val + (self._delta * self._step_cnt)
+ else:
+ return self.final_val
+
+
+class StepScheduler(ParameterScheduler):
+ """A step scheduler that alters a parameter after every n steps using either multiplicative or additive changes.
+
+ The scheduler can apply:
+ 1. Multiplicative changes: `new_val = curr_val * gamma`
+ 2. Additive changes: `new_val = curr_val + gamma`
+
+ Args:
+ obj (ReplayBuffer or Sampler): the replay buffer whose sampler to adjust (or the sampler itself).
+ param_name (str): the name of the attribute to adjust, e.g. `beta` to adjust the
+ beta parameter.
+ gamma (int or float, optional): The value by which to adjust the parameter,
+ either in a multiplicative or additive way.
+ n_steps (int, optional): The number of steps after which the parameter should be altered.
+ Defaults to 1.
+ mode (str, optional): The mode of scheduling. Can be either `'multiplicative'` or `'additive'`.
+ Defaults to `'multiplicative'`.
+ min_value (int or float, optional): a lower bound for the parameter to be adjusted.
+ Defaults to `None`.
+ max_value (int or float, optional): an upper bound for the parameter to be adjusted.
+ Defaults to `None`.
+
+ Example:
+ >>> # xdoctest: +SKIP
+ >>> # Assuming sampler uses initial beta = 0.6
+ >>> # beta = 0.6 if 0 <= step < 10
+ >>> # beta = 0.7 if 10 <= step < 20
+ >>> # beta = 0.8 if 20 <= step < 30
+ >>> # beta = 0.9 if 30 <= step < 40
+ >>> # beta = 1.0 if 40 <= step
+ >>> scheduler = StepScheduler(sampler, param_name='beta', gamma=0.1, mode='additive', max_value=1.0)
+ >>> for epoch in range(100):
+ >>> train(...)
+ >>> validate(...)
+ >>> scheduler.step()
+ """
+
+ def __init__(
+ self,
+ obj: ReplayBuffer | Sampler,
+ param_name: str,
+ gamma: int | float = 0.9,
+ n_steps: int = 1,
+ mode: str = "multiplicative",
+ min_value: int | float | None = None,
+ max_value: int | float | None = None,
+ ):
+
+ super().__init__(obj, param_name, min_value, max_value)
+ self.gamma = gamma
+ self.n_steps = n_steps
+ self.mode = mode
+ if mode == "additive":
+ operator = self.backend.add
+ elif mode == "multiplicative":
+ operator = self.backend.multiply
+ else:
+ raise ValueError(
+ f"Invalid mode: {mode}. Choose 'multiplicative' or 'additive'."
+ )
+ self.operator = operator
+
+ def _step(self):
+ """Applies the scheduling logic to alter the parameter value every `n_steps`."""
+ # Check if the current step count is a multiple of n_steps
+ current_val = getattr(self.sampler, self.param_name)
+ # Nit: we should use torch.where instead than if/else here to make the scheduler compatible with compile
+ # without graph breaks
+ if self._step_cnt % self.n_steps == 0:
+ return self.operator(current_val, self.gamma)
+ else:
+ return current_val
+
+
+class SchedulerList:
+ """Simple container abstracting a list of schedulers."""
+
+ def __init__(self, schedulers: list[ParameterScheduler]) -> None:
+ if isinstance(schedulers, ParameterScheduler):
+ schedulers = [schedulers]
+ self.schedulers = schedulers
+
+ def append(self, scheduler: ParameterScheduler):
+ self.schedulers.append(scheduler)
+
+ def step(self):
+ for scheduler in self.schedulers:
+ scheduler.step()
diff --git a/torchrl/data/replay_buffers/storages.py b/torchrl/data/replay_buffers/storages.py
index 58b1729296d..a36c59b66d9 100644
--- a/torchrl/data/replay_buffers/storages.py
+++ b/torchrl/data/replay_buffers/storages.py
@@ -5,6 +5,8 @@
from __future__ import annotations
import abc
+
+import logging
import os
import textwrap
import warnings
@@ -22,6 +24,7 @@
TensorDict,
TensorDictBase,
)
+from tensordict.base import _NESTED_TENSORS_AS_LISTS
from tensordict.memmap import MemoryMappedTensor
from torch import multiprocessing as mp
from torch.utils._pytree import tree_flatten, tree_map, tree_unflatten
@@ -54,6 +57,7 @@ class Storage:
ndim = 1
max_size: int
_default_checkpointer: StorageCheckpointerBase = StorageCheckpointerBase
+ _rng: torch.Generator | None = None
def __init__(
self, max_size: int, checkpointer: StorageCheckpointerBase | None = None
@@ -142,7 +146,13 @@ def _empty(self):
def _rand_given_ndim(self, batch_size):
# a method to return random indices given the storage ndim
if self.ndim == 1:
- return torch.randint(0, len(self), (batch_size,))
+ return torch.randint(
+ 0,
+ len(self),
+ (batch_size,),
+ generator=self._rng,
+ device=getattr(self, "device", None),
+ )
raise RuntimeError(
f"Random number generation is not implemented for storage of type {type(self)} with ndim {self.ndim}. "
f"Please report this exception as well as the use case (incl. buffer construction) on github."
@@ -185,6 +195,11 @@ def load(self, *args, **kwargs):
"""Alias for :meth:`~.loads`."""
return self.loads(*args, **kwargs)
+ def __getstate__(self):
+ state = copy(self.__dict__)
+ state["_rng"] = None
+ return state
+
class ListStorage(Storage):
"""A storage stored in a list.
@@ -299,7 +314,7 @@ def __getstate__(self):
raise RuntimeError(
f"Cannot share a storage of type {type(self)} between processes."
)
- state = copy(self.__dict__)
+ state = super().__getstate__()
return state
def __repr__(self):
@@ -497,7 +512,10 @@ def _rand_given_ndim(self, batch_size):
if self.ndim == 1:
return super()._rand_given_ndim(batch_size)
shape = self.shape
- return tuple(torch.randint(_dim, (batch_size,)) for _dim in shape)
+ return tuple(
+ torch.randint(_dim, (batch_size,), generator=self._rng, device=self.device)
+ for _dim in shape
+ )
def flatten(self):
if self.ndim == 1:
@@ -522,7 +540,7 @@ def flatten(self):
)
def __getstate__(self):
- state = copy(self.__dict__)
+ state = super().__getstate__()
if get_spawning_popen() is None:
length = self._len
del state["_len_value"]
@@ -539,15 +557,24 @@ def __getstate__(self):
# check that the content is shared, otherwise tell the user we can't help
storage = self._storage
STORAGE_ERR = "The storage must be place in shared memory or memmapped before being shared between processes."
+
+ # If the content is on cpu, it will be placed in shared memory.
+ # If it's on cuda it's already shared.
+ # If it's memmaped no worry in this case either.
+ # Only if the device is not "cpu" or "cuda" we may have a problem.
+ def assert_is_sharable(tensor):
+ if tensor.device is None or tensor.device.type in (
+ "cuda",
+ "cpu",
+ "meta",
+ ):
+ return
+ raise RuntimeError(STORAGE_ERR)
+
if is_tensor_collection(storage):
- if not storage.is_memmap() and not storage.is_shared():
- raise RuntimeError(STORAGE_ERR)
+ storage.apply(assert_is_sharable, filter_empty=True)
else:
- if (
- not isinstance(storage, MemoryMappedTensor)
- and not storage.is_shared()
- ):
- raise RuntimeError(STORAGE_ERR)
+ tree_map(storage, assert_is_sharable)
return state
@@ -722,7 +749,7 @@ def set( # noqa: F811
"A cursor of length superior to the storage capacity was provided. "
"To accommodate for this, the cursor will be truncated to its last "
"element such that its length matched the length of the storage. "
- "This may **not** be the optimal behaviour for your application! "
+ "This may **not** be the optimal behavior for your application! "
"Make sure that the storage capacity is big enough to support the "
"batch size provided."
)
@@ -882,9 +909,7 @@ def max_size_along_dim0(data_shape):
if is_tensor_collection(data):
out = data.to(self.device)
- out = out.expand(max_size_along_dim0(data.shape))
- out = out.clone()
- out = out.zero_()
+ out = torch.empty_like(out.expand(max_size_along_dim0(data.shape)))
else:
# if Tensor, we just create a MemoryMappedTensor of the desired shape, device and dtype
out = tree_map(
@@ -906,6 +931,8 @@ class LazyMemmapStorage(LazyTensorStorage):
Args:
max_size (int): size of the storage, i.e. maximum number of elements stored
in the buffer.
+
+ Keyword Args:
scratch_dir (str or path): directory where memmap-tensors will be written.
device (torch.device, optional): device where the sampled tensors will be
stored and sent. Default is :obj:`torch.device("cpu")`.
@@ -916,6 +943,9 @@ class LazyMemmapStorage(LazyTensorStorage):
measuring the storage size. For instance, a storage of shape ``[3, 4]``
has capacity ``3`` if ``ndim=1`` and ``12`` if ``ndim=2``.
Defaults to ``1``.
+ existsok (bool, optional): whether an error should be raised if any of the
+ tensors already exists on disk. Defaults to ``True``. If ``False``, the
+ tensor will be opened as is, not overewritten.
.. note:: When checkpointing a ``LazyMemmapStorage``, one can provide a path identical to where the storage is
already stored to avoid executing long copies of data that is already stored on disk.
@@ -992,10 +1022,12 @@ def __init__(
scratch_dir=None,
device: torch.device = "cpu",
ndim: int = 1,
+ existsok: bool = False,
):
super().__init__(max_size, ndim=ndim)
self.initialized = False
self.scratch_dir = None
+ self.existsok = existsok
if scratch_dir is not None:
self.scratch_dir = str(scratch_dir)
if self.scratch_dir[-1] != "/":
@@ -1091,17 +1123,23 @@ def max_size_along_dim0(data_shape):
if is_tensor_collection(data):
out = data.clone().to(self.device)
out = out.expand(max_size_along_dim0(data.shape))
- out = out.memmap_like(prefix=self.scratch_dir)
- for key, tensor in sorted(
- out.items(include_nested=True, leaves_only=True), key=str
- ):
- try:
- filesize = os.path.getsize(tensor.filename) / 1024 / 1024
- torchrl_logger.debug(
- f"\t{key}: {tensor.filename}, {filesize} Mb of storage (size: {tensor.shape})."
- )
- except (AttributeError, RuntimeError):
- pass
+ out = out.memmap_like(prefix=self.scratch_dir, existsok=self.existsok)
+ if torchrl_logger.isEnabledFor(logging.DEBUG):
+ for key, tensor in sorted(
+ out.items(
+ include_nested=True,
+ leaves_only=True,
+ is_leaf=_NESTED_TENSORS_AS_LISTS,
+ ),
+ key=str,
+ ):
+ try:
+ filesize = os.path.getsize(tensor.filename) / 1024 / 1024
+ torchrl_logger.debug(
+ f"\t{key}: {tensor.filename}, {filesize} Mb of storage (size: {tensor.shape})."
+ )
+ except (AttributeError, RuntimeError):
+ pass
else:
out = _init_pytree(self.scratch_dir, max_size_along_dim0, data)
self._storage = out
@@ -1142,6 +1180,7 @@ def __init__(
*storages: Storage,
transforms: List["Transform"] = None, # noqa: F821
):
+ self._rng_private = None
self._storages = storages
self._transforms = transforms
if transforms is not None and len(transforms) != len(storages):
@@ -1149,6 +1188,16 @@ def __init__(
"transforms must have the same length as the storages " "provided."
)
+ @property
+ def _rng(self):
+ return self._rng_private
+
+ @_rng.setter
+ def _rng(self, value):
+ self._rng_private = value
+ for storage in self._storages:
+ storage._rng = value
+
@property
def _attached_entities(self):
return set()
diff --git a/torchrl/data/replay_buffers/utils.py b/torchrl/data/replay_buffers/utils.py
index 3a4141fd218..15a90f1a8f5 100644
--- a/torchrl/data/replay_buffers/utils.py
+++ b/torchrl/data/replay_buffers/utils.py
@@ -802,7 +802,7 @@ def _path2str(path, default_name=None):
if result == default_name:
raise RuntimeError(
"A tensor had the same identifier as the default name used when the buffer contains "
- f"a single tensor (name={default_name}). This behaviour is not allowed. Please rename your "
+ f"a single tensor (name={default_name}). This behavior is not allowed. Please rename your "
f"tensor in the map/dict or set a new default name with the environment variable SINGLE_TENSOR_BUFFER_NAME."
)
return result
diff --git a/torchrl/data/replay_buffers/writers.py b/torchrl/data/replay_buffers/writers.py
index ea3b2b4a047..3a95c3975cc 100644
--- a/torchrl/data/replay_buffers/writers.py
+++ b/torchrl/data/replay_buffers/writers.py
@@ -38,6 +38,7 @@ class Writer(ABC):
"""A ReplayBuffer base Writer class."""
_storage: Storage
+ _rng: torch.Generator | None = None
def __init__(self) -> None:
self._storage = None
@@ -103,6 +104,11 @@ def _replicate_index(self, index):
def __repr__(self):
return f"{self.__class__.__name__}()"
+ def __getstate__(self):
+ state = copy(self.__dict__)
+ state["_rng"] = None
+ return state
+
class ImmutableDatasetWriter(Writer):
"""A blocking writer for immutable datasets."""
@@ -157,6 +163,7 @@ def add(self, data: Any) -> int | torch.Tensor:
self._cursor = (self._cursor + 1) % self._storage._max_size_along_dim0(
single_data=data
)
+ self._write_count += 1
# Replicate index requires the shape of the storage to be known
# Other than that, a "flat" (1d) index is ok to write the data
self._storage.set(_cursor, data)
@@ -185,6 +192,7 @@ def extend(self, data: Sequence) -> torch.Tensor:
)
# we need to update the cursor first to avoid race conditions between workers
self._cursor = (batch_size + cur_size) % max_size_along0
+ self._write_count += batch_size
# Replicate index requires the shape of the storage to be known
# Other than that, a "flat" (1d) index is ok to write the data
self._storage.set(index, data)
@@ -216,8 +224,22 @@ def _cursor(self, value):
_cursor_value = self._cursor_value = mp.Value("i", 0)
_cursor_value.value = value
+ @property
+ def _write_count(self):
+ _write_count = self.__dict__.get("_write_count_value", None)
+ if _write_count is None:
+ _write_count = self._write_count_value = mp.Value("i", 0)
+ return _write_count.value
+
+ @_write_count.setter
+ def _write_count(self, value):
+ _write_count = self.__dict__.get("_write_count_value", None)
+ if _write_count is None:
+ _write_count = self._write_count_value = mp.Value("i", 0)
+ _write_count.value = value
+
def __getstate__(self):
- state = copy(self.__dict__)
+ state = super().__getstate__()
if get_spawning_popen() is None:
cursor = self._cursor
del state["_cursor_value"]
@@ -243,6 +265,7 @@ def add(self, data: Any) -> int | torch.Tensor:
# we need to update the cursor first to avoid race conditions between workers
max_size_along_dim0 = self._storage._max_size_along_dim0(single_data=data)
self._cursor = (index + 1) % max_size_along_dim0
+ self._write_count += 1
if not is_tensorclass(data):
data.set(
"index",
@@ -269,6 +292,7 @@ def extend(self, data: Sequence) -> torch.Tensor:
)
# we need to update the cursor first to avoid race conditions between workers
self._cursor = (batch_size + cur_size) % max_size_along_dim0
+ self._write_count += batch_size
# storage must convert the data to the appropriate format if needed
if not is_tensorclass(data):
data.set(
@@ -454,6 +478,20 @@ def get_insert_index(self, data: Any) -> int:
return ret
+ @property
+ def _write_count(self):
+ _write_count = self.__dict__.get("_write_count_value", None)
+ if _write_count is None:
+ _write_count = self._write_count_value = mp.Value("i", 0)
+ return _write_count.value
+
+ @_write_count.setter
+ def _write_count(self, value):
+ _write_count = self.__dict__.get("_write_count_value", None)
+ if _write_count is None:
+ _write_count = self._write_count_value = mp.Value("i", 0)
+ _write_count.value = value
+
def add(self, data: Any) -> int | torch.Tensor:
"""Inserts a single element of data at an appropriate index, and returns that index.
@@ -463,6 +501,7 @@ def add(self, data: Any) -> int | torch.Tensor:
index = self.get_insert_index(data)
if index is not None:
data.set("index", index)
+ self._write_count += 1
# Replicate index requires the shape of the storage to be known
# Other than that, a "flat" (1d) index is ok to write the data
self._storage.set(index, data)
@@ -482,6 +521,7 @@ def extend(self, data: TensorDictBase) -> None:
for data_idx, sample in enumerate(data):
storage_idx = self.get_insert_index(sample)
if storage_idx is not None:
+ self._write_count += 1
data_to_replace[storage_idx] = data_idx
# -1 will be interpreted as invalid by prioritized buffers
@@ -511,9 +551,10 @@ def _empty(self) -> None:
def __getstate__(self):
if get_spawning_popen() is not None:
raise RuntimeError(
- f"Writers of type {type(self)} cannot be shared between processes."
+ f"Writers of type {type(self)} cannot be shared between processes. "
+ f"Please submit an issue at https://github.com/pytorch/rl if this feature is needed."
)
- state = copy(self.__dict__)
+ state = super().__getstate__()
return state
def dumps(self, path):
@@ -582,8 +623,19 @@ class WriterEnsemble(Writer):
"""
def __init__(self, *writers):
+ self._rng_private = None
self._writers = writers
+ @property
+ def _rng(self):
+ return self._rng_private
+
+ @_rng.setter
+ def _rng(self, value):
+ self._rng_private = value
+ for writer in self._writers:
+ writer._rng = value
+
def _empty(self):
raise NotImplementedError
diff --git a/torchrl/data/tensor_specs.py b/torchrl/data/tensor_specs.py
index 7c787b3ccfc..98a32de5715 100644
--- a/torchrl/data/tensor_specs.py
+++ b/torchrl/data/tensor_specs.py
@@ -6,6 +6,7 @@
from __future__ import annotations
import abc
+import enum
import math
import warnings
from collections.abc import Iterable
@@ -20,6 +21,7 @@
Generic,
List,
Optional,
+ overload,
Sequence,
Tuple,
TypeVar,
@@ -38,6 +40,7 @@
TensorDictBase,
unravel_key,
)
+from tensordict.base import NO_DEFAULT
from tensordict.utils import _getitem_batch_size, NestedKey
from torchrl._utils import _make_ordinal_device, get_binary_env_var
@@ -72,14 +75,27 @@
_DEFAULT_SHAPE = torch.Size((1,))
-DEVICE_ERR_MSG = "device of empty CompositeSpec is not defined."
+DEVICE_ERR_MSG = "device of empty Composite is not defined."
NOT_IMPLEMENTED_ERROR = NotImplementedError(
"method is not currently implemented."
" If you are interested in this feature please submit"
" an issue at https://github.com/pytorch/rl/issues"
)
-NO_DEFAULT = object()
+
+def _size(list_of_ints):
+ # ensures that np int64 elements don't slip through Size
+ # see https://github.com/pytorch/pytorch/issues/127194
+ return torch.Size([int(i) for i in list_of_ints])
+
+
+# Akin to TD's NO_DEFAULT but won't raise a KeyError when found in a TD or used as default
+class _NoDefault(enum.IntEnum):
+ ZERO = 0
+ ONE = 1
+
+
+NO_DEFAULT_RL = _NoDefault.ONE
def _default_dtype_and_device(
@@ -190,7 +206,7 @@ def _shape_indexing(
Shape of the resulting spec
Examples:
>>> idx = (2, ..., None)
- >>> DiscreteTensorSpec(2, shape=(3, 4))[idx].shape
+ >>> Categorical(2, shape=(3, 4))[idx].shape
torch.Size([4, 1])
>>> _shape_indexing([3, 4], idx)
torch.Size([4, 1])
@@ -350,7 +366,7 @@ def to(self, dest: Union[torch.dtype, DEVICE_TYPING]) -> ContinuousBox:
def __repr__(self):
return f"{self.__class__.__name__}()"
- def clone(self) -> DiscreteBox:
+ def clone(self) -> CategoricalBox:
return deepcopy(self)
@@ -387,16 +403,6 @@ def high(self, value):
self.device = value.device
self._high = value.cpu()
- @low.setter
- def low(self, value):
- self.device = value.device
- self._low = value.cpu()
-
- @high.setter
- def high(self, value):
- self.device = value.device
- self._high = value.cpu()
-
def __post_init__(self):
self.low = self.low.clone()
self.high = self.high.clone()
@@ -450,19 +456,25 @@ def __eq__(self, other):
@dataclass(repr=False)
-class DiscreteBox(Box):
- """A box of discrete values."""
+class CategoricalBox(Box):
+ """A box of discrete, categorical values."""
n: int
register = invertible_dict()
- def to(self, dest: Union[torch.dtype, DEVICE_TYPING]) -> DiscreteBox:
+ def to(self, dest: Union[torch.dtype, DEVICE_TYPING]) -> CategoricalBox:
return deepcopy(self)
def __repr__(self):
return f"{self.__class__.__name__}(n={self.n})"
+class DiscreteBox(CategoricalBox):
+ """Deprecated version of :class:`CategoricalBox`."""
+
+ ...
+
+
@dataclass(repr=False)
class BoxList(Box):
"""A box of discrete values."""
@@ -485,7 +497,7 @@ def __len__(self):
@staticmethod
def from_nvec(nvec: torch.Tensor):
if nvec.ndim == 0:
- return DiscreteBox(nvec.item())
+ return CategoricalBox(nvec.item())
else:
return BoxList([BoxList.from_nvec(n) for n in nvec.unbind(-1)])
@@ -505,14 +517,30 @@ def __repr__(self):
@dataclass(repr=False)
class TensorSpec:
- """Parent class of the tensor meta-data containers for observation, actions and rewards.
+ """Parent class of the tensor meta-data containers.
+
+ TorchRL's TensorSpec are used to present what input/output is to be expected for a specific class,
+ or sometimes to simulate simple behaviors by generating random data within a defined space.
+
+ TensorSpecs are primarily used in environments to specify their input/output structure without needing to
+ execute the environment (or starting it). They can also be used to instantiate shared buffers to pass
+ data from worker to worker.
+
+ TensorSpecs are dataclasses that always share the following fields: `shape`, `space, `dtype` and `device`.
+
+ As such, TensorSpecs possess some common behavior with :class:`~torch.Tensor` and :class:`~tensordict.TensorDict`:
+ they can be reshaped, indexed, squeezed, unsqueezed, moved to another device etc.
Args:
- shape (torch.Size): size of the tensor
- space (Box): Box instance describing what kind of values can be
- expected
- device (torch.device): device of the tensor
- dtype (torch.dtype): dtype of the tensor
+ shape (torch.Size): size of the tensor. The shape includes the batch dimensions as well as the feature
+ dimension. A negative shape (``-1``) means that the dimension has a variable number of elements.
+ space (Box): Box instance describing what kind of values can be expected.
+ device (torch.device): device of the tensor.
+ dtype (torch.dtype): dtype of the tensor.
+
+ .. note:: A spec can be constructed from a :class:`~tensordict.TensorDict` using the :func:`~torchrl.envs.utils.make_composite_from_td`
+ function. This function makes a low-assumption educated guess on the specs that may correspond to the input
+ tensordict and can help to build specs automatically without an in-depth knowledge of the `TensorSpec` API.
"""
@@ -537,21 +565,35 @@ def decorator(func):
@property
def device(self) -> torch.device:
+ """The device of the spec.
+
+ Only :class:`Composite` specs can have a ``None`` device. All leaves must have a non-null device.
+ """
return self._device
@device.setter
def device(self, device: torch.device | None) -> None:
self._device = _make_ordinal_device(device)
- def clear_device_(self):
- """A no-op for all leaf specs (which must have a device)."""
+ def clear_device_(self) -> T:
+ """A no-op for all leaf specs (which must have a device).
+
+ For :class:`Composite` specs, this method will erase the device.
+ """
return self
def encode(
- self, val: Union[np.ndarray, torch.Tensor], *, ignore_device=False
- ) -> torch.Tensor:
+ self,
+ val: np.ndarray | torch.Tensor | TensorDictBase,
+ *,
+ ignore_device: bool = False,
+ ) -> torch.Tensor | TensorDictBase:
"""Encodes a value given the specified spec, and return the corresponding tensor.
+ This method is to be used in environments that return a value (eg, a numpy array) that can be
+ easily mapped to the TorchRL required domain.
+ If the value is already a tensor, the spec will not change its value and return it as-is.
+
Args:
val (np.ndarray or torch.Tensor): value to be encoded as tensor.
@@ -604,11 +646,15 @@ def __ne__(self, other):
def __setattr__(self, key, value):
if key == "shape":
- value = torch.Size(value)
+ value = _size(value)
super().__setattr__(key, value)
- def to_numpy(self, val: torch.Tensor, safe: bool = None) -> np.ndarray:
- """Returns the np.ndarray correspondent of an input tensor.
+ def to_numpy(
+ self, val: torch.Tensor | TensorDictBase, safe: bool = None
+ ) -> np.ndarray | dict:
+ """Returns the ``np.ndarray`` correspondent of an input tensor.
+
+ This is intended to be the inverse operation of :meth:`.encode`.
Args:
val (torch.Tensor): tensor to be transformed_in to numpy.
@@ -617,7 +663,7 @@ def to_numpy(self, val: torch.Tensor, safe: bool = None) -> np.ndarray:
Defaults to the value of the ``CHECK_SPEC_ENCODE`` environment variable.
Returns:
- a np.ndarray
+ a np.ndarray.
"""
if safe is None:
@@ -627,19 +673,31 @@ def to_numpy(self, val: torch.Tensor, safe: bool = None) -> np.ndarray:
return val.detach().cpu().numpy()
@property
- def ndim(self):
+ def ndim(self) -> int:
+ """Number of dimensions of the spec shape.
+
+ Shortcut for ``len(spec.shape)``.
+
+ """
return self.ndimension()
- def ndimension(self):
+ def ndimension(self) -> int:
+ """Number of dimensions of the spec shape.
+
+ Shortcut for ``len(spec.shape)``.
+
+ """
return len(self.shape)
@property
- def _safe_shape(self):
+ def _safe_shape(self) -> torch.Size:
"""Returns a shape where all heterogeneous values are replaced by one (to be expandable)."""
- return torch.Size([int(v) if v >= 0 else 1 for v in self.shape])
+ return _size([int(v) if v >= 0 else 1 for v in self.shape])
@abc.abstractmethod
- def index(self, index: INDEX_TYPING, tensor_to_index: torch.Tensor) -> torch.Tensor:
+ def index(
+ self, index: INDEX_TYPING, tensor_to_index: torch.Tensor | TensorDictBase
+ ) -> torch.Tensor | TensorDictBase:
"""Indexes the input tensor.
Args:
@@ -652,20 +710,25 @@ def index(self, index: INDEX_TYPING, tensor_to_index: torch.Tensor) -> torch.Ten
"""
...
+ @overload
+ def expand(self, shape: torch.Size):
+ ...
+
@abc.abstractmethod
- def expand(self, *shape):
- """Returns a new Spec with the extended shape.
+ def expand(self, *shape: int) -> T:
+ """Returns a new Spec with the expanded shape.
Args:
- *shape (tuple or iterable of int): the new shape of the Spec. Must comply with the current shape:
+ *shape (tuple or iterable of int): the new shape of the Spec.
+ Must be broadcastable with the current shape:
its length must be at least as long as the current shape length,
- and its last values must be complient too; ie they can only differ
+ and its last values must be compliant too; ie they can only differ
from it if the current dimension is a singleton.
"""
...
- def squeeze(self, dim: int | None = None):
+ def squeeze(self, dim: int | None = None) -> T:
"""Returns a new Spec with all the dimensions of size ``1`` removed.
When ``dim`` is given, a squeeze operation is done only in that dimension.
@@ -679,21 +742,30 @@ def squeeze(self, dim: int | None = None):
return self
return self.__class__(shape=shape, device=self.device, dtype=self.dtype)
- def unsqueeze(self, dim: int):
+ def unsqueeze(self, dim: int) -> T:
+ """Returns a new Spec with one more singleton dimension (at the position indicated by ``dim``).
+
+ Args:
+ dim (int or None): the dimension to apply the unsqueeze operation to.
+
+ """
shape = _unsqueezed_shape(self.shape, dim)
return self.__class__(shape=shape, device=self.device, dtype=self.dtype)
- def make_neg_dim(self, dim):
+ def make_neg_dim(self, dim: int) -> T:
+ """Converts a specific dimension to ``-1``."""
if dim < 0:
dim = self.ndim + dim
if dim < 0 or dim > self.ndim - 1:
raise ValueError(f"dim={dim} is out of bound for ndim={self.ndim}")
- self.shape = torch.Size(
- [s if i != dim else -1 for i, s in enumerate(self.shape)]
- )
+ self.shape = _size([s if i != dim else -1 for i, s in enumerate(self.shape)])
+
+ @overload
+ def reshape(self, shape) -> T:
+ ...
- def reshape(self, *shape):
- """Reshapes a tensorspec.
+ def reshape(self, *shape) -> T:
+ """Reshapes a ``TensorSpec``.
Check :func:`~torch.reshape` for more information on this method.
@@ -705,23 +777,23 @@ def reshape(self, *shape):
view = reshape
@abc.abstractmethod
- def _reshape(self, shape):
+ def _reshape(self, shape: torch.Size) -> T:
...
- def unflatten(self, dim, sizes):
- """Unflattens a tensorspec.
+ def unflatten(self, dim: int, sizes: Tuple[int]) -> T:
+ """Unflattens a ``TensorSpec``.
Check :func:`~torch.unflatten` for more information on this method.
"""
return self._unflatten(dim, sizes)
- def _unflatten(self, dim, sizes):
+ def _unflatten(self, dim: int, sizes: Tuple[int]) -> T:
shape = torch.zeros(self.shape, device="meta").unflatten(dim, sizes).shape
return self._reshape(shape)
- def flatten(self, start_dim, end_dim):
- """Flattens a tensorspec.
+ def flatten(self, start_dim: int, end_dim: int) -> T:
+ """Flattens a ``TensorSpec``.
Check :func:`~torch.flatten` for more information on this method.
@@ -733,31 +805,39 @@ def _flatten(self, start_dim, end_dim):
return self._reshape(shape)
@abc.abstractmethod
- def _project(self, val: torch.Tensor) -> torch.Tensor:
+ def _project(
+ self, val: torch.Tensor | TensorDictBase
+ ) -> torch.Tensor | TensorDictBase:
raise NotImplementedError(type(self))
@abc.abstractmethod
- def is_in(self, val: torch.Tensor) -> bool:
- """If the value :obj:`val` is in the box defined by the TensorSpec, returns True, otherwise False.
+ def is_in(self, val: torch.Tensor | TensorDictBase) -> bool:
+ """If the value ``val`` could have been generated by the ``TensorSpec``, returns ``True``, otherwise ``False``.
+
+ More precisely, the ``is_in`` methods checks that the value ``val`` is within the limits defined by the ``space``
+ attribute (the box), and that the ``dtype``, ``device``, ``shape`` potentially other metadata match those
+ of the spec. If any of these checks fails, the ``is_in`` method will return ``False``.
Args:
- val (torch.Tensor): value to be checked
+ val (torch.Tensor): value to be checked.
Returns:
- boolean indicating if values belongs to the TensorSpec box
+ boolean indicating if values belongs to the TensorSpec box.
"""
...
- def contains(self, item):
- """Returns whether a sample is contained within the space defined by the TensorSpec.
+ def contains(self, item: torch.Tensor | TensorDictBase) -> bool:
+ """If the value ``val`` could have been generated by the ``TensorSpec``, returns ``True``, otherwise ``False``.
See :meth:`~.is_in` for more information.
"""
return self.is_in(item)
- def project(self, val: torch.Tensor) -> torch.Tensor:
- """If the input tensor is not in the TensorSpec box, it maps it back to it given some heuristic.
+ def project(
+ self, val: torch.Tensor | TensorDictBase
+ ) -> torch.Tensor | TensorDictBase:
+ """If the input tensor is not in the TensorSpec box, it maps it back to it given some defined heuristic.
Args:
val (torch.Tensor): tensor to be mapped to the box.
@@ -785,10 +865,10 @@ def assert_is_in(self, value: torch.Tensor) -> None:
)
def type_check(self, value: torch.Tensor, key: NestedKey = None) -> None:
- """Checks the input value dtype against the TensorSpec dtype and raises an exception if they don't match.
+ """Checks the input value ``dtype`` against the ``TensorSpec`` ``dtype`` and raises an exception if they don't match.
Args:
- value (torch.Tensor): tensor whose dtype has to be checked
+ value (torch.Tensor): tensor whose dtype has to be checked.
key (str, optional): if the TensorSpec has keys, the value
dtype will be checked against the spec pointed by the
indicated key.
@@ -801,8 +881,11 @@ def type_check(self, value: torch.Tensor, key: NestedKey = None) -> None:
)
@abc.abstractmethod
- def rand(self, shape=None) -> torch.Tensor:
- """Returns a random tensor in the space defined by the spec. The sampling will be uniform unless the box is unbounded.
+ def rand(self, shape: torch.Size = None) -> torch.Tensor | TensorDictBase:
+ """Returns a random tensor in the space defined by the spec.
+
+ The sampling will be done uniformly over the space, unless the box is unbounded in which case normal values
+ will be drawn.
Args:
shape (torch.Size): shape of the random tensor
@@ -811,19 +894,22 @@ def rand(self, shape=None) -> torch.Tensor:
a random tensor sampled in the TensorSpec box.
"""
- raise NotImplementedError
+ ...
- @property
- def sample(self):
+ def sample(self, shape: torch.Size = None) -> torch.Tensor | TensorDictBase:
"""Returns a random tensor in the space defined by the spec.
See :meth:`~.rand` for details.
"""
- return self.rand
+ return self.rand(shape=shape)
- def zero(self, shape=None) -> torch.Tensor:
+ def zero(self, shape: torch.Size = None) -> torch.Tensor | TensorDictBase:
"""Returns a zero-filled tensor in the box.
+ .. note:: Even though there is no guarantee that ``0`` belongs to the spec domain,
+ this method will not raise an exception when this condition is violated.
+ The primary use case of ``zero`` is to generate empty data buffers, not meaningful data.
+
Args:
shape (torch.Size): shape of the zero-tensor
@@ -832,26 +918,59 @@ def zero(self, shape=None) -> torch.Tensor:
"""
if shape is None:
- shape = torch.Size([])
+ shape = _size([])
return torch.zeros(
(*shape, *self._safe_shape), dtype=self.dtype, device=self.device
)
+ def zeros(self, shape: torch.Size = None) -> torch.Tensor | TensorDictBase:
+ """Proxy to :meth:`~.zero`."""
+ return self.zero(shape=shape)
+
+ def one(self, shape: torch.Size = None) -> torch.Tensor | TensorDictBase:
+ """Returns a one-filled tensor in the box.
+
+ .. note:: Even though there is no guarantee that ``1`` belongs to the spec domain,
+ this method will not raise an exception when this condition is violated.
+ The primary use case of ``one`` is to generate empty data buffers, not meaningful data.
+
+ Args:
+ shape (torch.Size): shape of the one-tensor
+
+ Returns:
+ a one-filled tensor sampled in the TensorSpec box.
+
+ """
+ if self.dtype == torch.bool:
+ return ~self.zero(shape=shape)
+ return self.zero(shape) + 1
+
+ def ones(self, shape: torch.Size = None) -> torch.Tensor | TensorDictBase:
+ """Proxy to :meth:`~.one`."""
+ return self.one(shape=shape)
+
@abc.abstractmethod
def to(self, dest: Union[torch.dtype, DEVICE_TYPING]) -> "TensorSpec":
- raise NotImplementedError
+ """Casts a TensorSpec to a device or a dtype.
+
+ Returns the same spec if no change is made.
+ """
+ ...
def cpu(self):
+ """Casts the TensorSpec to 'cpu' device."""
return self.to("cpu")
def cuda(self, device=None):
+ """Casts the TensorSpec to 'cuda' device."""
if device is None:
return self.to("cuda")
return self.to(f"cuda:{device}")
@abc.abstractmethod
def clone(self) -> "TensorSpec":
- raise NotImplementedError
+ """Creates a copy of the TensorSpec."""
+ ...
def __repr__(self):
shape_str = indent("shape=" + str(self.shape), " " * 4)
@@ -898,7 +1017,7 @@ def __init__(self, *specs: tuple[T, ...], dim: int) -> None:
self.dim = len(self.shape) + self.dim
def clear_device_(self):
- """Clears the device of the CompositeSpec."""
+ """Clears the device of the Composite."""
for spec in self._specs:
spec.clear_device_()
return self
@@ -997,7 +1116,7 @@ def clone(self) -> T:
def stack_dim(self):
return self.dim
- def zero(self, shape=None) -> TensorDictBase:
+ def zero(self, shape: torch.Size = None) -> TensorDictBase:
if shape is not None:
dim = self.dim + len(shape)
else:
@@ -1008,7 +1127,7 @@ def zero(self, shape=None) -> TensorDictBase:
)
return torch.nested.nested_tensor([spec.zero(shape) for spec in self._specs])
- def one(self, shape=None) -> TensorDictBase:
+ def one(self, shape: torch.Size = None) -> TensorDictBase:
if shape is not None:
dim = self.dim + len(shape)
else:
@@ -1019,7 +1138,7 @@ def one(self, shape=None) -> TensorDictBase:
)
return torch.nested.nested_tensor([spec.one(shape) for spec in self._specs])
- def rand(self, shape=None) -> TensorDictBase:
+ def rand(self, shape: torch.Size = None) -> TensorDictBase:
if shape is not None:
dim = self.dim + len(shape)
else:
@@ -1125,7 +1244,7 @@ def squeeze(self, dim: int = None):
)
-class LazyStackedTensorSpec(_LazyStackedMixin[TensorSpec], TensorSpec):
+class Stacked(_LazyStackedMixin[TensorSpec], TensorSpec):
"""A lazy representation of a stack of tensor specs.
Stacks tensor-specs together along one dimension.
@@ -1134,13 +1253,13 @@ class LazyStackedTensorSpec(_LazyStackedMixin[TensorSpec], TensorSpec):
Indexing is allowed but only along the stack dimension.
- This class is aimed to be used in multi-task and multi-agent settings, where
+ This class aims at being used in multi-tasks and multi-agent settings, where
heterogeneous specs may occur (same semantic but different shape).
"""
def __eq__(self, other):
- if not isinstance(other, LazyStackedTensorSpec):
+ if not isinstance(other, Stacked):
return False
if self.device != other.device:
raise RuntimeError((self, other))
@@ -1161,8 +1280,7 @@ def to_numpy(self, val: torch.Tensor, safe: bool = None) -> dict:
if safe:
if val.shape[self.dim] != len(self._specs):
raise ValueError(
- "Size of LazyStackedTensorSpec and val differ along the stacking "
- "dimension"
+ "Size of Stacked and val differ along the stacking " "dimension"
)
for spec, v in zip(self._specs, torch.unbind(val, dim=self.dim)):
spec.assert_is_in(v)
@@ -1174,7 +1292,7 @@ def __repr__(self):
dtype_str = "dtype=" + str(self.dtype)
domain_str = "domain=" + str(self._specs[0].domain)
sub_string = ", ".join([shape_str, device_str, dtype_str, domain_str])
- string = f"LazyStacked{self._specs[0].__class__.__name__}(\n {sub_string})"
+ string = f"Stacked{self._specs[0].__class__.__name__}(\n {sub_string})"
return string
@property
@@ -1204,7 +1322,7 @@ def shape(self):
if dim < 0:
dim = len(shape) + dim + 1
shape.insert(dim, len(self._specs))
- return torch.Size(shape)
+ return _size(shape)
@shape.setter
def shape(self, shape):
@@ -1216,7 +1334,7 @@ def shape(self, shape):
raise RuntimeError(
f"The shape attribute mismatches between the input {shape} and self.shape={self.shape}."
)
- shape_strip = torch.Size([s for i, s in enumerate(self.shape) if i != self.dim])
+ shape_strip = _size([s for i, s in enumerate(self.shape) if i != self.dim])
for spec in self._specs:
spec.shape = shape_strip
@@ -1295,7 +1413,7 @@ def encode(
@dataclass(repr=False)
-class OneHotDiscreteTensorSpec(TensorSpec):
+class OneHot(TensorSpec):
"""A unidimensional, one-hot discrete tensor spec.
By default, TorchRL assumes that categorical variables are encoded as
@@ -1316,10 +1434,10 @@ class OneHotDiscreteTensorSpec(TensorSpec):
Args:
n (int): number of possible outcomes.
shape (torch.Size, optional): total shape of the sampled tensors.
- If provided, the last dimension must match n.
+ If provided, the last dimension must match ``n``.
device (str, int or torch.device, optional): device of the tensors.
dtype (str or torch.dtype, optional): dtype of the tensors.
- user_register (bool): experimental feature. If True, every integer
+ use_register (bool): experimental feature. If ``True``, every integer
will be mapped onto a binary vector in the order in which they
appear. This feature is designed for environment with no
a-priori definition of the number of possible outcomes (e.g.
@@ -1329,16 +1447,29 @@ class OneHotDiscreteTensorSpec(TensorSpec):
mask (torch.Tensor or None): mask some of the possible outcomes when a
sample is taken. See :meth:`~.update_mask` for more information.
+ Examples:
+ >>> from torchrl.data.tensor_specs import OneHot
+ >>> spec = OneHot(5, shape=(2, 5))
+ >>> spec.rand()
+ tensor([[False, True, False, False, False],
+ [False, True, False, False, False]])
+ >>> mask = torch.tensor([
+ ... [False, False, False, False, True],
+ ... [False, False, False, False, True]
+ ... ])
+ >>> spec.update_mask(mask)
+ >>> spec.rand()
+ tensor([[False, False, False, False, True],
+ [False, False, False, False, True]])
+
"""
shape: torch.Size
- space: DiscreteBox
+ space: CategoricalBox
device: torch.device | None = None
dtype: torch.dtype = torch.float
domain: str = ""
- # SPEC_HANDLED_FUNCTIONS = {}
-
def __init__(
self,
n: int,
@@ -1350,11 +1481,11 @@ def __init__(
):
dtype, device = _default_dtype_and_device(dtype, device)
self.use_register = use_register
- space = DiscreteBox(n)
+ space = CategoricalBox(n)
if shape is None:
- shape = torch.Size((space.n,))
+ shape = _size((space.n,))
else:
- shape = torch.Size(shape)
+ shape = _size(shape)
if not len(shape) or shape[-1] != space.n:
raise ValueError(
f"The last value of the shape must match n for transform of type {self.__class__}. "
@@ -1378,12 +1509,12 @@ def update_mask(self, mask):
mask (torch.Tensor or None): boolean mask. If None, the mask is
disabled. Otherwise, the shape of the mask must be expandable to
the shape of the spec. ``False`` masks an outcome and ``True``
- leaves the outcome unmasked. If all of the possible outcomes are
+ leaves the outcome unmasked. If all the possible outcomes are
masked, then an error is raised when a sample is taken.
Examples:
>>> mask = torch.tensor([True, False, False])
- >>> ts = OneHotDiscreteTensorSpec(3, (2, 3,), dtype=torch.int64, mask=mask)
+ >>> ts = OneHot(3, (2, 3,), dtype=torch.int64, mask=mask)
>>> # All but one of the three possible outcomes are masked
>>> ts.rand()
tensor([[1, 0, 0],
@@ -1398,7 +1529,7 @@ def update_mask(self, mask):
raise ValueError("Only boolean masks are accepted.")
self.mask = mask
- def to(self, dest: Union[torch.dtype, DEVICE_TYPING]) -> CompositeSpec:
+ def to(self, dest: Union[torch.dtype, DEVICE_TYPING]) -> OneHot:
if dest is None:
return self
if isinstance(dest, torch.dtype):
@@ -1418,7 +1549,7 @@ def to(self, dest: Union[torch.dtype, DEVICE_TYPING]) -> CompositeSpec:
mask=self.mask.to(dest) if self.mask is not None else None,
)
- def clone(self) -> OneHotDiscreteTensorSpec:
+ def clone(self) -> OneHot:
return self.__class__(
n=self.space.n,
shape=self.shape,
@@ -1536,11 +1667,11 @@ def unbind(self, dim: int = 0):
for i in range(self.shape[dim])
)
- def rand(self, shape=None) -> torch.Tensor:
+ def rand(self, shape: torch.Size = None) -> torch.Tensor:
if shape is None:
shape = self.shape[:-1]
else:
- shape = torch.Size([*shape, *self.shape[:-1]])
+ shape = _size([*shape, *self.shape[:-1]])
mask = self.mask
if mask is None:
n = self.space.n
@@ -1561,7 +1692,7 @@ def rand(self, shape=None) -> torch.Tensor:
def encode(
self,
val: Union[np.ndarray, torch.Tensor],
- space: Optional[DiscreteBox] = None,
+ space: Optional[CategoricalBox] = None,
*,
ignore_device: bool = False,
) -> torch.Tensor:
@@ -1619,7 +1750,7 @@ def __getitem__(self, idx: SHAPE_INDEX_TYPING):
indexed_shape = _shape_indexing(self.shape[:-1], idx)
return self.__class__(
n=self.space.n,
- shape=torch.Size(indexed_shape + [self.shape[-1]]),
+ shape=_size(indexed_shape + [self.shape[-1]]),
device=self.device,
dtype=self.dtype,
use_register=self.use_register,
@@ -1689,6 +1820,16 @@ def to_categorical(self, val: torch.Tensor, safe: bool = None) -> torch.Tensor:
Returns:
The categorical tensor.
+
+ Examples:
+ >>> one_hot = OneHot(3, shape=(2, 3))
+ >>> one_hot_sample = one_hot.rand()
+ >>> one_hot_sample
+ tensor([[False, True, False],
+ [False, True, False]])
+ >>> categ_sample = one_hot.to_categorical(one_hot_sample)
+ >>> categ_sample
+ tensor([1, 1])
"""
if safe is None:
safe = _CHECK_SPEC_ENCODE
@@ -1696,25 +1837,103 @@ def to_categorical(self, val: torch.Tensor, safe: bool = None) -> torch.Tensor:
self.assert_is_in(val)
return val.long().argmax(-1)
- def to_categorical_spec(self) -> DiscreteTensorSpec:
- """Converts the spec to the equivalent categorical spec."""
- return DiscreteTensorSpec(
+ def to_categorical_spec(self) -> Categorical:
+ """Converts the spec to the equivalent categorical spec.
+
+ Examples:
+ >>> one_hot = OneHot(3, shape=(2, 3))
+ >>> one_hot.to_categorical_spec()
+ Categorical(
+ shape=torch.Size([2]),
+ space=CategoricalBox(n=3),
+ device=cpu,
+ dtype=torch.int64,
+ domain=discrete)
+
+ """
+ return Categorical(
self.space.n,
device=self.device,
shape=self.shape[:-1],
mask=self.mask,
)
+ def to_one_hot(self, val: torch.Tensor, safe: bool = None) -> torch.Tensor:
+ """No-op for OneHot."""
+ return val
+
+ def to_one_hot_spec(self) -> OneHot:
+ """No-op for OneHot."""
+ return self
+
+
+class _BoundedMeta(abc.ABCMeta):
+ def __call__(cls, *args, **kwargs):
+ instance = super().__call__(*args, **kwargs)
+ if instance.domain == "continuous":
+ instance.__class__ = BoundedContinuous
+ else:
+ instance.__class__ = BoundedDiscrete
+ return instance
+
@dataclass(repr=False)
-class BoundedTensorSpec(TensorSpec):
- """A bounded continuous tensor spec.
+class Bounded(TensorSpec, metaclass=_BoundedMeta):
+ """A bounded tensor spec.
+
+ ``Bounded`` specs will never appear as such and always be subclassed as :class:`BoundedContinuous`
+ or :class:`BoundedDiscrete` depending on their dtype (floating points dtypes will result in
+ :class:`BoundedContinuous` instances, all others in :class:`BoundedDiscrete` instances).
Args:
low (np.ndarray, torch.Tensor or number): lower bound of the box.
high (np.ndarray, torch.Tensor or number): upper bound of the box.
+ shape (torch.Size): the shape of the ``Bounded`` spec. The shape must be specified.
+ Inputs ``low``, ``high`` and ``shape`` must be broadcastable.
device (str, int or torch.device, optional): device of the tensors.
dtype (str or torch.dtype, optional): dtype of the tensors.
+ domain (str): `"continuous"` or `"discrete"`. Can be used to override the automatic type assignment.
+
+ Examples:
+ >>> spec = Bounded(low=-1, high=1, shape=(), dtype=torch.float)
+ >>> spec
+ BoundedContinuous(
+ shape=torch.Size([]),
+ space=ContinuousBox(
+ low=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, contiguous=True),
+ high=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, contiguous=True)),
+ device=cpu,
+ dtype=torch.float32,
+ domain=continuous)
+ >>> spec = Bounded(low=-1, high=1, shape=(), dtype=torch.int)
+ >>> spec
+ BoundedDiscrete(
+ shape=torch.Size([]),
+ space=ContinuousBox(
+ low=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.int32, contiguous=True),
+ high=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.int32, contiguous=True)),
+ device=cpu,
+ dtype=torch.int32,
+ domain=discrete)
+ >>> spec.to(torch.float)
+ BoundedContinuous(
+ shape=torch.Size([]),
+ space=ContinuousBox(
+ low=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, contiguous=True),
+ high=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, contiguous=True)),
+ device=cpu,
+ dtype=torch.float32,
+ domain=continuous)
+ >>> spec = Bounded(low=-1, high=1, shape=(), dtype=torch.int, domain="continuous")
+ >>> spec
+ BoundedContinuous(
+ shape=torch.Size([]),
+ space=ContinuousBox(
+ low=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.int32, contiguous=True),
+ high=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.int32, contiguous=True)),
+ device=cpu,
+ dtype=torch.int32,
+ domain=continuous)
"""
@@ -1748,13 +1967,18 @@ def __init__(
"Minimum is deprecated since v0.4.0, using low instead.",
category=DeprecationWarning,
)
- domain = kwargs.pop("domain", "continuous")
+ domain = kwargs.pop("domain", None)
if len(kwargs):
raise TypeError(f"Got unrecognised kwargs {tuple(kwargs.keys())}.")
dtype, device = _default_dtype_and_device(dtype, device)
if dtype is None:
dtype = torch.get_default_dtype()
+ if domain is None:
+ if dtype.is_floating_point:
+ domain = "continuous"
+ else:
+ domain = "discrete"
if not isinstance(low, torch.Tensor):
low = torch.tensor(low, dtype=dtype, device=device)
@@ -1769,7 +1993,7 @@ def __init__(
if dtype is not None and high.dtype is not dtype:
high = high.to(dtype)
err_msg = (
- "BoundedTensorSpec requires the shape to be explicitely (via "
+ "Bounded requires the shape to be explicitely (via "
"the shape argument) or implicitely defined (via either the "
"minimum or the maximum or both). If the maximum and/or the "
"minimum have a non-singleton shape, they must match the "
@@ -1777,9 +2001,9 @@ def __init__(
)
if shape is not None and not isinstance(shape, torch.Size):
if isinstance(shape, int):
- shape = torch.Size([shape])
+ shape = _size([shape])
else:
- shape = torch.Size(list(shape))
+ shape = _size(list(shape))
if shape is not None:
shape_corr = _remove_neg_shapes(shape)
else:
@@ -1812,9 +2036,9 @@ def __init__(
shape = low.shape
else:
if isinstance(shape_corr, float):
- shape_corr = torch.Size([shape_corr])
+ shape_corr = _size([shape_corr])
elif not isinstance(shape_corr, torch.Size):
- shape_corr = torch.Size(shape_corr)
+ shape_corr = _size(shape_corr)
shape_corr_err_msg = (
f"low and shape_corr mismatch, got {low.shape} and {shape_corr}"
)
@@ -1945,9 +2169,9 @@ def unbind(self, dim: int = 0):
for low, high in zip(low, high)
)
- def rand(self, shape=None) -> torch.Tensor:
+ def rand(self, shape: torch.Size = None) -> torch.Tensor:
if shape is None:
- shape = torch.Size([])
+ shape = _size([])
a, b = self.space
if self.dtype in (torch.float, torch.double, torch.half):
shape = [*shape, *self._safe_shape]
@@ -1971,9 +2195,7 @@ def rand(self, shape=None) -> torch.Tensor:
else:
mini = self.space.low
interval = maxi - mini
- r = torch.rand(
- torch.Size([*shape, *self._safe_shape]), device=interval.device
- )
+ r = torch.rand(_size([*shape, *self._safe_shape]), device=interval.device)
r = interval * r
r = self.space.low + r
r = r.to(self.dtype).to(self.device)
@@ -2028,7 +2250,7 @@ def is_in(self, val: torch.Tensor) -> bool:
return False
raise err
- def to(self, dest: Union[torch.dtype, DEVICE_TYPING]) -> CompositeSpec:
+ def to(self, dest: Union[torch.dtype, DEVICE_TYPING]) -> Bounded:
if isinstance(dest, torch.dtype):
dest_dtype = dest
dest_device = self.device
@@ -2039,15 +2261,16 @@ def to(self, dest: Union[torch.dtype, DEVICE_TYPING]) -> CompositeSpec:
dest_device = torch.device(dest)
if dest_device == self.device and dest_dtype == self.dtype:
return self
- return self.__class__(
- low=self.space.low.to(dest),
- high=self.space.high.to(dest),
+ self.space.device = dest_device
+ return Bounded(
+ low=self.space.low,
+ high=self.space.high,
shape=self.shape,
device=dest_device,
dtype=dest_dtype,
)
- def clone(self) -> BoundedTensorSpec:
+ def clone(self) -> Bounded:
return self.__class__(
low=self.space.low.clone(),
high=self.space.high.clone(),
@@ -2063,7 +2286,7 @@ def __getitem__(self, idx: SHAPE_INDEX_TYPING):
"Pending resolution of https://github.com/pytorch/pytorch/issues/100080."
)
- indexed_shape = torch.Size(_shape_indexing(self.shape, idx))
+ indexed_shape = _size(_shape_indexing(self.shape, idx))
# Expand is required as pytorch.tensor indexing
return self.__class__(
low=self.space.low[idx].clone().expand(indexed_shape),
@@ -2074,6 +2297,45 @@ def __getitem__(self, idx: SHAPE_INDEX_TYPING):
)
+class BoundedContinuous(Bounded, metaclass=_BoundedMeta):
+ """A specialized version of :class:`torchrl.data.Bounded` with continuous space."""
+
+ def __init__(
+ self,
+ low: Union[float, torch.Tensor, np.ndarray] = None,
+ high: Union[float, torch.Tensor, np.ndarray] = None,
+ shape: Optional[Union[torch.Size, int]] = None,
+ device: Optional[DEVICE_TYPING] = None,
+ dtype: Optional[Union[torch.dtype, str]] = None,
+ domain: str = "continuous",
+ ):
+ super().__init__(
+ low=low, high=high, shape=shape, device=device, dtype=dtype, domain=domain
+ )
+
+
+class BoundedDiscrete(Bounded, metaclass=_BoundedMeta):
+ """A specialized version of :class:`torchrl.data.Bounded` with discrete space."""
+
+ def __init__(
+ self,
+ low: Union[float, torch.Tensor, np.ndarray] = None,
+ high: Union[float, torch.Tensor, np.ndarray] = None,
+ shape: Optional[Union[torch.Size, int]] = None,
+ device: Optional[DEVICE_TYPING] = None,
+ dtype: Optional[Union[torch.dtype, str]] = None,
+ domain: str = "discrete",
+ ):
+ super().__init__(
+ low=low,
+ high=high,
+ shape=shape,
+ device=device,
+ dtype=dtype,
+ domain=domain,
+ )
+
+
def _is_nested_list(index, notuple=False):
if not notuple and isinstance(index, tuple):
for idx in index:
@@ -2088,8 +2350,14 @@ def _is_nested_list(index, notuple=False):
return False
-class NonTensorSpec(TensorSpec):
- """A spec for non-tensor data."""
+class NonTensor(TensorSpec):
+ """A spec for non-tensor data.
+
+ This spec has a shae, device and dtype like :class:`~tensordict.NonTensorData`.
+
+ :meth:`.rand` will return a :class:`~tensordict.NonTensorData` object with `None` data value.
+ (same will go for :meth:`.zero` and :meth:`.one`).
+ """
def __init__(
self,
@@ -2099,7 +2367,7 @@ def __init__(
**kwargs,
):
if isinstance(shape, int):
- shape = torch.Size([shape])
+ shape = _size([shape])
_, device = _default_dtype_and_device(None, device)
domain = None
@@ -2107,7 +2375,7 @@ def __init__(
shape=shape, space=None, device=device, dtype=dtype, domain=domain, **kwargs
)
- def to(self, dest: Union[torch.dtype, DEVICE_TYPING]) -> NonTensorSpec:
+ def to(self, dest: Union[torch.dtype, DEVICE_TYPING]) -> NonTensor:
if isinstance(dest, torch.dtype):
dest_dtype = dest
dest_device = self.device
@@ -2120,7 +2388,7 @@ def to(self, dest: Union[torch.dtype, DEVICE_TYPING]) -> NonTensorSpec:
return self
return self.__class__(shape=self.shape, device=dest_device, dtype=None)
- def clone(self) -> NonTensorSpec:
+ def clone(self) -> NonTensor:
return self.__class__(shape=self.shape, device=self.device, dtype=self.dtype)
def rand(self, shape=None):
@@ -2158,7 +2426,7 @@ def is_in(self, val: torch.Tensor) -> bool:
def expand(self, *shape):
if len(shape) == 1 and isinstance(shape[0], (tuple, list, torch.Size)):
shape = shape[0]
- shape = torch.Size(shape)
+ shape = _size(shape)
if not all(
(old == 1) or (old == new)
for old, new in zip(self.shape, shape[-len(self.shape) :])
@@ -2181,7 +2449,7 @@ def _unflatten(self, dim, sizes):
def __getitem__(self, idx: SHAPE_INDEX_TYPING):
"""Indexes the current TensorSpec based on the provided index."""
- indexed_shape = torch.Size(_shape_indexing(self.shape, idx))
+ indexed_shape = _size(_shape_indexing(self.shape, idx))
return self.__class__(shape=indexed_shape, device=self.device, dtype=self.dtype)
def unbind(self, dim: int = 0):
@@ -2203,17 +2471,76 @@ def unbind(self, dim: int = 0):
)
+class _UnboundedMeta(abc.ABCMeta):
+ def __call__(cls, *args, **kwargs):
+ instance = super().__call__(*args, **kwargs)
+ if instance.domain == "continuous":
+ instance.__class__ = UnboundedContinuous
+ else:
+ instance.__class__ = UnboundedDiscrete
+ return instance
+
+
@dataclass(repr=False)
-class UnboundedContinuousTensorSpec(TensorSpec):
- """An unbounded continuous tensor spec.
+class Unbounded(TensorSpec, metaclass=_UnboundedMeta):
+ """An unbounded tensor spec.
+
+ ``Unbounded`` specs will never appear as such and always be subclassed as :class:`UnboundedContinuous`
+ or :class:`UnboundedDiscrete` depending on their dtype (floating points dtypes will result in
+ :class:`UnboundedContinuous` instances, all others in :class:`UnboundedDiscrete` instances).
+
+ Although it is not properly limited above and below, this class still has a :attr:`Box` space that encodes
+ the maximum and minimum value that the dtype accepts.
Args:
+ shape (torch.Size): the shape of the ``Bounded`` spec. The shape must be specified.
+ Inputs ``low``, ``high`` and ``shape`` must be broadcastable.
device (str, int or torch.device, optional): device of the tensors.
- dtype (str or torch.dtype, optional): dtype of the tensors
- (should be an floating point dtype such as float, double etc.)
- """
+ dtype (str or torch.dtype, optional): dtype of the tensors.
+ domain (str): `"continuous"` or `"discrete"`. Can be used to override the automatic type assignment.
- # SPEC_HANDLED_FUNCTIONS = {}
+ Examples:
+ >>> spec = Unbounded(shape=(), dtype=torch.float)
+ >>> spec
+ UnboundedContinuous(
+ shape=torch.Size([]),
+ space=ContinuousBox(
+ low=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, contiguous=True),
+ high=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, contiguous=True)),
+ device=cpu,
+ dtype=torch.float32,
+ domain=continuous)
+ >>> spec = Unbounded(shape=(), dtype=torch.int)
+ >>> spec
+ UnboundedDiscrete(
+ shape=torch.Size([]),
+ space=ContinuousBox(
+ low=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.int32, contiguous=True),
+ high=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.int32, contiguous=True)),
+ device=cpu,
+ dtype=torch.int32,
+ domain=discrete)
+ >>> spec.to(torch.float)
+ UnboundedContinuous(
+ shape=torch.Size([]),
+ space=ContinuousBox(
+ low=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, contiguous=True),
+ high=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, contiguous=True)),
+ device=cpu,
+ dtype=torch.float32,
+ domain=continuous)
+ >>> spec = Unbounded(shape=(), dtype=torch.int, domain="continuous")
+ >>> spec
+ UnboundedContinuous(
+ shape=torch.Size([]),
+ space=ContinuousBox(
+ low=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.int32, contiguous=True),
+ high=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.int32, contiguous=True)),
+ device=cpu,
+ dtype=torch.int32,
+ domain=continuous)
+
+ """
def __init__(
self,
@@ -2223,26 +2550,37 @@ def __init__(
**kwargs,
):
if isinstance(shape, int):
- shape = torch.Size([shape])
+ shape = _size([shape])
dtype, device = _default_dtype_and_device(dtype, device)
- box = (
- ContinuousBox(
- torch.as_tensor(-np.inf, device=device).expand(shape),
- torch.as_tensor(np.inf, device=device).expand(shape),
- )
- if shape == _DEFAULT_SHAPE
- else None
+ if dtype == torch.bool:
+ min_value = False
+ max_value = True
+ default_domain = "discrete"
+ else:
+ if dtype.is_floating_point:
+ min_value = torch.finfo(dtype).min
+ max_value = torch.finfo(dtype).max
+ default_domain = "continuous"
+ else:
+ min_value = torch.iinfo(dtype).min
+ max_value = torch.iinfo(dtype).max
+ default_domain = "discrete"
+ box = ContinuousBox(
+ torch.full(
+ _remove_neg_shapes(shape), min_value, device=device, dtype=dtype
+ ),
+ torch.full(
+ _remove_neg_shapes(shape), max_value, device=device, dtype=dtype
+ ),
)
- default_domain = "continuous" if dtype.is_floating_point else "discrete"
+
domain = kwargs.pop("domain", default_domain)
super().__init__(
shape=shape, space=box, device=device, dtype=dtype, domain=domain, **kwargs
)
- def to(
- self, dest: Union[torch.dtype, DEVICE_TYPING]
- ) -> UnboundedContinuousTensorSpec:
+ def to(self, dest: Union[torch.dtype, DEVICE_TYPING]) -> Unbounded:
if isinstance(dest, torch.dtype):
dest_dtype = dest
dest_device = self.device
@@ -2253,14 +2591,14 @@ def to(
dest_device = torch.device(dest)
if dest_device == self.device and dest_dtype == self.dtype:
return self
- return self.__class__(shape=self.shape, device=dest_device, dtype=dest_dtype)
+ return Unbounded(shape=self.shape, device=dest_device, dtype=dest_dtype)
- def clone(self) -> UnboundedContinuousTensorSpec:
+ def clone(self) -> Unbounded:
return self.__class__(shape=self.shape, device=self.device, dtype=self.dtype)
- def rand(self, shape=None) -> torch.Tensor:
+ def rand(self, shape: torch.Size = None) -> torch.Tensor:
if shape is None:
- shape = torch.Size([])
+ shape = _size([])
shape = [*shape, *self.shape]
if self.dtype.is_floating_point:
return torch.randn(shape, device=self.device, dtype=self.dtype)
@@ -2301,7 +2639,7 @@ def _unflatten(self, dim, sizes):
def __getitem__(self, idx: SHAPE_INDEX_TYPING):
"""Indexes the current TensorSpec based on the provided index."""
- indexed_shape = torch.Size(_shape_indexing(self.shape, idx))
+ indexed_shape = _size(_shape_indexing(self.shape, idx))
return self.__class__(shape=indexed_shape, device=self.device, dtype=self.dtype)
def unbind(self, dim: int = 0):
@@ -2324,21 +2662,12 @@ def unbind(self, dim: int = 0):
def __eq__(self, other):
# those specs are equivalent to a discrete spec
- if isinstance(other, UnboundedDiscreteTensorSpec):
- return (
- UnboundedDiscreteTensorSpec(
- shape=self.shape,
- device=self.device,
- dtype=self.dtype,
- )
- == other
- )
- if isinstance(other, BoundedTensorSpec):
+ if isinstance(other, Bounded):
minval, maxval = _minmax_dtype(self.dtype)
minval = torch.as_tensor(minval).to(self.device, self.dtype)
maxval = torch.as_tensor(maxval).to(self.device, self.dtype)
return (
- BoundedTensorSpec(
+ Bounded(
shape=self.shape,
high=maxval,
low=minval,
@@ -2348,185 +2677,43 @@ def __eq__(self, other):
)
== other
)
+ elif isinstance(other, Unbounded):
+ if self.dtype != other.dtype:
+ return False
+ if self.shape != other.shape:
+ return False
+ if self.device != other.device:
+ return False
+ return True
return super().__eq__(other)
-@dataclass(repr=False)
-class UnboundedDiscreteTensorSpec(TensorSpec):
- """An unbounded discrete tensor spec.
+class UnboundedContinuous(Unbounded):
+ """A specialized version of :class:`torchrl.data.Unbounded` with continuous space."""
- Args:
- device (str, int or torch.device, optional): device of the tensors.
- dtype (str or torch.dtype, optional): dtype of the tensors
- (should be an integer dtype such as long, uint8 etc.)
- """
+ ...
- # SPEC_HANDLED_FUNCTIONS = {}
+
+class UnboundedDiscrete(Unbounded):
+ """A specialized version of :class:`torchrl.data.Unbounded` with discrete space."""
def __init__(
self,
shape: Union[torch.Size, int] = _DEFAULT_SHAPE,
device: Optional[DEVICE_TYPING] = None,
dtype: Optional[Union[str, torch.dtype]] = torch.int64,
+ **kwargs,
):
- if isinstance(shape, int):
- shape = torch.Size([shape])
-
- dtype, device = _default_dtype_and_device(dtype, device)
- if dtype == torch.bool:
- min_value = False
- max_value = True
- else:
- if dtype.is_floating_point:
- min_value = torch.finfo(dtype).min
- max_value = torch.finfo(dtype).max
- else:
- min_value = torch.iinfo(dtype).min
- max_value = torch.iinfo(dtype).max
- space = ContinuousBox(
- torch.full(_remove_neg_shapes(shape), min_value, device=device),
- torch.full(_remove_neg_shapes(shape), max_value, device=device),
- )
-
- super().__init__(
- shape=shape,
- space=space,
- device=device,
- dtype=dtype,
- domain="discrete",
- )
-
- def to(self, dest: Union[torch.dtype, DEVICE_TYPING]) -> CompositeSpec:
- if isinstance(dest, torch.dtype):
- dest_dtype = dest
- dest_device = self.device
- elif dest is None:
- return self
- else:
- dest_dtype = self.dtype
- dest_device = torch.device(dest)
- if dest_device == self.device and dest_dtype == self.dtype:
- return self
- return self.__class__(shape=self.shape, device=dest_device, dtype=dest_dtype)
-
- def clone(self) -> UnboundedDiscreteTensorSpec:
- return self.__class__(shape=self.shape, device=self.device, dtype=self.dtype)
-
- def rand(self, shape=None) -> torch.Tensor:
- if shape is None:
- shape = torch.Size([])
- interval = self.space.high - self.space.low
- r = torch.rand(torch.Size([*shape, *interval.shape]), device=interval.device)
- r = r * interval
- r = self.space.low + r
- r = r.to(self.dtype)
- return r.to(self.device)
-
- def is_in(self, val: torch.Tensor) -> bool:
- shape = torch.broadcast_shapes(self._safe_shape, val.shape)
- return val.shape == shape and val.dtype == self.dtype
-
- def expand(self, *shape):
- if len(shape) == 1 and isinstance(shape[0], (tuple, list, torch.Size)):
- shape = shape[0]
- if any(s1 != s2 and s2 != 1 for s1, s2 in zip(shape[-self.ndim :], self.shape)):
- raise ValueError(
- f"The last {self.ndim} of the expanded shape {shape} must match the"
- f"shape of the {self.__class__.__name__} spec in expand()."
- )
- return self.__class__(shape=shape, device=self.device, dtype=self.dtype)
-
- def _reshape(self, shape):
- return self.__class__(shape=shape, device=self.device, dtype=self.dtype)
-
- def _unflatten(self, dim, sizes):
- shape = torch.zeros(self.shape, device="meta").unflatten(dim, sizes).shape
- return self.__class__(
- shape=shape,
- device=self.device,
- dtype=self.dtype,
- )
-
- def __getitem__(self, idx: SHAPE_INDEX_TYPING):
- """Indexes the current TensorSpec based on the provided index."""
- indexed_shape = torch.Size(_shape_indexing(self.shape, idx))
- return self.__class__(shape=indexed_shape, device=self.device, dtype=self.dtype)
-
- def unbind(self, dim: int = 0):
- orig_dim = dim
- if dim < 0:
- dim = len(self.shape) + dim
- if dim < 0:
- raise ValueError(
- f"Cannot unbind along dim {orig_dim} with shape {self.shape}."
- )
- shape = tuple(s for i, s in enumerate(self.shape) if i != dim)
- return tuple(
- self.__class__(
- shape=shape,
- device=self.device,
- dtype=self.dtype,
- )
- for i in range(self.shape[dim])
- )
-
- def __eq__(self, other):
- # those specs are equivalent to a discrete spec
- if isinstance(other, UnboundedContinuousTensorSpec):
- return (
- UnboundedContinuousTensorSpec(
- shape=self.shape,
- device=self.device,
- dtype=self.dtype,
- domain=self.domain,
- )
- == other
- )
- if isinstance(other, BoundedTensorSpec):
- return (
- BoundedTensorSpec(
- shape=self.shape,
- high=self.space.high,
- low=self.space.low,
- dtype=self.dtype,
- device=self.device,
- domain=self.domain,
- )
- == other
- )
- return super().__eq__(other)
-
- def __ne__(self, other):
- # those specs are equivalent to a discrete spec
- if isinstance(other, UnboundedContinuousTensorSpec):
- return (
- UnboundedContinuousTensorSpec(
- shape=self.shape,
- device=self.device,
- dtype=self.dtype,
- domain=self.domain,
- )
- != other
- )
- if isinstance(other, BoundedTensorSpec):
- return (
- BoundedTensorSpec(
- shape=self.shape,
- high=self.space.high,
- low=self.space.low,
- dtype=self.dtype,
- device=self.device,
- domain=self.domain,
- )
- != other
- )
- return super().__ne__(other)
+ super().__init__(shape=shape, device=device, dtype=dtype, **kwargs)
@dataclass(repr=False)
-class MultiOneHotDiscreteTensorSpec(OneHotDiscreteTensorSpec):
+class MultiOneHot(OneHot):
"""A concatenation of one-hot discrete tensor spec.
+ This class can be used when a single tensor must carry information about multiple one-hot encoded
+ values.
+
The last dimension of the shape (domain of the tensor elements) cannot be indexed.
Args:
@@ -2541,20 +2728,22 @@ class MultiOneHotDiscreteTensorSpec(OneHotDiscreteTensorSpec):
sample is taken. See :meth:`~.update_mask` for more information.
Examples:
- >>> ts = MultiOneHotDiscreteTensorSpec((3,2,3))
- >>> ts.is_in(torch.tensor([0,0,1,
- ... 0,1,
- ... 1,0,0]))
+ >>> ts = MultiOneHot((3,2,3))
+ >>> ts.rand()
+ tensor([ True, False, False, True, False, False, False, True])
+ >>> ts.is_in(torch.tensor([
+ ... 0, 0, 1,
+ ... 0, 1,
+ ... 1, 0, 0], dtype=torch.bool))
True
- >>> ts.is_in(torch.tensor([1,0,1,
- ... 0,1,
- ... 1,0,0])) # False
+ >>> ts.is_in(torch.tensor([
+ ... 1, 0, 1,
+ ... 0, 1,
+ ... 1, 0, 0], dtype=torch.bool))
False
"""
- # SPEC_HANDLED_FUNCTIONS = {}
-
def __init__(
self,
nvec: Sequence[int],
@@ -2567,17 +2756,17 @@ def __init__(
self.nvec = nvec
dtype, device = _default_dtype_and_device(dtype, device)
if shape is None:
- shape = torch.Size((sum(nvec),))
+ shape = _size((sum(nvec),))
else:
- shape = torch.Size(shape)
+ shape = _size(shape)
if shape[-1] != sum(nvec):
raise ValueError(
f"The last value of the shape must match sum(nvec) for transform of type {self.__class__}. "
f"Got sum(nvec)={sum(nvec)} and shape={shape}."
)
- space = BoxList([DiscreteBox(n) for n in nvec])
+ space = BoxList([CategoricalBox(n) for n in nvec])
self.use_register = use_register
- super(OneHotDiscreteTensorSpec, self).__init__(
+ super(OneHot, self).__init__(
shape,
space,
device,
@@ -2601,7 +2790,7 @@ def update_mask(self, mask):
Examples:
>>> mask = torch.tensor([True, False, False,
... True, True])
- >>> ts = MultiOneHotDiscreteTensorSpec((3, 2), (2, 5), dtype=torch.int64, mask=mask)
+ >>> ts = MultiOneHot((3, 2), (2, 5), dtype=torch.int64, mask=mask)
>>> # All but one of the three possible outcomes for the first
>>> # one-hot group are masked, but neither of the two possible
>>> # outcomes for the second one-hot group are masked.
@@ -2618,7 +2807,7 @@ def update_mask(self, mask):
raise ValueError("Only boolean masks are accepted.")
self.mask = mask
- def to(self, dest: Union[torch.dtype, DEVICE_TYPING]) -> CompositeSpec:
+ def to(self, dest: Union[torch.dtype, DEVICE_TYPING]) -> MultiOneHot:
if isinstance(dest, torch.dtype):
dest_dtype = dest
dest_device = self.device
@@ -2629,7 +2818,7 @@ def to(self, dest: Union[torch.dtype, DEVICE_TYPING]) -> CompositeSpec:
dest_device = torch.device(dest)
if dest_device == self.device and dest_dtype == self.dtype:
return self
- return self.__class__(
+ return MultiOneHot(
nvec=deepcopy(self.nvec),
shape=self.shape,
device=dest_device,
@@ -2637,7 +2826,7 @@ def to(self, dest: Union[torch.dtype, DEVICE_TYPING]) -> CompositeSpec:
mask=self.mask.to(dest) if self.mask is not None else None,
)
- def clone(self) -> MultiOneHotDiscreteTensorSpec:
+ def clone(self) -> MultiOneHot:
return self.__class__(
nvec=deepcopy(self.nvec),
shape=self.shape,
@@ -2670,7 +2859,7 @@ def rand(self, shape: Optional[torch.Size] = None) -> torch.Tensor:
if shape is None:
shape = self.shape[:-1]
else:
- shape = torch.Size([*shape, *self.shape[:-1]])
+ shape = _size([*shape, *self.shape[:-1]])
mask = self.mask
if mask is None:
@@ -2722,9 +2911,7 @@ def encode(
f"value {v} is greater than the allowed max {space.n}"
)
x.append(
- super(MultiOneHotDiscreteTensorSpec, self).encode(
- v, space, ignore_device=ignore_device
- )
+ super(MultiOneHot, self).encode(v, space, ignore_device=ignore_device)
)
return torch.cat(x, -1).reshape(self.shape)
@@ -2776,7 +2963,7 @@ def _split_self(self):
n = space.n
shape = self.shape[:-1] + (n,)
result.append(
- OneHotDiscreteTensorSpec(
+ OneHot(
n=n,
shape=shape,
device=device,
@@ -2798,6 +2985,16 @@ def to_categorical(self, val: torch.Tensor, safe: bool = None) -> torch.Tensor:
Returns:
The categorical tensor.
+
+ Examples:
+ >>> mone_hot = MultiOneHot((2, 3, 4))
+ >>> onehot_sample = mone_hot.rand()
+ >>> onehot_sample
+ tensor([False, True, False, False, True, False, True, False, False])
+ >>> categ_sample = mone_hot.to_categorical(onehot_sample)
+ >>> categ_sample
+ tensor([1, 2, 1])
+
"""
if safe is None:
safe = _CHECK_SPEC_ENCODE
@@ -2806,15 +3003,36 @@ def to_categorical(self, val: torch.Tensor, safe: bool = None) -> torch.Tensor:
vals = self._split(val)
return torch.stack([val.long().argmax(-1) for val in vals], -1)
- def to_categorical_spec(self) -> MultiDiscreteTensorSpec:
- """Converts the spec to the equivalent categorical spec."""
- return MultiDiscreteTensorSpec(
+ def to_categorical_spec(self) -> MultiCategorical:
+ """Converts the spec to the equivalent categorical spec.
+
+ Examples:
+ >>> mone_hot = MultiOneHot((2, 3, 4))
+ >>> categ = mone_hot.to_categorical_spec()
+ >>> categ
+ MultiCategorical(
+ shape=torch.Size([3]),
+ space=BoxList(boxes=[CategoricalBox(n=2), CategoricalBox(n=3), CategoricalBox(n=4)]),
+ device=cpu,
+ dtype=torch.int64,
+ domain=discrete)
+
+ """
+ return MultiCategorical(
[_space.n for _space in self.space],
device=self.device,
shape=[*self.shape[:-1], len(self.space)],
mask=self.mask,
)
+ def to_one_hot(self, val: torch.Tensor, safe: bool = None) -> torch.Tensor:
+ """No-op for MultiOneHot."""
+ return val
+
+ def to_one_hot_spec(self) -> OneHot:
+ """No-op for MultiOneHot."""
+ return self
+
def expand(self, *shape):
nvecs = [space.n for space in self.space]
if len(shape) == 1 and isinstance(shape[0], (tuple, list, torch.Size)):
@@ -2917,28 +3135,21 @@ def __getitem__(self, idx: SHAPE_INDEX_TYPING):
indexed_shape = _shape_indexing(self.shape[:-1], idx)
return self.__class__(
nvec=self.nvec,
- shape=torch.Size(indexed_shape + [self.shape[-1]]),
+ shape=_size(indexed_shape + [self.shape[-1]]),
device=self.device,
dtype=self.dtype,
)
-class DiscreteTensorSpec(TensorSpec):
+class Categorical(TensorSpec):
"""A discrete tensor spec.
- An alternative to OneHotTensorSpec for categorical variables in TorchRL. Instead of
- using multiplication, categorical variables perform indexing which can speed up
+ An alternative to :class:`OneHot` for categorical variables in TorchRL.
+ Categorical variables perform indexing insted of masking, which can speed-up
computation and reduce memory cost for large categorical variables.
- The last dimension of the spec (length n of the binary vector) cannot be indexed
- Example:
- >>> batch, size = 3, 4
- >>> action_value = torch.arange(batch*size)
- >>> action_value = action_value.view(batch, size).to(torch.float)
- >>> action = torch.argmax(action_value, dim=-1).to(torch.long)
- >>> chosen_action_value = action_value[range(batch), action]
- >>> print(chosen_action_value)
- tensor([ 3., 7., 11.])
+ The spec will have the shape defined by the ``shape`` argument: if a singleton dimension is
+ desired for the training dimension, one should specify it explicitly.
Args:
n (int): number of possible outcomes.
@@ -2948,10 +3159,32 @@ class DiscreteTensorSpec(TensorSpec):
mask (torch.Tensor or None): mask some of the possible outcomes when a
sample is taken. See :meth:`~.update_mask` for more information.
+ Examples:
+ >>> categ = Categorical(3)
+ >>> categ
+ Categorical(
+ shape=torch.Size([]),
+ space=CategoricalBox(n=3),
+ device=cpu,
+ dtype=torch.int64,
+ domain=discrete)
+ >>> categ.rand()
+ tensor(2)
+ >>> categ = Categorical(3, shape=(1,))
+ >>> categ
+ Categorical(
+ shape=torch.Size([1]),
+ space=CategoricalBox(n=3),
+ device=cpu,
+ dtype=torch.int64,
+ domain=discrete)
+ >>> categ.rand()
+ tensor([1])
+
"""
shape: torch.Size
- space: DiscreteBox
+ space: CategoricalBox
device: torch.device | None = None
dtype: torch.dtype = torch.float
domain: str = ""
@@ -2967,9 +3200,9 @@ def __init__(
mask: torch.Tensor | None = None,
):
if shape is None:
- shape = torch.Size([])
+ shape = _size([])
dtype, device = _default_dtype_and_device(dtype, device)
- space = DiscreteBox(n)
+ space = CategoricalBox(n)
super().__init__(
shape=shape, space=space, device=device, dtype=dtype, domain="discrete"
)
@@ -2994,7 +3227,7 @@ def update_mask(self, mask):
Examples:
>>> mask = torch.tensor([True, False, True])
- >>> ts = DiscreteTensorSpec(3, (10,), dtype=torch.int64, mask=mask)
+ >>> ts = Categorical(3, (10,), dtype=torch.int64, mask=mask)
>>> # One of the three possible outcomes is masked
>>> ts.rand()
tensor([0, 2, 2, 0, 2, 0, 2, 2, 0, 2])
@@ -3008,14 +3241,14 @@ def update_mask(self, mask):
raise ValueError("Only boolean masks are accepted.")
self.mask = mask
- def rand(self, shape=None) -> torch.Tensor:
+ def rand(self, shape: torch.Size = None) -> torch.Tensor:
if shape is None:
- shape = torch.Size([])
+ shape = _size([])
if self.mask is None:
return torch.randint(
0,
self.space.n,
- torch.Size([*shape, *self.shape]),
+ _size([*shape, *self.shape]),
device=self.device,
dtype=self.dtype,
)
@@ -3035,7 +3268,7 @@ def _project(self, val: torch.Tensor) -> torch.Tensor:
if self.mask is None:
return val.clamp_(min=0, max=self.space.n - 1)
shape = self.mask.shape
- shape = torch.Size([*torch.broadcast_shapes(shape[:-1], val.shape), shape[-1]])
+ shape = _size([*torch.broadcast_shapes(shape[:-1], val.shape), shape[-1]])
mask_expand = self.mask.expand(shape)
gathered = mask_expand.gather(-1, val.unsqueeze(-1))
oob = ~gathered.all(-1)
@@ -3054,14 +3287,14 @@ def is_in(self, val: torch.Tensor) -> bool:
return False
return (0 <= val).all() and (val < self.space.n).all()
shape = self.mask.shape
- shape = torch.Size([*torch.broadcast_shapes(shape[:-1], val.shape), shape[-1]])
+ shape = _size([*torch.broadcast_shapes(shape[:-1], val.shape), shape[-1]])
mask_expand = self.mask.expand(shape)
gathered = mask_expand.gather(-1, val.unsqueeze(-1))
return gathered.all()
def __getitem__(self, idx: SHAPE_INDEX_TYPING):
"""Indexes the current TensorSpec based on the provided index."""
- indexed_shape = torch.Size(_shape_indexing(self.shape, idx))
+ indexed_shape = _size(_shape_indexing(self.shape, idx))
return self.__class__(
n=self.space.n,
shape=indexed_shape,
@@ -3106,6 +3339,15 @@ def to_one_hot(self, val: torch.Tensor, safe: bool = None) -> torch.Tensor:
Returns:
The one-hot encoded tensor.
+
+ Examples:
+ >>> categ = Categorical(3)
+ >>> categ_sample = categ.zero()
+ >>> categ_sample
+ tensor(0)
+ >>> onehot_sample = categ.to_one_hot(categ_sample)
+ >>> onehot_sample
+ tensor([ True, False, False])
"""
if safe is None:
safe = _CHECK_SPEC_ENCODE
@@ -3113,15 +3355,35 @@ def to_one_hot(self, val: torch.Tensor, safe: bool = None) -> torch.Tensor:
self.assert_is_in(val)
return torch.nn.functional.one_hot(val, self.space.n).bool()
- def to_one_hot_spec(self) -> OneHotDiscreteTensorSpec:
- """Converts the spec to the equivalent one-hot spec."""
+ def to_categorical(self, val: torch.Tensor, safe: bool = None) -> torch.Tensor:
+ """No-op for categorical."""
+ return val
+
+ def to_one_hot_spec(self) -> OneHot:
+ """Converts the spec to the equivalent one-hot spec.
+
+ Examples:
+ >>> categ = Categorical(3)
+ >>> categ.to_one_hot_spec()
+ OneHot(
+ shape=torch.Size([3]),
+ space=CategoricalBox(n=3),
+ device=cpu,
+ dtype=torch.bool,
+ domain=discrete)
+
+ """
shape = [*self.shape, self.space.n]
- return OneHotDiscreteTensorSpec(
+ return OneHot(
n=self.space.n,
shape=shape,
device=self.device,
)
+ def to_categorical_spec(self) -> Categorical:
+ """No-op for categorical."""
+ return self
+
def expand(self, *shape):
if len(shape) == 1 and isinstance(shape[0], (tuple, list, torch.Size)):
shape = shape[0]
@@ -3199,7 +3461,7 @@ def unbind(self, dim: int = 0):
for i in range(self.shape[dim])
)
- def to(self, dest: Union[torch.dtype, DEVICE_TYPING]) -> CompositeSpec:
+ def to(self, dest: Union[torch.dtype, DEVICE_TYPING]) -> Categorical:
if isinstance(dest, torch.dtype):
dest_dtype = dest
dest_device = self.device
@@ -3214,7 +3476,7 @@ def to(self, dest: Union[torch.dtype, DEVICE_TYPING]) -> CompositeSpec:
n=self.space.n, shape=self.shape, device=dest_device, dtype=dest_dtype
)
- def clone(self) -> DiscreteTensorSpec:
+ def clone(self) -> Categorical:
return self.__class__(
n=self.space.n,
shape=self.shape,
@@ -3225,32 +3487,59 @@ def clone(self) -> DiscreteTensorSpec:
@dataclass(repr=False)
-class BinaryDiscreteTensorSpec(DiscreteTensorSpec):
+class Binary(Categorical):
"""A binary discrete tensor spec.
+ A binary tensor spec encodes tensors of arbitrary size where the values are either 0 or 1 (or ``True`` or ``False``
+ if the dtype it ``torch.bool``).
+
+ Unlike :class:`OneHot`, `Binary` can have more than one non-null element along the last dimension.
+
Args:
- n (int): length of the binary vector.
+ n (int): length of the binary vector. If provided along with ``shape``, ``shape[-1]`` must match ``n``.
+ If not provided, ``shape`` must be passed.
+
+ .. warning:: the ``n`` argument from ``Binary`` must not be confused with the ``n`` argument from :class:`Categorical`
+ or :class:`OneHot` which denotes the maximum nmber of elements that can be sampled.
+ For clarity, use ``shape`` instead.
+
shape (torch.Size, optional): total shape of the sampled tensors.
- If provided, the last dimension must match n.
+ If provided, the last dimension must match ``n``.
device (str, int or torch.device, optional): device of the tensors.
- dtype (str or torch.dtype, optional): dtype of the tensors. Defaults to torch.long.
+ dtype (str or torch.dtype, optional): dtype of the tensors.
+ Defaults to ``torch.int8``.
Examples:
- >>> spec = BinaryDiscreteTensorSpec(n=4, shape=(5, 4), device="cpu", dtype=torch.bool)
- >>> print(spec.zero())
+ >>> torch.manual_seed(0)
+ >>> spec = Binary(n=4, shape=(2, 4))
+ >>> print(spec.rand())
+ tensor([[0, 1, 1, 0],
+ [1, 1, 1, 1]], dtype=torch.int8)
+ >>> spec = Binary(shape=(2, 4))
+ >>> print(spec.rand())
+ tensor([[1, 1, 1, 0],
+ [0, 1, 0, 0]], dtype=torch.int8)
+ >>> spec = Binary(n=4)
+ >>> print(spec.rand())
+ tensor([0, 0, 0, 1], dtype=torch.int8)
+
"""
def __init__(
self,
- n: int,
+ n: int | None = None,
shape: Optional[torch.Size] = None,
device: Optional[DEVICE_TYPING] = None,
dtype: Union[str, torch.dtype] = torch.int8,
):
+ if n is None and not shape:
+ raise TypeError("Must provide either n or shape.")
+ if n is None:
+ n = shape[-1]
if shape is None or not len(shape):
- shape = torch.Size((n,))
+ shape = _size((n,))
else:
- shape = torch.Size(shape)
+ shape = _size(shape)
if shape[-1] != n:
raise ValueError(
f"The last value of the shape must match n for spec {self.__class__}. "
@@ -3318,7 +3607,7 @@ def unbind(self, dim: int = 0):
for i in range(self.shape[dim])
)
- def to(self, dest: Union[torch.dtype, DEVICE_TYPING]) -> CompositeSpec:
+ def to(self, dest: Union[torch.dtype, DEVICE_TYPING]) -> Binary:
if isinstance(dest, torch.dtype):
dest_dtype = dest
dest_device = self.device
@@ -3333,7 +3622,7 @@ def to(self, dest: Union[torch.dtype, DEVICE_TYPING]) -> CompositeSpec:
n=self.shape[-1], shape=self.shape, device=dest_device, dtype=dest_dtype
)
- def clone(self) -> BinaryDiscreteTensorSpec:
+ def clone(self) -> Binary:
return self.__class__(
n=self.shape[-1],
shape=self.shape,
@@ -3349,14 +3638,14 @@ def __getitem__(self, idx: SHAPE_INDEX_TYPING):
indexed_shape = _shape_indexing(self.shape[:-1], idx)
return self.__class__(
n=self.shape[-1],
- shape=torch.Size(indexed_shape + [self.shape[-1]]),
+ shape=_size(indexed_shape + [self.shape[-1]]),
device=self.device,
dtype=self.dtype,
)
def __eq__(self, other):
- if not isinstance(other, BinaryDiscreteTensorSpec):
- if isinstance(other, DiscreteTensorSpec):
+ if not isinstance(other, Binary):
+ if isinstance(other, Categorical):
return (
other.n == 2
and other.device == self.device
@@ -3368,7 +3657,7 @@ def __eq__(self, other):
@dataclass(repr=False)
-class MultiDiscreteTensorSpec(DiscreteTensorSpec):
+class MultiCategorical(Categorical):
"""A concatenation of discrete tensor spec.
Args:
@@ -3385,15 +3674,13 @@ class MultiDiscreteTensorSpec(DiscreteTensorSpec):
sample is taken. See :meth:`~.update_mask` for more information.
Examples:
- >>> ts = MultiDiscreteTensorSpec((3, 2, 3))
+ >>> ts = MultiCategorical((3, 2, 3))
>>> ts.is_in(torch.tensor([2, 0, 1]))
True
- >>> ts.is_in(torch.tensor([2, 2, 1]))
+ >>> ts.is_in(torch.tensor([2, 10, 1]))
False
"""
- # SPEC_HANDLED_FUNCTIONS = {}
-
def __init__(
self,
nvec: Union[Sequence[int], torch.Tensor, int],
@@ -3412,7 +3699,7 @@ def __init__(
if shape is None:
shape = nvec.shape
else:
- shape = torch.Size(shape)
+ shape = _size(shape)
if shape[-1] != nvec.shape[-1]:
raise ValueError(
f"The last value of the shape must match nvec.shape[-1] for transform of type {self.__class__}. "
@@ -3422,7 +3709,7 @@ def __init__(
self.nvec = self.nvec.expand(_remove_neg_shapes(shape))
space = BoxList.from_nvec(self.nvec)
- super(DiscreteTensorSpec, self).__init__(
+ super(Categorical, self).__init__(
shape, space, device, dtype, domain="discrete"
)
self.update_mask(mask)
@@ -3442,9 +3729,10 @@ def update_mask(self, mask):
sample is taken.
Examples:
+ >>> torch.manual_seed(0)
>>> mask = torch.tensor([False, False, True,
... True, True])
- >>> ts = MultiDiscreteTensorSpec((3, 2), (5, 2,), dtype=torch.int64, mask=mask)
+ >>> ts = MultiCategorical((3, 2), (5, 2,), dtype=torch.int64, mask=mask)
>>> # All but one of the three possible outcomes for the first
>>> # group are masked, but neither of the two possible
>>> # outcomes for the second group are masked.
@@ -3453,7 +3741,7 @@ def update_mask(self, mask):
[2, 0],
[2, 1],
[2, 1],
- [2, 0]])
+ [2, 1]])
"""
if mask is not None:
try:
@@ -3464,7 +3752,7 @@ def update_mask(self, mask):
raise ValueError("Only boolean masks are accepted.")
self.mask = mask
- def to(self, dest: Union[torch.dtype, DEVICE_TYPING]) -> CompositeSpec:
+ def to(self, dest: Union[torch.dtype, DEVICE_TYPING]) -> MultiCategorical:
if isinstance(dest, torch.dtype):
dest_dtype = dest
dest_device = self.device
@@ -3503,7 +3791,7 @@ def __eq__(self, other):
and mask_equal
)
- def clone(self) -> MultiDiscreteTensorSpec:
+ def clone(self) -> MultiCategorical:
return self.__class__(
nvec=self.nvec.clone(),
shape=None,
@@ -3541,7 +3829,7 @@ def rand(self, shape: Optional[torch.Size] = None) -> torch.Tensor:
*self.shape[:-1],
)
x = self._rand(space=self.space, shape=shape, i=self.nvec.ndim)
- if self.remove_singleton and self.shape == torch.Size([1]):
+ if self.remove_singleton and self.shape == _size([1]):
x = x.squeeze(-1)
return x
@@ -3565,9 +3853,7 @@ def _split_self(self):
for n, _mask in zip(nvec, mask):
shape = self.shape[:-1]
result.append(
- DiscreteTensorSpec(
- n=n, shape=shape, device=device, dtype=dtype, mask=_mask
- )
+ Categorical(n=n, shape=shape, device=device, dtype=dtype, mask=_mask)
)
return result
@@ -3620,7 +3906,7 @@ def is_in(self, val: torch.Tensor) -> bool:
def to_one_hot(
self, val: torch.Tensor, safe: bool = None
- ) -> Union[MultiOneHotDiscreteTensorSpec, torch.Tensor]:
+ ) -> Union[MultiOneHot, torch.Tensor]:
"""Encodes a discrete tensor from the spec domain into its one-hot correspondent.
Args:
@@ -3644,16 +3930,24 @@ def to_one_hot(
-1,
).to(self.device)
- def to_one_hot_spec(self) -> MultiOneHotDiscreteTensorSpec:
+ def to_one_hot_spec(self) -> MultiOneHot:
"""Converts the spec to the equivalent one-hot spec."""
nvec = [_space.n for _space in self.space]
- return MultiOneHotDiscreteTensorSpec(
+ return MultiOneHot(
nvec,
device=self.device,
shape=[*self.shape[:-1], sum(nvec)],
mask=self.mask,
)
+ def to_categorical(self, val: torch.Tensor, safe: bool = None) -> MultiCategorical:
+ """Not op for MultiCategorical."""
+ return val
+
+ def to_categorical_spec(self) -> MultiCategorical:
+ """Not op for MultiCategorical."""
+ return self
+
def expand(self, *shape):
if len(shape) == 1 and isinstance(shape[0], (tuple, list, torch.Size)):
shape = shape[0]
@@ -3770,12 +4064,16 @@ def __getitem__(self, idx: SHAPE_INDEX_TYPING):
)
-class CompositeSpec(TensorSpec):
+class Composite(TensorSpec):
"""A composition of TensorSpecs.
+ If a ``TensorSpec`` is the set-description of Tensor category, the ``Composite`` class is akin to
+ the :class:`~tensordict.TensorDict` class. Like :class:`~tensordict.TensorDict`, it has a ``shape`` (akin to the
+ ``TensorDict``'s ``batch_size``) and an optional ``device``.
+
Args:
*args: if an unnamed argument is passed, it must be a dictionary with keys
- matching the expected keys to be found in the :obj:`CompositeSpec` object.
+ matching the expected keys to be found in the :obj:`Composite` object.
This is useful to build nested CompositeSpecs with tuple indices.
**kwargs (key (str): value (TensorSpec)): dictionary of tensorspecs
to be stored. Values can be None, in which case is_in will be assumed
@@ -3792,53 +4090,59 @@ class CompositeSpec(TensorSpec):
to the batch-size of the corresponding tensordicts.
Examples:
- >>> pixels_spec = BoundedTensorSpec(
- ... torch.zeros(3,32,32),
- ... torch.ones(3, 32, 32))
- >>> observation_vector_spec = BoundedTensorSpec(torch.zeros(33),
- ... torch.ones(33))
- >>> composite_spec = CompositeSpec(
+ >>> pixels_spec = Bounded(
+ ... low=torch.zeros(4, 3, 32, 32),
+ ... high=torch.ones(4, 3, 32, 32),
+ ... dtype=torch.uint8
+ ... )
+ >>> observation_vector_spec = Bounded(
+ ... low=torch.zeros(4, 33),
+ ... high=torch.ones(4, 33),
+ ... dtype=torch.float)
+ >>> composite_spec = Composite(
... pixels=pixels_spec,
- ... observation_vector=observation_vector_spec)
- >>> td = TensorDict({"pixels": torch.rand(10,3,32,32),
- ... "observation_vector": torch.rand(10,33)}, batch_size=[10])
- >>> print("td (rand) is within bounds: ", composite_spec.is_in(td))
- td (rand) is within bounds: True
- >>> td = TensorDict({"pixels": torch.randn(10,3,32,32),
- ... "observation_vector": torch.randn(10,33)}, batch_size=[10])
- >>> print("td (randn) is within bounds: ", composite_spec.is_in(td))
- td (randn) is within bounds: False
- >>> td_project = composite_spec.project(td)
- >>> print("td modification done in place: ", td_project is td)
- td modification done in place: True
- >>> print("check td is within bounds after projection: ",
- ... composite_spec.is_in(td_project))
- check td is within bounds after projection: True
- >>> print("random td: ", composite_spec.rand([3,]))
- random td: TensorDict(
+ ... observation_vector=observation_vector_spec,
+ ... shape=(4,)
+ ... )
+ >>> composite_spec
+ Composite(
+ pixels: BoundedDiscrete(
+ shape=torch.Size([4, 3, 32, 32]),
+ space=ContinuousBox(
+ low=Tensor(shape=torch.Size([4, 3, 32, 32]), device=cpu, dtype=torch.uint8, contiguous=True),
+ high=Tensor(shape=torch.Size([4, 3, 32, 32]), device=cpu, dtype=torch.uint8, contiguous=True)),
+ device=cpu,
+ dtype=torch.uint8,
+ domain=discrete),
+ observation_vector: BoundedContinuous(
+ shape=torch.Size([4, 33]),
+ space=ContinuousBox(
+ low=Tensor(shape=torch.Size([4, 33]), device=cpu, dtype=torch.float32, contiguous=True),
+ high=Tensor(shape=torch.Size([4, 33]), device=cpu, dtype=torch.float32, contiguous=True)),
+ device=cpu,
+ dtype=torch.float32,
+ domain=continuous),
+ device=None,
+ shape=torch.Size([4]))
+ >>> td = composite_spec.rand()
+ >>> td
+ TensorDict(
fields={
- observation_vector: Tensor(torch.Size([3, 33]), dtype=torch.float32),
- pixels: Tensor(torch.Size([3, 3, 32, 32]), dtype=torch.float32)},
- batch_size=torch.Size([3]),
+ observation_vector: Tensor(shape=torch.Size([4, 33]), device=cpu, dtype=torch.float32, is_shared=False),
+ pixels: Tensor(shape=torch.Size([4, 3, 32, 32]), device=cpu, dtype=torch.uint8, is_shared=False)},
+ batch_size=torch.Size([4]),
device=None,
is_shared=False)
-
- Examples:
>>> # we can build a nested composite spec using unnamed arguments
- >>> print(CompositeSpec({("a", "b"): None, ("a", "c"): None}))
- CompositeSpec(
- a: CompositeSpec(
+ >>> print(Composite({("a", "b"): None, ("a", "c"): None}))
+ Composite(
+ a: Composite(
b: None,
- c: None))
-
- CompositeSpec supports nested indexing:
- >>> spec = CompositeSpec(obs=None)
- >>> spec["nested", "x"] = None
- >>> print(spec)
- CompositeSpec(
- nested: CompositeSpec(
- x: None),
- x: None)
+ c: None,
+ device=None,
+ shape=torch.Size([])),
+ device=None,
+ shape=torch.Size([]))
"""
@@ -3862,17 +4166,17 @@ def shape(self, value: torch.Size):
if self.locked:
raise RuntimeError("Cannot modify shape of locked composite spec.")
for key, spec in self.items():
- if isinstance(spec, CompositeSpec):
+ if isinstance(spec, Composite):
if spec.shape[: len(value)] != value:
spec.shape = value
elif spec is not None:
if spec.shape[: len(value)] != value:
raise ValueError(
- f"The shape of the spec and the CompositeSpec mismatch during shape resetting: the "
+ f"The shape of the spec and the Composite mismatch during shape resetting: the "
f"{self.ndim} first dimensions should match but got self['{key}'].shape={spec.shape} and "
- f"CompositeSpec.shape={self.shape}."
+ f"Composite.shape={self.shape}."
)
- self._shape = torch.Size(value)
+ self._shape = _size(value)
def is_empty(self):
"""Whether the composite spec contains specs or not."""
@@ -3887,25 +4191,30 @@ def ndimension(self):
def set(self, name, spec):
if self.locked:
- raise RuntimeError("Cannot modify a locked CompositeSpec.")
+ raise RuntimeError("Cannot modify a locked Composite.")
if spec is not None:
shape = spec.shape
if shape[: self.ndim] != self.shape:
raise ValueError(
- "The shape of the spec and the CompositeSpec mismatch: the first "
+ "The shape of the spec and the Composite mismatch: the first "
f"{self.ndim} dimensions should match but got spec.shape={spec.shape} and "
- f"CompositeSpec.shape={self.shape}."
+ f"Composite.shape={self.shape}."
)
self._specs[name] = spec
- def __init__(self, *args, shape=None, device=None, **kwargs):
+ def __init__(
+ self, *args, shape: torch.Size = None, device: torch.device = None, **kwargs
+ ):
+ # For compatibility with TensorDict
+ batch_size = kwargs.pop("batch_size", None)
+ if batch_size is not None:
+ if shape is not None:
+ raise TypeError("Cannot specify both batch_size and shape.")
+ shape = batch_size
+
if shape is None:
- # Should we do this? Other specs have a default empty shape, maybe it would make sense to keep it
- # optional for composite (for clarity and easiness of use).
- # warnings.warn("shape=None for CompositeSpec will soon be deprecated. Make sure you set the "
- # "batch size of your CompositeSpec as you would do for a tensordict.")
- shape = []
- self._shape = torch.Size(shape)
+ shape = _size(())
+ self._shape = _size(shape)
self._specs = {}
for key, value in kwargs.items():
self.set(key, value)
@@ -3918,7 +4227,7 @@ def __init__(self, *args, shape=None, device=None, **kwargs):
if item is None:
continue
if (
- isinstance(item, CompositeSpec)
+ isinstance(item, Composite)
and item.device is None
and _device is not None
):
@@ -3927,22 +4236,22 @@ def __init__(self, *args, shape=None, device=None, **kwargs):
raise RuntimeError(
f"Setting a new attribute ({key}) on another device "
f"({item.device} against {_device}). All devices of "
- "CompositeSpec must match."
+ "Composite must match."
)
self._device = _device
if len(args):
if len(args) > 1:
raise RuntimeError(
- "Got multiple arguments, when at most one is expected for CompositeSpec."
+ "Got multiple arguments, when at most one is expected for Composite."
)
argdict = args[0]
- if not isinstance(argdict, (dict, CompositeSpec)):
+ if not isinstance(argdict, (dict, Composite)):
raise RuntimeError(
f"Expected a dictionary of specs, but got an argument of type {type(argdict)}."
)
for k, item in argdict.items():
if isinstance(item, dict):
- item = CompositeSpec(item, shape=shape, device=_device)
+ item = Composite(item, shape=shape, device=_device)
self[k] = item
@property
@@ -3959,14 +4268,14 @@ def device(self, device: DEVICE_TYPING):
self.to(device)
def clear_device_(self):
- """Clears the device of the CompositeSpec."""
+ """Clears the device of the Composite."""
self._device = None
for spec in self._specs.values():
spec.clear_device_()
return self
def __getitem__(self, idx):
- """Indexes the current CompositeSpec based on the provided index."""
+ """Indexes the current Composite based on the provided index."""
if isinstance(idx, (str, tuple)):
idx_unravel = unravel_key(idx)
else:
@@ -3975,7 +4284,7 @@ def __getitem__(self, idx):
if isinstance(idx_unravel, tuple):
return self[idx[0]][idx[1:]]
if idx_unravel in {"shape", "device", "dtype", "space"}:
- raise AttributeError(f"CompositeSpec has no key {idx_unravel}")
+ raise AttributeError(f"Composite has no key {idx_unravel}")
return self._specs[idx_unravel]
indexed_shape = _shape_indexing(self.shape, idx)
@@ -3987,9 +4296,9 @@ def __getitem__(self, idx):
if any(
isinstance(v, spec_class)
for spec_class in [
- BinaryDiscreteTensorSpec,
- MultiDiscreteTensorSpec,
- OneHotDiscreteTensorSpec,
+ Binary,
+ MultiCategorical,
+ OneHot,
]
):
protected_dims = 1
@@ -4011,7 +4320,7 @@ def __getitem__(self, idx):
)
def get(self, item, default=NO_DEFAULT):
- """Gets an item from the CompositeSpec.
+ """Gets an item from the Composite.
If the item is absent, a default value can be passed.
@@ -4026,7 +4335,7 @@ def get(self, item, default=NO_DEFAULT):
def __setitem__(self, key, value):
if isinstance(key, tuple) and len(key) > 1:
if key[0] not in self.keys(True):
- self[key[0]] = CompositeSpec(shape=self.shape, device=self.device)
+ self[key[0]] = Composite(shape=self.shape, device=self.device)
self[key[0]][key[1:]] = value
return
elif isinstance(key, tuple):
@@ -4035,20 +4344,20 @@ def __setitem__(self, key, value):
elif not isinstance(key, str):
raise TypeError(f"Got key of type {type(key)} when a string was expected.")
if key in {"shape", "device", "dtype", "space"}:
- raise AttributeError(f"CompositeSpec[{key}] cannot be set")
+ raise AttributeError(f"Composite[{key}] cannot be set")
if isinstance(value, dict):
- value = CompositeSpec(value, device=self._device, shape=self.shape)
+ value = Composite(value, device=self._device, shape=self.shape)
if (
value is not None
and self.device is not None
and value.device != self.device
):
- if isinstance(value, CompositeSpec) and value.device is None:
+ if isinstance(value, Composite) and value.device is None:
value = value.clone().to(self.device)
else:
raise RuntimeError(
f"Setting a new attribute ({key}) on another device ({value.device} against {self.device}). "
- f"All devices of CompositeSpec must match."
+ f"All devices of Composite must match."
)
self.set(key, value)
@@ -4077,17 +4386,17 @@ def encode(
if isinstance(vals, TensorDict):
out = vals.empty() # create and empty tensordict similar to vals
else:
- out = TensorDict._new_unsafe({}, torch.Size([]))
+ out = TensorDict._new_unsafe({}, _size([]))
for key, item in vals.items():
if item is None:
raise RuntimeError(
- "CompositeSpec.encode cannot be used with missing values."
+ "Composite.encode cannot be used with missing values."
)
try:
out[key] = self[key].encode(item, ignore_device=ignore_device)
except KeyError:
raise KeyError(
- f"The CompositeSpec instance with keys {self.keys()} does not have a '{key}' key."
+ f"The Composite instance with keys {self.keys()} does not have a '{key}' key."
)
except RuntimeError as err:
raise RuntimeError(
@@ -4100,7 +4409,7 @@ def __repr__(self) -> str:
indent(f"{k}: {str(item)}", 4 * " ") for k, item in self._specs.items()
]
sub_str = ",\n".join(sub_str)
- return f"CompositeSpec(\n{sub_str},\n device={self._device},\n shape={self.shape})"
+ return f"Composite(\n{sub_str},\n device={self._device},\n shape={self.shape})"
def type_check(
self,
@@ -4119,9 +4428,9 @@ def type_check(
def is_in(self, val: Union[dict, TensorDictBase]) -> bool:
for key, item in self._specs.items():
- if item is None or (isinstance(item, CompositeSpec) and item.is_empty()):
+ if item is None or (isinstance(item, Composite) and item.is_empty()):
continue
- val_item = val.get(key)
+ val_item = val.get(key, NO_DEFAULT)
if not item.is_in(val_item):
return False
return True
@@ -4135,9 +4444,9 @@ def project(self, val: TensorDictBase) -> TensorDictBase:
val.set(key, self._specs[key].project(_val))
return val
- def rand(self, shape=None) -> TensorDictBase:
+ def rand(self, shape: torch.Size = None) -> TensorDictBase:
if shape is None:
- shape = torch.Size([])
+ shape = _size([])
_dict = {}
for key, item in self.items():
if item is not None:
@@ -4146,7 +4455,7 @@ def rand(self, shape=None) -> TensorDictBase:
# TensorDict requirements
return TensorDict._new_unsafe(
_dict,
- batch_size=torch.Size([*shape, *self.shape]),
+ batch_size=_size([*shape, *self.shape]),
device=self._device,
)
@@ -4157,24 +4466,24 @@ def keys(
*,
is_leaf: Callable[[type], bool] | None = None,
) -> _CompositeSpecKeysView: # noqa: D417
- """Keys of the CompositeSpec.
+ """Keys of the Composite.
The keys argument reflect those of :class:`tensordict.TensorDict`.
Args:
include_nested (bool, optional): if ``False``, the returned keys will not be nested. They will
represent only the immediate children of the root, and not the whole nested sequence, i.e. a
- :obj:`CompositeSpec(next=CompositeSpec(obs=None))` will lead to the keys
+ :obj:`Composite(next=Composite(obs=None))` will lead to the keys
:obj:`["next"]. Default is ``False``, i.e. nested keys will not
be returned.
leaves_only (bool, optional): if ``False``, the values returned
- will contain every level of nesting, i.e. a :obj:`CompositeSpec(next=CompositeSpec(obs=None))`
+ will contain every level of nesting, i.e. a :obj:`Composite(next=Composite(obs=None))`
will lead to the keys :obj:`["next", ("next", "obs")]`.
Default is ``False``.
Keyword Args:
is_leaf (callable, optional): reads a type and returns a boolean indicating if that type
- should be seen as a leaf. By default, all non-CompositeSpec nodes are considered as
+ should be seen as a leaf. By default, all non-Composite nodes are considered as
leaves.
"""
@@ -4192,22 +4501,22 @@ def items(
*,
is_leaf: Callable[[type], bool] | None = None,
) -> _CompositeSpecItemsView: # noqa: D417
- """Items of the CompositeSpec.
+ """Items of the Composite.
Args:
include_nested (bool, optional): if ``False``, the returned keys will not be nested. They will
represent only the immediate children of the root, and not the whole nested sequence, i.e. a
- :obj:`CompositeSpec(next=CompositeSpec(obs=None))` will lead to the keys
+ :obj:`Composite(next=Composite(obs=None))` will lead to the keys
:obj:`["next"]. Default is ``False``, i.e. nested keys will not
be returned.
leaves_only (bool, optional): if ``False``, the values returned
- will contain every level of nesting, i.e. a :obj:`CompositeSpec(next=CompositeSpec(obs=None))`
+ will contain every level of nesting, i.e. a :obj:`Composite(next=Composite(obs=None))`
will lead to the keys :obj:`["next", ("next", "obs")]`.
Default is ``False``.
Keyword Args:
is_leaf (callable, optional): reads a type and returns a boolean indicating if that type
- should be seen as a leaf. By default, all non-CompositeSpec nodes are considered as
+ should be seen as a leaf. By default, all non-Composite nodes are considered as
leaves.
"""
return _CompositeSpecItemsView(
@@ -4224,22 +4533,22 @@ def values(
*,
is_leaf: Callable[[type], bool] | None = None,
) -> _CompositeSpecValuesView: # noqa: D417
- """Values of the CompositeSpec.
+ """Values of the Composite.
Args:
include_nested (bool, optional): if ``False``, the returned keys will not be nested. They will
represent only the immediate children of the root, and not the whole nested sequence, i.e. a
- :obj:`CompositeSpec(next=CompositeSpec(obs=None))` will lead to the keys
+ :obj:`Composite(next=Composite(obs=None))` will lead to the keys
:obj:`["next"]. Default is ``False``, i.e. nested keys will not
be returned.
leaves_only (bool, optional): if ``False``, the values returned
- will contain every level of nesting, i.e. a :obj:`CompositeSpec(next=CompositeSpec(obs=None))`
+ will contain every level of nesting, i.e. a :obj:`Composite(next=Composite(obs=None))`
will lead to the keys :obj:`["next", ("next", "obs")]`.
Default is ``False``.
Keyword Args:
is_leaf (callable, optional): reads a type and returns a boolean indicating if that type
- should be seen as a leaf. By default, all non-CompositeSpec nodes are considered as
+ should be seen as a leaf. By default, all non-Composite nodes are considered as
leaves.
"""
return _CompositeSpecItemsView(
@@ -4254,7 +4563,7 @@ def _reshape(self, shape):
key: val.reshape((*shape, *val.shape[self.ndimension() :]))
for key, val in self._specs.items()
}
- return CompositeSpec(_specs, shape=shape)
+ return Composite(_specs, shape=shape)
def _unflatten(self, dim, sizes):
shape = torch.zeros(self.shape, device="meta").unflatten(dim, sizes).shape
@@ -4263,12 +4572,12 @@ def _unflatten(self, dim, sizes):
def __len__(self):
return len(self.keys())
- def to(self, dest: Union[torch.dtype, DEVICE_TYPING]) -> CompositeSpec:
+ def to(self, dest: Union[torch.dtype, DEVICE_TYPING]) -> Composite:
if dest is None:
return self
if not isinstance(dest, (str, int, torch.device)):
raise ValueError(
- "Only device casting is allowed with specs of type CompositeSpec."
+ "Only device casting is allowed with specs of type Composite."
)
if self._device and self._device == torch.device(dest):
return self
@@ -4283,7 +4592,7 @@ def to(self, dest: Union[torch.dtype, DEVICE_TYPING]) -> CompositeSpec:
kwargs[key] = value.to(dest)
return self.__class__(**kwargs, device=_device, shape=self.shape)
- def clone(self) -> CompositeSpec:
+ def clone(self) -> Composite:
try:
device = self.device
except RuntimeError:
@@ -4312,9 +4621,9 @@ def empty(self):
def to_numpy(self, val: TensorDict, safe: bool = None) -> dict:
return {key: self[key].to_numpy(val) for key, val in val.items()}
- def zero(self, shape=None) -> TensorDictBase:
+ def zero(self, shape: torch.Size = None) -> TensorDictBase:
if shape is None:
- shape = torch.Size([])
+ shape = _size([])
try:
device = self.device
except RuntimeError:
@@ -4325,7 +4634,7 @@ def zero(self, shape=None) -> TensorDictBase:
for key in self.keys(True)
if isinstance(key, str) and self[key] is not None
},
- torch.Size([*shape, *self._safe_shape]),
+ _size([*shape, *self._safe_shape]),
device=device,
)
@@ -4338,9 +4647,9 @@ def __eq__(self, other):
and all((self._specs[key] == spec) for (key, spec) in other._specs.items())
)
- def update(self, dict_or_spec: Union[CompositeSpec, Dict[str, TensorSpec]]) -> None:
+ def update(self, dict_or_spec: Union[Composite, Dict[str, TensorSpec]]) -> None:
for key, item in dict_or_spec.items():
- if key in self.keys(True) and isinstance(self[key], CompositeSpec):
+ if key in self.keys(True) and isinstance(self[key], Composite):
self[key].update(item)
continue
try:
@@ -4381,7 +4690,7 @@ def expand(self, *shape):
else None
for key, value in tuple(self.items())
}
- out = CompositeSpec(
+ out = Composite(
specs,
shape=shape,
device=device,
@@ -4402,7 +4711,7 @@ def squeeze(self, dim: int | None = None):
except RuntimeError:
device = self._device
- return CompositeSpec(
+ return Composite(
{key: value.squeeze(dim) for key, value in self.items()},
shape=shape,
device=device,
@@ -4428,7 +4737,7 @@ def unsqueeze(self, dim: int):
except RuntimeError:
device = self._device
- return CompositeSpec(
+ return Composite(
{
key: value.unsqueeze(dim) if value is not None else None
for key, value in self.items()
@@ -4457,19 +4766,19 @@ def unbind(self, dim: int = 0):
)
def lock_(self, recurse=False):
- """Locks the CompositeSpec and prevents modification of its content.
+ """Locks the Composite and prevents modification of its content.
This is only a first-level lock, unless specified otherwise through the
``recurse`` arg.
Leaf specs can always be modified in place, but they cannot be replaced
- in their CompositeSpec parent.
+ in their Composite parent.
Examples:
>>> shape = [3, 4, 5]
- >>> spec = CompositeSpec(
- ... a=CompositeSpec(
- ... b=CompositeSpec(shape=shape[:3], device="cpu"), shape=shape[:2]
+ >>> spec = Composite(
+ ... a=Composite(
+ ... b=Composite(shape=shape[:3], device="cpu"), shape=shape[:2]
... ),
... shape=shape[:1],
... )
@@ -4500,12 +4809,12 @@ def lock_(self, recurse=False):
self._locked = True
if recurse:
for value in self.values():
- if isinstance(value, CompositeSpec):
+ if isinstance(value, Composite):
value.lock_(recurse)
return self
def unlock_(self, recurse=False):
- """Unlocks the CompositeSpec and allows modification of its content.
+ """Unlocks the Composite and allows modification of its content.
This is only a first-level lock modification, unless specified
otherwise through the ``recurse`` arg.
@@ -4514,7 +4823,7 @@ def unlock_(self, recurse=False):
self._locked = False
if recurse:
for value in self.values():
- if isinstance(value, CompositeSpec):
+ if isinstance(value, Composite):
value.unlock_(recurse)
return self
@@ -4523,7 +4832,7 @@ def locked(self):
return self._locked
-class LazyStackedCompositeSpec(_LazyStackedMixin[CompositeSpec], CompositeSpec):
+class StackedComposite(_LazyStackedMixin[Composite], Composite):
"""A lazy representation of a stack of composite specs.
Stacks composite specs together along one dimension.
@@ -4539,7 +4848,7 @@ class LazyStackedCompositeSpec(_LazyStackedMixin[CompositeSpec], CompositeSpec):
def update(self, dict) -> None:
for key, item in dict.items():
if key in self.keys() and isinstance(
- item, (Dict, CompositeSpec, LazyStackedCompositeSpec)
+ item, (Dict, Composite, StackedComposite)
):
for spec, sub_item in zip(self._specs, item.unbind(self.dim)):
spec[key].update(sub_item)
@@ -4548,7 +4857,7 @@ def update(self, dict) -> None:
return self
def __eq__(self, other):
- if not isinstance(other, LazyStackedCompositeSpec):
+ if not isinstance(other, StackedComposite):
return False
if len(self._specs) != len(other._specs):
return False
@@ -4567,7 +4876,7 @@ def to_numpy(self, val: TensorDict, safe: bool = None) -> dict:
if safe:
if val.shape[self.dim] != len(self._specs):
raise ValueError(
- "Size of LazyStackedCompositeSpec and val differ along the "
+ "Size of StackedComposite and val differ along the "
"stacking dimension"
)
for spec, v in zip(self._specs, torch.unbind(val, dim=self.dim)):
@@ -4665,7 +4974,7 @@ def __repr__(self) -> str:
string = ",\n".join(
[sub_str, exclusive_key_str, device_str, shape_str, stack_dim]
)
- return f"LazyStackedCompositeSpec(\n{string})"
+ return f"StackedComposite(\n{string})"
def repr_exclusive_keys(self):
keys = set(self.keys())
@@ -4771,7 +5080,7 @@ def shape(self):
if dim < 0:
dim = len(shape) + dim + 1
shape.insert(dim, len(self._specs))
- return torch.Size(shape)
+ return _size(shape)
def expand(self, *shape):
if len(shape) == 1 and not isinstance(shape[0], (int,)):
@@ -4803,7 +5112,7 @@ def expand(self, *shape):
)
def empty(self):
- return LazyStackedCompositeSpec.maybe_dense_stack(
+ return StackedComposite.maybe_dense_stack(
[spec.empty() for spec in self._specs], dim=self.stack_dim
)
@@ -4812,7 +5121,7 @@ def encode(
) -> Dict[str, torch.Tensor]:
raise NOT_IMPLEMENTED_ERROR
- def zero(self, shape=None) -> TensorDictBase:
+ def zero(self, shape: torch.Size = None) -> TensorDictBase:
if shape is not None:
dim = self.dim + len(shape)
else:
@@ -4821,7 +5130,7 @@ def zero(self, shape=None) -> TensorDictBase:
[spec.zero(shape) for spec in self._specs], dim
)
- def one(self, shape=None) -> TensorDictBase:
+ def one(self, shape: torch.Size = None) -> TensorDictBase:
if shape is not None:
dim = self.dim + len(shape)
else:
@@ -4830,7 +5139,7 @@ def one(self, shape=None) -> TensorDictBase:
[spec.one(shape) for spec in self._specs], dim
)
- def rand(self, shape=None) -> TensorDictBase:
+ def rand(self, shape: torch.Size = None) -> TensorDictBase:
if shape is not None:
dim = self.dim + len(shape)
else:
@@ -4840,7 +5149,6 @@ def rand(self, shape=None) -> TensorDictBase:
)
-# for SPEC_CLASS in [BinaryDiscreteTensorSpec, BoundedTensorSpec, DiscreteTensorSpec, MultiDiscreteTensorSpec, MultiOneHotDiscreteTensorSpec, OneHotDiscreteTensorSpec, UnboundedContinuousTensorSpec, UnboundedDiscreteTensorSpec]:
@TensorSpec.implements_for_spec(torch.stack)
def _stack_specs(list_of_spec, dim, out=None):
if out is not None:
@@ -4873,12 +5181,12 @@ def _stack_specs(list_of_spec, dim, out=None):
dim += len(shape) + 1
shape.insert(dim, len(list_of_spec))
return spec0.clone().unsqueeze(dim).expand(shape)
- return LazyStackedTensorSpec(*list_of_spec, dim=dim)
+ return Stacked(*list_of_spec, dim=dim)
else:
raise NotImplementedError
-@CompositeSpec.implements_for_spec(torch.stack)
+@Composite.implements_for_spec(torch.stack)
def _stack_composite_specs(list_of_spec, dim, out=None):
if out is not None:
raise NotImplementedError(
@@ -4888,7 +5196,7 @@ def _stack_composite_specs(list_of_spec, dim, out=None):
if not len(list_of_spec):
raise ValueError("Cannot stack an empty list of specs.")
spec0 = list_of_spec[0]
- if isinstance(spec0, CompositeSpec):
+ if isinstance(spec0, Composite):
devices = {spec.device for spec in list_of_spec}
if len(devices) == 1:
device = list(devices)[0]
@@ -4903,7 +5211,7 @@ def _stack_composite_specs(list_of_spec, dim, out=None):
all_equal = True
for spec in list_of_spec[1:]:
- if not isinstance(spec, CompositeSpec):
+ if not isinstance(spec, Composite):
raise RuntimeError(
"Stacking specs cannot occur: Found more than one type of spec in "
"the list."
@@ -4920,7 +5228,7 @@ def _stack_composite_specs(list_of_spec, dim, out=None):
dim += len(shape) + 1
shape.insert(dim, len(list_of_spec))
return spec0.clone().unsqueeze(dim).expand(shape)
- return LazyStackedCompositeSpec(*list_of_spec, dim=dim)
+ return StackedComposite(*list_of_spec, dim=dim)
else:
raise NotImplementedError
@@ -4930,8 +5238,8 @@ def _squeeze_spec(spec: TensorSpec, *args, **kwargs) -> TensorSpec:
return spec.squeeze(*args, **kwargs)
-@CompositeSpec.implements_for_spec(torch.squeeze)
-def _squeeze_composite_spec(spec: CompositeSpec, *args, **kwargs) -> CompositeSpec:
+@Composite.implements_for_spec(torch.squeeze)
+def _squeeze_composite_spec(spec: Composite, *args, **kwargs) -> Composite:
return spec.squeeze(*args, **kwargs)
@@ -4940,16 +5248,16 @@ def _unsqueeze_spec(spec: TensorSpec, *args, **kwargs) -> TensorSpec:
return spec.unsqueeze(*args, **kwargs)
-@CompositeSpec.implements_for_spec(torch.unsqueeze)
-def _unsqueeze_composite_spec(spec: CompositeSpec, *args, **kwargs) -> CompositeSpec:
+@Composite.implements_for_spec(torch.unsqueeze)
+def _unsqueeze_composite_spec(spec: Composite, *args, **kwargs) -> Composite:
return spec.unsqueeze(*args, **kwargs)
def _keys_to_empty_composite_spec(keys):
- """Given a list of keys, creates a CompositeSpec tree where each leaf is assigned a None value."""
+ """Given a list of keys, creates a Composite tree where each leaf is assigned a None value."""
if not len(keys):
return
- c = CompositeSpec()
+ c = Composite()
for key in keys:
if isinstance(key, str):
c[key] = None
@@ -4957,7 +5265,7 @@ def _keys_to_empty_composite_spec(keys):
if c[key[0]] is None:
# if the value is None we just replace it
c[key[0]] = _keys_to_empty_composite_spec([key[1:]])
- elif isinstance(c[key[0]], CompositeSpec):
+ elif isinstance(c[key[0]], Composite):
# if the value is Composite, we update it
out = _keys_to_empty_composite_spec([key[1:]])
if out is not None:
@@ -4973,7 +5281,7 @@ def _squeezed_shape(shape: torch.Size, dim: int | None) -> torch.Size | None:
if dim is None:
if len(shape) == 1 or shape.count(1) == 0:
return None
- new_shape = torch.Size([s for s in shape if s != 1])
+ new_shape = _size([s for s in shape if s != 1])
else:
if dim < 0:
dim += len(shape)
@@ -4981,7 +5289,7 @@ def _squeezed_shape(shape: torch.Size, dim: int | None) -> torch.Size | None:
if shape[dim] != 1:
return None
- new_shape = torch.Size([s for i, s in enumerate(shape) if i != dim])
+ new_shape = _size([s for i, s in enumerate(shape) if i != dim])
return new_shape
@@ -4997,15 +5305,15 @@ def _unsqueezed_shape(shape: torch.Size, dim: int) -> torch.Size:
new_shape = list(shape)
new_shape.insert(dim, 1)
- return torch.Size(new_shape)
+ return _size(new_shape)
class _CompositeSpecItemsView:
- """Wrapper class that enables richer behaviour of `items` for CompositeSpec."""
+ """Wrapper class that enables richer behavior of `items` for Composite."""
def __init__(
self,
- composite: CompositeSpec,
+ composite: Composite,
include_nested,
leaves_only,
*,
@@ -5023,13 +5331,13 @@ def __iter__(self):
if is_leaf in (None, _NESTED_TENSORS_AS_LISTS):
def _is_leaf(cls):
- return not issubclass(cls, CompositeSpec)
+ return not issubclass(cls, Composite)
else:
_is_leaf = is_leaf
def _iter_from_item(key, item):
- if self.include_nested and isinstance(item, CompositeSpec):
+ if self.include_nested and isinstance(item, Composite):
for subkey, subitem in item.items(
include_nested=True,
leaves_only=self.leaves_only,
@@ -5054,7 +5362,7 @@ def _iter_from_item(key, item):
def _get_composite_items(self, is_leaf):
- if isinstance(self.composite, LazyStackedCompositeSpec):
+ if isinstance(self.composite, StackedComposite):
from tensordict.base import _NESTED_TENSORS_AS_LISTS
if is_leaf is _NESTED_TENSORS_AS_LISTS:
@@ -5141,5 +5449,149 @@ def _minmax_dtype(dtype):
def _remove_neg_shapes(*shape):
if len(shape) == 1 and not isinstance(shape[0], int):
- return _remove_neg_shapes(*shape[0])
- return torch.Size([int(d) if d >= 0 else 1 for d in shape])
+ shape = shape[0]
+ if isinstance(shape, np.integer):
+ shape = (int(shape),)
+ return _remove_neg_shapes(*shape)
+ return _size([int(d) if d >= 0 else 1 for d in shape])
+
+
+##############
+# Legacy
+#
+class _LegacySpecMeta(abc.ABCMeta):
+ def __call__(cls, *args, **kwargs):
+ warnings.warn(
+ f"The {cls.__name__} has been deprecated and will be removed in v0.7. Please use "
+ f"{cls.__bases__[-1].__name__} instead.",
+ category=DeprecationWarning,
+ )
+ instance = super().__call__(*args, **kwargs)
+ if (
+ type(instance) in (UnboundedDiscreteTensorSpec, UnboundedDiscrete)
+ and instance.domain == "continuous"
+ ):
+ instance.__class__ = UnboundedContinuous
+ elif (
+ type(instance) in (UnboundedContinuousTensorSpec, UnboundedContinuous)
+ and instance.domain == "discrete"
+ ):
+ instance.__class__ = UnboundedDiscrete
+ return instance
+
+ def __instancecheck__(cls, instance):
+ check0 = super().__instancecheck__(instance)
+ if check0:
+ return True
+ parent_cls = cls.__bases__[-1]
+ return isinstance(instance, parent_cls)
+
+
+class CompositeSpec(Composite, metaclass=_LegacySpecMeta):
+ """Deprecated version of :class:`torchrl.data.Composite`."""
+
+ ...
+
+
+class OneHotDiscreteTensorSpec(OneHot, metaclass=_LegacySpecMeta):
+ """Deprecated version of :class:`torchrl.data.OneHot`."""
+
+ ...
+
+
+class MultiOneHotDiscreteTensorSpec(MultiOneHot, metaclass=_LegacySpecMeta):
+ """Deprecated version of :class:`torchrl.data.MultiOneHot`."""
+
+ ...
+
+
+class NonTensorSpec(NonTensor, metaclass=_LegacySpecMeta):
+ """Deprecated version of :class:`torchrl.data.NonTensor`."""
+
+ ...
+
+
+class MultiDiscreteTensorSpec(MultiCategorical, metaclass=_LegacySpecMeta):
+ """Deprecated version of :class:`torchrl.data.MultiCategorical`."""
+
+ ...
+
+
+class LazyStackedTensorSpec(Stacked, metaclass=_LegacySpecMeta):
+ """Deprecated version of :class:`torchrl.data.Stacked`."""
+
+ ...
+
+
+class LazyStackedCompositeSpec(StackedComposite, metaclass=_LegacySpecMeta):
+ """Deprecated version of :class:`torchrl.data.StackedComposite`."""
+
+ ...
+
+
+class DiscreteTensorSpec(Categorical, metaclass=_LegacySpecMeta):
+ """Deprecated version of :class:`torchrl.data.Categorical`."""
+
+ ...
+
+
+class BinaryDiscreteTensorSpec(Binary, metaclass=_LegacySpecMeta):
+ """Deprecated version of :class:`torchrl.data.Binary`."""
+
+ ...
+
+
+_BoundedLegacyMeta = type("_BoundedLegacyMeta", (_LegacySpecMeta, _BoundedMeta), {})
+
+
+class BoundedTensorSpec(Bounded, metaclass=_BoundedLegacyMeta):
+ """Deprecated version of :class:`torchrl.data.Bounded`."""
+
+ ...
+
+
+class _UnboundedContinuousMetaclass(_UnboundedMeta):
+ def __instancecheck__(cls, instance):
+ return isinstance(instance, Unbounded) and instance.domain == "continuous"
+
+
+_LegacyUnboundedContinuousMetaclass = type(
+ "_LegacyUnboundedDiscreteMetaclass",
+ (_UnboundedContinuousMetaclass, _LegacySpecMeta),
+ {},
+)
+
+
+class UnboundedContinuousTensorSpec(
+ Unbounded, metaclass=_LegacyUnboundedContinuousMetaclass
+):
+ """Deprecated version of :class:`torchrl.data.Unbounded` with continuous space."""
+
+ ...
+
+
+class _UnboundedDiscreteMetaclass(_UnboundedMeta):
+ def __instancecheck__(cls, instance):
+ return isinstance(instance, Unbounded) and instance.domain == "discrete"
+
+
+_LegacyUnboundedDiscreteMetaclass = type(
+ "_LegacyUnboundedDiscreteMetaclass",
+ (_UnboundedDiscreteMetaclass, _LegacySpecMeta),
+ {},
+)
+
+
+class UnboundedDiscreteTensorSpec(
+ Unbounded, metaclass=_LegacyUnboundedDiscreteMetaclass
+):
+ """Deprecated version of :class:`torchrl.data.Unbounded` with discrete space."""
+
+ def __init__(
+ self,
+ shape: Union[torch.Size, int] = _DEFAULT_SHAPE,
+ device: Optional[DEVICE_TYPING] = None,
+ dtype: Optional[Union[str, torch.dtype]] = torch.int64,
+ **kwargs,
+ ):
+ super().__init__(shape=shape, device=device, dtype=dtype, **kwargs)
diff --git a/torchrl/data/utils.py b/torchrl/data/utils.py
index fb4ec30daed..db2c8afca10 100644
--- a/torchrl/data/utils.py
+++ b/torchrl/data/utils.py
@@ -4,6 +4,7 @@
# LICENSE file in the root directory of this source tree.
from __future__ import annotations
+import functools
import typing
from typing import Any, Callable, List, Tuple, Union
@@ -13,14 +14,14 @@
from torch import Tensor
from torchrl.data.tensor_specs import (
- BinaryDiscreteTensorSpec,
- CompositeSpec,
- DiscreteTensorSpec,
- LazyStackedCompositeSpec,
- LazyStackedTensorSpec,
- MultiDiscreteTensorSpec,
- MultiOneHotDiscreteTensorSpec,
- OneHotDiscreteTensorSpec,
+ Binary,
+ Categorical,
+ Composite,
+ MultiCategorical,
+ MultiOneHot,
+ OneHot,
+ Stacked,
+ StackedComposite,
TensorSpec,
)
@@ -50,10 +51,10 @@
ACTION_SPACE_MAP = {
- OneHotDiscreteTensorSpec: "one_hot",
- MultiOneHotDiscreteTensorSpec: "mult_one_hot",
- BinaryDiscreteTensorSpec: "binary",
- DiscreteTensorSpec: "categorical",
+ OneHot: "one_hot",
+ MultiOneHot: "mult_one_hot",
+ Binary: "binary",
+ Categorical: "categorical",
"one_hot": "one_hot",
"one-hot": "one_hot",
"mult_one_hot": "mult_one_hot",
@@ -62,7 +63,7 @@
"multi-one-hot": "mult_one_hot",
"binary": "binary",
"categorical": "categorical",
- MultiDiscreteTensorSpec: "multi_categorical",
+ MultiCategorical: "multi_categorical",
"multi_categorical": "multi_categorical",
"multi-categorical": "multi_categorical",
"multi_discrete": "multi_categorical",
@@ -71,14 +72,14 @@
def consolidate_spec(
- spec: CompositeSpec,
+ spec: Composite,
recurse_through_entries: bool = True,
recurse_through_stack: bool = True,
):
"""Given a TensorSpec, removes exclusive keys by adding 0 shaped specs.
Args:
- spec (CompositeSpec): the spec to be consolidated.
+ spec (Composite): the spec to be consolidated.
recurse_through_entries (bool): if True, call the function recursively on all entries of the spec.
Default is True.
recurse_through_stack (bool): if True, if the provided spec is lazy, the function recursively
@@ -87,10 +88,10 @@ def consolidate_spec(
"""
spec = spec.clone()
- if not isinstance(spec, (CompositeSpec, LazyStackedCompositeSpec)):
+ if not isinstance(spec, (Composite, StackedComposite)):
return spec
- if isinstance(spec, LazyStackedCompositeSpec):
+ if isinstance(spec, StackedComposite):
keys = set(spec.keys()) # shared keys
exclusive_keys_per_spec = [
set() for _ in range(len(spec._specs))
@@ -128,7 +129,7 @@ def consolidate_spec(
if recurse_through_entries:
for key, value in spec.items():
- if isinstance(value, (CompositeSpec, LazyStackedCompositeSpec)):
+ if isinstance(value, (Composite, StackedComposite)):
spec.set(
key,
consolidate_spec(
@@ -145,16 +146,16 @@ def _empty_like_spec(specs: List[TensorSpec], shape):
"Found same key in lazy specs corresponding to entries with different classes"
)
spec = specs[0]
- if isinstance(spec, (CompositeSpec, LazyStackedCompositeSpec)):
+ if isinstance(spec, (Composite, StackedComposite)):
# the exclusive key has values which are CompositeSpecs ->
# we create an empty composite spec with same batch size
return spec.empty()
- elif isinstance(spec, LazyStackedTensorSpec):
+ elif isinstance(spec, Stacked):
# the exclusive key has values which are LazyStackedTensorSpecs ->
# we create a LazyStackedTensorSpec with the same shape (aka same -1s) as the first in the list.
# this will not add any new -1s when they are stacked
shape = list(shape[: spec.stack_dim]) + list(shape[spec.stack_dim + 1 :])
- return LazyStackedTensorSpec(
+ return Stacked(
*[_empty_like_spec(spec._specs, shape) for _ in spec._specs],
dim=spec.stack_dim,
)
@@ -191,14 +192,14 @@ def check_no_exclusive_keys(spec: TensorSpec, recurse: bool = True):
spec (TensorSpec): the spec to check
recurse (bool): if True, check recursively in nested specs. Default is True.
"""
- if isinstance(spec, LazyStackedCompositeSpec):
+ if isinstance(spec, StackedComposite):
keys = set(spec.keys())
for inner_td in spec._specs:
if recurse and not check_no_exclusive_keys(inner_td):
return False
if set(inner_td.keys()) != keys:
return False
- elif isinstance(spec, CompositeSpec) and recurse:
+ elif isinstance(spec, Composite) and recurse:
for value in spec.values():
if not check_no_exclusive_keys(value):
return False
@@ -214,9 +215,9 @@ def contains_lazy_spec(spec: TensorSpec) -> bool:
spec (TensorSpec): the spec to check
"""
- if isinstance(spec, (LazyStackedTensorSpec, LazyStackedCompositeSpec)):
+ if isinstance(spec, (Stacked, StackedComposite)):
return True
- elif isinstance(spec, CompositeSpec):
+ elif isinstance(spec, Composite):
for inner_spec in spec.values():
if contains_lazy_spec(inner_spec):
return True
@@ -235,6 +236,8 @@ def __init__(self, fn: Callable, **kwargs):
self.fn = fn
self.kwargs = kwargs
+ functools.update_wrapper(self, getattr(fn, "forward", fn))
+
def __getstate__(self):
import cloudpickle
@@ -244,6 +247,7 @@ def __setstate__(self, ob: bytes):
import pickle
self.fn, self.kwargs = pickle.loads(ob)
+ functools.update_wrapper(self, self.fn)
def __call__(self, *args, **kwargs) -> Any:
kwargs.update(self.kwargs)
@@ -253,7 +257,7 @@ def __call__(self, *args, **kwargs) -> Any:
def _process_action_space_spec(action_space, spec):
original_spec = spec
composite_spec = False
- if isinstance(spec, CompositeSpec):
+ if isinstance(spec, Composite):
# this will break whenever our action is more complex than a single tensor
try:
if "action" in spec.keys():
@@ -274,8 +278,8 @@ def _process_action_space_spec(action_space, spec):
"with a leaf 'action' entry. Otherwise, simply remove the spec and use the action_space only."
)
if action_space is not None:
- if isinstance(action_space, CompositeSpec):
- raise ValueError("action_space cannot be of type CompositeSpec.")
+ if isinstance(action_space, Composite):
+ raise ValueError("action_space cannot be of type Composite.")
if (
spec is not None
and isinstance(action_space, TensorSpec)
@@ -305,7 +309,7 @@ def _process_action_space_spec(action_space, spec):
def _find_action_space(action_space):
if isinstance(action_space, TensorSpec):
- if isinstance(action_space, CompositeSpec):
+ if isinstance(action_space, Composite):
if "action" in action_space.keys():
_key = "action"
else:
diff --git a/torchrl/envs/__init__.py b/torchrl/envs/__init__.py
index ced185d7e00..047550fa9d7 100644
--- a/torchrl/envs/__init__.py
+++ b/torchrl/envs/__init__.py
@@ -28,12 +28,16 @@
MultiThreadedEnv,
MultiThreadedEnvWrapper,
OpenMLEnv,
+ OpenSpielEnv,
+ OpenSpielWrapper,
PettingZooEnv,
PettingZooWrapper,
RoboHiveEnv,
set_gym_backend,
SMACv2Env,
SMACv2Wrapper,
+ UnityMLAgentsEnv,
+ UnityMLAgentsWrapper,
VmasEnv,
VmasWrapper,
)
@@ -100,12 +104,10 @@
from .utils import (
check_env_specs,
check_marl_grouping,
- exploration_mode,
exploration_type,
ExplorationType,
make_composite_from_td,
MarlGroupMapType,
- set_exploration_mode,
set_exploration_type,
step_mdp,
)
diff --git a/torchrl/envs/batched_envs.py b/torchrl/envs/batched_envs.py
index 4996e527527..02c7f5893dc 100644
--- a/torchrl/envs/batched_envs.py
+++ b/torchrl/envs/batched_envs.py
@@ -28,6 +28,7 @@
TensorDictBase,
unravel_key,
)
+from tensordict.utils import _zip_strict
from torch import multiprocessing as mp
from torchrl._utils import (
_check_for_faulty_process,
@@ -36,7 +37,7 @@
logger as torchrl_logger,
VERBOSE,
)
-from torchrl.data.tensor_specs import CompositeSpec, NonTensorSpec
+from torchrl.data.tensor_specs import Composite, NonTensor
from torchrl.data.utils import CloudpickleWrapper, contains_lazy_spec, DEVICE_TYPING
from torchrl.envs.common import _do_nothing, _EnvPostInit, EnvBase, EnvMetaData
from torchrl.envs.env_creator import get_env_metadata
@@ -318,14 +319,20 @@ def __init__(
create_env_fn = [create_env_fn for _ in range(num_workers)]
elif len(create_env_fn) != num_workers:
raise RuntimeError(
- f"num_workers and len(create_env_fn) mismatch, "
- f"got {len(create_env_fn)} and {num_workers}"
+ f"len(create_env_fn) and num_workers mismatch, "
+ f"got {len(create_env_fn)} and {num_workers}."
)
+
create_env_kwargs = {} if create_env_kwargs is None else create_env_kwargs
if isinstance(create_env_kwargs, dict):
create_env_kwargs = [
deepcopy(create_env_kwargs) for _ in range(num_workers)
]
+ elif len(create_env_kwargs) != num_workers:
+ raise RuntimeError(
+ f"len(create_env_kwargs) and num_workers mismatch, "
+ f"got {len(create_env_kwargs)} and {num_workers}."
+ )
self.policy_proof = policy_proof
self.num_workers = num_workers
@@ -534,7 +541,11 @@ def update_kwargs(self, kwargs: Union[dict, List[dict]]) -> None:
for _kwargs in self.create_env_kwargs:
_kwargs.update(kwargs)
else:
- for _kwargs, _new_kwargs in zip(self.create_env_kwargs, kwargs):
+ if len(kwargs) != self.num_workers:
+ raise RuntimeError(
+ f"len(kwargs) and num_workers mismatch, got {len(kwargs)} and {self.num_workers}."
+ )
+ for _kwargs, _new_kwargs in _zip_strict(self.create_env_kwargs, kwargs):
_kwargs.update(_new_kwargs)
def _get_in_keys_to_exclude(self, tensordict):
@@ -550,7 +561,7 @@ def _set_properties(self):
cls = type(self)
- def _check_for_empty_spec(specs: CompositeSpec):
+ def _check_for_empty_spec(specs: Composite):
for subspec in (
"full_state_spec",
"full_action_spec",
@@ -559,9 +570,9 @@ def _check_for_empty_spec(specs: CompositeSpec):
"full_observation_spec",
):
for key, spec in reversed(
- list(specs.get(subspec, default=CompositeSpec()).items(True))
+ list(specs.get(subspec, default=Composite()).items(True))
):
- if isinstance(spec, CompositeSpec) and spec.is_empty():
+ if isinstance(spec, Composite) and spec.is_empty():
raise RuntimeError(
f"The environment passed to {cls.__name__} has empty specs in {key}. Consider using "
f"torchrl.envs.transforms.RemoveEmptySpecs to remove the empty specs."
@@ -675,7 +686,7 @@ def _create_td(self) -> None:
self.full_done_spec,
):
for key, _spec in spec.items(True, True):
- if isinstance(_spec, NonTensorSpec):
+ if isinstance(_spec, NonTensor):
non_tensor_keys.append(key)
self._non_tensor_keys = non_tensor_keys
@@ -1031,12 +1042,18 @@ def _reset(self, tensordict: TensorDictBase, **kwargs) -> TensorDictBase:
if out_tds is not None:
out_tds[i] = _td
+ device = self.device
if not self._use_buffers:
result = LazyStackedTensorDict.maybe_dense_stack(out_tds)
+ if result.device != device:
+ if device is None:
+ result = result.clear_device_()
+ else:
+ result = result.to(device, non_blocking=self.non_blocking)
+ self._sync_w2m()
return result
selected_output_keys = self._selected_reset_keys_filt
- device = self.device
# select + clone creates 2 tds, but we can create one only
def select_and_clone(name, tensor):
@@ -1066,18 +1083,30 @@ def _step(
self,
tensordict: TensorDict,
) -> TensorDict:
- tensordict_in = tensordict.clone(False)
+ partial_steps = tensordict.get("_step", None)
+ tensordict_save = tensordict
+ if partial_steps is not None and partial_steps.all():
+ partial_steps = None
+ if partial_steps is not None:
+ partial_steps = partial_steps.view(tensordict.shape)
+ tensordict = tensordict[partial_steps]
+ workers_range = partial_steps.nonzero(as_tuple=True)[0].tolist()
+ tensordict_in = tensordict
+ else:
+ workers_range = range(self.num_workers)
+ tensordict_in = tensordict.clone(False)
+ # if self._use_buffers:
+ # shared_tensordict_parent = self.shared_tensordict_parent
+
data_in = []
- for i in range(self.num_workers):
+ for i, td_ in zip(workers_range, tensordict_in):
# shared_tensordicts are locked, and we need to select the keys since we update in-place.
# There may be unexpected keys, such as "_reset", that we should comfortably ignore here.
env_device = self._envs[i].device
if env_device != self.device and env_device is not None:
- data_in.append(
- tensordict_in[i].to(env_device, non_blocking=self.non_blocking)
- )
+ data_in.append(td_.to(env_device, non_blocking=self.non_blocking))
else:
- data_in.append(tensordict_in[i])
+ data_in.append(td_)
self._sync_m2w()
out_tds = None
@@ -1086,7 +1115,7 @@ def _step(
if self._use_buffers:
next_td = self.shared_tensordict_parent.get("next")
- for i, _data_in in enumerate(data_in):
+ for i, _data_in in zip(workers_range, data_in):
out_td = self._envs[i]._step(_data_in)
next_td[i].update_(
out_td,
@@ -1095,32 +1124,43 @@ def _step(
)
if out_tds is not None:
out_tds.append(out_td)
- else:
- for i, _data_in in enumerate(data_in):
- out_td = self._envs[i]._step(_data_in)
- out_tds.append(out_td)
- return LazyStackedTensorDict.maybe_dense_stack(out_tds)
- # We must pass a clone of the tensordict, as the values of this tensordict
- # will be modified in-place at further steps
- device = self.device
+ # We must pass a clone of the tensordict, as the values of this tensordict
+ # will be modified in-place at further steps
+ device = self.device
- def select_and_clone(name, tensor):
- if name in self._selected_step_keys:
- return tensor.clone()
+ def select_and_clone(name, tensor):
+ if name in self._selected_step_keys:
+ return tensor.clone()
- out = next_td.named_apply(select_and_clone, nested_keys=True, filter_empty=True)
- if out_tds is not None:
- out.update(
- LazyStackedTensorDict(*out_tds), keys_to_update=self._non_tensor_keys
+ if partial_steps is not None:
+ next_td = TensorDict.lazy_stack([next_td[i] for i in workers_range])
+ out = next_td.named_apply(
+ select_and_clone, nested_keys=True, filter_empty=True
)
+ if out_tds is not None:
+ out.update(
+ LazyStackedTensorDict(*out_tds),
+ keys_to_update=self._non_tensor_keys,
+ )
+
+ if out.device != device:
+ if device is None:
+ out = out.clear_device_()
+ elif out.device != device:
+ out = out.to(device, non_blocking=self.non_blocking)
+ self._sync_w2m()
+ else:
+ for i, _data_in in zip(workers_range, data_in):
+ out_td = self._envs[i]._step(_data_in)
+ out_tds.append(out_td)
+ out = LazyStackedTensorDict.maybe_dense_stack(out_tds)
+
+ if partial_steps is not None:
+ result = out.new_zeros(tensordict_save.shape)
+ result[partial_steps] = out
+ return result
- if out.device != device:
- if device is None:
- out = out.clear_device_()
- elif out.device != device:
- out = out.to(device, non_blocking=self.non_blocking)
- self._sync_w2m()
return out
def __getattr__(self, attr: str) -> Any:
@@ -1435,20 +1475,30 @@ def load_state_dict(self, state_dict: OrderedDict) -> None:
def _step_and_maybe_reset_no_buffers(
self, tensordict: TensorDictBase
) -> Tuple[TensorDictBase, TensorDictBase]:
+ partial_steps = tensordict.get("_step", None)
+ tensordict_save = tensordict
+ if partial_steps is not None and partial_steps.all():
+ partial_steps = None
+ if partial_steps is not None:
+ partial_steps = partial_steps.view(tensordict.shape)
+ tensordict = tensordict[partial_steps]
+ workers_range = partial_steps.nonzero(as_tuple=True)[0].tolist()
+ else:
+ workers_range = range(self.num_workers)
td = tensordict.consolidate(share_memory=True, inplace=True, num_threads=1)
- for i in range(td.shape[0]):
+ for i in workers_range:
# We send the same td multiple times as it is in shared mem and we just need to index it
# in each process.
# If we don't do this, we need to unbind it but then the custom pickler will require
# some extra metadata to be collected.
self.parent_channels[i].send(("step_and_maybe_reset", (td, i)))
- results = [None] * self.num_workers
+ results = [None] * len(workers_range)
consumed_indices = []
- events = set(range(self.num_workers))
- while len(consumed_indices) < self.num_workers:
+ events = set(workers_range)
+ while len(consumed_indices) < len(workers_range):
for i in list(events):
if self._events[i].is_set():
results[i] = self.parent_channels[i].recv()
@@ -1457,9 +1507,14 @@ def _step_and_maybe_reset_no_buffers(
events.discard(i)
out_next, out_root = zip(*(future for future in results))
- return TensorDict.maybe_dense_stack(out_next), TensorDict.maybe_dense_stack(
+ out = TensorDict.maybe_dense_stack(out_next), TensorDict.maybe_dense_stack(
out_root
)
+ if partial_steps is not None:
+ result = out.new_zeros(tensordict_save.shape)
+ result[partial_steps] = out
+ return result
+ return out
@torch.no_grad()
@_check_start
@@ -1471,6 +1526,42 @@ def step_and_maybe_reset(
# return self._step_and_maybe_reset_no_buffers(tensordict)
return super().step_and_maybe_reset(tensordict)
+ partial_steps = tensordict.get("_step", None)
+ tensordict_save = tensordict
+ if partial_steps is not None and partial_steps.all():
+ partial_steps = None
+ if partial_steps is not None:
+ partial_steps = partial_steps.view(tensordict.shape)
+ workers_range = partial_steps.nonzero(as_tuple=True)[0].tolist()
+ shared_tensordict_parent = TensorDict.lazy_stack(
+ [self.shared_tensordict_parent[i] for i in workers_range]
+ )
+ next_td = TensorDict.lazy_stack(
+ [self._shared_tensordict_parent_next[i] for i in workers_range]
+ )
+ tensordict_ = TensorDict.lazy_stack(
+ [self._shared_tensordict_parent_root[i] for i in workers_range]
+ )
+ if self.shared_tensordict_parent.device is None:
+ tensordict = tensordict._fast_apply(
+ lambda x, y: x[partial_steps].to(y.device)
+ if y is not None
+ else x[partial_steps],
+ self.shared_tensordict_parent,
+ default=None,
+ device=None,
+ batch_size=shared_tensordict_parent.shape,
+ )
+ else:
+ tensordict = tensordict[partial_steps].to(
+ self.shared_tensordict_parent.device
+ )
+ else:
+ workers_range = range(self.num_workers)
+ shared_tensordict_parent = self.shared_tensordict_parent
+ next_td = self._shared_tensordict_parent_next
+ tensordict_ = self._shared_tensordict_parent_root
+
# We must use the in_keys and nothing else for the following reasons:
# - efficiency: copying all the keys will in practice mean doing a lot
# of writing operations since the input tensordict may (and often will)
@@ -1479,7 +1570,7 @@ def step_and_maybe_reset(
# and this transform overrides an observation key (eg, CatFrames)
# the shape, dtype or device may not necessarily match and writing
# the value in-place will fail.
- self.shared_tensordict_parent.update_(
+ shared_tensordict_parent.update_(
tensordict,
keys_to_update=self._env_input_keys,
non_blocking=self.non_blocking,
@@ -1489,46 +1580,41 @@ def step_and_maybe_reset(
# if we have input "next" data (eg, RNNs which pass the next state)
# the sub-envs will need to process them through step_and_maybe_reset.
# We keep track of which keys are present to let the worker know what
- # should be passd to the env (we don't want to pass done states for instance)
+ # should be passed to the env (we don't want to pass done states for instance)
next_td_keys = list(next_td_passthrough.keys(True, True))
- data = [
- {"next_td_passthrough_keys": next_td_keys}
- for _ in range(self.num_workers)
- ]
- self.shared_tensordict_parent.get("next").update_(
+ data = [{"next_td_passthrough_keys": next_td_keys} for _ in workers_range]
+ shared_tensordict_parent.get("next").update_(
next_td_passthrough, non_blocking=self.non_blocking
)
else:
# next_td_keys = None
- data = [{} for _ in range(self.num_workers)]
+ data = [{} for _ in workers_range]
if self._non_tensor_keys:
- for i in range(self.num_workers):
+ for i in workers_range:
data[i]["non_tensor_data"] = tensordict[i].select(
*self._non_tensor_keys, strict=False
)
self._sync_m2w()
- for i in range(self.num_workers):
- self.parent_channels[i].send(("step_and_maybe_reset", data[i]))
+ for i, _data in zip(workers_range, data):
+ self.parent_channels[i].send(("step_and_maybe_reset", _data))
- for i in range(self.num_workers):
+ for i in workers_range:
event = self._events[i]
event.wait(self._timeout)
event.clear()
if self._non_tensor_keys:
non_tensor_tds = []
- for i in range(self.num_workers):
+ for i in workers_range:
msg, non_tensor_td = self.parent_channels[i].recv()
non_tensor_tds.append(non_tensor_td)
# We must pass a clone of the tensordict, as the values of this tensordict
# will be modified in-place at further steps
- next_td = self._shared_tensordict_parent_next
- tensordict_ = self._shared_tensordict_parent_root
device = self.device
- if self.shared_tensordict_parent.device == device:
+ if shared_tensordict_parent.device == device:
next_td = next_td.clone()
tensordict_ = tensordict_.clone()
elif device is not None:
@@ -1558,22 +1644,49 @@ def step_and_maybe_reset(
keys_to_update=[("next", key) for key in self._non_tensor_keys],
)
tensordict_.update(non_tensor_tds, keys_to_update=self._non_tensor_keys)
+
+ if partial_steps is not None:
+ result = tensordict.new_zeros(tensordict_save.shape)
+ result_ = tensordict_.new_zeros(tensordict_save.shape)
+ result[partial_steps] = tensordict
+ result_[partial_steps] = tensordict_
+ return result, result_
+
return tensordict, tensordict_
def _step_no_buffers(
self, tensordict: TensorDictBase
) -> Tuple[TensorDictBase, TensorDictBase]:
+ partial_steps = tensordict.get("_step", None)
+ tensordict_save = tensordict
+ if partial_steps is not None and partial_steps.all():
+ partial_steps = None
+ if partial_steps is not None:
+ partial_steps = partial_steps.view(tensordict.shape)
+ tensordict = tensordict[partial_steps]
+ workers_range = partial_steps.nonzero(as_tuple=True)[0].tolist()
+ else:
+ workers_range = range(self.num_workers)
+
data = tensordict.consolidate(share_memory=True, inplace=True, num_threads=1)
- for i, local_data in enumerate(data.unbind(0)):
+ for i, local_data in zip(workers_range, data.unbind(0)):
self.parent_channels[i].send(("step", local_data))
# for i in range(data.shape[0]):
# self.parent_channels[i].send(("step", (data, i)))
out_tds = []
- for i, channel in enumerate(self.parent_channels):
+ for i in workers_range:
+ channel = self.parent_channels[i]
self._events[i].wait()
td = channel.recv()
out_tds.append(td)
- return LazyStackedTensorDict.maybe_dense_stack(out_tds)
+ out = LazyStackedTensorDict.maybe_dense_stack(out_tds)
+ if self.device is not None and out.device != self.device:
+ out = out.to(self.device, non_blocking=self.non_blocking)
+ if partial_steps is not None:
+ result = out.new_zeros(tensordict_save.shape)
+ result[partial_steps] = out
+ return result
+ return out
@torch.no_grad()
@_check_start
@@ -1588,8 +1701,35 @@ def _step(self, tensordict: TensorDictBase) -> TensorDictBase:
# and this transform overrides an observation key (eg, CatFrames)
# the shape, dtype or device may not necessarily match and writing
# the value in-place will fail.
+ partial_steps = tensordict.get("_step", None)
+ tensordict_save = tensordict
+ if partial_steps is not None and partial_steps.all():
+ partial_steps = None
+ if partial_steps is not None:
+ partial_steps = partial_steps.view(tensordict.shape)
+ workers_range = partial_steps.nonzero(as_tuple=True)[0].tolist()
+ shared_tensordict_parent = TensorDict.lazy_stack(
+ [self.shared_tensordicts[i] for i in workers_range]
+ )
+ if self.shared_tensordict_parent.device is None:
+ tensordict = tensordict._fast_apply(
+ lambda x, y: x[partial_steps].to(y.device)
+ if y is not None
+ else x[partial_steps],
+ self.shared_tensordict_parent,
+ default=None,
+ device=None,
+ batch_size=shared_tensordict_parent.shape,
+ )
+ else:
+ tensordict = tensordict[partial_steps].to(
+ self.shared_tensordict_parent.device
+ )
+ else:
+ workers_range = range(self.num_workers)
+ shared_tensordict_parent = self.shared_tensordict_parent
- self.shared_tensordict_parent.update_(
+ shared_tensordict_parent.update_(
tensordict,
keys_to_update=list(self._env_input_keys),
non_blocking=self.non_blocking,
@@ -1599,20 +1739,20 @@ def _step(self, tensordict: TensorDictBase) -> TensorDictBase:
# if we have input "next" data (eg, RNNs which pass the next state)
# the sub-envs will need to process them through step_and_maybe_reset.
# We keep track of which keys are present to let the worker know what
- # should be passd to the env (we don't want to pass done states for instance)
+ # should be passed to the env (we don't want to pass done states for instance)
next_td_keys = list(next_td_passthrough.keys(True, True))
data = [
{"next_td_passthrough_keys": next_td_keys}
for _ in range(self.num_workers)
]
- self.shared_tensordict_parent.get("next").update_(
+ shared_tensordict_parent.get("next").update_(
next_td_passthrough, non_blocking=self.non_blocking
)
else:
data = [{} for _ in range(self.num_workers)]
if self._non_tensor_keys:
- for i in range(self.num_workers):
+ for i in workers_range:
data[i]["non_tensor_data"] = tensordict[i].select(
*self._non_tensor_keys, strict=False
)
@@ -1622,23 +1762,23 @@ def _step(self, tensordict: TensorDictBase) -> TensorDictBase:
if self.event is not None:
self.event.record()
self.event.synchronize()
- for i in range(self.num_workers):
+ for i in workers_range:
self.parent_channels[i].send(("step", data[i]))
- for i in range(self.num_workers):
+ for i in workers_range:
event = self._events[i]
event.wait(self._timeout)
event.clear()
if self._non_tensor_keys:
non_tensor_tds = []
- for i in range(self.num_workers):
+ for i in workers_range:
msg, non_tensor_td = self.parent_channels[i].recv()
non_tensor_tds.append(non_tensor_td)
# We must pass a clone of the tensordict, as the values of this tensordict
# will be modified in-place at further steps
- next_td = self.shared_tensordict_parent.get("next")
+ next_td = shared_tensordict_parent.get("next")
device = self.device
if next_td.device != device and device is not None:
@@ -1665,6 +1805,10 @@ def select_and_clone(name, tensor):
keys_to_update=self._non_tensor_keys,
)
self._sync_w2m()
+ if partial_steps is not None:
+ result = out.new_zeros(tensordict_save.shape)
+ result[partial_steps] = out
+ return result
return out
def _reset_no_buffers(
@@ -1698,7 +1842,11 @@ def _reset_no_buffers(
self._events[i].wait()
td = channel.recv()
out_tds[i] = td
- return LazyStackedTensorDict.maybe_dense_stack(out_tds)
+ result = LazyStackedTensorDict.maybe_dense_stack(out_tds)
+ device = self.device
+ if device is not None and result.device != device:
+ return result.to(self.device, non_blocking=self.non_blocking)
+ return result
@torch.no_grad()
@_check_start
diff --git a/torchrl/envs/common.py b/torchrl/envs/common.py
index b9216b58e86..31ff0b905af 100644
--- a/torchrl/envs/common.py
+++ b/torchrl/envs/common.py
@@ -25,12 +25,7 @@
seed_generator,
)
-from torchrl.data.tensor_specs import (
- CompositeSpec,
- DiscreteTensorSpec,
- TensorSpec,
- UnboundedContinuousTensorSpec,
-)
+from torchrl.data.tensor_specs import Categorical, Composite, TensorSpec, Unbounded
from torchrl.data.utils import DEVICE_TYPING
from torchrl.envs.utils import (
_make_compatible_policy,
@@ -62,7 +57,7 @@ def __init__(
self,
*,
tensordict: TensorDictBase,
- specs: CompositeSpec,
+ specs: Composite,
batch_size: torch.Size,
env_str: str,
device: torch.device,
@@ -91,7 +86,7 @@ def tensordict(self, value: TensorDictBase):
self._tensordict = value.to("cpu")
@specs.setter
- def specs(self, value: CompositeSpec):
+ def specs(self, value: Composite):
self._specs = value.to("cpu")
@staticmethod
@@ -191,6 +186,27 @@ def __call__(cls, *args, **kwargs):
return AutoResetEnv(
instance, AutoResetTransform(replace=auto_reset_replace)
)
+
+ done_keys = set(instance.full_done_spec.keys(True, True))
+ obs_keys = set(instance.full_observation_spec.keys(True, True))
+ reward_keys = set(instance.full_reward_spec.keys(True, True))
+ # state_keys can match obs_keys so we don't test that
+ action_keys = set(instance.full_action_spec.keys(True, True))
+ state_keys = set(instance.full_state_spec.keys(True, True))
+ total_set = set()
+ for keyset in (done_keys, obs_keys, reward_keys):
+ if total_set.intersection(keyset):
+ raise RuntimeError(
+ f"The set of keys of one spec collides (culprit: {total_set.intersection(keyset)}) with another."
+ )
+ total_set = total_set.union(keyset)
+ total_set = set()
+ for keyset in (state_keys, action_keys):
+ if total_set.intersection(keyset):
+ raise RuntimeError(
+ f"The set of keys of one spec collides (culprit: {total_set.intersection(keyset)}) with another."
+ )
+ total_set = total_set.union(keyset)
return instance
@@ -212,29 +228,29 @@ class EnvBase(nn.Module, metaclass=_EnvPostInit):
be done after a call to :meth:`~.reset` is made. Defaults to ``False``.
Attributes:
- done_spec (CompositeSpec): equivalent to ``full_done_spec`` as all
+ done_spec (Composite): equivalent to ``full_done_spec`` as all
``done_specs`` contain at least a ``"done"`` and a ``"terminated"`` entry
action_spec (TensorSpec): the spec of the action. Links to the spec of the leaf
action if only one action tensor is to be expected. Otherwise links to
``full_action_spec``.
- observation_spec (CompositeSpec): equivalent to ``full_observation_spec``.
+ observation_spec (Composite): equivalent to ``full_observation_spec``.
reward_spec (TensorSpec): the spec of the reward. Links to the spec of the leaf
reward if only one reward tensor is to be expected. Otherwise links to
``full_reward_spec``.
- state_spec (CompositeSpec): equivalent to ``full_state_spec``.
- full_done_spec (CompositeSpec): a composite spec such that ``full_done_spec.zero()``
+ state_spec (Composite): equivalent to ``full_state_spec``.
+ full_done_spec (Composite): a composite spec such that ``full_done_spec.zero()``
returns a tensordict containing only the leaves encoding the done status of the
environment.
- full_action_spec (CompositeSpec): a composite spec such that ``full_action_spec.zero()``
+ full_action_spec (Composite): a composite spec such that ``full_action_spec.zero()``
returns a tensordict containing only the leaves encoding the action of the
environment.
- full_observation_spec (CompositeSpec): a composite spec such that ``full_observation_spec.zero()``
+ full_observation_spec (Composite): a composite spec such that ``full_observation_spec.zero()``
returns a tensordict containing only the leaves encoding the observation of the
environment.
- full_reward_spec (CompositeSpec): a composite spec such that ``full_reward_spec.zero()``
+ full_reward_spec (Composite): a composite spec such that ``full_reward_spec.zero()``
returns a tensordict containing only the leaves encoding the reward of the
environment.
- full_state_spec (CompositeSpec): a composite spec such that ``full_state_spec.zero()``
+ full_state_spec (Composite): a composite spec such that ``full_state_spec.zero()``
returns a tensordict containing only the leaves encoding the inputs (actions
excluded) of the environment.
batch_size (torch.Size): The batch-size of the environment.
@@ -253,9 +269,9 @@ class EnvBase(nn.Module, metaclass=_EnvPostInit):
>>> from torchrl.envs import EnvBase
>>> class CounterEnv(EnvBase):
... def __init__(self, batch_size=(), device=None, **kwargs):
- ... self.observation_spec = CompositeSpec(
- ... count=UnboundedContinuousTensorSpec(batch_size, device=device, dtype=torch.int64))
- ... self.action_spec = UnboundedContinuousTensorSpec(batch_size, device=device, dtype=torch.int8)
+ ... self.observation_spec = Composite(
+ ... count=Unbounded(batch_size, device=device, dtype=torch.int64))
+ ... self.action_spec = Unbounded(batch_size, device=device, dtype=torch.int8)
... # done spec and reward spec are set automatically
... def _step(self, tensordict):
...
@@ -264,10 +280,10 @@ class EnvBase(nn.Module, metaclass=_EnvPostInit):
>>> env.batch_size # how many envs are run at once
torch.Size([])
>>> env.input_spec
- CompositeSpec(
+ Composite(
full_state_spec: None,
- full_action_spec: CompositeSpec(
- action: BoundedTensorSpec(
+ full_action_spec: Composite(
+ action: BoundedContinuous(
shape=torch.Size([1]),
space=ContinuousBox(
low=Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, contiguous=True),
@@ -276,7 +292,7 @@ class EnvBase(nn.Module, metaclass=_EnvPostInit):
dtype=torch.float32,
domain=continuous), device=cpu, shape=torch.Size([])), device=cpu, shape=torch.Size([]))
>>> env.action_spec
- BoundedTensorSpec(
+ BoundedContinuous(
shape=torch.Size([1]),
space=ContinuousBox(
low=Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, contiguous=True),
@@ -285,8 +301,8 @@ class EnvBase(nn.Module, metaclass=_EnvPostInit):
dtype=torch.float32,
domain=continuous)
>>> env.observation_spec
- CompositeSpec(
- observation: BoundedTensorSpec(
+ Composite(
+ observation: BoundedContinuous(
shape=torch.Size([3]),
space=ContinuousBox(
low=Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, contiguous=True),
@@ -295,14 +311,14 @@ class EnvBase(nn.Module, metaclass=_EnvPostInit):
dtype=torch.float32,
domain=continuous), device=cpu, shape=torch.Size([]))
>>> env.reward_spec
- UnboundedContinuousTensorSpec(
+ UnboundedContinuous(
shape=torch.Size([1]),
space=None,
device=cpu,
dtype=torch.float32,
domain=continuous)
>>> env.done_spec
- DiscreteTensorSpec(
+ Categorical(
shape=torch.Size([1]),
space=DiscreteBox(n=2),
device=cpu,
@@ -310,16 +326,16 @@ class EnvBase(nn.Module, metaclass=_EnvPostInit):
domain=discrete)
>>> # the output_spec contains all the expected outputs
>>> env.output_spec
- CompositeSpec(
- full_reward_spec: CompositeSpec(
- reward: UnboundedContinuousTensorSpec(
+ Composite(
+ full_reward_spec: Composite(
+ reward: UnboundedContinuous(
shape=torch.Size([1]),
space=None,
device=cpu,
dtype=torch.float32,
domain=continuous), device=cpu, shape=torch.Size([])),
- full_observation_spec: CompositeSpec(
- observation: BoundedTensorSpec(
+ full_observation_spec: Composite(
+ observation: BoundedContinuous(
shape=torch.Size([3]),
space=ContinuousBox(
low=Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, contiguous=True),
@@ -327,8 +343,8 @@ class EnvBase(nn.Module, metaclass=_EnvPostInit):
device=cpu,
dtype=torch.float32,
domain=continuous), device=cpu, shape=torch.Size([])),
- full_done_spec: CompositeSpec(
- done: DiscreteTensorSpec(
+ full_done_spec: Composite(
+ done: Categorical(
shape=torch.Size([1]),
space=DiscreteBox(n=2),
device=cpu,
@@ -544,10 +560,10 @@ def input_spec(self) -> TensorSpec:
>>> from torchrl.envs.libs.gym import GymEnv
>>> env = GymEnv("Pendulum-v1")
>>> env.input_spec
- CompositeSpec(
+ Composite(
full_state_spec: None,
- full_action_spec: CompositeSpec(
- action: BoundedTensorSpec(
+ full_action_spec: Composite(
+ action: BoundedContinuous(
shape=torch.Size([1]),
space=ContinuousBox(
low=Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, contiguous=True),
@@ -560,7 +576,7 @@ def input_spec(self) -> TensorSpec:
"""
input_spec = self.__dict__.get("_input_spec")
if input_spec is None:
- input_spec = CompositeSpec(
+ input_spec = Composite(
full_state_spec=None,
shape=self.batch_size,
device=self.device,
@@ -591,16 +607,16 @@ def output_spec(self) -> TensorSpec:
>>> from torchrl.envs.libs.gym import GymEnv
>>> env = GymEnv("Pendulum-v1")
>>> env.output_spec
- CompositeSpec(
- full_reward_spec: CompositeSpec(
- reward: UnboundedContinuousTensorSpec(
+ Composite(
+ full_reward_spec: Composite(
+ reward: UnboundedContinuous(
shape=torch.Size([1]),
space=None,
device=cpu,
dtype=torch.float32,
domain=continuous), device=cpu, shape=torch.Size([])),
- full_observation_spec: CompositeSpec(
- observation: BoundedTensorSpec(
+ full_observation_spec: Composite(
+ observation: BoundedContinuous(
shape=torch.Size([3]),
space=ContinuousBox(
low=Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, contiguous=True),
@@ -608,8 +624,8 @@ def output_spec(self) -> TensorSpec:
device=cpu,
dtype=torch.float32,
domain=continuous), device=cpu, shape=torch.Size([])),
- full_done_spec: CompositeSpec(
- done: DiscreteTensorSpec(
+ full_done_spec: Composite(
+ done: Categorical(
shape=torch.Size([1]),
space=DiscreteBox(n=2),
device=cpu,
@@ -620,7 +636,7 @@ def output_spec(self) -> TensorSpec:
"""
output_spec = self.__dict__.get("_output_spec")
if output_spec is None:
- output_spec = CompositeSpec(
+ output_spec = Composite(
shape=self.batch_size,
device=self.device,
).lock_()
@@ -688,9 +704,9 @@ def action_spec(self) -> TensorSpec:
If the action spec is provided as a simple spec, this will be returned.
- >>> env.action_spec = UnboundedContinuousTensorSpec(1)
+ >>> env.action_spec = Unbounded(1)
>>> env.action_spec
- UnboundedContinuousTensorSpec(
+ UnboundedContinuous(
shape=torch.Size([1]),
space=ContinuousBox(
low=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, contiguous=True),
@@ -702,9 +718,9 @@ def action_spec(self) -> TensorSpec:
If the action spec is provided as a composite spec and contains only one leaf,
this function will return just the leaf.
- >>> env.action_spec = CompositeSpec({"nested": {"action": UnboundedContinuousTensorSpec(1)}})
+ >>> env.action_spec = Composite({"nested": {"action": Unbounded(1)}})
>>> env.action_spec
- UnboundedContinuousTensorSpec(
+ UnboundedContinuous(
shape=torch.Size([1]),
space=ContinuousBox(
low=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, contiguous=True),
@@ -716,11 +732,11 @@ def action_spec(self) -> TensorSpec:
If the action spec is provided as a composite spec and has more than one leaf,
this function will return the whole spec.
- >>> env.action_spec = CompositeSpec({"nested": {"action": UnboundedContinuousTensorSpec(1), "another_action": DiscreteTensorSpec(1)}})
+ >>> env.action_spec = Composite({"nested": {"action": Unbounded(1), "another_action": Categorical(1)}})
>>> env.action_spec
- CompositeSpec(
- nested: CompositeSpec(
- action: UnboundedContinuousTensorSpec(
+ Composite(
+ nested: Composite(
+ action: UnboundedContinuous(
shape=torch.Size([1]),
space=ContinuousBox(
low=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, contiguous=True),
@@ -728,7 +744,7 @@ def action_spec(self) -> TensorSpec:
device=cpu,
dtype=torch.float32,
domain=continuous),
- another_action: DiscreteTensorSpec(
+ another_action: Categorical(
shape=torch.Size([]),
space=DiscreteBox(n=1),
device=cpu,
@@ -745,7 +761,7 @@ def action_spec(self) -> TensorSpec:
>>> from torchrl.envs.libs.gym import GymEnv
>>> env = GymEnv("Pendulum-v1")
>>> env.action_spec
- BoundedTensorSpec(
+ BoundedContinuous(
shape=torch.Size([1]),
space=ContinuousBox(
low=Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, contiguous=True),
@@ -794,16 +810,16 @@ def action_spec(self, value: TensorSpec) -> None:
f"The value of spec.shape ({value.shape}) must match the env batch size ({self.batch_size})."
)
- if isinstance(value, CompositeSpec):
+ if isinstance(value, Composite):
for _ in value.values(True, True): # noqa: B007
break
else:
raise RuntimeError(
- "An empty CompositeSpec was passed for the action spec. "
+ "An empty Composite was passed for the action spec. "
"This is currently not permitted."
)
else:
- value = CompositeSpec(
+ value = Composite(
action=value.to(device), shape=self.batch_size, device=device
)
@@ -812,10 +828,10 @@ def action_spec(self, value: TensorSpec) -> None:
self.input_spec.lock_()
@property
- def full_action_spec(self) -> CompositeSpec:
+ def full_action_spec(self) -> Composite:
"""The full action spec.
- ``full_action_spec`` is a :class:`~torchrl.data.CompositeSpec`` instance
+ ``full_action_spec`` is a :class:`~torchrl.data.Composite`` instance
that contains all the action entries.
Examples:
@@ -824,8 +840,8 @@ def full_action_spec(self) -> CompositeSpec:
... break
>>> env = BraxEnv(envname)
>>> env.full_action_spec
- CompositeSpec(
- action: BoundedTensorSpec(
+ Composite(
+ action: BoundedContinuous(
shape=torch.Size([8]),
space=ContinuousBox(
low=Tensor(shape=torch.Size([8]), device=cpu, dtype=torch.float32, contiguous=True),
@@ -835,10 +851,16 @@ def full_action_spec(self) -> CompositeSpec:
domain=continuous), device=cpu, shape=torch.Size([]))
"""
- return self.input_spec["full_action_spec"]
+ full_action_spec = self.input_spec.get("full_action_spec", None)
+ if full_action_spec is None:
+ full_action_spec = Composite(shape=self.batch_size, device=self.device)
+ self.input_spec.unlock_()
+ self.input_spec["full_action_spec"] = full_action_spec
+ self.input_spec.lock_()
+ return full_action_spec
@full_action_spec.setter
- def full_action_spec(self, spec: CompositeSpec) -> None:
+ def full_action_spec(self, spec: Composite) -> None:
self.action_spec = spec
# Reward spec
@@ -881,9 +903,9 @@ def reward_spec(self) -> TensorSpec:
If the reward spec is provided as a simple spec, this will be returned.
- >>> env.reward_spec = UnboundedContinuousTensorSpec(1)
+ >>> env.reward_spec = Unbounded(1)
>>> env.reward_spec
- UnboundedContinuousTensorSpec(
+ UnboundedContinuous(
shape=torch.Size([1]),
space=ContinuousBox(
low=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, contiguous=True),
@@ -895,9 +917,9 @@ def reward_spec(self) -> TensorSpec:
If the reward spec is provided as a composite spec and contains only one leaf,
this function will return just the leaf.
- >>> env.reward_spec = CompositeSpec({"nested": {"reward": UnboundedContinuousTensorSpec(1)}})
+ >>> env.reward_spec = Composite({"nested": {"reward": Unbounded(1)}})
>>> env.reward_spec
- UnboundedContinuousTensorSpec(
+ UnboundedContinuous(
shape=torch.Size([1]),
space=ContinuousBox(
low=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, contiguous=True),
@@ -909,11 +931,11 @@ def reward_spec(self) -> TensorSpec:
If the reward spec is provided as a composite spec and has more than one leaf,
this function will return the whole spec.
- >>> env.reward_spec = CompositeSpec({"nested": {"reward": UnboundedContinuousTensorSpec(1), "another_reward": DiscreteTensorSpec(1)}})
+ >>> env.reward_spec = Composite({"nested": {"reward": Unbounded(1), "another_reward": Categorical(1)}})
>>> env.reward_spec
- CompositeSpec(
- nested: CompositeSpec(
- reward: UnboundedContinuousTensorSpec(
+ Composite(
+ nested: Composite(
+ reward: UnboundedContinuous(
shape=torch.Size([1]),
space=ContinuousBox(
low=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, contiguous=True),
@@ -921,7 +943,7 @@ def reward_spec(self) -> TensorSpec:
device=cpu,
dtype=torch.float32,
domain=continuous),
- another_reward: DiscreteTensorSpec(
+ another_reward: Categorical(
shape=torch.Size([]),
space=DiscreteBox(n=1),
device=cpu,
@@ -938,7 +960,7 @@ def reward_spec(self) -> TensorSpec:
>>> from torchrl.envs.libs.gym import GymEnv
>>> env = GymEnv("Pendulum-v1")
>>> env.reward_spec
- UnboundedContinuousTensorSpec(
+ UnboundedContinuous(
shape=torch.Size([1]),
space=None,
device=cpu,
@@ -952,7 +974,7 @@ def reward_spec(self) -> TensorSpec:
# this will be raised if there is not full_reward_spec (unlikely) or no reward_key
# Since output_spec is lazily populated with an empty composite spec for
# reward_spec, the second case is much more likely to occur.
- self.reward_spec = UnboundedContinuousTensorSpec(
+ self.reward_spec = Unbounded(
shape=(*self.batch_size, 1),
device=self.device,
)
@@ -982,16 +1004,16 @@ def reward_spec(self, value: TensorSpec) -> None:
raise ValueError(
f"The value of spec.shape ({value.shape}) must match the env batch size ({self.batch_size})."
)
- if isinstance(value, CompositeSpec):
+ if isinstance(value, Composite):
for _ in value.values(True, True): # noqa: B007
break
else:
raise RuntimeError(
- "An empty CompositeSpec was passed for the reward spec. "
+ "An empty Composite was passed for the reward spec. "
"This is currently not permitted."
)
else:
- value = CompositeSpec(
+ value = Composite(
reward=value.to(device), shape=self.batch_size, device=device
)
for leaf in value.values(True, True):
@@ -1007,10 +1029,10 @@ def reward_spec(self, value: TensorSpec) -> None:
self.output_spec.lock_()
@property
- def full_reward_spec(self) -> CompositeSpec:
+ def full_reward_spec(self) -> Composite:
"""The full reward spec.
- ``full_reward_spec`` is a :class:`~torchrl.data.CompositeSpec`` instance
+ ``full_reward_spec`` is a :class:`~torchrl.data.Composite`` instance
that contains all the reward entries.
Examples:
@@ -1019,9 +1041,9 @@ def full_reward_spec(self) -> CompositeSpec:
>>> base_env = GymWrapper(gymnasium.make("Pendulum-v1"))
>>> env = TransformedEnv(base_env, RenameTransform("reward", ("nested", "reward")))
>>> env.full_reward_spec
- CompositeSpec(
- nested: CompositeSpec(
- reward: UnboundedContinuousTensorSpec(
+ Composite(
+ nested: Composite(
+ reward: UnboundedContinuous(
shape=torch.Size([1]),
space=ContinuousBox(
low=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, contiguous=True),
@@ -1034,7 +1056,7 @@ def full_reward_spec(self) -> CompositeSpec:
return self.output_spec["full_reward_spec"]
@full_reward_spec.setter
- def full_reward_spec(self, spec: CompositeSpec) -> None:
+ def full_reward_spec(self, spec: Composite) -> None:
self.reward_spec = spec.to(self.device) if self.device is not None else spec
# done spec
@@ -1068,10 +1090,10 @@ def done_key(self):
return self.done_keys[0]
@property
- def full_done_spec(self) -> CompositeSpec:
+ def full_done_spec(self) -> Composite:
"""The full done spec.
- ``full_done_spec`` is a :class:`~torchrl.data.CompositeSpec`` instance
+ ``full_done_spec`` is a :class:`~torchrl.data.Composite`` instance
that contains all the done entries.
It can be used to generate fake data with a structure that mimics the
one obtained at runtime.
@@ -1081,14 +1103,14 @@ def full_done_spec(self) -> CompositeSpec:
>>> from torchrl.envs import GymWrapper
>>> env = GymWrapper(gymnasium.make("Pendulum-v1"))
>>> env.full_done_spec
- CompositeSpec(
- done: DiscreteTensorSpec(
+ Composite(
+ done: Categorical(
shape=torch.Size([1]),
space=DiscreteBox(n=2),
device=cpu,
dtype=torch.bool,
domain=discrete),
- truncated: DiscreteTensorSpec(
+ truncated: Categorical(
shape=torch.Size([1]),
space=DiscreteBox(n=2),
device=cpu,
@@ -1099,7 +1121,7 @@ def full_done_spec(self) -> CompositeSpec:
return self.output_spec["full_done_spec"]
@full_done_spec.setter
- def full_done_spec(self, spec: CompositeSpec) -> None:
+ def full_done_spec(self, spec: Composite) -> None:
self.done_spec = spec.to(self.device) if self.device is not None else spec
# Done spec: done specs belong to output_spec
@@ -1111,9 +1133,9 @@ def done_spec(self) -> TensorSpec:
If the done spec is provided as a simple spec, this will be returned.
- >>> env.done_spec = DiscreteTensorSpec(2, dtype=torch.bool)
+ >>> env.done_spec = Categorical(2, dtype=torch.bool)
>>> env.done_spec
- DiscreteTensorSpec(
+ Categorical(
shape=torch.Size([]),
space=DiscreteBox(n=2),
device=cpu,
@@ -1123,9 +1145,9 @@ def done_spec(self) -> TensorSpec:
If the done spec is provided as a composite spec and contains only one leaf,
this function will return just the leaf.
- >>> env.done_spec = CompositeSpec({"nested": {"done": DiscreteTensorSpec(2, dtype=torch.bool)}})
+ >>> env.done_spec = Composite({"nested": {"done": Categorical(2, dtype=torch.bool)}})
>>> env.done_spec
- DiscreteTensorSpec(
+ Categorical(
shape=torch.Size([]),
space=DiscreteBox(n=2),
device=cpu,
@@ -1135,17 +1157,17 @@ def done_spec(self) -> TensorSpec:
If the done spec is provided as a composite spec and has more than one leaf,
this function will return the whole spec.
- >>> env.done_spec = CompositeSpec({"nested": {"done": DiscreteTensorSpec(2, dtype=torch.bool), "another_done": DiscreteTensorSpec(2, dtype=torch.bool)}})
+ >>> env.done_spec = Composite({"nested": {"done": Categorical(2, dtype=torch.bool), "another_done": Categorical(2, dtype=torch.bool)}})
>>> env.done_spec
- CompositeSpec(
- nested: CompositeSpec(
- done: DiscreteTensorSpec(
+ Composite(
+ nested: Composite(
+ done: Categorical(
shape=torch.Size([]),
space=DiscreteBox(n=2),
device=cpu,
dtype=torch.bool,
domain=discrete),
- another_done: DiscreteTensorSpec(
+ another_done: Categorical(
shape=torch.Size([]),
space=DiscreteBox(n=2),
device=cpu,
@@ -1162,7 +1184,7 @@ def done_spec(self) -> TensorSpec:
>>> from torchrl.envs.libs.gym import GymEnv
>>> env = GymEnv("Pendulum-v1")
>>> env.done_spec
- DiscreteTensorSpec(
+ Categorical(
shape=torch.Size([1]),
space=DiscreteBox(n=2),
device=cpu,
@@ -1185,16 +1207,16 @@ def _create_done_specs(self):
try:
full_done_spec = self.output_spec["full_done_spec"]
except KeyError:
- full_done_spec = CompositeSpec(
+ full_done_spec = Composite(
shape=self.output_spec.shape, device=self.output_spec.device
)
- full_done_spec["done"] = DiscreteTensorSpec(
+ full_done_spec["done"] = Categorical(
n=2,
shape=(*full_done_spec.shape, 1),
dtype=torch.bool,
device=self.device,
)
- full_done_spec["terminated"] = DiscreteTensorSpec(
+ full_done_spec["terminated"] = Categorical(
n=2,
shape=(*full_done_spec.shape, 1),
dtype=torch.bool,
@@ -1215,7 +1237,7 @@ def check_local_done(spec):
spec["terminated"] = item.clone()
elif key == "terminated" and "done" not in spec.keys():
spec["done"] = item.clone()
- elif isinstance(item, CompositeSpec):
+ elif isinstance(item, Composite):
check_local_done(item)
else:
if shape is None:
@@ -1229,10 +1251,10 @@ def check_local_done(spec):
# if the spec is empty, we need to add a done and terminated manually
if spec.is_empty():
- spec["done"] = DiscreteTensorSpec(
+ spec["done"] = Categorical(
n=2, shape=(*spec.shape, 1), dtype=torch.bool, device=self.device
)
- spec["terminated"] = DiscreteTensorSpec(
+ spec["terminated"] = Categorical(
n=2, shape=(*spec.shape, 1), dtype=torch.bool, device=self.device
)
@@ -1260,16 +1282,16 @@ def done_spec(self, value: TensorSpec) -> None:
raise ValueError(
f"The value of spec.shape ({value.shape}) must match the env batch size ({self.batch_size})."
)
- if isinstance(value, CompositeSpec):
+ if isinstance(value, Composite):
for _ in value.values(True, True): # noqa: B007
break
else:
raise RuntimeError(
- "An empty CompositeSpec was passed for the done spec. "
+ "An empty Composite was passed for the done spec. "
"This is currently not permitted."
)
else:
- value = CompositeSpec(
+ value = Composite(
done=value.to(device),
terminated=value.to(device),
shape=self.batch_size,
@@ -1290,10 +1312,10 @@ def done_spec(self, value: TensorSpec) -> None:
# observation spec: observation specs belong to output_spec
@property
- def observation_spec(self) -> CompositeSpec:
+ def observation_spec(self) -> Composite:
"""Observation spec.
- Must be a :class:`torchrl.data.CompositeSpec` instance.
+ Must be a :class:`torchrl.data.Composite` instance.
The keys listed in the spec are directly accessible after reset and step.
In TorchRL, even though they are not properly speaking "observations"
@@ -1307,8 +1329,8 @@ def observation_spec(self) -> CompositeSpec:
>>> from torchrl.envs.libs.gym import GymEnv
>>> env = GymEnv("Pendulum-v1")
>>> env.observation_spec
- CompositeSpec(
- observation: BoundedTensorSpec(
+ Composite(
+ observation: BoundedContinuous(
shape=torch.Size([3]),
space=ContinuousBox(
low=Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, contiguous=True),
@@ -1318,9 +1340,9 @@ def observation_spec(self) -> CompositeSpec:
domain=continuous), device=cpu, shape=torch.Size([]))
"""
- observation_spec = self.output_spec["full_observation_spec"]
+ observation_spec = self.output_spec.get("full_observation_spec", default=None)
if observation_spec is None:
- observation_spec = CompositeSpec(shape=self.batch_size, device=self.device)
+ observation_spec = Composite(shape=self.batch_size, device=self.device)
self.output_spec.unlock_()
self.output_spec["full_observation_spec"] = observation_spec
self.output_spec.lock_()
@@ -1330,7 +1352,7 @@ def observation_spec(self) -> CompositeSpec:
def observation_spec(self, value: TensorSpec) -> None:
try:
self.output_spec.unlock_()
- if not isinstance(value, CompositeSpec):
+ if not isinstance(value, Composite):
raise TypeError("The type of an observation_spec must be Composite.")
elif value.shape[: len(self.batch_size)] != self.batch_size:
raise ValueError(
@@ -1348,19 +1370,19 @@ def observation_spec(self, value: TensorSpec) -> None:
self.output_spec.lock_()
@property
- def full_observation_spec(self) -> CompositeSpec:
+ def full_observation_spec(self) -> Composite:
return self.observation_spec
@full_observation_spec.setter
- def full_observation_spec(self, spec: CompositeSpec):
+ def full_observation_spec(self, spec: Composite):
self.observation_spec = spec
# state spec: state specs belong to input_spec
@property
- def state_spec(self) -> CompositeSpec:
+ def state_spec(self) -> Composite:
"""State spec.
- Must be a :class:`torchrl.data.CompositeSpec` instance.
+ Must be a :class:`torchrl.data.Composite` instance.
The keys listed here should be provided as input alongside actions to the environment.
In TorchRL, even though they are not properly speaking "state"
@@ -1376,10 +1398,10 @@ def state_spec(self) -> CompositeSpec:
... break
>>> env = BraxEnv(envname)
>>> env.state_spec
- CompositeSpec(
- state: CompositeSpec(
- pipeline_state: CompositeSpec(
- q: UnboundedContinuousTensorSpec(
+ Composite(
+ state: Composite(
+ pipeline_state: Composite(
+ q: UnboundedContinuous(
shape=torch.Size([15]),
space=None,
device=cpu,
@@ -1391,14 +1413,14 @@ def state_spec(self) -> CompositeSpec:
"""
state_spec = self.input_spec["full_state_spec"]
if state_spec is None:
- state_spec = CompositeSpec(shape=self.batch_size, device=self.device)
+ state_spec = Composite(shape=self.batch_size, device=self.device)
self.input_spec.unlock_()
self.input_spec["full_state_spec"] = state_spec
self.input_spec.lock_()
return state_spec
@state_spec.setter
- def state_spec(self, value: CompositeSpec) -> None:
+ def state_spec(self, value: Composite) -> None:
try:
self.input_spec.unlock_()
try:
@@ -1406,12 +1428,12 @@ def state_spec(self, value: CompositeSpec) -> None:
except AttributeError:
pass
if value is None:
- self.input_spec["full_state_spec"] = CompositeSpec(
+ self.input_spec["full_state_spec"] = Composite(
device=self.device, shape=self.batch_size
)
else:
device = self.input_spec.device
- if not isinstance(value, CompositeSpec):
+ if not isinstance(value, Composite):
raise TypeError("The type of an state_spec must be Composite.")
elif value.shape[: len(self.batch_size)] != self.batch_size:
raise ValueError(
@@ -1428,10 +1450,10 @@ def state_spec(self, value: CompositeSpec) -> None:
self.input_spec.lock_()
@property
- def full_state_spec(self) -> CompositeSpec:
+ def full_state_spec(self) -> Composite:
"""The full state spec.
- ``full_state_spec`` is a :class:`~torchrl.data.CompositeSpec`` instance
+ ``full_state_spec`` is a :class:`~torchrl.data.Composite`` instance
that contains all the state entries (ie, the input data that is not action).
Examples:
@@ -1440,10 +1462,10 @@ def full_state_spec(self) -> CompositeSpec:
... break
>>> env = BraxEnv(envname)
>>> env.full_state_spec
- CompositeSpec(
- state: CompositeSpec(
- pipeline_state: CompositeSpec(
- q: UnboundedContinuousTensorSpec(
+ Composite(
+ state: Composite(
+ pipeline_state: Composite(
+ q: UnboundedContinuous(
shape=torch.Size([15]),
space=None,
device=cpu,
@@ -1455,7 +1477,7 @@ def full_state_spec(self) -> CompositeSpec:
return self.state_spec
@full_state_spec.setter
- def full_state_spec(self, spec: CompositeSpec) -> None:
+ def full_state_spec(self, spec: Composite) -> None:
self.state_spec = spec
def step(self, tensordict: TensorDictBase) -> TensorDictBase:
@@ -1494,7 +1516,7 @@ def step(self, tensordict: TensorDictBase) -> TensorDictBase:
@classmethod
def _complete_done(
- cls, done_spec: CompositeSpec, data: TensorDictBase
+ cls, done_spec: Composite, data: TensorDictBase
) -> TensorDictBase:
"""Completes the data structure at step time to put missing done keys."""
# by default, if a done key is missing, it is assumed that it is False
@@ -1508,7 +1530,7 @@ def _complete_done(
i = -1
for i, (key, item) in enumerate(done_spec.items()): # noqa: B007
val = data.get(key, None)
- if isinstance(item, CompositeSpec):
+ if isinstance(item, Composite):
if val is not None:
cls._complete_done(item, val)
continue
@@ -2141,9 +2163,9 @@ def reset(
self._assert_tensordict_shape(tensordict)
tensordict_reset = self._reset(tensordict, **kwargs)
- # We assume that this is done properly
- # if reset.device != self.device:
- # reset = reset.to(self.device, non_blocking=True)
+ # We assume that this is done properly
+ # if reset.device != self.device:
+ # reset = reset.to(self.device, non_blocking=True)
if tensordict_reset is tensordict:
raise RuntimeError(
"EnvBase._reset should return outplace changes to the input "
@@ -2300,14 +2322,14 @@ def rand_step(self, tensordict: Optional[TensorDictBase] = None) -> TensorDictBa
return self.step(tensordict)
@property
- def specs(self) -> CompositeSpec:
+ def specs(self) -> Composite:
"""Returns a Composite container where all the environment are present.
This feature allows one to create an environment, retrieve all of the specs in a single data container and then
erase the environment from the workspace.
"""
- return CompositeSpec(
+ return Composite(
output_spec=self.output_spec,
input_spec=self.input_spec,
shape=self.batch_size,
@@ -2322,13 +2344,16 @@ def rollout(
max_steps: int,
policy: Optional[Callable[[TensorDictBase], TensorDictBase]] = None,
callback: Optional[Callable[[TensorDictBase, ...], Any]] = None,
+ *,
auto_reset: bool = True,
auto_cast_to_device: bool = False,
- break_when_any_done: bool = True,
+ break_when_any_done: bool | None = None,
+ break_when_all_done: bool | None = None,
return_contiguous: bool = True,
tensordict: Optional[TensorDictBase] = None,
set_truncated: bool = False,
out=None,
+ trust_policy: bool = False,
):
"""Executes a rollout in the environment.
@@ -2347,6 +2372,8 @@ def rollout(
TensorDict. Defaults to ``None``. The output of ``callback`` will not be collected, it is the user
responsibility to save any result within the callback call if data needs to be carried over beyond
the call to ``rollout``.
+
+ Keyword Args:
auto_reset (bool, optional): if ``True``, resets automatically the environment
if it is in a done state when the rollout is initiated.
Default is ``True``.
@@ -2354,6 +2381,7 @@ def rollout(
policy device before the policy is used. Default is ``False``.
break_when_any_done (bool): breaks if any of the done state is True. If False, a reset() is
called on the sub-envs that are done. Default is True.
+ break_when_all_done (bool): TODO
return_contiguous (bool): if False, a LazyStackedTensorDict will be returned. Default is True.
tensordict (TensorDict, optional): if ``auto_reset`` is False, an initial
tensordict must be provided. Rollout will check if this tensordict has done flags and reset the
@@ -2367,6 +2395,9 @@ def rollout(
``done_spec``, an exception is raised.
Truncated keys can be set through ``env.add_truncated_keys``.
Defaults to ``False``.
+ trust_policy (bool, optional): if ``True``, a non-TensorDictModule policy will be trusted to be
+ assumed to be compatible with the collector. This defaults to ``True`` for CudaGraphModules
+ and ``False`` otherwise.
Returns:
TensorDict object containing the resulting trajectory.
@@ -2550,9 +2581,26 @@ def rollout(
... )
"""
+ if break_when_any_done is None: # True by default
+ if break_when_all_done: # all overrides
+ break_when_any_done = False
+ else:
+ break_when_any_done = True
+ if break_when_all_done is None:
+ # There is no case where break_when_all_done is True by default
+ break_when_all_done = False
+ if break_when_all_done and break_when_any_done:
+ raise TypeError(
+ "Cannot have both break_when_all_done and break_when_any_done True at the same time."
+ )
+
if policy is not None:
policy = _make_compatible_policy(
- policy, self.observation_spec, env=self, fast_wrap=True
+ policy,
+ self.observation_spec,
+ env=self,
+ fast_wrap=True,
+ trust_policy=trust_policy,
)
if auto_cast_to_device:
try:
@@ -2583,8 +2631,12 @@ def rollout(
"env_device": env_device,
"callback": callback,
}
- if break_when_any_done:
- tensordicts = self._rollout_stop_early(**kwargs)
+ if break_when_any_done or break_when_all_done:
+ tensordicts = self._rollout_stop_early(
+ break_when_all_done=break_when_all_done,
+ break_when_any_done=break_when_any_done,
+ **kwargs,
+ )
else:
tensordicts = self._rollout_nonstop(**kwargs)
batch_size = self.batch_size if tensordict is None else tensordict.batch_size
@@ -2644,6 +2696,8 @@ def _step_mdp(self):
def _rollout_stop_early(
self,
*,
+ break_when_any_done,
+ break_when_all_done,
tensordict,
auto_cast_to_device,
max_steps,
@@ -2656,6 +2710,7 @@ def _rollout_stop_early(
if auto_cast_to_device:
sync_func = _get_sync_func(policy_device, env_device)
tensordicts = []
+ partial_steps = True
for i in range(max_steps):
if auto_cast_to_device:
if policy_device is not None:
@@ -2673,6 +2728,14 @@ def _rollout_stop_early(
tensordict.clear_device_()
tensordict = self.step(tensordict)
td_append = tensordict.copy()
+ if break_when_all_done:
+ if partial_steps is not True:
+ # At least one partial step has been done
+ del td_append["_partial_steps"]
+ td_append = torch.where(
+ partial_steps.view(td_append.shape), td_append, tensordicts[-1]
+ )
+
tensordicts.append(td_append)
if i == max_steps - 1:
@@ -2680,16 +2743,31 @@ def _rollout_stop_early(
break
tensordict = self._step_mdp(tensordict)
- # done and truncated are in done_keys
- # We read if any key is done.
- any_done = _terminated_or_truncated(
- tensordict,
- full_done_spec=self.output_spec["full_done_spec"],
- key=None,
- )
-
- if any_done:
- break
+ if break_when_any_done:
+ # done and truncated are in done_keys
+ # We read if any key is done.
+ any_done = _terminated_or_truncated(
+ tensordict,
+ full_done_spec=self.output_spec["full_done_spec"],
+ key=None,
+ )
+ if any_done:
+ break
+ else:
+ _terminated_or_truncated(
+ tensordict,
+ full_done_spec=self.output_spec["full_done_spec"],
+ key="_partial_steps",
+ write_full_false=False,
+ )
+ partial_step_curr = tensordict.get("_partial_steps", None)
+ if partial_step_curr is not None:
+ partial_step_curr = ~partial_step_curr
+ partial_steps = partial_steps & partial_step_curr
+ if partial_steps is not True:
+ if not partial_steps.any():
+ break
+ tensordict.set("_partial_steps", partial_steps)
if callback is not None:
callback(self, tensordict)
@@ -3044,6 +3122,7 @@ def __init__(
self._constructor_kwargs = kwargs
self._check_kwargs(kwargs)
+ self._convert_actions_to_numpy = kwargs.pop("convert_actions_to_numpy", True)
self._env = self._build_env(**kwargs) # writes the self._env attribute
self._make_specs(self._env) # writes the self._env attribute
self.is_closed = False
@@ -3169,7 +3248,7 @@ def _do_nothing():
return
-def _has_dynamic_specs(spec: CompositeSpec):
+def _has_dynamic_specs(spec: Composite):
from tensordict.base import _NESTED_TENSORS_AS_LISTS
return any(
diff --git a/torchrl/envs/custom/pendulum.py b/torchrl/envs/custom/pendulum.py
index 8253e3df9b7..e2007227127 100644
--- a/torchrl/envs/custom/pendulum.py
+++ b/torchrl/envs/custom/pendulum.py
@@ -6,11 +6,7 @@
import torch
from tensordict import TensorDict, TensorDictBase
-from torchrl.data.tensor_specs import (
- BoundedTensorSpec,
- CompositeSpec,
- UnboundedContinuousTensorSpec,
-)
+from torchrl.data.tensor_specs import Bounded, Composite, Unbounded
from torchrl.envs.common import EnvBase
from torchrl.envs.utils import make_composite_from_td
@@ -21,125 +17,193 @@ class PendulumEnv(EnvBase):
See the Pendulum tutorial for more details: :ref:`tutorial `.
Specs:
- CompositeSpec(
- output_spec: CompositeSpec(
- full_observation_spec: CompositeSpec(
- th: BoundedTensorSpec(
+ >>> env = PendulumEnv()
+ >>> env.specs
+ Composite(
+ output_spec: Composite(
+ full_observation_spec: Composite(
+ th: BoundedContinuous(
shape=torch.Size([]),
space=ContinuousBox(
- low=Tensor(shape=torch.Size([]), dtype=torch.float32, contiguous=True),
- high=Tensor(shape=torch.Size([]), dtype=torch.float32, contiguous=True)),
+ low=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, contiguous=True),
+ high=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, contiguous=True)),
+ device=cpu,
dtype=torch.float32,
domain=continuous),
- thdot: BoundedTensorSpec(
+ thdot: BoundedContinuous(
shape=torch.Size([]),
space=ContinuousBox(
- low=Tensor(shape=torch.Size([]), dtype=torch.float32, contiguous=True),
- high=Tensor(shape=torch.Size([]), dtype=torch.float32, contiguous=True)),
+ low=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, contiguous=True),
+ high=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, contiguous=True)),
+ device=cpu,
dtype=torch.float32,
domain=continuous),
- params: CompositeSpec(
- max_speed: UnboundedContinuousTensorSpec(
+ params: Composite(
+ max_speed: UnboundedDiscrete(
shape=torch.Size([]),
+ space=ContinuousBox(
+ low=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.int64, contiguous=True),
+ high=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.int64, contiguous=True)),
+ device=cpu,
dtype=torch.int64,
domain=discrete),
- max_torque: UnboundedContinuousTensorSpec(
+ max_torque: UnboundedContinuous(
shape=torch.Size([]),
+ space=ContinuousBox(
+ low=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, contiguous=True),
+ high=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, contiguous=True)),
+ device=cpu,
dtype=torch.float32,
domain=continuous),
- dt: UnboundedContinuousTensorSpec(
+ dt: UnboundedContinuous(
shape=torch.Size([]),
+ space=ContinuousBox(
+ low=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, contiguous=True),
+ high=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, contiguous=True)),
+ device=cpu,
dtype=torch.float32,
domain=continuous),
- g: UnboundedContinuousTensorSpec(
+ g: UnboundedContinuous(
shape=torch.Size([]),
+ space=ContinuousBox(
+ low=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, contiguous=True),
+ high=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, contiguous=True)),
+ device=cpu,
dtype=torch.float32,
domain=continuous),
- m: UnboundedContinuousTensorSpec(
+ m: UnboundedContinuous(
shape=torch.Size([]),
+ space=ContinuousBox(
+ low=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, contiguous=True),
+ high=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, contiguous=True)),
+ device=cpu,
dtype=torch.float32,
domain=continuous),
- l: UnboundedContinuousTensorSpec(
+ l: UnboundedContinuous(
shape=torch.Size([]),
+ space=ContinuousBox(
+ low=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, contiguous=True),
+ high=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, contiguous=True)),
+ device=cpu,
dtype=torch.float32,
domain=continuous),
+ device=None,
shape=torch.Size([])),
+ device=None,
shape=torch.Size([])),
- full_reward_spec: CompositeSpec(
- reward: UnboundedContinuousTensorSpec(
+ full_reward_spec: Composite(
+ reward: UnboundedContinuous(
shape=torch.Size([1]),
space=ContinuousBox(
- low=Tensor(shape=torch.Size([1]), dtype=torch.float32, contiguous=True),
- high=Tensor(shape=torch.Size([1]), dtype=torch.float32, contiguous=True)),
+ low=Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, contiguous=True),
+ high=Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, contiguous=True)),
+ device=cpu,
dtype=torch.float32,
domain=continuous),
+ device=None,
shape=torch.Size([])),
- full_done_spec: CompositeSpec(
- done: DiscreteTensorSpec(
+ full_done_spec: Composite(
+ done: Categorical(
shape=torch.Size([1]),
- space=DiscreteBox(n=2),
+ space=CategoricalBox(n=2),
+ device=cpu,
dtype=torch.bool,
domain=discrete),
- terminated: DiscreteTensorSpec(
+ terminated: Categorical(
shape=torch.Size([1]),
- space=DiscreteBox(n=2),
+ space=CategoricalBox(n=2),
+ device=cpu,
dtype=torch.bool,
domain=discrete),
+ device=None,
shape=torch.Size([])),
+ device=None,
shape=torch.Size([])),
- input_spec: CompositeSpec(
- full_state_spec: CompositeSpec(
- th: BoundedTensorSpec(
+ input_spec: Composite(
+ full_state_spec: Composite(
+ th: BoundedContinuous(
shape=torch.Size([]),
space=ContinuousBox(
- low=Tensor(shape=torch.Size([]), dtype=torch.float32, contiguous=True),
- high=Tensor(shape=torch.Size([]), dtype=torch.float32, contiguous=True)),
+ low=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, contiguous=True),
+ high=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, contiguous=True)),
+ device=cpu,
dtype=torch.float32,
domain=continuous),
- thdot: BoundedTensorSpec(
+ thdot: BoundedContinuous(
shape=torch.Size([]),
space=ContinuousBox(
- low=Tensor(shape=torch.Size([]), dtype=torch.float32, contiguous=True),
- high=Tensor(shape=torch.Size([]), dtype=torch.float32, contiguous=True)),
+ low=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, contiguous=True),
+ high=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, contiguous=True)),
+ device=cpu,
dtype=torch.float32,
domain=continuous),
- params: CompositeSpec(
- max_speed: UnboundedContinuousTensorSpec(
+ params: Composite(
+ max_speed: UnboundedDiscrete(
shape=torch.Size([]),
+ space=ContinuousBox(
+ low=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.int64, contiguous=True),
+ high=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.int64, contiguous=True)),
+ device=cpu,
dtype=torch.int64,
domain=discrete),
- max_torque: UnboundedContinuousTensorSpec(
+ max_torque: UnboundedContinuous(
shape=torch.Size([]),
+ space=ContinuousBox(
+ low=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, contiguous=True),
+ high=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, contiguous=True)),
+ device=cpu,
dtype=torch.float32,
domain=continuous),
- dt: UnboundedContinuousTensorSpec(
+ dt: UnboundedContinuous(
shape=torch.Size([]),
+ space=ContinuousBox(
+ low=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, contiguous=True),
+ high=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, contiguous=True)),
+ device=cpu,
dtype=torch.float32,
domain=continuous),
- g: UnboundedContinuousTensorSpec(
+ g: UnboundedContinuous(
shape=torch.Size([]),
+ space=ContinuousBox(
+ low=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, contiguous=True),
+ high=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, contiguous=True)),
+ device=cpu,
dtype=torch.float32,
domain=continuous),
- m: UnboundedContinuousTensorSpec(
+ m: UnboundedContinuous(
shape=torch.Size([]),
+ space=ContinuousBox(
+ low=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, contiguous=True),
+ high=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, contiguous=True)),
+ device=cpu,
dtype=torch.float32,
domain=continuous),
- l: UnboundedContinuousTensorSpec(
+ l: UnboundedContinuous(
shape=torch.Size([]),
+ space=ContinuousBox(
+ low=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, contiguous=True),
+ high=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, contiguous=True)),
+ device=cpu,
dtype=torch.float32,
domain=continuous),
+ device=None,
shape=torch.Size([])),
+ device=None,
shape=torch.Size([])),
- full_action_spec: CompositeSpec(
- action: BoundedTensorSpec(
+ full_action_spec: Composite(
+ action: BoundedContinuous(
shape=torch.Size([1]),
space=ContinuousBox(
- low=Tensor(shape=torch.Size([1]), dtype=torch.float32, contiguous=True),
- high=Tensor(shape=torch.Size([1]), dtype=torch.float32, contiguous=True)),
+ low=Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, contiguous=True),
+ high=Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, contiguous=True)),
+ device=cpu,
dtype=torch.float32,
domain=continuous),
+ device=None,
shape=torch.Size([])),
+ device=None,
shape=torch.Size([])),
+ device=None,
shape=torch.Size([]))
"""
@@ -152,6 +216,7 @@ class PendulumEnv(EnvBase):
"render_fps": 30,
}
batch_locked = False
+ rng = None
def __init__(self, td_params=None, seed=None, device=None):
if td_params is None:
@@ -160,7 +225,7 @@ def __init__(self, td_params=None, seed=None, device=None):
super().__init__(device=device)
self._make_spec(td_params)
if seed is None:
- seed = torch.empty((), dtype=torch.int64).random_().item()
+ seed = torch.empty((), dtype=torch.int64).random_(generator=self.rng).item()
self.set_seed(seed)
@classmethod
@@ -240,14 +305,14 @@ def _reset(self, tensordict):
def _make_spec(self, td_params):
# Under the hood, this will populate self.output_spec["observation"]
- self.observation_spec = CompositeSpec(
- th=BoundedTensorSpec(
+ self.observation_spec = Composite(
+ th=Bounded(
low=-torch.pi,
high=torch.pi,
shape=(),
dtype=torch.float32,
),
- thdot=BoundedTensorSpec(
+ thdot=Bounded(
low=-td_params["params", "max_speed"],
high=td_params["params", "max_speed"],
shape=(),
@@ -265,22 +330,22 @@ def _make_spec(self, td_params):
self.state_spec = self.observation_spec.clone()
# action-spec will be automatically wrapped in input_spec when
# `self.action_spec = spec` will be called supported
- self.action_spec = BoundedTensorSpec(
+ self.action_spec = Bounded(
low=-td_params["params", "max_torque"],
high=td_params["params", "max_torque"],
shape=(1,),
dtype=torch.float32,
)
- self.reward_spec = UnboundedContinuousTensorSpec(shape=(*td_params.shape, 1))
+ self.reward_spec = Unbounded(shape=(*td_params.shape, 1))
def make_composite_from_td(td):
# custom function to convert a ``tensordict`` in a similar spec structure
# of unbounded values.
- composite = CompositeSpec(
+ composite = Composite(
{
key: make_composite_from_td(tensor)
if isinstance(tensor, TensorDictBase)
- else UnboundedContinuousTensorSpec(
+ else Unbounded(
dtype=tensor.dtype, device=tensor.device, shape=tensor.shape
)
for key, tensor in td.items()
@@ -290,7 +355,8 @@ def make_composite_from_td(td):
return composite
def _set_seed(self, seed: int):
- rng = torch.manual_seed(seed)
+ rng = torch.Generator()
+ rng.manual_seed(seed)
self.rng = rng
@staticmethod
diff --git a/torchrl/envs/custom/tictactoeenv.py b/torchrl/envs/custom/tictactoeenv.py
index 79ea3b2dfb6..2c93a5748ef 100644
--- a/torchrl/envs/custom/tictactoeenv.py
+++ b/torchrl/envs/custom/tictactoeenv.py
@@ -9,12 +9,7 @@
import torch
from tensordict import TensorDict, TensorDictBase
-from torchrl.data.tensor_specs import (
- CompositeSpec,
- DiscreteTensorSpec,
- UnboundedContinuousTensorSpec,
- UnboundedDiscreteTensorSpec,
-)
+from torchrl.data.tensor_specs import Categorical, Composite, Unbounded
from torchrl.envs.common import EnvBase
@@ -39,28 +34,28 @@ class TicTacToeEnv(EnvBase):
output entry).
Specs:
- CompositeSpec(
- output_spec: CompositeSpec(
- full_observation_spec: CompositeSpec(
- board: DiscreteTensorSpec(
+ Composite(
+ output_spec: Composite(
+ full_observation_spec: Composite(
+ board: Categorical(
shape=torch.Size([3, 3]),
space=DiscreteBox(n=2),
dtype=torch.int32,
domain=discrete),
- turn: DiscreteTensorSpec(
+ turn: Categorical(
shape=torch.Size([1]),
space=DiscreteBox(n=2),
dtype=torch.int32,
domain=discrete),
- mask: DiscreteTensorSpec(
+ mask: Categorical(
shape=torch.Size([9]),
space=DiscreteBox(n=2),
dtype=torch.bool,
domain=discrete),
shape=torch.Size([])),
- full_reward_spec: CompositeSpec(
- player0: CompositeSpec(
- reward: UnboundedContinuousTensorSpec(
+ full_reward_spec: Composite(
+ player0: Composite(
+ reward: UnboundedContinuous(
shape=torch.Size([1]),
space=ContinuousBox(
low=Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, contiguous=True),
@@ -68,8 +63,8 @@ class TicTacToeEnv(EnvBase):
dtype=torch.float32,
domain=continuous),
shape=torch.Size([])),
- player1: CompositeSpec(
- reward: UnboundedContinuousTensorSpec(
+ player1: Composite(
+ reward: UnboundedContinuous(
shape=torch.Size([1]),
space=ContinuousBox(
low=Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, contiguous=True),
@@ -78,43 +73,43 @@ class TicTacToeEnv(EnvBase):
domain=continuous),
shape=torch.Size([])),
shape=torch.Size([])),
- full_done_spec: CompositeSpec(
- done: DiscreteTensorSpec(
+ full_done_spec: Composite(
+ done: Categorical(
shape=torch.Size([1]),
space=DiscreteBox(n=2),
dtype=torch.bool,
domain=discrete),
- terminated: DiscreteTensorSpec(
+ terminated: Categorical(
shape=torch.Size([1]),
space=DiscreteBox(n=2),
dtype=torch.bool,
domain=discrete),
- truncated: DiscreteTensorSpec(
+ truncated: Categorical(
shape=torch.Size([1]),
space=DiscreteBox(n=2),
dtype=torch.bool,
domain=discrete),
shape=torch.Size([])),
shape=torch.Size([])),
- input_spec: CompositeSpec(
- full_state_spec: CompositeSpec(
- board: DiscreteTensorSpec(
+ input_spec: Composite(
+ full_state_spec: Composite(
+ board: Categorical(
shape=torch.Size([3, 3]),
space=DiscreteBox(n=2),
dtype=torch.int32,
domain=discrete),
- turn: DiscreteTensorSpec(
+ turn: Categorical(
shape=torch.Size([1]),
space=DiscreteBox(n=2),
dtype=torch.int32,
domain=discrete),
- mask: DiscreteTensorSpec(
+ mask: Categorical(
shape=torch.Size([9]),
space=DiscreteBox(n=2),
dtype=torch.bool,
domain=discrete), shape=torch.Size([])),
- full_action_spec: CompositeSpec(
- action: DiscreteTensorSpec(
+ full_action_spec: Composite(
+ action: Categorical(
shape=torch.Size([1]),
space=DiscreteBox(n=9),
dtype=torch.int64,
@@ -172,23 +167,21 @@ class TicTacToeEnv(EnvBase):
def __init__(self, *, single_player: bool = False, device=None):
super().__init__(device=device)
self.single_player = single_player
- self.action_spec: UnboundedDiscreteTensorSpec = DiscreteTensorSpec(
+ self.action_spec: Unbounded = Categorical(
n=9,
shape=(),
device=device,
)
- self.full_observation_spec: CompositeSpec = CompositeSpec(
- board=UnboundedContinuousTensorSpec(
- shape=(3, 3), dtype=torch.int, device=device
- ),
- turn=DiscreteTensorSpec(
+ self.full_observation_spec: Composite = Composite(
+ board=Unbounded(shape=(3, 3), dtype=torch.int, device=device),
+ turn=Categorical(
2,
shape=(1,),
dtype=torch.int,
device=device,
),
- mask=DiscreteTensorSpec(
+ mask=Categorical(
2,
shape=(9,),
dtype=torch.bool,
@@ -196,22 +189,18 @@ def __init__(self, *, single_player: bool = False, device=None):
),
device=device,
)
- self.state_spec: CompositeSpec = self.observation_spec.clone()
+ self.state_spec: Composite = self.observation_spec.clone()
- self.reward_spec: UnboundedContinuousTensorSpec = CompositeSpec(
+ self.reward_spec: Unbounded = Composite(
{
- ("player0", "reward"): UnboundedContinuousTensorSpec(
- shape=(1,), device=device
- ),
- ("player1", "reward"): UnboundedContinuousTensorSpec(
- shape=(1,), device=device
- ),
+ ("player0", "reward"): Unbounded(shape=(1,), device=device),
+ ("player1", "reward"): Unbounded(shape=(1,), device=device),
},
device=device,
)
- self.full_done_spec: DiscreteTensorSpec = CompositeSpec(
- done=DiscreteTensorSpec(2, shape=(1,), dtype=torch.bool, device=device),
+ self.full_done_spec: Categorical = Composite(
+ done=Categorical(2, shape=(1,), dtype=torch.bool, device=device),
device=device,
)
self.full_done_spec["terminated"] = self.full_done_spec["done"].clone()
@@ -229,7 +218,7 @@ def _step(self, state: TensorDict) -> TensorDict:
turn = state["turn"].clone()
action = state["action"]
board.flatten(-2, -1).scatter_(index=action.unsqueeze(-1), dim=-1, value=1)
- wins = self.win(state["board"], action)
+ wins = self.win(board, action)
mask = board.flatten(-2, -1) == -1
done = wins | ~mask.any(-1, keepdim=True)
@@ -245,7 +234,7 @@ def _step(self, state: TensorDict) -> TensorDict:
("player0", "reward"): reward_0.float(),
("player1", "reward"): reward_1.float(),
"board": torch.where(board == -1, board, 1 - board),
- "turn": 1 - state["turn"],
+ "turn": 1 - turn,
"mask": mask,
},
batch_size=state.batch_size,
@@ -271,13 +260,15 @@ def _set_seed(self, seed: int | None):
def win(board: torch.Tensor, action: torch.Tensor):
row = action // 3 # type: ignore
col = action % 3 # type: ignore
- return (
- board[..., row, :].sum()
- == 3 | board[..., col].sum()
- == 3 | board.diagonal(0, -2, -1).sum()
- == 3 | board.flip(-1).diagonal(0, -2, -1).sum()
- == 3
- )
+ if board[..., row, :].sum() == 3:
+ return True
+ if board[..., col].sum() == 3:
+ return True
+ if board.diagonal(0, -2, -1).sum() == 3:
+ return True
+ if board.flip(-1).diagonal(0, -2, -1).sum() == 3:
+ return True
+ return False
@staticmethod
def full(board: torch.Tensor) -> bool:
diff --git a/torchrl/envs/env_creator.py b/torchrl/envs/env_creator.py
index 89ee8cc5614..f090289214d 100644
--- a/torchrl/envs/env_creator.py
+++ b/torchrl/envs/env_creator.py
@@ -109,7 +109,7 @@ def share_memory(self, state_dict: OrderedDict) -> None:
del state_dict[key]
@property
- def meta_data(self):
+ def meta_data(self) -> EnvMetaData:
if self._meta_data is None:
raise RuntimeError(
"meta_data is None in EnvCreator. " "Make sure init_() has been called."
diff --git a/torchrl/envs/gym_like.py b/torchrl/envs/gym_like.py
index c7935272c91..995f245a8ac 100644
--- a/torchrl/envs/gym_like.py
+++ b/torchrl/envs/gym_like.py
@@ -12,14 +12,10 @@
import numpy as np
import torch
-from tensordict import TensorDict, TensorDictBase
+from tensordict import NonTensorData, TensorDict, TensorDictBase
from torchrl._utils import logger as torchrl_logger
-from torchrl.data.tensor_specs import (
- CompositeSpec,
- TensorSpec,
- UnboundedContinuousTensorSpec,
-)
+from torchrl.data.tensor_specs import Composite, NonTensor, TensorSpec, Unbounded
from torchrl.envs.common import _EnvWrapper, EnvBase
@@ -44,10 +40,10 @@ class default_info_dict_reader(BaseInfoDictReader):
Args:
keys (list of keys, optional): If provided, the list of keys to get from
the info dictionary. Defaults to all keys.
- spec (List[TensorSpec], Dict[str, TensorSpec] or CompositeSpec, optional):
+ spec (List[TensorSpec], Dict[str, TensorSpec] or Composite, optional):
If a list of specs is provided, each spec will be matched to its
- correspondent key to form a :class:`torchrl.data.CompositeSpec`.
- If not provided, a composite spec with :class:`~torchrl.data.UnboundedContinuousTensorSpec`
+ correspondent key to form a :class:`torchrl.data.Composite`.
+ If not provided, a composite spec with :class:`~torchrl.data.Unbounded`
specs will lazyly be created.
ignore_private (bool, optional): If ``True``, private infos (starting with
an underscore) will be ignored. Defaults to ``True``.
@@ -72,10 +68,7 @@ class default_info_dict_reader(BaseInfoDictReader):
def __init__(
self,
keys: List[str] | None = None,
- spec: Sequence[TensorSpec]
- | Dict[str, TensorSpec]
- | CompositeSpec
- | None = None,
+ spec: Sequence[TensorSpec] | Dict[str, TensorSpec] | Composite | None = None,
ignore_private: bool = True,
):
self.ignore_private = ignore_private
@@ -87,19 +80,17 @@ def __init__(
if spec is None and keys is None:
_info_spec = None
elif spec is None:
- _info_spec = CompositeSpec(
- {key: UnboundedContinuousTensorSpec(()) for key in keys}, shape=[]
- )
- elif not isinstance(spec, CompositeSpec):
+ _info_spec = Composite({key: Unbounded(()) for key in keys}, shape=[])
+ elif not isinstance(spec, Composite):
if self.keys is not None and len(spec) != len(self.keys):
raise ValueError(
"If specifying specs for info keys with a sequence, the "
"length of the sequence must match the number of keys"
)
if isinstance(spec, dict):
- _info_spec = CompositeSpec(spec, shape=[])
+ _info_spec = Composite(spec, shape=[])
else:
- _info_spec = CompositeSpec(
+ _info_spec = Composite(
{key: spec for key, spec in zip(keys, spec)}, shape=[]
)
else:
@@ -121,7 +112,7 @@ def __call__(
keys = [key for key in keys if not key.startswith("_")]
self.keys = keys
# create an info_spec only if there is none
- info_spec = None if self.info_spec is not None else CompositeSpec()
+ info_spec = None if self.info_spec is not None else Composite()
for key in keys:
if key in info_dict:
val = info_dict[key]
@@ -130,7 +121,7 @@ def __call__(
tensordict.set(key, val)
if info_spec is not None:
val = tensordict.get(key)
- info_spec[key] = UnboundedContinuousTensorSpec(
+ info_spec[key] = Unbounded(
val.shape, device=val.device, dtype=val.dtype
)
elif self.info_spec is not None:
@@ -158,7 +149,7 @@ def info_spec(self) -> Dict[str, TensorSpec]:
class GymLikeEnv(_EnvWrapper):
"""A gym-like env is an environment.
- Its behaviour is similar to gym environments in what common methods (specifically reset and step) are expected to do.
+ Its behavior is similar to gym environments in what common methods (specifically reset and step) are expected to do.
A :obj:`GymLikeEnv` has a :obj:`.step()` method with the following signature:
@@ -181,6 +172,7 @@ class GymLikeEnv(_EnvWrapper):
def __new__(cls, *args, **kwargs):
self = super().__new__(cls, *args, _batch_locked=True, **kwargs)
self._info_dict_reader = []
+
return self
def read_action(self, action):
@@ -291,14 +283,18 @@ def read_obs(
observations = observations_dict
else:
for key, val in observations.items():
- observations[key] = self.observation_spec[key].encode(
- val, ignore_device=True
- )
+ if isinstance(self.observation_spec[key], NonTensor):
+ observations[key] = NonTensorData(val)
+ else:
+ observations[key] = self.observation_spec[key].encode(
+ val, ignore_device=True
+ )
return observations
def _step(self, tensordict: TensorDictBase) -> TensorDictBase:
action = tensordict.get(self.action_key)
- action_np = self.read_action(action)
+ if self._convert_actions_to_numpy:
+ action = self.read_action(action)
reward = 0
for _ in range(self.wrapper_frame_skip):
@@ -309,7 +305,7 @@ def _step(self, tensordict: TensorDictBase) -> TensorDictBase:
truncated,
done,
info_dict,
- ) = self._output_transform(self._env.step(action_np))
+ ) = self._output_transform(self._env.step(action))
if _reward is not None:
reward = reward + _reward
@@ -515,7 +511,7 @@ def auto_register_info_dict(
the info is filled at reset time.
.. note:: This method requires running a few iterations in the environment to
- manually check that the behaviour matches expectations.
+ manually check that the behavior matches expectations.
Args:
ignore_private (bool, optional): If ``True``, private infos (starting with
diff --git a/torchrl/envs/libs/__init__.py b/torchrl/envs/libs/__init__.py
index e322c2cbf01..7ea113ce46d 100644
--- a/torchrl/envs/libs/__init__.py
+++ b/torchrl/envs/libs/__init__.py
@@ -19,7 +19,9 @@
from .jumanji import JumanjiEnv, JumanjiWrapper
from .meltingpot import MeltingpotEnv, MeltingpotWrapper
from .openml import OpenMLEnv
+from .openspiel import OpenSpielEnv, OpenSpielWrapper
from .pettingzoo import PettingZooEnv, PettingZooWrapper
from .robohive import RoboHiveEnv
from .smacv2 import SMACv2Env, SMACv2Wrapper
+from .unity_mlagents import UnityMLAgentsEnv, UnityMLAgentsWrapper
from .vmas import VmasEnv, VmasWrapper
diff --git a/torchrl/envs/libs/_gym_utils.py b/torchrl/envs/libs/_gym_utils.py
index fb01f430fc1..b95bfb335c6 100644
--- a/torchrl/envs/libs/_gym_utils.py
+++ b/torchrl/envs/libs/_gym_utils.py
@@ -12,9 +12,9 @@
from torch.utils._pytree import tree_map
from torchrl._utils import implement_for
-from torchrl.data import CompositeSpec
+from torchrl.data import Composite
from torchrl.envs import step_mdp, TransformedEnv
-from torchrl.envs.libs.gym import _torchrl_to_gym_spec_transform
+from torchrl.envs.libs.gym import _torchrl_to_gym_spec_transform, GYMNASIUM_1_ERROR
_has_gym = importlib.util.find_spec("gym", None) is not None
_has_gymnasium = importlib.util.find_spec("gymnasium", None) is not None
@@ -37,7 +37,7 @@ def __init__(
),
)
self.observation_space = _torchrl_to_gym_spec_transform(
- CompositeSpec(
+ Composite(
{
key: self.torchrl_env.full_observation_spec[key]
for key in self._observation_keys
@@ -125,7 +125,11 @@ def _action_keys(self):
import gymnasium
class _TorchRLGymnasiumWrapper(gymnasium.Env, _BaseGymWrapper):
- @implement_for("gymnasium")
+ @implement_for("gymnasium", "1.0.0")
+ def step(self, action): # noqa: F811
+ raise ImportError(GYMNASIUM_1_ERROR)
+
+ @implement_for("gymnasium", None, "1.0.0")
def step(self, action): # noqa: F811
action_keys = self._action_keys
if len(action_keys) == 1:
@@ -153,7 +157,7 @@ def step(self, action): # noqa: F811
out = tree_map(lambda x: x.detach().cpu().numpy(), out)
return out
- @implement_for("gymnasium")
+ @implement_for("gymnasium", None, "1.0.0")
def reset(self): # noqa: F811
self._tensordict = self.torchrl_env.reset()
observation = self._tensordict
@@ -167,6 +171,10 @@ def reset(self): # noqa: F811
out = tree_map(lambda x: x.detach().cpu().numpy(), out)
return out
+ @implement_for("gymnasium", "1.0.0")
+ def reset(self): # noqa: F811
+ raise ImportError(GYMNASIUM_1_ERROR)
+
else:
class _TorchRLGymnasiumWrapper:
diff --git a/torchrl/envs/libs/brax.py b/torchrl/envs/libs/brax.py
index ac4cd71ddad..9542b8e71ff 100644
--- a/torchrl/envs/libs/brax.py
+++ b/torchrl/envs/libs/brax.py
@@ -11,11 +11,7 @@
from packaging import version
from tensordict import TensorDict, TensorDictBase
-from torchrl.data.tensor_specs import (
- BoundedTensorSpec,
- CompositeSpec,
- UnboundedContinuousTensorSpec,
-)
+from torchrl.data.tensor_specs import Bounded, Composite, Unbounded
from torchrl.envs.common import _EnvWrapper
from torchrl.envs.libs.jax_utils import (
_extract_spec,
@@ -55,8 +51,8 @@ class BraxWrapper(_EnvWrapper):
Args:
env (brax.envs.base.PipelineEnv): the environment to wrap.
categorical_action_encoding (bool, optional): if ``True``, categorical
- specs will be converted to the TorchRL equivalent (:class:`torchrl.data.DiscreteTensorSpec`),
- otherwise a one-hot encoding will be used (:class:`torchrl.data.OneHotTensorSpec`).
+ specs will be converted to the TorchRL equivalent (:class:`torchrl.data.Categorical`),
+ otherwise a one-hot encoding will be used (:class:`torchrl.data.OneHot`).
Defaults to ``False``.
Keyword Args:
@@ -255,7 +251,7 @@ def _make_state_spec(self, env: "brax.envs.env.Env"): # noqa: F821
return state_spec
def _make_specs(self, env: "brax.envs.env.Env") -> None: # noqa: F821
- self.action_spec = BoundedTensorSpec(
+ self.action_spec = Bounded(
low=-1,
high=1,
shape=(
@@ -264,15 +260,15 @@ def _make_specs(self, env: "brax.envs.env.Env") -> None: # noqa: F821
),
device=self.device,
)
- self.reward_spec = UnboundedContinuousTensorSpec(
+ self.reward_spec = Unbounded(
shape=[
*self.batch_size,
1,
],
device=self.device,
)
- self.observation_spec = CompositeSpec(
- observation=UnboundedContinuousTensorSpec(
+ self.observation_spec = Composite(
+ observation=Unbounded(
shape=(
*self.batch_size,
env.observation_size,
@@ -439,8 +435,8 @@ class BraxEnv(BraxWrapper):
env_name (str): the environment name of the env to wrap. Must be part of
:attr:`~.available_envs`.
categorical_action_encoding (bool, optional): if ``True``, categorical
- specs will be converted to the TorchRL equivalent (:class:`torchrl.data.DiscreteTensorSpec`),
- otherwise a one-hot encoding will be used (:class:`torchrl.data.OneHotTensorSpec`).
+ specs will be converted to the TorchRL equivalent (:class:`torchrl.data.Categorical`),
+ otherwise a one-hot encoding will be used (:class:`torchrl.data.OneHot`).
Defaults to ``False``.
Keyword Args:
diff --git a/torchrl/envs/libs/dm_control.py b/torchrl/envs/libs/dm_control.py
index 5558754de26..2ca62e106f6 100644
--- a/torchrl/envs/libs/dm_control.py
+++ b/torchrl/envs/libs/dm_control.py
@@ -16,13 +16,12 @@
from torchrl._utils import logger as torchrl_logger, VERBOSE
from torchrl.data.tensor_specs import (
- BoundedTensorSpec,
- CompositeSpec,
- DiscreteTensorSpec,
- OneHotDiscreteTensorSpec,
+ Bounded,
+ Categorical,
+ Composite,
+ OneHot,
TensorSpec,
- UnboundedContinuousTensorSpec,
- UnboundedDiscreteTensorSpec,
+ Unbounded,
)
from torchrl.data.utils import DEVICE_TYPING, numpy_to_torch_dtype_dict
@@ -57,14 +56,10 @@ def _dmcontrol_to_torchrl_spec_transform(
)
for k, item in spec.items()
}
- return CompositeSpec(**spec)
+ return Composite(**spec)
elif isinstance(spec, dm_env.specs.DiscreteArray):
# DiscreteArray is a type of BoundedArray so this block needs to go first
- action_space_cls = (
- DiscreteTensorSpec
- if categorical_discrete_encoding
- else OneHotDiscreteTensorSpec
- )
+ action_space_cls = Categorical if categorical_discrete_encoding else OneHot
if dtype is None:
dtype = (
numpy_to_torch_dtype_dict[spec.dtype]
@@ -78,7 +73,7 @@ def _dmcontrol_to_torchrl_spec_transform(
shape = spec.shape
if not len(shape):
shape = torch.Size([1])
- return BoundedTensorSpec(
+ return Bounded(
shape=shape,
low=spec.minimum,
high=spec.maximum,
@@ -92,11 +87,9 @@ def _dmcontrol_to_torchrl_spec_transform(
if dtype is None:
dtype = numpy_to_torch_dtype_dict[spec.dtype]
if dtype in (torch.float, torch.double, torch.half):
- return UnboundedContinuousTensorSpec(
- shape=shape, dtype=dtype, device=device
- )
+ return Unbounded(shape=shape, dtype=dtype, device=device)
else:
- return UnboundedDiscreteTensorSpec(shape=shape, dtype=dtype, device=device)
+ return Unbounded(shape=shape, dtype=dtype, device=device)
else:
raise NotImplementedError(type(spec))
@@ -254,10 +247,10 @@ def _make_specs(self, env: "gym.Env") -> None: # noqa: F821
reward_spec.shape = torch.Size([1])
self.reward_spec = reward_spec
# populate default done spec
- done_spec = DiscreteTensorSpec(
+ done_spec = Categorical(
n=2, shape=(*self.batch_size, 1), dtype=torch.bool, device=self.device
)
- self.done_spec = CompositeSpec(
+ self.done_spec = Composite(
done=done_spec.clone(),
truncated=done_spec.clone(),
terminated=done_spec.clone(),
diff --git a/torchrl/envs/libs/envpool.py b/torchrl/envs/libs/envpool.py
index a029a0beb5b..599645dfdfc 100644
--- a/torchrl/envs/libs/envpool.py
+++ b/torchrl/envs/libs/envpool.py
@@ -13,12 +13,7 @@
from tensordict import TensorDict, TensorDictBase
from torchrl._utils import logger as torchrl_logger
-from torchrl.data.tensor_specs import (
- CompositeSpec,
- DiscreteTensorSpec,
- TensorSpec,
- UnboundedContinuousTensorSpec,
-)
+from torchrl.data.tensor_specs import Categorical, Composite, TensorSpec, Unbounded
from torchrl.envs.common import _EnvWrapper
from torchrl.envs.utils import _classproperty
@@ -35,8 +30,8 @@ class MultiThreadedEnvWrapper(_EnvWrapper):
Args:
env (envpool.python.envpool.EnvPoolMixin): the envpool to wrap.
categorical_action_encoding (bool, optional): if ``True``, categorical
- specs will be converted to the TorchRL equivalent (:class:`torchrl.data.DiscreteTensorSpec`),
- otherwise a one-hot encoding will be used (:class:`torchrl.data.OneHotTensorSpec`).
+ specs will be converted to the TorchRL equivalent (:class:`torchrl.data.Categorical`),
+ otherwise a one-hot encoding will be used (:class:`torchrl.data.OneHot`).
Defaults to ``False``.
Keyword Args:
@@ -161,7 +156,7 @@ def _get_action_spec(self) -> TensorSpec:
return action_spec
def _get_output_spec(self) -> TensorSpec:
- return CompositeSpec(
+ return Composite(
full_observation_spec=self._get_observation_spec(),
full_reward_spec=self._get_reward_spec(),
full_done_spec=self._get_done_spec(),
@@ -180,9 +175,9 @@ def _get_observation_spec(self) -> TensorSpec:
categorical_action_encoding=True,
)
observation_spec = self._add_shape_to_spec(observation_spec)
- if isinstance(observation_spec, CompositeSpec):
+ if isinstance(observation_spec, Composite):
return observation_spec
- return CompositeSpec(
+ return Composite(
observation=observation_spec,
shape=(self.num_workers,),
device=self.device,
@@ -192,19 +187,19 @@ def _add_shape_to_spec(self, spec: TensorSpec) -> TensorSpec:
return spec.expand((self.num_workers, *spec.shape))
def _get_reward_spec(self) -> TensorSpec:
- return UnboundedContinuousTensorSpec(
+ return Unbounded(
device=self.device,
shape=self.batch_size,
)
def _get_done_spec(self) -> TensorSpec:
- spec = DiscreteTensorSpec(
+ spec = Categorical(
2,
device=self.device,
shape=self.batch_size,
dtype=torch.bool,
)
- return CompositeSpec(
+ return Composite(
done=spec,
truncated=spec.clone(),
terminated=spec.clone(),
@@ -335,8 +330,8 @@ class MultiThreadedEnv(MultiThreadedEnvWrapper):
create_env_kwargs (Dict[str, Any], optional): kwargs to be passed to envpool
environment constructor.
categorical_action_encoding (bool, optional): if ``True``, categorical
- specs will be converted to the TorchRL equivalent (:class:`torchrl.data.DiscreteTensorSpec`),
- otherwise a one-hot encoding will be used (:class:`torchrl.data.OneHotTensorSpec`).
+ specs will be converted to the TorchRL equivalent (:class:`torchrl.data.Categorical`),
+ otherwise a one-hot encoding will be used (:class:`torchrl.data.OneHot`).
Defaults to ``False``.
disable_env_checker (bool, optional): for gym > 0.24 only. If ``True`` (default
for these versions), the environment checker won't be run.
diff --git a/torchrl/envs/libs/gym.py b/torchrl/envs/libs/gym.py
index 9195929e31d..dfe0db92230 100644
--- a/torchrl/envs/libs/gym.py
+++ b/torchrl/envs/libs/gym.py
@@ -23,16 +23,16 @@
from torchrl._utils import implement_for
from torchrl.data.tensor_specs import (
_minmax_dtype,
- BinaryDiscreteTensorSpec,
- BoundedTensorSpec,
- CompositeSpec,
- DiscreteTensorSpec,
- MultiDiscreteTensorSpec,
- MultiOneHotDiscreteTensorSpec,
- OneHotDiscreteTensorSpec,
+ Binary,
+ Bounded,
+ Categorical,
+ Composite,
+ MultiCategorical,
+ MultiOneHot,
+ NonTensor,
+ OneHot,
TensorSpec,
- UnboundedContinuousTensorSpec,
- UnboundedDiscreteTensorSpec,
+ Unbounded,
)
from torchrl.data.utils import numpy_to_torch_dtype_dict, torch_to_numpy_dtype_dict
from torchrl.envs.batched_envs import CloudpickleWrapper
@@ -56,6 +56,30 @@
_has_mo = importlib.util.find_spec("mo_gymnasium") is not None
_has_sb3 = importlib.util.find_spec("stable_baselines3") is not None
+_has_minigrid = importlib.util.find_spec("minigrid") is not None
+
+
+GYMNASIUM_1_ERROR = """RuntimeError: TorchRL does not support gymnasium 1.0 or later versions due to incompatible
+changes in the Gym API.
+Using gymnasium 1.0 with TorchRL would require significant modifications to your code and may result in:
+* Inaccurate step counting, as the auto-reset feature can cause unpredictable numbers of steps to be executed.
+* Potential data corruption, as the environment may require/produce garbage data during reset steps.
+* Trajectory overlap during data collection.
+* Increased computational overhead, as the library would need to handle the additional complexity of auto-resets.
+* Manual filtering and boilerplate code to mitigate these issues, which would compromise the modularity and ease of
+use of TorchRL.
+To maintain the integrity and efficiency of our library, we cannot support this version of gymnasium at this time.
+If you need to use gymnasium 1.0 or later, we recommend exploring alternative solutions or waiting for future updates
+to TorchRL and gymnasium that may address this compatibility issue.
+For more information, please refer to discussion https://github.com/pytorch/rl/discussions/2483 in torchrl.
+"""
+
+
+def _minigrid_lib():
+ assert _has_minigrid, "minigrid not found"
+ import minigrid
+
+ return minigrid
class set_gym_backend(_DecoratorContextManager):
@@ -259,11 +283,7 @@ def _gym_to_torchrl_spec_transform(
)
return result
if isinstance(spec, gym_spaces.discrete.Discrete):
- action_space_cls = (
- DiscreteTensorSpec
- if categorical_action_encoding
- else OneHotDiscreteTensorSpec
- )
+ action_space_cls = Categorical if categorical_action_encoding else OneHot
dtype = (
numpy_to_torch_dtype_dict[spec.dtype]
if categorical_action_encoding
@@ -271,7 +291,7 @@ def _gym_to_torchrl_spec_transform(
)
return action_space_cls(spec.n, device=device, dtype=dtype)
elif isinstance(spec, gym_spaces.multi_binary.MultiBinary):
- return BinaryDiscreteTensorSpec(
+ return Binary(
spec.n, device=device, dtype=numpy_to_torch_dtype_dict[spec.dtype]
)
# a spec type cannot be a string, so we're sure that versions of gym that don't have Sequence will just skip through this
@@ -300,11 +320,9 @@ def _gym_to_torchrl_spec_transform(
)
return (
- MultiDiscreteTensorSpec(spec.nvec, device=device, dtype=dtype)
+ MultiCategorical(spec.nvec, device=device, dtype=dtype)
if categorical_action_encoding
- else MultiOneHotDiscreteTensorSpec(
- spec.nvec, device=device, dtype=dtype
- )
+ else MultiOneHot(spec.nvec, device=device, dtype=dtype)
)
return torch.stack(
@@ -337,9 +355,9 @@ def _gym_to_torchrl_spec_transform(
and torch.isclose(high, torch.as_tensor(maxval, dtype=dtype)).all()
)
return (
- UnboundedContinuousTensorSpec(shape, device=device, dtype=dtype)
+ Unbounded(shape, device=device, dtype=dtype)
if is_unbounded
- else BoundedTensorSpec(
+ else Bounded(
low,
high,
shape,
@@ -368,7 +386,7 @@ def _gym_to_torchrl_spec_transform(
remap_state_to_observation=remap_state_to_observation,
)
# the batch-size must be set later
- return CompositeSpec(spec_out, device=device)
+ return Composite(spec_out, device=device)
elif isinstance(spec, gym_spaces.dict.Dict):
return _gym_to_torchrl_spec_transform(
spec.spaces,
@@ -376,6 +394,8 @@ def _gym_to_torchrl_spec_transform(
categorical_action_encoding=categorical_action_encoding,
remap_state_to_observation=remap_state_to_observation,
)
+ elif _has_minigrid and isinstance(spec, _minigrid_lib().core.mission.MissionSpace):
+ return NonTensor((), device=device)
else:
raise NotImplementedError(
f"spec of type {type(spec).__name__} is currently unaccounted for"
@@ -396,13 +416,18 @@ def _box_convert(spec, gym_spaces, shape): # noqa: F811
return gym_spaces.Box(low=low, high=high, shape=shape)
-@implement_for("gymnasium")
+@implement_for("gymnasium", None, "1.0.0")
def _box_convert(spec, gym_spaces, shape): # noqa: F811
low = spec.low.detach().cpu().numpy()
high = spec.high.detach().cpu().numpy()
return gym_spaces.Box(low=low, high=high, shape=shape)
+@implement_for("gymnasium", "1.0.0")
+def _box_convert(spec, gym_spaces, shape): # noqa: F811
+ raise ImportError(GYMNASIUM_1_ERROR)
+
+
@implement_for("gym", "0.21", None)
def _multidiscrete_convert(gym_spaces, spec):
return gym_spaces.multi_discrete.MultiDiscrete(
@@ -410,13 +435,18 @@ def _multidiscrete_convert(gym_spaces, spec):
)
-@implement_for("gymnasium")
+@implement_for("gymnasium", None, "1.0.0")
def _multidiscrete_convert(gym_spaces, spec): # noqa: F811
return gym_spaces.multi_discrete.MultiDiscrete(
spec.nvec, dtype=torch_to_numpy_dtype_dict[spec.dtype]
)
+@implement_for("gymnasium", "1.0.0")
+def _multidiscrete_convert(gym_spaces, spec): # noqa: F811
+ raise ImportError(GYMNASIUM_1_ERROR)
+
+
@implement_for("gym", None, "0.21")
def _multidiscrete_convert(gym_spaces, spec): # noqa: F811
return gym_spaces.multi_discrete.MultiDiscrete(spec.nvec)
@@ -445,19 +475,19 @@ def _torchrl_to_gym_spec_transform(
return gym_spaces.Tuple(
tuple(_torchrl_to_gym_spec_transform(spec) for spec in spec.unbind(0))
)
- if isinstance(spec, MultiDiscreteTensorSpec):
+ if isinstance(spec, MultiCategorical):
return _multidiscrete_convert(gym_spaces, spec)
- if isinstance(spec, MultiOneHotDiscreteTensorSpec):
+ if isinstance(spec, MultiOneHot):
return gym_spaces.multi_discrete.MultiDiscrete(spec.nvec)
- if isinstance(spec, BinaryDiscreteTensorSpec):
+ if isinstance(spec, Binary):
return gym_spaces.multi_binary.MultiBinary(spec.shape[-1])
- if isinstance(spec, DiscreteTensorSpec):
+ if isinstance(spec, Categorical):
return gym_spaces.discrete.Discrete(
spec.n
) # dtype=torch_to_numpy_dtype_dict[spec.dtype])
- if isinstance(spec, OneHotDiscreteTensorSpec):
+ if isinstance(spec, OneHot):
return gym_spaces.discrete.Discrete(spec.n)
- if isinstance(spec, UnboundedContinuousTensorSpec):
+ if isinstance(spec, Unbounded):
minval, maxval = _minmax_dtype(spec.dtype)
return gym_spaces.Box(
low=minval,
@@ -465,7 +495,7 @@ def _torchrl_to_gym_spec_transform(
shape=shape,
dtype=torch_to_numpy_dtype_dict[spec.dtype],
)
- if isinstance(spec, UnboundedDiscreteTensorSpec):
+ if isinstance(spec, Unbounded):
minval, maxval = _minmax_dtype(spec.dtype)
return gym_spaces.Box(
low=minval,
@@ -473,9 +503,9 @@ def _torchrl_to_gym_spec_transform(
shape=shape,
dtype=torch_to_numpy_dtype_dict[spec.dtype],
)
- if isinstance(spec, BoundedTensorSpec):
+ if isinstance(spec, Bounded):
return _box_convert(spec, gym_spaces, shape)
- if isinstance(spec, CompositeSpec):
+ if isinstance(spec, Composite):
# remove batch size
while spec.shape:
spec = spec[0]
@@ -515,12 +545,17 @@ def _get_gym_envs(): # noqa: F811
return gym.envs.registration.registry.keys()
-@implement_for("gymnasium")
+@implement_for("gymnasium", None, "1.0.0")
def _get_gym_envs(): # noqa: F811
gym = gym_backend()
return gym.envs.registration.registry.keys()
+@implement_for("gymnasium", "1.0.0")
+def _get_gym_envs(): # noqa: F811
+ raise ImportError(GYMNASIUM_1_ERROR)
+
+
def _is_from_pixels(env):
observation_spec = env.observation_space
try:
@@ -624,8 +659,8 @@ class GymWrapper(GymLikeEnv, metaclass=_AsyncMeta):
or :class:`gym.VectorEnv`) are supported and the environment batch-size
will reflect the number of environments executed in parallel.
categorical_action_encoding (bool, optional): if ``True``, categorical
- specs will be converted to the TorchRL equivalent (:class:`torchrl.data.DiscreteTensorSpec`),
- otherwise a one-hot encoding will be used (:class:`torchrl.data.OneHotTensorSpec`).
+ specs will be converted to the TorchRL equivalent (:class:`torchrl.data.Categorical`),
+ otherwise a one-hot encoding will be used (:class:`torchrl.data.OneHot`).
Defaults to ``False``.
Keyword Args:
@@ -652,6 +687,11 @@ class GymWrapper(GymLikeEnv, metaclass=_AsyncMeta):
allow_done_after_reset (bool, optional): if ``True``, it is tolerated
for envs to be ``done`` just after :meth:`~.reset` is called.
Defaults to ``False``.
+ convert_actions_to_numpy (bool, optional): if ``True``, actions will be
+ converted from tensors to numpy arrays and moved to CPU before being passed to the
+ env step function. Set this to ``False`` if the environment is evaluated
+ on GPU, such as IsaacLab.
+ Defaults to ``True``.
Attributes:
available_envs (List[str]): a list of environments to build.
@@ -768,14 +808,20 @@ def __init__(self, env=None, categorical_action_encoding=False, **kwargs):
self._seed_calls_reset = None
self._categorical_action_encoding = categorical_action_encoding
if env is not None:
- if "EnvCompatibility" in str(
- env
- ): # a hacky way of knowing if EnvCompatibility is part of the wrappers of env
- raise ValueError(
- "GymWrapper does not support the gym.wrapper.compatibility.EnvCompatibility wrapper. "
- "If this feature is needed, detail your use case in an issue of "
- "https://github.com/pytorch/rl/issues."
- )
+ try:
+ env_str = str(env)
+ except TypeError:
+ # MiniGrid has a bug where the __str__ method fails
+ pass
+ else:
+ if (
+ "EnvCompatibility" in env_str
+ ): # a hacky way of knowing if EnvCompatibility is part of the wrappers of env
+ raise ValueError(
+ "GymWrapper does not support the gym.wrapper.compatibility.EnvCompatibility wrapper. "
+ "If this feature is needed, detail your use case in an issue of "
+ "https://github.com/pytorch/rl/issues."
+ )
libname = self.get_library_name(env)
with set_gym_backend(libname):
kwargs["env"] = env
@@ -820,7 +866,7 @@ def _get_batch_size(self, env):
batch_size = self.batch_size
return batch_size
- @implement_for("gymnasium") # gymnasium wants the unwrapped env
+ @implement_for("gymnasium", None, "1.0.0") # gymnasium wants the unwrapped env
def _get_batch_size(self, env): # noqa: F811
env_unwrapped = env.unwrapped
if hasattr(env_unwrapped, "num_envs"):
@@ -829,6 +875,10 @@ def _get_batch_size(self, env): # noqa: F811
batch_size = self.batch_size
return batch_size
+ @implement_for("gymnasium", "1.0.0")
+ def _get_batch_size(self, env): # noqa: F811
+ raise ImportError(GYMNASIUM_1_ERROR)
+
def _check_kwargs(self, kwargs: Dict):
if "env" not in kwargs:
raise TypeError("Could not find environment key 'env' in kwargs.")
@@ -865,10 +915,7 @@ def _build_env(
def read_action(self, action):
action = super().read_action(action)
- if (
- isinstance(self.action_spec, (OneHotDiscreteTensorSpec, DiscreteTensorSpec))
- and action.size == 1
- ):
+ if isinstance(self.action_spec, (OneHot, Categorical)) and action.size == 1:
# some envs require an integer for indexing
action = int(action)
return action
@@ -908,7 +955,11 @@ def _build_gym_env(self, env, pixels_only): # noqa: F811
return LegacyPixelObservationWrapper(env, pixels_only=pixels_only)
- @implement_for("gymnasium")
+ @implement_for("gymnasium", "1.0.0")
+ def _build_gym_env(self, env, pixels_only): # noqa: F811
+ raise ImportError(GYMNASIUM_1_ERROR)
+
+ @implement_for("gymnasium", None, "1.0.0")
def _build_gym_env(self, env, pixels_only): # noqa: F811
compatibility = gym_backend("wrappers.compatibility")
pixel_observation = gym_backend("wrappers.pixel_observation")
@@ -973,7 +1024,11 @@ def _set_seed_initial(self, seed: int) -> None: # noqa: F811
except AttributeError as err2:
raise err from err2
- @implement_for("gymnasium")
+ @implement_for("gymnasium", "1.0.0")
+ def _set_seed_initial(self, seed: int) -> None: # noqa: F811
+ raise ImportError(GYMNASIUM_1_ERROR)
+
+ @implement_for("gymnasium", None, "1.0.0")
def _set_seed_initial(self, seed: int) -> None: # noqa: F811
try:
self.reset(seed=seed)
@@ -991,7 +1046,11 @@ def _reward_space(self, env):
if hasattr(env, "reward_space") and env.reward_space is not None:
return env.reward_space
- @implement_for("gymnasium")
+ @implement_for("gymnasium", "1.0.0")
+ def _reward_space(self, env): # noqa: F811
+ raise ImportError(GYMNASIUM_1_ERROR)
+
+ @implement_for("gymnasium", None, "1.0.0")
def _reward_space(self, env): # noqa: F811
env = env.unwrapped
if hasattr(env, "reward_space") and env.reward_space is not None:
@@ -1012,13 +1071,13 @@ def _make_specs(self, env: "gym.Env", batch_size=None) -> None: # noqa: F821
device=self.device,
categorical_action_encoding=self._categorical_action_encoding,
)
- if not isinstance(observation_spec, CompositeSpec):
+ if not isinstance(observation_spec, Composite):
if self.from_pixels:
- observation_spec = CompositeSpec(
+ observation_spec = Composite(
pixels=observation_spec, shape=cur_batch_size
)
else:
- observation_spec = CompositeSpec(
+ observation_spec = Composite(
observation=observation_spec, shape=cur_batch_size
)
elif observation_spec.shape[: len(cur_batch_size)] != cur_batch_size:
@@ -1032,7 +1091,7 @@ def _make_specs(self, env: "gym.Env", batch_size=None) -> None: # noqa: F821
categorical_action_encoding=self._categorical_action_encoding,
)
else:
- reward_spec = UnboundedContinuousTensorSpec(
+ reward_spec = Unbounded(
shape=[1],
device=self.device,
)
@@ -1053,15 +1112,15 @@ def _make_specs(self, env: "gym.Env", batch_size=None) -> None: # noqa: F821
@implement_for("gym", None, "0.26")
def _make_done_spec(self): # noqa: F811
- return CompositeSpec(
+ return Composite(
{
- "done": DiscreteTensorSpec(
+ "done": Categorical(
2, dtype=torch.bool, device=self.device, shape=(*self.batch_size, 1)
),
- "terminated": DiscreteTensorSpec(
+ "terminated": Categorical(
2, dtype=torch.bool, device=self.device, shape=(*self.batch_size, 1)
),
- "truncated": DiscreteTensorSpec(
+ "truncated": Categorical(
2, dtype=torch.bool, device=self.device, shape=(*self.batch_size, 1)
),
},
@@ -1070,15 +1129,15 @@ def _make_done_spec(self): # noqa: F811
@implement_for("gym", "0.26", None)
def _make_done_spec(self): # noqa: F811
- return CompositeSpec(
+ return Composite(
{
- "done": DiscreteTensorSpec(
+ "done": Categorical(
2, dtype=torch.bool, device=self.device, shape=(*self.batch_size, 1)
),
- "terminated": DiscreteTensorSpec(
+ "terminated": Categorical(
2, dtype=torch.bool, device=self.device, shape=(*self.batch_size, 1)
),
- "truncated": DiscreteTensorSpec(
+ "truncated": Categorical(
2, dtype=torch.bool, device=self.device, shape=(*self.batch_size, 1)
),
},
@@ -1087,15 +1146,15 @@ def _make_done_spec(self): # noqa: F811
@implement_for("gymnasium", "0.27", None)
def _make_done_spec(self): # noqa: F811
- return CompositeSpec(
+ return Composite(
{
- "done": DiscreteTensorSpec(
+ "done": Categorical(
2, dtype=torch.bool, device=self.device, shape=(*self.batch_size, 1)
),
- "terminated": DiscreteTensorSpec(
+ "terminated": Categorical(
2, dtype=torch.bool, device=self.device, shape=(*self.batch_size, 1)
),
- "truncated": DiscreteTensorSpec(
+ "truncated": Categorical(
2, dtype=torch.bool, device=self.device, shape=(*self.batch_size, 1)
),
},
@@ -1250,8 +1309,8 @@ class GymEnv(GymWrapper):
Args:
env_name (str): the environment id registered in `gym.registry`.
categorical_action_encoding (bool, optional): if ``True``, categorical
- specs will be converted to the TorchRL equivalent (:class:`torchrl.data.DiscreteTensorSpec`),
- otherwise a one-hot encoding will be used (:class:`torchrl.data.OneHotTensorSpec`).
+ specs will be converted to the TorchRL equivalent (:class:`torchrl.data.Categorical`),
+ otherwise a one-hot encoding will be used (:class:`torchrl.data.OneHot`).
Defaults to ``False``.
Keyword Args:
@@ -1269,7 +1328,7 @@ class GymEnv(GymWrapper):
pixels_only (bool, optional): if ``True``, only the pixel observations will
be returned (by default under the ``"pixels"`` entry in the output tensordict).
If ``False``, observations (eg, states) and pixels will be returned
- whenever ``from_pixels=True``. Defaults to ``True``.
+ whenever ``from_pixels=True``. Defaults to ``False``.
frame_skip (int, optional): if provided, indicates for how many steps the
same action is to be repeated. The observation returned will be the
last observation of the sequence, whereas the reward will be the sum
@@ -1385,7 +1444,14 @@ def _set_gym_args( # noqa: F811
) -> None:
kwargs.setdefault("disable_env_checker", True)
- @implement_for("gymnasium")
+ @implement_for("gymnasium", "1.0.0")
+ def _set_gym_args( # noqa: F811
+ self,
+ kwargs,
+ ) -> None:
+ raise ImportError(GYMNASIUM_1_ERROR)
+
+ @implement_for("gymnasium", None, "1.0.0")
def _set_gym_args( # noqa: F811
self,
kwargs,
@@ -1567,7 +1633,7 @@ class terminal_obs_reader(default_info_dict_reader):
replaced.
Args:
- observation_spec (CompositeSpec): The observation spec of the gym env.
+ observation_spec (Composite): The observation spec of the gym env.
backend (str, optional): the backend of the env. One of `"sb3"` for
stable-baselines3 or `"gym"` for gym/gymnasium.
@@ -1585,7 +1651,7 @@ class terminal_obs_reader(default_info_dict_reader):
"gym": "final_info",
}
- def __init__(self, observation_spec: CompositeSpec, backend, name="final"):
+ def __init__(self, observation_spec: Composite, backend, name="final"):
super().__init__()
self.name = name
self._obs_spec = observation_spec.clone()
diff --git a/torchrl/envs/libs/habitat.py b/torchrl/envs/libs/habitat.py
index 53752147acc..4180c42b2dc 100644
--- a/torchrl/envs/libs/habitat.py
+++ b/torchrl/envs/libs/habitat.py
@@ -54,8 +54,8 @@ class HabitatEnv(GymEnv):
Args:
env_name (str): The environment to execute.
categorical_action_encoding (bool, optional): if ``True``, categorical
- specs will be converted to the TorchRL equivalent (:class:`torchrl.data.DiscreteTensorSpec`),
- otherwise a one-hot encoding will be used (:class:`torchrl.data.OneHotTensorSpec`).
+ specs will be converted to the TorchRL equivalent (:class:`torchrl.data.Categorical`),
+ otherwise a one-hot encoding will be used (:class:`torchrl.data.OneHot`).
Defaults to ``False``.
Keyword Args:
diff --git a/torchrl/envs/libs/isaacgym.py b/torchrl/envs/libs/isaacgym.py
index 4c56bea304a..fb37639ad37 100644
--- a/torchrl/envs/libs/isaacgym.py
+++ b/torchrl/envs/libs/isaacgym.py
@@ -14,7 +14,7 @@
import torch
from tensordict import TensorDictBase
-from torchrl.data import CompositeSpec
+from torchrl.data import Composite
from torchrl.envs.libs.gym import GymWrapper
from torchrl.envs.utils import _classproperty, make_composite_from_td
@@ -59,7 +59,7 @@ def __init__(
def _make_specs(self, env: "gym.Env") -> None: # noqa: F821
super()._make_specs(env, batch_size=self.batch_size)
- self.full_done_spec = CompositeSpec(
+ self.full_done_spec = Composite(
{
key: spec.squeeze(-1)
for key, spec in self.full_done_spec.items(True, True)
diff --git a/torchrl/envs/libs/jax_utils.py b/torchrl/envs/libs/jax_utils.py
index d1d1094a264..052f538f0c4 100644
--- a/torchrl/envs/libs/jax_utils.py
+++ b/torchrl/envs/libs/jax_utils.py
@@ -13,12 +13,7 @@
# from jax import dlpack as jax_dlpack, numpy as jnp
from tensordict import make_tensordict, TensorDictBase
from torch.utils import dlpack as torch_dlpack
-from torchrl.data.tensor_specs import (
- CompositeSpec,
- TensorSpec,
- UnboundedContinuousTensorSpec,
- UnboundedDiscreteTensorSpec,
-)
+from torchrl.data.tensor_specs import Composite, TensorSpec, Unbounded
from torchrl.data.utils import numpy_to_torch_dtype_dict
_has_jax = importlib.util.find_spec("jax") is not None
@@ -155,15 +150,11 @@ def _extract_spec(data: Union[torch.Tensor, TensorDictBase], key=None) -> Tensor
if key in ("reward", "done"):
shape = (*shape, 1)
if data.dtype in (torch.float, torch.double, torch.half):
- return UnboundedContinuousTensorSpec(
- shape=shape, dtype=data.dtype, device=data.device
- )
+ return Unbounded(shape=shape, dtype=data.dtype, device=data.device)
else:
- return UnboundedDiscreteTensorSpec(
- shape=shape, dtype=data.dtype, device=data.device
- )
+ return Unbounded(shape=shape, dtype=data.dtype, device=data.device)
elif isinstance(data, TensorDictBase):
- return CompositeSpec(
+ return Composite(
{key: _extract_spec(value, key=key) for key, value in data.items()}
)
else:
diff --git a/torchrl/envs/libs/jumanji.py b/torchrl/envs/libs/jumanji.py
index 071c8f7f56c..dbbc980e8cc 100644
--- a/torchrl/envs/libs/jumanji.py
+++ b/torchrl/envs/libs/jumanji.py
@@ -18,16 +18,15 @@
_has_jumanji = importlib.util.find_spec("jumanji") is not None
from torchrl.data.tensor_specs import (
- BoundedTensorSpec,
- CompositeSpec,
+ Bounded,
+ Categorical,
+ Composite,
DEVICE_TYPING,
- DiscreteTensorSpec,
- MultiDiscreteTensorSpec,
- MultiOneHotDiscreteTensorSpec,
- OneHotDiscreteTensorSpec,
+ MultiCategorical,
+ MultiOneHot,
+ OneHot,
TensorSpec,
- UnboundedContinuousTensorSpec,
- UnboundedDiscreteTensorSpec,
+ Unbounded,
)
from torchrl.data.utils import numpy_to_torch_dtype_dict
from torchrl.envs.gym_like import GymLikeEnv
@@ -59,19 +58,13 @@ def _jumanji_to_torchrl_spec_transform(
import jumanji
if isinstance(spec, jumanji.specs.DiscreteArray):
- action_space_cls = (
- DiscreteTensorSpec
- if categorical_action_encoding
- else OneHotDiscreteTensorSpec
- )
+ action_space_cls = Categorical if categorical_action_encoding else OneHot
if dtype is None:
dtype = numpy_to_torch_dtype_dict[spec.dtype]
return action_space_cls(spec.num_values, dtype=dtype, device=device)
if isinstance(spec, jumanji.specs.MultiDiscreteArray):
action_space_cls = (
- MultiDiscreteTensorSpec
- if categorical_action_encoding
- else MultiOneHotDiscreteTensorSpec
+ MultiCategorical if categorical_action_encoding else MultiOneHot
)
if dtype is None:
dtype = numpy_to_torch_dtype_dict[spec.dtype]
@@ -82,7 +75,7 @@ def _jumanji_to_torchrl_spec_transform(
shape = spec.shape
if dtype is None:
dtype = numpy_to_torch_dtype_dict[spec.dtype]
- return BoundedTensorSpec(
+ return Bounded(
shape=shape,
low=np.asarray(spec.minimum),
high=np.asarray(spec.maximum),
@@ -94,11 +87,9 @@ def _jumanji_to_torchrl_spec_transform(
if dtype is None:
dtype = numpy_to_torch_dtype_dict[spec.dtype]
if dtype in (torch.float, torch.double, torch.half):
- return UnboundedContinuousTensorSpec(
- shape=shape, dtype=dtype, device=device
- )
+ return Unbounded(shape=shape, dtype=dtype, device=device)
else:
- return UnboundedDiscreteTensorSpec(shape=shape, dtype=dtype, device=device)
+ return Unbounded(shape=shape, dtype=dtype, device=device)
elif isinstance(spec, jumanji.specs.Spec) and hasattr(spec, "__dict__"):
new_spec = {}
for key, value in spec.__dict__.items():
@@ -110,7 +101,7 @@ def _jumanji_to_torchrl_spec_transform(
new_spec[key] = _jumanji_to_torchrl_spec_transform(
value, dtype, device, categorical_action_encoding
)
- return CompositeSpec(**new_spec)
+ return Composite(**new_spec)
else:
raise TypeError(f"Unsupported spec type {type(spec)}")
@@ -140,8 +131,8 @@ class JumanjiWrapper(GymLikeEnv, metaclass=_JumanjiMakeRender):
Args:
env (jumanji.env.Environment): the env to wrap.
categorical_action_encoding (bool, optional): if ``True``, categorical
- specs will be converted to the TorchRL equivalent (:class:`torchrl.data.DiscreteTensorSpec`),
- otherwise a one-hot encoding will be used (:class:`torchrl.data.OneHotTensorSpec`).
+ specs will be converted to the TorchRL equivalent (:class:`torchrl.data.Categorical`),
+ otherwise a one-hot encoding will be used (:class:`torchrl.data.OneHot`).
Defaults to ``False``.
Keyword Args:
@@ -433,9 +424,9 @@ def _make_observation_spec(self, env) -> TensorSpec:
spec = env.observation_spec
new_spec = _jumanji_to_torchrl_spec_transform(spec, device=self.device)
if isinstance(spec, jumanji.specs.Array):
- return CompositeSpec(observation=new_spec).expand(self.batch_size)
+ return Composite(observation=new_spec).expand(self.batch_size)
elif isinstance(spec, jumanji.specs.Spec):
- return CompositeSpec(**{k: v for k, v in new_spec.items()}).expand(
+ return Composite(**{k: v for k, v in new_spec.items()}).expand(
self.batch_size
)
else:
@@ -681,8 +672,8 @@ class JumanjiEnv(JumanjiWrapper):
Args:
env_name (str): the name of the environment to wrap. Must be part of :attr:`~.available_envs`.
categorical_action_encoding (bool, optional): if ``True``, categorical
- specs will be converted to the TorchRL equivalent (:class:`torchrl.data.DiscreteTensorSpec`),
- otherwise a one-hot encoding will be used (:class:`torchrl.data.OneHotTensorSpec`).
+ specs will be converted to the TorchRL equivalent (:class:`torchrl.data.Categorical`),
+ otherwise a one-hot encoding will be used (:class:`torchrl.data.OneHot`).
Defaults to ``False``.
Keyword Args:
diff --git a/torchrl/envs/libs/meltingpot.py b/torchrl/envs/libs/meltingpot.py
index 446b3dac292..b8e52031a23 100644
--- a/torchrl/envs/libs/meltingpot.py
+++ b/torchrl/envs/libs/meltingpot.py
@@ -12,7 +12,7 @@
from tensordict import TensorDict, TensorDictBase
-from torchrl.data import CompositeSpec, DiscreteTensorSpec, TensorSpec
+from torchrl.data import Categorical, Composite, TensorSpec
from torchrl.envs.common import _EnvWrapper
from torchrl.envs.libs.dm_control import _dmcontrol_to_torchrl_spec_transform
from torchrl.envs.utils import _classproperty, check_marl_grouping, MarlGroupMapType
@@ -246,9 +246,9 @@ def _make_specs(
}
self._make_group_map()
- action_spec = CompositeSpec()
- observation_spec = CompositeSpec()
- reward_spec = CompositeSpec()
+ action_spec = Composite()
+ observation_spec = Composite()
+ reward_spec = Composite()
for group in self.group_map.keys():
(
@@ -266,11 +266,9 @@ def _make_specs(
reward_spec[group] = group_reward_spec
observation_spec.update(torchrl_state_spec)
- self.done_spec = CompositeSpec(
+ self.done_spec = Composite(
{
- "done": DiscreteTensorSpec(
- n=2, shape=torch.Size((1,)), dtype=torch.bool
- ),
+ "done": Categorical(n=2, shape=torch.Size((1,)), dtype=torch.bool),
},
)
self.action_spec = action_spec
@@ -292,7 +290,7 @@ def _make_group_specs(
for agent_name in self.group_map[group]:
agent_index = self.agent_names_to_indices_map[agent_name]
action_specs.append(
- CompositeSpec(
+ Composite(
{
"action": torchrl_agent_act_specs[
agent_index
@@ -301,7 +299,7 @@ def _make_group_specs(
)
)
observation_specs.append(
- CompositeSpec(
+ Composite(
{
"observation": torchrl_agent_obs_specs[
agent_index
@@ -310,7 +308,7 @@ def _make_group_specs(
)
)
reward_specs.append(
- CompositeSpec({"reward": torchrl_rew_spec[agent_index]}) # shape = (1,)
+ Composite({"reward": torchrl_rew_spec[agent_index]}) # shape = (1,)
)
# Create multi-agent specs
diff --git a/torchrl/envs/libs/openml.py b/torchrl/envs/libs/openml.py
index 7ac318e03cb..55b246bd902 100644
--- a/torchrl/envs/libs/openml.py
+++ b/torchrl/envs/libs/openml.py
@@ -8,12 +8,7 @@
from tensordict import TensorDict, TensorDictBase
from torchrl.data.replay_buffers import SamplerWithoutReplacement
-from torchrl.data.tensor_specs import (
- CompositeSpec,
- DiscreteTensorSpec,
- UnboundedContinuousTensorSpec,
- UnboundedDiscreteTensorSpec,
-)
+from torchrl.data.tensor_specs import Categorical, Composite, Unbounded
from torchrl.envs.common import EnvBase
from torchrl.envs.transforms import Compose, DoubleToFloat, RenameTransform
from torchrl.envs.utils import _classproperty
@@ -24,17 +19,13 @@
def _make_composite_from_td(td):
# custom funtion to convert a tensordict in a similar spec structure
# of unbounded values.
- composite = CompositeSpec(
+ composite = Composite(
{
key: _make_composite_from_td(tensor)
if isinstance(tensor, TensorDictBase)
- else UnboundedContinuousTensorSpec(
- dtype=tensor.dtype, device=tensor.device, shape=tensor.shape
- )
+ else Unbounded(dtype=tensor.dtype, device=tensor.device, shape=tensor.shape)
if tensor.dtype in (torch.float16, torch.float32, torch.float64)
- else UnboundedDiscreteTensorSpec(
- dtype=tensor.dtype, device=tensor.device, shape=tensor.shape
- )
+ else Unbounded(dtype=tensor.dtype, device=tensor.device, shape=tensor.shape)
for key, tensor in td.items()
},
shape=td.shape,
@@ -115,10 +106,10 @@ def __init__(self, dataset_name, device="cpu", batch_size=None):
.reshape(self.batch_size)
.exclude("index")
)
- self.action_spec = DiscreteTensorSpec(
+ self.action_spec = Categorical(
self._data.max_outcome_val + 1, shape=self.batch_size, device=self.device
)
- self.reward_spec = UnboundedContinuousTensorSpec(shape=(*self.batch_size, 1))
+ self.reward_spec = Unbounded(shape=(*self.batch_size, 1))
def _reset(self, tensordict):
data = self._data.sample()
diff --git a/torchrl/envs/libs/openspiel.py b/torchrl/envs/libs/openspiel.py
new file mode 100644
index 00000000000..8d2d76f453f
--- /dev/null
+++ b/torchrl/envs/libs/openspiel.py
@@ -0,0 +1,655 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+from __future__ import annotations
+
+import importlib.util
+from typing import Dict, List
+
+import torch
+from tensordict import TensorDict, TensorDictBase
+
+from torchrl.data.tensor_specs import (
+ Categorical,
+ Composite,
+ NonTensor,
+ OneHot,
+ Unbounded,
+)
+from torchrl.envs.common import _EnvWrapper
+from torchrl.envs.utils import _classproperty, check_marl_grouping, MarlGroupMapType
+
+_has_pyspiel = importlib.util.find_spec("pyspiel") is not None
+
+
+def _get_envs():
+ if not _has_pyspiel:
+ raise ImportError(
+ "open_spiel not found. Consider downloading and installing "
+ f"open_spiel from {OpenSpielWrapper.git_url}."
+ )
+
+ import pyspiel
+
+ return [game.short_name for game in pyspiel.registered_games()]
+
+
+class OpenSpielWrapper(_EnvWrapper):
+ """Google DeepMind OpenSpiel environment wrapper.
+
+ GitHub: https://github.com/google-deepmind/open_spiel
+
+ Documentation: https://openspiel.readthedocs.io/en/latest/index.html
+
+ Args:
+ env (pyspiel.State): the game to wrap.
+
+ Keyword Args:
+ device (torch.device, optional): if provided, the device on which the data
+ is to be cast. Defaults to ``None``.
+ batch_size (torch.Size, optional): the batch size of the environment.
+ Defaults to ``torch.Size([])``.
+ allow_done_after_reset (bool, optional): if ``True``, it is tolerated
+ for envs to be ``done`` just after :meth:`~.reset` is called.
+ Defaults to ``False``.
+ group_map (MarlGroupMapType or Dict[str, List[str]]], optional): how to
+ group agents in tensordicts for input/output. See
+ :class:`~torchrl.envs.utils.MarlGroupMapType` for more info.
+ Defaults to
+ :class:`~torchrl.envs.utils.MarlGroupMapType.ALL_IN_ONE_GROUP`.
+ categorical_actions (bool, optional): if ``True``, categorical specs
+ will be converted to the TorchRL equivalent
+ (:class:`torchrl.data.Categorical`), otherwise a one-hot encoding
+ will be used (:class:`torchrl.data.OneHot`). Defaults to ``False``.
+ return_state (bool, optional): if ``True``, "state" is included in the
+ output of :meth:`~.reset` and :meth:`~step`. The state can be given
+ to :meth:`~.reset` to reset to that state, rather than resetting to
+ the initial state.
+ Defaults to ``False``.
+
+ Attributes:
+ available_envs: environments available to build
+
+ Examples:
+ >>> import pyspiel
+ >>> from torchrl.envs import OpenSpielWrapper
+ >>> from tensordict import TensorDict
+ >>> base_env = pyspiel.load_game('chess').new_initial_state()
+ >>> env = OpenSpielWrapper(base_env, return_state=True)
+ >>> td = env.reset()
+ >>> td = env.step(env.full_action_spec.rand())
+ >>> print(td)
+ TensorDict(
+ fields={
+ agents: TensorDict(
+ fields={
+ action: Tensor(shape=torch.Size([2, 4672]), device=cpu, dtype=torch.int64, is_shared=False)},
+ batch_size=torch.Size([]),
+ device=None,
+ is_shared=False),
+ next: TensorDict(
+ fields={
+ agents: TensorDict(
+ fields={
+ observation: Tensor(shape=torch.Size([2, 20, 8, 8]), device=cpu, dtype=torch.float32, is_shared=False),
+ reward: Tensor(shape=torch.Size([2, 1]), device=cpu, dtype=torch.float32, is_shared=False)},
+ batch_size=torch.Size([2]),
+ device=None,
+ is_shared=False),
+ current_player: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.int32, is_shared=False),
+ done: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False),
+ state: NonTensorData(data=FEN: rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR w KQkq - 0 1
+ 3009
+ , batch_size=torch.Size([]), device=None),
+ terminated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False)},
+ batch_size=torch.Size([]),
+ device=None,
+ is_shared=False)},
+ batch_size=torch.Size([]),
+ device=None,
+ is_shared=False)
+ >>> print(env.available_envs)
+ ['2048', 'add_noise', 'amazons', 'backgammon', ...]
+
+ :meth:`~.reset` can restore a specific state, rather than the initial
+ state, as long as ``return_state=True``.
+
+ >>> import pyspiel
+ >>> from torchrl.envs import OpenSpielWrapper
+ >>> from tensordict import TensorDict
+ >>> base_env = pyspiel.load_game('chess').new_initial_state()
+ >>> env = OpenSpielWrapper(base_env, return_state=True)
+ >>> td = env.reset()
+ >>> td = env.step(env.full_action_spec.rand())
+ >>> td_restore = td["next"]
+ >>> td = env.step(env.full_action_spec.rand())
+ >>> # Current state is not equal `td_restore`
+ >>> (td["next"] == td_restore).all()
+ False
+ >>> td = env.reset(td_restore)
+ >>> # After resetting, now the current state is equal to `td_restore`
+ >>> (td == td_restore).all()
+ True
+ """
+
+ git_url = "https://github.com/google-deepmind/open_spiel"
+ libname = "pyspiel"
+ _lib = None
+
+ @_classproperty
+ def lib(cls):
+ if cls._lib is not None:
+ return cls._lib
+
+ import pyspiel
+
+ cls._lib = pyspiel
+ return pyspiel
+
+ @_classproperty
+ def available_envs(cls):
+ if not _has_pyspiel:
+ return []
+ return _get_envs()
+
+ def __init__(
+ self,
+ env=None,
+ *,
+ group_map: MarlGroupMapType
+ | Dict[str, List[str]] = MarlGroupMapType.ALL_IN_ONE_GROUP,
+ categorical_actions: bool = False,
+ return_state: bool = False,
+ **kwargs,
+ ):
+ if env is not None:
+ kwargs["env"] = env
+
+ self.group_map = group_map
+ self.categorical_actions = categorical_actions
+ self.return_state = return_state
+ self._cached_game = None
+ super().__init__(**kwargs)
+
+ # `reset` allows resetting to any state, including a terminal state
+ self._allow_done_after_reset = True
+
+ def _check_kwargs(self, kwargs: Dict):
+ pyspiel = self.lib
+ if "env" not in kwargs:
+ raise TypeError("Could not find environment key 'env' in kwargs.")
+ env = kwargs["env"]
+ if not isinstance(env, pyspiel.State):
+ raise TypeError("env is not of type 'pyspiel.State'.")
+
+ def _build_env(self, env, requires_grad: bool = False, **kwargs):
+ game = env.get_game()
+ game_type = game.get_type()
+
+ if game.max_chance_outcomes() != 0:
+ raise NotImplementedError(
+ f"The game '{game_type.short_name}' has chance nodes, which are not yet supported."
+ )
+ if game_type.dynamics == self.lib.GameType.Dynamics.MEAN_FIELD:
+ # NOTE: It is unclear from the OpenSpiel documentation what exactly
+ # "mean field" means exactly, and there is no documentation on the
+ # several games which have it.
+ raise RuntimeError(
+ f"Mean field games like '{game_type.name}' are not yet " "supported."
+ )
+ self.parallel = game_type.dynamics == self.lib.GameType.Dynamics.SIMULTANEOUS
+ self.requires_grad = requires_grad
+ return env
+
+ def _init_env(self):
+ self._update_action_mask()
+
+ def _get_game(self):
+ if self._cached_game is None:
+ self._cached_game = self._env.get_game()
+ return self._cached_game
+
+ def _make_group_map(self, group_map, agent_names):
+ if group_map is None:
+ group_map = MarlGroupMapType.ONE_GROUP_PER_AGENT.get_group_map(agent_names)
+ elif isinstance(group_map, MarlGroupMapType):
+ group_map = group_map.get_group_map(agent_names)
+ check_marl_grouping(group_map, agent_names)
+ return group_map
+
+ def _make_group_specs(
+ self,
+ env,
+ group: str,
+ ):
+ observation_specs = []
+ action_specs = []
+ reward_specs = []
+ game = env.get_game()
+
+ for _ in self.group_map[group]:
+ observation_spec = Composite()
+
+ if self.has_observation:
+ observation_spec["observation"] = Unbounded(
+ shape=(*game.observation_tensor_shape(),),
+ device=self.device,
+ domain="continuous",
+ )
+
+ if self.has_information_state:
+ observation_spec["information_state"] = Unbounded(
+ shape=(*game.information_state_tensor_shape(),),
+ device=self.device,
+ domain="continuous",
+ )
+
+ observation_specs.append(observation_spec)
+
+ action_spec_cls = Categorical if self.categorical_actions else OneHot
+ action_specs.append(
+ Composite(
+ action=action_spec_cls(
+ env.num_distinct_actions(),
+ dtype=torch.int64,
+ device=self.device,
+ )
+ )
+ )
+
+ reward_specs.append(
+ Composite(
+ reward=Unbounded(
+ shape=(1,),
+ device=self.device,
+ domain="continuous",
+ )
+ )
+ )
+
+ group_observation_spec = torch.stack(
+ observation_specs, dim=0
+ ) # shape = (n_agents, n_obser_per_agent)
+ group_action_spec = torch.stack(
+ action_specs, dim=0
+ ) # shape = (n_agents, n_actions_per_agent)
+ group_reward_spec = torch.stack(reward_specs, dim=0) # shape = (n_agents, 1)
+
+ return (
+ group_observation_spec,
+ group_action_spec,
+ group_reward_spec,
+ )
+
+ def _make_specs(self, env: "pyspiel.State") -> None: # noqa: F821
+ self.agent_names = [f"player_{index}" for index in range(env.num_players())]
+ self.agent_names_to_indices_map = {
+ agent_name: i for i, agent_name in enumerate(self.agent_names)
+ }
+ self.group_map = self._make_group_map(self.group_map, self.agent_names)
+ self.done_spec = Categorical(
+ n=2,
+ shape=torch.Size((1,)),
+ dtype=torch.bool,
+ device=self.device,
+ )
+ game = env.get_game()
+ game_type = game.get_type()
+ # In OpenSpiel, a game's state may have either an "observation" tensor,
+ # an "information state" tensor, or both. If the OpenSpiel game does not
+ # have one of these, then its corresponding accessor functions raise an
+ # error, so we must avoid calling them.
+ self.has_observation = game_type.provides_observation_tensor
+ self.has_information_state = game_type.provides_information_state_tensor
+
+ observation_spec = {}
+ action_spec = {}
+ reward_spec = {}
+
+ for group in self.group_map.keys():
+ (
+ group_observation_spec,
+ group_action_spec,
+ group_reward_spec,
+ ) = self._make_group_specs(
+ env,
+ group,
+ )
+ observation_spec[group] = group_observation_spec
+ action_spec[group] = group_action_spec
+ reward_spec[group] = group_reward_spec
+
+ if self.return_state:
+ observation_spec["state"] = NonTensor([])
+
+ observation_spec["current_player"] = Unbounded(
+ shape=(),
+ dtype=torch.int,
+ device=self.device,
+ domain="discrete",
+ )
+
+ self.observation_spec = Composite(observation_spec)
+ self.action_spec = Composite(action_spec)
+ self.reward_spec = Composite(reward_spec)
+
+ def _set_seed(self, seed):
+ if seed is not None:
+ raise NotImplementedError("This environment has no seed.")
+
+ def current_player(self):
+ return self._env.current_player()
+
+ def _update_action_mask(self):
+ if self._env.is_terminal():
+ agents_acting = []
+ else:
+ agents_acting = [
+ self.agent_names
+ if self.parallel
+ else self.agent_names[self._env.current_player()]
+ ]
+ for group, agents in self.group_map.items():
+ action_masks = []
+ for agent in agents:
+ agent_index = self.agent_names_to_indices_map[agent]
+ if agent in agents_acting:
+ action_mask = torch.zeros(
+ self._env.num_distinct_actions(),
+ device=self.device,
+ dtype=torch.bool,
+ )
+ action_mask[self._env.legal_actions(agent_index)] = True
+ else:
+ action_mask = torch.zeros(
+ self._env.num_distinct_actions(),
+ device=self.device,
+ dtype=torch.bool,
+ )
+ # In OpenSpiel parallel games, non-acting players are
+ # expected to take action 0.
+ # https://openspiel.readthedocs.io/en/latest/api_reference/state_apply_action.html
+ action_mask[0] = True
+ action_masks.append(action_mask)
+ self.full_action_spec[group, "action"].update_mask(
+ torch.stack(action_masks, dim=0)
+ )
+
+ def _make_td_out(self, exclude_reward=False):
+ done = torch.tensor(
+ self._env.is_terminal(), device=self.device, dtype=torch.bool
+ )
+ current_player = torch.tensor(
+ self.current_player(), device=self.device, dtype=torch.int
+ )
+
+ source = {
+ "done": done,
+ "terminated": done.clone(),
+ "current_player": current_player,
+ }
+
+ if self.return_state:
+ source["state"] = self._env.serialize()
+
+ reward = self._env.returns()
+
+ for group, agent_names in self.group_map.items():
+ agent_tds = []
+
+ for agent in agent_names:
+ agent_index = self.agent_names_to_indices_map[agent]
+ agent_source = {}
+ if self.has_observation:
+ observation_shape = self._get_game().observation_tensor_shape()
+ agent_source["observation"] = self._to_tensor(
+ self._env.observation_tensor(agent_index)
+ ).reshape(observation_shape)
+
+ if self.has_information_state:
+ information_state_shape = (
+ self._get_game().information_state_tensor_shape()
+ )
+ agent_source["information_state"] = self._to_tensor(
+ self._env.information_state_tensor(agent_index)
+ ).reshape(information_state_shape)
+
+ if not exclude_reward:
+ agent_source["reward"] = self._to_tensor(reward[agent_index])
+
+ agent_td = TensorDict(
+ source=agent_source,
+ batch_size=self.batch_size,
+ device=self.device,
+ )
+ agent_tds.append(agent_td)
+
+ source[group] = torch.stack(agent_tds, dim=0)
+
+ tensordict_out = TensorDict(
+ source=source,
+ batch_size=self.batch_size,
+ device=self.device,
+ )
+
+ return tensordict_out
+
+ def _get_action_from_tensor(self, tensor):
+ if not self.categorical_actions:
+ action = torch.argmax(tensor, dim=-1)
+ else:
+ action = tensor
+ return action
+
+ def _step_parallel(self, tensordict: TensorDictBase):
+ actions = [0] * self._env.num_players()
+ for group, agents in self.group_map.items():
+ for index_in_group, agent in enumerate(agents):
+ agent_index = self.agent_names_to_indices_map[agent]
+ action_tensor = tensordict[group, "action"][index_in_group]
+ action = self._get_action_from_tensor(action_tensor)
+ actions[agent_index] = action
+
+ self._env.apply_actions(actions)
+
+ def _step_sequential(self, tensordict: TensorDictBase):
+ agent_index = self._env.current_player()
+
+ # If the game has ended, do nothing
+ if agent_index == self.lib.PlayerId.TERMINAL:
+ return
+
+ agent = self.agent_names[agent_index]
+ agent_group = None
+ agent_index_in_group = None
+
+ for group, agents in self.group_map.items():
+ if agent in agents:
+ agent_group = group
+ agent_index_in_group = agents.index(agent)
+ break
+
+ assert agent_group is not None
+
+ action_tensor = tensordict[agent_group, "action"][agent_index_in_group]
+ action = self._get_action_from_tensor(action_tensor)
+ self._env.apply_action(action)
+
+ def _step(self, tensordict: TensorDictBase) -> TensorDictBase:
+ if self.parallel:
+ self._step_parallel(tensordict)
+ else:
+ self._step_sequential(tensordict)
+
+ self._update_action_mask()
+ return self._make_td_out()
+
+ def _to_tensor(self, value):
+ return torch.tensor(value, device=self.device, dtype=torch.float32)
+
+ def _reset(
+ self, tensordict: TensorDictBase | None = None, **kwargs
+ ) -> TensorDictBase:
+ game = self._get_game()
+
+ if tensordict is not None and "state" in tensordict:
+ new_env = game.deserialize_state(tensordict["state"])
+ else:
+ new_env = game.new_initial_state()
+
+ self._env = new_env
+ self._update_action_mask()
+ return self._make_td_out(exclude_reward=True)
+
+
+class OpenSpielEnv(OpenSpielWrapper):
+ """Google DeepMind OpenSpiel environment wrapper built with the game string.
+
+ GitHub: https://github.com/google-deepmind/open_spiel
+
+ Documentation: https://openspiel.readthedocs.io/en/latest/index.html
+
+ Args:
+ game_string (str): the name of the game to wrap. Must be part of
+ :attr:`~.available_envs`.
+
+ Keyword Args:
+ device (torch.device, optional): if provided, the device on which the data
+ is to be cast. Defaults to ``None``.
+ batch_size (torch.Size, optional): the batch size of the environment.
+ Defaults to ``torch.Size([])``.
+ allow_done_after_reset (bool, optional): if ``True``, it is tolerated
+ for envs to be ``done`` just after :meth:`~.reset` is called.
+ Defaults to ``False``.
+ group_map (MarlGroupMapType or Dict[str, List[str]]], optional): how to
+ group agents in tensordicts for input/output. See
+ :class:`~torchrl.envs.utils.MarlGroupMapType` for more info.
+ Defaults to
+ :class:`~torchrl.envs.utils.MarlGroupMapType.ALL_IN_ONE_GROUP`.
+ categorical_actions (bool, optional): if ``True``, categorical specs
+ will be converted to the TorchRL equivalent
+ (:class:`torchrl.data.Categorical`), otherwise a one-hot encoding
+ will be used (:class:`torchrl.data.OneHot`). Defaults to ``False``.
+ return_state (bool, optional): if ``True``, "state" is included in the
+ output of :meth:`~.reset` and :meth:`~step`. The state can be given
+ to :meth:`~.reset` to reset to that state, rather than resetting to
+ the initial state.
+ Defaults to ``False``.
+
+ Attributes:
+ available_envs: environments available to build
+
+ Examples:
+ >>> from torchrl.envs import OpenSpielEnv
+ >>> from tensordict import TensorDict
+ >>> env = OpenSpielEnv("chess", return_state=True)
+ >>> td = env.reset()
+ >>> td = env.step(env.full_action_spec.rand())
+ >>> print(td)
+ TensorDict(
+ fields={
+ agents: TensorDict(
+ fields={
+ action: Tensor(shape=torch.Size([2, 4672]), device=cpu, dtype=torch.int64, is_shared=False)},
+ batch_size=torch.Size([]),
+ device=None,
+ is_shared=False),
+ next: TensorDict(
+ fields={
+ agents: TensorDict(
+ fields={
+ observation: Tensor(shape=torch.Size([2, 20, 8, 8]), device=cpu, dtype=torch.float32, is_shared=False),
+ reward: Tensor(shape=torch.Size([2, 1]), device=cpu, dtype=torch.float32, is_shared=False)},
+ batch_size=torch.Size([2]),
+ device=None,
+ is_shared=False),
+ current_player: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.int32, is_shared=False),
+ done: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False),
+ state: NonTensorData(data=FEN: rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR w KQkq - 0 1
+ 674
+ , batch_size=torch.Size([]), device=None),
+ terminated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False)},
+ batch_size=torch.Size([]),
+ device=None,
+ is_shared=False)},
+ batch_size=torch.Size([]),
+ device=None,
+ is_shared=False)
+ >>> print(env.available_envs)
+ ['2048', 'add_noise', 'amazons', 'backgammon', ...]
+
+ :meth:`~.reset` can restore a specific state, rather than the initial state,
+ as long as ``return_state=True``.
+
+ >>> from torchrl.envs import OpenSpielEnv
+ >>> from tensordict import TensorDict
+ >>> env = OpenSpielEnv("chess", return_state=True)
+ >>> td = env.reset()
+ >>> td = env.step(env.full_action_spec.rand())
+ >>> td_restore = td["next"]
+ >>> td = env.step(env.full_action_spec.rand())
+ >>> # Current state is not equal `td_restore`
+ >>> (td["next"] == td_restore).all()
+ False
+ >>> td = env.reset(td_restore)
+ >>> # After resetting, now the current state is equal to `td_restore`
+ >>> (td == td_restore).all()
+ True
+ """
+
+ def __init__(
+ self,
+ game_string,
+ *,
+ group_map: MarlGroupMapType
+ | Dict[str, List[str]] = MarlGroupMapType.ALL_IN_ONE_GROUP,
+ categorical_actions=False,
+ return_state: bool = False,
+ **kwargs,
+ ):
+ kwargs["game_string"] = game_string
+ super().__init__(
+ group_map=group_map,
+ categorical_actions=categorical_actions,
+ return_state=return_state,
+ **kwargs,
+ )
+
+ def _build_env(
+ self,
+ game_string: str,
+ **kwargs,
+ ) -> "pyspiel.State": # noqa: F821
+ if not _has_pyspiel:
+ raise ImportError(
+ f"open_spiel not found, unable to create {game_string}. Consider "
+ f"downloading and installing open_spiel from {self.git_url}"
+ )
+ requires_grad = kwargs.pop("requires_grad", False)
+ parameters = kwargs.pop("parameters", None)
+ if kwargs:
+ raise ValueError("kwargs not supported.")
+
+ if parameters:
+ game = self.lib.load_game(game_string, parameters=parameters)
+ else:
+ game = self.lib.load_game(game_string)
+
+ env = game.new_initial_state()
+ return super()._build_env(
+ env,
+ requires_grad=requires_grad,
+ )
+
+ @property
+ def game_string(self):
+ return self._constructor_kwargs["game_string"]
+
+ def _check_kwargs(self, kwargs: Dict):
+ if "game_string" not in kwargs:
+ raise TypeError("Expected 'game_string' to be part of kwargs")
+
+ def __repr__(self) -> str:
+ return f"{self.__class__.__name__}(env={self.game_string}, batch_size={self.batch_size}, device={self.device})"
diff --git a/torchrl/envs/libs/pettingzoo.py b/torchrl/envs/libs/pettingzoo.py
index eb94a27cbba..9853e8d516d 100644
--- a/torchrl/envs/libs/pettingzoo.py
+++ b/torchrl/envs/libs/pettingzoo.py
@@ -13,12 +13,7 @@
import torch
from tensordict import TensorDictBase
-from torchrl.data.tensor_specs import (
- CompositeSpec,
- DiscreteTensorSpec,
- OneHotDiscreteTensorSpec,
- UnboundedContinuousTensorSpec,
-)
+from torchrl.data.tensor_specs import Categorical, Composite, OneHot, Unbounded
from torchrl.envs.common import _EnvWrapper
from torchrl.envs.libs.gym import _gym_to_torchrl_spec_transform, set_gym_backend
from torchrl.envs.utils import _classproperty, check_marl_grouping, MarlGroupMapType
@@ -141,7 +136,7 @@ class PettingZooWrapper(_EnvWrapper):
For example, you can provide ``MarlGroupMapType.ONE_GROUP_PER_AGENT``, telling that each agent should
have its own tensordict (similar to the pettingzoo parallel API).
- Grouping is useful for leveraging vectorisation among agents whose data goes through the same
+ Grouping is useful for leveraging vectorization among agents whose data goes through the same
neural network.
Args:
@@ -308,24 +303,24 @@ def _make_specs(
check_marl_grouping(self.group_map, self.possible_agents)
self.has_action_mask = {group: False for group in self.group_map.keys()}
- action_spec = CompositeSpec()
- observation_spec = CompositeSpec()
- reward_spec = CompositeSpec()
- done_spec = CompositeSpec(
+ action_spec = Composite()
+ observation_spec = Composite()
+ reward_spec = Composite()
+ done_spec = Composite(
{
- "done": DiscreteTensorSpec(
+ "done": Categorical(
n=2,
shape=torch.Size((1,)),
dtype=torch.bool,
device=self.device,
),
- "terminated": DiscreteTensorSpec(
+ "terminated": Categorical(
n=2,
shape=torch.Size((1,)),
dtype=torch.bool,
device=self.device,
),
- "truncated": DiscreteTensorSpec(
+ "truncated": Categorical(
n=2,
shape=torch.Size((1,)),
dtype=torch.bool,
@@ -356,7 +351,7 @@ def _make_group_specs(self, group_name: str, agent_names: List[str]):
observation_specs = []
for agent in agent_names:
action_specs.append(
- CompositeSpec(
+ Composite(
{
"action": _gym_to_torchrl_spec_transform(
self.action_space(agent),
@@ -368,7 +363,7 @@ def _make_group_specs(self, group_name: str, agent_names: List[str]):
)
)
observation_specs.append(
- CompositeSpec(
+ Composite(
{
"observation": _gym_to_torchrl_spec_transform(
self.observation_space(agent),
@@ -386,34 +381,31 @@ def _make_group_specs(self, group_name: str, agent_names: List[str]):
# We uniform this by removing it from both places and optionally set it in a standard location.
group_observation_inner_spec = group_observation_spec["observation"]
if (
- isinstance(group_observation_inner_spec, CompositeSpec)
+ isinstance(group_observation_inner_spec, Composite)
and "action_mask" in group_observation_inner_spec.keys()
):
self.has_action_mask[group_name] = True
del group_observation_inner_spec["action_mask"]
- group_observation_spec["action_mask"] = DiscreteTensorSpec(
+ group_observation_spec["action_mask"] = Categorical(
n=2,
shape=group_action_spec["action"].shape
if not self.categorical_actions
- else (
- *group_action_spec["action"].shape,
- group_action_spec["action"].space.n,
- ),
+ else group_action_spec["action"].to_one_hot_spec().shape,
dtype=torch.bool,
device=self.device,
)
if self.use_mask:
- group_observation_spec["mask"] = DiscreteTensorSpec(
+ group_observation_spec["mask"] = Categorical(
n=2,
shape=torch.Size((n_agents,)),
dtype=torch.bool,
device=self.device,
)
- group_reward_spec = CompositeSpec(
+ group_reward_spec = Composite(
{
- "reward": UnboundedContinuousTensorSpec(
+ "reward": Unbounded(
shape=torch.Size((n_agents, 1)),
device=self.device,
dtype=torch.float32,
@@ -421,21 +413,21 @@ def _make_group_specs(self, group_name: str, agent_names: List[str]):
},
shape=torch.Size((n_agents,)),
)
- group_done_spec = CompositeSpec(
+ group_done_spec = Composite(
{
- "done": DiscreteTensorSpec(
+ "done": Categorical(
n=2,
shape=torch.Size((n_agents, 1)),
dtype=torch.bool,
device=self.device,
),
- "terminated": DiscreteTensorSpec(
+ "terminated": Categorical(
n=2,
shape=torch.Size((n_agents, 1)),
dtype=torch.bool,
device=self.device,
),
- "truncated": DiscreteTensorSpec(
+ "truncated": Categorical(
n=2,
shape=torch.Size((n_agents, 1)),
dtype=torch.bool,
@@ -473,11 +465,11 @@ def _init_env(self):
info_specs = []
for agent in agents:
info_specs.append(
- CompositeSpec(
+ Composite(
{
- "info": CompositeSpec(
+ "info": Composite(
{
- key: UnboundedContinuousTensorSpec(
+ key: Unbounded(
shape=torch.as_tensor(value).shape,
device=self.device,
)
@@ -495,11 +487,11 @@ def _init_env(self):
group_action_spec = self.input_spec[
"full_action_spec", group, "action"
]
- self.observation_spec[group]["action_mask"] = DiscreteTensorSpec(
+ self.observation_spec[group]["action_mask"] = Categorical(
n=2,
shape=group_action_spec.shape
if not self.categorical_actions
- else (*group_action_spec.shape, group_action_spec.space.n),
+ else group_action_spec.to_one_hot_spec().shape,
dtype=torch.bool,
device=self.device,
)
@@ -518,7 +510,7 @@ def _init_env(self):
)
except AttributeError:
state_example = torch.as_tensor(self.state(), device=self.device)
- state_spec = UnboundedContinuousTensorSpec(
+ state_spec = Unbounded(
shape=state_example.shape,
dtype=state_example.dtype,
device=self.device,
@@ -809,9 +801,7 @@ def _update_action_mask(self, td, observation_dict, info_dict):
del agent_info["action_mask"]
group_action_spec = self.input_spec["full_action_spec", group, "action"]
- if isinstance(
- group_action_spec, (DiscreteTensorSpec, OneHotDiscreteTensorSpec)
- ):
+ if isinstance(group_action_spec, (Categorical, OneHot)):
# We update the mask for available actions
group_action_spec.update_mask(group_mask.clone())
@@ -904,7 +894,7 @@ class PettingZooEnv(PettingZooWrapper):
For example, you can provide ``MarlGroupMapType.ONE_GROUP_PER_AGENT``, telling that each agent should
have its own tensordict (similar to the pettingzoo parallel API).
- Grouping is useful for leveraging vectorisation among agents whose data goes through the same
+ Grouping is useful for leveraging vectorization among agents whose data goes through the same
neural network.
Args:
diff --git a/torchrl/envs/libs/robohive.py b/torchrl/envs/libs/robohive.py
index 5e5c8f52393..30d9c644ced 100644
--- a/torchrl/envs/libs/robohive.py
+++ b/torchrl/envs/libs/robohive.py
@@ -12,7 +12,7 @@
import numpy as np
import torch
from tensordict import TensorDict
-from torchrl.data.tensor_specs import UnboundedContinuousTensorSpec
+from torchrl.data.tensor_specs import Unbounded
from torchrl.envs.libs.gym import (
_AsyncMeta,
_gym_to_torchrl_spec_transform,
@@ -80,8 +80,8 @@ class RoboHiveEnv(GymEnv, metaclass=_RoboHiveBuild):
Args:
env_name (str): the environment name to build. Must be one of :attr:`.available_envs`
categorical_action_encoding (bool, optional): if ``True``, categorical
- specs will be converted to the TorchRL equivalent (:class:`torchrl.data.DiscreteTensorSpec`),
- otherwise a one-hot encoding will be used (:class:`torchrl.data.OneHotTensorSpec`).
+ specs will be converted to the TorchRL equivalent (:class:`torchrl.data.Categorical`),
+ otherwise a one-hot encoding will be used (:class:`torchrl.data.OneHot`).
Defaults to ``False``.
Keyword Args:
@@ -305,7 +305,7 @@ def get_obs():
)
self.observation_spec = observation_spec
- self.reward_spec = UnboundedContinuousTensorSpec(
+ self.reward_spec = Unbounded(
shape=(1,),
device=self.device,
) # default
diff --git a/torchrl/envs/libs/smacv2.py b/torchrl/envs/libs/smacv2.py
index d460eb38f1e..67e71da0d5a 100644
--- a/torchrl/envs/libs/smacv2.py
+++ b/torchrl/envs/libs/smacv2.py
@@ -10,13 +10,7 @@
import torch
from tensordict import TensorDict, TensorDictBase
-from torchrl.data.tensor_specs import (
- BoundedTensorSpec,
- CompositeSpec,
- DiscreteTensorSpec,
- OneHotDiscreteTensorSpec,
- UnboundedContinuousTensorSpec,
-)
+from torchrl.data.tensor_specs import Bounded, Categorical, Composite, OneHot, Unbounded
from torchrl.envs.common import _EnvWrapper
from torchrl.envs.utils import _classproperty, ACTION_MASK_ERROR
@@ -224,11 +218,11 @@ def _build_env(
def _make_specs(self, env: "smacv2.env.StarCraft2Env") -> None: # noqa: F821
self.group_map = {"agents": [str(i) for i in range(self.n_agents)]}
- self.reward_spec = UnboundedContinuousTensorSpec(
+ self.reward_spec = Unbounded(
shape=torch.Size((1,)),
device=self.device,
)
- self.done_spec = DiscreteTensorSpec(
+ self.done_spec = Categorical(
n=2,
shape=torch.Size((1,)),
dtype=torch.bool,
@@ -241,54 +235,50 @@ def _init_env(self) -> None:
self._env.reset()
self._update_action_mask()
- def _make_action_spec(self) -> CompositeSpec:
+ def _make_action_spec(self) -> Composite:
if self.categorical_actions:
- action_spec = DiscreteTensorSpec(
+ action_spec = Categorical(
self.n_actions,
shape=torch.Size((self.n_agents,)),
device=self.device,
dtype=torch.long,
)
else:
- action_spec = OneHotDiscreteTensorSpec(
+ action_spec = OneHot(
self.n_actions,
shape=torch.Size((self.n_agents, self.n_actions)),
device=self.device,
dtype=torch.long,
)
- spec = CompositeSpec(
+ spec = Composite(
{
- "agents": CompositeSpec(
+ "agents": Composite(
{"action": action_spec}, shape=torch.Size((self.n_agents,))
)
}
)
return spec
- def _make_observation_spec(self) -> CompositeSpec:
- obs_spec = BoundedTensorSpec(
+ def _make_observation_spec(self) -> Composite:
+ obs_spec = Bounded(
low=-1.0,
high=1.0,
shape=torch.Size([self.n_agents, self.get_obs_size()]),
device=self.device,
dtype=torch.float32,
)
- info_spec = CompositeSpec(
+ info_spec = Composite(
{
- "battle_won": DiscreteTensorSpec(
- 2, dtype=torch.bool, device=self.device
- ),
- "episode_limit": DiscreteTensorSpec(
- 2, dtype=torch.bool, device=self.device
- ),
- "dead_allies": BoundedTensorSpec(
+ "battle_won": Categorical(2, dtype=torch.bool, device=self.device),
+ "episode_limit": Categorical(2, dtype=torch.bool, device=self.device),
+ "dead_allies": Bounded(
low=0,
high=self.n_agents,
dtype=torch.long,
device=self.device,
shape=(),
),
- "dead_enemies": BoundedTensorSpec(
+ "dead_enemies": Bounded(
low=0,
high=self.n_enemies,
dtype=torch.long,
@@ -297,19 +287,19 @@ def _make_observation_spec(self) -> CompositeSpec:
),
}
)
- mask_spec = DiscreteTensorSpec(
+ mask_spec = Categorical(
2,
torch.Size([self.n_agents, self.n_actions]),
device=self.device,
dtype=torch.bool,
)
- spec = CompositeSpec(
+ spec = Composite(
{
- "agents": CompositeSpec(
+ "agents": Composite(
{"observation": obs_spec, "action_mask": mask_spec},
shape=torch.Size((self.n_agents,)),
),
- "state": BoundedTensorSpec(
+ "state": Bounded(
low=-1.0,
high=1.0,
shape=torch.Size((self.get_state_size(),)),
diff --git a/torchrl/envs/libs/unity_mlagents.py b/torchrl/envs/libs/unity_mlagents.py
new file mode 100644
index 00000000000..95c2460bc83
--- /dev/null
+++ b/torchrl/envs/libs/unity_mlagents.py
@@ -0,0 +1,899 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+from __future__ import annotations
+
+import importlib.util
+from typing import Dict, List, Optional
+
+import torch
+from tensordict import TensorDict, TensorDictBase
+
+from torchrl.data.tensor_specs import (
+ BoundedContinuous,
+ Categorical,
+ Composite,
+ MultiCategorical,
+ MultiOneHot,
+ Unbounded,
+)
+from torchrl.envs.common import _EnvWrapper
+from torchrl.envs.utils import _classproperty, check_marl_grouping, MarlGroupMapType
+
+_has_unity_mlagents = importlib.util.find_spec("mlagents_envs") is not None
+
+
+def _get_registered_envs():
+ if not _has_unity_mlagents:
+ raise ImportError(
+ "mlagents_envs not found. Consider downloading and installing "
+ f"mlagents from {UnityMLAgentsWrapper.git_url}."
+ )
+
+ from mlagents_envs.registry import default_registry
+
+ return list(default_registry.keys())
+
+
+class UnityMLAgentsWrapper(_EnvWrapper):
+ """Unity ML-Agents environment wrapper.
+
+ GitHub: https://github.com/Unity-Technologies/ml-agents
+
+ Documentation: https://unity-technologies.github.io/ml-agents/Python-LLAPI/
+
+ Args:
+ env (mlagents_envs.environment.UnityEnvironment): the ML-Agents
+ environment to wrap.
+
+ Keyword Args:
+ device (torch.device, optional): if provided, the device on which the data
+ is to be cast. Defaults to ``None``.
+ batch_size (torch.Size, optional): the batch size of the environment.
+ Defaults to ``torch.Size([])``.
+ allow_done_after_reset (bool, optional): if ``True``, it is tolerated
+ for envs to be ``done`` just after :meth:`~.reset` is called.
+ Defaults to ``False``.
+ group_map (MarlGroupMapType or Dict[str, List[str]]], optional): how to
+ group agents in tensordicts for input/output. See
+ :class:`~torchrl.envs.utils.MarlGroupMapType` for more info. If not
+ specified, agents are grouped according to the group ID given by the
+ Unity environment. Defaults to ``None``.
+ categorical_actions (bool, optional): if ``True``, categorical specs
+ will be converted to the TorchRL equivalent
+ (:class:`torchrl.data.Categorical`), otherwise a one-hot encoding
+ will be used (:class:`torchrl.data.OneHot`). Defaults to ``False``.
+
+ Attributes:
+ available_envs: list of registered environments available to build
+
+ Examples:
+ >>> from mlagents_envs.environment import UnityEnvironment
+ >>> base_env = UnityEnvironment()
+ >>> from torchrl.envs import UnityMLAgentsWrapper
+ >>> env = UnityMLAgentsWrapper(base_env)
+ >>> td = env.reset()
+ >>> td = env.step(td.update(env.full_action_spec.rand()))
+ """
+
+ git_url = "https://github.com/Unity-Technologies/ml-agents"
+ libname = "mlagents_envs"
+ _lib = None
+
+ @_classproperty
+ def lib(cls):
+ if cls._lib is not None:
+ return cls._lib
+
+ import mlagents_envs
+ import mlagents_envs.environment
+
+ cls._lib = mlagents_envs
+ return mlagents_envs
+
+ def __init__(
+ self,
+ env=None,
+ *,
+ group_map: MarlGroupMapType | Dict[str, List[str]] | None = None,
+ categorical_actions: bool = False,
+ **kwargs,
+ ):
+ if env is not None:
+ kwargs["env"] = env
+
+ self.group_map = group_map
+ self.categorical_actions = categorical_actions
+ super().__init__(**kwargs)
+
+ def _check_kwargs(self, kwargs: Dict):
+ mlagents_envs = self.lib
+ if "env" not in kwargs:
+ raise TypeError("Could not find environment key 'env' in kwargs.")
+ env = kwargs["env"]
+ if not isinstance(env, mlagents_envs.environment.UnityEnvironment):
+ raise TypeError(
+ "env is not of type 'mlagents_envs.environment.UnityEnvironment'"
+ )
+
+ def _build_env(self, env, requires_grad: bool = False, **kwargs):
+ self.requires_grad = requires_grad
+ return env
+
+ def _init_env(self):
+ self._update_action_mask()
+
+ # Creates a group map where agents are grouped by the group_id given by the
+ # Unity environment.
+ def _collect_agents(self, env):
+ agent_name_to_behavior_map = {}
+ agent_name_to_group_id_map = {}
+
+ for steps_idx in [0, 1]:
+ for behavior in env.behavior_specs.keys():
+ steps = env.get_steps(behavior)[steps_idx]
+ is_terminal = steps_idx == 1
+ agent_ids = steps.agent_id
+ group_ids = steps.group_id
+
+ for agent_id, group_id in zip(agent_ids, group_ids):
+ agent_name = f"agent_{agent_id}"
+ if agent_name in agent_name_to_behavior_map:
+ # Sometimes in an MLAgents environment, an agent may
+ # show up in both the decision steps and the terminal
+ # steps. When that happens, just skip the duplicate.
+ assert is_terminal
+ continue
+ agent_name_to_behavior_map[agent_name] = behavior
+ agent_name_to_group_id_map[agent_name] = group_id
+
+ return (
+ agent_name_to_behavior_map,
+ agent_name_to_group_id_map,
+ )
+
+ # Creates a group map where agents are grouped by their group_id.
+ def _make_default_group_map(self, agent_name_to_group_id_map):
+ group_map = {}
+ for agent_name, group_id in agent_name_to_group_id_map.items():
+ group_name = f"group_{group_id}"
+ if group_name not in group_map:
+ group_map[group_name] = []
+ group_map[group_name].append(agent_name)
+ return group_map
+
+ def _make_group_map(self, group_map, agent_name_to_group_id_map):
+ if group_map is None:
+ group_map = self._make_default_group_map(agent_name_to_group_id_map)
+ elif isinstance(group_map, MarlGroupMapType):
+ group_map = group_map.get_group_map(agent_name_to_group_id_map.keys())
+ check_marl_grouping(group_map, agent_name_to_group_id_map.keys())
+ agent_name_to_group_name_map = {}
+ for group_name, agents in group_map.items():
+ for agent_name in agents:
+ agent_name_to_group_name_map[agent_name] = group_name
+ return group_map, agent_name_to_group_name_map
+
+ def _make_specs(
+ self, env: "mlagents_envs.environment.UnityEnvironment" # noqa: F821
+ ) -> None:
+ # NOTE: We need to reset here because mlagents only initializes the
+ # agents and behaviors after reset. In order to build specs, we make the
+ # following assumptions about the mlagents environment:
+ # * all behaviors are defined on the first step
+ # * all agents request an action on the first step
+ # However, mlagents allows you to break these assumptions, so we probably
+ # will need to detect changes to the behaviors and agents on each step.
+ env.reset()
+ (
+ self.agent_name_to_behavior_map,
+ self.agent_name_to_group_id_map,
+ ) = self._collect_agents(env)
+
+ (self.group_map, self.agent_name_to_group_name_map) = self._make_group_map(
+ self.group_map, self.agent_name_to_group_id_map
+ )
+
+ action_spec = {}
+ observation_spec = {}
+ reward_spec = {}
+ done_spec = {}
+
+ for group_name, agents in self.group_map.items():
+ group_action_spec = {}
+ group_observation_spec = {}
+ group_reward_spec = {}
+ group_done_spec = {}
+ for agent_name in agents:
+ behavior = self.agent_name_to_behavior_map[agent_name]
+ behavior_spec = env.behavior_specs[behavior]
+
+ # Create action spec
+ agent_action_spec = Composite()
+ env_action_spec = behavior_spec.action_spec
+ discrete_branches = env_action_spec.discrete_branches
+ continuous_size = env_action_spec.continuous_size
+ if len(discrete_branches) > 0:
+ discrete_action_spec_cls = (
+ MultiCategorical if self.categorical_actions else MultiOneHot
+ )
+ agent_action_spec["discrete_action"] = discrete_action_spec_cls(
+ discrete_branches,
+ dtype=torch.int32,
+ device=self.device,
+ )
+ if continuous_size > 0:
+ # In mlagents, continuous actions can take values between -1
+ # and 1 by default:
+ # https://github.com/Unity-Technologies/ml-agents/blob/22a59aad34ef46a5de05469735426feed758f8f5/ml-agents-envs/mlagents_envs/base_env.py#L395
+ agent_action_spec["continuous_action"] = BoundedContinuous(
+ -1, 1, (continuous_size,), self.device, torch.float32
+ )
+ group_action_spec[agent_name] = agent_action_spec
+
+ # Create observation spec
+ agent_observation_spec = Composite()
+ for obs_idx, env_observation_spec in enumerate(
+ behavior_spec.observation_specs
+ ):
+ if len(env_observation_spec.name) == 0:
+ obs_name = f"observation_{obs_idx}"
+ else:
+ obs_name = env_observation_spec.name
+ agent_observation_spec[obs_name] = Unbounded(
+ env_observation_spec.shape,
+ dtype=torch.float32,
+ device=self.device,
+ )
+ group_observation_spec[agent_name] = agent_observation_spec
+
+ # Create reward spec
+ agent_reward_spec = Composite()
+ agent_reward_spec["reward"] = Unbounded(
+ (1,),
+ dtype=torch.float32,
+ device=self.device,
+ )
+ agent_reward_spec["group_reward"] = Unbounded(
+ (1,),
+ dtype=torch.float32,
+ device=self.device,
+ )
+ group_reward_spec[agent_name] = agent_reward_spec
+
+ # Create done spec
+ agent_done_spec = Composite()
+ for done_key in ["done", "terminated", "truncated"]:
+ agent_done_spec[done_key] = Categorical(
+ 2, (1,), dtype=torch.bool, device=self.device
+ )
+ group_done_spec[agent_name] = agent_done_spec
+
+ action_spec[group_name] = group_action_spec
+ observation_spec[group_name] = group_observation_spec
+ reward_spec[group_name] = group_reward_spec
+ done_spec[group_name] = group_done_spec
+
+ self.action_spec = Composite(action_spec)
+ self.observation_spec = Composite(observation_spec)
+ self.reward_spec = Composite(reward_spec)
+ self.done_spec = Composite(done_spec)
+
+ def _set_seed(self, seed):
+ if seed is not None:
+ raise NotImplementedError("This environment has no seed.")
+
+ def _check_agent_exists(self, agent_name, group_id):
+ if agent_name not in self.agent_name_to_group_id_map:
+ raise RuntimeError(
+ (
+ "Unity environment added a new agent. This is not yet "
+ "supported in torchrl."
+ )
+ )
+ if self.agent_name_to_group_id_map[agent_name] != group_id:
+ raise RuntimeError(
+ (
+ "Unity environment changed the group of an agent. This "
+ "is not yet supported in torchrl."
+ )
+ )
+
+ def _update_action_mask(self):
+ for behavior, behavior_spec in self._env.behavior_specs.items():
+ env_action_spec = behavior_spec.action_spec
+ discrete_branches = env_action_spec.discrete_branches
+
+ if len(discrete_branches) > 0:
+ steps = self._env.get_steps(behavior)[0]
+ env_action_mask = steps.action_mask
+ if env_action_mask is not None:
+ combined_action_mask = torch.cat(
+ [
+ torch.tensor(m, device=self.device, dtype=torch.bool)
+ for m in env_action_mask
+ ],
+ dim=-1,
+ ).logical_not()
+
+ for agent_id, group_id, agent_action_mask in zip(
+ steps.agent_id, steps.group_id, combined_action_mask
+ ):
+ agent_name = f"agent_{agent_id}"
+ self._check_agent_exists(agent_name, group_id)
+ group_name = self.agent_name_to_group_name_map[agent_name]
+ self.full_action_spec[
+ group_name, agent_name, "discrete_action"
+ ].update_mask(agent_action_mask)
+
+ def _make_td_out(self, tensordict_in, is_reset=False):
+ source = {}
+ for behavior, behavior_spec in self._env.behavior_specs.items():
+ for idx, steps in enumerate(self._env.get_steps(behavior)):
+ is_terminal = idx == 1
+ for steps_idx, (agent_id, group_id) in enumerate(
+ zip(steps.agent_id, steps.group_id)
+ ):
+ agent_name = f"agent_{agent_id}"
+ self._check_agent_exists(agent_name, group_id)
+ group_name = self.agent_name_to_group_name_map[agent_name]
+ if group_name not in source:
+ source[group_name] = {}
+ if agent_name not in source[group_name]:
+ source[group_name][agent_name] = {}
+
+ # Add observations
+ for obs_idx, (
+ behavior_observation,
+ env_observation_spec,
+ ) in enumerate(zip(steps.obs, behavior_spec.observation_specs)):
+ observation = torch.tensor(
+ behavior_observation[steps_idx],
+ device=self.device,
+ dtype=torch.float32,
+ )
+ if len(env_observation_spec.name) == 0:
+ obs_name = f"observation_{obs_idx}"
+ else:
+ obs_name = env_observation_spec.name
+ source[group_name][agent_name][obs_name] = observation
+
+ # Add rewards
+ if not is_reset:
+ source[group_name][agent_name]["reward"] = torch.tensor(
+ steps.reward[steps_idx],
+ device=self.device,
+ dtype=torch.float32,
+ )
+ source[group_name][agent_name]["group_reward"] = torch.tensor(
+ steps.group_reward[steps_idx],
+ device=self.device,
+ dtype=torch.float32,
+ )
+
+ # Add done
+ done = is_terminal and not is_reset
+ source[group_name][agent_name]["done"] = torch.tensor(
+ done, device=self.device, dtype=torch.bool
+ )
+ source[group_name][agent_name]["truncated"] = torch.tensor(
+ done and steps.interrupted[steps_idx],
+ device=self.device,
+ dtype=torch.bool,
+ )
+ source[group_name][agent_name]["terminated"] = torch.tensor(
+ done and not steps.interrupted[steps_idx],
+ device=self.device,
+ dtype=torch.bool,
+ )
+
+ if tensordict_in is not None:
+ # In MLAgents, a given step will only contain information for agents
+ # which either terminated or requested a decision during the step.
+ # Some agents may have neither terminated nor requested a decision,
+ # so we need to fill in their information from the previous step.
+ for group_name, agents in self.group_map.items():
+ for agent_name in agents:
+ if group_name not in source.keys():
+ source[group_name] = {}
+ if agent_name not in source[group_name].keys():
+ agent_dict = {}
+ agent_behavior = self.agent_name_to_behavior_map[agent_name]
+ behavior_spec = self._env.behavior_specs[agent_behavior]
+ td_agent_in = tensordict_in[group_name, agent_name]
+
+ # Add observations
+ for env_observation_spec in behavior_spec.observation_specs:
+ if len(env_observation_spec.name) == 0:
+ obs_name = f"observation_{obs_idx}"
+ else:
+ obs_name = env_observation_spec.name
+ agent_dict[obs_name] = td_agent_in[obs_name]
+
+ # Add rewards
+ if not is_reset:
+ # Since the agent didn't request an decision, the
+ # reward is 0
+ agent_dict["reward"] = torch.zeros(
+ (1,), device=self.device, dtype=torch.float32
+ )
+ agent_dict["group_reward"] = torch.zeros(
+ (1,), device=self.device, dtype=torch.float32
+ )
+
+ # Add done
+ agent_dict["done"] = torch.tensor(
+ False, device=self.device, dtype=torch.bool
+ )
+ agent_dict["terminated"] = torch.tensor(
+ False, device=self.device, dtype=torch.bool
+ )
+ agent_dict["truncated"] = torch.tensor(
+ False, device=self.device, dtype=torch.bool
+ )
+
+ source[group_name][agent_name] = agent_dict
+
+ tensordict_out = TensorDict(
+ source=source,
+ batch_size=self.batch_size,
+ device=self.device,
+ )
+
+ return tensordict_out
+
+ def _get_action_from_tensor(self, tensor):
+ if not self.categorical_actions:
+ action = torch.argmax(tensor, dim=-1)
+ else:
+ action = tensor
+ return action
+
+ def _step(self, tensordict: TensorDictBase) -> TensorDictBase:
+ # Apply actions
+ for behavior, behavior_spec in self._env.behavior_specs.items():
+ env_action_spec = behavior_spec.action_spec
+ steps = self._env.get_steps(behavior)[0]
+
+ for agent_id, group_id in zip(steps.agent_id, steps.group_id):
+ agent_name = f"agent_{agent_id}"
+ self._check_agent_exists(agent_name, group_id)
+ group_name = self.agent_name_to_group_name_map[agent_name]
+
+ agent_action_spec = self.full_action_spec[group_name, agent_name]
+ action_tuple = self.lib.base_env.ActionTuple()
+ discrete_branches = env_action_spec.discrete_branches
+ continuous_size = env_action_spec.continuous_size
+
+ if len(discrete_branches) > 0:
+ discrete_spec = agent_action_spec["discrete_action"]
+ discrete_action = tensordict[
+ group_name, agent_name, "discrete_action"
+ ]
+ if not self.categorical_actions:
+ discrete_action = discrete_spec.to_categorical(discrete_action)
+ action_tuple.add_discrete(discrete_action[None, ...].numpy())
+
+ if continuous_size > 0:
+ continuous_action = tensordict[
+ group_name, agent_name, "continuous_action"
+ ]
+ action_tuple.add_continuous(continuous_action[None, ...].numpy())
+
+ self._env.set_action_for_agent(behavior, agent_id, action_tuple)
+
+ self._env.step()
+ self._update_action_mask()
+ return self._make_td_out(tensordict)
+
+ def _to_tensor(self, value):
+ return torch.tensor(value, device=self.device, dtype=torch.float32)
+
+ def _reset(
+ self, tensordict: TensorDictBase | None = None, **kwargs
+ ) -> TensorDictBase:
+ self._env.reset()
+ return self._make_td_out(tensordict, is_reset=True)
+
+ def close(self):
+ self._env.close()
+
+ @_classproperty
+ def available_envs(cls):
+ if not _has_unity_mlagents:
+ return []
+ return _get_registered_envs()
+
+
+class UnityMLAgentsEnv(UnityMLAgentsWrapper):
+ """Unity ML-Agents environment wrapper.
+
+ GitHub: https://github.com/Unity-Technologies/ml-agents
+
+ Documentation: https://unity-technologies.github.io/ml-agents/Python-LLAPI/
+
+ This class can be provided any of the optional initialization arguments that
+ :class:`mlagents_envs.environment.UnityEnvironment` class provides. For a
+ list of these arguments, see:
+ https://unity-technologies.github.io/ml-agents/Python-LLAPI-Documentation/#__init__
+
+ If both ``file_name`` and ``registered_name`` are given, an error is raised.
+
+ If neither ``file_name`` nor``registered_name`` are given, the environment
+ setup waits on a localhost port, and the user must execute a Unity ML-Agents
+ environment binary for to connect to it.
+
+ Args:
+ file_name (str, optional): if provided, the path to the Unity
+ environment binary. Defaults to ``None``.
+ registered_name (str, optional): if provided, the Unity environment
+ binary is loaded from the default ML-Agents registry. The list of
+ registered environments is in :attr:`~.available_envs`. Defaults to
+ ``None``.
+
+ Keyword Args:
+ device (torch.device, optional): if provided, the device on which the data
+ is to be cast. Defaults to ``None``.
+ batch_size (torch.Size, optional): the batch size of the environment.
+ Defaults to ``torch.Size([])``.
+ allow_done_after_reset (bool, optional): if ``True``, it is tolerated
+ for envs to be ``done`` just after :meth:`~.reset` is called.
+ Defaults to ``False``.
+ group_map (MarlGroupMapType or Dict[str, List[str]]], optional): how to
+ group agents in tensordicts for input/output. See
+ :class:`~torchrl.envs.utils.MarlGroupMapType` for more info. If not
+ specified, agents are grouped according to the group ID given by the
+ Unity environment. Defaults to ``None``.
+ categorical_actions (bool, optional): if ``True``, categorical specs
+ will be converted to the TorchRL equivalent
+ (:class:`torchrl.data.Categorical`), otherwise a one-hot encoding
+ will be used (:class:`torchrl.data.OneHot`). Defaults to ``False``.
+
+ Attributes:
+ available_envs: list of registered environments available to build
+
+ Examples:
+ >>> from torchrl.envs import UnityMLAgentsEnv
+ >>> env = UnityMLAgentsEnv(registered_name='3DBall')
+ >>> td = env.reset()
+ >>> td = env.step(td.update(env.full_action_spec.rand()))
+ >>> td
+ TensorDict(
+ fields={
+ group_0: TensorDict(
+ fields={
+ agent_0: TensorDict(
+ fields={
+ VectorSensor_size8: Tensor(shape=torch.Size([8]), device=cpu, dtype=torch.float32, is_shared=False),
+ continuous_action: Tensor(shape=torch.Size([2]), device=cpu, dtype=torch.float32, is_shared=False),
+ done: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False),
+ terminated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False),
+ truncated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False)},
+ batch_size=torch.Size([]),
+ device=None,
+ is_shared=False),
+ agent_10: TensorDict(
+ fields={
+ VectorSensor_size8: Tensor(shape=torch.Size([8]), device=cpu, dtype=torch.float32, is_shared=False),
+ continuous_action: Tensor(shape=torch.Size([2]), device=cpu, dtype=torch.float32, is_shared=False),
+ done: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False),
+ terminated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False),
+ truncated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False)},
+ batch_size=torch.Size([]),
+ device=None,
+ is_shared=False),
+ agent_11: TensorDict(
+ fields={
+ VectorSensor_size8: Tensor(shape=torch.Size([8]), device=cpu, dtype=torch.float32, is_shared=False),
+ continuous_action: Tensor(shape=torch.Size([2]), device=cpu, dtype=torch.float32, is_shared=False),
+ done: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False),
+ terminated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False),
+ truncated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False)},
+ batch_size=torch.Size([]),
+ device=None,
+ is_shared=False),
+ agent_1: TensorDict(
+ fields={
+ VectorSensor_size8: Tensor(shape=torch.Size([8]), device=cpu, dtype=torch.float32, is_shared=False),
+ continuous_action: Tensor(shape=torch.Size([2]), device=cpu, dtype=torch.float32, is_shared=False),
+ done: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False),
+ terminated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False),
+ truncated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False)},
+ batch_size=torch.Size([]),
+ device=None,
+ is_shared=False),
+ agent_2: TensorDict(
+ fields={
+ VectorSensor_size8: Tensor(shape=torch.Size([8]), device=cpu, dtype=torch.float32, is_shared=False),
+ continuous_action: Tensor(shape=torch.Size([2]), device=cpu, dtype=torch.float32, is_shared=False),
+ done: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False),
+ terminated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False),
+ truncated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False)},
+ batch_size=torch.Size([]),
+ device=None,
+ is_shared=False),
+ agent_3: TensorDict(
+ fields={
+ VectorSensor_size8: Tensor(shape=torch.Size([8]), device=cpu, dtype=torch.float32, is_shared=False),
+ continuous_action: Tensor(shape=torch.Size([2]), device=cpu, dtype=torch.float32, is_shared=False),
+ done: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False),
+ terminated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False),
+ truncated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False)},
+ batch_size=torch.Size([]),
+ device=None,
+ is_shared=False),
+ agent_4: TensorDict(
+ fields={
+ VectorSensor_size8: Tensor(shape=torch.Size([8]), device=cpu, dtype=torch.float32, is_shared=False),
+ continuous_action: Tensor(shape=torch.Size([2]), device=cpu, dtype=torch.float32, is_shared=False),
+ done: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False),
+ terminated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False),
+ truncated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False)},
+ batch_size=torch.Size([]),
+ device=None,
+ is_shared=False),
+ agent_5: TensorDict(
+ fields={
+ VectorSensor_size8: Tensor(shape=torch.Size([8]), device=cpu, dtype=torch.float32, is_shared=False),
+ continuous_action: Tensor(shape=torch.Size([2]), device=cpu, dtype=torch.float32, is_shared=False),
+ done: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False),
+ terminated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False),
+ truncated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False)},
+ batch_size=torch.Size([]),
+ device=None,
+ is_shared=False),
+ agent_6: TensorDict(
+ fields={
+ VectorSensor_size8: Tensor(shape=torch.Size([8]), device=cpu, dtype=torch.float32, is_shared=False),
+ continuous_action: Tensor(shape=torch.Size([2]), device=cpu, dtype=torch.float32, is_shared=False),
+ done: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False),
+ terminated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False),
+ truncated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False)},
+ batch_size=torch.Size([]),
+ device=None,
+ is_shared=False),
+ agent_7: TensorDict(
+ fields={
+ VectorSensor_size8: Tensor(shape=torch.Size([8]), device=cpu, dtype=torch.float32, is_shared=False),
+ continuous_action: Tensor(shape=torch.Size([2]), device=cpu, dtype=torch.float32, is_shared=False),
+ done: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False),
+ terminated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False),
+ truncated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False)},
+ batch_size=torch.Size([]),
+ device=None,
+ is_shared=False),
+ agent_8: TensorDict(
+ fields={
+ VectorSensor_size8: Tensor(shape=torch.Size([8]), device=cpu, dtype=torch.float32, is_shared=False),
+ continuous_action: Tensor(shape=torch.Size([2]), device=cpu, dtype=torch.float32, is_shared=False),
+ done: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False),
+ terminated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False),
+ truncated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False)},
+ batch_size=torch.Size([]),
+ device=None,
+ is_shared=False),
+ agent_9: TensorDict(
+ fields={
+ VectorSensor_size8: Tensor(shape=torch.Size([8]), device=cpu, dtype=torch.float32, is_shared=False),
+ continuous_action: Tensor(shape=torch.Size([2]), device=cpu, dtype=torch.float32, is_shared=False),
+ done: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False),
+ terminated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False),
+ truncated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False)},
+ batch_size=torch.Size([]),
+ device=None,
+ is_shared=False)},
+ batch_size=torch.Size([]),
+ device=None,
+ is_shared=False),
+ next: TensorDict(
+ fields={
+ group_0: TensorDict(
+ fields={
+ agent_0: TensorDict(
+ fields={
+ VectorSensor_size8: Tensor(shape=torch.Size([8]), device=cpu, dtype=torch.float32, is_shared=False),
+ done: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False),
+ group_reward: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, is_shared=False),
+ reward: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, is_shared=False),
+ terminated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False),
+ truncated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False)},
+ batch_size=torch.Size([]),
+ device=None,
+ is_shared=False),
+ agent_10: TensorDict(
+ fields={
+ VectorSensor_size8: Tensor(shape=torch.Size([8]), device=cpu, dtype=torch.float32, is_shared=False),
+ done: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False),
+ group_reward: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, is_shared=False),
+ reward: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, is_shared=False),
+ terminated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False),
+ truncated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False)},
+ batch_size=torch.Size([]),
+ device=None,
+ is_shared=False),
+ agent_11: TensorDict(
+ fields={
+ VectorSensor_size8: Tensor(shape=torch.Size([8]), device=cpu, dtype=torch.float32, is_shared=False),
+ done: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False),
+ group_reward: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, is_shared=False),
+ reward: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, is_shared=False),
+ terminated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False),
+ truncated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False)},
+ batch_size=torch.Size([]),
+ device=None,
+ is_shared=False),
+ agent_1: TensorDict(
+ fields={
+ VectorSensor_size8: Tensor(shape=torch.Size([8]), device=cpu, dtype=torch.float32, is_shared=False),
+ done: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False),
+ group_reward: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, is_shared=False),
+ reward: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, is_shared=False),
+ terminated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False),
+ truncated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False)},
+ batch_size=torch.Size([]),
+ device=None,
+ is_shared=False),
+ agent_2: TensorDict(
+ fields={
+ VectorSensor_size8: Tensor(shape=torch.Size([8]), device=cpu, dtype=torch.float32, is_shared=False),
+ done: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False),
+ group_reward: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, is_shared=False),
+ reward: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, is_shared=False),
+ terminated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False),
+ truncated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False)},
+ batch_size=torch.Size([]),
+ device=None,
+ is_shared=False),
+ agent_3: TensorDict(
+ fields={
+ VectorSensor_size8: Tensor(shape=torch.Size([8]), device=cpu, dtype=torch.float32, is_shared=False),
+ done: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False),
+ group_reward: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, is_shared=False),
+ reward: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, is_shared=False),
+ terminated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False),
+ truncated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False)},
+ batch_size=torch.Size([]),
+ device=None,
+ is_shared=False),
+ agent_4: TensorDict(
+ fields={
+ VectorSensor_size8: Tensor(shape=torch.Size([8]), device=cpu, dtype=torch.float32, is_shared=False),
+ done: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False),
+ group_reward: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, is_shared=False),
+ reward: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, is_shared=False),
+ terminated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False),
+ truncated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False)},
+ batch_size=torch.Size([]),
+ device=None,
+ is_shared=False),
+ agent_5: TensorDict(
+ fields={
+ VectorSensor_size8: Tensor(shape=torch.Size([8]), device=cpu, dtype=torch.float32, is_shared=False),
+ done: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False),
+ group_reward: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, is_shared=False),
+ reward: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, is_shared=False),
+ terminated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False),
+ truncated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False)},
+ batch_size=torch.Size([]),
+ device=None,
+ is_shared=False),
+ agent_6: TensorDict(
+ fields={
+ VectorSensor_size8: Tensor(shape=torch.Size([8]), device=cpu, dtype=torch.float32, is_shared=False),
+ done: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False),
+ group_reward: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, is_shared=False),
+ reward: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, is_shared=False),
+ terminated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False),
+ truncated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False)},
+ batch_size=torch.Size([]),
+ device=None,
+ is_shared=False),
+ agent_7: TensorDict(
+ fields={
+ VectorSensor_size8: Tensor(shape=torch.Size([8]), device=cpu, dtype=torch.float32, is_shared=False),
+ done: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False),
+ group_reward: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, is_shared=False),
+ reward: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, is_shared=False),
+ terminated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False),
+ truncated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False)},
+ batch_size=torch.Size([]),
+ device=None,
+ is_shared=False),
+ agent_8: TensorDict(
+ fields={
+ VectorSensor_size8: Tensor(shape=torch.Size([8]), device=cpu, dtype=torch.float32, is_shared=False),
+ done: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False),
+ group_reward: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, is_shared=False),
+ reward: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, is_shared=False),
+ terminated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False),
+ truncated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False)},
+ batch_size=torch.Size([]),
+ device=None,
+ is_shared=False),
+ agent_9: TensorDict(
+ fields={
+ VectorSensor_size8: Tensor(shape=torch.Size([8]), device=cpu, dtype=torch.float32, is_shared=False),
+ done: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False),
+ group_reward: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, is_shared=False),
+ reward: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, is_shared=False),
+ terminated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False),
+ truncated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False)},
+ batch_size=torch.Size([]),
+ device=None,
+ is_shared=False)},
+ batch_size=torch.Size([]),
+ device=None,
+ is_shared=False)},
+ batch_size=torch.Size([]),
+ device=None,
+ is_shared=False)},
+ batch_size=torch.Size([]),
+ device=None,
+ is_shared=False)
+ """
+
+ def __init__(
+ self,
+ file_name: Optional[str] = None,
+ registered_name: Optional[str] = None,
+ *,
+ group_map: MarlGroupMapType | Dict[str, List[str]] | None = None,
+ categorical_actions=False,
+ **kwargs,
+ ):
+ kwargs["file_name"] = file_name
+ kwargs["registered_name"] = registered_name
+ super().__init__(
+ group_map=group_map,
+ categorical_actions=categorical_actions,
+ **kwargs,
+ )
+
+ def _build_env(
+ self,
+ file_name: Optional[str],
+ registered_name: Optional[str],
+ **kwargs,
+ ) -> "mlagents_envs.environment.UnityEnvironment": # noqa: F821
+ if not _has_unity_mlagents:
+ raise ImportError(
+ "mlagents_envs not found, unable to create environment. "
+ "Consider downloading and installing mlagents from "
+ f"{self.git_url}"
+ )
+ if file_name is not None and registered_name is not None:
+ raise ValueError(
+ "Both `file_name` and `registered_name` were specified, which "
+ "is not allowed. Specify one of them or neither."
+ )
+ elif registered_name is not None:
+ from mlagents_envs.registry import default_registry
+
+ env = default_registry[registered_name].make(**kwargs)
+ else:
+ env = self.lib.environment.UnityEnvironment(file_name, **kwargs)
+ requires_grad = kwargs.pop("requires_grad", False)
+ return super()._build_env(
+ env,
+ requires_grad=requires_grad,
+ )
+
+ @property
+ def file_name(self):
+ return self._constructor_kwargs["file_name"]
+
+ @property
+ def registered_name(self):
+ return self._constructor_kwargs["registered_name"]
+
+ def _check_kwargs(self, kwargs: Dict):
+ pass
+
+ def __repr__(self) -> str:
+ if self.registered_name is not None:
+ env_name = self.registered_name
+ else:
+ env_name = self.file_name
+ return f"{self.__class__.__name__}(env={env_name}, batch_size={self.batch_size}, device={self.device})"
diff --git a/torchrl/envs/libs/vmas.py b/torchrl/envs/libs/vmas.py
index 9751e84a3ac..22f9835303b 100644
--- a/torchrl/envs/libs/vmas.py
+++ b/torchrl/envs/libs/vmas.py
@@ -12,16 +12,16 @@
from tensordict import LazyStackedTensorDict, TensorDict, TensorDictBase
from torchrl.data.tensor_specs import (
- BoundedTensorSpec,
- CompositeSpec,
+ Bounded,
+ Categorical,
+ Composite,
DEVICE_TYPING,
- DiscreteTensorSpec,
- LazyStackedCompositeSpec,
- MultiDiscreteTensorSpec,
- MultiOneHotDiscreteTensorSpec,
- OneHotDiscreteTensorSpec,
+ MultiCategorical,
+ MultiOneHot,
+ OneHot,
+ StackedComposite,
TensorSpec,
- UnboundedContinuousTensorSpec,
+ Unbounded,
)
from torchrl.data.utils import numpy_to_torch_dtype_dict
from torchrl.envs.common import _EnvWrapper, EnvBase
@@ -57,11 +57,7 @@ def _vmas_to_torchrl_spec_transform(
) -> TensorSpec:
gym_spaces = gym_backend("spaces")
if isinstance(spec, gym_spaces.discrete.Discrete):
- action_space_cls = (
- DiscreteTensorSpec
- if categorical_action_encoding
- else OneHotDiscreteTensorSpec
- )
+ action_space_cls = Categorical if categorical_action_encoding else OneHot
dtype = (
numpy_to_torch_dtype_dict[spec.dtype]
if categorical_action_encoding
@@ -75,9 +71,9 @@ def _vmas_to_torchrl_spec_transform(
else torch.long
)
return (
- MultiDiscreteTensorSpec(spec.nvec, device=device, dtype=dtype)
+ MultiCategorical(spec.nvec, device=device, dtype=dtype)
if categorical_action_encoding
- else MultiOneHotDiscreteTensorSpec(spec.nvec, device=device, dtype=dtype)
+ else MultiOneHot(spec.nvec, device=device, dtype=dtype)
)
elif isinstance(spec, gym_spaces.Box):
shape = spec.shape
@@ -88,9 +84,9 @@ def _vmas_to_torchrl_spec_transform(
high = torch.tensor(spec.high, device=device, dtype=dtype)
is_unbounded = low.isinf().all() and high.isinf().all()
return (
- UnboundedContinuousTensorSpec(shape, device=device, dtype=dtype)
+ Unbounded(shape, device=device, dtype=dtype)
if is_unbounded
- else BoundedTensorSpec(
+ else Bounded(
low,
high,
shape,
@@ -98,6 +94,16 @@ def _vmas_to_torchrl_spec_transform(
device=device,
)
)
+ elif isinstance(spec, gym_spaces.Dict):
+ spec_out = {}
+ for key in spec.keys():
+ spec_out[key] = _vmas_to_torchrl_spec_transform(
+ spec[key],
+ device=device,
+ categorical_action_encoding=categorical_action_encoding,
+ )
+ # the batch-size must be set later
+ return Composite(spec_out, device=device)
else:
raise NotImplementedError(
f"spec of type {type(spec).__name__} is currently unaccounted for vmas"
@@ -322,9 +328,9 @@ def _make_specs(
self.group_map = self.group_map.get_group_map(self.agent_names)
check_marl_grouping(self.group_map, self.agent_names)
- self.unbatched_action_spec = CompositeSpec(device=self.device)
- self.unbatched_observation_spec = CompositeSpec(device=self.device)
- self.unbatched_reward_spec = CompositeSpec(device=self.device)
+ self.unbatched_action_spec = Composite(device=self.device)
+ self.unbatched_observation_spec = Composite(device=self.device)
+ self.unbatched_reward_spec = Composite(device=self.device)
self.het_specs = False
self.het_specs_map = {}
@@ -341,14 +347,14 @@ def _make_specs(
if group_info_spec is not None:
self.unbatched_observation_spec[(group, "info")] = group_info_spec
group_het_specs = isinstance(
- group_observation_spec, LazyStackedCompositeSpec
- ) or isinstance(group_action_spec, LazyStackedCompositeSpec)
+ group_observation_spec, StackedComposite
+ ) or isinstance(group_action_spec, StackedComposite)
self.het_specs_map[group] = group_het_specs
self.het_specs = self.het_specs or group_het_specs
- self.unbatched_done_spec = CompositeSpec(
+ self.unbatched_done_spec = Composite(
{
- "done": DiscreteTensorSpec(
+ "done": Categorical(
n=2,
shape=torch.Size((1,)),
dtype=torch.bool,
@@ -380,7 +386,7 @@ def _make_unbatched_group_specs(self, group: str):
agent_index = self.agent_names_to_indices_map[agent_name]
agent = self.agents[agent_index]
action_specs.append(
- CompositeSpec(
+ Composite(
{
"action": _vmas_to_torchrl_spec_transform(
self.action_space[agent_index],
@@ -391,7 +397,7 @@ def _make_unbatched_group_specs(self, group: str):
)
)
observation_specs.append(
- CompositeSpec(
+ Composite(
{
"observation": _vmas_to_torchrl_spec_transform(
self.observation_space[agent_index],
@@ -402,9 +408,9 @@ def _make_unbatched_group_specs(self, group: str):
)
)
reward_specs.append(
- CompositeSpec(
+ Composite(
{
- "reward": UnboundedContinuousTensorSpec(
+ "reward": Unbounded(
shape=torch.Size((1,)),
device=self.device,
) # shape = (1,)
@@ -414,9 +420,9 @@ def _make_unbatched_group_specs(self, group: str):
agent_info = self.scenario.info(agent)
if len(agent_info):
info_specs.append(
- CompositeSpec(
+ Composite(
{
- key: UnboundedContinuousTensorSpec(
+ key: Unbounded(
shape=_selective_unsqueeze(
value, batch_size=self.batch_size
).shape[1:],
diff --git a/torchrl/envs/model_based/common.py b/torchrl/envs/model_based/common.py
index f6b3f97cd4a..2a3c0198f9c 100644
--- a/torchrl/envs/model_based/common.py
+++ b/torchrl/envs/model_based/common.py
@@ -27,18 +27,18 @@ class ModelBasedEnvBase(EnvBase):
Example:
>>> import torch
>>> from tensordict import TensorDict
- >>> from torchrl.data import CompositeSpec, UnboundedContinuousTensorSpec
+ >>> from torchrl.data import Composite, Unbounded
>>> class MyMBEnv(ModelBasedEnvBase):
... def __init__(self, world_model, device="cpu", dtype=None, batch_size=None):
... super().__init__(world_model, device=device, dtype=dtype, batch_size=batch_size)
- ... self.observation_spec = CompositeSpec(
- ... hidden_observation=UnboundedContinuousTensorSpec((4,))
+ ... self.observation_spec = Composite(
+ ... hidden_observation=Unbounded((4,))
... )
- ... self.state_spec = CompositeSpec(
- ... hidden_observation=UnboundedContinuousTensorSpec((4,)),
+ ... self.state_spec = Composite(
+ ... hidden_observation=Unbounded((4,)),
... )
- ... self.action_spec = UnboundedContinuousTensorSpec((1,))
- ... self.reward_spec = UnboundedContinuousTensorSpec((1,))
+ ... self.action_spec = Unbounded((1,))
+ ... self.reward_spec = Unbounded((1,))
...
... def _reset(self, tensordict: TensorDict) -> TensorDict:
... tensordict = TensorDict({},
@@ -84,10 +84,10 @@ class ModelBasedEnvBase(EnvBase):
Properties:
- - observation_spec (CompositeSpec): sampling spec of the observations;
+ - observation_spec (Composite): sampling spec of the observations;
- action_spec (TensorSpec): sampling spec of the actions;
- reward_spec (TensorSpec): sampling spec of the rewards;
- - input_spec (CompositeSpec): sampling spec of the inputs;
+ - input_spec (Composite): sampling spec of the inputs;
- batch_size (torch.Size): batch_size to be used by the env. If not set, the env accept tensordicts of all batch sizes.
- device (torch.device): device where the env input and output are expected to live
diff --git a/torchrl/envs/model_based/dreamer.py b/torchrl/envs/model_based/dreamer.py
index 5609861c75f..f5636f76c5a 100644
--- a/torchrl/envs/model_based/dreamer.py
+++ b/torchrl/envs/model_based/dreamer.py
@@ -9,7 +9,7 @@
from tensordict import TensorDict
from tensordict.nn import TensorDictModule
-from torchrl.data.tensor_specs import CompositeSpec
+from torchrl.data.tensor_specs import Composite
from torchrl.data.utils import DEVICE_TYPING
from torchrl.envs.common import EnvBase
from torchrl.envs.model_based import ModelBasedEnvBase
@@ -39,7 +39,7 @@ def set_specs_from_env(self, env: EnvBase):
"""Sets the specs of the environment from the specs of the given environment."""
super().set_specs_from_env(env)
self.action_spec = self.action_spec.to(self.device)
- self.state_spec = CompositeSpec(
+ self.state_spec = Composite(
state=self.observation_spec["state"],
belief=self.observation_spec["belief"],
shape=env.batch_size,
diff --git a/torchrl/envs/transforms/gym_transforms.py b/torchrl/envs/transforms/gym_transforms.py
index 35f122b770a..b3ac334a5d8 100644
--- a/torchrl/envs/transforms/gym_transforms.py
+++ b/torchrl/envs/transforms/gym_transforms.py
@@ -10,7 +10,7 @@
import torchrl.objectives.common
from tensordict import TensorDictBase
from tensordict.utils import expand_as_right, NestedKey
-from torchrl.data.tensor_specs import UnboundedDiscreteTensorSpec
+from torchrl.data.tensor_specs import Unbounded
from torchrl.envs.transforms.transforms import FORWARD_NOT_IMPLEMENTED, Transform
@@ -179,7 +179,7 @@ def _reset(self, tensordict, tensordict_reset):
def transform_observation_spec(self, observation_spec):
full_done_spec = self.parent.output_spec["full_done_spec"]
observation_spec[self.eol_key] = full_done_spec[self.done_key].clone()
- observation_spec[self.lives_key] = UnboundedDiscreteTensorSpec(
+ observation_spec[self.lives_key] = Unbounded(
self.parent.batch_size,
device=self.parent.device,
dtype=torch.int64,
diff --git a/torchrl/envs/transforms/r3m.py b/torchrl/envs/transforms/r3m.py
index 546321d5815..bdc8af1eefa 100644
--- a/torchrl/envs/transforms/r3m.py
+++ b/torchrl/envs/transforms/r3m.py
@@ -11,11 +11,7 @@
from torch.hub import load_state_dict_from_url
from torch.nn import Identity
-from torchrl.data.tensor_specs import (
- CompositeSpec,
- TensorSpec,
- UnboundedContinuousTensorSpec,
-)
+from torchrl.data.tensor_specs import Composite, TensorSpec, Unbounded
from torchrl.data.utils import DEVICE_TYPING
from torchrl.envs.transforms.transforms import (
CatTensors,
@@ -103,8 +99,8 @@ def _apply_transform(self, obs: torch.Tensor) -> None:
return out
def transform_observation_spec(self, observation_spec: TensorSpec) -> TensorSpec:
- if not isinstance(observation_spec, CompositeSpec):
- raise ValueError("_R3MNet can only infer CompositeSpec")
+ if not isinstance(observation_spec, Composite):
+ raise ValueError("_R3MNet can only infer Composite")
keys = [key for key in observation_spec.keys(True, True) if key in self.in_keys]
device = observation_spec[keys[0]].device
@@ -116,7 +112,7 @@ def transform_observation_spec(self, observation_spec: TensorSpec) -> TensorSpec
del observation_spec[in_key]
for out_key in self.out_keys:
- observation_spec[out_key] = UnboundedContinuousTensorSpec(
+ observation_spec[out_key] = Unbounded(
shape=torch.Size([*dim, self.outdim]), device=device
)
@@ -319,7 +315,7 @@ def _init(self):
unsqueeze = UnsqueezeTransform(
in_keys=in_keys,
out_keys=in_keys,
- unsqueeze_dim=-4,
+ dim=-4,
)
transforms.append(unsqueeze)
diff --git a/torchrl/envs/transforms/rlhf.py b/torchrl/envs/transforms/rlhf.py
index 33874393038..6228b0f22b7 100644
--- a/torchrl/envs/transforms/rlhf.py
+++ b/torchrl/envs/transforms/rlhf.py
@@ -9,7 +9,7 @@
from tensordict.nn import ProbabilisticTensorDictModule, TensorDictParams
from tensordict.utils import is_seq_of_nested_key
from torch import nn
-from torchrl.data.tensor_specs import CompositeSpec, UnboundedContinuousTensorSpec
+from torchrl.data.tensor_specs import Composite, Unbounded
from torchrl.envs.transforms.transforms import Transform
from torchrl.envs.transforms.utils import _set_missing_tolerance, _stateless_param
@@ -142,8 +142,8 @@ def _make_detached_param(x):
self.sample_log_prob_key = "sample_log_prob"
def find_sample_log_prob(module):
- if hasattr(module, "SAMPLE_LOG_PROB_KEY"):
- self.sample_log_prob_key = module.SAMPLE_LOG_PROB_KEY
+ if hasattr(module, "log_prob_key"):
+ self.sample_log_prob_key = module.log_prob_key
self.functional_actor.apply(find_sample_log_prob)
@@ -186,7 +186,7 @@ def _step(
forward = _call
- def transform_output_spec(self, output_spec: CompositeSpec) -> CompositeSpec:
+ def transform_output_spec(self, output_spec: Composite) -> Composite:
output_spec = super().transform_output_spec(output_spec)
# todo: here we'll need to use the reward_key once it's implemented
# parent = self.parent
@@ -195,17 +195,17 @@ def transform_output_spec(self, output_spec: CompositeSpec) -> CompositeSpec:
if in_key == "reward" and out_key == "reward":
parent = self.parent
- reward_spec = UnboundedContinuousTensorSpec(
+ reward_spec = Unbounded(
device=output_spec.device,
shape=output_spec["full_reward_spec"][parent.reward_key].shape,
)
- output_spec["full_reward_spec"] = CompositeSpec(
+ output_spec["full_reward_spec"] = Composite(
{parent.reward_key: reward_spec},
shape=output_spec["full_reward_spec"].shape,
)
elif in_key == "reward":
parent = self.parent
- reward_spec = UnboundedContinuousTensorSpec(
+ reward_spec = Unbounded(
device=output_spec.device,
shape=output_spec["full_reward_spec"][parent.reward_key].shape,
)
@@ -214,7 +214,7 @@ def transform_output_spec(self, output_spec: CompositeSpec) -> CompositeSpec:
observation_spec[out_key] = reward_spec
else:
observation_spec = output_spec["full_observation_spec"]
- reward_spec = UnboundedContinuousTensorSpec(
+ reward_spec = Unbounded(
device=output_spec.device, shape=observation_spec[in_key].shape
)
# then we need to populate the output keys
diff --git a/torchrl/envs/transforms/transforms.py b/torchrl/envs/transforms/transforms.py
index 7c9dec980f5..b70e05ca431 100644
--- a/torchrl/envs/transforms/transforms.py
+++ b/torchrl/envs/transforms/transforms.py
@@ -39,25 +39,30 @@
unravel_key,
unravel_key_list,
)
-from tensordict._C import _unravel_key_to_tuple
from tensordict.nn import dispatch, TensorDictModuleBase
-from tensordict.utils import expand_as_right, expand_right, NestedKey
+from tensordict.utils import (
+ _unravel_key_to_tuple,
+ _zip_strict,
+ expand_as_right,
+ expand_right,
+ NestedKey,
+)
from torch import nn, Tensor
from torch.utils._pytree import tree_map
from torchrl._utils import _append_last, _ends_with, _make_ordinal_device, _replace_last
from torchrl.data.tensor_specs import (
- BinaryDiscreteTensorSpec,
- BoundedTensorSpec,
- CompositeSpec,
+ Binary,
+ Bounded,
+ Categorical,
+ Composite,
ContinuousBox,
- DiscreteTensorSpec,
- MultiDiscreteTensorSpec,
- MultiOneHotDiscreteTensorSpec,
- OneHotDiscreteTensorSpec,
+ MultiCategorical,
+ MultiOneHot,
+ OneHot,
TensorSpec,
- UnboundedContinuousTensorSpec,
+ Unbounded,
)
from torchrl.envs.common import _do_nothing, _EnvPostInit, EnvBase, make_tensordict
from torchrl.envs.transforms import functional as F
@@ -80,14 +85,14 @@
def _apply_to_composite(function):
@wraps(function)
def new_fun(self, observation_spec):
- if isinstance(observation_spec, CompositeSpec):
+ if isinstance(observation_spec, Composite):
_specs = observation_spec._specs
in_keys = self.in_keys
out_keys = self.out_keys
- for in_key, out_key in zip(in_keys, out_keys):
+ for in_key, out_key in _zip_strict(in_keys, out_keys):
if in_key in observation_spec.keys(True, True):
_specs[out_key] = function(self, observation_spec[in_key].clone())
- return CompositeSpec(
+ return Composite(
_specs, shape=observation_spec.shape, device=observation_spec.device
)
else:
@@ -109,12 +114,12 @@ def new_fun(self, input_spec):
action_spec = input_spec["full_action_spec"].clone()
state_spec = input_spec["full_state_spec"]
if state_spec is None:
- state_spec = CompositeSpec(shape=input_spec.shape, device=input_spec.device)
+ state_spec = Composite(shape=input_spec.shape, device=input_spec.device)
else:
state_spec = state_spec.clone()
in_keys_inv = self.in_keys_inv
out_keys_inv = self.out_keys_inv
- for in_key, out_key in zip(in_keys_inv, out_keys_inv):
+ for in_key, out_key in _zip_strict(in_keys_inv, out_keys_inv):
if in_key != out_key:
# we only change the input spec if the key is the same
continue
@@ -122,7 +127,7 @@ def new_fun(self, input_spec):
action_spec[out_key] = function(self, action_spec[in_key].clone())
elif in_key in state_spec.keys(True, True):
state_spec[out_key] = function(self, state_spec[in_key].clone())
- return CompositeSpec(
+ return Composite(
full_state_spec=state_spec,
full_action_spec=action_spec,
shape=input_spec.shape,
@@ -270,7 +275,7 @@ def _call(self, tensordict: TensorDictBase) -> TensorDictBase:
:meth:`TransformedEnv.reset`.
"""
- for in_key, out_key in zip(self.in_keys, self.out_keys):
+ for in_key, out_key in _zip_strict(self.in_keys, self.out_keys):
value = tensordict.get(in_key, default=None)
if value is not None:
observation = self._apply_transform(value)
@@ -287,7 +292,7 @@ def _call(self, tensordict: TensorDictBase) -> TensorDictBase:
@dispatch(source="in_keys", dest="out_keys")
def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
"""Reads the input tensordict, and for the selected keys, applies the transform."""
- for in_key, out_key in zip(self.in_keys, self.out_keys):
+ for in_key, out_key in _zip_strict(self.in_keys, self.out_keys):
data = tensordict.get(in_key, None)
if data is not None:
data = self._apply_transform(data)
@@ -328,7 +333,7 @@ def _inv_apply_transform(self, state: torch.Tensor) -> torch.Tensor:
def _inv_call(self, tensordict: TensorDictBase) -> TensorDictBase:
if not self.in_keys_inv:
return tensordict
- for in_key, out_key in zip(self.in_keys_inv, self.out_keys_inv):
+ for in_key, out_key in _zip_strict(self.in_keys_inv, self.out_keys_inv):
data = tensordict.get(in_key, None)
if data is not None:
item = self._inv_apply_transform(data)
@@ -360,7 +365,7 @@ def transform_env_batch_size(self, batch_size: torch.Size):
"""Transforms the batch-size of the parent env."""
return batch_size
- def transform_output_spec(self, output_spec: CompositeSpec) -> CompositeSpec:
+ def transform_output_spec(self, output_spec: Composite) -> Composite:
"""Transforms the output spec such that the resulting spec matches transform mapping.
This method should generally be left untouched. Changes should be implemented using
@@ -599,7 +604,7 @@ def __init__(
device = env.device
super().__init__(device=None, allow_done_after_reset=None, **kwargs)
- # Type matching must be exact here, because subtyping could introduce differences in behaviour that must
+ # Type matching must be exact here, because subtyping could introduce differences in behavior that must
# be contained within the subclass.
if type(env) is TransformedEnv and type(self) is TransformedEnv:
self._set_env(env.base_env, device)
@@ -831,7 +836,7 @@ def _reset_proc_data(self, tensordict, tensordict_reset):
return tensordict_reset
def _complete_done(
- cls, done_spec: CompositeSpec, data: TensorDictBase
+ cls, done_spec: Composite, data: TensorDictBase
) -> TensorDictBase:
# This step has already been completed. We assume the transform module do their job correctly.
return data
@@ -1090,7 +1095,7 @@ def transform_env_batch_size(self, batch_size: torch.batch_size):
return batch_size
def transform_input_spec(self, input_spec: TensorSpec) -> TensorSpec:
- for t in self.transforms[::-1]:
+ for t in self.transforms:
input_spec = t.transform_input_spec(input_spec)
return input_spec
@@ -1343,11 +1348,11 @@ def _apply_transform(self, observation: torch.FloatTensor) -> torch.Tensor:
@_apply_to_composite
def transform_observation_spec(self, observation_spec: TensorSpec) -> TensorSpec:
observation_spec = self._pixel_observation(observation_spec)
- unsqueeze_dim = [1] if self._should_unsqueeze(observation_spec) else []
+ dim = [1] if self._should_unsqueeze(observation_spec) else []
if not self.shape_tolerant or observation_spec.shape[-1] == 3:
observation_spec.shape = torch.Size(
[
- *unsqueeze_dim,
+ *dim,
*observation_spec.shape[:-3],
observation_spec.shape[-1],
observation_spec.shape[-3],
@@ -1465,7 +1470,7 @@ def _inv_apply_transform(self, state: torch.Tensor) -> torch.Tensor:
@_apply_to_composite
def transform_observation_spec(self, observation_spec: TensorSpec) -> TensorSpec:
- return BoundedTensorSpec(
+ return Bounded(
shape=observation_spec.shape,
device=observation_spec.device,
dtype=observation_spec.dtype,
@@ -1477,7 +1482,7 @@ def transform_reward_spec(self, reward_spec: TensorSpec) -> TensorSpec:
for key in self.in_keys:
if key in self.parent.reward_keys:
spec = self.parent.output_spec["full_reward_spec"][key]
- self.parent.output_spec["full_reward_spec"][key] = BoundedTensorSpec(
+ self.parent.output_spec["full_reward_spec"][key] = Bounded(
shape=spec.shape,
device=spec.device,
dtype=spec.dtype,
@@ -1507,7 +1512,7 @@ class TargetReturn(Transform):
In goal-conditioned RL, the :class:`~.TargetReturn` is defined as the
expected cumulative reward obtained from the current state to the goal state
- or the end of the episode. It is used as input for the policy to guide its behaviour.
+ or the end of the episode. It is used as input for the policy to guide its behavior.
For a trained policy typically the maximum return in the environment is
chosen as the target return.
However, as it is used as input to the policy module, it should be scaled
@@ -1633,7 +1638,7 @@ def _reset(self, tensordict: TensorDict, tensordict_reset: TensorDictBase):
return tensordict_reset
def _call(self, tensordict: TensorDict) -> TensorDict:
- for in_key, out_key in zip(self.in_keys, self.out_keys):
+ for in_key, out_key in _zip_strict(self.in_keys, self.out_keys):
val_in = tensordict.get(in_key, None)
val_out = tensordict.get(out_key, None)
if val_in is not None:
@@ -1675,7 +1680,7 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
)
def transform_observation_spec(self, observation_spec: TensorSpec) -> TensorSpec:
- for in_key, out_key in zip(self.in_keys, self.out_keys):
+ for in_key, out_key in _zip_strict(self.in_keys, self.out_keys):
if in_key in self.parent.full_observation_spec.keys(True):
target = self.parent.full_observation_spec[in_key]
elif in_key in self.parent.full_reward_spec.keys(True):
@@ -1685,7 +1690,7 @@ def transform_observation_spec(self, observation_spec: TensorSpec) -> TensorSpec
target = self.parent.full_done_spec[in_key]
else:
raise RuntimeError(f"in_key {in_key} not found in output_spec.")
- target_return_spec = UnboundedContinuousTensorSpec(
+ target_return_spec = Unbounded(
shape=target.shape,
dtype=target.dtype,
device=target.device,
@@ -1744,8 +1749,8 @@ def _apply_transform(self, reward: torch.Tensor) -> torch.Tensor:
@_apply_to_composite
def transform_reward_spec(self, reward_spec: TensorSpec) -> TensorSpec:
- if isinstance(reward_spec, UnboundedContinuousTensorSpec):
- return BoundedTensorSpec(
+ if isinstance(reward_spec, Unbounded):
+ return Bounded(
self.clamp_min,
self.clamp_max,
shape=reward_spec.shape,
@@ -1798,7 +1803,7 @@ def _apply_transform(self, reward: torch.Tensor) -> torch.Tensor:
@_apply_to_composite
def transform_reward_spec(self, reward_spec: TensorSpec) -> TensorSpec:
- return BinaryDiscreteTensorSpec(
+ return Binary(
n=1,
device=reward_spec.device,
shape=reward_spec.shape,
@@ -2132,41 +2137,42 @@ class UnsqueezeTransform(Transform):
"""Inserts a dimension of size one at the specified position.
Args:
- unsqueeze_dim (int): dimension to unsqueeze. Must be negative (or allow_positive_dim
+ dim (int): dimension to unsqueeze. Must be negative (or allow_positive_dim
must be turned on).
+
+ Keyword Args:
allow_positive_dim (bool, optional): if ``True``, positive dimensions are accepted.
- :obj:`UnsqueezeTransform` will map these to the n^th feature dimension
+ `UnsqueezeTransform`` will map these to the n^th feature dimension
(ie n^th dimension after batch size of parent env) of the input tensor,
- independently from the tensordict batch size (ie positive dims may be
+ independently of the tensordict batch size (ie positive dims may be
dangerous in contexts where tensordict of different batch dimension
are passed).
Defaults to False, ie. non-negative dimensions are not permitted.
+ in_keys (list of NestedKeys): input entries (read).
+ out_keys (list of NestedKeys): input entries (write). Defaults to ``in_keys`` if
+ not provided.
+ in_keys_inv (list of NestedKeys): input entries (read) during :meth:`~.inv` calls.
+ out_keys_inv (list of NestedKeys): input entries (write) during :meth:`~.inv` calls.
+ Defaults to ``in_keys_in`` if not provided.
"""
invertible = True
@classmethod
def __new__(cls, *args, **kwargs):
- cls._unsqueeze_dim = None
+ cls._dim = None
return super().__new__(cls)
def __init__(
self,
dim: int = None,
+ *,
allow_positive_dim: bool = False,
in_keys: Sequence[NestedKey] | None = None,
out_keys: Sequence[NestedKey] | None = None,
in_keys_inv: Sequence[NestedKey] | None = None,
out_keys_inv: Sequence[NestedKey] | None = None,
- **kwargs,
):
- if "unsqueeze_dim" in kwargs:
- warnings.warn(
- "The `unsqueeze_dim` kwarg will be removed in v0.6. Please use `dim` instead."
- )
- dim = kwargs["unsqueeze_dim"]
- elif dim is None:
- raise TypeError("dim must be provided.")
if in_keys is None:
in_keys = [] # default
if out_keys is None:
@@ -2186,22 +2192,26 @@ def __init__(
raise RuntimeError(
"dim should be smaller than 0 to accommodate for "
"envs of different batch_sizes. Turn allow_positive_dim to accommodate "
- "for positive unsqueeze_dim."
+ "for positive dim."
)
self._dim = dim
@property
def unsqueeze_dim(self):
+ return self.dim
+
+ @property
+ def dim(self):
if self._dim >= 0 and self.parent is not None:
return len(self.parent.batch_size) + self._dim
return self._dim
def _apply_transform(self, observation: torch.Tensor) -> torch.Tensor:
- observation = observation.unsqueeze(self.unsqueeze_dim)
+ observation = observation.unsqueeze(self.dim)
return observation
def _inv_apply_transform(self, observation: torch.Tensor) -> torch.Tensor:
- observation = observation.squeeze(self.unsqueeze_dim)
+ observation = observation.squeeze(self.dim)
return observation
def _transform_spec(self, spec: TensorSpec):
@@ -2248,7 +2258,7 @@ def _reset(
def __repr__(self) -> str:
s = (
- f"{self.__class__.__name__}(unsqueeze_dim={self.unsqueeze_dim}, in_keys={self.in_keys}, out_keys={self.out_keys},"
+ f"{self.__class__.__name__}(dim={self.dim}, in_keys={self.in_keys}, out_keys={self.out_keys},"
f" in_keys_inv={self.in_keys_inv}, out_keys_inv={self.out_keys_inv})"
)
return s
@@ -2258,14 +2268,14 @@ class SqueezeTransform(UnsqueezeTransform):
"""Removes a dimension of size one at the specified position.
Args:
- squeeze_dim (int): dimension to squeeze.
+ dim (int): dimension to squeeze.
"""
invertible = True
def __init__(
self,
- squeeze_dim: int,
+ dim: int | None = None,
*args,
in_keys: Optional[Sequence[str]] = None,
out_keys: Optional[Sequence[str]] = None,
@@ -2273,8 +2283,19 @@ def __init__(
out_keys_inv: Optional[Sequence[str]] = None,
**kwargs,
):
+ if dim is None:
+ if "squeeze_dim" in kwargs:
+ warnings.warn(
+ f"squeeze_dim will be deprecated in favor of dim arg in {type(self).__name__}."
+ )
+ dim = kwargs.pop("squeeze_dim")
+ else:
+ raise TypeError(
+ f"dim must be passed to {type(self).__name__} constructor."
+ )
+
super().__init__(
- squeeze_dim,
+ dim,
*args,
in_keys=in_keys,
out_keys=out_keys,
@@ -2285,7 +2306,7 @@ def __init__(
@property
def squeeze_dim(self):
- return super().unsqueeze_dim
+ return super().dim
_apply_transform = UnsqueezeTransform._inv_apply_transform
_inv_apply_transform = UnsqueezeTransform._apply_transform
@@ -2505,7 +2526,7 @@ class ObservationNorm(ObservationTransform):
loc (number or tensor): location of the affine transform
scale (number or tensor): scale of the affine transform
in_keys (sequence of NestedKey, optional): entries to be normalized. Defaults to ["observation", "pixels"].
- All entries will be normalized with the same values: if a different behaviour is desired
+ All entries will be normalized with the same values: if a different behavior is desired
(e.g. a different normalization for pixels and states) different :obj:`ObservationNorm`
objects should be used.
out_keys (sequence of NestedKey, optional): output entries. Defaults to the value of `in_keys`.
@@ -2569,7 +2590,7 @@ def __init__(
):
if in_keys is None:
raise RuntimeError(
- "Not passing in_keys to ObservationNorm is a deprecated behaviour."
+ "Not passing in_keys to ObservationNorm is a deprecated behavior."
)
if out_keys is None:
@@ -3000,7 +3021,7 @@ def _inv_call(self, tensordict: TensorDictBase) -> torch.Tensor:
def _call(self, tensordict: TensorDictBase, _reset=None) -> TensorDictBase:
"""Update the episode tensordict with max pooled keys."""
_just_reset = _reset is not None
- for in_key, out_key in zip(self.in_keys, self.out_keys):
+ for in_key, out_key in _zip_strict(self.in_keys, self.out_keys):
# Lazy init of buffers
buffer_name = f"_cat_buffers_{in_key}"
data = tensordict.get(in_key)
@@ -3082,6 +3103,31 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
else:
return self.unfolding(tensordict)
+ def _apply_same_padding(self, dim, data, done_mask):
+ d = data.ndim + dim - 1
+ res = data.clone()
+ num_repeats_per_sample = done_mask.sum(dim=-1)
+
+ if num_repeats_per_sample.dim() > 2:
+ extra_dims = num_repeats_per_sample.dim() - 2
+ num_repeats_per_sample = num_repeats_per_sample.flatten(0, extra_dims)
+ res_flat_series = res.flatten(0, extra_dims)
+ else:
+ extra_dims = 0
+ res_flat_series = res
+
+ if d - 1 > extra_dims:
+ res_flat_series_flat_batch = res_flat_series.flatten(1, d - 1)
+ else:
+ res_flat_series_flat_batch = res_flat_series[:, None]
+
+ for sample_idx, num_repeats in enumerate(num_repeats_per_sample):
+ if num_repeats > 0:
+ res_slice = res_flat_series_flat_batch[sample_idx]
+ res_slice[:, :num_repeats] = res_slice[:, num_repeats : num_repeats + 1]
+
+ return res
+
@set_lazy_legacy(False)
def unfolding(self, tensordict: TensorDictBase) -> TensorDictBase:
# it is assumed that the last dimension of the tensordict is the time dimension
@@ -3110,12 +3156,12 @@ def unfolding(self, tensordict: TensorDictBase) -> TensorDictBase:
# first sort the in_keys with strings and non-strings
keys = [
(in_key, out_key)
- for in_key, out_key in zip(self.in_keys, self.out_keys)
+ for in_key, out_key in _zip_strict(self.in_keys, self.out_keys)
if isinstance(in_key, str)
]
keys += [
(in_key, out_key)
- for in_key, out_key in zip(self.in_keys, self.out_keys)
+ for in_key, out_key in _zip_strict(self.in_keys, self.out_keys)
if not isinstance(in_key, str)
]
@@ -3151,7 +3197,7 @@ def unfold_done(done, N):
first_val = None
if isinstance(in_key, tuple) and in_key[0] == "next":
# let's get the out_key we have already processed
- prev_out_key = dict(zip(self.in_keys, self.out_keys)).get(
+ prev_out_key = dict(_zip_strict(self.in_keys, self.out_keys)).get(
in_key[1], None
)
if prev_out_key is not None:
@@ -3192,24 +3238,7 @@ def unfold_done(done, N):
if self.padding != "same":
data = torch.where(done_mask_expand, self.padding_value, data)
else:
- # TODO: This is a pretty bad implementation, could be
- # made more efficient but it works!
- reset_any = reset.any(-1, False)
- reset_vals = list(data_orig[reset_any].unbind(0))
- j_ = float("inf")
- reps = []
- d = data.ndim + self.dim - 1
- n_feat = data.shape[data.ndim + self.dim :].numel()
- for j in done_mask_expand.flatten(d, -1).sum(-1).view(-1) // n_feat:
- if j > j_:
- reset_vals = reset_vals[1:]
- reps.extend([reset_vals[0]] * int(j))
- j_ = j
- if reps:
- reps = torch.stack(reps)
- data = torch.masked_scatter(
- data, done_mask_expand, reps.reshape(-1)
- )
+ data = self._apply_same_padding(self.dim, data, done_mask)
if first_val is not None:
# Aggregate reset along last dim
@@ -3321,7 +3350,7 @@ def _apply_transform(self, reward: torch.Tensor) -> torch.Tensor:
@_apply_to_composite
def transform_reward_spec(self, reward_spec: TensorSpec) -> TensorSpec:
- if isinstance(reward_spec, UnboundedContinuousTensorSpec):
+ if isinstance(reward_spec, Unbounded):
return reward_spec
else:
raise NotImplementedError(
@@ -3361,7 +3390,7 @@ class DTypeCastTransform(Transform):
"""Casts one dtype to another for selected keys.
Depending on whether the ``in_keys`` or ``in_keys_inv`` are provided
- during construction, the class behaviour will change:
+ during construction, the class behavior will change:
* If the keys are provided, those entries and those entries only will be
transformed from ``dtype_in`` to ``dtype_out`` entries;
@@ -3417,17 +3446,17 @@ class DTypeCastTransform(Transform):
>>> print(td.get("not_transformed").dtype)
torch.float32
- The same behaviour is the rule when environments are constructedw without
+ The same behavior is the rule when environments are constructedw without
specifying the transform keys:
Examples:
>>> class MyEnv(EnvBase):
... def __init__(self):
... super().__init__()
- ... self.observation_spec = CompositeSpec(obs=UnboundedContinuousTensorSpec((), dtype=torch.float64))
- ... self.action_spec = UnboundedContinuousTensorSpec((), dtype=torch.float64)
- ... self.reward_spec = UnboundedContinuousTensorSpec((1,), dtype=torch.float64)
- ... self.done_spec = UnboundedContinuousTensorSpec((1,), dtype=torch.bool)
+ ... self.observation_spec = Composite(obs=Unbounded((), dtype=torch.float64))
+ ... self.action_spec = Unbounded((), dtype=torch.float64)
+ ... self.reward_spec = Unbounded((1,), dtype=torch.float64)
+ ... self.done_spec = Unbounded((1,), dtype=torch.bool)
... def _reset(self, data=None):
... return TensorDict({"done": torch.zeros((1,), dtype=torch.bool), **self.observation_spec.rand()}, [])
... def _step(self, data):
@@ -3601,7 +3630,7 @@ def func(name, item):
return tensordict
else:
# we made sure that if in_keys is not None, out_keys is not None either
- for in_key, out_key in zip(in_keys, out_keys):
+ for in_key, out_key in _zip_strict(in_keys, out_keys):
item = self._apply_transform(tensordict.get(in_key))
tensordict.set(out_key, item)
return tensordict
@@ -3640,7 +3669,7 @@ def _inv_apply_transform(self, state: torch.Tensor) -> torch.Tensor:
return state.to(self.dtype_in)
def _transform_spec(self, spec: TensorSpec) -> None:
- if isinstance(spec, CompositeSpec):
+ if isinstance(spec, Composite):
for key in spec:
self._transform_spec(spec[key])
else:
@@ -3660,7 +3689,7 @@ def transform_input_spec(self, input_spec: TensorSpec) -> TensorSpec:
raise NotImplementedError(
f"Calling transform_input_spec without a parent environment isn't supported yet for {type(self)}."
)
- for in_key_inv, out_key_inv in zip(self.in_keys_inv, self.out_keys_inv):
+ for in_key_inv, out_key_inv in _zip_strict(self.in_keys_inv, self.out_keys_inv):
if in_key_inv in full_action_spec.keys(True):
_spec = full_action_spec[in_key_inv]
target = "action"
@@ -3685,7 +3714,7 @@ def transform_input_spec(self, input_spec: TensorSpec) -> TensorSpec:
raise RuntimeError
return input_spec
- def transform_output_spec(self, output_spec: CompositeSpec) -> CompositeSpec:
+ def transform_output_spec(self, output_spec: Composite) -> Composite:
if self.in_keys is None:
raise NotImplementedError(
f"Calling transform_reward_spec without a parent environment isn't supported yet for {type(self)}."
@@ -3694,7 +3723,7 @@ def transform_output_spec(self, output_spec: CompositeSpec) -> CompositeSpec:
full_observation_spec = output_spec["full_observation_spec"]
for reward_key, reward_spec in list(full_reward_spec.items(True, True)):
# find out_key that match the in_key
- for in_key, out_key in zip(self.in_keys, self.out_keys):
+ for in_key, out_key in _zip_strict(self.in_keys, self.out_keys):
if reward_key == in_key:
if reward_spec.dtype != self.dtype_in:
raise TypeError(f"reward_spec.dtype is not {self.dtype_in}")
@@ -3710,7 +3739,7 @@ def transform_observation_spec(self, observation_spec):
full_observation_spec.items(True, True)
):
# find out_key that match the in_key
- for in_key, out_key in zip(self.in_keys, self.out_keys):
+ for in_key, out_key in _zip_strict(self.in_keys, self.out_keys):
if observation_key == in_key:
if observation_spec.dtype != self.dtype_in:
raise TypeError(
@@ -3733,7 +3762,7 @@ class DoubleToFloat(DTypeCastTransform):
"""Casts one dtype to another for selected keys.
Depending on whether the ``in_keys`` or ``in_keys_inv`` are provided
- during construction, the class behaviour will change:
+ during construction, the class behavior will change:
* If the keys are provided, those entries and those entries only will be
transformed from ``float64`` to ``float32`` entries;
@@ -3787,17 +3816,17 @@ class DoubleToFloat(DTypeCastTransform):
>>> print(td.get("not_transformed").dtype)
torch.float32
- The same behaviour is the rule when environments are constructedw without
+ The same behavior is the rule when environments are constructedw without
specifying the transform keys:
Examples:
>>> class MyEnv(EnvBase):
... def __init__(self):
... super().__init__()
- ... self.observation_spec = CompositeSpec(obs=UnboundedContinuousTensorSpec((), dtype=torch.float64))
- ... self.action_spec = UnboundedContinuousTensorSpec((), dtype=torch.float64)
- ... self.reward_spec = UnboundedContinuousTensorSpec((1,), dtype=torch.float64)
- ... self.done_spec = UnboundedContinuousTensorSpec((1,), dtype=torch.bool)
+ ... self.observation_spec = Composite(obs=Unbounded((), dtype=torch.float64))
+ ... self.action_spec = Unbounded((), dtype=torch.float64)
+ ... self.reward_spec = Unbounded((1,), dtype=torch.float64)
+ ... self.done_spec = Unbounded((1,), dtype=torch.bool)
... def _reset(self, data=None):
... return TensorDict({"done": torch.zeros((1,), dtype=torch.bool), **self.observation_spec.rand()}, [])
... def _step(self, data):
@@ -3864,6 +3893,19 @@ class DeviceCastTransform(Transform):
a parent environment exists, it it retrieved from it. In all other cases,
it remains unspecified.
+ Keyword Args:
+ in_keys (list of NestedKey): the list of entries to map to a different device.
+ Defaults to ``None``.
+ out_keys (list of NestedKey): the output names of the entries mapped onto a device.
+ Defaults to the values of ``in_keys``.
+ in_keys_inv (list of NestedKey): the list of entries to map to a different device.
+ ``in_keys_inv`` are the names expected by the base environment.
+ Defaults to ``None``.
+ out_keys_inv (list of NestedKey): the output names of the entries mapped onto a device.
+ ``out_keys_inv`` are the names of the keys as seen from outside the transformed env.
+ Defaults to the values of ``in_keys_inv``.
+
+
Examples:
>>> td = TensorDict(
... {'obs': torch.ones(1, dtype=torch.double),
@@ -3891,6 +3933,10 @@ def __init__(
self.orig_device = (
torch.device(orig_device) if orig_device is not None else orig_device
)
+ if out_keys is None:
+ out_keys = copy(in_keys)
+ if out_keys_inv is None:
+ out_keys_inv = copy(in_keys_inv)
super().__init__(
in_keys=in_keys,
out_keys=out_keys,
@@ -3943,7 +3989,7 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
return result
tensordict_t = tensordict.named_apply(self._to, nested_keys=True, device=None)
if self._rename_keys:
- for in_key, out_key in zip(self.in_keys, self.out_keys):
+ for in_key, out_key in _zip_strict(self.in_keys, self.out_keys):
if out_key != in_key:
tensordict_t.rename_key_(in_key, out_key)
tensordict_t.set(in_key, tensordict.get(in_key))
@@ -3957,7 +4003,7 @@ def _call(self, tensordict: TensorDictBase) -> TensorDictBase:
return result
tensordict_t = tensordict.named_apply(self._to, nested_keys=True, device=None)
if self._rename_keys:
- for in_key, out_key in zip(self.in_keys, self.out_keys):
+ for in_key, out_key in _zip_strict(self.in_keys, self.out_keys):
if out_key != in_key:
tensordict_t.rename_key_(in_key, out_key)
tensordict_t.set(in_key, tensordict.get(in_key))
@@ -3985,7 +4031,7 @@ def _inv_call(self, tensordict: TensorDictBase) -> TensorDictBase:
device=None,
)
if self._rename_keys_inv:
- for in_key, out_key in zip(self.in_keys_inv, self.out_keys_inv):
+ for in_key, out_key in _zip_strict(self.in_keys_inv, self.out_keys_inv):
if out_key != in_key:
tensordict_t.rename_key_(in_key, out_key)
tensordict_t.set(in_key, tensordict.get(in_key))
@@ -4010,58 +4056,58 @@ def _sync_orig_device(self):
return self._sync_orig_device
return sync_func
- def transform_input_spec(self, input_spec: CompositeSpec) -> CompositeSpec:
+ def transform_input_spec(self, input_spec: Composite) -> Composite:
if self._map_env_device:
return input_spec.to(self.device)
else:
+ input_spec.clear_device_()
return super().transform_input_spec(input_spec)
- def transform_action_spec(self, full_action_spec: CompositeSpec) -> CompositeSpec:
+ def transform_action_spec(self, full_action_spec: Composite) -> Composite:
full_action_spec = full_action_spec.clear_device_()
- for in_key, out_key in zip(self.in_keys_inv, self.out_keys_inv):
- if in_key not in full_action_spec.keys(True, True):
- continue
- full_action_spec[out_key] = full_action_spec[in_key].to(self.device)
+ for in_key, out_key in _zip_strict(self.in_keys_inv, self.out_keys_inv):
+ local_action_spec = full_action_spec.get(in_key, None)
+ if local_action_spec is not None:
+ full_action_spec[out_key] = local_action_spec.to(self.device)
return full_action_spec
- def transform_state_spec(self, full_state_spec: CompositeSpec) -> CompositeSpec:
+ def transform_state_spec(self, full_state_spec: Composite) -> Composite:
full_state_spec = full_state_spec.clear_device_()
- for in_key, out_key in zip(self.in_keys_inv, self.out_keys_inv):
- if in_key not in full_state_spec.keys(True, True):
- continue
- full_state_spec[out_key] = full_state_spec[in_key].to(self.device)
+ for in_key, out_key in _zip_strict(self.in_keys_inv, self.out_keys_inv):
+ local_state_spec = full_state_spec.get(in_key, None)
+ if local_state_spec is not None:
+ full_state_spec[out_key] = local_state_spec.to(self.device)
return full_state_spec
- def transform_output_spec(self, output_spec: CompositeSpec) -> CompositeSpec:
+ def transform_output_spec(self, output_spec: Composite) -> Composite:
if self._map_env_device:
return output_spec.to(self.device)
else:
+ output_spec.clear_device_()
return super().transform_output_spec(output_spec)
- def transform_observation_spec(
- self, observation_spec: CompositeSpec
- ) -> CompositeSpec:
+ def transform_observation_spec(self, observation_spec: Composite) -> Composite:
observation_spec = observation_spec.clear_device_()
- for in_key, out_key in zip(self.in_keys, self.out_keys):
- if in_key not in observation_spec.keys(True, True):
- continue
- observation_spec[out_key] = observation_spec[in_key].to(self.device)
+ for in_key, out_key in _zip_strict(self.in_keys, self.out_keys):
+ local_obs_spec = observation_spec.get(in_key, None)
+ if local_obs_spec is not None:
+ observation_spec[out_key] = local_obs_spec.to(self.device)
return observation_spec
- def transform_done_spec(self, full_done_spec: CompositeSpec) -> CompositeSpec:
+ def transform_done_spec(self, full_done_spec: Composite) -> Composite:
full_done_spec = full_done_spec.clear_device_()
- for in_key, out_key in zip(self.in_keys, self.out_keys):
- if in_key not in full_done_spec.keys(True, True):
- continue
- full_done_spec[out_key] = full_done_spec[in_key].to(self.device)
+ for in_key, out_key in _zip_strict(self.in_keys, self.out_keys):
+ local_done_spec = full_done_spec.get(in_key, None)
+ if local_done_spec is not None:
+ full_done_spec[out_key] = local_done_spec.to(self.device)
return full_done_spec
- def transform_reward_spec(self, full_reward_spec: CompositeSpec) -> CompositeSpec:
+ def transform_reward_spec(self, full_reward_spec: Composite) -> Composite:
full_reward_spec = full_reward_spec.clear_device_()
- for in_key, out_key in zip(self.in_keys, self.out_keys):
- if in_key not in full_reward_spec.keys(True, True):
- continue
- full_reward_spec[out_key] = full_reward_spec[in_key].to(self.device)
+ for in_key, out_key in _zip_strict(self.in_keys, self.out_keys):
+ local_reward_spec = full_reward_spec.get(in_key, None)
+ if local_reward_spec is not None:
+ full_reward_spec[out_key] = local_reward_spec.to(self.device)
return full_reward_spec
def transform_env_device(self, device):
@@ -4092,7 +4138,7 @@ class CatTensors(Transform):
Args:
in_keys (sequence of NestedKey): keys to be concatenated. If `None` (or not provided)
the keys will be retrieved from the parent environment the first time
- the transform is used. This behaviour will only work if a parent is set.
+ the transform is used. This behavior will only work if a parent is set.
out_key (NestedKey): key of the resulting tensor.
dim (int, optional): dimension along which the concatenation will occur.
Default is ``-1``.
@@ -4215,13 +4261,13 @@ def transform_observation_spec(self, observation_spec: TensorSpec) -> TensorSpec
self._initialized = True
# check that all keys are in observation_spec
- if len(self.in_keys) > 1 and not isinstance(observation_spec, CompositeSpec):
+ if len(self.in_keys) > 1 and not isinstance(observation_spec, Composite):
raise ValueError(
"CatTensor cannot infer the output observation spec as there are multiple input keys but "
"only one observation_spec."
)
- if isinstance(observation_spec, CompositeSpec) and len(
+ if isinstance(observation_spec, Composite) and len(
[key for key in self.in_keys if key not in observation_spec.keys(True)]
):
raise ValueError(
@@ -4229,7 +4275,7 @@ def transform_observation_spec(self, observation_spec: TensorSpec) -> TensorSpec
"Make sure the environment has an observation_spec attribute that includes all the specs needed for CatTensor."
)
- if not isinstance(observation_spec, CompositeSpec):
+ if not isinstance(observation_spec, Composite):
# by def, there must be only one key
return observation_spec
@@ -4249,7 +4295,7 @@ def transform_observation_spec(self, observation_spec: TensorSpec) -> TensorSpec
device = spec0.device
shape[self.dim] = sum_shape
shape = torch.Size(shape)
- observation_spec[out_key] = UnboundedContinuousTensorSpec(
+ observation_spec[out_key] = Unbounded(
shape=shape,
dtype=spec0.dtype,
device=device,
@@ -4357,14 +4403,14 @@ def _inv_apply_transform(self, action: torch.Tensor) -> torch.Tensor:
action = nn.functional.one_hot(action, self.num_actions_effective)
return action
- def transform_input_spec(self, input_spec: CompositeSpec):
+ def transform_input_spec(self, input_spec: Composite):
input_spec = input_spec.clone()
for key in input_spec["full_action_spec"].keys(True, True):
key = ("full_action_spec", key)
break
else:
raise KeyError("key not found in action_spec.")
- input_spec[key] = OneHotDiscreteTensorSpec(
+ input_spec[key] = OneHot(
self.max_actions,
shape=(*input_spec[key].shape[:-1], self.max_actions),
device=input_spec.device,
@@ -4456,7 +4502,7 @@ def _reset(
)
# Merge the two tensordicts
tensordict = parent._reset_proc_data(tensordict.clone(False), tensordict_reset)
- # check that there is a single done state -- behaviour is undefined for multiple dones
+ # check that there is a single done state -- behavior is undefined for multiple dones
done_keys = parent.done_keys
reward_key = parent.reward_key
if parent.batch_size.numel() > 1:
@@ -4526,9 +4572,9 @@ class TensorDictPrimer(Transform):
tensordict with the desired features.
Args:
- primers (dict or CompositeSpec, optional): a dictionary containing
+ primers (dict or Composite, optional): a dictionary containing
key-spec pairs which will be used to populate the input tensordict.
- :class:`~torchrl.data.CompositeSpec` instances are supported too.
+ :class:`~torchrl.data.Composite` instances are supported too.
random (bool, optional): if ``True``, the values will be drawn randomly from
the TensorSpec domain (or a unit Gaussian if unbounded). Otherwise a fixed value will be assumed.
Defaults to `False`.
@@ -4557,7 +4603,7 @@ class TensorDictPrimer(Transform):
>>> base_env = SerialEnv(2, lambda: GymEnv("Pendulum-v1"))
>>> env = TransformedEnv(base_env)
>>> # the env is batch-locked, so the leading dims of the spec must match those of the env
- >>> env.append_transform(TensorDictPrimer(mykey=UnboundedContinuousTensorSpec([2, 3])))
+ >>> env.append_transform(TensorDictPrimer(mykey=Unbounded([2, 3])))
>>> td = env.reset()
>>> print(td)
TensorDict(
@@ -4591,14 +4637,14 @@ class TensorDictPrimer(Transform):
.. note:: Some TorchRL modules rely on specific keys being present in the environment TensorDicts,
like :class:`~torchrl.modules.models.LSTM` or :class:`~torchrl.modules.models.GRU`.
- To facilitate this process, the method :func:`~torchrl.models.utils.get_primers_from_module`
+ To facilitate this process, the method :func:`~torchrl.modules.utils.get_primers_from_module`
automatically checks for required primer transforms in a module and its submodules and
generates them.
"""
def __init__(
self,
- primers: dict | CompositeSpec = None,
+ primers: dict | Composite = None,
random: bool | None = None,
default_value: float
| Callable
@@ -4615,8 +4661,8 @@ def __init__(
"as kwargs."
)
kwargs = primers
- if not isinstance(kwargs, CompositeSpec):
- kwargs = CompositeSpec(kwargs)
+ if not isinstance(kwargs, Composite):
+ kwargs = Composite(kwargs)
self.primers = kwargs
if random and default_value:
raise ValueError(
@@ -4658,10 +4704,15 @@ def __init__(
def reset_key(self):
reset_key = self.__dict__.get("_reset_key", None)
if reset_key is None:
+ if self.parent is None:
+ raise RuntimeError(
+ "Missing parent, cannot infer reset_key automatically."
+ )
reset_keys = self.parent.reset_keys
if len(reset_keys) > 1:
raise RuntimeError(
- f"Got more than one reset key in env {self.container}, cannot infer which one to use. Consider providing the reset key in the {type(self)} constructor."
+ f"Got more than one reset key in env {self.container}, cannot infer which one to use. "
+ f"Consider providing the reset key in the {type(self)} constructor."
)
reset_key = self._reset_key = reset_keys[0]
return reset_key
@@ -4698,12 +4749,10 @@ def to(self, *args, **kwargs):
def _expand_shape(self, spec):
return spec.expand((*self.parent.batch_size, *spec.shape))
- def transform_observation_spec(
- self, observation_spec: CompositeSpec
- ) -> CompositeSpec:
- if not isinstance(observation_spec, CompositeSpec):
+ def transform_observation_spec(self, observation_spec: Composite) -> Composite:
+ if not isinstance(observation_spec, Composite):
raise ValueError(
- f"observation_spec was expected to be of type CompositeSpec. Got {type(observation_spec)} instead."
+ f"observation_spec was expected to be of type Composite. Got {type(observation_spec)} instead."
)
if self.primers.shape != observation_spec.shape:
@@ -4718,9 +4767,12 @@ def transform_observation_spec(
return observation_spec
def transform_input_spec(self, input_spec: TensorSpec) -> TensorSpec:
- input_spec["full_state_spec"] = self.transform_observation_spec(
- input_spec["full_state_spec"]
- )
+ new_state_spec = self.transform_observation_spec(input_spec["full_state_spec"])
+ for action_key in list(input_spec["full_action_spec"].keys(True, True)):
+ if action_key in new_state_spec.keys(True, True):
+ input_spec["full_action_spec", action_key] = new_state_spec[action_key]
+ del new_state_spec[action_key]
+ input_spec["full_state_spec"] = new_state_spec
return input_spec
@property
@@ -4866,7 +4918,7 @@ def __init__(
)
random = state_dim is not None and action_dim is not None
shape = tuple(shape) + tail_dim
- primers = {"_eps_gSDE": UnboundedContinuousTensorSpec(shape=shape)}
+ primers = {"_eps_gSDE": Unbounded(shape=shape)}
super().__init__(primers=primers, random=random, **kwargs)
@@ -5010,7 +5062,7 @@ def _call(self, tensordict: TensorDictBase) -> TensorDictBase:
if self.lock is not None:
self.lock.acquire()
- for key, key_out in zip(self.in_keys, self.out_keys):
+ for key, key_out in _zip_strict(self.in_keys, self.out_keys):
if key not in tensordict.keys(include_nested=True):
# TODO: init missing rewards with this
# for key_suffix in [_append_last(key, suffix) for suffix in ("_sum", "_ssq", "_count")]:
@@ -5148,7 +5200,7 @@ def to_observation_norm(self) -> Union[Compose, ObservationNorm]:
out = []
loc = self.loc
scale = self.scale
- for key, key_out in zip(self.in_keys, self.out_keys):
+ for key, key_out in _zip_strict(self.in_keys, self.out_keys):
_out = ObservationNorm(
loc=loc.get(key),
scale=scale.get(key),
@@ -5325,8 +5377,8 @@ def __setstate__(self, state: Dict[str, Any]):
@_apply_to_composite
def transform_observation_spec(self, observation_spec: TensorSpec) -> TensorSpec:
- if isinstance(observation_spec, BoundedTensorSpec):
- return UnboundedContinuousTensorSpec(
+ if isinstance(observation_spec, Bounded):
+ return Unbounded(
shape=observation_spec.shape,
dtype=observation_spec.dtype,
device=observation_spec.device,
@@ -5464,10 +5516,12 @@ def reset_keys(self):
# We take the filtered reset keys, which are the only keys that really
# matter when calling reset, and check that they match the in_keys root.
reset_keys = parent._filtered_reset_keys
+ if len(reset_keys) == 1:
+ reset_keys = list(reset_keys) * len(self.in_keys)
def _check_match(reset_keys, in_keys):
# if this is called, the length of reset_keys and in_keys must match
- for reset_key, in_key in zip(reset_keys, in_keys):
+ for reset_key, in_key in _zip_strict(reset_keys, in_keys):
# having _reset at the root and the reward_key ("agent", "reward") is allowed
# but having ("agent", "_reset") and "reward" isn't
if isinstance(reset_key, tuple) and isinstance(in_key, str):
@@ -5511,7 +5565,7 @@ def _reset(
self, tensordict: TensorDictBase, tensordict_reset: TensorDictBase
) -> TensorDictBase:
"""Resets episode rewards."""
- for in_key, reset_key, out_key in zip(
+ for in_key, reset_key, out_key in _zip_strict(
self.in_keys, self.reset_keys, self.out_keys
):
_reset = _get_reset(reset_key, tensordict)
@@ -5528,7 +5582,7 @@ def _step(
) -> TensorDictBase:
"""Updates the episode rewards with the step rewards."""
# Update episode rewards
- for in_key, out_key in zip(self.in_keys, self.out_keys):
+ for in_key, out_key in _zip_strict(self.in_keys, self.out_keys):
if in_key in next_tensordict.keys(include_nested=True):
reward = next_tensordict.get(in_key)
prev_reward = tensordict.get(out_key, 0.0)
@@ -5540,17 +5594,17 @@ def _step(
def transform_input_spec(self, input_spec: TensorSpec) -> TensorSpec:
state_spec = input_spec["full_state_spec"]
if state_spec is None:
- state_spec = CompositeSpec(shape=input_spec.shape, device=input_spec.device)
+ state_spec = Composite(shape=input_spec.shape, device=input_spec.device)
state_spec.update(self._generate_episode_reward_spec())
input_spec["full_state_spec"] = state_spec
return input_spec
- def _generate_episode_reward_spec(self) -> CompositeSpec:
- episode_reward_spec = CompositeSpec()
+ def _generate_episode_reward_spec(self) -> Composite:
+ episode_reward_spec = Composite()
reward_spec = self.parent.full_reward_spec
reward_spec_keys = self.parent.reward_keys
# Define episode specs for all out_keys
- for in_key, out_key in zip(self.in_keys, self.out_keys):
+ for in_key, out_key in _zip_strict(self.in_keys, self.out_keys):
if (
in_key in reward_spec_keys
): # if this out_key has a corresponding key in reward_spec
@@ -5559,7 +5613,7 @@ def _generate_episode_reward_spec(self) -> CompositeSpec:
temp_rew_spec = reward_spec
for sub_key in out_key[:-1]:
if (
- not isinstance(temp_rew_spec, CompositeSpec)
+ not isinstance(temp_rew_spec, Composite)
or sub_key not in temp_rew_spec.keys()
):
break
@@ -5580,8 +5634,8 @@ def transform_observation_spec(self, observation_spec: TensorSpec) -> TensorSpec
"""Transforms the observation spec, adding the new keys generated by RewardSum."""
if self.reward_spec:
return observation_spec
- if not isinstance(observation_spec, CompositeSpec):
- observation_spec = CompositeSpec(
+ if not isinstance(observation_spec, Composite):
+ observation_spec = Composite(
observation=observation_spec, shape=self.parent.batch_size
)
observation_spec.update(self._generate_episode_reward_spec())
@@ -5600,8 +5654,8 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
"At least one dimension of the tensordict must be named 'time' in offline mode"
)
time_dim = time_dim[0] - 1
- for in_key, out_key in zip(self.in_keys, self.out_keys):
- reward = tensordict.get(in_key)
+ for in_key, out_key in _zip_strict(self.in_keys, self.out_keys):
+ reward = tensordict[in_key]
cumsum = reward.cumsum(time_dim)
tensordict.set(out_key, cumsum)
return tensordict
@@ -5778,7 +5832,13 @@ def _reset(
self, tensordict: TensorDictBase, tensordict_reset: TensorDictBase
) -> TensorDictBase:
# get reset signal
- for step_count_key, truncated_key, terminated_key, reset_key, done_key in zip(
+ for (
+ step_count_key,
+ truncated_key,
+ terminated_key,
+ reset_key,
+ done_key,
+ ) in _zip_strict(
self.step_count_keys,
self.truncated_keys,
self.terminated_keys,
@@ -5819,10 +5879,8 @@ def _reset(
def _step(
self, tensordict: TensorDictBase, next_tensordict: TensorDictBase
) -> TensorDictBase:
- for step_count_key, truncated_key, done_key in zip(
- self.step_count_keys,
- self.truncated_keys,
- self.done_keys,
+ for step_count_key, truncated_key, done_key in _zip_strict(
+ self.step_count_keys, self.truncated_keys, self.done_keys
):
step_count = tensordict.get(step_count_key)
next_step_count = step_count + 1
@@ -5844,12 +5902,10 @@ def _step(
next_tensordict.set(truncated_key, truncated)
return next_tensordict
- def transform_observation_spec(
- self, observation_spec: CompositeSpec
- ) -> CompositeSpec:
- if not isinstance(observation_spec, CompositeSpec):
+ def transform_observation_spec(self, observation_spec: Composite) -> Composite:
+ if not isinstance(observation_spec, Composite):
raise ValueError(
- f"observation_spec was expected to be of type CompositeSpec. Got {type(observation_spec)} instead."
+ f"observation_spec was expected to be of type Composite. Got {type(observation_spec)} instead."
)
full_done_spec = self.parent.output_spec["full_done_spec"]
for step_count_key in self.step_count_keys:
@@ -5871,7 +5927,7 @@ def transform_observation_spec(
raise KeyError(
f"Could not find root of step_count_key {step_count_key} in done keys {self.done_keys}."
)
- observation_spec[step_count_key] = BoundedTensorSpec(
+ observation_spec[step_count_key] = Bounded(
shape=shape,
dtype=torch.int64,
device=observation_spec.device,
@@ -5880,7 +5936,7 @@ def transform_observation_spec(
)
return super().transform_observation_spec(observation_spec)
- def transform_output_spec(self, output_spec: CompositeSpec) -> CompositeSpec:
+ def transform_output_spec(self, output_spec: Composite) -> Composite:
if self.max_steps:
full_done_spec = self.parent.output_spec["full_done_spec"]
for truncated_key in self.truncated_keys:
@@ -5902,7 +5958,7 @@ def transform_output_spec(self, output_spec: CompositeSpec) -> CompositeSpec:
raise KeyError(
f"Could not find root of truncated_key {truncated_key} in done keys {self.done_keys}."
)
- full_done_spec[truncated_key] = DiscreteTensorSpec(
+ full_done_spec[truncated_key] = Categorical(
2, dtype=torch.bool, device=output_spec.device, shape=shape
)
if self.update_done:
@@ -5925,19 +5981,19 @@ def transform_output_spec(self, output_spec: CompositeSpec) -> CompositeSpec:
raise KeyError(
f"Could not find root of stop_key {done_key} in done keys {self.done_keys}."
)
- full_done_spec[done_key] = DiscreteTensorSpec(
+ full_done_spec[done_key] = Categorical(
2, dtype=torch.bool, device=output_spec.device, shape=shape
)
output_spec["full_done_spec"] = full_done_spec
return super().transform_output_spec(output_spec)
- def transform_input_spec(self, input_spec: CompositeSpec) -> CompositeSpec:
- if not isinstance(input_spec, CompositeSpec):
+ def transform_input_spec(self, input_spec: Composite) -> Composite:
+ if not isinstance(input_spec, Composite):
raise ValueError(
- f"input_spec was expected to be of type CompositeSpec. Got {type(input_spec)} instead."
+ f"input_spec was expected to be of type Composite. Got {type(input_spec)} instead."
)
if input_spec["full_state_spec"] is None:
- input_spec["full_state_spec"] = CompositeSpec(
+ input_spec["full_state_spec"] = Composite(
shape=input_spec.shape, device=input_spec.device
)
@@ -5962,9 +6018,7 @@ def transform_input_spec(self, input_spec: CompositeSpec) -> CompositeSpec:
f"Could not find root of step_count_key {step_count_key} in done keys {self.done_keys}."
)
- input_spec[
- unravel_key(("full_state_spec", step_count_key))
- ] = BoundedTensorSpec(
+ input_spec[unravel_key(("full_state_spec", step_count_key))] = Bounded(
shape=shape,
dtype=torch.int64,
device=input_spec.device,
@@ -6051,7 +6105,7 @@ def _reset(
return tensordict_reset.exclude(*self.excluded_keys)
return tensordict
- def transform_output_spec(self, output_spec: CompositeSpec) -> CompositeSpec:
+ def transform_output_spec(self, output_spec: Composite) -> Composite:
if not self.inverse:
full_done_spec = output_spec["full_done_spec"]
full_reward_spec = output_spec["full_reward_spec"]
@@ -6171,7 +6225,7 @@ def _reset(
*self.selected_keys, *reward_keys, *done_keys, *input_keys, strict=False
)
- def transform_output_spec(self, output_spec: CompositeSpec) -> CompositeSpec:
+ def transform_output_spec(self, output_spec: Composite) -> Composite:
full_done_spec = output_spec["full_done_spec"]
full_reward_spec = output_spec["full_reward_spec"]
full_observation_spec = output_spec["full_observation_spec"]
@@ -6325,7 +6379,7 @@ def _make_missing_buffer(self, tensordict, in_key, buffer_name):
def _call(self, tensordict: TensorDictBase, _reset=None) -> TensorDictBase:
"""Update the episode tensordict with max pooled keys."""
- for in_key, out_key in zip(self.in_keys, self.out_keys):
+ for in_key, out_key in _zip_strict(self.in_keys, self.out_keys):
# Lazy init of buffers
buffer_name = self._buffer_name(in_key)
buffer = getattr(self, buffer_name)
@@ -6381,7 +6435,7 @@ class RandomCropTensorDict(Transform):
This transform is primarily designed to be used with replay buffers and modules.
Currently, it cannot be used as an environment transform.
- Do not hesitate to request for this behaviour through an issue if this is
+ Do not hesitate to request for this behavior through an issue if this is
desired.
Args:
@@ -6409,7 +6463,7 @@ def __init__(
if sample_dim > 0:
warnings.warn(
"A positive shape has been passed to the RandomCropTensorDict "
- "constructor. This may have unexpected behaviours when the "
+ "constructor. This may have unexpected behaviors when the "
"passed tensordicts have inconsistent batch dimensions. "
"For context, by convention, TorchRL concatenates time steps "
"along the last dimension of the tensordict."
@@ -6566,7 +6620,7 @@ def _reset(
device = tensordict.device
if device is None:
device = torch.device("cpu")
- for reset_key, init_key in zip(self.reset_keys, self.init_keys):
+ for reset_key, init_key in _zip_strict(self.reset_keys, self.init_keys):
_reset = tensordict.get(reset_key, None)
if _reset is None:
done_key = _replace_last(init_key, "done")
@@ -6610,7 +6664,7 @@ def transform_observation_spec(self, observation_spec: TensorSpec) -> TensorSpec
raise KeyError(
f"Could not find root of init_key {init_key} within done_keys {self.parent.done_keys}."
)
- observation_spec[init_key] = DiscreteTensorSpec(
+ observation_spec[init_key] = Categorical(
2,
dtype=torch.bool,
device=self.parent.device,
@@ -6625,15 +6679,15 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
class RenameTransform(Transform):
- """A transform to rename entries in the output tensordict.
+ """A transform to rename entries in the output tensordict (or input tensordict via the inverse keys).
Args:
- in_keys (sequence of NestedKey): the entries to rename
+ in_keys (sequence of NestedKey): the entries to rename.
out_keys (sequence of NestedKey): the name of the entries after renaming.
- in_keys_inv (sequence of NestedKey, optional): the entries to rename before
- passing the input tensordict to :meth:`EnvBase._step`.
- out_keys_inv (sequence of NestedKey, optional): the names of the renamed
- entries passed to :meth:`EnvBase._step`.
+ in_keys_inv (sequence of NestedKey, optional): the entries to rename
+ in the input tensordict, which will be passed to :meth:`EnvBase._step`.
+ out_keys_inv (sequence of NestedKey, optional): the names of the entries
+ in the input tensordict after renaming.
create_copy (bool, optional): if ``True``, the entries will be copied
with a different name rather than being renamed. This allows for
renaming immutable entries such as ``"reward"`` and ``"done"``.
@@ -6702,15 +6756,15 @@ def __init__(
def _call(self, tensordict: TensorDictBase) -> TensorDictBase:
if self.create_copy:
out = tensordict.select(*self.in_keys, strict=not self._missing_tolerance)
- for in_key, out_key in zip(self.in_keys, self.out_keys):
+ for in_key, out_key in _zip_strict(self.in_keys, self.out_keys):
try:
- tensordict.rename_key_(in_key, out_key)
+ out.rename_key_(in_key, out_key)
except KeyError:
if not self._missing_tolerance:
raise
tensordict = tensordict.update(out)
else:
- for in_key, out_key in zip(self.in_keys, self.out_keys):
+ for in_key, out_key in _zip_strict(self.in_keys, self.out_keys):
try:
tensordict.rename_key_(in_key, out_key)
except KeyError:
@@ -6732,7 +6786,7 @@ def _inv_call(self, tensordict: TensorDictBase) -> TensorDictBase:
out = tensordict.select(
*self.out_keys_inv, strict=not self._missing_tolerance
)
- for in_key, out_key in zip(self.in_keys_inv, self.out_keys_inv):
+ for in_key, out_key in _zip_strict(self.in_keys_inv, self.out_keys_inv):
try:
out.rename_key_(out_key, in_key)
except KeyError:
@@ -6741,7 +6795,7 @@ def _inv_call(self, tensordict: TensorDictBase) -> TensorDictBase:
tensordict = tensordict.update(out)
else:
- for in_key, out_key in zip(self.in_keys_inv, self.out_keys_inv):
+ for in_key, out_key in _zip_strict(self.in_keys_inv, self.out_keys_inv):
try:
tensordict.rename_key_(out_key, in_key)
except KeyError:
@@ -6749,7 +6803,7 @@ def _inv_call(self, tensordict: TensorDictBase) -> TensorDictBase:
raise
return tensordict
- def transform_output_spec(self, output_spec: CompositeSpec) -> CompositeSpec:
+ def transform_output_spec(self, output_spec: Composite) -> Composite:
for done_key in self.parent.done_keys:
if done_key in self.in_keys:
for i, out_key in enumerate(self.out_keys): # noqa: B007
@@ -6791,11 +6845,11 @@ def transform_output_spec(self, output_spec: CompositeSpec) -> CompositeSpec:
del output_spec["full_observation_spec"][observation_key]
return output_spec
- def transform_input_spec(self, input_spec: CompositeSpec) -> CompositeSpec:
+ def transform_input_spec(self, input_spec: Composite) -> Composite:
for action_key in self.parent.action_keys:
- if action_key in self.in_keys:
- for i, out_key in enumerate(self.out_keys): # noqa: B007
- if self.in_keys[i] == action_key:
+ if action_key in self.in_keys_inv:
+ for i, out_key in enumerate(self.out_keys_inv): # noqa: B007
+ if self.in_keys_inv[i] == action_key:
break
else:
# unreachable
@@ -6806,9 +6860,9 @@ def transform_input_spec(self, input_spec: CompositeSpec) -> CompositeSpec:
if not self.create_copy:
del input_spec["full_action_spec"][action_key]
for state_key in self.parent.full_state_spec.keys(True):
- if state_key in self.in_keys:
- for i, out_key in enumerate(self.out_keys): # noqa: B007
- if self.in_keys[i] == state_key:
+ if state_key in self.in_keys_inv:
+ for i, out_key in enumerate(self.out_keys_inv): # noqa: B007
+ if self.in_keys_inv[i] == state_key:
break
else:
# unreachable
@@ -6962,7 +7016,7 @@ def _inv_call(self, tensordict: TensorDictBase) -> TensorDictBase:
"No episode ends found to calculate the reward to go. Make sure that the number of frames_per_batch is larger than number of steps per episode."
)
found = False
- for in_key, out_key in zip(self.in_keys_inv, self.out_keys_inv):
+ for in_key, out_key in _zip_strict(self.in_keys_inv, self.out_keys_inv):
if in_key in tensordict.keys(include_nested=True):
found = True
item = self._inv_apply_transform(tensordict.get(in_key), done)
@@ -7003,16 +7057,16 @@ class ActionMask(Transform):
Examples:
>>> import torch
- >>> from torchrl.data.tensor_specs import DiscreteTensorSpec, BinaryDiscreteTensorSpec, UnboundedContinuousTensorSpec, CompositeSpec
+ >>> from torchrl.data.tensor_specs import Categorical, Binary, Unbounded, Composite
>>> from torchrl.envs.transforms import ActionMask, TransformedEnv
>>> from torchrl.envs.common import EnvBase
>>> class MaskedEnv(EnvBase):
... def __init__(self, *args, **kwargs):
... super().__init__(*args, **kwargs)
- ... self.action_spec = DiscreteTensorSpec(4)
- ... self.state_spec = CompositeSpec(action_mask=BinaryDiscreteTensorSpec(4, dtype=torch.bool))
- ... self.observation_spec = CompositeSpec(obs=UnboundedContinuousTensorSpec(3))
- ... self.reward_spec = UnboundedContinuousTensorSpec(1)
+ ... self.action_spec = Categorical(4)
+ ... self.state_spec = Composite(action_mask=Binary(4, dtype=torch.bool))
+ ... self.observation_spec = Composite(obs=Unbounded(3))
+ ... self.reward_spec = Unbounded(1)
...
... def _reset(self, tensordict=None):
... td = self.observation_spec.rand()
@@ -7048,10 +7102,10 @@ class ActionMask(Transform):
"""
ACCEPTED_SPECS = (
- OneHotDiscreteTensorSpec,
- DiscreteTensorSpec,
- MultiOneHotDiscreteTensorSpec,
- MultiDiscreteTensorSpec,
+ OneHot,
+ Categorical,
+ MultiOneHot,
+ MultiCategorical,
)
SPEC_TYPE_ERROR = "The action spec must be one of {}. Got {} instead."
@@ -7477,7 +7531,7 @@ def _inv_apply_transform(self, state: torch.Tensor) -> torch.Tensor:
@_apply_to_composite
def transform_observation_spec(self, observation_spec: TensorSpec) -> TensorSpec:
- return BoundedTensorSpec(
+ return Bounded(
shape=observation_spec.shape,
device=observation_spec.device,
dtype=observation_spec.dtype,
@@ -7489,7 +7543,7 @@ def transform_reward_spec(self, reward_spec: TensorSpec) -> TensorSpec:
for key in self.in_keys:
if key in self.parent.reward_keys:
spec = self.parent.output_spec["full_reward_spec"][key]
- self.parent.output_spec["full_reward_spec"][key] = BoundedTensorSpec(
+ self.parent.output_spec["full_reward_spec"][key] = Bounded(
shape=spec.shape,
device=spec.device,
dtype=spec.dtype,
@@ -7512,31 +7566,31 @@ class RemoveEmptySpecs(Transform):
Examples:
>>> import torch
>>> from tensordict import TensorDict
- >>> from torchrl.data import UnboundedContinuousTensorSpec, CompositeSpec, \
- ... DiscreteTensorSpec
+ >>> from torchrl.data import Unbounded, Composite, \
+ ... Categorical
>>> from torchrl.envs import EnvBase, TransformedEnv, RemoveEmptySpecs
>>>
>>>
>>> class DummyEnv(EnvBase):
... def __init__(self, *args, **kwargs):
... super().__init__(*args, **kwargs)
- ... self.observation_spec = CompositeSpec(
- ... observation=UnboundedContinuousTensorSpec((*self.batch_size, 3)),
- ... other=CompositeSpec(
- ... another_other=CompositeSpec(shape=self.batch_size),
+ ... self.observation_spec = Composite(
+ ... observation=UnboundedContinuous((*self.batch_size, 3)),
+ ... other=Composite(
+ ... another_other=Composite(shape=self.batch_size),
... shape=self.batch_size,
... ),
... shape=self.batch_size,
... )
- ... self.action_spec = UnboundedContinuousTensorSpec((*self.batch_size, 3))
- ... self.done_spec = DiscreteTensorSpec(
+ ... self.action_spec = UnboundedContinuous((*self.batch_size, 3))
+ ... self.done_spec = Categorical(
... 2, (*self.batch_size, 1), dtype=torch.bool
... )
... self.full_done_spec["truncated"] = self.full_done_spec[
... "terminated"].clone()
- ... self.reward_spec = CompositeSpec(
- ... reward=UnboundedContinuousTensorSpec(*self.batch_size, 1),
- ... other_reward=CompositeSpec(shape=self.batch_size),
+ ... self.reward_spec = Composite(
+ ... reward=UnboundedContinuous(*self.batch_size, 1),
+ ... other_reward=Composite(shape=self.batch_size),
... shape=self.batch_size
... )
...
@@ -7629,7 +7683,7 @@ def _sorter(key_val):
return 0
return len(key)
- def transform_output_spec(self, output_spec: CompositeSpec) -> CompositeSpec:
+ def transform_output_spec(self, output_spec: Composite) -> Composite:
full_done_spec = output_spec["full_done_spec"]
full_reward_spec = output_spec["full_reward_spec"]
full_observation_spec = output_spec["full_observation_spec"]
@@ -7637,19 +7691,19 @@ def transform_output_spec(self, output_spec: CompositeSpec) -> CompositeSpec:
for key, spec in sorted(
full_done_spec.items(True), key=self._sorter, reverse=True
):
- if isinstance(spec, CompositeSpec) and spec.is_empty():
+ if isinstance(spec, Composite) and spec.is_empty():
del full_done_spec[key]
for key, spec in sorted(
full_observation_spec.items(True), key=self._sorter, reverse=True
):
- if isinstance(spec, CompositeSpec) and spec.is_empty():
+ if isinstance(spec, Composite) and spec.is_empty():
del full_observation_spec[key]
for key, spec in sorted(
full_reward_spec.items(True), key=self._sorter, reverse=True
):
- if isinstance(spec, CompositeSpec) and spec.is_empty():
+ if isinstance(spec, Composite) and spec.is_empty():
del full_reward_spec[key]
return output_spec
@@ -7662,14 +7716,14 @@ def transform_input_spec(self, input_spec: TensorSpec) -> TensorSpec:
for key, spec in sorted(
full_action_spec.items(True), key=self._sorter, reverse=True
):
- if isinstance(spec, CompositeSpec) and spec.is_empty():
+ if isinstance(spec, Composite) and spec.is_empty():
self._has_empty_input = True
del full_action_spec[key]
for key, spec in sorted(
full_state_spec.items(True), key=self._sorter, reverse=True
):
- if isinstance(spec, CompositeSpec) and spec.is_empty():
+ if isinstance(spec, Composite) and spec.is_empty():
self._has_empty_input = True
del full_state_spec[key]
return input_spec
@@ -7688,7 +7742,7 @@ def _inv_call(self, tensordict: TensorDictBase) -> TensorDictBase:
full_action_spec.items(True), key=self._sorter, reverse=True
):
if (
- isinstance(spec, CompositeSpec)
+ isinstance(spec, Composite)
and spec.is_empty()
and key not in tensordict.keys(True)
):
@@ -7698,7 +7752,7 @@ def _inv_call(self, tensordict: TensorDictBase) -> TensorDictBase:
full_state_spec.items(True), key=self._sorter, reverse=True
):
if (
- isinstance(spec, CompositeSpec)
+ isinstance(spec, Composite)
and spec.is_empty()
and key not in tensordict.keys(True)
):
@@ -7866,9 +7920,9 @@ class BatchSizeTransform(Transform):
... batch_locked = False
... def __init__(self):
... super().__init__()
- ... self.observation_spec = CompositeSpec(observation=UnboundedContinuousTensorSpec(3))
- ... self.reward_spec = UnboundedContinuousTensorSpec(1)
- ... self.action_spec = UnboundedContinuousTensorSpec(1)
+ ... self.observation_spec = Composite(observation=Unbounded(3))
+ ... self.reward_spec = Unbounded(1)
+ ... self.action_spec = Unbounded(1)
...
... def _reset(self, tensordict: TensorDictBase, **kwargs) -> TensorDictBase:
... tensordict_batch_size = tensordict.batch_size if tensordict is not None else torch.Size([])
@@ -8016,12 +8070,12 @@ def transform_env_batch_size(self, batch_size: torch.Size):
return self.batch_size
return self.reshape_fn(torch.zeros(batch_size, device="meta")).shape
- def transform_output_spec(self, output_spec: CompositeSpec) -> CompositeSpec:
+ def transform_output_spec(self, output_spec: Composite) -> Composite:
if self.batch_size is not None:
return output_spec.expand(self.batch_size)
return self.reshape_fn(output_spec)
- def transform_input_spec(self, input_spec: CompositeSpec) -> CompositeSpec:
+ def transform_input_spec(self, input_spec: Composite) -> Composite:
if self.batch_size is not None:
return input_spec.expand(self.batch_size)
return self.reshape_fn(input_spec)
@@ -8480,7 +8534,7 @@ def _indent(s):
def transform_input_spec(self, input_spec):
try:
action_spec = input_spec["full_action_spec", self.in_keys_inv[0]]
- if not isinstance(action_spec, BoundedTensorSpec):
+ if not isinstance(action_spec, Bounded):
raise TypeError(
f"action spec type {type(action_spec)} is not supported."
)
@@ -8539,9 +8593,9 @@ def custom_arange(nint):
]
cls = (
- functools.partial(MultiDiscreteTensorSpec, remove_singleton=False)
+ functools.partial(MultiCategorical, remove_singleton=False)
if self.categorical
- else MultiOneHotDiscreteTensorSpec
+ else MultiOneHot
)
if not isinstance(num_intervals, torch.Tensor):
diff --git a/torchrl/envs/transforms/vc1.py b/torchrl/envs/transforms/vc1.py
index d8bec1cf524..d394816372d 100644
--- a/torchrl/envs/transforms/vc1.py
+++ b/torchrl/envs/transforms/vc1.py
@@ -14,12 +14,7 @@
from torch import nn
from torchrl._utils import logger as torchrl_logger
-from torchrl.data.tensor_specs import (
- CompositeSpec,
- DEVICE_TYPING,
- TensorSpec,
- UnboundedContinuousTensorSpec,
-)
+from torchrl.data.tensor_specs import Composite, DEVICE_TYPING, TensorSpec, Unbounded
from torchrl.envs.transforms.transforms import (
CenterCrop,
Compose,
@@ -198,8 +193,8 @@ def _apply_transform(self, obs: torch.Tensor) -> None:
return out
def transform_observation_spec(self, observation_spec: TensorSpec) -> TensorSpec:
- if not isinstance(observation_spec, CompositeSpec):
- raise ValueError("VC1Transform can only infer CompositeSpec")
+ if not isinstance(observation_spec, Composite):
+ raise ValueError("VC1Transform can only infer Composite")
keys = [key for key in observation_spec.keys(True, True) if key in self.in_keys]
device = observation_spec[keys[0]].device
@@ -211,7 +206,7 @@ def transform_observation_spec(self, observation_spec: TensorSpec) -> TensorSpec
del observation_spec[in_key]
for out_key in self.out_keys:
- observation_spec[out_key] = UnboundedContinuousTensorSpec(
+ observation_spec[out_key] = Unbounded(
shape=torch.Size([*dim, self.embd_size]), device=device
)
diff --git a/torchrl/envs/transforms/vip.py b/torchrl/envs/transforms/vip.py
index e814f5da476..a28e490c4f1 100644
--- a/torchrl/envs/transforms/vip.py
+++ b/torchrl/envs/transforms/vip.py
@@ -9,11 +9,7 @@
from tensordict import set_lazy_legacy, TensorDict, TensorDictBase
from torch.hub import load_state_dict_from_url
-from torchrl.data.tensor_specs import (
- CompositeSpec,
- TensorSpec,
- UnboundedContinuousTensorSpec,
-)
+from torchrl.data.tensor_specs import Composite, TensorSpec, Unbounded
from torchrl.data.utils import DEVICE_TYPING
from torchrl.envs.transforms.transforms import (
CatTensors,
@@ -92,8 +88,8 @@ def _apply_transform(self, obs: torch.Tensor) -> None:
return out
def transform_observation_spec(self, observation_spec: TensorSpec) -> TensorSpec:
- if not isinstance(observation_spec, CompositeSpec):
- raise ValueError("_VIPNet can only infer CompositeSpec")
+ if not isinstance(observation_spec, Composite):
+ raise ValueError("_VIPNet can only infer Composite")
keys = [key for key in observation_spec.keys(True, True) if key in self.in_keys]
device = observation_spec[keys[0]].device
@@ -105,7 +101,7 @@ def transform_observation_spec(self, observation_spec: TensorSpec) -> TensorSpec
del observation_spec[in_key]
for out_key in self.out_keys:
- observation_spec[out_key] = UnboundedContinuousTensorSpec(
+ observation_spec[out_key] = Unbounded(
shape=torch.Size([*dim, 1024]), device=device
)
@@ -289,7 +285,7 @@ def _init(self):
unsqueeze = UnsqueezeTransform(
in_keys=in_keys,
out_keys=in_keys,
- unsqueeze_dim=-4,
+ dim=-4,
)
transforms.append(unsqueeze)
@@ -399,7 +395,7 @@ def transform_input_spec(self, input_spec: TensorSpec) -> TensorSpec:
if "full_state_spec" in input_spec.keys():
full_state_spec = input_spec["full_state_spec"]
else:
- full_state_spec = CompositeSpec(
+ full_state_spec = Composite(
shape=input_spec.shape, device=input_spec.device
)
# find the obs spec
diff --git a/torchrl/envs/utils.py b/torchrl/envs/utils.py
index ee7649fabe4..f1724326d2a 100644
--- a/torchrl/envs/utils.py
+++ b/torchrl/envs/utils.py
@@ -32,13 +32,8 @@
from tensordict.base import _is_leaf_nontensor
from tensordict.nn import TensorDictModule, TensorDictModuleBase
from tensordict.nn.probabilistic import ( # noqa
- # Note: the `set_interaction_mode` and their associated arg `default_interaction_mode` are being deprecated!
- # Please use the `set_/interaction_type` ones above with the InteractionType enum instead.
- # See more details: https://github.com/pytorch/rl/issues/1016
- interaction_mode as exploration_mode,
interaction_type as exploration_type,
InteractionType as ExplorationType,
- set_interaction_mode as set_exploration_mode,
set_interaction_type as set_exploration_type,
)
from tensordict.utils import is_non_tensor, NestedKey
@@ -47,17 +42,15 @@
from torchrl._utils import _replace_last, _rng_decorator, logger as torchrl_logger
from torchrl.data.tensor_specs import (
- CompositeSpec,
- NO_DEFAULT,
+ Composite,
+ NO_DEFAULT_RL as NO_DEFAULT,
TensorSpec,
- UnboundedContinuousTensorSpec,
+ Unbounded,
)
-from torchrl.data.utils import check_no_exclusive_keys
+from torchrl.data.utils import check_no_exclusive_keys, CloudpickleWrapper
__all__ = [
- "exploration_mode",
"exploration_type",
- "set_exploration_mode",
"set_exploration_type",
"ExplorationType",
"check_env_specs",
@@ -69,22 +62,16 @@
ACTION_MASK_ERROR = RuntimeError(
- "An out-of-bounds actions has been provided to an env with an 'action_mask' output."
- " If you are using a custom policy, make sure to take the action mask into account when computing the output."
- " If you are using a default policy, please add the torchrl.envs.transforms.ActionMask transform to your environment."
+ "An out-of-bounds actions has been provided to an env with an 'action_mask' output. "
+ "If you are using a custom policy, make sure to take the action mask into account when computing the output. "
+ "If you are using a default policy, please add the torchrl.envs.transforms.ActionMask transform to your environment. "
"If you are using a ParallelEnv or another batched inventor, "
- "make sure to add the transform to the ParallelEnv (and not to the sub-environments)."
- " For more info on using action masks, see the docs at: "
- "https://pytorch.org/rl/reference/envs.html#environments-with-masked-actions"
+ "make sure to add the transform to the ParallelEnv (and not to the sub-environments). "
+ "For more info on using action masks, see the docs at: "
+ "https://pytorch.org/rl/main/reference/envs.html#environments-with-masked-actions"
)
-def _convert_exploration_type(*, exploration_mode, exploration_type):
- if exploration_mode is not None:
- return ExplorationType.from_str(exploration_mode)
- return exploration_type
-
-
class _classproperty(property):
def __get__(self, cls, owner):
return classmethod(self.fget).__get__(None, owner)()
@@ -360,7 +347,7 @@ def step_mdp(
Given a tensordict retrieved after a step, returns the :obj:`"next"` indexed-tensordict.
The arguments allow for a precise control over what should be kept and what
- should be copied from the ``"next"`` entry. The default behaviour is:
+ should be copied from the ``"next"`` entry. The default behavior is:
move the observation entries, reward and done states to the root, exclude
the current action and keep all extra keys (non-action, non-done, non-reward).
@@ -823,7 +810,7 @@ def check_env_specs(
"you will need to first pass your stack through `torchrl.data.consolidate_spec`."
)
if spec is None:
- spec = CompositeSpec(shape=env.batch_size, device=env.device)
+ spec = Composite(shape=env.batch_size, device=env.device)
td = last_td.select(*spec.keys(True, True), strict=True)
if not spec.contains(td):
raise AssertionError(
@@ -835,7 +822,7 @@ def check_env_specs(
("obs", full_observation_spec),
):
if spec is None:
- spec = CompositeSpec(shape=env.batch_size, device=env.device)
+ spec = Composite(shape=env.batch_size, device=env.device)
td = last_td.get("next").select(*spec.keys(True, True), strict=True)
if not spec.contains(td):
raise AssertionError(
@@ -870,10 +857,10 @@ def _sort_keys(element):
def make_composite_from_td(data, unsqueeze_null_shapes: bool = True):
- """Creates a CompositeSpec instance from a tensordict, assuming all values are unbounded.
+ """Creates a Composite instance from a tensordict, assuming all values are unbounded.
Args:
- data (tensordict.TensorDict): a tensordict to be mapped onto a CompositeSpec.
+ data (tensordict.TensorDict): a tensordict to be mapped onto a Composite.
unsqueeze_null_shapes (bool, optional): if ``True``, every empty shape will be
unsqueezed to (1,). Defaults to ``True``.
@@ -886,25 +873,25 @@ def make_composite_from_td(data, unsqueeze_null_shapes: bool = True):
... }, [])
>>> spec = make_composite_from_td(data)
>>> print(spec)
- CompositeSpec(
- obs: UnboundedContinuousTensorSpec(
+ Composite(
+ obs: UnboundedContinuous(
shape=torch.Size([3]), space=None, device=cpu, dtype=torch.float32, domain=continuous),
- action: UnboundedContinuousTensorSpec(
+ action: UnboundedContinuous(
shape=torch.Size([2]), space=None, device=cpu, dtype=torch.int32, domain=continuous),
- next: CompositeSpec(
- obs: UnboundedContinuousTensorSpec(
+ next: Composite(
+ obs: UnboundedContinuous(
shape=torch.Size([3]), space=None, device=cpu, dtype=torch.float32, domain=continuous),
- reward: UnboundedContinuousTensorSpec(
+ reward: UnboundedContinuous(
shape=torch.Size([1]), space=ContinuousBox(low=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, contiguous=True), high=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, contiguous=True)), device=cpu, dtype=torch.float32, domain=continuous), device=cpu, shape=torch.Size([])), device=cpu, shape=torch.Size([]))
>>> assert (spec.zero() == data.zero_()).all()
"""
# custom funtion to convert a tensordict in a similar spec structure
# of unbounded values.
- composite = CompositeSpec(
+ composite = Composite(
{
key: make_composite_from_td(tensor)
if isinstance(tensor, TensorDictBase)
- else UnboundedContinuousTensorSpec(
+ else Unbounded(
dtype=tensor.dtype,
device=tensor.device,
shape=tensor.shape
@@ -1094,14 +1081,14 @@ def _terminated_or_truncated(
contained a ``True``.
Examples:
- >>> from torchrl.data.tensor_specs import DiscreteTensorSpec
+ >>> from torchrl.data.tensor_specs import Categorical
>>> from tensordict import TensorDict
- >>> spec = CompositeSpec(
- ... done=DiscreteTensorSpec(2, dtype=torch.bool),
- ... truncated=DiscreteTensorSpec(2, dtype=torch.bool),
- ... nested=CompositeSpec(
- ... done=DiscreteTensorSpec(2, dtype=torch.bool),
- ... truncated=DiscreteTensorSpec(2, dtype=torch.bool),
+ >>> spec = Composite(
+ ... done=Categorical(2, dtype=torch.bool),
+ ... truncated=Categorical(2, dtype=torch.bool),
+ ... nested=Composite(
+ ... done=Categorical(2, dtype=torch.bool),
+ ... truncated=Categorical(2, dtype=torch.bool),
... )
... )
>>> data = TensorDict({
@@ -1147,7 +1134,7 @@ def inner_terminated_or_truncated(data, full_done_spec, key, curr_done_key=()):
composite_spec = {}
found_leaf = 0
for eot_key, item in full_done_spec.items():
- if isinstance(item, CompositeSpec):
+ if isinstance(item, Composite):
composite_spec[eot_key] = item
else:
found_leaf += 1
@@ -1219,14 +1206,14 @@ def terminated_or_truncated(
contained a ``True``.
Examples:
- >>> from torchrl.data.tensor_specs import DiscreteTensorSpec
+ >>> from torchrl.data.tensor_specs import Categorical
>>> from tensordict import TensorDict
- >>> spec = CompositeSpec(
- ... done=DiscreteTensorSpec(2, dtype=torch.bool),
- ... truncated=DiscreteTensorSpec(2, dtype=torch.bool),
- ... nested=CompositeSpec(
- ... done=DiscreteTensorSpec(2, dtype=torch.bool),
- ... truncated=DiscreteTensorSpec(2, dtype=torch.bool),
+ >>> spec = Composite(
+ ... done=Categorical(2, dtype=torch.bool),
+ ... truncated=Categorical(2, dtype=torch.bool),
+ ... nested=Composite(
+ ... done=Categorical(2, dtype=torch.bool),
+ ... truncated=Categorical(2, dtype=torch.bool),
... )
... )
>>> data = TensorDict({
@@ -1274,7 +1261,7 @@ def inner_terminated_or_truncated(data, full_done_spec, key, curr_done_key=()):
)
else:
for eot_key, item in full_done_spec.items():
- if isinstance(item, CompositeSpec):
+ if isinstance(item, Composite):
any_eot = any_eot | inner_terminated_or_truncated(
data=data.get(eot_key),
full_done_spec=item,
@@ -1427,16 +1414,63 @@ def _repr_by_depth(key):
return (len(key) - 1, ".".join(key))
-def _make_compatible_policy(policy, observation_spec, env=None, fast_wrap=False):
+def _make_compatible_policy(
+ policy,
+ observation_spec,
+ env=None,
+ fast_wrap=False,
+ trust_policy=False,
+ env_maker=None,
+ env_maker_kwargs=None,
+):
+ if trust_policy:
+ return policy
if policy is None:
- if env is None:
- raise ValueError(
- "env must be provided to _get_policy_and_device if policy is None"
- )
- policy = RandomPolicy(env.input_spec["full_action_spec"])
- # make sure policy is an nn.Module
- policy = _NonParametricPolicyWrapper(policy)
+ input_spec = None
+ if env_maker is not None:
+ from torchrl.envs import EnvBase, EnvCreator
+
+ if isinstance(env_maker, EnvBase):
+ env = env_maker
+ input_spec = env.input_spec["full_action_spec"]
+ elif isinstance(env_maker, EnvCreator):
+ input_spec = env_maker._meta_data.specs[
+ "input_spec", "full_action_spec"
+ ]
+ else:
+ env = env_maker(**env_maker_kwargs)
+ input_spec = env.full_action_spec
+ if input_spec is None:
+ if env is not None:
+ input_spec = env.input_spec["full_action_spec"]
+ else:
+ raise ValueError(
+ "env must be provided to _get_policy_and_device if policy is None"
+ )
+
+ policy = RandomPolicy(input_spec)
+
+ # make sure policy is an nn.Module - this will return the same policy if conditions are met
+ # policy = CloudpickleWrapper(policy)
+
+ caller = getattr(policy, "forward", policy)
+
if not _policy_is_tensordict_compatible(policy):
+ if observation_spec is None:
+ if env is not None:
+ observation_spec = env.observation_spec
+ elif env_maker is not None:
+ from torchrl.envs import EnvBase, EnvCreator
+
+ if isinstance(env_maker, EnvBase):
+ observation_spec = env_maker.observation_spec
+ elif isinstance(env_maker, EnvCreator):
+ observation_spec = env_maker._meta_data.specs[
+ "output_spec", "full_observation_spec"
+ ]
+ else:
+ observation_spec = env_maker(**env_maker_kwargs).observation_spec
+
# policy is a nn.Module that doesn't operate on tensordicts directly
# so we attempt to auto-wrap policy with TensorDictModule
if observation_spec is None:
@@ -1445,13 +1479,15 @@ def _make_compatible_policy(policy, observation_spec, env=None, fast_wrap=False)
"required to check compatibility of the environment and policy "
"since the policy is a nn.Module that operates on tensors "
"rather than a TensorDictModule or a nn.Module that accepts a "
- "TensorDict as input and defines in_keys and out_keys."
+ "TensorDict as input and defines in_keys and out_keys. "
+ "If your policy is compatible with the environment, you can solve this warning by setting "
+ "trust_policy=True in the constructor."
)
try:
- sig = policy.forward.__signature__
+ sig = caller.__signature__
except AttributeError:
- sig = inspect.signature(policy.forward)
+ sig = inspect.signature(caller)
# we check if all the mandatory params are there
params = list(sig.parameters.keys())
if (
@@ -1480,7 +1516,7 @@ def _make_compatible_policy(policy, observation_spec, env=None, fast_wrap=False)
out_keys = ["action"]
else:
out_keys = list(env.action_keys)
- for p in policy.parameters():
+ for p in getattr(policy, "parameters", list)():
policy_device = p.device
break
else:
@@ -1503,7 +1539,7 @@ def _make_compatible_policy(policy, observation_spec, env=None, fast_wrap=False)
If you want TorchRL to automatically wrap your policy with a TensorDictModule
then the arguments to policy.forward must correspond one-to-one with entries
in env.observation_spec.
- For more complex behaviour and more control you can consider writing your
+ For more complex behavior and more control you can consider writing your
own TensorDictModule.
Check the collector documentation to know more about accepted policies.
"""
@@ -1512,15 +1548,20 @@ def _make_compatible_policy(policy, observation_spec, env=None, fast_wrap=False)
def _policy_is_tensordict_compatible(policy: nn.Module):
- if isinstance(policy, _NonParametricPolicyWrapper) and isinstance(
- policy.policy, RandomPolicy
- ):
- return True
+ def is_compatible(policy):
+ return isinstance(policy, (RandomPolicy, TensorDictModuleBase))
- if isinstance(policy, TensorDictModuleBase):
+ if (
+ is_compatible(policy)
+ or (
+ isinstance(policy, _NonParametricPolicyWrapper)
+ and is_compatible(policy.policy)
+ )
+ or (isinstance(policy, CloudpickleWrapper) and is_compatible(policy.fn))
+ ):
return True
- sig = inspect.signature(policy.forward)
+ sig = inspect.signature(getattr(policy, "forward", policy))
if (
len(sig.parameters) == 1
@@ -1541,7 +1582,7 @@ def _policy_is_tensordict_compatible(policy: nn.Module):
# if in_keys or out_keys were defined but policy is not a TensorDictModule or
# accepts multiple arguments then it's likely the user is trying to do something
- # that will have undetermined behaviour, we raise an error
+ # that will have undetermined behavior, we raise an error
raise TypeError(
"Received a policy that defines in_keys or out_keys and also expects multiple "
"arguments to policy.forward. If the policy is compatible with TensorDict, it "
@@ -1562,8 +1603,8 @@ class RandomPolicy:
Examples:
>>> from tensordict import TensorDict
- >>> from torchrl.data.tensor_specs import BoundedTensorSpec
- >>> action_spec = BoundedTensorSpec(-torch.ones(3), torch.ones(3))
+ >>> from torchrl.data.tensor_specs import Bounded
+ >>> action_spec = Bounded(-torch.ones(3), torch.ones(3))
>>> actor = RandomPolicy(action_spec=action_spec)
>>> td = actor(TensorDict({}, batch_size=[])) # selects a random action in the cube [-1; 1]
"""
@@ -1574,7 +1615,7 @@ def __init__(self, action_spec: TensorSpec, action_key: NestedKey = "action"):
self.action_key = action_key
def __call__(self, td: TensorDictBase) -> TensorDictBase:
- if isinstance(self.action_spec, CompositeSpec):
+ if isinstance(self.action_spec, Composite):
return td.update(self.action_spec.rand())
else:
return td.set(self.action_key, self.action_spec.rand())
@@ -1593,19 +1634,10 @@ class _NonParametricPolicyWrapper(nn.Module, metaclass=_PolicyMetaClass):
def __init__(self, policy):
super().__init__()
- self.policy = policy
-
- @property
- def forward(self):
- forward = self.__dict__.get("_forward", None)
- if forward is None:
-
- @functools.wraps(self.policy)
- def forward(*input, **kwargs):
- return self.policy.__call__(*input, **kwargs)
-
- self.__dict__["_forward"] = forward
- return forward
+ functools.update_wrapper(self, policy)
+ self.policy = CloudpickleWrapper(policy)
+ if hasattr(policy, "forward"):
+ self.forward = self.policy.forward
def __getattr__(self, attr: str) -> Any:
if attr in self.__dir__():
diff --git a/torchrl/modules/__init__.py b/torchrl/modules/__init__.py
index 0a06e5844a0..f65461842bb 100644
--- a/torchrl/modules/__init__.py
+++ b/torchrl/modules/__init__.py
@@ -20,6 +20,8 @@
TruncatedNormal,
)
from .models import (
+ BatchRenorm1d,
+ ConsistentDropoutModule,
Conv3dNet,
ConvNet,
DdpgCnnActor,
@@ -84,4 +86,5 @@
VmapModule,
WorldModelWrapper,
)
+from .utils import get_primers_from_module
from .planners import CEMPlanner, MPCPlannerBase, MPPIPlanner # usort:skip
diff --git a/torchrl/modules/distributions/continuous.py b/torchrl/modules/distributions/continuous.py
index fddc2f3415d..8b0d5654b8d 100644
--- a/torchrl/modules/distributions/continuous.py
+++ b/torchrl/modules/distributions/continuous.py
@@ -5,13 +5,21 @@
from __future__ import annotations
import warnings
+import weakref
from numbers import Number
from typing import Dict, Optional, Sequence, Tuple, Union
import numpy as np
import torch
from torch import distributions as D, nn
+
+try:
+ from torch.compiler import assume_constant_result
+except ImportError:
+ from torch._dynamo import assume_constant_result
+
from torch.distributions import constraints
+from torch.distributions.transforms import _InverseTransform
from torchrl.modules.distributions.truncated_normal import (
TruncatedNormal as _TruncatedNormal,
@@ -20,14 +28,19 @@
from torchrl.modules.distributions.utils import (
_cast_device,
FasterTransformedDistribution,
- safeatanh,
- safetanh,
+ safeatanh_noeps,
+ safetanh_noeps,
)
from torchrl.modules.utils import mappings
# speeds up distribution construction
D.Distribution.set_default_validate_args(False)
+try:
+ from torch.compiler import is_dynamo_compiling
+except ImportError:
+ from torch._dynamo import is_compiling as is_dynamo_compiling
+
class IndependentNormal(D.Independent):
"""Implements a Normal distribution with location scaling.
@@ -39,7 +52,7 @@ class IndependentNormal(D.Independent):
.. math::
loc = tanh(loc / upscale) * upscale.
- This behaviour can be disabled by switching off the tanh_loc parameter (see below).
+ This behavior can be disabled by switching off the tanh_loc parameter (see below).
Args:
@@ -92,19 +105,21 @@ class SafeTanhTransform(D.TanhTransform):
"""TanhTransform subclass that ensured that the transformation is numerically invertible."""
def _call(self, x: torch.Tensor) -> torch.Tensor:
- if x.dtype.is_floating_point:
- eps = torch.finfo(x.dtype).resolution
- else:
- raise NotImplementedError(f"No tanh transform for {x.dtype} inputs.")
- return safetanh(x, eps)
+ return safetanh_noeps(x)
def _inverse(self, y: torch.Tensor) -> torch.Tensor:
- if y.dtype.is_floating_point:
- eps = torch.finfo(y.dtype).resolution
- else:
- raise NotImplementedError(f"No inverse tanh for {y.dtype} inputs.")
- x = safeatanh(y, eps)
- return x
+ return safeatanh_noeps(y)
+
+ @property
+ def inv(self):
+ inv = None
+ if self._inv is not None:
+ inv = self._inv()
+ if inv is None:
+ inv = _InverseTransform(self)
+ if not is_dynamo_compiling():
+ self._inv = weakref.ref(inv)
+ return inv
class NormalParamWrapper(nn.Module):
@@ -173,7 +188,7 @@ class TruncatedNormal(D.Independent):
.. math::
loc = tanh(loc / upscale) * upscale.
- This behaviour can be disabled by switching off the tanh_loc parameter (see below).
+ This behavior can be disabled by switching off the tanh_loc parameter (see below).
Args:
@@ -202,13 +217,6 @@ class TruncatedNormal(D.Independent):
"scale": constraints.greater_than(1e-6),
}
- def _warn_minmax(self):
- warnings.warn(
- f"the min / high keyword arguments are deprecated in favor of low / high in {type(self).__name__} "
- f"and will be removed entirely in v0.6. ",
- DeprecationWarning,
- )
-
def __init__(
self,
loc: torch.Tensor,
@@ -217,14 +225,7 @@ def __init__(
low: Union[torch.Tensor, float] = -1.0,
high: Union[torch.Tensor, float] = 1.0,
tanh_loc: bool = False,
- **kwargs,
):
- if "max" in kwargs:
- self._warn_minmax()
- high = kwargs.pop("max")
- if "min" in kwargs:
- self._warn_minmax()
- low = kwargs.pop("min")
err_msg = "TanhNormal high values must be strictly greater than low values"
if isinstance(high, torch.Tensor) or isinstance(low, torch.Tensor):
@@ -316,6 +317,33 @@ def log_prob(self, value, **kwargs):
return lp
+class _PatchedComposeTransform(D.ComposeTransform):
+ @property
+ def inv(self):
+ inv = None
+ if self._inv is not None:
+ inv = self._inv()
+ if inv is None:
+ inv = _PatchedComposeTransform([p.inv for p in reversed(self.parts)])
+ if not is_dynamo_compiling():
+ self._inv = weakref.ref(inv)
+ inv._inv = weakref.ref(self)
+ return inv
+
+
+class _PatchedAffineTransform(D.AffineTransform):
+ @property
+ def inv(self):
+ inv = None
+ if self._inv is not None:
+ inv = self._inv()
+ if inv is None:
+ inv = _InverseTransform(self)
+ if not is_dynamo_compiling():
+ self._inv = weakref.ref(inv)
+ return inv
+
+
class TanhNormal(FasterTransformedDistribution):
"""Implements a TanhNormal distribution with location scaling.
@@ -337,13 +365,15 @@ class TanhNormal(FasterTransformedDistribution):
.. math::
loc = tanh(loc / upscale) * upscale.
- min (torch.Tensor or number, optional): minimum value of the distribution. Default is -1.0;
- max (torch.Tensor or number, optional): maximum value of the distribution. Default is 1.0;
+ low (torch.Tensor or number, optional): minimum value of the distribution. Default is -1.0;
+ high (torch.Tensor or number, optional): maximum value of the distribution. Default is 1.0;
event_dims (int, optional): number of dimensions describing the action.
Default is 1. Setting ``event_dims`` to ``0`` will result in a log-probability that has the same shape
as the input, ``1`` will reduce (sum over) the last dimension, ``2`` the last two etc.
tanh_loc (bool, optional): if ``True``, the above formula is used for the location scaling, otherwise the raw
value is kept. Default is ``False``;
+ safe_tanh (bool, optional): if ``True``, the Tanh transform is done "safely", to avoid numerical overflows.
+ This will currently break with :func:`torch.compile`.
"""
arg_constraints = {
@@ -353,13 +383,6 @@ class TanhNormal(FasterTransformedDistribution):
num_params = 2
- def _warn_minmax(self):
- warnings.warn(
- f"the min / high keyword arguments are deprecated in favor of low / high in {type(self).__name__} "
- f"and will be removed entirely in v0.6. ",
- DeprecationWarning,
- )
-
def __init__(
self,
loc: torch.Tensor,
@@ -369,15 +392,9 @@ def __init__(
high: Union[torch.Tensor, Number] = 1.0,
event_dims: int | None = None,
tanh_loc: bool = False,
+ safe_tanh: bool = True,
**kwargs,
):
- if "max" in kwargs:
- self._warn_minmax()
- high = kwargs.pop("max")
- if "min" in kwargs:
- self._warn_minmax()
- low = kwargs.pop("min")
-
if not isinstance(loc, torch.Tensor):
loc = torch.as_tensor(loc, dtype=torch.get_default_dtype())
if not isinstance(scale, torch.Tensor):
@@ -419,13 +436,20 @@ def __init__(
self.low = low
self.high = high
- t = SafeTanhTransform()
+ if safe_tanh:
+ if is_dynamo_compiling():
+ _err_compile_safetanh()
+ t = SafeTanhTransform()
+ else:
+ t = D.TanhTransform()
# t = D.TanhTransform()
- if self.non_trivial_max or self.non_trivial_min:
- t = D.ComposeTransform(
+ if is_dynamo_compiling() or (self.non_trivial_max or self.non_trivial_min):
+ t = _PatchedComposeTransform(
[
t,
- D.AffineTransform(loc=(high + low) / 2, scale=(high - low) / 2),
+ _PatchedAffineTransform(
+ loc=(high + low) / 2, scale=(high - low) / 2
+ ),
]
)
self._t = t
@@ -446,7 +470,7 @@ def update(self, loc: torch.Tensor, scale: torch.Tensor) -> None:
if self.tanh_loc:
loc = (loc / self.upscale).tanh() * self.upscale
# loc must be rescaled if tanh_loc
- if self.non_trivial_max or self.non_trivial_min:
+ if is_dynamo_compiling() or (self.non_trivial_max or self.non_trivial_min):
loc = loc + (self.high - self.low) / 2 + self.low
self.loc = loc
self.scale = scale
@@ -466,6 +490,10 @@ def update(self, loc: torch.Tensor, scale: torch.Tensor) -> None:
base = D.Normal(self.loc, self.scale)
super().__init__(base, self._t)
+ @property
+ def support(self):
+ return D.constraints.real()
+
@property
def root_dist(self):
bd = self
@@ -475,15 +503,10 @@ def root_dist(self):
@property
def mode(self):
- warnings.warn(
- "This computation of the mode is based on an inaccurate estimation of the mode "
- "given the base_dist mode. "
- "To use a more stable implementation of the mode, use dist.get_mode() method instead. "
- "To silence this warning, consider using the DETERMINISTIC exploration_type."
- "This implementation will be removed in v0.6.",
- category=DeprecationWarning,
+ raise RuntimeError(
+ f"The distribution {type(self).__name__} has not analytical mode. "
+ f"Use ExplorationMode.DETERMINISTIC to get a deterministic sample from it."
)
- return self.deterministic_sample
@property
def deterministic_sample(self):
@@ -647,13 +670,6 @@ class TanhDelta(FasterTransformedDistribution):
"loc": constraints.real,
}
- def _warn_minmax(self):
- warnings.warn(
- f"the min / high keyword arguments are deprecated in favor of low / high in {type(self).__name__} "
- f"and will be removed entirely in v0.6. ",
- category=DeprecationWarning,
- )
-
def __init__(
self,
param: torch.Tensor,
@@ -662,15 +678,7 @@ def __init__(
event_dims: int = 1,
atol: float = 1e-6,
rtol: float = 1e-6,
- **kwargs,
):
- if "max" in kwargs:
- self._warn_minmax()
- high = kwargs.pop("max")
- if "min" in kwargs:
- self._warn_minmax()
- low = kwargs.pop("min")
-
minmax_msg = "high value has been found to be equal or less than low value"
if isinstance(high, torch.Tensor) or isinstance(low, torch.Tensor):
if not (high > low).all():
@@ -696,10 +704,10 @@ def __init__(
loc = self.update(param)
if self.non_trivial:
- t = D.ComposeTransform(
+ t = _PatchedComposeTransform(
[
t,
- D.AffineTransform(
+ _PatchedAffineTransform(
loc=(self.high + self.low) / 2, scale=(self.high - self.low) / 2
),
]
@@ -712,7 +720,6 @@ def __init__(
rtol=rtol,
batch_shape=batch_shape,
event_shape=event_shape,
- **kwargs,
)
super().__init__(base, t)
@@ -761,3 +768,16 @@ def _uniform_sample_delta(dist: Delta, size=None) -> torch.Tensor:
uniform_sample_delta = _uniform_sample_delta
+
+
+def _err_compile_safetanh():
+ raise RuntimeError(
+ "safe_tanh=True in TanhNormal is not compatible with torch.compile. To deactivate it, pass"
+ "safe_tanh=False. "
+ "If you are using a ProbabilisticTensorDictModule, this can be done via "
+ "`distribution_kwargs={'safe_tanh': False}`. "
+ "See https://github.com/pytorch/pytorch/issues/133529 for more details."
+ )
+
+
+_warn_compile_safetanh = assume_constant_result(_err_compile_safetanh)
diff --git a/torchrl/modules/distributions/discrete.py b/torchrl/modules/distributions/discrete.py
index c48d8168887..d2ffba30686 100644
--- a/torchrl/modules/distributions/discrete.py
+++ b/torchrl/modules/distributions/discrete.py
@@ -389,6 +389,17 @@ def sample(
) -> torch.Tensor:
...
+ @property
+ def deterministic_sample(self):
+ return self.mode
+
+ @property
+ def mode(self) -> torch.Tensor:
+ if hasattr(self, "logits"):
+ return (self.logits == self.logits.max(-1, True)[0]).to(torch.long)
+ else:
+ return (self.probs == self.probs.max(-1, True)[0]).to(torch.long)
+
def log_prob(self, value: torch.Tensor) -> torch.Tensor:
return super().log_prob(value.argmax(dim=-1))
diff --git a/torchrl/modules/distributions/utils.py b/torchrl/modules/distributions/utils.py
index 267632c4fd9..546d93cb228 100644
--- a/torchrl/modules/distributions/utils.py
+++ b/torchrl/modules/distributions/utils.py
@@ -6,7 +6,6 @@
from typing import Union
import torch
-from packaging import version
from torch import autograd, distributions as d
from torch.distributions import Independent, Transform, TransformedDistribution
@@ -92,72 +91,133 @@ def __init__(self, base_distribution, transforms, validate_args=None):
)
-if version.parse(torch.__version__) >= version.parse("2.0.0"):
-
- class _SafeTanh(autograd.Function):
- generate_vmap_rule = True
-
- @staticmethod
- def forward(input, eps):
- output = input.tanh()
- lim = 1.0 - eps
- output = output.clamp(-lim, lim)
- # ctx.save_for_backward(output)
- return output
-
- @staticmethod
- def setup_context(ctx, inputs, output):
- # input, eps = inputs
- # ctx.mark_non_differentiable(ind, ind_inv)
- # # Tensors must be saved via ctx.save_for_backward. Please do not
- # # assign them directly onto the ctx object.
- ctx.save_for_backward(output)
-
- @staticmethod
- def backward(ctx, *grad):
- grad = grad[0]
- (output,) = ctx.saved_tensors
- return (grad * (1 - output.pow(2)), None)
-
- class _SafeaTanh(autograd.Function):
- generate_vmap_rule = True
-
- @staticmethod
- def setup_context(ctx, inputs, output):
- tanh_val, eps = inputs
- # ctx.mark_non_differentiable(ind, ind_inv)
- # # Tensors must be saved via ctx.save_for_backward. Please do not
- # # assign them directly onto the ctx object.
- ctx.save_for_backward(tanh_val)
- ctx.eps = eps
-
- @staticmethod
- def forward(tanh_val, eps):
- lim = 1.0 - eps
- output = tanh_val.clamp(-lim, lim)
- # ctx.save_for_backward(output)
- output = output.atanh()
- return output
-
- @staticmethod
- def backward(ctx, *grad):
- grad = grad[0]
- (tanh_val,) = ctx.saved_tensors
- eps = ctx.eps
- lim = 1.0 - eps
- output = tanh_val.clamp(-lim, lim)
- return (grad / (1 - output.pow(2)), None)
-
- safetanh = _SafeTanh.apply
- safeatanh = _SafeaTanh.apply
-
-else:
-
- def safetanh(x, eps): # noqa: D103
+def _safetanh(x, eps): # noqa: D103
+ lim = 1.0 - eps
+ y = x.tanh()
+ return y.clamp(-lim, lim)
+
+
+def _safeatanh(y, eps): # noqa: D103
+ lim = 1.0 - eps
+ return y.clamp(-lim, lim).atanh()
+
+
+class _SafeTanh(autograd.Function):
+ generate_vmap_rule = True
+
+ @staticmethod
+ def forward(input, eps):
+ output = input.tanh()
lim = 1.0 - eps
- y = x.tanh()
- return y.clamp(-lim, lim)
+ output = output.clamp(-lim, lim)
+ # ctx.save_for_backward(output)
+ return output
+
+ @staticmethod
+ def setup_context(ctx, inputs, output):
+ # input, eps = inputs
+ # ctx.mark_non_differentiable(ind, ind_inv)
+ # # Tensors must be saved via ctx.save_for_backward. Please do not
+ # # assign them directly onto the ctx object.
+ ctx.save_for_backward(output)
+
+ @staticmethod
+ def backward(ctx, *grad):
+ grad = grad[0]
+ (output,) = ctx.saved_tensors
+ return (grad * (1 - output.pow(2)), None)
+
+
+class _SafeTanhNoEps(autograd.Function):
+ generate_vmap_rule = True
+
+ @staticmethod
+ def forward(input):
+ output = input.tanh()
+ eps = torch.finfo(input.dtype).resolution
+ lim = 1.0 - eps
+ output = output.clamp(-lim, lim)
+ return output
+
+ @staticmethod
+ def setup_context(ctx, inputs, output):
+ ctx.save_for_backward(output)
+
+ @staticmethod
+ def backward(ctx, *grad):
+ grad = grad[0]
+ (output,) = ctx.saved_tensors
+ return (grad * (1 - output.pow(2)),)
+
+
+class _SafeaTanh(autograd.Function):
+ generate_vmap_rule = True
- def safeatanh(y, eps): # noqa: D103
+ @staticmethod
+ def forward(tanh_val, eps):
+ if eps is None:
+ eps = torch.finfo(tanh_val.dtype).resolution
lim = 1.0 - eps
- return y.clamp(-lim, lim).atanh()
+ output = tanh_val.clamp(-lim, lim)
+ # ctx.save_for_backward(output)
+ output = output.atanh()
+ return output
+
+ @staticmethod
+ def setup_context(ctx, inputs, output):
+ tanh_val, eps = inputs
+
+ # ctx.mark_non_differentiable(ind, ind_inv)
+ # # Tensors must be saved via ctx.save_for_backward. Please do not
+ # # assign them directly onto the ctx object.
+ ctx.save_for_backward(tanh_val)
+ ctx.eps = eps
+
+ @staticmethod
+ def backward(ctx, *grad):
+ grad = grad[0]
+ (tanh_val,) = ctx.saved_tensors
+ eps = ctx.eps
+ lim = 1.0 - eps
+ output = tanh_val.clamp(-lim, lim)
+ return (grad / (1 - output.pow(2)), None)
+
+
+class _SafeaTanhNoEps(autograd.Function):
+ generate_vmap_rule = True
+
+ @staticmethod
+ def forward(tanh_val):
+ eps = torch.finfo(tanh_val.dtype).resolution
+ lim = 1.0 - eps
+ output = tanh_val.clamp(-lim, lim)
+ # ctx.save_for_backward(output)
+ output = output.atanh()
+ return output
+
+ @staticmethod
+ def setup_context(ctx, inputs, output):
+ tanh_val = inputs[0]
+ eps = torch.finfo(tanh_val.dtype).resolution
+
+ # ctx.mark_non_differentiable(ind, ind_inv)
+ # # Tensors must be saved via ctx.save_for_backward. Please do not
+ # # assign them directly onto the ctx object.
+ ctx.save_for_backward(tanh_val)
+ ctx.eps = eps
+
+ @staticmethod
+ def backward(ctx, *grad):
+ grad = grad[0]
+ (tanh_val,) = ctx.saved_tensors
+ eps = ctx.eps
+ lim = 1.0 - eps
+ output = tanh_val.clamp(-lim, lim)
+ return (grad / (1 - output.pow(2)),)
+
+
+safetanh = _SafeTanh.apply
+safeatanh = _SafeaTanh.apply
+
+safetanh_noeps = _SafeTanhNoEps.apply
+safeatanh_noeps = _SafeaTanhNoEps.apply
diff --git a/torchrl/modules/models/__init__.py b/torchrl/modules/models/__init__.py
index 9a814e35477..90b9fadd747 100644
--- a/torchrl/modules/models/__init__.py
+++ b/torchrl/modules/models/__init__.py
@@ -9,7 +9,12 @@
from .batchrenorm import BatchRenorm1d
from .decision_transformer import DecisionTransformer
-from .exploration import NoisyLazyLinear, NoisyLinear, reset_noise
+from .exploration import (
+ ConsistentDropoutModule,
+ NoisyLazyLinear,
+ NoisyLinear,
+ reset_noise,
+)
from .model_based import (
DreamerActor,
ObsDecoder,
diff --git a/torchrl/modules/models/batchrenorm.py b/torchrl/modules/models/batchrenorm.py
index 26a2f9d50d2..41de0945f70 100644
--- a/torchrl/modules/models/batchrenorm.py
+++ b/torchrl/modules/models/batchrenorm.py
@@ -32,9 +32,9 @@ class BatchRenorm1d(nn.Module):
Defaults to ``5.0``.
warmup_steps (int, optional): Number of warm-up steps for the running mean and variance.
Defaults to ``10000``.
- smooth (bool, optional): if ``True``, the behaviour smoothly transitions from regular
+ smooth (bool, optional): if ``True``, the behavior smoothly transitions from regular
batch-norm (when ``iter=0``) to batch-renorm (when ``iter=warmup_steps``).
- Otherwise, the behaviour will transition from batch-norm to batch-renorm when
+ Otherwise, the behavior will transition from batch-norm to batch-renorm when
``iter=warmup_steps``. Defaults to ``False``.
"""
diff --git a/torchrl/modules/models/exploration.py b/torchrl/modules/models/exploration.py
index 2ec51b46559..d69a85fd685 100644
--- a/torchrl/modules/models/exploration.py
+++ b/torchrl/modules/models/exploration.py
@@ -2,16 +2,24 @@
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
+from __future__ import annotations
+
+import functools
import math
import warnings
-from typing import Optional, Sequence, Union
+from typing import List, Optional, Sequence, Union
import torch
+
+from tensordict.nn import TensorDictModuleBase
+from tensordict.utils import NestedKey
from torch import distributions as d, nn
+from torch.nn import functional as F
+from torch.nn.modules.dropout import _DropoutNd
from torch.nn.modules.lazy import LazyModuleMixin
from torch.nn.parameter import UninitializedBuffer, UninitializedParameter
-
from torchrl._utils import prod
+from torchrl.data.tensor_specs import Unbounded
from torchrl.data.utils import DEVICE_TYPING, DEVICE_TYPING_ARGS
from torchrl.envs.utils import exploration_type, ExplorationType
from torchrl.modules.distributions.utils import _cast_transform_device
@@ -359,7 +367,7 @@ def sigma(self):
def forward(self, mu, state, _eps_gSDE):
sigma = self.sigma.clamp_max(self.scale_max)
- _err_explo = f"gSDE behaviour for exploration mode {exploration_type()} is not defined. Choose from 'random' or 'mode'."
+ _err_explo = f"gSDE behavior for exploration mode {exploration_type()} is not defined. Choose from 'random' or 'mode'."
if state.shape[:-1] != mu.shape[:-1]:
_err_msg = f"mu and state are expected to have matching batch size, got shapes {mu.shape} and {state.shape}"
@@ -520,3 +528,203 @@ def initialize_parameters(
)
self._sigma.materialize((action_dim, state_dim))
self._sigma.data.copy_(self.sigma_init.expand_as(self._sigma))
+
+
+class ConsistentDropout(_DropoutNd):
+ """Implements a :class:`~torch.nn.Dropout` variant with consistent dropout.
+
+ This method is proposed in `"Consistent Dropout for Policy Gradient Reinforcement Learning" (Hausknecht & Wagener, 2022)
+ `_.
+
+ This :class:`~torch.nn.Dropout` variant attempts to increase training stability and
+ reduce update variance by caching the dropout masks used during rollout
+ and reusing them during the update phase.
+
+ The class you are looking at is independent of the rest of TorchRL's API and does not require tensordict to be run.
+ :class:`~torchrl.modules.ConsistentDropoutModule` is a wrapper around ``ConsistentDropout`` that capitalizes on the extensibility
+ of ``TensorDict``s by storing generated dropout masks in the transition ``TensorDict`` themselves.
+ See this class for a detailed explanation as well as usage examples.
+
+ There is otherwise little conceptual deviance from the PyTorch
+ :class:`~torch.nn.Dropout` implementation.
+
+ ..note:: TorchRL's data collectors perform rollouts in :meth:`~torch.no_grad` mode but not in `eval` mode,
+ so the dropout masks will be applied unless the policy passed to the collector is in eval mode.
+
+ .. note:: Unlike other exploration modules, :class:`~torchrl.modules.ConsistentDropoutModule`
+ uses the ``train``/``eval`` mode to comply with the regular `Dropout` API in PyTorch.
+ The :func:`~torchrl.envs.utils.set_exploration_type` context manager will have no effect on
+ this module.
+
+ Args:
+ p (float, optional): Dropout probability. Defaults to ``0.5``.
+
+ .. seealso::
+
+ - :class:`~torchrl.collectors.SyncDataCollector`:
+ :meth:`~torchrl.collectors.SyncDataCollector.rollout()` and :meth:`~torchrl.collectors.SyncDataCollector.iterator()`
+ - :class:`~torchrl.collectors.MultiSyncDataCollector`:
+ Uses :meth:`~torchrl.collectors.collectors._main_async_collector` (:class:`~torchrl.collectors.SyncDataCollector`)
+ under the hood
+ - :class:`~torchrl.collectors.MultiaSyncDataCollector`, :class:`~torchrl.collectors.aSyncDataCollector`: Ditto.
+
+ """
+
+ def __init__(self, p: float = 0.5):
+ super().__init__()
+ self.p = p
+
+ def forward(
+ self, x: torch.Tensor, mask: torch.Tensor | None = None
+ ) -> torch.Tensor:
+ """During training (rollouts & updates), this call masks a tensor full of ones before multiplying with the input tensor.
+
+ During evaluation, this call results in a no-op and only the input is returned.
+
+ Args:
+ x (torch.Tensor): the input tensor.
+ mask (torch.Tensor, optional): the optional mask for the dropout.
+
+ Returns: a tensor and a corresponding mask in train mode, and only a tensor in eval mode.
+ """
+ if self.training:
+ if mask is None:
+ mask = self.make_mask(input=x)
+ return x * mask, mask
+
+ return x
+
+ def make_mask(self, *, input=None, shape=None):
+ if input is not None:
+ return F.dropout(
+ torch.ones_like(input), self.p, self.training, inplace=False
+ )
+ elif shape is not None:
+ return F.dropout(torch.ones(shape), self.p, self.training, inplace=False)
+ else:
+ raise RuntimeError("input or shape must be passed to make_mask.")
+
+
+class ConsistentDropoutModule(TensorDictModuleBase):
+ """A TensorDictModule wrapper for :class:`~ConsistentDropout`.
+
+ Args:
+ p (float, optional): Dropout probability. Default: ``0.5``.
+ in_keys (NestedKey or list of NestedKeys): keys to be read
+ from input tensordict and passed to this module.
+ out_keys (NestedKey or iterable of NestedKeys): keys to be written to the input tensordict.
+ Defaults to ``in_keys`` values.
+
+ Keyword Args:
+ input_shape (tuple, optional): the shape of the input (non-batchted), used to generate the
+ tensordict primers with :meth:`~.make_tensordict_primer`.
+ input_dtype (torch.dtype, optional): the dtype of the input for the primer. If none is pased,
+ ``torch.get_default_dtype`` is assumed.
+
+ .. note:: To use this class within a policy, one needs the mask to be reset at reset time.
+ This can be achieved through a :class:`~torchrl.envs.TensorDictPrimer` transform that can be obtained
+ with :meth:`~.make_tensordict_primer`. See this method for more information.
+
+ Examples:
+ >>> from tensordict import TensorDict
+ >>> module = ConsistentDropoutModule(p = 0.1)
+ >>> td = TensorDict({"x": torch.randn(3, 4)}, [3])
+ >>> module(td)
+ TensorDict(
+ fields={
+ mask_6127171760: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.bool, is_shared=False),
+ x: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False)},
+ batch_size=torch.Size([3]),
+ device=None,
+ is_shared=False)
+ """
+
+ def __init__(
+ self,
+ p: float,
+ in_keys: NestedKey | List[NestedKey],
+ out_keys: NestedKey | List[NestedKey] | None = None,
+ input_shape: torch.Size = None,
+ input_dtype: torch.dtype | None = None,
+ ):
+ if isinstance(in_keys, NestedKey):
+ in_keys = [in_keys, f"mask_{id(self)}"]
+ if out_keys is None:
+ out_keys = list(in_keys)
+ if isinstance(out_keys, NestedKey):
+ out_keys = [out_keys, f"mask_{id(self)}"]
+ if len(in_keys) != 2 or len(out_keys) != 2:
+ raise ValueError(
+ "in_keys and out_keys length must be 2 for consistent dropout."
+ )
+ self.in_keys = in_keys
+ self.out_keys = out_keys
+ self.input_shape = input_shape
+ self.input_dtype = input_dtype
+ super().__init__()
+
+ if not 0 <= p < 1:
+ raise ValueError(f"p must be in [0,1), got p={p: 4.4f}.")
+
+ self.consistent_dropout = ConsistentDropout(p)
+
+ def forward(self, tensordict):
+ x = tensordict.get(self.in_keys[0])
+ mask = tensordict.get(self.in_keys[1], default=None)
+ if self.consistent_dropout.training:
+ x, mask = self.consistent_dropout(x, mask=mask)
+ tensordict.set(self.out_keys[0], x)
+ tensordict.set(self.out_keys[1], mask)
+ else:
+ x = self.consistent_dropout(x, mask=mask)
+ tensordict.set(self.out_keys[0], x)
+
+ return tensordict
+
+ def make_tensordict_primer(self):
+ """Makes a tensordict primer for the environment to generate random masks during reset calls.
+
+ .. seealso:: :func:`torchrl.modules.utils.get_primers_from_module` for a method to generate all primers for a given
+ module.
+
+ Examples:
+ >>> from tensordict.nn import TensorDictSequential as Seq, TensorDictModule as Mod
+ >>> from torchrl.envs import GymEnv, StepCounter, SerialEnv
+ >>> m = Seq(
+ ... Mod(torch.nn.Linear(7, 4), in_keys=["observation"], out_keys=["intermediate"]),
+ ... ConsistentDropoutModule(
+ ... p=0.5,
+ ... input_shape=(2, 4),
+ ... in_keys="intermediate",
+ ... ),
+ ... Mod(torch.nn.Linear(4, 7), in_keys=["intermediate"], out_keys=["action"]),
+ ... )
+ >>> primer = get_primers_from_module(m)
+ >>> env0 = GymEnv("Pendulum-v1").append_transform(StepCounter(5))
+ >>> env1 = GymEnv("Pendulum-v1").append_transform(StepCounter(6))
+ >>> env = SerialEnv(2, [lambda env=env0: env, lambda env=env1: env])
+ >>> env = env.append_transform(primer)
+ >>> r = env.rollout(10, m, break_when_any_done=False)
+ >>> mask = [k for k in r.keys() if k.startswith("mask")][0]
+ >>> assert (r[mask][0, :5] != r[mask][0, 5:6]).any()
+ >>> assert (r[mask][0, :4] == r[mask][0, 4:5]).all()
+
+ """
+ from torchrl.envs.transforms.transforms import TensorDictPrimer
+
+ shape = self.input_shape
+ dtype = self.input_dtype
+ if dtype is None:
+ dtype = torch.get_default_dtype()
+ if shape is None:
+ raise RuntimeError(
+ "Cannot infer the shape of the input automatically. "
+ "Please pass the shape of the tensor to `ConstistentDropoutModule` during construction "
+ "with the `input_shape` kwarg."
+ )
+ return TensorDictPrimer(
+ primers={self.in_keys[1]: Unbounded(dtype=dtype, shape=shape)},
+ default_value=functools.partial(
+ self.consistent_dropout.make_mask, shape=shape
+ ),
+ )
diff --git a/torchrl/modules/models/models.py b/torchrl/modules/models/models.py
index 23c229c6524..3faaa396299 100644
--- a/torchrl/modules/models/models.py
+++ b/torchrl/modules/models/models.py
@@ -5,6 +5,7 @@
from __future__ import annotations
import dataclasses
+import warnings
from copy import deepcopy
from numbers import Number
@@ -179,8 +180,15 @@ def __init__(
if out_features is None:
raise ValueError("out_features must be specified for MLP.")
- default_num_cells = 32
if num_cells is None:
+ warnings.warn(
+ "The current behavior of MLP when not providing `num_cells` is that the number of cells is "
+ "set to [default_num_cells] * depth, where `depth=3` by default and `default_num_cells=0`. "
+ "From v0.7, this behavior will switch and `depth=0` will be used. "
+ "To silence tis message, indicate what number of cells you desire.",
+ category=DeprecationWarning,
+ )
+ default_num_cells = 32
if depth is None:
num_cells = [default_num_cells] * 3
depth = 3
diff --git a/torchrl/modules/planners/cem.py b/torchrl/modules/planners/cem.py
index 6d9e6fb3b49..abc0e3d3f95 100644
--- a/torchrl/modules/planners/cem.py
+++ b/torchrl/modules/planners/cem.py
@@ -45,20 +45,20 @@ class CEMPlanner(MPCPlannerBase):
Examples:
>>> from tensordict import TensorDict
- >>> from torchrl.data import CompositeSpec, UnboundedContinuousTensorSpec
+ >>> from torchrl.data import Composite, Unbounded
>>> from torchrl.envs.model_based import ModelBasedEnvBase
>>> from torchrl.modules import SafeModule
>>> class MyMBEnv(ModelBasedEnvBase):
... def __init__(self, world_model, device="cpu", dtype=None, batch_size=None):
... super().__init__(world_model, device=device, dtype=dtype, batch_size=batch_size)
- ... self.state_spec = CompositeSpec(
- ... hidden_observation=UnboundedContinuousTensorSpec((4,))
+ ... self.state_spec = Composite(
+ ... hidden_observation=Unbounded((4,))
... )
- ... self.observation_spec = CompositeSpec(
- ... hidden_observation=UnboundedContinuousTensorSpec((4,))
+ ... self.observation_spec = Composite(
+ ... hidden_observation=Unbounded((4,))
... )
- ... self.action_spec = UnboundedContinuousTensorSpec((1,))
- ... self.reward_spec = UnboundedContinuousTensorSpec((1,))
+ ... self.action_spec = Unbounded((1,))
+ ... self.reward_spec = Unbounded((1,))
...
... def _reset(self, tensordict: TensorDict) -> TensorDict:
... tensordict = TensorDict(
diff --git a/torchrl/modules/planners/mppi.py b/torchrl/modules/planners/mppi.py
index 9c0bbc8f147..002094fb5d2 100644
--- a/torchrl/modules/planners/mppi.py
+++ b/torchrl/modules/planners/mppi.py
@@ -43,7 +43,7 @@ class MPPIPlanner(MPCPlannerBase):
Examples:
>>> from tensordict import TensorDict
- >>> from torchrl.data import CompositeSpec, UnboundedContinuousTensorSpec
+ >>> from torchrl.data import Composite, Unbounded
>>> from torchrl.envs.model_based import ModelBasedEnvBase
>>> from tensordict.nn import TensorDictModule
>>> from torchrl.modules import ValueOperator
@@ -51,14 +51,14 @@ class MPPIPlanner(MPCPlannerBase):
>>> class MyMBEnv(ModelBasedEnvBase):
... def __init__(self, world_model, device="cpu", dtype=None, batch_size=None):
... super().__init__(world_model, device=device, dtype=dtype, batch_size=batch_size)
- ... self.state_spec = CompositeSpec(
- ... hidden_observation=UnboundedContinuousTensorSpec((4,))
+ ... self.state_spec = Composite(
+ ... hidden_observation=Unbounded((4,))
... )
- ... self.observation_spec = CompositeSpec(
- ... hidden_observation=UnboundedContinuousTensorSpec((4,))
+ ... self.observation_spec = Composite(
+ ... hidden_observation=Unbounded((4,))
... )
- ... self.action_spec = UnboundedContinuousTensorSpec((1,))
- ... self.reward_spec = UnboundedContinuousTensorSpec((1,))
+ ... self.action_spec = Unbounded((1,))
+ ... self.reward_spec = Unbounded((1,))
...
... def _reset(self, tensordict: TensorDict) -> TensorDict:
... tensordict = TensorDict(
diff --git a/torchrl/modules/tensordict_module/actors.py b/torchrl/modules/tensordict_module/actors.py
index 81b7ec1e605..003c35cf0eb 100644
--- a/torchrl/modules/tensordict_module/actors.py
+++ b/torchrl/modules/tensordict_module/actors.py
@@ -22,7 +22,7 @@
from torch.distributions import Categorical
from torchrl._utils import _replace_last
-from torchrl.data.tensor_specs import CompositeSpec, TensorSpec
+from torchrl.data.tensor_specs import Composite, TensorSpec
from torchrl.data.utils import _process_action_space_spec
from torchrl.modules.tensordict_module.common import DistributionalDQNnet, SafeModule
from torchrl.modules.tensordict_module.probabilistic import (
@@ -37,8 +37,8 @@ class Actor(SafeModule):
The Actor class comes with default values for the out_keys (``["action"]``)
and if the spec is provided but not as a
- :class:`~torchrl.data.CompositeSpec` object, it will be
- automatically translated into ``spec = CompositeSpec(action=spec)``.
+ :class:`~torchrl.data.Composite` object, it will be
+ automatically translated into ``spec = Composite(action=spec)``.
Args:
module (nn.Module): a :class:`~torch.nn.Module` used to map the input to
@@ -70,11 +70,11 @@ class Actor(SafeModule):
Examples:
>>> import torch
>>> from tensordict import TensorDict
- >>> from torchrl.data import UnboundedContinuousTensorSpec
+ >>> from torchrl.data import Unbounded
>>> from torchrl.modules import Actor
>>> torch.manual_seed(0)
>>> td = TensorDict({"observation": torch.randn(3, 4)}, [3,])
- >>> action_spec = UnboundedContinuousTensorSpec(4)
+ >>> action_spec = Unbounded(4)
>>> module = torch.nn.Linear(4, 4)
>>> td_module = Actor(
... module=module,
@@ -111,9 +111,9 @@ def __init__(
if (
"action" in out_keys
and spec is not None
- and not isinstance(spec, CompositeSpec)
+ and not isinstance(spec, Composite)
):
- spec = CompositeSpec(action=spec)
+ spec = Composite(action=spec)
super().__init__(
module,
@@ -128,8 +128,8 @@ class ProbabilisticActor(SafeProbabilisticTensorDictSequential):
"""General class for probabilistic actors in RL.
The Actor class comes with default values for the out_keys (["action"])
- and if the spec is provided but not as a CompositeSpec object, it will be
- automatically translated into :obj:`spec = CompositeSpec(action=spec)`
+ and if the spec is provided but not as a Composite object, it will be
+ automatically translated into :obj:`spec = Composite(action=spec)`
Args:
module (nn.Module): a :class:`torch.nn.Module` used to map the input to
@@ -205,10 +205,10 @@ class ProbabilisticActor(SafeProbabilisticTensorDictSequential):
>>> import torch
>>> from tensordict import TensorDict
>>> from tensordict.nn import TensorDictModule
- >>> from torchrl.data import BoundedTensorSpec
+ >>> from torchrl.data import Bounded
>>> from torchrl.modules import ProbabilisticActor, NormalParamExtractor, TanhNormal
>>> td = TensorDict({"observation": torch.randn(3, 4)}, [3,])
- >>> action_spec = BoundedTensorSpec(shape=torch.Size([4]),
+ >>> action_spec = Bounded(shape=torch.Size([4]),
... low=-1, high=1)
>>> module = nn.Sequential(torch.nn.Linear(4, 8), NormalParamExtractor())
>>> tensordict_module = TensorDictModule(module, in_keys=["observation"], out_keys=["loc", "scale"])
@@ -382,12 +382,8 @@ def __init__(
out_keys = list(distribution_map.keys())
else:
out_keys = ["action"]
- if (
- len(out_keys) == 1
- and spec is not None
- and not isinstance(spec, CompositeSpec)
- ):
- spec = CompositeSpec({out_keys[0]: spec})
+ if len(out_keys) == 1 and spec is not None and not isinstance(spec, Composite):
+ spec = Composite({out_keys[0]: spec})
super().__init__(
module,
@@ -424,7 +420,7 @@ class ValueOperator(TensorDictModule):
>>> import torch
>>> from tensordict import TensorDict
>>> from torch import nn
- >>> from torchrl.data import UnboundedContinuousTensorSpec
+ >>> from torchrl.data import Unbounded
>>> from torchrl.modules import ValueOperator
>>> td = TensorDict({"observation": torch.randn(3, 4), "action": torch.randn(3, 2)}, [3,])
>>> class CustomModule(nn.Module):
@@ -577,22 +573,22 @@ def __init__(
)
self.out_keys = out_keys
action_key = out_keys[0]
- if not isinstance(spec, CompositeSpec):
- spec = CompositeSpec({action_key: spec})
+ if not isinstance(spec, Composite):
+ spec = Composite({action_key: spec})
super().__init__()
self.register_spec(safe=safe, spec=spec)
register_spec = SafeModule.register_spec
@property
- def spec(self) -> CompositeSpec:
+ def spec(self) -> Composite:
return self._spec
@spec.setter
- def spec(self, spec: CompositeSpec) -> None:
- if not isinstance(spec, CompositeSpec):
+ def spec(self, spec: Composite) -> None:
+ if not isinstance(spec, Composite):
raise RuntimeError(
- f"Trying to set an object of type {type(spec)} as a tensorspec but expected a CompositeSpec instance."
+ f"Trying to set an object of type {type(spec)} as a tensorspec but expected a Composite instance."
)
self._spec = spec
@@ -891,13 +887,13 @@ class QValueHook:
>>> import torch
>>> from tensordict import TensorDict
>>> from torch import nn
- >>> from torchrl.data import OneHotDiscreteTensorSpec
+ >>> from torchrl.data import OneHot
>>> from torchrl.modules.tensordict_module.actors import QValueHook, Actor
>>> td = TensorDict({'observation': torch.randn(5, 4)}, [5])
>>> module = nn.Linear(4, 4)
>>> hook = QValueHook("one_hot")
>>> module.register_forward_hook(hook)
- >>> action_spec = OneHotDiscreteTensorSpec(4)
+ >>> action_spec = OneHot(4)
>>> qvalue_actor = Actor(module=module, spec=action_spec, out_keys=["action", "action_value"])
>>> td = qvalue_actor(td)
>>> print(td)
@@ -975,7 +971,7 @@ class DistributionalQValueHook(QValueHook):
>>> import torch
>>> from tensordict import TensorDict
>>> from torch import nn
- >>> from torchrl.data import OneHotDiscreteTensorSpec
+ >>> from torchrl.data import OneHot
>>> from torchrl.modules.tensordict_module.actors import DistributionalQValueHook, Actor
>>> td = TensorDict({'observation': torch.randn(5, 4)}, [5])
>>> nbins = 3
@@ -989,7 +985,7 @@ class DistributionalQValueHook(QValueHook):
...
>>> module = CustomDistributionalQval()
>>> params = TensorDict.from_module(module)
- >>> action_spec = OneHotDiscreteTensorSpec(4)
+ >>> action_spec = OneHot(4)
>>> hook = DistributionalQValueHook("one_hot", support = torch.arange(nbins))
>>> module.register_forward_hook(hook)
>>> qvalue_actor = Actor(module=module, spec=action_spec, out_keys=["action", "action_value"])
@@ -1085,12 +1081,12 @@ class QValueActor(SafeSequential):
>>> import torch
>>> from tensordict import TensorDict
>>> from torch import nn
- >>> from torchrl.data import OneHotDiscreteTensorSpec
+ >>> from torchrl.data import OneHot
>>> from torchrl.modules.tensordict_module.actors import QValueActor
>>> td = TensorDict({'observation': torch.randn(5, 4)}, [5])
>>> # with a regular nn.Module
>>> module = nn.Linear(4, 4)
- >>> action_spec = OneHotDiscreteTensorSpec(4)
+ >>> action_spec = OneHot(4)
>>> qvalue_actor = QValueActor(module=module, spec=action_spec)
>>> td = qvalue_actor(td)
>>> print(td)
@@ -1106,7 +1102,7 @@ class QValueActor(SafeSequential):
>>> # with a TensorDictModule
>>> td = TensorDict({'obs': torch.randn(5, 4)}, [5])
>>> module = TensorDictModule(lambda x: x, in_keys=["obs"], out_keys=["action_value"])
- >>> action_spec = OneHotDiscreteTensorSpec(4)
+ >>> action_spec = OneHot(4)
>>> qvalue_actor = QValueActor(module=module, spec=action_spec)
>>> td = qvalue_actor(td)
>>> print(td)
@@ -1161,13 +1157,13 @@ def __init__(
module, in_keys=in_keys, out_keys=[action_value_key]
)
if spec is None:
- spec = CompositeSpec()
- if isinstance(spec, CompositeSpec):
+ spec = Composite()
+ if isinstance(spec, Composite):
spec = spec.clone()
if "action" not in spec.keys():
spec["action"] = None
else:
- spec = CompositeSpec(action=spec, shape=spec.shape[:-1])
+ spec = Composite(action=spec, shape=spec.shape[:-1])
spec[action_value_key] = None
spec["chosen_action_value"] = None
qvalue = QValueModule(
@@ -1237,7 +1233,7 @@ class DistributionalQValueActor(QValueActor):
>>> from tensordict import TensorDict
>>> from tensordict.nn import TensorDictModule, TensorDictSequential
>>> from torch import nn
- >>> from torchrl.data import OneHotDiscreteTensorSpec
+ >>> from torchrl.data import OneHot
>>> from torchrl.modules import DistributionalQValueActor, MLP
>>> td = TensorDict({'observation': torch.randn(5, 4)}, [5])
>>> nbins = 3
@@ -1247,7 +1243,7 @@ class DistributionalQValueActor(QValueActor):
... TensorDictModule(module, ["observation"], ["action_value"]),
... TensorDictModule(lambda x: x.log_softmax(-2), ["action_value"], ["action_value"]),
... )
- >>> action_spec = OneHotDiscreteTensorSpec(4)
+ >>> action_spec = OneHot(4)
>>> qvalue_actor = DistributionalQValueActor(
... module=module,
... spec=action_spec,
@@ -1299,13 +1295,13 @@ def __init__(
module, in_keys=in_keys, out_keys=[action_value_key]
)
if spec is None:
- spec = CompositeSpec()
- if isinstance(spec, CompositeSpec):
+ spec = Composite()
+ if isinstance(spec, Composite):
spec = spec.clone()
if "action" not in spec.keys():
spec["action"] = None
else:
- spec = CompositeSpec(action=spec, shape=spec.shape[:-1])
+ spec = Composite(action=spec, shape=spec.shape[:-1])
spec[action_value_key] = None
qvalue = DistributionalQValueModule(
@@ -1848,8 +1844,8 @@ def __init__(
self.return_to_go_key = "return_to_go"
self.inference_context = inference_context
if spec is not None:
- if not isinstance(spec, CompositeSpec) and len(self.out_keys) >= 1:
- spec = CompositeSpec({self.action_key: spec}, shape=spec.shape[:-1])
+ if not isinstance(spec, Composite) and len(self.out_keys) >= 1:
+ spec = Composite({self.action_key: spec}, shape=spec.shape[:-1])
self._spec = spec
elif hasattr(self.td_module, "_spec"):
self._spec = self.td_module._spec.clone()
@@ -1860,7 +1856,7 @@ def __init__(
if self.action_key not in self._spec.keys():
self._spec[self.action_key] = None
else:
- self._spec = CompositeSpec({key: None for key in policy.out_keys})
+ self._spec = Composite({key: None for key in policy.out_keys})
self.checked = False
@property
@@ -1989,7 +1985,7 @@ class TanhModule(TensorDictModuleBase):
Keyword Args:
spec (TensorSpec, optional): if provided, the spec of the output.
- If a CompositeSpec is provided, its key(s) must match the key(s)
+ If a Composite is provided, its key(s) must match the key(s)
in out_keys. Otherwise, the key(s) of out_keys are assumed and the
same spec is used for all outputs.
low (float, np.ndarray or torch.Tensor): the lower bound of the space.
@@ -2027,8 +2023,8 @@ class TanhModule(TensorDictModuleBase):
>>> data['action']
tensor([-2.0000, 0.9991, 1.0000, -2.0000, -1.9991])
>>> # A spec can be provided
- >>> from torchrl.data import BoundedTensorSpec
- >>> spec = BoundedTensorSpec(low, high, shape=())
+ >>> from torchrl.data import Bounded
+ >>> spec = Bounded(low, high, shape=())
>>> mod = TanhModule(
... in_keys=in_keys,
... low=low,
@@ -2038,9 +2034,9 @@ class TanhModule(TensorDictModuleBase):
... )
>>> # One can also work with multiple keys
>>> in_keys = ['a', 'b']
- >>> spec = CompositeSpec(
- ... a=BoundedTensorSpec(-3, 0, shape=()),
- ... b=BoundedTensorSpec(0, 3, shape=()))
+ >>> spec = Composite(
+ ... a=Bounded(-3, 0, shape=()),
+ ... b=Bounded(0, 3, shape=()))
>>> mod = TanhModule(
... in_keys=in_keys,
... spec=spec,
@@ -2077,13 +2073,13 @@ def __init__(
)
self.out_keys = out_keys
# action_spec can be a composite spec or not
- if isinstance(spec, CompositeSpec):
+ if isinstance(spec, Composite):
for out_key in self.out_keys:
if out_key not in spec.keys(True, True):
spec[out_key] = None
else:
# if one spec is present, we assume it is the same for all keys
- spec = CompositeSpec(
+ spec = Composite(
{out_key: spec for out_key in out_keys},
)
diff --git a/torchrl/modules/tensordict_module/common.py b/torchrl/modules/tensordict_module/common.py
index 11cc363b461..4018589bfa1 100644
--- a/torchrl/modules/tensordict_module/common.py
+++ b/torchrl/modules/tensordict_module/common.py
@@ -21,7 +21,7 @@
from torch import nn
from torch.nn import functional as F
-from torchrl.data.tensor_specs import CompositeSpec, TensorSpec
+from torchrl.data.tensor_specs import Composite, TensorSpec
from torchrl.data.utils import DEVICE_TYPING
@@ -59,12 +59,12 @@ def _check_all_str(list_of_str, first_level=True):
def _forward_hook_safe_action(module, tensordict_in, tensordict_out):
try:
spec = module.spec
- if len(module.out_keys) > 1 and not isinstance(spec, CompositeSpec):
+ if len(module.out_keys) > 1 and not isinstance(spec, Composite):
raise RuntimeError(
- "safe TensorDictModules with multiple out_keys require a CompositeSpec with matching keys. Got "
+ "safe TensorDictModules with multiple out_keys require a Composite with matching keys. Got "
f"keys {module.out_keys}."
)
- elif not isinstance(spec, CompositeSpec):
+ elif not isinstance(spec, Composite):
out_key = module.out_keys[0]
keys = [out_key]
values = [spec]
@@ -138,10 +138,10 @@ class SafeModule(TensorDictModule):
Examples:
>>> import torch
>>> from tensordict import TensorDict
- >>> from torchrl.data import UnboundedContinuousTensorSpec
+ >>> from torchrl.data import Unbounded
>>> from torchrl.modules import TensorDictModule
>>> td = TensorDict({"input": torch.randn(3, 4), "hidden": torch.randn(3, 8)}, [3,])
- >>> spec = UnboundedContinuousTensorSpec(8)
+ >>> spec = Unbounded(8)
>>> module = torch.nn.GRUCell(4, 8)
>>> td_fmodule = TensorDictModule(
... module=module,
@@ -216,18 +216,18 @@ def register_spec(self, safe, spec):
spec = spec.clone()
if spec is not None and not isinstance(spec, TensorSpec):
raise TypeError("spec must be a TensorSpec subclass")
- elif spec is not None and not isinstance(spec, CompositeSpec):
+ elif spec is not None and not isinstance(spec, Composite):
if len(self.out_keys) > 1:
raise RuntimeError(
f"got more than one out_key for the TensorDictModule: {self.out_keys},\nbut only one spec. "
- "Consider using a CompositeSpec object or no spec at all."
+ "Consider using a Composite object or no spec at all."
)
- spec = CompositeSpec({self.out_keys[0]: spec})
- elif spec is not None and isinstance(spec, CompositeSpec):
+ spec = Composite({self.out_keys[0]: spec})
+ elif spec is not None and isinstance(spec, Composite):
if "_" in spec.keys() and spec["_"] is not None:
warnings.warn('got a spec with key "_": it will be ignored')
elif spec is None:
- spec = CompositeSpec()
+ spec = Composite()
# unravel_key_list(self.out_keys) can be removed once 473 is merged in tensordict
spec_keys = set(unravel_key_list(list(spec.keys(True, True))))
@@ -247,7 +247,7 @@ def register_spec(self, safe, spec):
self.safe = safe
if safe:
if spec is None or (
- isinstance(spec, CompositeSpec)
+ isinstance(spec, Composite)
and all(_spec is None for _spec in spec.values())
):
raise RuntimeError(
@@ -257,14 +257,14 @@ def register_spec(self, safe, spec):
self.register_forward_hook(_forward_hook_safe_action)
@property
- def spec(self) -> CompositeSpec:
+ def spec(self) -> Composite:
return self._spec
@spec.setter
- def spec(self, spec: CompositeSpec) -> None:
- if not isinstance(spec, CompositeSpec):
+ def spec(self, spec: Composite) -> None:
+ if not isinstance(spec, Composite):
raise RuntimeError(
- f"Trying to set an object of type {type(spec)} as a tensorspec but expected a CompositeSpec instance."
+ f"Trying to set an object of type {type(spec)} as a tensorspec but expected a Composite instance."
)
self._spec = spec
@@ -350,7 +350,7 @@ def is_tensordict_compatible(module: Union[TensorDictModule, nn.Module]):
# if in_keys or out_keys were defined but module is not a TensorDictModule or
# accepts multiple arguments then it's likely the user is trying to do something
- # that will have undetermined behaviour, we raise an error
+ # that will have undetermined behavior, we raise an error
raise TypeError(
"Received a module that defines in_keys or out_keys and also expects multiple "
"arguments to module.forward. If the module is compatible with TensorDict, it "
@@ -403,7 +403,7 @@ def ensure_tensordict_compatible(
"env.observation_spec. If you want TorchRL to automatically "
"wrap your module with a TensorDictModule then the arguments "
"to module must correspond one-to-one with entries in "
- "in_keys. For more complex behaviour and more control you can "
+ "in_keys. For more complex behavior and more control you can "
"consider writing your own TensorDictModule."
)
diff --git a/torchrl/modules/tensordict_module/exploration.py b/torchrl/modules/tensordict_module/exploration.py
index 5a41f11bf76..7337d1c94dd 100644
--- a/torchrl/modules/tensordict_module/exploration.py
+++ b/torchrl/modules/tensordict_module/exploration.py
@@ -16,7 +16,7 @@
)
from tensordict.utils import expand_as_right, expand_right, NestedKey
-from torchrl.data.tensor_specs import CompositeSpec, TensorSpec
+from torchrl.data.tensor_specs import Composite, TensorSpec
from torchrl.envs.utils import exploration_type, ExplorationType
from torchrl.modules.tensordict_module.common import _forward_hook_safe_action
@@ -64,9 +64,9 @@ class EGreedyModule(TensorDictModuleBase):
>>> from tensordict import TensorDict
>>> from tensordict.nn import TensorDictSequential
>>> from torchrl.modules import EGreedyModule, Actor
- >>> from torchrl.data import BoundedTensorSpec
+ >>> from torchrl.data import Bounded
>>> torch.manual_seed(0)
- >>> spec = BoundedTensorSpec(-1, 1, torch.Size([4]))
+ >>> spec = Bounded(-1, 1, torch.Size([4]))
>>> module = torch.nn.Linear(4, 4, bias=False)
>>> policy = Actor(spec=spec, module=module)
>>> explorative_policy = TensorDictSequential(policy, EGreedyModule(eps_init=0.2))
@@ -115,8 +115,8 @@ def __init__(
self.register_buffer("eps", torch.as_tensor([eps_init], dtype=torch.float32))
if spec is not None:
- if not isinstance(spec, CompositeSpec) and len(self.out_keys) >= 1:
- spec = CompositeSpec({action_key: spec}, shape=spec.shape[:-1])
+ if not isinstance(spec, Composite) and len(self.out_keys) >= 1:
+ spec = Composite({action_key: spec}, shape=spec.shape[:-1])
self._spec = spec
@property
@@ -155,7 +155,7 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
cond = expand_as_right(cond, out)
spec = self.spec
if spec is not None:
- if isinstance(spec, CompositeSpec):
+ if isinstance(spec, Composite):
spec = spec[self.action_key]
if spec.shape != out.shape:
# In batched envs if the spec is passed unbatched, the rand() will not
@@ -214,9 +214,9 @@ class EGreedyWrapper(TensorDictModuleWrapper):
>>> import torch
>>> from tensordict import TensorDict
>>> from torchrl.modules import EGreedyWrapper, Actor
- >>> from torchrl.data import BoundedTensorSpec
+ >>> from torchrl.data import Bounded
>>> torch.manual_seed(0)
- >>> spec = BoundedTensorSpec(-1, 1, torch.Size([4]))
+ >>> spec = Bounded(-1, 1, torch.Size([4]))
>>> module = torch.nn.Linear(4, 4, bias=False)
>>> policy = Actor(spec=spec, module=module)
>>> explorative_policy = EGreedyWrapper(policy, eps_init=0.2)
@@ -267,7 +267,7 @@ class AdditiveGaussianWrapper(TensorDictModuleWrapper):
mean (float, optional): mean of each output element’s normal distribution.
std (float, optional): standard deviation of each output element’s normal distribution.
action_key (NestedKey, optional): if the policy module has more than one output key,
- its output spec will be of type CompositeSpec. One needs to know where to
+ its output spec will be of type Composite. One needs to know where to
find the action spec.
Default is "action".
spec (TensorSpec, optional): if provided, the sampled action will be
@@ -323,8 +323,8 @@ def __init__(
f"The action key {action_key} was not found in the td_module out_keys {self.td_module.out_keys}."
)
if spec is not None:
- if not isinstance(spec, CompositeSpec) and len(self.out_keys) >= 1:
- spec = CompositeSpec({action_key: spec}, shape=spec.shape[:-1])
+ if not isinstance(spec, Composite) and len(self.out_keys) >= 1:
+ spec = Composite({action_key: spec}, shape=spec.shape[:-1])
self._spec = spec
elif hasattr(self.td_module, "_spec"):
self._spec = self.td_module._spec.clone()
@@ -335,7 +335,7 @@ def __init__(
if action_key not in self._spec.keys(True, True):
self._spec[action_key] = None
else:
- self._spec = CompositeSpec({key: None for key in policy.out_keys})
+ self._spec = Composite({key: None for key in policy.out_keys})
self.safe = safe
if self.safe:
@@ -410,7 +410,7 @@ class AdditiveGaussianModule(TensorDictModuleBase):
Keyword Args:
action_key (NestedKey, optional): if the policy module has more than one output key,
- its output spec will be of type CompositeSpec. One needs to know where to
+ its output spec will be of type Composite. One needs to know where to
find the action spec.
default: "action"
@@ -453,8 +453,8 @@ def __init__(
self.register_buffer("sigma", torch.tensor([sigma_init], dtype=torch.float32))
if spec is not None:
- if not isinstance(spec, CompositeSpec) and len(self.out_keys) >= 1:
- spec = CompositeSpec({action_key: spec}, shape=spec.shape[:-1])
+ if not isinstance(spec, Composite) and len(self.out_keys) >= 1:
+ spec = Composite({action_key: spec}, shape=spec.shape[:-1])
else:
raise RuntimeError("spec cannot be None.")
self._spec = spec
@@ -570,10 +570,10 @@ class OrnsteinUhlenbeckProcessWrapper(TensorDictModuleWrapper):
Examples:
>>> import torch
>>> from tensordict import TensorDict
- >>> from torchrl.data import BoundedTensorSpec
+ >>> from torchrl.data import Bounded
>>> from torchrl.modules import OrnsteinUhlenbeckProcessWrapper, Actor
>>> torch.manual_seed(0)
- >>> spec = BoundedTensorSpec(-1, 1, torch.Size([4]))
+ >>> spec = Bounded(-1, 1, torch.Size([4]))
>>> module = torch.nn.Linear(4, 4, bias=False)
>>> policy = Actor(module=module, spec=spec)
>>> explorative_policy = OrnsteinUhlenbeckProcessWrapper(policy)
@@ -647,8 +647,8 @@ def __init__(
steps_key = self.ou.steps_key
if spec is not None:
- if not isinstance(spec, CompositeSpec) and len(self.out_keys) >= 1:
- spec = CompositeSpec({action_key: spec}, shape=spec.shape[:-1])
+ if not isinstance(spec, Composite) and len(self.out_keys) >= 1:
+ spec = Composite({action_key: spec}, shape=spec.shape[:-1])
self._spec = spec
elif hasattr(self.td_module, "_spec"):
self._spec = self.td_module._spec.clone()
@@ -659,7 +659,7 @@ def __init__(
if action_key not in self._spec.keys(True, True):
self._spec[action_key] = None
else:
- self._spec = CompositeSpec({key: None for key in policy.out_keys})
+ self._spec = Composite({key: None for key in policy.out_keys})
ou_specs = {
noise_key: None,
steps_key: None,
@@ -707,7 +707,7 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
f"The tensordict passed to {self.__class__.__name__} appears to be "
f"missing the '{self.is_init_key}' entry. This entry is used to "
f"reset the noise at the beginning of a trajectory, without it "
- f"the behaviour of this exploration method is undefined. "
+ f"the behavior of this exploration method is undefined. "
f"This is allowed for BC compatibility purposes but it will be deprecated soon! "
f"To create a '{self.is_init_key}' entry, simply append an torchrl.envs.InitTracker "
f"transform to your environment with `env = TransformedEnv(env, InitTracker())`."
@@ -783,10 +783,10 @@ class OrnsteinUhlenbeckProcessModule(TensorDictModuleBase):
>>> import torch
>>> from tensordict import TensorDict
>>> from tensordict.nn import TensorDictSequential
- >>> from torchrl.data import BoundedTensorSpec
+ >>> from torchrl.data import Bounded
>>> from torchrl.modules import OrnsteinUhlenbeckProcessModule, Actor
>>> torch.manual_seed(0)
- >>> spec = BoundedTensorSpec(-1, 1, torch.Size([4]))
+ >>> spec = Bounded(-1, 1, torch.Size([4]))
>>> module = torch.nn.Linear(4, 4, bias=False)
>>> policy = Actor(module=module, spec=spec)
>>> ou = OrnsteinUhlenbeckProcessModule(spec=spec)
@@ -851,8 +851,8 @@ def __init__(
steps_key = self.ou.steps_key
if spec is not None:
- if not isinstance(spec, CompositeSpec) and len(self.out_keys) >= 1:
- spec = CompositeSpec({action_key: spec}, shape=spec.shape[:-1])
+ if not isinstance(spec, Composite) and len(self.out_keys) >= 1:
+ spec = Composite({action_key: spec}, shape=spec.shape[:-1])
self._spec = spec
else:
raise RuntimeError("spec cannot be None.")
@@ -900,7 +900,7 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
f"The tensordict passed to {self.__class__.__name__} appears to be "
f"missing the '{self.is_init_key}' entry. This entry is used to "
f"reset the noise at the beginning of a trajectory, without it "
- f"the behaviour of this exploration method is undefined. "
+ f"the behavior of this exploration method is undefined. "
f"This is allowed for BC compatibility purposes but it will be deprecated soon! "
f"To create a '{self.is_init_key}' entry, simply append an torchrl.envs.InitTracker "
f"transform to your environment with `env = TransformedEnv(env, InitTracker())`."
diff --git a/torchrl/modules/tensordict_module/probabilistic.py b/torchrl/modules/tensordict_module/probabilistic.py
index 725323e1a28..483d9b90eea 100644
--- a/torchrl/modules/tensordict_module/probabilistic.py
+++ b/torchrl/modules/tensordict_module/probabilistic.py
@@ -15,7 +15,7 @@
TensorDictModule,
)
from tensordict.utils import NestedKey
-from torchrl.data.tensor_specs import CompositeSpec, TensorSpec
+from torchrl.data.tensor_specs import Composite, TensorSpec
from torchrl.modules.distributions import Delta
from torchrl.modules.tensordict_module.common import _forward_hook_safe_action
from torchrl.modules.tensordict_module.sequence import SafeSequential
@@ -104,7 +104,6 @@ def __init__(
out_keys: Optional[Union[NestedKey, List[NestedKey]]] = None,
spec: Optional[TensorSpec] = None,
safe: bool = False,
- default_interaction_mode: str = None,
default_interaction_type: str = InteractionType.DETERMINISTIC,
distribution_class: Type = Delta,
distribution_kwargs: Optional[dict] = None,
@@ -117,7 +116,6 @@ def __init__(
in_keys=in_keys,
out_keys=out_keys,
default_interaction_type=default_interaction_type,
- default_interaction_mode=default_interaction_mode,
distribution_class=distribution_class,
distribution_kwargs=distribution_kwargs,
return_log_prob=return_log_prob,
@@ -129,18 +127,18 @@ def __init__(
spec = spec.clone()
if spec is not None and not isinstance(spec, TensorSpec):
raise TypeError("spec must be a TensorSpec subclass")
- elif spec is not None and not isinstance(spec, CompositeSpec):
+ elif spec is not None and not isinstance(spec, Composite):
if len(self.out_keys) > 1:
raise RuntimeError(
f"got more than one out_key for the SafeModule: {self.out_keys},\nbut only one spec. "
- "Consider using a CompositeSpec object or no spec at all."
+ "Consider using a Composite object or no spec at all."
)
- spec = CompositeSpec({self.out_keys[0]: spec})
- elif spec is not None and isinstance(spec, CompositeSpec):
+ spec = Composite({self.out_keys[0]: spec})
+ elif spec is not None and isinstance(spec, Composite):
if "_" in spec.keys():
warnings.warn('got a spec with key "_": it will be ignored')
elif spec is None:
- spec = CompositeSpec()
+ spec = Composite()
spec_keys = set(unravel_key_list(list(spec.keys(True, True))))
out_keys = set(unravel_key_list(self.out_keys))
if spec_keys != out_keys:
@@ -159,7 +157,7 @@ def __init__(
self.safe = safe
if safe:
if spec is None or (
- isinstance(spec, CompositeSpec)
+ isinstance(spec, Composite)
and all(_spec is None for _spec in spec.values())
):
raise RuntimeError(
@@ -169,14 +167,14 @@ def __init__(
self.register_forward_hook(_forward_hook_safe_action)
@property
- def spec(self) -> CompositeSpec:
+ def spec(self) -> Composite:
return self._spec
@spec.setter
- def spec(self, spec: CompositeSpec) -> None:
- if not isinstance(spec, CompositeSpec):
+ def spec(self, spec: Composite) -> None:
+ if not isinstance(spec, Composite):
raise RuntimeError(
- f"Trying to set an object of type {type(spec)} as a tensorspec but expected a CompositeSpec instance."
+ f"Trying to set an object of type {type(spec)} as a tensorspec but expected a Composite instance."
)
self._spec = spec
diff --git a/torchrl/modules/tensordict_module/rnn.py b/torchrl/modules/tensordict_module/rnn.py
index 878fb13ebb8..f538f8e95c5 100644
--- a/torchrl/modules/tensordict_module/rnn.py
+++ b/torchrl/modules/tensordict_module/rnn.py
@@ -2,6 +2,8 @@
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
+from __future__ import annotations
+
from typing import Optional, Tuple
import torch
@@ -16,7 +18,7 @@
from torch import nn, Tensor
from torch.nn.modules.rnn import RNNCellBase
-from torchrl.data.tensor_specs import UnboundedContinuousTensorSpec
+from torchrl.data.tensor_specs import Unbounded
from torchrl.objectives.value.functional import (
_inv_pad_sequence,
_split_and_pad_sequence,
@@ -387,7 +389,7 @@ class LSTMModule(ModuleBase):
.. note:: This module relies on specific ``recurrent_state`` keys being present in the input
TensorDicts. To generate a :class:`~torchrl.envs.transforms.TensorDictPrimer` transform that will automatically
add hidden states to the environment TensorDicts, use the method :func:`~torchrl.modules.rnn.LSTMModule.make_tensordict_primer`.
- If this class is a submodule in a larger module, the method :func:`~torchrl.models.utils.get_primers_from_module` can be called
+ If this class is a submodule in a larger module, the method :func:`~torchrl.modules.utils.get_primers_from_module` can be called
on the parent module to automatically generate the primer transforms required for all submodules, including this one.
@@ -529,11 +531,14 @@ def make_tensordict_primer(self):
inputs and outputs (recurrent states) during rollout execution. That way, the data can be shared across
processes and dealt with properly.
- Not including a ``TensorDictPrimer`` in the environment may result in poorly defined behaviours, for instance
+ Not including a ``TensorDictPrimer`` in the environment may result in poorly defined behaviors, for instance
in parallel settings where a step involves copying the new recurrent state from ``"next"`` to the root
tensordict, which the meth:`~torchrl.EnvBase.step_mdp` method will not be able to do as the recurrent states
are not registered within the environment specs.
+ See :func:`torchrl.modules.utils.get_primers_from_module` for a method to generate all primers for a given
+ module.
+
Examples:
>>> from torchrl.collectors import SyncDataCollector
>>> from torchrl.envs import TransformedEnv, InitTracker
@@ -581,12 +586,8 @@ def make_tuple(key):
)
return TensorDictPrimer(
{
- in_key1: UnboundedContinuousTensorSpec(
- shape=(self.lstm.num_layers, self.lstm.hidden_size)
- ),
- in_key2: UnboundedContinuousTensorSpec(
- shape=(self.lstm.num_layers, self.lstm.hidden_size)
- ),
+ in_key1: Unbounded(shape=(self.lstm.num_layers, self.lstm.hidden_size)),
+ in_key2: Unbounded(shape=(self.lstm.num_layers, self.lstm.hidden_size)),
}
)
@@ -609,7 +610,7 @@ def temporal_mode(self):
def set_recurrent_mode(self, mode: bool = True):
"""Returns a new copy of the module that shares the same lstm model but with a different ``recurrent_mode`` attribute (if it differs).
- A copy is created such that the module can be used with divergent behaviour
+ A copy is created such that the module can be used with divergent behavior
in various parts of the code (inference vs training):
Examples:
@@ -623,7 +624,7 @@ def set_recurrent_mode(self, mode: bool = True):
>>> lstm = nn.LSTM(input_size=env.observation_spec["observation"].shape[-1], hidden_size=64, batch_first=True)
>>> lstm_module = LSTMModule(lstm=lstm, in_keys=["observation", "hidden0", "hidden1"], out_keys=["intermediate", ("next", "hidden0"), ("next", "hidden1")])
>>> mlp = MLP(num_cells=[64], out_features=1)
- >>> # building two policies with different behaviours:
+ >>> # building two policies with different behaviors:
>>> policy_inference = Seq(lstm_module, Mod(mlp, in_keys=["intermediate"], out_keys=["action"]))
>>> policy_training = Seq(lstm_module.set_recurrent_mode(True), Mod(mlp, in_keys=["intermediate"], out_keys=["action"]))
>>> traj_td = env.rollout(3) # some random temporal data
@@ -665,7 +666,7 @@ def forward(self, tensordict: TensorDictBase):
else:
tensordict_shaped = tensordict.reshape(-1).unsqueeze(-1)
- is_init = tensordict_shaped.get("is_init").squeeze(-1)
+ is_init = tensordict_shaped["is_init"].squeeze(-1)
splits = None
if self.recurrent_mode and is_init[..., 1:].any():
# if we have consecutive trajectories, things get a little more complicated
@@ -679,7 +680,7 @@ def forward(self, tensordict: TensorDictBase):
tensordict_shaped = _split_and_pad_sequence(
tensordict_shaped.select(*self.in_keys, strict=False), splits
)
- is_init = tensordict_shaped.get("is_init").squeeze(-1)
+ is_init = tensordict_shaped["is_init"].squeeze(-1)
value, hidden0, hidden1 = (
tensordict_shaped.get(key, default)
@@ -691,7 +692,7 @@ def forward(self, tensordict: TensorDictBase):
# packed sequences do not help to get the accurate last hidden values
# if splits is not None:
# value = torch.nn.utils.rnn.pack_padded_sequence(value, splits, batch_first=True)
- if is_init.any() and hidden0 is not None:
+ if hidden0 is not None:
is_init_expand = expand_as_right(is_init, hidden0)
hidden0 = torch.where(is_init_expand, 0, hidden0)
hidden1 = torch.where(is_init_expand, 0, hidden1)
@@ -1112,7 +1113,7 @@ class GRUModule(ModuleBase):
.. note:: This module relies on specific ``recurrent_state`` keys being present in the input
TensorDicts. To generate a :class:`~torchrl.envs.transforms.TensorDictPrimer` transform that will automatically
add hidden states to the environment TensorDicts, use the method :func:`~torchrl.modules.rnn.GRUModule.make_tensordict_primer`.
- If this class is a submodule in a larger module, the method :func:`~torchrl.models.utils.get_primers_from_module` can be called
+ If this class is a submodule in a larger module, the method :func:`~torchrl.modules.utils.get_primers_from_module` can be called
on the parent module to automatically generate the primer transforms required for all submodules, including this one.
Examples:
@@ -1279,11 +1280,14 @@ def make_tensordict_primer(self):
inputs and outputs (recurrent states) during rollout execution. That way, the data can be shared across
processes and dealt with properly.
- Not including a ``TensorDictPrimer`` in the environment may result in poorly defined behaviours, for instance
+ Not including a ``TensorDictPrimer`` in the environment may result in poorly defined behaviors, for instance
in parallel settings where a step involves copying the new recurrent state from ``"next"`` to the root
tensordict, which the meth:`~torchrl.EnvBase.step_mdp` method will not be able to do as the recurrent states
are not registered within the environment specs.
+ See :func:`torchrl.modules.utils.get_primers_from_module` for a method to generate all primers for a given
+ module.
+
Examples:
>>> from torchrl.collectors import SyncDataCollector
>>> from torchrl.envs import TransformedEnv, InitTracker
@@ -1329,9 +1333,7 @@ def make_tuple(key):
)
return TensorDictPrimer(
{
- in_key1: UnboundedContinuousTensorSpec(
- shape=(self.gru.num_layers, self.gru.hidden_size)
- ),
+ in_key1: Unbounded(shape=(self.gru.num_layers, self.gru.hidden_size)),
}
)
@@ -1354,7 +1356,7 @@ def temporal_mode(self):
def set_recurrent_mode(self, mode: bool = True):
"""Returns a new copy of the module that shares the same gru model but with a different ``recurrent_mode`` attribute (if it differs).
- A copy is created such that the module can be used with divergent behaviour
+ A copy is created such that the module can be used with divergent behavior
in various parts of the code (inference vs training):
Examples:
@@ -1367,7 +1369,7 @@ def set_recurrent_mode(self, mode: bool = True):
>>> gru = nn.GRU(input_size=env.observation_spec["observation"].shape[-1], hidden_size=64, batch_first=True)
>>> gru_module = GRUModule(gru=gru, in_keys=["observation", "hidden"], out_keys=["intermediate", ("next", "hidden")])
>>> mlp = MLP(num_cells=[64], out_features=1)
- >>> # building two policies with different behaviours:
+ >>> # building two policies with different behaviors:
>>> policy_inference = Seq(gru_module, Mod(mlp, in_keys=["intermediate"], out_keys=["action"]))
>>> policy_training = Seq(gru_module.set_recurrent_mode(True), Mod(mlp, in_keys=["intermediate"], out_keys=["action"]))
>>> traj_td = env.rollout(3) # some random temporal data
@@ -1410,7 +1412,7 @@ def forward(self, tensordict: TensorDictBase):
else:
tensordict_shaped = tensordict.reshape(-1).unsqueeze(-1)
- is_init = tensordict_shaped.get("is_init").squeeze(-1)
+ is_init = tensordict_shaped["is_init"].squeeze(-1)
splits = None
if self.recurrent_mode and is_init[..., 1:].any():
# if we have consecutive trajectories, things get a little more complicated
@@ -1424,7 +1426,7 @@ def forward(self, tensordict: TensorDictBase):
tensordict_shaped = _split_and_pad_sequence(
tensordict_shaped.select(*self.in_keys, strict=False), splits
)
- is_init = tensordict_shaped.get("is_init").squeeze(-1)
+ is_init = tensordict_shaped["is_init"].squeeze(-1)
value, hidden = (
tensordict_shaped.get(key, default)
diff --git a/torchrl/modules/tensordict_module/sequence.py b/torchrl/modules/tensordict_module/sequence.py
index 41ddb55fb35..938843e624f 100644
--- a/torchrl/modules/tensordict_module/sequence.py
+++ b/torchrl/modules/tensordict_module/sequence.py
@@ -8,7 +8,7 @@
from tensordict.nn import TensorDictModule, TensorDictSequential
from torch import nn
-from torchrl.data.tensor_specs import CompositeSpec
+from torchrl.data.tensor_specs import Composite
from torchrl.modules.tensordict_module.common import SafeModule
@@ -33,11 +33,11 @@ class SafeSequential(TensorDictSequential, SafeModule):
Examples:
>>> import torch
>>> from tensordict import TensorDict
- >>> from torchrl.data import CompositeSpec, UnboundedContinuousTensorSpec
+ >>> from torchrl.data import Composite, Unbounded
>>> from torchrl.modules import TanhNormal, SafeSequential, TensorDictModule, NormalParamExtractor
>>> from torchrl.modules.tensordict_module import SafeProbabilisticModule
>>> td = TensorDict({"input": torch.randn(3, 4)}, [3,])
- >>> spec1 = CompositeSpec(hidden=UnboundedContinuousTensorSpec(4), loc=None, scale=None)
+ >>> spec1 = Composite(hidden=Unbounded(4), loc=None, scale=None)
>>> net1 = nn.Sequential(torch.nn.Linear(4, 8), NormalParamExtractor())
>>> module1 = TensorDictModule(net1, in_keys=["input"], out_keys=["loc", "scale"])
>>> td_module1 = SafeProbabilisticModule(
@@ -48,7 +48,7 @@ class SafeSequential(TensorDictSequential, SafeModule):
... distribution_class=TanhNormal,
... return_log_prob=True,
... )
- >>> spec2 = UnboundedContinuousTensorSpec(8)
+ >>> spec2 = Unbounded(8)
>>> module2 = torch.nn.Linear(4, 8)
>>> td_module2 = TensorDictModule(
... module=module2,
@@ -74,12 +74,12 @@ class SafeSequential(TensorDictSequential, SafeModule):
is_shared=False)
>>> # The module spec aggregates all the input specs:
>>> print(td_module.spec)
- CompositeSpec(
- hidden: UnboundedContinuousTensorSpec(
+ Composite(
+ hidden: UnboundedContinuous(
shape=torch.Size([4]), space=None, device=cpu, dtype=torch.float32, domain=continuous),
loc: None,
scale: None,
- output: UnboundedContinuousTensorSpec(
+ output: UnboundedContinuous(
shape=torch.Size([8]), space=None, device=cpu, dtype=torch.float32, domain=continuous))
In the vmap case:
@@ -112,12 +112,12 @@ def __init__(
in_keys, out_keys = self._compute_in_and_out_keys(modules)
- spec = CompositeSpec()
+ spec = Composite()
for module in modules:
try:
spec.update(module.spec)
except AttributeError:
- spec.update(CompositeSpec({key: None for key in module.out_keys}))
+ spec.update(Composite({key: None for key in module.out_keys}))
super(TensorDictSequential, self).__init__(
spec=spec,
diff --git a/torchrl/modules/utils/utils.py b/torchrl/modules/utils/utils.py
index 0f3088a8943..9a8914aab89 100644
--- a/torchrl/modules/utils/utils.py
+++ b/torchrl/modules/utils/utils.py
@@ -46,8 +46,8 @@ def get_primers_from_module(module):
>>> primers = get_primers_from_module(model)
>>> print(primers)
- TensorDictPrimer(primers=CompositeSpec(
- recurrent_state: UnboundedContinuousTensorSpec(
+ TensorDictPrimer(primers=Composite(
+ recurrent_state: UnboundedContinuous(
shape=torch.Size([1, 10]),
space=None,
device=cpu,
diff --git a/torchrl/objectives/__init__.py b/torchrl/objectives/__init__.py
index aa13a88c7e9..1ea9ebb5998 100644
--- a/torchrl/objectives/__init__.py
+++ b/torchrl/objectives/__init__.py
@@ -11,6 +11,7 @@
from .decision_transformer import DTLoss, OnlineDTLoss
from .dqn import DistributionalDQNLoss, DQNLoss
from .dreamer import DreamerActorLoss, DreamerModelLoss, DreamerValueLoss
+from .gail import GAILLoss
from .iql import DiscreteIQLLoss, IQLLoss
from .multiagent import QMixerLoss
from .ppo import ClipPPOLoss, KLPENPPOLoss, PPOLoss
@@ -29,5 +30,3 @@
SoftUpdate,
ValueEstimators,
)
-
-# from .value import bellman_max, c_val, dv_val, vtrace, GAE, TDLambdaEstimate, TDEstimate
diff --git a/torchrl/objectives/a2c.py b/torchrl/objectives/a2c.py
index a236b80d56c..c823788b4c2 100644
--- a/torchrl/objectives/a2c.py
+++ b/torchrl/objectives/a2c.py
@@ -10,7 +10,12 @@
from typing import Tuple
import torch
-from tensordict import TensorDict, TensorDictBase, TensorDictParams
+from tensordict import (
+ is_tensor_collection,
+ TensorDict,
+ TensorDictBase,
+ TensorDictParams,
+)
from tensordict.nn import dispatch, ProbabilisticTensorDictSequential, TensorDictModule
from tensordict.utils import NestedKey
from torch import distributions as d
@@ -56,8 +61,9 @@ class A2CLoss(LossModule):
``samples_mc_entropy`` will control how many
samples will be used to compute this estimate.
Defaults to ``1``.
- entropy_coef (float): the weight of the entropy loss.
- critic_coef (float): the weight of the critic loss.
+ entropy_coef (float): the weight of the entropy loss. Defaults to `0.01``.
+ critic_coef (float): the weight of the critic loss. Defaults to ``1.0``. If ``None``, the critic
+ loss won't be included and the in-keys will miss the critic inputs.
loss_critic_type (str): loss function for the value discrepancy.
Can be one of "l1", "l2" or "smooth_l1". Defaults to ``"smooth_l1"``.
separate_losses (bool, optional): if ``True``, shared parameters between
@@ -96,14 +102,14 @@ class A2CLoss(LossModule):
Examples:
>>> import torch
>>> from torch import nn
- >>> from torchrl.data import BoundedTensorSpec
+ >>> from torchrl.data import Bounded
>>> from torchrl.modules.distributions import NormalParamExtractor, TanhNormal
>>> from torchrl.modules.tensordict_module.actors import ProbabilisticActor, ValueOperator
>>> from torchrl.modules.tensordict_module.common import SafeModule
>>> from torchrl.objectives.a2c import A2CLoss
>>> from tensordict import TensorDict
>>> n_act, n_obs = 4, 3
- >>> spec = BoundedTensorSpec(-torch.ones(n_act), torch.ones(n_act), (n_act,))
+ >>> spec = Bounded(-torch.ones(n_act), torch.ones(n_act), (n_act,))
>>> net = nn.Sequential(nn.Linear(n_obs, 2 * n_act), NormalParamExtractor())
>>> module = SafeModule(net, in_keys=["observation"], out_keys=["loc", "scale"])
>>> actor = ProbabilisticActor(
@@ -147,14 +153,14 @@ class A2CLoss(LossModule):
Examples:
>>> import torch
>>> from torch import nn
- >>> from torchrl.data import BoundedTensorSpec
+ >>> from torchrl.data import Bounded
>>> from torchrl.modules.distributions import NormalParamExtractor, TanhNormal
>>> from torchrl.modules.tensordict_module.actors import ProbabilisticActor, ValueOperator
>>> from torchrl.modules.tensordict_module.common import SafeModule
>>> from torchrl.objectives.a2c import A2CLoss
>>> _ = torch.manual_seed(42)
>>> n_act, n_obs = 4, 3
- >>> spec = BoundedTensorSpec(-torch.ones(n_act), torch.ones(n_act), (n_act,))
+ >>> spec = Bounded(-torch.ones(n_act), torch.ones(n_act), (n_act,))
>>> net = nn.Sequential(nn.Linear(n_obs, 2 * n_act), NormalParamExtractor())
>>> module = SafeModule(net, in_keys=["observation"], out_keys=["loc", "scale"])
>>> actor = ProbabilisticActor(
@@ -190,6 +196,13 @@ class A2CLoss(LossModule):
... next_reward = torch.randn(*batch, 1),
... next_observation = torch.randn(*batch, n_obs))
>>> loss_obj.backward()
+
+ .. note::
+ There is an exception regarding compatibility with non-tensordict-based modules.
+ If the actor network is probabilistic and uses a :class:`~tensordict.nn.distributions.CompositeDistribution`,
+ this class must be used with tensordicts and cannot function as a tensordict-independent module.
+ This is because composite action spaces inherently rely on the structured representation of data provided by
+ tensordicts to handle their actions.
"""
@dataclass
@@ -311,7 +324,13 @@ def __init__(
self.register_buffer(
"entropy_coef", torch.as_tensor(entropy_coef, device=device)
)
- self.register_buffer("critic_coef", torch.as_tensor(critic_coef, device=device))
+ if critic_coef is not None:
+ self.register_buffer(
+ "critic_coef", torch.as_tensor(critic_coef, device=device)
+ )
+ else:
+ self.critic_coef = None
+
if gamma is not None:
raise TypeError(_GAMMA_LMBDA_DEPREC_ERROR)
self.loss_critic_type = loss_critic_type
@@ -344,7 +363,7 @@ def in_keys(self):
*self.actor_network.in_keys,
*[("next", key) for key in self.actor_network.in_keys],
]
- if self.critic_coef:
+ if self.critic_coef is not None:
keys.extend(self.critic_network.in_keys)
return list(set(keys))
@@ -352,7 +371,7 @@ def in_keys(self):
def out_keys(self):
if self._out_keys is None:
outs = ["loss_objective"]
- if self.critic_coef:
+ if self.critic_coef is not None:
outs.append("loss_critic")
if self.entropy_bonus:
outs.append("entropy")
@@ -383,7 +402,10 @@ def get_entropy_bonus(self, dist: d.Distribution) -> torch.Tensor:
entropy = dist.entropy()
except NotImplementedError:
x = dist.rsample((self.samples_mc_entropy,))
- entropy = -dist.log_prob(x).mean(0)
+ log_prob = dist.log_prob(x)
+ if is_tensor_collection(log_prob):
+ log_prob = log_prob.get(self.tensor_keys.sample_log_prob)
+ entropy = -log_prob.mean(0)
return entropy.unsqueeze(-1)
def _log_probs(
@@ -391,10 +413,6 @@ def _log_probs(
) -> Tuple[torch.Tensor, d.Distribution]:
# current log_prob of actions
action = tensordict.get(self.tensor_keys.action)
- if action.requires_grad:
- raise RuntimeError(
- f"tensordict stored {self.tensor_keys.action} require grad."
- )
tensordict_clone = tensordict.select(
*self.actor_network.in_keys, strict=False
).clone()
@@ -402,11 +420,29 @@ def _log_probs(
self.actor_network
) if self.functional else contextlib.nullcontext():
dist = self.actor_network.get_dist(tensordict_clone)
- log_prob = dist.log_prob(action)
+ if action.requires_grad:
+ raise RuntimeError(
+ f"tensordict stored {self.tensor_keys.action} requires grad."
+ )
+ if isinstance(action, torch.Tensor):
+ log_prob = dist.log_prob(action)
+ else:
+ maybe_log_prob = dist.log_prob(tensordict)
+ if not isinstance(maybe_log_prob, torch.Tensor):
+ # In some cases (Composite distribution with aggregate_probabilities toggled off) the returned type may not
+ # be a tensor
+ log_prob = maybe_log_prob.get(self.tensor_keys.sample_log_prob)
+ else:
+ log_prob = maybe_log_prob
log_prob = log_prob.unsqueeze(-1)
return log_prob, dist
- def loss_critic(self, tensordict: TensorDictBase) -> torch.Tensor:
+ def loss_critic(self, tensordict: TensorDictBase) -> Tuple[torch.Tensor, float]:
+ """Returns the loss value of the critic, multiplied by ``critic_coef`` if it is not ``None``.
+
+ Returns the loss and the clip-fraction.
+
+ """
if self.clip_value:
old_state_value = tensordict.get(
self.tensor_keys.value, None
@@ -456,7 +492,9 @@ def loss_critic(self, tensordict: TensorDictBase) -> torch.Tensor:
loss_value,
self.loss_critic_type,
)
- return self.critic_coef * loss_value, clip_fraction
+ if self.critic_coef is not None:
+ return self.critic_coef * loss_value, clip_fraction
+ return loss_value, clip_fraction
@property
@_cache_values
@@ -483,7 +521,7 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
entropy = self.get_entropy_bonus(dist)
td_out.set("entropy", entropy.detach().mean()) # for logging
td_out.set("loss_entropy", -self.entropy_coef * entropy)
- if self.critic_coef:
+ if self.critic_coef is not None:
loss_critic, value_clip_fraction = self.loss_critic(tensordict)
td_out.set("loss_critic", loss_critic)
if value_clip_fraction is not None:
diff --git a/torchrl/objectives/common.py b/torchrl/objectives/common.py
index 5ceec84e36a..f6935ceae82 100644
--- a/torchrl/objectives/common.py
+++ b/torchrl/objectives/common.py
@@ -15,6 +15,7 @@
from tensordict import is_tensor_collection, TensorDict, TensorDictBase
from tensordict.nn import TensorDictModule, TensorDictModuleBase, TensorDictParams
+from tensordict.utils import Buffer
from torch import nn
from torch.nn import Parameter
from torchrl._utils import RL_WARNINGS
@@ -23,9 +24,18 @@
from torchrl.objectives.utils import RANDOM_MODULE_LIST, ValueEstimators
from torchrl.objectives.value import ValueEstimatorBase
+try:
+ from torch.compiler import is_dynamo_compiling
+except ImportError:
+ from torch._dynamo import is_compiling as is_dynamo_compiling
+
def _updater_check_forward_prehook(module, *args, **kwargs):
- if not all(module._has_update_associated.values()) and RL_WARNINGS:
+ if (
+ not all(module._has_update_associated.values())
+ and RL_WARNINGS
+ and not is_dynamo_compiling()
+ ):
warnings.warn(
module.TARGET_NET_WARNING,
category=UserWarning,
@@ -87,8 +97,8 @@ class LossModule(TensorDictModuleBase, metaclass=_LossMeta):
>>> loss.set_keys(action="action2")
.. note:: When a policy that is wrapped or augmented with an exploration module is passed
- to the loss, we want to deactivate the exploration through ``set_exploration_mode()`` where
- ```` is either ``ExplorationType.MEAN``, ``ExplorationType.MODE`` or
+ to the loss, we want to deactivate the exploration through ``set_exploration_type()`` where
+ ```` is either ``ExplorationType.MEAN``, ``ExplorationType.MODE`` or
``ExplorationType.DETERMINISTIC``. The default value is ``DETERMINISTIC`` and it is set
through the ``deterministic_sampling_mode`` loss attribute. If another
exploration mode is required (or if ``DETERMINISTIC`` is not available), one can
@@ -217,8 +227,10 @@ def set_keys(self, **kwargs) -> None:
>>> dqn_loss.set_keys(priority_key="td_error", action_value_key="action_value")
"""
for key, value in kwargs.items():
- if key not in self._AcceptedKeys.__dict__:
- raise ValueError(f"{key} is not an accepted tensordict key")
+ if key not in self._AcceptedKeys.__dataclass_fields__:
+ raise ValueError(
+ f"{key} is not an accepted tensordict key. Accepted keys are: {self._AcceptedKeys.__dataclass_fields__}."
+ )
if value is not None:
setattr(self.tensor_keys, key, value)
else:
@@ -415,7 +427,11 @@ def __getattr__(self, item):
# no target param, take detached data
params = getattr(self, item[7:])
params = params.data
- elif not self._has_update_associated[item[7:-7]] and RL_WARNINGS:
+ elif (
+ not self._has_update_associated[item[7:-7]]
+ and RL_WARNINGS
+ and not is_dynamo_compiling()
+ ):
# no updater associated
warnings.warn(
self.TARGET_NET_WARNING,
@@ -433,7 +449,7 @@ def _apply(self, fn):
def _erase_cache(self):
for key in list(self.__dict__):
if key.startswith("_cache"):
- del self.__dict__[key]
+ delattr(self, key)
def _networks(self) -> Iterator[nn.Module]:
for item in self.__dir__():
@@ -603,11 +619,10 @@ def __init__(self, clone):
self.clone = clone
def __call__(self, x):
+ x = x.data.clone() if self.clone else x.data
if isinstance(x, nn.Parameter):
- return nn.Parameter(
- x.data.clone() if self.clone else x.data, requires_grad=False
- )
- return x.data.clone() if self.clone else x.data
+ return Buffer(x)
+ return x
def add_ramdom_module(module):
diff --git a/torchrl/objectives/cql.py b/torchrl/objectives/cql.py
index f1e2aa9c532..fb8fbff2ccf 100644
--- a/torchrl/objectives/cql.py
+++ b/torchrl/objectives/cql.py
@@ -19,7 +19,7 @@
from tensordict.utils import NestedKey, unravel_key
from torch import Tensor
-from torchrl.data.tensor_specs import CompositeSpec
+from torchrl.data.tensor_specs import Composite
from torchrl.data.utils import _find_action_space
from torchrl.envs.utils import ExplorationType, set_exploration_type
@@ -100,14 +100,14 @@ class CQLLoss(LossModule):
Examples:
>>> import torch
>>> from torch import nn
- >>> from torchrl.data import BoundedTensorSpec
+ >>> from torchrl.data import Bounded
>>> from torchrl.modules.distributions import NormalParamExtractor, TanhNormal
>>> from torchrl.modules.tensordict_module.actors import ProbabilisticActor, ValueOperator
>>> from torchrl.modules.tensordict_module.common import SafeModule
>>> from torchrl.objectives.cql import CQLLoss
>>> from tensordict import TensorDict
>>> n_act, n_obs = 4, 3
- >>> spec = BoundedTensorSpec(-torch.ones(n_act), torch.ones(n_act), (n_act,))
+ >>> spec = Bounded(-torch.ones(n_act), torch.ones(n_act), (n_act,))
>>> net = nn.Sequential(nn.Linear(n_obs, 2 * n_act), NormalParamExtractor())
>>> module = SafeModule(net, in_keys=["observation"], out_keys=["loc", "scale"])
>>> actor = ProbabilisticActor(
@@ -160,14 +160,14 @@ class CQLLoss(LossModule):
Examples:
>>> import torch
>>> from torch import nn
- >>> from torchrl.data import BoundedTensorSpec
+ >>> from torchrl.data import Bounded
>>> from torchrl.modules.distributions import NormalParamExtractor, TanhNormal
>>> from torchrl.modules.tensordict_module.actors import ProbabilisticActor, ValueOperator
>>> from torchrl.modules.tensordict_module.common import SafeModule
>>> from torchrl.objectives.cql import CQLLoss
>>> _ = torch.manual_seed(42)
>>> n_act, n_obs = 4, 3
- >>> spec = BoundedTensorSpec(-torch.ones(n_act), torch.ones(n_act), (n_act,))
+ >>> spec = Bounded(-torch.ones(n_act), torch.ones(n_act), (n_act,))
>>> net = nn.Sequential(nn.Linear(n_obs, 2 * n_act), NormalParamExtractor())
>>> module = SafeModule(net, in_keys=["observation"], out_keys=["loc", "scale"])
>>> actor = ProbabilisticActor(
@@ -405,8 +405,8 @@ def target_entropy(self):
"the target entropy explicitely or provide the spec of the "
"action tensor in the actor network."
)
- if not isinstance(action_spec, CompositeSpec):
- action_spec = CompositeSpec({self.tensor_keys.action: action_spec})
+ if not isinstance(action_spec, Composite):
+ action_spec = Composite({self.tensor_keys.action: action_spec})
if (
isinstance(self.tensor_keys.action, tuple)
and len(self.tensor_keys.action) > 1
@@ -521,16 +521,16 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
else:
tensordict_reshape = tensordict
- td_device = tensordict_reshape.to(tensordict.device)
-
- q_loss, metadata = self.q_loss(td_device)
- cql_loss, cql_metadata = self.cql_loss(td_device)
+ q_loss, metadata = self.q_loss(tensordict_reshape)
+ cql_loss, cql_metadata = self.cql_loss(tensordict_reshape)
if self.with_lagrange:
- alpha_prime_loss, alpha_prime_metadata = self.alpha_prime_loss(td_device)
+ alpha_prime_loss, alpha_prime_metadata = self.alpha_prime_loss(
+ tensordict_reshape
+ )
metadata.update(alpha_prime_metadata)
- loss_actor_bc, bc_metadata = self.actor_bc_loss(td_device)
- loss_actor, actor_metadata = self.actor_loss(td_device)
- loss_alpha, alpha_metadata = self.alpha_loss(td_device)
+ loss_actor_bc, bc_metadata = self.actor_bc_loss(tensordict_reshape)
+ loss_actor, actor_metadata = self.actor_loss(tensordict_reshape)
+ loss_alpha, alpha_metadata = self.alpha_loss(actor_metadata)
metadata.update(bc_metadata)
metadata.update(cql_metadata)
metadata.update(actor_metadata)
@@ -547,7 +547,7 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
"loss_cql": cql_loss,
"loss_alpha": loss_alpha,
"alpha": self._alpha,
- "entropy": -td_device.get(self.tensor_keys.log_prob).mean().detach(),
+ "entropy": -actor_metadata.get(self.tensor_keys.log_prob).mean().detach(),
}
if self.with_lagrange:
out["loss_alpha_prime"] = alpha_prime_loss.mean()
@@ -574,7 +574,7 @@ def actor_bc_loss(self, tensordict: TensorDictBase) -> Tensor:
metadata = {"bc_log_prob": bc_log_prob.mean().detach()}
return bc_actor_loss, metadata
- def actor_loss(self, tensordict: TensorDictBase) -> Tensor:
+ def actor_loss(self, tensordict: TensorDictBase) -> Tuple[Tensor, dict]:
with set_exploration_type(
ExplorationType.RANDOM
), self.actor_network_params.to_module(self.actor_network):
@@ -585,6 +585,8 @@ def actor_loss(self, tensordict: TensorDictBase) -> Tensor:
log_prob = dist.log_prob(a_reparm)
td_q = tensordict.select(*self.qvalue_network.in_keys, strict=False)
+ if td_q is tensordict:
+ raise RuntimeError
td_q.set(self.tensor_keys.action, a_reparm)
td_q = self._vmap_qvalue_networkN0(
td_q,
@@ -599,12 +601,12 @@ def actor_loss(self, tensordict: TensorDictBase) -> Tensor:
f"Losses shape mismatch: {log_prob.shape} and {min_q_logprob.shape}"
)
- # write log_prob in tensordict for alpha loss
- tensordict.set(self.tensor_keys.log_prob, log_prob.detach())
+ metadata = {}
+ metadata[self.tensor_keys.log_prob] = log_prob.detach()
actor_loss = self._alpha * log_prob - min_q_logprob
actor_loss = _reduce(actor_loss, reduction=self.reduction)
- return actor_loss, {}
+ return actor_loss, metadata
def _get_policy_actions(self, data, actor_params, num_actions=10):
batch_size = data.batch_size
@@ -667,7 +669,7 @@ def _get_value_v(self, tensordict, _alpha, actor_params, qval_params):
if self.max_q_backup:
next_tensordict, _ = self._get_policy_actions(
- tensordict.get("next"),
+ tensordict.get("next").copy(),
actor_params,
num_actions=self.num_random,
)
@@ -691,10 +693,10 @@ def _get_value_v(self, tensordict, _alpha, actor_params, qval_params):
target_value = self.value_estimator.value_estimate(tensordict).squeeze(-1)
return target_value
- def q_loss(self, tensordict: TensorDictBase) -> Tensor:
+ def q_loss(self, tensordict: TensorDictBase) -> Tuple[Tensor, dict]:
# we pass the alpha value to the tensordict. Since it's a scalar, we must erase the batch-size first.
target_value = self._get_value_v(
- tensordict,
+ tensordict.copy(),
self._alpha,
self.actor_network_params,
self.target_qvalue_network_params,
@@ -722,7 +724,7 @@ def q_loss(self, tensordict: TensorDictBase) -> Tensor:
metadata = {"td_error": td_error.detach()}
return loss_qval, metadata
- def cql_loss(self, tensordict: TensorDictBase) -> Tensor:
+ def cql_loss(self, tensordict: TensorDictBase) -> Tuple[Tensor, dict]:
pred_q1 = tensordict.get(self.tensor_keys.pred_q1)
pred_q2 = tensordict.get(self.tensor_keys.pred_q2)
@@ -746,12 +748,12 @@ def cql_loss(self, tensordict: TensorDictBase) -> Tensor:
.to(tensordict.device)
)
curr_actions_td, curr_log_pis = self._get_policy_actions(
- tensordict,
+ tensordict.copy(),
self.actor_network_params,
num_actions=self.num_random,
)
new_curr_actions_td, new_log_pis = self._get_policy_actions(
- tensordict.get("next"),
+ tensordict.get("next").copy(),
self.actor_network_params,
num_actions=self.num_random,
)
@@ -933,11 +935,11 @@ class DiscreteCQLLoss(LossModule):
Examples:
>>> from torchrl.modules import MLP, QValueActor
- >>> from torchrl.data import OneHotDiscreteTensorSpec
+ >>> from torchrl.data import OneHot
>>> from torchrl.objectives import DiscreteCQLLoss
>>> n_obs, n_act = 4, 3
>>> value_net = MLP(in_features=n_obs, out_features=n_act)
- >>> spec = OneHotDiscreteTensorSpec(n_act)
+ >>> spec = OneHot(n_act)
>>> actor = QValueActor(value_net, in_keys=["observation"], action_space=spec)
>>> loss = DiscreteCQLLoss(actor, action_space=spec)
>>> batch = [10,]
@@ -969,12 +971,12 @@ class DiscreteCQLLoss(LossModule):
Examples:
>>> from torchrl.objectives import DiscreteCQLLoss
- >>> from torchrl.data import OneHotDiscreteTensorSpec
+ >>> from torchrl.data import OneHot
>>> from torch import nn
>>> import torch
>>> n_obs = 3
>>> n_action = 4
- >>> action_spec = OneHotDiscreteTensorSpec(n_action)
+ >>> action_spec = OneHot(n_action)
>>> value_network = nn.Linear(n_obs, n_action) # a simple value model
>>> dcql_loss = DiscreteCQLLoss(value_network, action_space=action_spec)
>>> # define data
@@ -1089,7 +1091,7 @@ def __init__(
if action_space is None:
warnings.warn(
"action_space was not specified. DiscreteCQLLoss will default to 'one-hot'. "
- "This behaviour will be deprecated soon and a space will have to be passed. "
+ "This behavior will be deprecated soon and a space will have to be passed. "
"Check the DiscreteCQLLoss documentation to see how to pass the action space."
)
action_space = "one-hot"
diff --git a/torchrl/objectives/crossq.py b/torchrl/objectives/crossq.py
index e76e3438c09..d86442fca12 100644
--- a/torchrl/objectives/crossq.py
+++ b/torchrl/objectives/crossq.py
@@ -15,7 +15,7 @@
from tensordict.nn import dispatch, TensorDictModule
from tensordict.utils import NestedKey
from torch import Tensor
-from torchrl.data.tensor_specs import CompositeSpec
+from torchrl.data.tensor_specs import Composite
from torchrl.envs.utils import ExplorationType, set_exploration_type
from torchrl.modules import ProbabilisticActor
from torchrl.objectives.common import LossModule
@@ -98,14 +98,14 @@ class CrossQLoss(LossModule):
Examples:
>>> import torch
>>> from torch import nn
- >>> from torchrl.data import BoundedTensorSpec
+ >>> from torchrl.data import Bounded
>>> from torchrl.modules.distributions import NormalParamExtractor, TanhNormal
>>> from torchrl.modules.tensordict_module.actors import ProbabilisticActor, ValueOperator
>>> from torchrl.modules.tensordict_module.common import SafeModule
>>> from torchrl.objectives.crossq import CrossQLoss
>>> from tensordict import TensorDict
>>> n_act, n_obs = 4, 3
- >>> spec = BoundedTensorSpec(-torch.ones(n_act), torch.ones(n_act), (n_act,))
+ >>> spec = Bounded(-torch.ones(n_act), torch.ones(n_act), (n_act,))
>>> net = nn.Sequential(nn.Linear(n_obs, 2 * n_act), NormalParamExtractor())
>>> module = SafeModule(net, in_keys=["observation"], out_keys=["loc", "scale"])
>>> actor = ProbabilisticActor(
@@ -156,14 +156,14 @@ class CrossQLoss(LossModule):
Examples:
>>> import torch
>>> from torch import nn
- >>> from torchrl.data import BoundedTensorSpec
+ >>> from torchrl.data import Bounded
>>> from torchrl.modules.distributions import NormalParamExtractor, TanhNormal
>>> from torchrl.modules.tensordict_module.actors import ProbabilisticActor, ValueOperator
>>> from torchrl.modules.tensordict_module.common import SafeModule
>>> from torchrl.objectives import CrossQLoss
>>> _ = torch.manual_seed(42)
>>> n_act, n_obs = 4, 3
- >>> spec = BoundedTensorSpec(-torch.ones(n_act), torch.ones(n_act), (n_act,))
+ >>> spec = Bounded(-torch.ones(n_act), torch.ones(n_act), (n_act,))
>>> net = nn.Sequential(nn.Linear(n_obs, 2 * n_act), NormalParamExtractor())
>>> module = SafeModule(net, in_keys=["observation"], out_keys=["loc", "scale"])
>>> actor = ProbabilisticActor(
@@ -375,8 +375,8 @@ def target_entropy(self):
"the target entropy explicitely or provide the spec of the "
"action tensor in the actor network."
)
- if not isinstance(action_spec, CompositeSpec):
- action_spec = CompositeSpec({self.tensor_keys.action: action_spec})
+ if not isinstance(action_spec, Composite):
+ action_spec = Composite({self.tensor_keys.action: action_spec})
if (
isinstance(self.tensor_keys.action, tuple)
and len(self.tensor_keys.action) > 1
diff --git a/torchrl/objectives/ddpg.py b/torchrl/objectives/ddpg.py
index 6e1cf0f5eb3..7dc6b23212a 100644
--- a/torchrl/objectives/ddpg.py
+++ b/torchrl/objectives/ddpg.py
@@ -50,12 +50,12 @@ class DDPGLoss(LossModule):
Examples:
>>> import torch
>>> from torch import nn
- >>> from torchrl.data import BoundedTensorSpec
+ >>> from torchrl.data import Bounded
>>> from torchrl.modules.tensordict_module.actors import Actor, ValueOperator
>>> from torchrl.objectives.ddpg import DDPGLoss
>>> from tensordict import TensorDict
>>> n_act, n_obs = 4, 3
- >>> spec = BoundedTensorSpec(-torch.ones(n_act), torch.ones(n_act), (n_act,))
+ >>> spec = Bounded(-torch.ones(n_act), torch.ones(n_act), (n_act,))
>>> actor = Actor(spec=spec, module=nn.Linear(n_obs, n_act))
>>> class ValueClass(nn.Module):
... def __init__(self):
@@ -100,12 +100,12 @@ class DDPGLoss(LossModule):
Examples:
>>> import torch
>>> from torch import nn
- >>> from torchrl.data import BoundedTensorSpec
+ >>> from torchrl.data import Bounded
>>> from torchrl.modules.tensordict_module.actors import Actor, ValueOperator
>>> from torchrl.objectives.ddpg import DDPGLoss
>>> _ = torch.manual_seed(42)
>>> n_act, n_obs = 4, 3
- >>> spec = BoundedTensorSpec(-torch.ones(n_act), torch.ones(n_act), (n_act,))
+ >>> spec = Bounded(-torch.ones(n_act), torch.ones(n_act), (n_act,))
>>> actor = Actor(spec=spec, module=nn.Linear(n_obs, n_act))
>>> class ValueClass(nn.Module):
... def __init__(self):
diff --git a/torchrl/objectives/deprecated.py b/torchrl/objectives/deprecated.py
index c1ed8b2cffe..32394942600 100644
--- a/torchrl/objectives/deprecated.py
+++ b/torchrl/objectives/deprecated.py
@@ -17,7 +17,7 @@
from tensordict.utils import NestedKey
from torch import Tensor
-from torchrl.data.tensor_specs import CompositeSpec
+from torchrl.data.tensor_specs import Composite
from torchrl.envs.utils import ExplorationType, set_exploration_type, step_mdp
from torchrl.objectives import default_value_kwargs, distance_loss, ValueEstimators
from torchrl.objectives.common import LossModule
@@ -251,8 +251,8 @@ def target_entropy(self):
"the target entropy explicitely or provide the spec of the "
"action tensor in the actor network."
)
- if not isinstance(action_spec, CompositeSpec):
- action_spec = CompositeSpec({self.tensor_keys.action: action_spec})
+ if not isinstance(action_spec, Composite):
+ action_spec = Composite({self.tensor_keys.action: action_spec})
target_entropy = -float(
np.prod(action_spec[self.tensor_keys.action].shape)
)
@@ -465,7 +465,7 @@ def make_value_estimator(self, value_type: ValueEstimators = None, **hyperparams
class DoubleREDQLoss_deprecated(REDQLoss_deprecated):
- """[Deprecated] Class for delayed target-REDQ (which should be the default behaviour)."""
+ """[Deprecated] Class for delayed target-REDQ (which should be the default behavior)."""
delay_qvalue: bool = True
diff --git a/torchrl/objectives/dqn.py b/torchrl/objectives/dqn.py
index 7b35598c474..6cbb8b02426 100644
--- a/torchrl/objectives/dqn.py
+++ b/torchrl/objectives/dqn.py
@@ -47,14 +47,14 @@ class DQNLoss(LossModule):
Defaults to "l2".
delay_value (bool, optional): whether to duplicate the value network
into a new target value network to
- create a DQN with a target network. Default is ``False``.
+ create a DQN with a target network. Default is ``True``.
double_dqn (bool, optional): whether to use Double DQN, as described in
https://arxiv.org/abs/1509.06461. Defaults to ``False``.
action_space (str or TensorSpec, optional): Action space. Must be one of
``"one-hot"``, ``"mult_one_hot"``, ``"binary"`` or ``"categorical"``,
- or an instance of the corresponding specs (:class:`torchrl.data.OneHotDiscreteTensorSpec`,
- :class:`torchrl.data.MultiOneHotDiscreteTensorSpec`,
- :class:`torchrl.data.BinaryDiscreteTensorSpec` or :class:`torchrl.data.DiscreteTensorSpec`).
+ or an instance of the corresponding specs (:class:`torchrl.data.OneHot`,
+ :class:`torchrl.data.MultiOneHot`,
+ :class:`torchrl.data.Binary` or :class:`torchrl.data.Categorical`).
If not provided, an attempt to retrieve it from the value network
will be made.
priority_key (NestedKey, optional): [Deprecated, use .set_keys(priority_key=priority_key) instead]
@@ -68,10 +68,10 @@ class DQNLoss(LossModule):
Examples:
>>> from torchrl.modules import MLP
- >>> from torchrl.data import OneHotDiscreteTensorSpec
+ >>> from torchrl.data import OneHot
>>> n_obs, n_act = 4, 3
>>> value_net = MLP(in_features=n_obs, out_features=n_act)
- >>> spec = OneHotDiscreteTensorSpec(n_act)
+ >>> spec = OneHot(n_act)
>>> actor = QValueActor(value_net, in_keys=["observation"], action_space=spec)
>>> loss = DQNLoss(actor, action_space=spec)
>>> batch = [10,]
@@ -99,12 +99,12 @@ class DQNLoss(LossModule):
Examples:
>>> from torchrl.objectives import DQNLoss
- >>> from torchrl.data import OneHotDiscreteTensorSpec
+ >>> from torchrl.data import OneHot
>>> from torch import nn
>>> import torch
>>> n_obs = 3
>>> n_action = 4
- >>> action_spec = OneHotDiscreteTensorSpec(n_action)
+ >>> action_spec = OneHot(n_action)
>>> value_network = nn.Linear(n_obs, n_action) # a simple value model
>>> dqn_loss = DQNLoss(value_network, action_space=action_spec)
>>> # define data
@@ -224,7 +224,7 @@ def __init__(
if action_space is None:
warnings.warn(
"action_space was not specified. DQNLoss will default to 'one-hot'."
- "This behaviour will be deprecated soon and a space will have to be passed."
+ "This behavior will be deprecated soon and a space will have to be passed."
"Check the DQNLoss documentation to see how to pass the action space. "
)
action_space = "one-hot"
diff --git a/torchrl/objectives/functional.py b/torchrl/objectives/functional.py
index 7c598676794..fd96b2e92a3 100644
--- a/torchrl/objectives/functional.py
+++ b/torchrl/objectives/functional.py
@@ -20,7 +20,7 @@ def cross_entropy_loss(
(integer representation) or log_policy.shape (one-hot).
inplace: fills log_policy in-place with 0.0 at non-selected actions before summing along the last dimensions.
This is usually faster but it will change the value of log-policy in place, which may lead to unwanted
- behaviours.
+ behaviors.
"""
if action.shape == log_policy.shape:
diff --git a/torchrl/objectives/gail.py b/torchrl/objectives/gail.py
new file mode 100644
index 00000000000..3c0050fca84
--- /dev/null
+++ b/torchrl/objectives/gail.py
@@ -0,0 +1,251 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+from __future__ import annotations
+
+from dataclasses import dataclass
+
+import torch
+
+import torch.autograd as autograd
+from tensordict import TensorDict, TensorDictBase, TensorDictParams
+from tensordict.nn import dispatch, TensorDictModule
+from tensordict.utils import NestedKey
+
+from torchrl.objectives.common import LossModule
+from torchrl.objectives.utils import _reduce
+
+
+class GAILLoss(LossModule):
+ r"""TorchRL implementation of the Generative Adversarial Imitation Learning (GAIL) loss.
+
+ Presented in `"Generative Adversarial Imitation Learning" `
+
+ Args:
+ discriminator_network (TensorDictModule): stochastic actor
+
+ Keyword Args:
+ use_grad_penalty (bool, optional): Whether to use gradient penalty. Default: ``False``.
+ gp_lambda (float, optional): Gradient penalty lambda. Default: ``10``.
+ reduction (str, optional): Specifies the reduction to apply to the output:
+ ``"none"`` | ``"mean"`` | ``"sum"``. ``"none"``: no reduction will be applied,
+ ``"mean"``: the sum of the output will be divided by the number of
+ elements in the output, ``"sum"``: the output will be summed. Default: ``"mean"``.
+ """
+
+ @dataclass
+ class _AcceptedKeys:
+ """Maintains default values for all configurable tensordict keys.
+
+ This class defines which tensordict keys can be set using '.set_keys(key_name=key_value)' and their
+ default values.
+
+ Attributes:
+ expert_action (NestedKey): The input tensordict key where the action is expected.
+ Defaults to ``"action"``.
+ expert_observation (NestedKey): The tensordict key where the observation is expected.
+ Defaults to ``"observation"``.
+ collector_action (NestedKey): The tensordict key where the collector action is expected.
+ Defaults to ``"collector_action"``.
+ collector_observation (NestedKey): The tensordict key where the collector observation is expected.
+ Defaults to ``"collector_observation"``.
+ discriminator_pred (NestedKey): The tensordict key where the discriminator prediction is expected.
+ """
+
+ expert_action: NestedKey = "action"
+ expert_observation: NestedKey = "observation"
+ collector_action: NestedKey = "collector_action"
+ collector_observation: NestedKey = "collector_observation"
+ discriminator_pred: NestedKey = "d_logits"
+
+ default_keys = _AcceptedKeys()
+
+ discriminator_network: TensorDictModule
+ discriminator_network_params: TensorDictParams
+ target_discriminator_network: TensorDictModule
+ target_discriminator_network_params: TensorDictParams
+
+ out_keys = [
+ "loss",
+ "gp_loss",
+ ]
+
+ def __init__(
+ self,
+ discriminator_network: TensorDictModule,
+ *,
+ use_grad_penalty: bool = False,
+ gp_lambda: float = 10,
+ reduction: str = None,
+ ) -> None:
+ self._in_keys = None
+ self._out_keys = None
+ if reduction is None:
+ reduction = "mean"
+ super().__init__()
+
+ # Discriminator Network
+ self.convert_to_functional(
+ discriminator_network,
+ "discriminator_network",
+ create_target_params=False,
+ )
+ self.loss_function = torch.nn.BCELoss(reduction="none")
+ self.use_grad_penalty = use_grad_penalty
+ self.gp_lambda = gp_lambda
+
+ self.reduction = reduction
+
+ def _set_in_keys(self):
+ keys = self.discriminator_network.in_keys
+ keys = set(keys)
+ keys.add(self.tensor_keys.expert_observation)
+ keys.add(self.tensor_keys.expert_action)
+ keys.add(self.tensor_keys.collector_observation)
+ keys.add(self.tensor_keys.collector_action)
+ self._in_keys = sorted(keys, key=str)
+
+ def _forward_value_estimator_keys(self, **kwargs) -> None:
+ pass
+
+ @property
+ def in_keys(self):
+ if self._in_keys is None:
+ self._set_in_keys()
+ return self._in_keys
+
+ @in_keys.setter
+ def in_keys(self, values):
+ self._in_keys = values
+
+ @property
+ def out_keys(self):
+ if self._out_keys is None:
+ keys = ["loss"]
+ if self.use_grad_penalty:
+ keys.append("gp_loss")
+ self._out_keys = keys
+ return self._out_keys
+
+ @out_keys.setter
+ def out_keys(self, values):
+ self._out_keys = values
+
+ @dispatch
+ def forward(
+ self,
+ tensordict: TensorDictBase,
+ ) -> TensorDictBase:
+ """The forward method.
+
+ Computes the discriminator loss and gradient penalty if `use_grad_penalty` is set to True. If `use_grad_penalty` is set to True, the detached gradient penalty loss is also returned for logging purposes.
+ To see what keys are expected in the input tensordict and what keys are expected as output, check the
+ class's `"in_keys"` and `"out_keys"` attributes.
+ """
+ device = self.discriminator_network.device
+ tensordict = tensordict.clone(False)
+ shape = tensordict.shape
+ if len(shape) > 1:
+ batch_size, seq_len = shape
+ else:
+ batch_size = shape[0]
+ collector_obs = tensordict.get(self.tensor_keys.collector_observation)
+ collector_act = tensordict.get(self.tensor_keys.collector_action)
+
+ expert_obs = tensordict.get(self.tensor_keys.expert_observation)
+ expert_act = tensordict.get(self.tensor_keys.expert_action)
+
+ combined_obs_inputs = torch.cat([expert_obs, collector_obs], dim=0)
+ combined_act_inputs = torch.cat([expert_act, collector_act], dim=0)
+
+ combined_inputs = TensorDict(
+ {
+ self.tensor_keys.expert_observation: combined_obs_inputs,
+ self.tensor_keys.expert_action: combined_act_inputs,
+ },
+ batch_size=[2 * batch_size],
+ device=device,
+ )
+
+ # create
+ if len(shape) > 1:
+ fake_labels = torch.zeros((batch_size, seq_len, 1), dtype=torch.float32).to(
+ device
+ )
+ real_labels = torch.ones((batch_size, seq_len, 1), dtype=torch.float32).to(
+ device
+ )
+ else:
+ fake_labels = torch.zeros((batch_size, 1), dtype=torch.float32).to(device)
+ real_labels = torch.ones((batch_size, 1), dtype=torch.float32).to(device)
+
+ with self.discriminator_network_params.to_module(self.discriminator_network):
+ d_logits = self.discriminator_network(combined_inputs).get(
+ self.tensor_keys.discriminator_pred
+ )
+
+ expert_preds, collection_preds = torch.split(
+ d_logits, [batch_size, batch_size], dim=0
+ )
+
+ expert_loss = self.loss_function(expert_preds, real_labels)
+ collection_loss = self.loss_function(collection_preds, fake_labels)
+
+ loss = expert_loss + collection_loss
+ out = {}
+ if self.use_grad_penalty:
+ obs = tensordict.get(self.tensor_keys.collector_observation)
+ acts = tensordict.get(self.tensor_keys.collector_action)
+ obs_e = tensordict.get(self.tensor_keys.expert_observation)
+ acts_e = tensordict.get(self.tensor_keys.expert_action)
+
+ obss_noise = (
+ torch.distributions.Uniform(0.0, 1.0).sample(obs_e.shape).to(device)
+ )
+ acts_noise = (
+ torch.distributions.Uniform(0.0, 1.0).sample(acts_e.shape).to(device)
+ )
+ obss_mixture = obss_noise * obs + (1 - obss_noise) * obs_e
+ acts_mixture = acts_noise * acts + (1 - acts_noise) * acts_e
+ obss_mixture.requires_grad_(True)
+ acts_mixture.requires_grad_(True)
+
+ pg_input_td = TensorDict(
+ {
+ self.tensor_keys.expert_observation: obss_mixture,
+ self.tensor_keys.expert_action: acts_mixture,
+ },
+ [],
+ device=device,
+ )
+
+ with self.discriminator_network_params.to_module(
+ self.discriminator_network
+ ):
+ d_logits_mixture = self.discriminator_network(pg_input_td).get(
+ self.tensor_keys.discriminator_pred
+ )
+
+ gradients = torch.cat(
+ autograd.grad(
+ outputs=d_logits_mixture,
+ inputs=(obss_mixture, acts_mixture),
+ grad_outputs=torch.ones(d_logits_mixture.size(), device=device),
+ create_graph=True,
+ retain_graph=True,
+ only_inputs=True,
+ ),
+ dim=-1,
+ )
+
+ gp_loss = self.gp_lambda * torch.mean(
+ (torch.linalg.norm(gradients, dim=-1) - 1) ** 2
+ )
+
+ loss += gp_loss
+ out["gp_loss"] = gp_loss.detach()
+ loss = _reduce(loss, reduction=self.reduction)
+ out["loss"] = loss
+ td_out = TensorDict(out, [])
+ return td_out
diff --git a/torchrl/objectives/iql.py b/torchrl/objectives/iql.py
index 74cfe504e78..c4639b70bdd 100644
--- a/torchrl/objectives/iql.py
+++ b/torchrl/objectives/iql.py
@@ -73,14 +73,14 @@ class IQLLoss(LossModule):
Examples:
>>> import torch
>>> from torch import nn
- >>> from torchrl.data import BoundedTensorSpec
+ >>> from torchrl.data import Bounded
>>> from torchrl.modules.distributions import NormalParamExtractor, TanhNormal
>>> from torchrl.modules.tensordict_module.actors import ProbabilisticActor, ValueOperator
>>> from torchrl.modules.tensordict_module.common import SafeModule
>>> from torchrl.objectives.iql import IQLLoss
>>> from tensordict import TensorDict
>>> n_act, n_obs = 4, 3
- >>> spec = BoundedTensorSpec(-torch.ones(n_act), torch.ones(n_act), (n_act,))
+ >>> spec = Bounded(-torch.ones(n_act), torch.ones(n_act), (n_act,))
>>> net = nn.Sequential(nn.Linear(n_obs, 2 * n_act), NormalParamExtractor())
>>> module = SafeModule(net, in_keys=["observation"], out_keys=["loc", "scale"])
>>> actor = ProbabilisticActor(
@@ -136,14 +136,14 @@ class IQLLoss(LossModule):
Examples:
>>> import torch
>>> from torch import nn
- >>> from torchrl.data import BoundedTensorSpec
+ >>> from torchrl.data import Bounded
>>> from torchrl.modules.distributions import NormalParamExtractor, TanhNormal
>>> from torchrl.modules.tensordict_module.actors import ProbabilisticActor, ValueOperator
>>> from torchrl.modules.tensordict_module.common import SafeModule
>>> from torchrl.objectives.iql import IQLLoss
>>> _ = torch.manual_seed(42)
>>> n_act, n_obs = 4, 3
- >>> spec = BoundedTensorSpec(-torch.ones(n_act), torch.ones(n_act), (n_act,))
+ >>> spec = Bounded(-torch.ones(n_act), torch.ones(n_act), (n_act,))
>>> net = nn.Sequential(nn.Linear(n_obs, 2 * n_act), NormalParamExtractor())
>>> module = SafeModule(net, in_keys=["observation"], out_keys=["loc", "scale"])
>>> actor = ProbabilisticActor(
@@ -383,7 +383,8 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
loss_actor, metadata = self.actor_loss(tensordict_reshape)
loss_qvalue, metadata_qvalue = self.qvalue_loss(tensordict_reshape)
loss_value, metadata_value = self.value_loss(tensordict_reshape)
- metadata.update(**metadata_qvalue, **metadata_value)
+ metadata.update(metadata_qvalue)
+ metadata.update(metadata_value)
if (loss_actor.shape != loss_qvalue.shape) or (
loss_value is not None and loss_actor.shape != loss_value.shape
@@ -410,7 +411,7 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
[],
)
- def actor_loss(self, tensordict: TensorDictBase) -> Tensor:
+ def actor_loss(self, tensordict: TensorDictBase) -> Tuple[Tensor, dict]:
# KL loss
with self.actor_network_params.to_module(self.actor_network):
dist = self.actor_network.get_dist(tensordict)
@@ -446,7 +447,7 @@ def actor_loss(self, tensordict: TensorDictBase) -> Tensor:
loss_actor = _reduce(loss_actor, reduction=self.reduction)
return loss_actor, {}
- def value_loss(self, tensordict: TensorDictBase) -> Tuple[Tensor, Tensor]:
+ def value_loss(self, tensordict: TensorDictBase) -> Tuple[Tensor, dict]:
# Min Q value
td_q = tensordict.select(*self.qvalue_network.in_keys, strict=False)
td_q = self._vmap_qvalue_networkN0(td_q, self.target_qvalue_network_params)
@@ -460,7 +461,7 @@ def value_loss(self, tensordict: TensorDictBase) -> Tuple[Tensor, Tensor]:
value_loss = _reduce(value_loss, reduction=self.reduction)
return value_loss, {}
- def qvalue_loss(self, tensordict: TensorDictBase) -> Tuple[Tensor, Tensor]:
+ def qvalue_loss(self, tensordict: TensorDictBase) -> Tuple[Tensor, dict]:
obs_keys = self.actor_network.in_keys
tensordict = tensordict.select(
"next", *obs_keys, self.tensor_keys.action, strict=False
@@ -541,9 +542,9 @@ class DiscreteIQLLoss(IQLLoss):
Keyword Args:
action_space (str or TensorSpec): Action space. Must be one of
``"one-hot"``, ``"mult_one_hot"``, ``"binary"`` or ``"categorical"``,
- or an instance of the corresponding specs (:class:`torchrl.data.OneHotDiscreteTensorSpec`,
- :class:`torchrl.data.MultiOneHotDiscreteTensorSpec`,
- :class:`torchrl.data.BinaryDiscreteTensorSpec` or :class:`torchrl.data.DiscreteTensorSpec`).
+ or an instance of the corresponding specs (:class:`torchrl.data.OneHot`,
+ :class:`torchrl.data.MultiOneHot`,
+ :class:`torchrl.data.Binary` or :class:`torchrl.data.Categorical`).
num_qvalue_nets (integer, optional): number of Q-Value networks used.
Defaults to ``2``.
loss_function (str, optional): loss function to be used with
@@ -569,14 +570,14 @@ class DiscreteIQLLoss(IQLLoss):
Examples:
>>> import torch
>>> from torch import nn
- >>> from torchrl.data.tensor_specs import OneHotDiscreteTensorSpec
+ >>> from torchrl.data.tensor_specs import OneHot
>>> from torchrl.modules.distributions.discrete import OneHotCategorical
>>> from torchrl.modules.tensordict_module.actors import ProbabilisticActor
>>> from torchrl.modules.tensordict_module.common import SafeModule
>>> from torchrl.objectives.iql import DiscreteIQLLoss
>>> from tensordict import TensorDict
>>> n_act, n_obs = 4, 3
- >>> spec = OneHotDiscreteTensorSpec(n_act)
+ >>> spec = OneHot(n_act)
>>> module = SafeModule(nn.Linear(n_obs, n_act), in_keys=["observation"], out_keys=["logits"])
>>> actor = ProbabilisticActor(
... module=module,
@@ -627,14 +628,14 @@ class DiscreteIQLLoss(IQLLoss):
>>> import torch
>>> import torch
>>> from torch import nn
- >>> from torchrl.data.tensor_specs import OneHotDiscreteTensorSpec
+ >>> from torchrl.data.tensor_specs import OneHot
>>> from torchrl.modules.distributions.discrete import OneHotCategorical
>>> from torchrl.modules.tensordict_module.actors import ProbabilisticActor
>>> from torchrl.modules.tensordict_module.common import SafeModule
>>> from torchrl.objectives.iql import DiscreteIQLLoss
>>> _ = torch.manual_seed(42)
>>> n_act, n_obs = 4, 3
- >>> spec = OneHotDiscreteTensorSpec(n_act)
+ >>> spec = OneHot(n_act)
>>> module = SafeModule(nn.Linear(n_obs, n_act), in_keys=["observation"], out_keys=["logits"])
>>> actor = ProbabilisticActor(
... module=module,
@@ -774,14 +775,14 @@ def __init__(
if action_space is None:
warnings.warn(
"action_space was not specified. DiscreteIQLLoss will default to 'one-hot'."
- "This behaviour will be deprecated soon and a space will have to be passed."
+ "This behavior will be deprecated soon and a space will have to be passed."
"Check the DiscreteIQLLoss documentation to see how to pass the action space. "
)
action_space = "one-hot"
self.action_space = _find_action_space(action_space)
self.reduction = reduction
- def actor_loss(self, tensordict: TensorDictBase) -> Tensor:
+ def actor_loss(self, tensordict: TensorDictBase) -> Tuple[Tensor, dict]:
# KL loss
with self.actor_network_params.to_module(self.actor_network):
dist = self.actor_network.get_dist(tensordict)
@@ -828,7 +829,7 @@ def actor_loss(self, tensordict: TensorDictBase) -> Tensor:
loss_actor = _reduce(loss_actor, reduction=self.reduction)
return loss_actor, {}
- def value_loss(self, tensordict: TensorDictBase) -> Tuple[Tensor, Tensor]:
+ def value_loss(self, tensordict: TensorDictBase) -> Tuple[Tensor, dict]:
# Min Q value
with torch.no_grad():
# Min Q value
@@ -856,7 +857,7 @@ def value_loss(self, tensordict: TensorDictBase) -> Tuple[Tensor, Tensor]:
value_loss = _reduce(value_loss, reduction=self.reduction)
return value_loss, {}
- def qvalue_loss(self, tensordict: TensorDictBase) -> Tuple[Tensor, Tensor]:
+ def qvalue_loss(self, tensordict: TensorDictBase) -> Tuple[Tensor, dict]:
obs_keys = self.actor_network.in_keys
next_td = tensordict.select(
"next", *obs_keys, self.tensor_keys.action, strict=False
diff --git a/torchrl/objectives/multiagent/qmixer.py b/torchrl/objectives/multiagent/qmixer.py
index c9dc281ef41..39777c59e26 100644
--- a/torchrl/objectives/multiagent/qmixer.py
+++ b/torchrl/objectives/multiagent/qmixer.py
@@ -63,9 +63,9 @@ class QMixerLoss(LossModule):
create a double DQN. Default is ``False``.
action_space (str or TensorSpec, optional): Action space. Must be one of
``"one-hot"``, ``"mult_one_hot"``, ``"binary"`` or ``"categorical"``,
- or an instance of the corresponding specs (:class:`torchrl.data.OneHotDiscreteTensorSpec`,
- :class:`torchrl.data.MultiOneHotDiscreteTensorSpec`,
- :class:`torchrl.data.BinaryDiscreteTensorSpec` or :class:`torchrl.data.DiscreteTensorSpec`).
+ or an instance of the corresponding specs (:class:`torchrl.data.OneHot`,
+ :class:`torchrl.data.MultiOneHot`,
+ :class:`torchrl.data.Binary` or :class:`torchrl.data.Categorical`).
If not provided, an attempt to retrieve it from the value network
will be made.
priority_key (NestedKey, optional): [Deprecated, use .set_keys(priority_key=priority_key) instead]
@@ -254,7 +254,7 @@ def __init__(
if action_space is None:
warnings.warn(
"action_space was not specified. QMixerLoss will default to 'one-hot'."
- "This behaviour will be deprecated soon and a space will have to be passed."
+ "This behavior will be deprecated soon and a space will have to be passed."
"Check the QMixerLoss documentation to see how to pass the action space. "
)
action_space = "one-hot"
diff --git a/torchrl/objectives/ppo.py b/torchrl/objectives/ppo.py
index c29bc73dfa8..efc951b3999 100644
--- a/torchrl/objectives/ppo.py
+++ b/torchrl/objectives/ppo.py
@@ -6,13 +6,17 @@
import contextlib
-import math
from copy import deepcopy
from dataclasses import dataclass
from typing import Tuple
import torch
-from tensordict import TensorDict, TensorDictBase, TensorDictParams
+from tensordict import (
+ is_tensor_collection,
+ TensorDict,
+ TensorDictBase,
+ TensorDictParams,
+)
from tensordict.nn import (
dispatch,
ProbabilisticTensorDictModule,
@@ -45,15 +49,15 @@
class PPOLoss(LossModule):
"""A parent PPO loss class.
- PPO (Proximal Policy Optimisation) is a model-free, online RL algorithm
+ PPO (Proximal Policy Optimization) is a model-free, online RL algorithm
that makes use of a recorded (batch of)
trajectories to perform several optimization steps, while actively
preventing the updated policy to deviate too
much from its original parameter configuration.
- PPO loss can be found in different flavours, depending on the way the
- constrained optimisation is implemented: ClipPPOLoss and KLPENPPOLoss.
- Unlike its subclasses, this class does not implement any regularisation
+ PPO loss can be found in different flavors, depending on the way the
+ constrained optimization is implemented: ClipPPOLoss and KLPENPPOLoss.
+ Unlike its subclasses, this class does not implement any regularization
and should therefore be used cautiously.
For more details regarding PPO, refer to: "Proximal Policy Optimization Algorithms",
@@ -75,7 +79,8 @@ class PPOLoss(LossModule):
entropy_coef (scalar, optional): entropy multiplier when computing the total loss.
Defaults to ``0.01``.
critic_coef (scalar, optional): critic loss multiplier when computing the total
- loss. Defaults to ``1.0``.
+ loss. Defaults to ``1.0``. Set ``critic_coef`` to ``None`` to exclude the value
+ loss from the forward outputs.
loss_critic_type (str, optional): loss function for the value discrepancy.
Can be one of "l1", "l2" or "smooth_l1". Defaults to ``"smooth_l1"``.
normalize_advantage (bool, optional): if ``True``, the advantage will be normalized
@@ -151,14 +156,14 @@ class PPOLoss(LossModule):
Examples:
>>> import torch
>>> from torch import nn
- >>> from torchrl.data.tensor_specs import BoundedTensorSpec
+ >>> from torchrl.data.tensor_specs import Bounded
>>> from torchrl.modules.distributions import NormalParamExtractor, TanhNormal
>>> from torchrl.modules.tensordict_module.actors import ProbabilisticActor, ValueOperator
>>> from torchrl.modules.tensordict_module.common import SafeModule
>>> from torchrl.objectives.ppo import PPOLoss
>>> from tensordict import TensorDict
>>> n_act, n_obs = 4, 3
- >>> spec = BoundedTensorSpec(-torch.ones(n_act), torch.ones(n_act), (n_act,))
+ >>> spec = Bounded(-torch.ones(n_act), torch.ones(n_act), (n_act,))
>>> base_layer = nn.Linear(n_obs, 5)
>>> net = nn.Sequential(base_layer, nn.Linear(5, 2 * n_act), NormalParamExtractor())
>>> module = SafeModule(net, in_keys=["observation"], out_keys=["loc", "scale"])
@@ -204,13 +209,13 @@ class PPOLoss(LossModule):
Examples:
>>> import torch
>>> from torch import nn
- >>> from torchrl.data.tensor_specs import BoundedTensorSpec
+ >>> from torchrl.data.tensor_specs import Bounded
>>> from torchrl.modules.distributions import NormalParamExtractor, TanhNormal
>>> from torchrl.modules.tensordict_module.actors import ProbabilisticActor, ValueOperator
>>> from torchrl.modules.tensordict_module.common import SafeModule
>>> from torchrl.objectives.ppo import PPOLoss
>>> n_act, n_obs = 4, 3
- >>> spec = BoundedTensorSpec(-torch.ones(n_act), torch.ones(n_act), (n_act,))
+ >>> spec = Bounded(-torch.ones(n_act), torch.ones(n_act), (n_act,))
>>> base_layer = nn.Linear(n_obs, 5)
>>> net = nn.Sequential(base_layer, nn.Linear(5, 2 * n_act), NormalParamExtractor())
>>> module = SafeModule(net, in_keys=["observation"], out_keys=["loc", "scale"])
@@ -238,6 +243,12 @@ class PPOLoss(LossModule):
... next_observation=torch.randn(*batch, n_obs))
>>> loss_objective.backward()
+ .. note::
+ There is an exception regarding compatibility with non-tensordict-based modules.
+ If the actor network is probabilistic and uses a :class:`~tensordict.nn.distributions.CompositeDistribution`,
+ this class must be used with tensordicts and cannot function as a tensordict-independent module.
+ This is because composite action spaces inherently rely on the structured representation of data provided by
+ tensordicts to handle their actions.
"""
@dataclass
@@ -360,7 +371,12 @@ def __init__(
device = torch.device("cpu")
self.register_buffer("entropy_coef", torch.tensor(entropy_coef, device=device))
- self.register_buffer("critic_coef", torch.tensor(critic_coef, device=device))
+ if critic_coef is not None:
+ self.register_buffer(
+ "critic_coef", torch.tensor(critic_coef, device=device)
+ )
+ else:
+ self.critic_coef = None
self.loss_critic_type = loss_critic_type
self.normalize_advantage = normalize_advantage
if gamma is not None:
@@ -449,7 +465,10 @@ def get_entropy_bonus(self, dist: d.Distribution) -> torch.Tensor:
entropy = dist.entropy()
except NotImplementedError:
x = dist.rsample((self.samples_mc_entropy,))
- entropy = -dist.log_prob(x).mean(0)
+ log_prob = dist.log_prob(x)
+ if is_tensor_collection(log_prob):
+ log_prob = log_prob.get(self.tensor_keys.sample_log_prob)
+ entropy = -log_prob.mean(0)
return entropy.unsqueeze(-1)
def _log_weight(
@@ -457,20 +476,32 @@ def _log_weight(
) -> Tuple[torch.Tensor, d.Distribution]:
# current log_prob of actions
action = tensordict.get(self.tensor_keys.action)
- if action.requires_grad:
- raise RuntimeError(
- f"tensordict stored {self.tensor_keys.action} requires grad."
- )
with self.actor_network_params.to_module(
self.actor_network
) if self.functional else contextlib.nullcontext():
dist = self.actor_network.get_dist(tensordict)
- log_prob = dist.log_prob(action)
prev_log_prob = tensordict.get(self.tensor_keys.sample_log_prob)
if prev_log_prob.requires_grad:
- raise RuntimeError("tensordict prev_log_prob requires grad.")
+ raise RuntimeError(
+ f"tensordict stored {self.tensor_keys.sample_log_prob} requires grad."
+ )
+
+ if action.requires_grad:
+ raise RuntimeError(
+ f"tensordict stored {self.tensor_keys.action} requires grad."
+ )
+ if isinstance(action, torch.Tensor):
+ log_prob = dist.log_prob(action)
+ else:
+ maybe_log_prob = dist.log_prob(tensordict)
+ if not isinstance(maybe_log_prob, torch.Tensor):
+ # In some cases (Composite distribution with aggregate_probabilities toggled off) the returned type may not
+ # be a tensor
+ log_prob = maybe_log_prob.get(self.tensor_keys.sample_log_prob)
+ else:
+ log_prob = maybe_log_prob
log_weight = (log_prob - prev_log_prob).unsqueeze(-1)
kl_approx = (prev_log_prob - log_prob).unsqueeze(-1)
@@ -478,6 +509,7 @@ def _log_weight(
return log_weight, dist, kl_approx
def loss_critic(self, tensordict: TensorDictBase) -> torch.Tensor:
+ """Returns the critic loss multiplied by ``critic_coef``, if it is not ``None``."""
# TODO: if the advantage is gathered by forward, this introduces an
# overhead that we could easily reduce.
if self.separate_losses:
@@ -536,7 +568,9 @@ def loss_critic(self, tensordict: TensorDictBase) -> torch.Tensor:
self.loss_critic_type,
)
- return self.critic_coef * loss_value, clip_fraction
+ if self.critic_coef is not None:
+ return self.critic_coef * loss_value, clip_fraction
+ return loss_value, clip_fraction
@property
@_cache_values
@@ -569,7 +603,7 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
td_out.set("entropy", entropy.detach().mean()) # for logging
td_out.set("kl_approx", kl_approx.detach().mean()) # for logging
td_out.set("loss_entropy", -self.entropy_coef * entropy)
- if self.critic_coef:
+ if self.critic_coef is not None:
loss_critic, value_clip_fraction = self.loss_critic(tensordict)
td_out.set("loss_critic", loss_critic)
if value_clip_fraction is not None:
@@ -653,7 +687,8 @@ class ClipPPOLoss(PPOLoss):
entropy_coef (scalar, optional): entropy multiplier when computing the total loss.
Defaults to ``0.01``.
critic_coef (scalar, optional): critic loss multiplier when computing the total
- loss. Defaults to ``1.0``.
+ loss. Defaults to ``1.0``. Set ``critic_coef`` to ``None`` to exclude the value
+ loss from the forward outputs.
loss_critic_type (str, optional): loss function for the value discrepancy.
Can be one of "l1", "l2" or "smooth_l1". Defaults to ``"smooth_l1"``.
normalize_advantage (bool, optional): if ``True``, the advantage will be normalized
@@ -774,13 +809,18 @@ def __init__(
clip_value=clip_value,
**kwargs,
)
- self.register_buffer("clip_epsilon", torch.tensor(clip_epsilon))
+ for p in self.parameters():
+ device = p.device
+ break
+ else:
+ device = None
+ self.register_buffer("clip_epsilon", torch.tensor(clip_epsilon, device=device))
@property
def _clip_bounds(self):
return (
- math.log1p(-self.clip_epsilon),
- math.log1p(self.clip_epsilon),
+ (-self.clip_epsilon).log1p(),
+ self.clip_epsilon.log1p(),
)
@property
@@ -843,7 +883,7 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
td_out.set("entropy", entropy.detach().mean()) # for logging
td_out.set("kl_approx", kl_approx.detach().mean()) # for logging
td_out.set("loss_entropy", -self.entropy_coef * entropy)
- if self.critic_coef:
+ if self.critic_coef is not None:
loss_critic, value_clip_fraction = self.loss_critic(tensordict)
td_out.set("loss_critic", loss_critic)
if value_clip_fraction is not None:
@@ -1107,7 +1147,17 @@ def forward(self, tensordict: TensorDictBase) -> TensorDict:
kl = torch.distributions.kl.kl_divergence(previous_dist, current_dist)
except NotImplementedError:
x = previous_dist.sample((self.samples_mc_kl,))
- kl = (previous_dist.log_prob(x) - current_dist.log_prob(x)).mean(0)
+ previous_log_prob = previous_dist.log_prob(x)
+ current_log_prob = current_dist.log_prob(x)
+ if is_tensor_collection(current_log_prob):
+ previous_log_prob = previous_log_prob.get(
+ self.tensor_keys.sample_log_prob
+ )
+ current_log_prob = current_log_prob.get(
+ self.tensor_keys.sample_log_prob
+ )
+
+ kl = (previous_log_prob - current_log_prob).mean(0)
kl = kl.unsqueeze(-1)
neg_loss = neg_loss - self.beta * kl
if kl.mean() > self.dtarg * 1.5:
@@ -1127,7 +1177,7 @@ def forward(self, tensordict: TensorDictBase) -> TensorDict:
td_out.set("entropy", entropy.detach().mean()) # for logging
td_out.set("kl_approx", kl_approx.detach().mean()) # for logging
td_out.set("loss_entropy", -self.entropy_coef * entropy)
- if self.critic_coef:
+ if self.critic_coef is not None:
loss_critic, value_clip_fraction = self.loss_critic(tensordict_copy)
td_out.set("loss_critic", loss_critic)
if value_clip_fraction is not None:
diff --git a/torchrl/objectives/redq.py b/torchrl/objectives/redq.py
index 1522fd7749e..271f233bae8 100644
--- a/torchrl/objectives/redq.py
+++ b/torchrl/objectives/redq.py
@@ -12,11 +12,11 @@
import torch
from tensordict import TensorDict, TensorDictBase, TensorDictParams
-from tensordict.nn import dispatch, TensorDictModule, TensorDictSequential
+from tensordict.nn import dispatch, TensorDictModule
from tensordict.utils import NestedKey
from torch import Tensor
-from torchrl.data.tensor_specs import CompositeSpec
+from torchrl.data.tensor_specs import Composite
from torchrl.envs.utils import ExplorationType, set_exploration_type, step_mdp
from torchrl.objectives.common import LossModule
@@ -93,14 +93,14 @@ class REDQLoss(LossModule):
Examples:
>>> import torch
>>> from torch import nn
- >>> from torchrl.data import BoundedTensorSpec
+ >>> from torchrl.data import Bounded
>>> from torchrl.modules.distributions import NormalParamExtractor, TanhNormal
>>> from torchrl.modules.tensordict_module.actors import ProbabilisticActor, ValueOperator
>>> from torchrl.modules.tensordict_module.common import SafeModule
>>> from torchrl.objectives.redq import REDQLoss
>>> from tensordict import TensorDict
>>> n_act, n_obs = 4, 3
- >>> spec = BoundedTensorSpec(-torch.ones(n_act), torch.ones(n_act), (n_act,))
+ >>> spec = Bounded(-torch.ones(n_act), torch.ones(n_act), (n_act,))
>>> net = nn.Sequential(nn.Linear(n_obs, 2 * n_act), NormalParamExtractor())
>>> module = SafeModule(net, in_keys=["observation"], out_keys=["loc", "scale"])
>>> actor = ProbabilisticActor(
@@ -155,13 +155,13 @@ class REDQLoss(LossModule):
Examples:
>>> import torch
>>> from torch import nn
- >>> from torchrl.data import BoundedTensorSpec
+ >>> from torchrl.data import Bounded
>>> from torchrl.modules.distributions import NormalParamExtractor, TanhNormal
>>> from torchrl.modules.tensordict_module.actors import ProbabilisticActor, ValueOperator
>>> from torchrl.modules.tensordict_module.common import SafeModule
>>> from torchrl.objectives.redq import REDQLoss
>>> n_act, n_obs = 4, 3
- >>> spec = BoundedTensorSpec(-torch.ones(n_act), torch.ones(n_act), (n_act,))
+ >>> spec = Bounded(-torch.ones(n_act), torch.ones(n_act), (n_act,))
>>> net = nn.Sequential(nn.Linear(n_obs, 2 * n_act), NormalParamExtractor())
>>> module = SafeModule(net, in_keys=["observation"], out_keys=["loc", "scale"])
>>> actor = ProbabilisticActor(
@@ -326,7 +326,11 @@ def __init__(
else:
self.register_parameter(
"log_alpha",
- torch.nn.Parameter(torch.tensor(math.log(alpha_init), device=device)),
+ torch.nn.Parameter(
+ torch.tensor(
+ math.log(alpha_init), device=device, requires_grad=True
+ )
+ ),
)
self._target_entropy = target_entropy
@@ -367,8 +371,8 @@ def target_entropy(self):
"the target entropy explicitely or provide the spec of the "
"action tensor in the actor network."
)
- if not isinstance(action_spec, CompositeSpec):
- action_spec = CompositeSpec({self.tensor_keys.action: action_spec})
+ if not isinstance(action_spec, Composite):
+ action_spec = Composite({self.tensor_keys.action: action_spec})
if (
isinstance(self.tensor_keys.action, tuple)
and len(self.tensor_keys.action) > 1
@@ -401,10 +405,8 @@ def _forward_value_estimator_keys(self, **kwargs) -> None:
@property
def alpha(self):
- self.log_alpha.data.clamp_(self.min_log_alpha, self.max_log_alpha)
with torch.no_grad():
- alpha = self.log_alpha.exp()
- return alpha
+ return self.log_alpha.clamp(self.min_log_alpha, self.max_log_alpha).exp()
def _set_in_keys(self):
keys = [
@@ -448,9 +450,12 @@ def _qvalue_params_cat(self, selected_q_params):
@dispatch
def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
obs_keys = self.actor_network.in_keys
- tensordict_select = tensordict.clone(False).select(
+ tensordict_select = tensordict.select(
"next", *obs_keys, self.tensor_keys.action, strict=False
)
+ # We need to copy bc select does not copy sub-tds
+ tensordict_select = tensordict_select.copy()
+
selected_models_idx = torch.randperm(self.num_qvalue_nets)[
: self.sub_sample_len
].sort()[0]
@@ -467,7 +472,6 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
*self.actor_network.in_keys, strict=False
) # next_observation ->
tensordict_actor = torch.stack([tensordict_actor_grad, next_td_actor], 0)
- # tensordict_actor = tensordict_actor.contiguous()
with set_exploration_type(ExplorationType.RANDOM):
if self.gSDE:
@@ -480,19 +484,12 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
tensordict_actor,
actor_params,
)
- if isinstance(self.actor_network, TensorDictSequential):
- sample_key = self.tensor_keys.action
- tensordict_actor_dist = self.actor_network.build_dist_from_params(
- td_params
- )
- else:
- sample_key = self.tensor_keys.action
- tensordict_actor_dist = self.actor_network.build_dist_from_params(
- td_params
- )
+ sample_key = self.tensor_keys.action
+ sample_key_lp = self.tensor_keys.sample_log_prob
+ tensordict_actor_dist = self.actor_network.build_dist_from_params(td_params)
tensordict_actor.set(sample_key, tensordict_actor_dist.rsample())
tensordict_actor.set(
- self.tensor_keys.sample_log_prob,
+ sample_key_lp,
tensordict_actor_dist.log_prob(tensordict_actor.get(sample_key)),
)
@@ -603,12 +600,22 @@ def _loss_alpha(self, log_pi: Tensor) -> Tensor:
)
if self.target_entropy is not None:
# we can compute this loss even if log_alpha is not a parameter
- alpha_loss = -self.log_alpha.exp() * (log_pi.detach() + self.target_entropy)
+ alpha_loss = -self._safe_log_alpha.exp() * (
+ log_pi.detach() + self.target_entropy
+ )
else:
# placeholder
alpha_loss = torch.zeros_like(log_pi)
return alpha_loss
+ @property
+ def _safe_log_alpha(self):
+ log_alpha = self.log_alpha
+ with torch.no_grad():
+ log_alpha_clamp = log_alpha.clamp(self.min_log_alpha, self.max_log_alpha)
+ log_alpha_det = log_alpha.detach()
+ return log_alpha - log_alpha_det + log_alpha_clamp
+
def make_value_estimator(self, value_type: ValueEstimators = None, **hyperparams):
if value_type is None:
value_type = self.default_value_estimator
diff --git a/torchrl/objectives/reinforce.py b/torchrl/objectives/reinforce.py
index af9f7d99b46..08ff896610c 100644
--- a/torchrl/objectives/reinforce.py
+++ b/torchrl/objectives/reinforce.py
@@ -100,7 +100,7 @@ class ReinforceLoss(LossModule):
Examples:
>>> import torch
>>> from torch import nn
- >>> from torchrl.data.tensor_specs import UnboundedContinuousTensorSpec
+ >>> from torchrl.data.tensor_specs import Unbounded
>>> from torchrl.modules.distributions import NormalParamExtractor, TanhNormal
>>> from torchrl.modules.tensordict_module.actors import ProbabilisticActor, ValueOperator
>>> from torchrl.modules.tensordict_module.common import SafeModule
@@ -115,7 +115,7 @@ class ReinforceLoss(LossModule):
... distribution_class=TanhNormal,
... return_log_prob=True,
... in_keys=["loc", "scale"],
- ... spec=UnboundedContinuousTensorSpec(n_act),)
+ ... spec=Unbounded(n_act),)
>>> loss = ReinforceLoss(actor_net, value_net)
>>> batch = 2
>>> data = TensorDict({
@@ -146,7 +146,7 @@ class ReinforceLoss(LossModule):
Examples:
>>> import torch
>>> from torch import nn
- >>> from torchrl.data.tensor_specs import UnboundedContinuousTensorSpec
+ >>> from torchrl.data.tensor_specs import Unbounded
>>> from torchrl.modules.distributions import NormalParamExtractor, TanhNormal
>>> from torchrl.modules.tensordict_module.actors import ProbabilisticActor, ValueOperator
>>> from torchrl.modules.tensordict_module.common import SafeModule
@@ -160,7 +160,7 @@ class ReinforceLoss(LossModule):
... distribution_class=TanhNormal,
... return_log_prob=True,
... in_keys=["loc", "scale"],
- ... spec=UnboundedContinuousTensorSpec(n_act),)
+ ... spec=Unbounded(n_act),)
>>> loss = ReinforceLoss(actor_net, value_net)
>>> batch = 2
>>> loss_actor, loss_value = loss(
diff --git a/torchrl/objectives/sac.py b/torchrl/objectives/sac.py
index df444eac053..6350538db16 100644
--- a/torchrl/objectives/sac.py
+++ b/torchrl/objectives/sac.py
@@ -18,7 +18,7 @@
from tensordict.nn import dispatch, TensorDictModule
from tensordict.utils import NestedKey
from torch import Tensor
-from torchrl.data.tensor_specs import CompositeSpec, TensorSpec
+from torchrl.data.tensor_specs import Composite, TensorSpec
from torchrl.data.utils import _find_action_space
from torchrl.envs.utils import ExplorationType, set_exploration_type
from torchrl.modules import ProbabilisticActor
@@ -46,6 +46,19 @@ def new_func(self, *args, **kwargs):
return new_func
+def compute_log_prob(action_dist, action_or_tensordict, tensor_key):
+ """Compute the log probability of an action given a distribution."""
+ if isinstance(action_or_tensordict, torch.Tensor):
+ log_p = action_dist.log_prob(action_or_tensordict)
+ else:
+ maybe_log_prob = action_dist.log_prob(action_or_tensordict)
+ if not isinstance(maybe_log_prob, torch.Tensor):
+ log_p = maybe_log_prob.get(tensor_key)
+ else:
+ log_p = maybe_log_prob
+ return log_p
+
+
class SACLoss(LossModule):
"""TorchRL implementation of the SAC loss.
@@ -117,14 +130,14 @@ class SACLoss(LossModule):
Examples:
>>> import torch
>>> from torch import nn
- >>> from torchrl.data import BoundedTensorSpec
+ >>> from torchrl.data import Bounded
>>> from torchrl.modules.distributions import NormalParamExtractor, TanhNormal
>>> from torchrl.modules.tensordict_module.actors import ProbabilisticActor, ValueOperator
>>> from torchrl.modules.tensordict_module.common import SafeModule
>>> from torchrl.objectives.sac import SACLoss
>>> from tensordict import TensorDict
>>> n_act, n_obs = 4, 3
- >>> spec = BoundedTensorSpec(-torch.ones(n_act), torch.ones(n_act), (n_act,))
+ >>> spec = Bounded(-torch.ones(n_act), torch.ones(n_act), (n_act,))
>>> net = nn.Sequential(nn.Linear(n_obs, 2 * n_act), NormalParamExtractor())
>>> module = SafeModule(net, in_keys=["observation"], out_keys=["loc", "scale"])
>>> actor = ProbabilisticActor(
@@ -180,14 +193,14 @@ class SACLoss(LossModule):
Examples:
>>> import torch
>>> from torch import nn
- >>> from torchrl.data import BoundedTensorSpec
+ >>> from torchrl.data import Bounded
>>> from torchrl.modules.distributions import NormalParamExtractor, TanhNormal
>>> from torchrl.modules.tensordict_module.actors import ProbabilisticActor, ValueOperator
>>> from torchrl.modules.tensordict_module.common import SafeModule
>>> from torchrl.objectives.sac import SACLoss
>>> _ = torch.manual_seed(42)
>>> n_act, n_obs = 4, 3
- >>> spec = BoundedTensorSpec(-torch.ones(n_act), torch.ones(n_act), (n_act,))
+ >>> spec = Bounded(-torch.ones(n_act), torch.ones(n_act), (n_act,))
>>> net = nn.Sequential(nn.Linear(n_obs, 2 * n_act), NormalParamExtractor())
>>> module = SafeModule(net, in_keys=["observation"], out_keys=["loc", "scale"])
>>> actor = ProbabilisticActor(
@@ -251,7 +264,7 @@ class _AcceptedKeys:
state_action_value (NestedKey): The input tensordict key where the
state action value is expected. Defaults to ``"state_action_value"``.
log_prob (NestedKey): The input tensordict key where the log probability is expected.
- Defaults to ``"_log_prob"``.
+ Defaults to ``"sample_log_prob"``.
priority (NestedKey): The input tensordict key where the target priority is written to.
Defaults to ``"td_error"``.
reward (NestedKey): The input tensordict key where the reward is expected.
@@ -267,7 +280,7 @@ class _AcceptedKeys:
action: NestedKey = "action"
value: NestedKey = "state_value"
state_action_value: NestedKey = "state_action_value"
- log_prob: NestedKey = "_log_prob"
+ log_prob: NestedKey = "sample_log_prob"
priority: NestedKey = "td_error"
reward: NestedKey = "reward"
done: NestedKey = "done"
@@ -440,8 +453,8 @@ def target_entropy(self):
"the target entropy explicitely or provide the spec of the "
"action tensor in the actor network."
)
- if not isinstance(action_spec, CompositeSpec):
- action_spec = CompositeSpec({self.tensor_keys.action: action_spec})
+ if not isinstance(action_spec, Composite):
+ action_spec = Composite({self.tensor_keys.action: action_spec})
if (
isinstance(self.tensor_keys.action, tuple)
and len(self.tensor_keys.action) > 1
@@ -450,9 +463,7 @@ def target_entropy(self):
else:
action_container_shape = action_spec.shape
target_entropy = -float(
- action_spec[self.tensor_keys.action]
- .shape[len(action_container_shape) :]
- .numel()
+ action_spec.shape[len(action_container_shape) :].numel()
)
delattr(self, "_target_entropy")
self.register_buffer(
@@ -622,7 +633,7 @@ def _actor_loss(
), self.actor_network_params.to_module(self.actor_network):
dist = self.actor_network.get_dist(tensordict)
a_reparm = dist.rsample()
- log_prob = dist.log_prob(a_reparm)
+ log_prob = compute_log_prob(dist, a_reparm, self.tensor_keys.log_prob)
td_q = tensordict.select(*self.qvalue_network.in_keys, strict=False)
td_q.set(self.tensor_keys.action, a_reparm)
@@ -713,7 +724,9 @@ def _compute_target_v2(self, tensordict) -> Tensor:
next_dist = self.actor_network.get_dist(next_tensordict)
next_action = next_dist.rsample()
next_tensordict.set(self.tensor_keys.action, next_action)
- next_sample_log_prob = next_dist.log_prob(next_action)
+ next_sample_log_prob = compute_log_prob(
+ next_dist, next_action, self.tensor_keys.log_prob
+ )
# get q-values
next_tensordict_expand = self._vmap_qnetworkN0(
@@ -780,7 +793,8 @@ def _value_loss(
td_copy.get(self.tensor_keys.state_action_value).squeeze(-1).min(0)[0]
)
- log_p = action_dist.log_prob(action)
+ log_p = compute_log_prob(action_dist, action, self.tensor_keys.log_prob)
+
if log_p.shape != min_qval.shape:
raise RuntimeError(
f"Losses shape mismatch: {min_qval.shape} and {log_p.shape}"
@@ -818,9 +832,9 @@ class DiscreteSACLoss(LossModule):
qvalue_network (TensorDictModule): a single Q-value network that will be multiplicated as many times as needed.
action_space (str or TensorSpec): Action space. Must be one of
``"one-hot"``, ``"mult_one_hot"``, ``"binary"`` or ``"categorical"``,
- or an instance of the corresponding specs (:class:`torchrl.data.OneHotDiscreteTensorSpec`,
- :class:`torchrl.data.MultiOneHotDiscreteTensorSpec`,
- :class:`torchrl.data.BinaryDiscreteTensorSpec` or :class:`torchrl.data.DiscreteTensorSpec`).
+ or an instance of the corresponding specs (:class:`torchrl.data.OneHot`,
+ :class:`torchrl.data.MultiOneHot`,
+ :class:`torchrl.data.Binary` or :class:`torchrl.data.Categorical`).
num_actions (int, optional): number of actions in the action space.
To be provided if target_entropy is set to "auto".
num_qvalue_nets (int, optional): Number of Q-value networks to be trained. Default is 2.
@@ -852,7 +866,7 @@ class DiscreteSACLoss(LossModule):
Examples:
>>> import torch
>>> from torch import nn
- >>> from torchrl.data.tensor_specs import OneHotDiscreteTensorSpec
+ >>> from torchrl.data.tensor_specs import OneHot
>>> from torchrl.modules.distributions import NormalParamExtractor, OneHotCategorical
>>> from torchrl.modules.tensordict_module.actors import ProbabilisticActor, ValueOperator
>>> from torchrl.modules.tensordict_module.common import SafeModule
@@ -860,7 +874,7 @@ class DiscreteSACLoss(LossModule):
>>> from tensordict import TensorDict
>>> from tensordict.nn import TensorDictModule
>>> n_act, n_obs = 4, 3
- >>> spec = OneHotDiscreteTensorSpec(n_act)
+ >>> spec = OneHot(n_act)
>>> module = TensorDictModule(nn.Linear(n_obs, n_act), in_keys=["observation"], out_keys=["logits"])
>>> actor = ProbabilisticActor(
... module=module,
@@ -909,13 +923,13 @@ class DiscreteSACLoss(LossModule):
Examples:
>>> import torch
>>> from torch import nn
- >>> from torchrl.data.tensor_specs import OneHotDiscreteTensorSpec
+ >>> from torchrl.data.tensor_specs import OneHot
>>> from torchrl.modules.distributions import NormalParamExtractor, OneHotCategorical
>>> from torchrl.modules.tensordict_module.actors import ProbabilisticActor, ValueOperator
>>> from torchrl.modules.tensordict_module.common import SafeModule
>>> from torchrl.objectives.sac import DiscreteSACLoss
>>> n_act, n_obs = 4, 3
- >>> spec = OneHotDiscreteTensorSpec(n_act)
+ >>> spec = OneHot(n_act)
>>> net = nn.Sequential(nn.Linear(n_obs, 2 * n_act), NormalParamExtractor())
>>> module = SafeModule(net, in_keys=["observation"], out_keys=["logits"])
>>> actor = ProbabilisticActor(
@@ -1088,7 +1102,7 @@ def __init__(
if action_space is None:
warnings.warn(
"action_space was not specified. DiscreteSACLoss will default to 'one-hot'."
- "This behaviour will be deprecated soon and a space will have to be passed. "
+ "This behavior will be deprecated soon and a space will have to be passed. "
"Check the DiscreteSACLoss documentation to see how to pass the action space. "
)
action_space = "one-hot"
diff --git a/torchrl/objectives/td3.py b/torchrl/objectives/td3.py
index eb1027ad936..89ff581991f 100644
--- a/torchrl/objectives/td3.py
+++ b/torchrl/objectives/td3.py
@@ -12,7 +12,7 @@
from tensordict import TensorDict, TensorDictBase, TensorDictParams
from tensordict.nn import dispatch, TensorDictModule
from tensordict.utils import NestedKey
-from torchrl.data.tensor_specs import BoundedTensorSpec, CompositeSpec, TensorSpec
+from torchrl.data.tensor_specs import Bounded, Composite, TensorSpec
from torchrl.envs.utils import step_mdp
from torchrl.objectives.common import LossModule
@@ -83,14 +83,14 @@ class TD3Loss(LossModule):
Examples:
>>> import torch
>>> from torch import nn
- >>> from torchrl.data import BoundedTensorSpec
+ >>> from torchrl.data import Bounded
>>> from torchrl.modules.distributions import NormalParamExtractor, TanhNormal
>>> from torchrl.modules.tensordict_module.actors import Actor, ProbabilisticActor, ValueOperator
>>> from torchrl.modules.tensordict_module.common import SafeModule
>>> from torchrl.objectives.td3 import TD3Loss
>>> from tensordict import TensorDict
>>> n_act, n_obs = 4, 3
- >>> spec = BoundedTensorSpec(-torch.ones(n_act), torch.ones(n_act), (n_act,))
+ >>> spec = Bounded(-torch.ones(n_act), torch.ones(n_act), (n_act,))
>>> module = nn.Linear(n_obs, n_act)
>>> actor = Actor(
... module=module,
@@ -139,11 +139,11 @@ class TD3Loss(LossModule):
Examples:
>>> import torch
>>> from torch import nn
- >>> from torchrl.data import BoundedTensorSpec
+ >>> from torchrl.data import Bounded
>>> from torchrl.modules.tensordict_module.actors import Actor, ValueOperator
>>> from torchrl.objectives.td3 import TD3Loss
>>> n_act, n_obs = 4, 3
- >>> spec = BoundedTensorSpec(-torch.ones(n_act), torch.ones(n_act), (n_act,))
+ >>> spec = Bounded(-torch.ones(n_act), torch.ones(n_act), (n_act,))
>>> module = nn.Linear(n_obs, n_act)
>>> actor = Actor(
... module=module,
@@ -283,7 +283,7 @@ def __init__(
f"but not both or none. Got bounds={bounds} and action_spec={action_spec}."
)
elif action_spec is not None:
- if isinstance(action_spec, CompositeSpec):
+ if isinstance(action_spec, Composite):
if (
isinstance(self.tensor_keys.action, tuple)
and len(self.tensor_keys.action) > 1
@@ -296,9 +296,9 @@ def __init__(
action_spec = action_spec[self.tensor_keys.action][
(0,) * len(action_container_shape)
]
- if not isinstance(action_spec, BoundedTensorSpec):
+ if not isinstance(action_spec, Bounded):
raise ValueError(
- f"action_spec is not of type BoundedTensorSpec but {type(action_spec)}."
+ f"action_spec is not of type Bounded but {type(action_spec)}."
)
low = action_spec.space.low
high = action_spec.space.high
@@ -372,7 +372,7 @@ def _cached_stack_actor_params(self):
[self.actor_network_params, self.target_actor_network_params], 0
)
- def actor_loss(self, tensordict):
+ def actor_loss(self, tensordict) -> Tuple[torch.Tensor, dict]:
tensordict_actor_grad = tensordict.select(
*self.actor_network.in_keys, strict=False
)
@@ -398,7 +398,7 @@ def actor_loss(self, tensordict):
loss_actor = _reduce(loss_actor, reduction=self.reduction)
return loss_actor, metadata
- def value_loss(self, tensordict):
+ def value_loss(self, tensordict) -> Tuple[torch.Tensor, dict]:
tensordict = tensordict.clone(False)
act = tensordict.get(self.tensor_keys.action)
diff --git a/torchrl/objectives/td3_bc.py b/torchrl/objectives/td3_bc.py
index aa87ea9aa1a..8b394137480 100644
--- a/torchrl/objectives/td3_bc.py
+++ b/torchrl/objectives/td3_bc.py
@@ -12,7 +12,7 @@
from tensordict import TensorDict, TensorDictBase, TensorDictParams
from tensordict.nn import dispatch, TensorDictModule
from tensordict.utils import NestedKey
-from torchrl.data.tensor_specs import BoundedTensorSpec, CompositeSpec, TensorSpec
+from torchrl.data.tensor_specs import Bounded, Composite, TensorSpec
from torchrl.envs.utils import step_mdp
from torchrl.objectives.common import LossModule
@@ -94,14 +94,14 @@ class TD3BCLoss(LossModule):
Examples:
>>> import torch
>>> from torch import nn
- >>> from torchrl.data import BoundedTensorSpec
+ >>> from torchrl.data import Bounded
>>> from torchrl.modules.distributions import NormalParamExtractor, TanhNormal
>>> from torchrl.modules.tensordict_module.actors import Actor, ProbabilisticActor, ValueOperator
>>> from torchrl.modules.tensordict_module.common import SafeModule
>>> from torchrl.objectives.td3_bc import TD3BCLoss
>>> from tensordict import TensorDict
>>> n_act, n_obs = 4, 3
- >>> spec = BoundedTensorSpec(-torch.ones(n_act), torch.ones(n_act), (n_act,))
+ >>> spec = Bounded(-torch.ones(n_act), torch.ones(n_act), (n_act,))
>>> module = nn.Linear(n_obs, n_act)
>>> actor = Actor(
... module=module,
@@ -152,11 +152,11 @@ class TD3BCLoss(LossModule):
Examples:
>>> import torch
>>> from torch import nn
- >>> from torchrl.data import BoundedTensorSpec
+ >>> from torchrl.data import Bounded
>>> from torchrl.modules.tensordict_module.actors import Actor, ValueOperator
>>> from torchrl.objectives.td3_bc import TD3BCLoss
>>> n_act, n_obs = 4, 3
- >>> spec = BoundedTensorSpec(-torch.ones(n_act), torch.ones(n_act), (n_act,))
+ >>> spec = Bounded(-torch.ones(n_act), torch.ones(n_act), (n_act,))
>>> module = nn.Linear(n_obs, n_act)
>>> actor = Actor(
... module=module,
@@ -299,7 +299,7 @@ def __init__(
f"but not both or none. Got bounds={bounds} and action_spec={action_spec}."
)
elif action_spec is not None:
- if isinstance(action_spec, CompositeSpec):
+ if isinstance(action_spec, Composite):
if (
isinstance(self.tensor_keys.action, tuple)
and len(self.tensor_keys.action) > 1
@@ -312,9 +312,9 @@ def __init__(
action_spec = action_spec[self.tensor_keys.action][
(0,) * len(action_container_shape)
]
- if not isinstance(action_spec, BoundedTensorSpec):
+ if not isinstance(action_spec, Bounded):
raise ValueError(
- f"action_spec is not of type BoundedTensorSpec but {type(action_spec)}."
+ f"action_spec is not of type Bounded but {type(action_spec)}."
)
low = action_spec.space.low
high = action_spec.space.high
@@ -386,7 +386,7 @@ def _cached_stack_actor_params(self):
[self.actor_network_params, self.target_actor_network_params], 0
)
- def actor_loss(self, tensordict):
+ def actor_loss(self, tensordict) -> Tuple[torch.Tensor, dict]:
"""Compute the actor loss.
The actor loss should be computed after the :meth:`~.qvalue_loss` and is usually delayed 1-3 critic updates.
@@ -433,7 +433,7 @@ def actor_loss(self, tensordict):
loss_actor = _reduce(loss_actor, reduction=self.reduction)
return loss_actor, metadata
- def qvalue_loss(self, tensordict):
+ def qvalue_loss(self, tensordict) -> Tuple[torch.Tensor, dict]:
"""Compute the q-value loss.
The q-value loss should be computed before the :meth:`~.actor_loss`.
diff --git a/torchrl/objectives/utils.py b/torchrl/objectives/utils.py
index b1077198784..31954005195 100644
--- a/torchrl/objectives/utils.py
+++ b/torchrl/objectives/utils.py
@@ -26,6 +26,11 @@
raise err_ft from err
from torchrl.envs.utils import step_mdp
+try:
+ from torch.compiler import is_dynamo_compiling
+except ImportError:
+ from torch._dynamo import is_compiling as is_dynamo_compiling
+
_GAMMA_LMBDA_DEPREC_ERROR = (
"Passing gamma / lambda parameters through the loss constructor "
"is a deprecated feature. To customize your value function, "
@@ -198,23 +203,37 @@ def __init__(
@property
def _targets(self):
- return TensorDict(
- {name: getattr(self.loss_module, name) for name in self._target_names},
- [],
- )
+ targets = self.__dict__.get("_targets_val", None)
+ if targets is None:
+ targets = self.__dict__["_targets_val"] = TensorDict(
+ {name: getattr(self.loss_module, name) for name in self._target_names},
+ [],
+ )
+ return targets
+
+ @_targets.setter
+ def _targets(self, targets):
+ self.__dict__["_targets_val"] = targets
@property
def _sources(self):
- return TensorDict(
- {name: getattr(self.loss_module, name) for name in self._source_names},
- [],
- )
+ sources = self.__dict__.get("_sources_val", None)
+ if sources is None:
+ sources = self.__dict__["_sources_val"] = TensorDict(
+ {name: getattr(self.loss_module, name) for name in self._source_names},
+ [],
+ )
+ return sources
+
+ @_sources.setter
+ def _sources(self, sources):
+ self.__dict__["_sources_val"] = sources
def init_(self) -> None:
if self.initialized:
warnings.warn("Updated already initialized.")
found_distinct = False
- self._distinct = {}
+ self._distinct_and_params = {}
for key, source in self._sources.items(True, True):
if not isinstance(key, tuple):
key = (key,)
@@ -223,8 +242,12 @@ def init_(self) -> None:
# for p_source, p_target in zip(source, target):
if target.requires_grad:
raise RuntimeError("the target parameter is part of a graph.")
- self._distinct[key] = target.data_ptr() != source.data.data_ptr()
- found_distinct = found_distinct or self._distinct[key]
+ self._distinct_and_params[key] = (
+ target.is_leaf
+ and source.requires_grad
+ and target.data_ptr() != source.data.data_ptr()
+ )
+ found_distinct = found_distinct or self._distinct_and_params[key]
target.data.copy_(source.data)
if not found_distinct:
raise RuntimeError(
@@ -235,6 +258,23 @@ def init_(self) -> None:
f"If no target parameter is needed, do not use a target updater such as {type(self)}."
)
+ # filter the target_ out
+ def filter_target(key):
+ if isinstance(key, tuple):
+ return (filter_target(key[0]), *key[1:])
+ return key[7:]
+
+ self._sources = self._sources.select(
+ *[
+ filter_target(key)
+ for (key, val) in self._distinct_and_params.items()
+ if val
+ ]
+ ).lock_()
+ self._targets = self._targets.select(
+ *(key for (key, val) in self._distinct_and_params.items() if val)
+ ).lock_()
+
self.initialized = True
def step(self) -> None:
@@ -243,19 +283,11 @@ def step(self) -> None:
f"{self.__class__.__name__} must be "
f"initialized (`{self.__class__.__name__}.init_()`) before calling step()"
)
- for key, source in self._sources.items(True, True):
- if not isinstance(key, tuple):
- key = (key,)
- key = ("target_" + key[0], *key[1:])
- if not self._distinct[key]:
- continue
- target = self._targets[key]
+ for key, param in self._sources.items():
+ target = self._targets.get("target_{}".format(key))
if target.requires_grad:
raise RuntimeError("the target parameter is part of a graph.")
- if target.is_leaf:
- self._step(source, target)
- else:
- target.copy_(source)
+ self._step(param, target)
def _step(self, p_source: Tensor, p_target: Tensor) -> None:
raise NotImplementedError
@@ -301,7 +333,7 @@ def __init__(
):
if eps is None and tau is None:
raise RuntimeError(
- "Neither eps nor tau was provided. This behaviour is deprecated.",
+ "Neither eps nor tau was provided. This behavior is deprecated.",
)
eps = 0.999
if (eps is None) ^ (tau is None):
@@ -321,8 +353,10 @@ def __init__(
super(SoftUpdate, self).__init__(loss_module)
self.eps = eps
- def _step(self, p_source: Tensor, p_target: Tensor) -> None:
- p_target.data.copy_(p_target.data * self.eps + p_source.data * (1 - self.eps))
+ def _step(
+ self, p_source: Tensor | TensorDictBase, p_target: Tensor | TensorDictBase
+ ) -> None:
+ p_target.data.lerp_(p_source.data, 1 - self.eps)
class HardUpdate(TargetNetUpdater):
@@ -454,11 +488,17 @@ def next_state_value(
return target_value
-def _cache_values(fun):
+def _cache_values(func):
"""Caches the tensordict returned by a property."""
- name = fun.__name__
+ name = func.__name__
- def new_fun(self, netname=None):
+ @functools.wraps(func)
+ def new_func(self, netname=None):
+ if is_dynamo_compiling():
+ if netname is not None:
+ return func(self, netname)
+ else:
+ return func(self)
__dict__ = self.__dict__
_cache = __dict__.setdefault("_cache", {})
attr_name = name
@@ -468,16 +508,16 @@ def new_fun(self, netname=None):
out = _cache[attr_name]
return out
if netname is not None:
- out = fun(self, netname)
+ out = func(self, netname)
else:
- out = fun(self)
+ out = func(self)
# TODO: decide what to do with locked tds in functional calls
# if is_tensor_collection(out):
# out.lock_()
_cache[attr_name] = out
return out
- return new_fun
+ return new_func
def _vmap_func(module, *args, func=None, **kwargs):
diff --git a/torchrl/objectives/value/advantages.py b/torchrl/objectives/value/advantages.py
index b7db2e8242e..e396b7e1fcc 100644
--- a/torchrl/objectives/value/advantages.py
+++ b/torchrl/objectives/value/advantages.py
@@ -16,13 +16,12 @@
from tensordict import TensorDictBase
from tensordict.nn import (
dispatch,
- is_functional,
set_skip_existing,
TensorDictModule,
TensorDictModuleBase,
)
from tensordict.utils import NestedKey
-from torch import nn, Tensor
+from torch import Tensor
from torchrl._utils import RL_WARNINGS
from torchrl.envs.utils import step_mdp
@@ -412,18 +411,13 @@ def value_estimate(
@property
def is_functional(self):
- if isinstance(self.value_network, nn.Module):
- return is_functional(self.value_network)
- elif self.value_network is None:
- return None
- else:
- raise RuntimeError("Cannot determine if value network is functional.")
+ # legacy
+ return False
@property
def is_stateless(self):
- if not self.is_functional:
- return False
- return self.value_network._is_stateless
+ # legacy
+ return False
def _next_value(self, tensordict, target_params, kwargs):
step_td = step_mdp(tensordict, keep_other=False)
@@ -1183,17 +1177,17 @@ class GAE(ValueEstimatorBase):
device (torch.device, optional): device of the module.
time_dim (int, optional): the dimension corresponding to the time
in the input tensordict. If not provided, defaults to the dimension
- markes with the ``"time"`` name if any, and to the last dimension
+ marked with the ``"time"`` name if any, and to the last dimension
otherwise. Can be overridden during a call to
:meth:`~.value_estimate`.
Negative dimensions are considered with respect to the input
tensordict.
- GAE will return an :obj:`"advantage"` entry containing the advange value. It will also
+ GAE will return an :obj:`"advantage"` entry containing the advantage value. It will also
return a :obj:`"value_target"` entry with the return value that is to be used
to train the value network. Finally, if :obj:`gradient_mode` is ``True``,
an additional and differentiable :obj:`"value_error"` entry will be returned,
- which simple represents the difference between the return and the value network
+ which simply represents the difference between the return and the value network
output (i.e. an additional distance loss should be applied to that signed value).
.. note::
@@ -1268,7 +1262,7 @@ def forward(
target params to be passed to the functional value network module.
time_dim (int, optional): the dimension corresponding to the time
in the input tensordict. If not provided, defaults to the dimension
- markes with the ``"time"`` name if any, and to the last dimension
+ marked with the ``"time"`` name if any, and to the last dimension
otherwise.
Negative dimensions are considered with respect to the input
tensordict.
@@ -1316,7 +1310,7 @@ def forward(
"""
if tensordict.batch_dims < 1:
raise RuntimeError(
- "Expected input tensordict to have at least one dimensions, got "
+ "Expected input tensordict to have at least one dimension, got "
f"tensordict.batch_size = {tensordict.batch_size}"
)
reward = tensordict.get(("next", self.tensor_keys.reward))
diff --git a/torchrl/objectives/value/functional.py b/torchrl/objectives/value/functional.py
index d3ad8d93ca4..ddd688610c2 100644
--- a/torchrl/objectives/value/functional.py
+++ b/torchrl/objectives/value/functional.py
@@ -230,7 +230,7 @@ def _fast_vec_gae(
``[*Batch x TimeSteps x F]``, with ``F`` feature dimensions.
"""
- # _gen_num_per_traj and _split_and_pad_sequence need
+ # _get_num_per_traj and _split_and_pad_sequence need
# time dimension at last position
done = done.transpose(-2, -1)
terminated = terminated.transpose(-2, -1)
diff --git a/torchrl/record/recorder.py b/torchrl/record/recorder.py
index b7fb8ab4ed2..e533f9e9df9 100644
--- a/torchrl/record/recorder.py
+++ b/torchrl/record/recorder.py
@@ -18,7 +18,7 @@
from torchrl._utils import _can_be_pickled
from torchrl.data import TensorSpec
-from torchrl.data.tensor_specs import NonTensorSpec, UnboundedContinuousTensorSpec
+from torchrl.data.tensor_specs import NonTensor, Unbounded
from torchrl.data.utils import CloudpickleWrapper
from torchrl.envs import EnvBase
from torchrl.envs.transforms import ObservationTransform, Transform
@@ -409,9 +409,9 @@ class PixelRenderTransform(Transform):
>>> env.transform[-1].dump()
The transform can be disabled using the :meth:`~torchrl.record.PixelRenderTransform.switch` method, which will
- turn the rendering on if it's off or off if it's on (an argument can also be passed to control this behaviour).
+ turn the rendering on if it's off or off if it's on (an argument can also be passed to control this behavior).
Since transforms are :class:`~torch.nn.Module` instances, :meth:`~torch.nn.Module.apply` can be used to control
- this behaviour:
+ this behavior:
>>> def switch(module):
... if isinstance(module, PixelRenderTransform):
@@ -506,11 +506,9 @@ def transform_observation_spec(self, observation_spec: TensorSpec) -> TensorSpec
self._call(td_in)
obs = td_in.get(self.out_keys[0])
if isinstance(obs, NonTensorData):
- spec = NonTensorSpec(device=obs.device, dtype=obs.dtype, shape=obs.shape)
+ spec = NonTensor(device=obs.device, dtype=obs.dtype, shape=obs.shape)
else:
- spec = UnboundedContinuousTensorSpec(
- device=obs.device, dtype=obs.dtype, shape=obs.shape
- )
+ spec = Unbounded(device=obs.device, dtype=obs.dtype, shape=obs.shape)
observation_spec[self.out_keys[0]] = spec
if switch:
self.switch()
diff --git a/torchrl/trainers/helpers/collectors.py b/torchrl/trainers/helpers/collectors.py
index b192d115a54..efdde1a1c63 100644
--- a/torchrl/trainers/helpers/collectors.py
+++ b/torchrl/trainers/helpers/collectors.py
@@ -19,7 +19,6 @@
from torchrl.data.postprocs import MultiStep
from torchrl.envs.batched_envs import ParallelEnv
from torchrl.envs.common import EnvBase
-from torchrl.envs.utils import ExplorationType
def sync_async_collector(
@@ -304,7 +303,7 @@ def make_collector_offpolicy(
"init_random_frames": cfg.init_random_frames,
"split_trajs": True,
# trajectories must be separated if multi-step is used
- "exploration_type": ExplorationType.from_str(cfg.exploration_mode),
+ "exploration_type": cfg.exploration_type,
}
collector = collector_helper(**collector_helper_kwargs)
@@ -358,7 +357,7 @@ def make_collector_onpolicy(
"storing_device": cfg.collector_device,
"split_trajs": True,
# trajectories must be separated in online settings
- "exploration_mode": cfg.exploration_mode,
+ "exploration_type": cfg.exploration_type,
}
collector = collector_helper(**collector_helper_kwargs)
@@ -398,7 +397,7 @@ class OnPolicyCollectorConfig:
# for each of these parallel wrappers. If env_per_collector=num_workers, no parallel wrapper is created
seed: int = 42
# seed used for the environment, pytorch and numpy.
- exploration_mode: str = "random"
+ exploration_type: str = "random"
# exploration mode of the data collector.
async_collection: bool = False
# whether data collection should be done asynchrously. Asynchrounous data collection means
diff --git a/torchrl/trainers/helpers/models.py b/torchrl/trainers/helpers/models.py
index 0c9ec92cff4..a3776f78e5a 100644
--- a/torchrl/trainers/helpers/models.py
+++ b/torchrl/trainers/helpers/models.py
@@ -9,11 +9,7 @@
from tensordict import set_lazy_legacy
from tensordict.nn import InteractionType
from torch import nn
-from torchrl.data.tensor_specs import (
- CompositeSpec,
- DiscreteTensorSpec,
- UnboundedContinuousTensorSpec,
-)
+from torchrl.data.tensor_specs import Categorical, Composite, Unbounded
from torchrl.data.utils import DEVICE_TYPING
from torchrl.envs.common import EnvBase
from torchrl.envs.model_based.dreamer import DreamerEnv
@@ -153,9 +149,9 @@ def make_dqn_actor(
actor_class = QValueActor
actor_kwargs = {}
- if isinstance(action_spec, DiscreteTensorSpec):
+ if isinstance(action_spec, Categorical):
# if action spec is modeled as categorical variable, we still need to have features equal
- # to the number of possible choices and also set categorical behavioural for actors.
+ # to the number of possible choices and also set categorical behavioral for actors.
actor_kwargs.update({"action_space": "categorical"})
out_features = env_specs["input_spec", "full_action_spec", "action"].space.n
else:
@@ -182,7 +178,7 @@ def make_dqn_actor(
model = actor_class(
module=net,
- spec=CompositeSpec(action=action_spec),
+ spec=Composite(action=action_spec),
in_keys=[in_key],
safe=True,
**actor_kwargs,
@@ -385,13 +381,13 @@ def _dreamer_make_actor_sim(action_key, proof_environment, actor_module):
actor_module,
in_keys=["state", "belief"],
out_keys=["loc", "scale"],
- spec=CompositeSpec(
+ spec=Composite(
**{
- "loc": UnboundedContinuousTensorSpec(
+ "loc": Unbounded(
proof_environment.action_spec.shape,
device=proof_environment.action_spec.device,
),
- "scale": UnboundedContinuousTensorSpec(
+ "scale": Unbounded(
proof_environment.action_spec.shape,
device=proof_environment.action_spec.device,
),
@@ -404,7 +400,7 @@ def _dreamer_make_actor_sim(action_key, proof_environment, actor_module):
default_interaction_type=InteractionType.RANDOM,
distribution_class=TanhNormal,
distribution_kwargs={"tanh_loc": True},
- spec=CompositeSpec(**{action_key: proof_environment.action_spec}),
+ spec=Composite(**{action_key: proof_environment.action_spec}),
),
)
return actor_simulator
@@ -436,12 +432,12 @@ def _dreamer_make_actor_real(
actor_module,
in_keys=["state", "belief"],
out_keys=["loc", "scale"],
- spec=CompositeSpec(
+ spec=Composite(
**{
- "loc": UnboundedContinuousTensorSpec(
+ "loc": Unbounded(
proof_environment.action_spec.shape,
),
- "scale": UnboundedContinuousTensorSpec(
+ "scale": Unbounded(
proof_environment.action_spec.shape,
),
}
@@ -453,9 +449,7 @@ def _dreamer_make_actor_real(
default_interaction_type=InteractionType.DETERMINISTIC,
distribution_class=TanhNormal,
distribution_kwargs={"tanh_loc": True},
- spec=CompositeSpec(
- **{action_key: proof_environment.action_spec.to("cpu")}
- ),
+ spec=Composite(**{action_key: proof_environment.action_spec.to("cpu")}),
),
),
SafeModule(
@@ -536,8 +530,8 @@ def _dreamer_make_mbenv(
model_based_env.set_specs_from_env(proof_environment)
model_based_env = TransformedEnv(model_based_env)
default_dict = {
- "state": UnboundedContinuousTensorSpec(state_dim),
- "belief": UnboundedContinuousTensorSpec(rssm_hidden_dim),
+ "state": Unbounded(state_dim),
+ "belief": Unbounded(rssm_hidden_dim),
# "action": proof_environment.action_spec,
}
model_based_env.append_transform(
diff --git a/torchrl/trainers/trainers.py b/torchrl/trainers/trainers.py
index 247d039eb1e..62ea4a4a109 100644
--- a/torchrl/trainers/trainers.py
+++ b/torchrl/trainers/trainers.py
@@ -1126,7 +1126,7 @@ class Recorder(TrainerHookBase):
"""Recorder hook for :class:`~torchrl.trainers.Trainer`.
Args:
- record_interval (int): total number of optimisation steps
+ record_interval (int): total number of optimization steps
between two calls to the recorder for testing.
record_frames (int): number of frames to be recorded during
testing.
@@ -1145,7 +1145,7 @@ class Recorder(TrainerHookBase):
Given that this instance is supposed to both explore and render
the performance of the policy, it should be possible to turn off
- the explorative behaviour by calling the
+ the explorative behavior by calling the
`set_exploration_type(ExplorationType.DETERMINISTIC)` context manager.
environment (EnvBase): An environment instance to be used
for testing.
diff --git a/tutorials/README.md b/tutorials/README.md
index d774f6b7566..562c5d427a9 100644
--- a/tutorials/README.md
+++ b/tutorials/README.md
@@ -1,21 +1,7 @@
# Tutorials
-Get a sense of TorchRL functionalities through our tutorials.
+Get a sense of TorchRL functionalities through our [tutorials](https://pytorch.org/rl/stable/tutorials).
-For an overview of TorchRL, try the [TorchRL demo](https://pytorch.org/rl/tutorials/torchrl_demo.html).
+The ["Getting Started"](https://pytorch.org/rl/stable/index.html#getting-started) section will help you model your first training loop with the library!
-Make sure you test the [TensorDict tutorial](https://pytorch.org/rl/tutorials/tensordict_tutorial.html) to see what TensorDict
-is about and what it can do.
-
-To understand how to use `TensorDict` with pytorch modules, make sure to check out the [TensorDictModule tutorial](https://pytorch.org/rl/tutorials/tensordict_module.html).
-
-Check out the [environment tutorial](https://pytorch.org/rl/tutorials/torch_envs.html) for a deep dive in the envs
-functionalities.
-
-Read through our short tutorial on [multi-tasking](https://pytorch.org/rl/tutorials/multi_task.html) to see how you can execute diverse
-tasks in batch mode and build task-specific policies.
-This tutorial is also a good example of the advanced features of TensorDict stacking and
-indexing.
-
-Finally, the [DDPG tutorial](https://pytorch.org/rl/tutorials/coding_ddpg.html) and [DQN tutorial](https://pytorch.org/rl/tutorials/coding_dqn.html) will guide you through the steps to code
-your first RL algorithms with TorchRL.
+The rest of the tutorials is split in [Basic](https://pytorch.org/rl/stable/index.html#basics), [Intermediate](https://pytorch.org/rl/stable/index.html#intermediate) and [Advanced](https://pytorch.org/rl/stable/index.html#advanced) sections.
diff --git a/tutorials/sphinx-tutorials/coding_ddpg.py b/tutorials/sphinx-tutorials/coding_ddpg.py
index 1bf7fd57e83..13721b715e3 100644
--- a/tutorials/sphinx-tutorials/coding_ddpg.py
+++ b/tutorials/sphinx-tutorials/coding_ddpg.py
@@ -683,7 +683,7 @@ def get_env_stats():
)
-from torchrl.data import CompositeSpec
+from torchrl.data import Composite
###############################################################################
# Building the model
@@ -756,7 +756,7 @@ def make_ddpg_actor(
actor,
distribution_class=TanhDelta,
in_keys=["param"],
- spec=CompositeSpec(action=proof_environment.action_spec),
+ spec=Composite(action=proof_environment.action_spec),
).to(device)
q_net = DdpgMlpQNet()
@@ -899,7 +899,7 @@ def make_recorder(actor_model_explore, transform_state_dict, record_interval):
record_frames=1000,
policy_exploration=actor_model_explore,
environment=environment,
- exploration_type=ExplorationType.MEAN,
+ exploration_type=ExplorationType.DETERMINISTIC,
record_interval=record_interval,
)
return recorder_obj
diff --git a/tutorials/sphinx-tutorials/coding_dqn.py b/tutorials/sphinx-tutorials/coding_dqn.py
index e9f2085d3df..59188ad21f6 100644
--- a/tutorials/sphinx-tutorials/coding_dqn.py
+++ b/tutorials/sphinx-tutorials/coding_dqn.py
@@ -449,7 +449,7 @@ def get_collector(
policy=actor_explore,
frames_per_batch=frames_per_batch,
total_frames=total_frames,
- # this is the default behaviour: the collector runs in ``"random"`` (or explorative) mode
+ # this is the default behavior: the collector runs in ``"random"`` (or explorative) mode
exploration_type=ExplorationType.RANDOM,
# We set the all the devices to be identical. Below is an example of
# heterogeneous devices
@@ -672,7 +672,7 @@ def get_loss_module(actor, gamma):
frame_skip=1,
policy_exploration=actor_explore,
environment=test_env,
- exploration_type=ExplorationType.MODE,
+ exploration_type=ExplorationType.DETERMINISTIC,
log_keys=[("next", "reward")],
out_keys={("next", "reward"): "rewards"},
log_pbar=True,
diff --git a/tutorials/sphinx-tutorials/coding_ppo.py b/tutorials/sphinx-tutorials/coding_ppo.py
index 51229e1880d..25e72dc40f4 100644
--- a/tutorials/sphinx-tutorials/coding_ppo.py
+++ b/tutorials/sphinx-tutorials/coding_ppo.py
@@ -195,7 +195,7 @@
# ~~~~~~~~~~~~~~
#
# At each data collection (or batch collection) we will run the optimization
-# over a certain number of *epochs*, each time consuming the entire data we just
+# over a certain number of *epochs*, each time-consuming the entire data we just
# acquired in a nested training loop. Here, the ``sub_batch_size`` is different from the
# ``frames_per_batch`` here above: recall that we are working with a "batch of data"
# coming from our collector, which size is defined by ``frames_per_batch``, and that
@@ -203,7 +203,7 @@
# The size of these sub-batches is controlled by ``sub_batch_size``.
#
sub_batch_size = 64 # cardinality of the sub-samples gathered from the current data in the inner loop
-num_epochs = 10 # optimisation steps per batch of data collected
+num_epochs = 10 # optimization steps per batch of data collected
clip_epsilon = (
0.2 # clip value for PPO loss: see the equation in the intro for more context.
)
@@ -651,7 +651,7 @@
# number of steps (1000, which is our ``env`` horizon).
# The ``rollout`` method of the ``env`` can take a policy as argument:
# it will then execute this policy at each step.
- with set_exploration_type(ExplorationType.MEAN), torch.no_grad():
+ with set_exploration_type(ExplorationType.DETERMINISTIC), torch.no_grad():
# execute a rollout with the trained policy
eval_rollout = env.rollout(1000, policy_module)
logs["eval reward"].append(eval_rollout["next", "reward"].mean().item())
diff --git a/tutorials/sphinx-tutorials/dqn_with_rnn.py b/tutorials/sphinx-tutorials/dqn_with_rnn.py
index 28a9638c6f6..8931f483384 100644
--- a/tutorials/sphinx-tutorials/dqn_with_rnn.py
+++ b/tutorials/sphinx-tutorials/dqn_with_rnn.py
@@ -440,7 +440,7 @@
exploration_module.step(data.numel())
updater.step()
- with set_exploration_type(ExplorationType.MODE), torch.no_grad():
+ with set_exploration_type(ExplorationType.DETERMINISTIC), torch.no_grad():
rollout = env.rollout(10000, stoch_policy)
traj_lens.append(rollout.get(("next", "step_count")).max().item())
diff --git a/tutorials/sphinx-tutorials/getting-started-1.py b/tutorials/sphinx-tutorials/getting-started-1.py
index 437cae26c42..4e8a1b30930 100644
--- a/tutorials/sphinx-tutorials/getting-started-1.py
+++ b/tutorials/sphinx-tutorials/getting-started-1.py
@@ -172,7 +172,7 @@
from torchrl.envs.utils import ExplorationType, set_exploration_type
-with set_exploration_type(ExplorationType.MEAN):
+with set_exploration_type(ExplorationType.DETERMINISTIC):
# takes the mean as action
rollout = env.rollout(max_steps=10, policy=policy)
with set_exploration_type(ExplorationType.RANDOM):
@@ -221,7 +221,7 @@
exploration_policy = TensorDictSequential(policy, exploration_module)
-with set_exploration_type(ExplorationType.MEAN):
+with set_exploration_type(ExplorationType.DETERMINISTIC):
# Turns off exploration
rollout = env.rollout(max_steps=10, policy=exploration_policy)
with set_exploration_type(ExplorationType.RANDOM):
diff --git a/tutorials/sphinx-tutorials/getting-started-5.py b/tutorials/sphinx-tutorials/getting-started-5.py
index 5f95fe1e534..d355d1888c5 100644
--- a/tutorials/sphinx-tutorials/getting-started-5.py
+++ b/tutorials/sphinx-tutorials/getting-started-5.py
@@ -89,7 +89,7 @@
optim_steps = 10
collector = SyncDataCollector(
env,
- policy,
+ policy_explore,
frames_per_batch=frames_per_batch,
total_frames=-1,
init_random_frames=init_rand_steps,
diff --git a/tutorials/sphinx-tutorials/multiagent_competitive_ddpg.py b/tutorials/sphinx-tutorials/multiagent_competitive_ddpg.py
index 77574b765e7..08b6d83bf5c 100644
--- a/tutorials/sphinx-tutorials/multiagent_competitive_ddpg.py
+++ b/tutorials/sphinx-tutorials/multiagent_competitive_ddpg.py
@@ -89,7 +89,7 @@
# wrapper for either PettingZoo or VMAS.
#
# 3. Following that, we will formulate the policy and critic networks, discussing the effects of various choices on
-# parameter sharing and critic centralisation.
+# parameter sharing and critic centralization.
#
# 4. Afterwards, we will create the sampling collector and the replay buffer.
#
@@ -179,7 +179,7 @@
memory_size = 1_000_000 # The replay buffer of each group can store this many frames
# Training
-n_optimiser_steps = 100 # Number of optimisation steps per training iteration
+n_optimiser_steps = 100 # Number of optimization steps per training iteration
train_batch_size = 128 # Number of frames trained in each optimiser step
lr = 3e-4 # Learning rate
max_grad_norm = 1.0 # Maximum norm for the gradients
@@ -193,7 +193,7 @@
# -----------
#
# Multi-agent environments simulate multiple agents interacting with the world.
-# TorchRL API allows integrating various types of multi-agent environment flavours.
+# TorchRL API allows integrating various types of multi-agent environment flavors.
# In this tutorial we will focus on environments where multiple agent groups interact in parallel.
# That is: at every step all agents will get an observation and take an action synchronously.
#
@@ -310,7 +310,7 @@
# Looking at the ``done_spec``, we can see that there are some keys that are outside of agent groups
# (``"done", "terminated", "truncated"``), which do not have a leading multi-agent dimension.
# These keys are shared by all agents and represent the environment global done state used for resetting.
-# By default, like in this case, parallel PettingZoo environments are done when any agent is done, but this behaviour
+# By default, like in this case, parallel PettingZoo environments are done when any agent is done, but this behavior
# can be overridden by setting ``done_on_any`` at PettingZoo environment construction.
#
# To quickly access the keys for each of these values in tensordicts, we can simply ask the environment for the
@@ -415,7 +415,7 @@
# Another important decision we need to make is whether we want the agents within a team to **share the policy parameters**.
# On the one hand, sharing parameters means that they will all share the same policy, which will allow them to benefit from
# each other's experiences. This will also result in faster training.
-# On the other hand, it will make them behaviourally *homogenous*, as they will in fact share the same model.
+# On the other hand, it will make them behaviorally *homogenous*, as they will in fact share the same model.
# For this example, we will enable sharing as we do not mind the homogeneity and can benefit from the computational
# speed, but it is important to always think about this decision in your own problems!
#
@@ -424,7 +424,7 @@
# **First**: define a neural network ``n_obs_per_agent`` -> ``n_actions_per_agents``
#
# For this we use the ``MultiAgentMLP``, a TorchRL module made exactly for
-# multiple agents, with much customisation available.
+# multiple agents, with much customization available.
#
# We will define a different policy for each group and store them in a dictionary.
#
@@ -817,7 +817,7 @@ def process_batch(batch: TensorDictBase) -> TensorDictBase:
target_updaters[group].step()
# Exploration sigma anneal update
- exploration_policies[group].step(current_frames)
+ exploration_policies[group][-1].step(current_frames)
# Stop training a certain group when a condition is met (e.g., number of training iterations)
if iteration == iteration_when_stop_training_evaders:
@@ -903,7 +903,7 @@ def process_batch(batch: TensorDictBase) -> TensorDictBase:
env_with_render = env_with_render.append_transform(
VideoRecorder(logger=video_logger, tag="vmas_rendered")
)
- with set_exploration_type(ExplorationType.MODE):
+ with set_exploration_type(ExplorationType.DETERMINISTIC):
print("Rendering rollout...")
env_with_render.rollout(100, policy=agents_exploration_policy)
print("Saving the video...")
diff --git a/tutorials/sphinx-tutorials/multiagent_ppo.py b/tutorials/sphinx-tutorials/multiagent_ppo.py
index d7d906a4fb0..ec24de6cddd 100644
--- a/tutorials/sphinx-tutorials/multiagent_ppo.py
+++ b/tutorials/sphinx-tutorials/multiagent_ppo.py
@@ -99,7 +99,7 @@
# wrapper for the VMAS simulator.
#
# 3. Next, we will design the policy and the critic networks, discussing the impact of the various choices on
-# parameter sharing and critic centralisation.
+# parameter sharing and critic centralization.
#
# 4. Next, we will create the sampling collector and the replay buffer.
#
@@ -184,7 +184,7 @@
# -----------
#
# Multi-agent environments simulate multiple agents interacting with the world.
-# TorchRL API allows integrating various types of multi-agent environment flavours.
+# TorchRL API allows integrating various types of multi-agent environment flavors.
# Some examples include environments with shared or individual agent rewards, done flags, and observations.
# For more information on how the multi-agent environments API works in TorchRL, you can check out the dedicated
# :ref:`doc section `.
@@ -195,7 +195,7 @@
# This means that all its state and physics
# are PyTorch tensors with a first dimension representing the number of parallel environments in a batch.
# This allows leveraging the Single Instruction Multiple Data (SIMD) paradigm of GPUs and significantly
-# speed up parallel computation by leveraging parallelisation in GPU warps. It also means
+# speed up parallel computation by leveraging parallelization in GPU warps. It also means
# that, when using it in TorchRL, both simulation and training can be run on-device, without ever passing
# data to the CPU.
#
@@ -207,7 +207,7 @@
# avoid colliding into each other.
# Agents act in a 2D continuous world with drag and elastic collisions.
# Their actions are 2D continuous forces which determine their acceleration.
-# The reward is composed of three terms: a collision penalisation, a reward based on the distance to the goal, and a
+# The reward is composed of three terms: a collision penalization, a reward based on the distance to the goal, and a
# final shared reward given when all agents reach their goal.
# The distance-based term is computed as the difference in the relative distance
# between an agent and its goal over two consecutive timesteps.
@@ -391,7 +391,7 @@
# **First**: define a neural network ``n_obs_per_agent`` -> ``2 * n_actions_per_agents``
#
# For this we use the ``MultiAgentMLP``, a TorchRL module made exactly for
-# multiple agents, with much customisation available.
+# multiple agents, with much customization available.
#
share_parameters_policy = True
diff --git a/tutorials/sphinx-tutorials/pendulum.py b/tutorials/sphinx-tutorials/pendulum.py
index d25bc2cdd8a..1593d42a0ec 100644
--- a/tutorials/sphinx-tutorials/pendulum.py
+++ b/tutorials/sphinx-tutorials/pendulum.py
@@ -107,7 +107,7 @@
from tensordict.nn import TensorDictModule
from torch import nn
-from torchrl.data import BoundedTensorSpec, CompositeSpec, UnboundedContinuousTensorSpec
+from torchrl.data import Bounded, Composite, Unbounded
from torchrl.envs import (
CatTensors,
EnvBase,
@@ -128,7 +128,7 @@
# * :meth:`EnvBase._reset`, which codes for the resetting of the simulator
# at a (potentially random) initial state;
# * :meth:`EnvBase._step` which codes for the state transition dynamic;
-# * :meth:`EnvBase._set_seed`` which implements the seeding mechanism;
+# * :meth:`EnvBase._set_seed` which implements the seeding mechanism;
# * the environment specs.
#
# Let us first describe the problem at hand: we would like to model a simple
@@ -410,14 +410,14 @@ def _reset(self, tensordict):
def _make_spec(self, td_params):
# Under the hood, this will populate self.output_spec["observation"]
- self.observation_spec = CompositeSpec(
- th=BoundedTensorSpec(
+ self.observation_spec = Composite(
+ th=Bounded(
low=-torch.pi,
high=torch.pi,
shape=(),
dtype=torch.float32,
),
- thdot=BoundedTensorSpec(
+ thdot=Bounded(
low=-td_params["params", "max_speed"],
high=td_params["params", "max_speed"],
shape=(),
@@ -433,25 +433,23 @@ def _make_spec(self, td_params):
self.state_spec = self.observation_spec.clone()
# action-spec will be automatically wrapped in input_spec when
# `self.action_spec = spec` will be called supported
- self.action_spec = BoundedTensorSpec(
+ self.action_spec = Bounded(
low=-td_params["params", "max_torque"],
high=td_params["params", "max_torque"],
shape=(1,),
dtype=torch.float32,
)
- self.reward_spec = UnboundedContinuousTensorSpec(shape=(*td_params.shape, 1))
+ self.reward_spec = Unbounded(shape=(*td_params.shape, 1))
def make_composite_from_td(td):
# custom function to convert a ``tensordict`` in a similar spec structure
# of unbounded values.
- composite = CompositeSpec(
+ composite = Composite(
{
key: make_composite_from_td(tensor)
if isinstance(tensor, TensorDictBase)
- else UnboundedContinuousTensorSpec(
- dtype=tensor.dtype, device=tensor.device, shape=tensor.shape
- )
+ else Unbounded(dtype=tensor.dtype, device=tensor.device, shape=tensor.shape)
for key, tensor in td.items()
},
shape=td.shape,
@@ -611,7 +609,7 @@ def __init__(self, td_params=None, seed=None, device="cpu"):
env,
# ``Unsqueeze`` the observations that we will concatenate
UnsqueezeTransform(
- unsqueeze_dim=-1,
+ dim=-1,
in_keys=["th", "thdot"],
in_keys_inv=["th", "thdot"],
),
@@ -694,7 +692,7 @@ def _reset(
# is of type ``Composite``
@_apply_to_composite
def transform_observation_spec(self, observation_spec):
- return BoundedTensorSpec(
+ return Bounded(
low=-1,
high=1,
shape=observation_spec.shape,
@@ -718,7 +716,7 @@ def _reset(
# is of type ``Composite``
@_apply_to_composite
def transform_observation_spec(self, observation_spec):
- return BoundedTensorSpec(
+ return Bounded(
low=-1,
high=1,
shape=observation_spec.shape,
diff --git a/tutorials/sphinx-tutorials/rb_tutorial.py b/tutorials/sphinx-tutorials/rb_tutorial.py
index fc3a3ae954c..f189888b804 100644
--- a/tutorials/sphinx-tutorials/rb_tutorial.py
+++ b/tutorials/sphinx-tutorials/rb_tutorial.py
@@ -133,7 +133,7 @@
# basic properties (such as shape and dtype) as the first batch of data that
# was used to instantiate the buffer.
# Passing data that does not match this requirement will either raise an
-# exception or lead to some undefined behaviours.
+# exception or lead to some undefined behaviors.
# - The :class:`~torchrl.data.LazyMemmapStorage` works as the
# :class:`~torchrl.data.LazyTensorStorage` in that it is lazy (i.e., it
# expects the first batch of data to be instantiated), and it requires data
diff --git a/tutorials/sphinx-tutorials/torchrl_demo.py b/tutorials/sphinx-tutorials/torchrl_demo.py
index 9d25da0a4cd..84f7be715ad 100644
--- a/tutorials/sphinx-tutorials/torchrl_demo.py
+++ b/tutorials/sphinx-tutorials/torchrl_demo.py
@@ -170,7 +170,7 @@
# * a collection of algorithms: we do not intend to provide SOTA implementations of RL algorithms,
# but we provide these algorithms only as examples of how to use the library.
#
-# * a research framework: modularity in TorchRL comes in two flavours. First, we try
+# * a research framework: modularity in TorchRL comes in two flavors. First, we try
# to build re-usable components, such that they can be easily swapped with each other.
# Second, we make our best such that components can be used independently of the rest
# of the library.
@@ -365,7 +365,7 @@
# Envs
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
-from torchrl.envs.libs.gym import GymEnv, GymWrapper
+from torchrl.envs.libs.gym import GymEnv, GymWrapper, set_gym_backend
gym_env = gym.make("Pendulum-v1")
env = GymWrapper(gym_env)
@@ -434,9 +434,16 @@
from torchrl.envs import ParallelEnv
+
+def make_env():
+ # You can control whether to use gym or gymnasium for your env
+ with set_gym_backend("gym"):
+ return GymEnv("Pendulum-v1", frame_skip=3, from_pixels=True, pixels_only=False)
+
+
base_env = ParallelEnv(
4,
- lambda: GymEnv("Pendulum-v1", frame_skip=3, from_pixels=True, pixels_only=False),
+ make_env,
mp_start_method="fork", # This will break on Windows machines! Remove and decorate with if __name__ == "__main__"
)
env = TransformedEnv(
@@ -572,10 +579,10 @@ def exec_sequence(params, data):
# ------------------------------
torch.manual_seed(0)
-from torchrl.data import BoundedTensorSpec
+from torchrl.data import Bounded
from torchrl.modules import SafeModule
-spec = BoundedTensorSpec(-torch.ones(3), torch.ones(3))
+spec = Bounded(-torch.ones(3), torch.ones(3))
base_module = nn.Linear(5, 3)
module = SafeModule(
module=base_module, spec=spec, in_keys=["obs"], out_keys=["action"], safe=True
@@ -652,14 +659,10 @@ def exec_sequence(params, data):
td_module(td)
print("random:", td["action"])
-with set_exploration_type(ExplorationType.MODE):
+with set_exploration_type(ExplorationType.DETERMINISTIC):
td_module(td)
print("mode:", td["action"])
-with set_exploration_type(ExplorationType.MODE):
- td_module(td)
- print("mean:", td["action"])
-
###############################################################################
# Using Environments and Modules
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
diff --git a/tutorials/sphinx-tutorials/torchrl_envs.py b/tutorials/sphinx-tutorials/torchrl_envs.py
index f2ae0372db2..34189396ee9 100644
--- a/tutorials/sphinx-tutorials/torchrl_envs.py
+++ b/tutorials/sphinx-tutorials/torchrl_envs.py
@@ -608,7 +608,7 @@ def env_make(env_name):
###############################################################################
# Transforming parallel environments
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
-# There are two equivalent ways of transforming parallen environments: in each
+# There are two equivalent ways of transforming parallel environments: in each
# process separately, or on the main process. It is even possible to do both.
# One can therefore think carefully about the transform design to leverage the
# device capabilities (e.g. transforms on cuda devices) and vectorizing