diff --git a/.github/unittest/linux/scripts/environment.yml b/.github/unittest/linux/scripts/environment.yml index 30e01cfc4b5..2234683a497 100644 --- a/.github/unittest/linux/scripts/environment.yml +++ b/.github/unittest/linux/scripts/environment.yml @@ -25,6 +25,7 @@ dependencies: - imageio==2.26.0 - wandb - dm_control + - mujoco - mlflow - av - coverage diff --git a/.github/unittest/linux/scripts/run_all.sh b/.github/unittest/linux/scripts/run_all.sh index 38235043d3f..22cdad1b479 100755 --- a/.github/unittest/linux/scripts/run_all.sh +++ b/.github/unittest/linux/scripts/run_all.sh @@ -3,8 +3,8 @@ set -euxo pipefail set -v -# ==================================================================================== # -# ================================ Setup env ========================================= # +# =============================================================================== # +# ================================ Init ========================================= # if [[ $OSTYPE != 'darwin'* ]]; then @@ -31,6 +31,10 @@ if [[ $OSTYPE != 'darwin'* ]]; then cp $this_dir/10_nvidia.json /usr/share/glvnd/egl_vendor.d/10_nvidia.json fi + +# ==================================================================================== # +# ================================ Setup env ========================================= # + # Avoid error: "fatal: unsafe repository" git config --global --add safe.directory '*' root_dir="$(git rev-parse --show-toplevel)" @@ -61,7 +65,7 @@ if [ ! -d "${env_dir}" ]; then fi conda activate "${env_dir}" -# 4. Install Conda dependencies +# 3. Install Conda dependencies printf "* Installing dependencies (except PyTorch)\n" echo " - python=${PYTHON_VERSION}" >> "${this_dir}/environment.yml" cat "${this_dir}/environment.yml" @@ -76,7 +80,7 @@ export DISPLAY=:0 export SDL_VIDEODRIVER=dummy # legacy from bash scripts: remove? -conda env config vars set MUJOCO_GL=$MUJOCO_GL PYOPENGL_PLATFORM=$MUJOCO_GL DISPLAY=:0 SDL_VIDEODRIVER=dummy LAZY_LEGACY_OP=False +conda env config vars set MUJOCO_GL=$MUJOCO_GL PYOPENGL_PLATFORM=$MUJOCO_GL DISPLAY=:0 SDL_VIDEODRIVER=dummy LAZY_LEGACY_OP=False RL_LOGGING_LEVEL=DEBUG pip3 install pip --upgrade pip install virtualenv @@ -88,10 +92,14 @@ conda deactivate conda activate "${env_dir}" echo "installing gymnasium" -pip3 install "gymnasium" -pip3 install ale_py -pip3 install mo-gymnasium[mujoco] # requires here bc needs mujoco-py -pip3 install mujoco -U +if [[ "$PYTHON_VERSION" == "3.12" ]]; then + pip3 install ale-py + pip3 install sympy + pip3 install "gymnasium[accept-rom-license,mujoco]<1.0" mo-gymnasium[mujoco] +else + pip3 install "gymnasium[atari,accept-rom-license,mujoco]<1.0" mo-gymnasium[mujoco] +fi +pip3 install "mujoco" -U # sanity check: remove? python3 -c """ @@ -127,13 +135,13 @@ if [[ "$TORCH_VERSION" == "nightly" ]]; then if [ "${CU_VERSION:-}" == cpu ] ; then pip3 install --pre torch torchvision --index-url https://download.pytorch.org/whl/nightly/cpu -U else - pip3 install --pre torch torchvision --index-url https://download.pytorch.org/whl/nightly/$CU_VERSION + pip3 install --pre torch torchvision --index-url https://download.pytorch.org/whl/nightly/$CU_VERSION -U fi elif [[ "$TORCH_VERSION" == "stable" ]]; then if [ "${CU_VERSION:-}" == cpu ] ; then - pip3 install torch torchvision --index-url https://download.pytorch.org/whl/cpu + pip3 install torch torchvision --index-url https://download.pytorch.org/whl/cpu -U else - pip3 install torch torchvision --index-url https://download.pytorch.org/whl/$CU_VERSION + pip3 install torch torchvision --index-url https://download.pytorch.org/whl/$CU_VERSION -U fi else printf "Failed to install pytorch" @@ -181,7 +189,9 @@ fi export PYTORCH_TEST_WITH_SLOW='1' python -m torch.utils.collect_env -# Avoid error: "fatal: unsafe repository" +## Avoid error: "fatal: unsafe repository" +#git config --global --add safe.directory '*' +#root_dir="$(git rev-parse --show-toplevel)" # solves ImportError: /lib64/libstdc++.so.6: version `GLIBCXX_3.4.21' not found #export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:$lib_dir @@ -198,7 +208,8 @@ if [ "${CU_VERSION:-}" != cpu ] ; then --timeout=120 --mp_fork_if_no_cuda else python .github/unittest/helpers/coverage_run_parallel.py -m pytest test \ - --instafail --durations 200 -vv --capture no --ignore test/test_rlhf.py --ignore test/test_distributed.py \ + --instafail --durations 200 -vv --capture no --ignore test/test_rlhf.py \ + --ignore test/test_distributed.py \ --timeout=120 --mp_fork_if_no_cuda fi diff --git a/.github/unittest/linux_distributed/scripts/environment.yml b/.github/unittest/linux_distributed/scripts/environment.yml index 6d27071791b..76160f7a16a 100644 --- a/.github/unittest/linux_distributed/scripts/environment.yml +++ b/.github/unittest/linux_distributed/scripts/environment.yml @@ -24,6 +24,7 @@ dependencies: - imageio==2.26.0 - wandb - dm_control + - mujoco - mlflow - av - coverage diff --git a/.github/unittest/linux_distributed/scripts/setup_env.sh b/.github/unittest/linux_distributed/scripts/setup_env.sh index 501dbe1c914..2a48ab21459 100755 --- a/.github/unittest/linux_distributed/scripts/setup_env.sh +++ b/.github/unittest/linux_distributed/scripts/setup_env.sh @@ -119,7 +119,7 @@ if [[ $OSTYPE != 'darwin'* ]]; then rm ale_py-0.8.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl fi echo "installing gymnasium" - pip install "gymnasium[atari,accept-rom-license]" + pip install "gymnasium[atari,accept-rom-license]<1.0" else - pip install "gymnasium[atari,accept-rom-license]" + pip install "gymnasium[atari,accept-rom-license]<1.0" fi diff --git a/.github/unittest/linux_examples/scripts/environment.yml b/.github/unittest/linux_examples/scripts/environment.yml index 688921f826a..f7dddbc5e3c 100644 --- a/.github/unittest/linux_examples/scripts/environment.yml +++ b/.github/unittest/linux_examples/scripts/environment.yml @@ -22,6 +22,7 @@ dependencies: - hydra-core - imageio==2.26.0 - dm_control + - mujoco - mlflow - av - coverage diff --git a/.github/unittest/linux_examples/scripts/run_all.sh b/.github/unittest/linux_examples/scripts/run_all.sh index 37719e51074..073ef59ed3f 100755 --- a/.github/unittest/linux_examples/scripts/run_all.sh +++ b/.github/unittest/linux_examples/scripts/run_all.sh @@ -130,7 +130,7 @@ elif [[ $PY_VERSION == *"3.11"* ]]; then pip install ale_py-0.8.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl rm ale_py-0.8.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl fi -pip install "gymnasium[atari,accept-rom-license]" +pip install "gymnasium[atari,accept-rom-license]<1.0" # ============================================================================================ # # ================================ PyTorch & TorchRL ========================================= # @@ -150,15 +150,15 @@ git submodule sync && git submodule update --init --recursive printf "Installing PyTorch with %s\n" "${CU_VERSION}" if [[ "$TORCH_VERSION" == "nightly" ]]; then if [ "${CU_VERSION:-}" == cpu ] ; then - pip3 install --pre torch torchvision --index-url https://download.pytorch.org/whl/nightly/cpu -U + pip3 install --pre torch torchvision numpy==1.26.4 --index-url https://download.pytorch.org/whl/nightly/cpu -U else - pip3 install --pre torch torchvision --index-url https://download.pytorch.org/whl/nightly/$CU_VERSION + pip3 install --pre torch torchvision numpy==1.26.4 --index-url https://download.pytorch.org/whl/nightly/$CU_VERSION fi elif [[ "$TORCH_VERSION" == "stable" ]]; then if [ "${CU_VERSION:-}" == cpu ] ; then - pip3 install torch torchvision --index-url https://download.pytorch.org/whl/cpu + pip3 install torch torchvision numpy==1.26.4 --index-url https://download.pytorch.org/whl/cpu else - pip3 install torch torchvision --index-url https://download.pytorch.org/whl/$CU_VERSION + pip3 install torch torchvision numpy==1.26.4 --index-url https://download.pytorch.org/whl/$CU_VERSION fi else printf "Failed to install pytorch" diff --git a/.github/unittest/linux_examples/scripts/run_test.sh b/.github/unittest/linux_examples/scripts/run_test.sh index f8b700c0410..ef0d081f8fd 100755 --- a/.github/unittest/linux_examples/scripts/run_test.sh +++ b/.github/unittest/linux_examples/scripts/run_test.sh @@ -205,6 +205,13 @@ python .github/unittest/helpers/coverage_run_parallel.py sota-implementations/iq env.train_num_envs=2 \ logger.mode=offline \ logger.backend= + python .github/unittest/helpers/coverage_run_parallel.py sota-implementations/gail/gail.py \ + ppo.collector.total_frames=48 \ + replay_buffer.batch_size=16 \ + ppo.loss.mini_batch_size=10 \ + ppo.collector.frames_per_batch=16 \ + logger.mode=offline \ + logger.backend= # With single envs python .github/unittest/helpers/coverage_run_parallel.py sota-implementations/dreamer/dreamer.py \ diff --git a/.github/unittest/linux_libs/scripts_brax/install.sh b/.github/unittest/linux_libs/scripts_brax/install.sh index 80efdc536ab..20a2643dac8 100755 --- a/.github/unittest/linux_libs/scripts_brax/install.sh +++ b/.github/unittest/linux_libs/scripts_brax/install.sh @@ -34,7 +34,7 @@ if [[ "$TORCH_VERSION" == "nightly" ]]; then fi elif [[ "$TORCH_VERSION" == "stable" ]]; then if [ "${CU_VERSION:-}" == cpu ] ; then - pip3 install torch --index-url https://download.pytorch.org/whl/cpu + pip3 install torch --index-url https://download.pytorch.org/whl/cpu -U else pip3 install torch --index-url https://download.pytorch.org/whl/cu121 fi diff --git a/.github/unittest/linux_libs/scripts_envpool/environment.yml b/.github/unittest/linux_libs/scripts_envpool/environment.yml index 9259a2a4a43..74a3c91cf06 100644 --- a/.github/unittest/linux_libs/scripts_envpool/environment.yml +++ b/.github/unittest/linux_libs/scripts_envpool/environment.yml @@ -19,4 +19,5 @@ dependencies: - pyyaml - scipy - dm_control + - mujoco - coverage diff --git a/.github/unittest/linux_libs/scripts_envpool/setup_env.sh b/.github/unittest/linux_libs/scripts_envpool/setup_env.sh index bb5c09079ea..aabc153bde3 100755 --- a/.github/unittest/linux_libs/scripts_envpool/setup_env.sh +++ b/.github/unittest/linux_libs/scripts_envpool/setup_env.sh @@ -82,9 +82,9 @@ if [[ $OSTYPE != 'darwin'* ]]; then fi echo "installing gym" # envpool does not currently work with gymnasium - pip install "gym[atari,accept-rom-license]" + pip install "gym[atari,accept-rom-license]<1.0" else - pip install "gym[atari,accept-rom-license]" + pip install "gym[atari,accept-rom-license]<1.0" fi pip install envpool treevalue diff --git a/.github/unittest/linux_libs/scripts_gym/batch_scripts.sh b/.github/unittest/linux_libs/scripts_gym/batch_scripts.sh index 11921b44821..dc264e07b2d 100755 --- a/.github/unittest/linux_libs/scripts_gym/batch_scripts.sh +++ b/.github/unittest/linux_libs/scripts_gym/batch_scripts.sh @@ -6,6 +6,7 @@ DIR="$(cd "$(dirname "$0")" && pwd)" set -e +set -v eval "$(./conda/bin/conda shell.bash hook)" conda activate ./env @@ -139,7 +140,7 @@ conda deactivate conda create --prefix ./cloned_env --clone ./env -y conda activate ./cloned_env -pip3 install 'gymnasium[accept-rom-license,ale-py,atari]' mo-gymnasium gymnasium-robotics -U +pip3 install 'gymnasium[accept-rom-license,ale-py,atari]<1.0' mo-gymnasium gymnasium-robotics -U $DIR/run_test.sh diff --git a/.github/unittest/linux_libs/scripts_gym/install.sh b/.github/unittest/linux_libs/scripts_gym/install.sh index d3eac779861..a66fe5fddd1 100755 --- a/.github/unittest/linux_libs/scripts_gym/install.sh +++ b/.github/unittest/linux_libs/scripts_gym/install.sh @@ -7,6 +7,7 @@ unset PYTORCH_VERSION apt-get update && apt-get install -y git wget gcc g++ set -e +set -v eval "$(./conda/bin/conda shell.bash hook)" conda activate ./env @@ -39,7 +40,7 @@ printf "Installing PyTorch with %s\n" "${CU_VERSION}" if [ "${CU_VERSION:-}" == cpu ] ; then conda install pytorch==2.0 torchvision==0.15 cpuonly -c pytorch -y else - conda install pytorch==2.0.1 torchvision==0.15.2 torchaudio==2.0.2 pytorch-cuda=11.8 -c pytorch -c nvidia -y + conda install pytorch==2.0.1 torchvision==0.15.2 torchaudio==2.0.2 pytorch-cuda=11.8 numpy==1.26 numpy-base==1.26 -c pytorch -c nvidia -y fi # Solving circular import: https://stackoverflow.com/questions/75501048/how-to-fix-attributeerror-partially-initialized-module-charset-normalizer-has diff --git a/.github/unittest/linux_libs/scripts_habitat/setup_env.sh b/.github/unittest/linux_libs/scripts_habitat/setup_env.sh index fc182a669ea..6ad970c3f47 100755 --- a/.github/unittest/linux_libs/scripts_habitat/setup_env.sh +++ b/.github/unittest/linux_libs/scripts_habitat/setup_env.sh @@ -67,8 +67,9 @@ pip install pip --upgrade conda env update --file "${this_dir}/environment.yml" --prune -#conda install habitat-sim withbullet headless -c conda-forge -c aihabitat -y conda install habitat-sim withbullet headless -c conda-forge -c aihabitat -y -conda run python -m pip install git+https://github.com/facebookresearch/habitat-lab.git@stable#subdirectory=habitat-lab -#conda run python -m pip install git+https://github.com/facebookresearch/habitat-lab.git#subdirectory=habitat-baselines +git clone https://github.com/facebookresearch/habitat-lab.git +cd habitat-lab +pip3 install -e habitat-lab +pip3 install -e habitat-baselines # install habitat_baselines conda run python -m pip install "gym[atari,accept-rom-license]" pygame diff --git a/.github/unittest/linux_libs/scripts_minari/environment.yml b/.github/unittest/linux_libs/scripts_minari/environment.yml index 27963a42a24..ad5bfc12650 100644 --- a/.github/unittest/linux_libs/scripts_minari/environment.yml +++ b/.github/unittest/linux_libs/scripts_minari/environment.yml @@ -17,4 +17,4 @@ dependencies: - pyyaml - scipy - hydra-core - - minari + - minari[gcs,hdf5] diff --git a/.github/unittest/linux_libs/scripts_open_spiel/environment.yml b/.github/unittest/linux_libs/scripts_open_spiel/environment.yml new file mode 100644 index 00000000000..937c37d58f6 --- /dev/null +++ b/.github/unittest/linux_libs/scripts_open_spiel/environment.yml @@ -0,0 +1,20 @@ +channels: + - pytorch + - defaults +dependencies: + - pip + - pip: + - hypothesis + - future + - cloudpickle + - pytest + - pytest-cov + - pytest-mock + - pytest-instafail + - pytest-rerunfailures + - pytest-error-for-skips + - expecttest + - pyyaml + - scipy + - hydra-core + - open_spiel diff --git a/.github/unittest/linux_libs/scripts_open_spiel/install.sh b/.github/unittest/linux_libs/scripts_open_spiel/install.sh new file mode 100755 index 00000000000..95a4a5a0e29 --- /dev/null +++ b/.github/unittest/linux_libs/scripts_open_spiel/install.sh @@ -0,0 +1,60 @@ +#!/usr/bin/env bash + +unset PYTORCH_VERSION +# For unittest, nightly PyTorch is used as the following section, +# so no need to set PYTORCH_VERSION. +# In fact, keeping PYTORCH_VERSION forces us to hardcode PyTorch version in config. + +set -e + +eval "$(./conda/bin/conda shell.bash hook)" +conda activate ./env + +if [ "${CU_VERSION:-}" == cpu ] ; then + version="cpu" +else + if [[ ${#CU_VERSION} -eq 4 ]]; then + CUDA_VERSION="${CU_VERSION:2:1}.${CU_VERSION:3:1}" + elif [[ ${#CU_VERSION} -eq 5 ]]; then + CUDA_VERSION="${CU_VERSION:2:2}.${CU_VERSION:4:1}" + fi + echo "Using CUDA $CUDA_VERSION as determined by CU_VERSION ($CU_VERSION)" + version="$(python -c "print('.'.join(\"${CUDA_VERSION}\".split('.')[:2]))")" +fi + +# submodules +git submodule sync && git submodule update --init --recursive + +printf "Installing PyTorch with cu121" +if [[ "$TORCH_VERSION" == "nightly" ]]; then + if [ "${CU_VERSION:-}" == cpu ] ; then + pip3 install --pre torch --index-url https://download.pytorch.org/whl/nightly/cpu -U + else + pip3 install --pre torch --index-url https://download.pytorch.org/whl/nightly/cu121 -U + fi +elif [[ "$TORCH_VERSION" == "stable" ]]; then + if [ "${CU_VERSION:-}" == cpu ] ; then + pip3 install torch --index-url https://download.pytorch.org/whl/cpu + else + pip3 install torch --index-url https://download.pytorch.org/whl/cu121 + fi +else + printf "Failed to install pytorch" + exit 1 +fi + +# install tensordict +if [[ "$RELEASE" == 0 ]]; then + pip3 install git+https://github.com/pytorch/tensordict.git +else + pip3 install tensordict +fi + +# smoke test +python -c "import functorch;import tensordict" + +printf "* Installing torchrl\n" +python setup.py develop + +# smoke test +python -c "import torchrl" diff --git a/.github/unittest/linux_libs/scripts_open_spiel/post_process.sh b/.github/unittest/linux_libs/scripts_open_spiel/post_process.sh new file mode 100755 index 00000000000..e97bf2a7b1b --- /dev/null +++ b/.github/unittest/linux_libs/scripts_open_spiel/post_process.sh @@ -0,0 +1,6 @@ +#!/usr/bin/env bash + +set -e + +eval "$(./conda/bin/conda shell.bash hook)" +conda activate ./env diff --git a/.github/unittest/linux_libs/scripts_open_spiel/run-clang-format.py b/.github/unittest/linux_libs/scripts_open_spiel/run-clang-format.py new file mode 100755 index 00000000000..5783a885d86 --- /dev/null +++ b/.github/unittest/linux_libs/scripts_open_spiel/run-clang-format.py @@ -0,0 +1,356 @@ +#!/usr/bin/env python +""" +MIT License + +Copyright (c) 2017 Guillaume Papin + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. + +A wrapper script around clang-format, suitable for linting multiple files +and to use for continuous integration. + +This is an alternative API for the clang-format command line. +It runs over multiple files and directories in parallel. +A diff output is produced and a sensible exit code is returned. + +""" + +import argparse +import difflib +import fnmatch +import multiprocessing +import os +import signal +import subprocess +import sys +import traceback +from functools import partial + +try: + from subprocess import DEVNULL # py3k +except ImportError: + DEVNULL = open(os.devnull, "wb") + + +DEFAULT_EXTENSIONS = "c,h,C,H,cpp,hpp,cc,hh,c++,h++,cxx,hxx,cu" + + +class ExitStatus: + SUCCESS = 0 + DIFF = 1 + TROUBLE = 2 + + +def list_files(files, recursive=False, extensions=None, exclude=None): + if extensions is None: + extensions = [] + if exclude is None: + exclude = [] + + out = [] + for file in files: + if recursive and os.path.isdir(file): + for dirpath, dnames, fnames in os.walk(file): + fpaths = [os.path.join(dirpath, fname) for fname in fnames] + for pattern in exclude: + # os.walk() supports trimming down the dnames list + # by modifying it in-place, + # to avoid unnecessary directory listings. + dnames[:] = [ + x + for x in dnames + if not fnmatch.fnmatch(os.path.join(dirpath, x), pattern) + ] + fpaths = [x for x in fpaths if not fnmatch.fnmatch(x, pattern)] + for f in fpaths: + ext = os.path.splitext(f)[1][1:] + if ext in extensions: + out.append(f) + else: + out.append(file) + return out + + +def make_diff(file, original, reformatted): + return list( + difflib.unified_diff( + original, + reformatted, + fromfile=f"{file}\t(original)", + tofile=f"{file}\t(reformatted)", + n=3, + ) + ) + + +class DiffError(Exception): + def __init__(self, message, errs=None): + super().__init__(message) + self.errs = errs or [] + + +class UnexpectedError(Exception): + def __init__(self, message, exc=None): + super().__init__(message) + self.formatted_traceback = traceback.format_exc() + self.exc = exc + + +def run_clang_format_diff_wrapper(args, file): + try: + ret = run_clang_format_diff(args, file) + return ret + except DiffError: + raise + except Exception as e: + raise UnexpectedError(f"{file}: {e.__class__.__name__}: {e}", e) + + +def run_clang_format_diff(args, file): + try: + with open(file, encoding="utf-8") as f: + original = f.readlines() + except OSError as exc: + raise DiffError(str(exc)) + invocation = [args.clang_format_executable, file] + + # Use of utf-8 to decode the process output. + # + # Hopefully, this is the correct thing to do. + # + # It's done due to the following assumptions (which may be incorrect): + # - clang-format will returns the bytes read from the files as-is, + # without conversion, and it is already assumed that the files use utf-8. + # - if the diagnostics were internationalized, they would use utf-8: + # > Adding Translations to Clang + # > + # > Not possible yet! + # > Diagnostic strings should be written in UTF-8, + # > the client can translate to the relevant code page if needed. + # > Each translation completely replaces the format string + # > for the diagnostic. + # > -- http://clang.llvm.org/docs/InternalsManual.html#internals-diag-translation + + try: + proc = subprocess.Popen( + invocation, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + universal_newlines=True, + encoding="utf-8", + ) + except OSError as exc: + raise DiffError( + f"Command '{subprocess.list2cmdline(invocation)}' failed to start: {exc}" + ) + proc_stdout = proc.stdout + proc_stderr = proc.stderr + + # hopefully the stderr pipe won't get full and block the process + outs = list(proc_stdout.readlines()) + errs = list(proc_stderr.readlines()) + proc.wait() + if proc.returncode: + raise DiffError( + "Command '{}' returned non-zero exit status {}".format( + subprocess.list2cmdline(invocation), proc.returncode + ), + errs, + ) + return make_diff(file, original, outs), errs + + +def bold_red(s): + return "\x1b[1m\x1b[31m" + s + "\x1b[0m" + + +def colorize(diff_lines): + def bold(s): + return "\x1b[1m" + s + "\x1b[0m" + + def cyan(s): + return "\x1b[36m" + s + "\x1b[0m" + + def green(s): + return "\x1b[32m" + s + "\x1b[0m" + + def red(s): + return "\x1b[31m" + s + "\x1b[0m" + + for line in diff_lines: + if line[:4] in ["--- ", "+++ "]: + yield bold(line) + elif line.startswith("@@ "): + yield cyan(line) + elif line.startswith("+"): + yield green(line) + elif line.startswith("-"): + yield red(line) + else: + yield line + + +def print_diff(diff_lines, use_color): + if use_color: + diff_lines = colorize(diff_lines) + sys.stdout.writelines(diff_lines) + + +def print_trouble(prog, message, use_colors): + error_text = "error:" + if use_colors: + error_text = bold_red(error_text) + print(f"{prog}: {error_text} {message}", file=sys.stderr) + + +def main(): + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument( + "--clang-format-executable", + metavar="EXECUTABLE", + help="path to the clang-format executable", + default="clang-format", + ) + parser.add_argument( + "--extensions", + help=f"comma separated list of file extensions (default: {DEFAULT_EXTENSIONS})", + default=DEFAULT_EXTENSIONS, + ) + parser.add_argument( + "-r", + "--recursive", + action="store_true", + help="run recursively over directories", + ) + parser.add_argument("files", metavar="file", nargs="+") + parser.add_argument("-q", "--quiet", action="store_true") + parser.add_argument( + "-j", + metavar="N", + type=int, + default=0, + help="run N clang-format jobs in parallel (default number of cpus + 1)", + ) + parser.add_argument( + "--color", + default="auto", + choices=["auto", "always", "never"], + help="show colored diff (default: auto)", + ) + parser.add_argument( + "-e", + "--exclude", + metavar="PATTERN", + action="append", + default=[], + help="exclude paths matching the given glob-like pattern(s) from recursive search", + ) + + args = parser.parse_args() + + # use default signal handling, like diff return SIGINT value on ^C + # https://bugs.python.org/issue14229#msg156446 + signal.signal(signal.SIGINT, signal.SIG_DFL) + try: + signal.SIGPIPE + except AttributeError: + # compatibility, SIGPIPE does not exist on Windows + pass + else: + signal.signal(signal.SIGPIPE, signal.SIG_DFL) + + colored_stdout = False + colored_stderr = False + if args.color == "always": + colored_stdout = True + colored_stderr = True + elif args.color == "auto": + colored_stdout = sys.stdout.isatty() + colored_stderr = sys.stderr.isatty() + + version_invocation = [args.clang_format_executable, "--version"] + try: + subprocess.check_call(version_invocation, stdout=DEVNULL) + except subprocess.CalledProcessError as e: + print_trouble(parser.prog, str(e), use_colors=colored_stderr) + return ExitStatus.TROUBLE + except OSError as e: + print_trouble( + parser.prog, + f"Command '{subprocess.list2cmdline(version_invocation)}' failed to start: {e}", + use_colors=colored_stderr, + ) + return ExitStatus.TROUBLE + + retcode = ExitStatus.SUCCESS + files = list_files( + args.files, + recursive=args.recursive, + exclude=args.exclude, + extensions=args.extensions.split(","), + ) + + if not files: + return + + njobs = args.j + if njobs == 0: + njobs = multiprocessing.cpu_count() + 1 + njobs = min(len(files), njobs) + + if njobs == 1: + # execute directly instead of in a pool, + # less overhead, simpler stacktraces + it = (run_clang_format_diff_wrapper(args, file) for file in files) + pool = None + else: + pool = multiprocessing.Pool(njobs) + it = pool.imap_unordered(partial(run_clang_format_diff_wrapper, args), files) + while True: + try: + outs, errs = next(it) + except StopIteration: + break + except DiffError as e: + print_trouble(parser.prog, str(e), use_colors=colored_stderr) + retcode = ExitStatus.TROUBLE + sys.stderr.writelines(e.errs) + except UnexpectedError as e: + print_trouble(parser.prog, str(e), use_colors=colored_stderr) + sys.stderr.write(e.formatted_traceback) + retcode = ExitStatus.TROUBLE + # stop at the first unexpected error, + # something could be very wrong, + # don't process all files unnecessarily + if pool: + pool.terminate() + break + else: + sys.stderr.writelines(errs) + if outs == []: + continue + if not args.quiet: + print_diff(outs, use_color=colored_stdout) + if retcode == ExitStatus.SUCCESS: + retcode = ExitStatus.DIFF + return retcode + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/.github/unittest/linux_libs/scripts_open_spiel/run_test.sh b/.github/unittest/linux_libs/scripts_open_spiel/run_test.sh new file mode 100755 index 00000000000..a09229bf59a --- /dev/null +++ b/.github/unittest/linux_libs/scripts_open_spiel/run_test.sh @@ -0,0 +1,28 @@ +#!/usr/bin/env bash + +set -e + +eval "$(./conda/bin/conda shell.bash hook)" +conda activate ./env + +apt-get update && apt-get install -y git wget + +export PYTORCH_TEST_WITH_SLOW='1' +export LAZY_LEGACY_OP=False +python -m torch.utils.collect_env +# Avoid error: "fatal: unsafe repository" +git config --global --add safe.directory '*' + +root_dir="$(git rev-parse --show-toplevel)" +env_dir="${root_dir}/env" +lib_dir="${env_dir}/lib" + +conda deactivate && conda activate ./env + +# this workflow only tests the libs +python -c "import pyspiel" + +python .github/unittest/helpers/coverage_run_parallel.py -m pytest test/test_libs.py --instafail -v --durations 200 --capture no -k TestOpenSpiel --error-for-skips --runslow + +coverage combine +coverage xml -i diff --git a/.github/unittest/linux_optdeps/scripts/setup_env.sh b/.github/unittest/linux_libs/scripts_open_spiel/setup_env.sh similarity index 96% rename from .github/unittest/linux_optdeps/scripts/setup_env.sh rename to .github/unittest/linux_libs/scripts_open_spiel/setup_env.sh index aa83bca32fc..e7b08ab02ff 100755 --- a/.github/unittest/linux_optdeps/scripts/setup_env.sh +++ b/.github/unittest/linux_libs/scripts_open_spiel/setup_env.sh @@ -6,8 +6,11 @@ # Do not install PyTorch and torchvision here, otherwise they also get cached. set -e +set -v this_dir="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )" +# Avoid error: "fatal: unsafe repository" + git config --global --add safe.directory '*' root_dir="$(git rev-parse --show-toplevel)" conda_dir="${root_dir}/conda" diff --git a/.github/unittest/linux_libs/scripts_openx/install.sh b/.github/unittest/linux_libs/scripts_openx/install.sh index 1be73fc1de0..c657fd48b46 100755 --- a/.github/unittest/linux_libs/scripts_openx/install.sh +++ b/.github/unittest/linux_libs/scripts_openx/install.sh @@ -37,9 +37,9 @@ if [[ "$TORCH_VERSION" == "nightly" ]]; then fi elif [[ "$TORCH_VERSION" == "stable" ]]; then if [ "${CU_VERSION:-}" == cpu ] ; then - pip3 install torch torchvision --index-url https://download.pytorch.org/whl/cpu + pip3 install torch torchvision --index-url https://download.pytorch.org/whl/cpu -U else - pip3 install torch torchvision --index-url https://download.pytorch.org/whl/cu121 + pip3 install torch torchvision --index-url https://download.pytorch.org/whl/cu121 -U fi else printf "Failed to install pytorch" diff --git a/.github/unittest/linux_libs/scripts_pettingzoo/environment.yml b/.github/unittest/linux_libs/scripts_pettingzoo/environment.yml index f6c79784a0b..8f4e35c8efa 100644 --- a/.github/unittest/linux_libs/scripts_pettingzoo/environment.yml +++ b/.github/unittest/linux_libs/scripts_pettingzoo/environment.yml @@ -20,3 +20,4 @@ dependencies: - pyyaml - autorom[accept-rom-license] - pettingzoo[all]==1.24.3 + - gymnasium<1.0.0 diff --git a/.github/unittest/linux_libs/scripts_rlhf/install.sh b/.github/unittest/linux_libs/scripts_rlhf/install.sh index d0363186c1a..9a5cf82074b 100755 --- a/.github/unittest/linux_libs/scripts_rlhf/install.sh +++ b/.github/unittest/linux_libs/scripts_rlhf/install.sh @@ -31,15 +31,15 @@ git submodule sync && git submodule update --init --recursive printf "Installing PyTorch with cu121" if [[ "$TORCH_VERSION" == "nightly" ]]; then if [ "${CU_VERSION:-}" == cpu ] ; then - pip3 install --pre torch --index-url https://download.pytorch.org/whl/nightly/cpu -U + pip3 install --pre torch numpy==1.26.4 --index-url https://download.pytorch.org/whl/nightly/cpu -U else - pip3 install --pre torch --index-url https://download.pytorch.org/whl/nightly/cu121 -U + pip3 install --pre torch numpy==1.26.4 --index-url https://download.pytorch.org/whl/nightly/cu121 -U fi elif [[ "$TORCH_VERSION" == "stable" ]]; then if [ "${CU_VERSION:-}" == cpu ] ; then - pip3 install torch --index-url https://download.pytorch.org/whl/cpu + pip3 install torch numpy==1.26.4 --index-url https://download.pytorch.org/whl/cpu else - pip3 install torch --index-url https://download.pytorch.org/whl/cu121 + pip3 install torch numpy==1.26.4 --index-url https://download.pytorch.org/whl/cu121 fi else printf "Failed to install pytorch" diff --git a/.github/unittest/linux_libs/scripts_robohive/environment.yml b/.github/unittest/linux_libs/scripts_robohive/environment.yml index cff88245d1e..4b6e4ef4f0e 100644 --- a/.github/unittest/linux_libs/scripts_robohive/environment.yml +++ b/.github/unittest/linux_libs/scripts_robohive/environment.yml @@ -6,7 +6,7 @@ dependencies: - protobuf - pip: # Initial version is required to install Atari ROMS in setup_env.sh - - gymnasium + - gymnasium<1.0 - hypothesis - future - cloudpickle diff --git a/.github/unittest/linux_libs/scripts_unity_mlagents/environment.yml b/.github/unittest/linux_libs/scripts_unity_mlagents/environment.yml new file mode 100644 index 00000000000..6dc82afbc25 --- /dev/null +++ b/.github/unittest/linux_libs/scripts_unity_mlagents/environment.yml @@ -0,0 +1,21 @@ +channels: + - pytorch + - defaults +dependencies: + - python==3.10.12 + - pip + - pip: + - mlagents_envs==1.0.0 + - hypothesis + - future + - cloudpickle + - pytest + - pytest-cov + - pytest-mock + - pytest-instafail + - pytest-rerunfailures + - pytest-error-for-skips + - expecttest + - pyyaml + - scipy + - hydra-core diff --git a/.github/unittest/linux_libs/scripts_unity_mlagents/install.sh b/.github/unittest/linux_libs/scripts_unity_mlagents/install.sh new file mode 100755 index 00000000000..95a4a5a0e29 --- /dev/null +++ b/.github/unittest/linux_libs/scripts_unity_mlagents/install.sh @@ -0,0 +1,60 @@ +#!/usr/bin/env bash + +unset PYTORCH_VERSION +# For unittest, nightly PyTorch is used as the following section, +# so no need to set PYTORCH_VERSION. +# In fact, keeping PYTORCH_VERSION forces us to hardcode PyTorch version in config. + +set -e + +eval "$(./conda/bin/conda shell.bash hook)" +conda activate ./env + +if [ "${CU_VERSION:-}" == cpu ] ; then + version="cpu" +else + if [[ ${#CU_VERSION} -eq 4 ]]; then + CUDA_VERSION="${CU_VERSION:2:1}.${CU_VERSION:3:1}" + elif [[ ${#CU_VERSION} -eq 5 ]]; then + CUDA_VERSION="${CU_VERSION:2:2}.${CU_VERSION:4:1}" + fi + echo "Using CUDA $CUDA_VERSION as determined by CU_VERSION ($CU_VERSION)" + version="$(python -c "print('.'.join(\"${CUDA_VERSION}\".split('.')[:2]))")" +fi + +# submodules +git submodule sync && git submodule update --init --recursive + +printf "Installing PyTorch with cu121" +if [[ "$TORCH_VERSION" == "nightly" ]]; then + if [ "${CU_VERSION:-}" == cpu ] ; then + pip3 install --pre torch --index-url https://download.pytorch.org/whl/nightly/cpu -U + else + pip3 install --pre torch --index-url https://download.pytorch.org/whl/nightly/cu121 -U + fi +elif [[ "$TORCH_VERSION" == "stable" ]]; then + if [ "${CU_VERSION:-}" == cpu ] ; then + pip3 install torch --index-url https://download.pytorch.org/whl/cpu + else + pip3 install torch --index-url https://download.pytorch.org/whl/cu121 + fi +else + printf "Failed to install pytorch" + exit 1 +fi + +# install tensordict +if [[ "$RELEASE" == 0 ]]; then + pip3 install git+https://github.com/pytorch/tensordict.git +else + pip3 install tensordict +fi + +# smoke test +python -c "import functorch;import tensordict" + +printf "* Installing torchrl\n" +python setup.py develop + +# smoke test +python -c "import torchrl" diff --git a/.github/unittest/linux_libs/scripts_unity_mlagents/post_process.sh b/.github/unittest/linux_libs/scripts_unity_mlagents/post_process.sh new file mode 100755 index 00000000000..e97bf2a7b1b --- /dev/null +++ b/.github/unittest/linux_libs/scripts_unity_mlagents/post_process.sh @@ -0,0 +1,6 @@ +#!/usr/bin/env bash + +set -e + +eval "$(./conda/bin/conda shell.bash hook)" +conda activate ./env diff --git a/.github/unittest/linux_libs/scripts_unity_mlagents/run-clang-format.py b/.github/unittest/linux_libs/scripts_unity_mlagents/run-clang-format.py new file mode 100755 index 00000000000..5783a885d86 --- /dev/null +++ b/.github/unittest/linux_libs/scripts_unity_mlagents/run-clang-format.py @@ -0,0 +1,356 @@ +#!/usr/bin/env python +""" +MIT License + +Copyright (c) 2017 Guillaume Papin + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. + +A wrapper script around clang-format, suitable for linting multiple files +and to use for continuous integration. + +This is an alternative API for the clang-format command line. +It runs over multiple files and directories in parallel. +A diff output is produced and a sensible exit code is returned. + +""" + +import argparse +import difflib +import fnmatch +import multiprocessing +import os +import signal +import subprocess +import sys +import traceback +from functools import partial + +try: + from subprocess import DEVNULL # py3k +except ImportError: + DEVNULL = open(os.devnull, "wb") + + +DEFAULT_EXTENSIONS = "c,h,C,H,cpp,hpp,cc,hh,c++,h++,cxx,hxx,cu" + + +class ExitStatus: + SUCCESS = 0 + DIFF = 1 + TROUBLE = 2 + + +def list_files(files, recursive=False, extensions=None, exclude=None): + if extensions is None: + extensions = [] + if exclude is None: + exclude = [] + + out = [] + for file in files: + if recursive and os.path.isdir(file): + for dirpath, dnames, fnames in os.walk(file): + fpaths = [os.path.join(dirpath, fname) for fname in fnames] + for pattern in exclude: + # os.walk() supports trimming down the dnames list + # by modifying it in-place, + # to avoid unnecessary directory listings. + dnames[:] = [ + x + for x in dnames + if not fnmatch.fnmatch(os.path.join(dirpath, x), pattern) + ] + fpaths = [x for x in fpaths if not fnmatch.fnmatch(x, pattern)] + for f in fpaths: + ext = os.path.splitext(f)[1][1:] + if ext in extensions: + out.append(f) + else: + out.append(file) + return out + + +def make_diff(file, original, reformatted): + return list( + difflib.unified_diff( + original, + reformatted, + fromfile=f"{file}\t(original)", + tofile=f"{file}\t(reformatted)", + n=3, + ) + ) + + +class DiffError(Exception): + def __init__(self, message, errs=None): + super().__init__(message) + self.errs = errs or [] + + +class UnexpectedError(Exception): + def __init__(self, message, exc=None): + super().__init__(message) + self.formatted_traceback = traceback.format_exc() + self.exc = exc + + +def run_clang_format_diff_wrapper(args, file): + try: + ret = run_clang_format_diff(args, file) + return ret + except DiffError: + raise + except Exception as e: + raise UnexpectedError(f"{file}: {e.__class__.__name__}: {e}", e) + + +def run_clang_format_diff(args, file): + try: + with open(file, encoding="utf-8") as f: + original = f.readlines() + except OSError as exc: + raise DiffError(str(exc)) + invocation = [args.clang_format_executable, file] + + # Use of utf-8 to decode the process output. + # + # Hopefully, this is the correct thing to do. + # + # It's done due to the following assumptions (which may be incorrect): + # - clang-format will returns the bytes read from the files as-is, + # without conversion, and it is already assumed that the files use utf-8. + # - if the diagnostics were internationalized, they would use utf-8: + # > Adding Translations to Clang + # > + # > Not possible yet! + # > Diagnostic strings should be written in UTF-8, + # > the client can translate to the relevant code page if needed. + # > Each translation completely replaces the format string + # > for the diagnostic. + # > -- http://clang.llvm.org/docs/InternalsManual.html#internals-diag-translation + + try: + proc = subprocess.Popen( + invocation, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + universal_newlines=True, + encoding="utf-8", + ) + except OSError as exc: + raise DiffError( + f"Command '{subprocess.list2cmdline(invocation)}' failed to start: {exc}" + ) + proc_stdout = proc.stdout + proc_stderr = proc.stderr + + # hopefully the stderr pipe won't get full and block the process + outs = list(proc_stdout.readlines()) + errs = list(proc_stderr.readlines()) + proc.wait() + if proc.returncode: + raise DiffError( + "Command '{}' returned non-zero exit status {}".format( + subprocess.list2cmdline(invocation), proc.returncode + ), + errs, + ) + return make_diff(file, original, outs), errs + + +def bold_red(s): + return "\x1b[1m\x1b[31m" + s + "\x1b[0m" + + +def colorize(diff_lines): + def bold(s): + return "\x1b[1m" + s + "\x1b[0m" + + def cyan(s): + return "\x1b[36m" + s + "\x1b[0m" + + def green(s): + return "\x1b[32m" + s + "\x1b[0m" + + def red(s): + return "\x1b[31m" + s + "\x1b[0m" + + for line in diff_lines: + if line[:4] in ["--- ", "+++ "]: + yield bold(line) + elif line.startswith("@@ "): + yield cyan(line) + elif line.startswith("+"): + yield green(line) + elif line.startswith("-"): + yield red(line) + else: + yield line + + +def print_diff(diff_lines, use_color): + if use_color: + diff_lines = colorize(diff_lines) + sys.stdout.writelines(diff_lines) + + +def print_trouble(prog, message, use_colors): + error_text = "error:" + if use_colors: + error_text = bold_red(error_text) + print(f"{prog}: {error_text} {message}", file=sys.stderr) + + +def main(): + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument( + "--clang-format-executable", + metavar="EXECUTABLE", + help="path to the clang-format executable", + default="clang-format", + ) + parser.add_argument( + "--extensions", + help=f"comma separated list of file extensions (default: {DEFAULT_EXTENSIONS})", + default=DEFAULT_EXTENSIONS, + ) + parser.add_argument( + "-r", + "--recursive", + action="store_true", + help="run recursively over directories", + ) + parser.add_argument("files", metavar="file", nargs="+") + parser.add_argument("-q", "--quiet", action="store_true") + parser.add_argument( + "-j", + metavar="N", + type=int, + default=0, + help="run N clang-format jobs in parallel (default number of cpus + 1)", + ) + parser.add_argument( + "--color", + default="auto", + choices=["auto", "always", "never"], + help="show colored diff (default: auto)", + ) + parser.add_argument( + "-e", + "--exclude", + metavar="PATTERN", + action="append", + default=[], + help="exclude paths matching the given glob-like pattern(s) from recursive search", + ) + + args = parser.parse_args() + + # use default signal handling, like diff return SIGINT value on ^C + # https://bugs.python.org/issue14229#msg156446 + signal.signal(signal.SIGINT, signal.SIG_DFL) + try: + signal.SIGPIPE + except AttributeError: + # compatibility, SIGPIPE does not exist on Windows + pass + else: + signal.signal(signal.SIGPIPE, signal.SIG_DFL) + + colored_stdout = False + colored_stderr = False + if args.color == "always": + colored_stdout = True + colored_stderr = True + elif args.color == "auto": + colored_stdout = sys.stdout.isatty() + colored_stderr = sys.stderr.isatty() + + version_invocation = [args.clang_format_executable, "--version"] + try: + subprocess.check_call(version_invocation, stdout=DEVNULL) + except subprocess.CalledProcessError as e: + print_trouble(parser.prog, str(e), use_colors=colored_stderr) + return ExitStatus.TROUBLE + except OSError as e: + print_trouble( + parser.prog, + f"Command '{subprocess.list2cmdline(version_invocation)}' failed to start: {e}", + use_colors=colored_stderr, + ) + return ExitStatus.TROUBLE + + retcode = ExitStatus.SUCCESS + files = list_files( + args.files, + recursive=args.recursive, + exclude=args.exclude, + extensions=args.extensions.split(","), + ) + + if not files: + return + + njobs = args.j + if njobs == 0: + njobs = multiprocessing.cpu_count() + 1 + njobs = min(len(files), njobs) + + if njobs == 1: + # execute directly instead of in a pool, + # less overhead, simpler stacktraces + it = (run_clang_format_diff_wrapper(args, file) for file in files) + pool = None + else: + pool = multiprocessing.Pool(njobs) + it = pool.imap_unordered(partial(run_clang_format_diff_wrapper, args), files) + while True: + try: + outs, errs = next(it) + except StopIteration: + break + except DiffError as e: + print_trouble(parser.prog, str(e), use_colors=colored_stderr) + retcode = ExitStatus.TROUBLE + sys.stderr.writelines(e.errs) + except UnexpectedError as e: + print_trouble(parser.prog, str(e), use_colors=colored_stderr) + sys.stderr.write(e.formatted_traceback) + retcode = ExitStatus.TROUBLE + # stop at the first unexpected error, + # something could be very wrong, + # don't process all files unnecessarily + if pool: + pool.terminate() + break + else: + sys.stderr.writelines(errs) + if outs == []: + continue + if not args.quiet: + print_diff(outs, use_color=colored_stdout) + if retcode == ExitStatus.SUCCESS: + retcode = ExitStatus.DIFF + return retcode + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/.github/unittest/linux_libs/scripts_unity_mlagents/run_test.sh b/.github/unittest/linux_libs/scripts_unity_mlagents/run_test.sh new file mode 100755 index 00000000000..d5bb8695c44 --- /dev/null +++ b/.github/unittest/linux_libs/scripts_unity_mlagents/run_test.sh @@ -0,0 +1,28 @@ +#!/usr/bin/env bash + +set -e + +eval "$(./conda/bin/conda shell.bash hook)" +conda activate ./env + +apt-get update && apt-get install -y git wget + +export PYTORCH_TEST_WITH_SLOW='1' +export LAZY_LEGACY_OP=False +python -m torch.utils.collect_env +# Avoid error: "fatal: unsafe repository" +git config --global --add safe.directory '*' + +root_dir="$(git rev-parse --show-toplevel)" +env_dir="${root_dir}/env" +lib_dir="${env_dir}/lib" + +conda deactivate && conda activate ./env + +# this workflow only tests the libs +python -c "import mlagents_envs" + +python .github/unittest/helpers/coverage_run_parallel.py -m pytest test/test_libs.py --instafail -v --durations 200 --capture no -k TestUnityMLAgents --runslow + +coverage combine +coverage xml -i diff --git a/.github/unittest/linux_libs/scripts_unity_mlagents/setup_env.sh b/.github/unittest/linux_libs/scripts_unity_mlagents/setup_env.sh new file mode 100755 index 00000000000..e7b08ab02ff --- /dev/null +++ b/.github/unittest/linux_libs/scripts_unity_mlagents/setup_env.sh @@ -0,0 +1,49 @@ +#!/usr/bin/env bash + +# This script is for setting up environment in which unit test is ran. +# To speed up the CI time, the resulting environment is cached. +# +# Do not install PyTorch and torchvision here, otherwise they also get cached. + +set -e +set -v + +this_dir="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )" +# Avoid error: "fatal: unsafe repository" + +git config --global --add safe.directory '*' +root_dir="$(git rev-parse --show-toplevel)" +conda_dir="${root_dir}/conda" +env_dir="${root_dir}/env" + +cd "${root_dir}" + +case "$(uname -s)" in + Darwin*) os=MacOSX;; + *) os=Linux +esac + +# 1. Install conda at ./conda +if [ ! -d "${conda_dir}" ]; then + printf "* Installing conda\n" + wget -O miniconda.sh "http://repo.continuum.io/miniconda/Miniconda3-latest-${os}-x86_64.sh" + bash ./miniconda.sh -b -f -p "${conda_dir}" +fi +eval "$(${conda_dir}/bin/conda shell.bash hook)" + +# 2. Create test environment at ./env +printf "python: ${PYTHON_VERSION}\n" +if [ ! -d "${env_dir}" ]; then + printf "* Creating a test environment\n" + conda create --prefix "${env_dir}" -y python="$PYTHON_VERSION" +fi +conda activate "${env_dir}" + +# 3. Install Conda dependencies +printf "* Installing dependencies (except PyTorch)\n" +echo " - python=${PYTHON_VERSION}" >> "${this_dir}/environment.yml" +cat "${this_dir}/environment.yml" + +pip install pip --upgrade + +conda env update --file "${this_dir}/environment.yml" --prune diff --git a/.github/unittest/linux_libs/scripts_vd4rl/install.sh b/.github/unittest/linux_libs/scripts_vd4rl/install.sh index 1be73fc1de0..256f8d065f6 100755 --- a/.github/unittest/linux_libs/scripts_vd4rl/install.sh +++ b/.github/unittest/linux_libs/scripts_vd4rl/install.sh @@ -37,7 +37,7 @@ if [[ "$TORCH_VERSION" == "nightly" ]]; then fi elif [[ "$TORCH_VERSION" == "stable" ]]; then if [ "${CU_VERSION:-}" == cpu ] ; then - pip3 install torch torchvision --index-url https://download.pytorch.org/whl/cpu + pip3 install torch torchvision --index-url https://download.pytorch.org/whl/cpu -U else pip3 install torch torchvision --index-url https://download.pytorch.org/whl/cu121 fi diff --git a/.github/unittest/linux_olddeps/scripts_gym_0_13/environment.yml b/.github/unittest/linux_olddeps/scripts_gym_0_13/environment.yml index d34011e7bdc..06c4a112933 100644 --- a/.github/unittest/linux_olddeps/scripts_gym_0_13/environment.yml +++ b/.github/unittest/linux_olddeps/scripts_gym_0_13/environment.yml @@ -22,6 +22,7 @@ dependencies: - scipy - hydra-core - dm_control -e git+https://github.com/deepmind/dm_control.git@c053360edea6170acfd9c8f65446703307d9d352#egg={dm_control} + - mujoco - patchelf - pyopengl==3.1.4 - ray diff --git a/.github/unittest/linux_olddeps/scripts_gym_0_13/install.sh b/.github/unittest/linux_olddeps/scripts_gym_0_13/install.sh index 58a33cd43f4..c1dde8bb7d0 100755 --- a/.github/unittest/linux_olddeps/scripts_gym_0_13/install.sh +++ b/.github/unittest/linux_olddeps/scripts_gym_0_13/install.sh @@ -39,7 +39,7 @@ printf "Installing PyTorch with %s\n" "${CU_VERSION}" if [ "${CU_VERSION:-}" == cpu ] ; then conda install pytorch==2.0 torchvision==0.15 cpuonly -c pytorch -y else - conda install pytorch==2.0.1 torchvision==0.15.2 torchaudio==2.0.2 pytorch-cuda=11.8 -c pytorch -c nvidia -y + conda install pytorch==2.0.1 torchvision==0.15.2 torchaudio==2.0.2 pytorch-cuda=11.8 numpy==1.26 -c pytorch -c nvidia -y fi # Solving circular import: https://stackoverflow.com/questions/75501048/how-to-fix-attributeerror-partially-initialized-module-charset-normalizer-has diff --git a/.github/unittest/linux_optdeps/scripts/install.sh b/.github/unittest/linux_optdeps/scripts/install.sh deleted file mode 100755 index 8ccbfbb8e19..00000000000 --- a/.github/unittest/linux_optdeps/scripts/install.sh +++ /dev/null @@ -1,36 +0,0 @@ -#!/usr/bin/env bash - -unset PYTORCH_VERSION - -set -e -set -v - -eval "$(./conda/bin/conda shell.bash hook)" -conda activate ./env - -if [[ ${#CU_VERSION} -eq 4 ]]; then - CUDA_VERSION="${CU_VERSION:2:1}.${CU_VERSION:3:1}" -elif [[ ${#CU_VERSION} -eq 5 ]]; then - CUDA_VERSION="${CU_VERSION:2:2}.${CU_VERSION:4:1}" -fi -echo "Using CUDA $CUDA_VERSION as determined by CU_VERSION ($CU_VERSION)" -version="$(python -c "print('.'.join(\"${CUDA_VERSION}\".split('.')[:2]))")" - -# submodules -git submodule sync && git submodule update --init --recursive - -printf "Installing PyTorch with %s\n" "${CU_VERSION}" -pip3 install --pre torch --index-url https://download.pytorch.org/whl/nightly/$CU_VERSION - -# install tensordict -if [[ "$RELEASE" == 0 ]]; then - pip3 install git+https://github.com/pytorch/tensordict.git -else - pip3 install tensordict -fi - -printf "* Installing torchrl\n" -python setup.py develop - -# smoke test -python -c "import torchrl" diff --git a/.github/unittest/linux_optdeps/scripts/run_all.sh b/.github/unittest/linux_optdeps/scripts/run_all.sh index 9edfec5ea46..7f34ffd42fd 100755 --- a/.github/unittest/linux_optdeps/scripts/run_all.sh +++ b/.github/unittest/linux_optdeps/scripts/run_all.sh @@ -2,9 +2,10 @@ set -euxo pipefail set -v +set -e -# ==================================================================================== # -# ================================ Init ============================================== # +# =============================================================================== # +# ================================ Init ========================================= # if [[ $OSTYPE != 'darwin'* ]]; then @@ -35,18 +36,133 @@ fi # ==================================================================================== # # ================================ Setup env ========================================= # -bash ${this_dir}/setup_env.sh +# Avoid error: "fatal: unsafe repository" +git config --global --add safe.directory '*' +root_dir="$(git rev-parse --show-toplevel)" +conda_dir="${root_dir}/conda" +env_dir="${root_dir}/env" +lib_dir="${env_dir}/lib" + +cd "${root_dir}" + +case "$(uname -s)" in + Darwin*) os=MacOSX;; + *) os=Linux +esac + +# 1. Install conda at ./conda +if [ ! -d "${conda_dir}" ]; then + printf "* Installing conda\n" + wget -O miniconda.sh "http://repo.continuum.io/miniconda/Miniconda3-latest-${os}-x86_64.sh" + bash ./miniconda.sh -b -f -p "${conda_dir}" +fi +eval "$(${conda_dir}/bin/conda shell.bash hook)" + +# 2. Create test environment at ./env +printf "python: ${PYTHON_VERSION}\n" +if [ ! -d "${env_dir}" ]; then + printf "* Creating a test environment\n" + conda create --prefix "${env_dir}" -y python="$PYTHON_VERSION" +fi +conda activate "${env_dir}" + +# 3. Install Conda dependencies +printf "* Installing dependencies (except PyTorch)\n" +echo " - python=${PYTHON_VERSION}" >> "${this_dir}/environment.yml" +cat "${this_dir}/environment.yml" + +pip3 install pip --upgrade + +conda env update --file "${this_dir}/environment.yml" --prune # ============================================================================================ # # ================================ PyTorch & TorchRL ========================================= # -bash ${this_dir}/install.sh +unset PYTORCH_VERSION + +if [ "${CU_VERSION:-}" == cpu ] ; then + version="cpu" + echo "Using cpu build" +else + if [[ ${#CU_VERSION} -eq 4 ]]; then + CUDA_VERSION="${CU_VERSION:2:1}.${CU_VERSION:3:1}" + elif [[ ${#CU_VERSION} -eq 5 ]]; then + CUDA_VERSION="${CU_VERSION:2:2}.${CU_VERSION:4:1}" + fi + echo "Using CUDA $CUDA_VERSION as determined by CU_VERSION ($CU_VERSION)" + version="$(python -c "print('.'.join(\"${CUDA_VERSION}\".split('.')[:2]))")" +fi + +# submodules +git submodule sync && git submodule update --init --recursive + +printf "Installing PyTorch with %s\n" "${CU_VERSION}" +if [[ "$TORCH_VERSION" == "nightly" ]]; then + if [ "${CU_VERSION:-}" == cpu ] ; then + pip3 install --pre torch --index-url https://download.pytorch.org/whl/nightly/cpu -U + else + pip3 install --pre torch --index-url https://download.pytorch.org/whl/nightly/$CU_VERSION -U + fi +elif [[ "$TORCH_VERSION" == "stable" ]]; then + if [ "${CU_VERSION:-}" == cpu ] ; then + pip3 install torch --index-url https://download.pytorch.org/whl/cpu -U + else + pip3 install torch --index-url https://download.pytorch.org/whl/$CU_VERSION -U + fi +else + printf "Failed to install pytorch" + exit 1 +fi + +# smoke test +python -c "import functorch" + +## install snapshot +#if [[ "$TORCH_VERSION" == "nightly" ]]; then +# pip3 install git+https://github.com/pytorch/torchsnapshot +#else +# pip3 install torchsnapshot +#fi + +# install tensordict +if [[ "$RELEASE" == 0 ]]; then + pip3 install git+https://github.com/pytorch/tensordict.git +else + pip3 install tensordict +fi + +printf "* Installing torchrl\n" +python setup.py develop + +# smoke test +python -c "import torchrl" # ==================================================================================== # # ================================ Run tests ========================================= # -bash ${this_dir}/run_test.sh +# find libstdc +STDC_LOC=$(find conda/ -name "libstdc++.so.6" | head -1) + +export PYTORCH_TEST_WITH_SLOW='1' +export LAZY_LEGACY_OP=False +python -m torch.utils.collect_env +# Avoid error: "fatal: unsafe repository" +git config --global --add safe.directory '*' +root_dir="$(git rev-parse --show-toplevel)" + +export MKL_THREADING_LAYER=GNU +export CKPT_BACKEND=torch +export MAX_IDLE_COUNT=100 +export BATCHED_PIPE_TIMEOUT=60 + +python .github/unittest/helpers/coverage_run_parallel.py -m pytest test \ + --instafail --durations 200 -vv --capture no --ignore test/test_rlhf.py \ + --ignore test/test_distributed.py \ + --timeout=120 --mp_fork_if_no_cuda + +coverage combine +coverage xml -i # ==================================================================================== # # ================================ Post-proc ========================================= # diff --git a/.github/unittest/linux_optdeps/scripts/run_test.sh b/.github/unittest/linux_optdeps/scripts/run_test.sh deleted file mode 100755 index e3f8c31cfc9..00000000000 --- a/.github/unittest/linux_optdeps/scripts/run_test.sh +++ /dev/null @@ -1,25 +0,0 @@ -#!/usr/bin/env bash - -set -e - -eval "$(./conda/bin/conda shell.bash hook)" -conda activate ./env - -# find libstdc -STDC_LOC=$(find conda/ -name "libstdc++.so.6" | head -1) - -export PYTORCH_TEST_WITH_SLOW='1' -export LAZY_LEGACY_OP=False -python -m torch.utils.collect_env -# Avoid error: "fatal: unsafe repository" -git config --global --add safe.directory '*' -root_dir="$(git rev-parse --show-toplevel)" -export MKL_THREADING_LAYER=GNU -export CKPT_BACKEND=torch -export BATCHED_PIPE_TIMEOUT=60 - -MUJOCO_GL=egl python .github/unittest/helpers/coverage_run_parallel.py -m pytest --instafail \ - -v --durations 200 --ignore test/test_distributed.py --ignore test/test_rlhf.py --capture no \ - --timeout=120 --mp_fork_if_no_cuda -coverage combine -coverage xml -i diff --git a/.github/workflows/benchmarks.yml b/.github/workflows/benchmarks.yml index 8eaed2fb825..28832d5229b 100644 --- a/.github/workflows/benchmarks.yml +++ b/.github/workflows/benchmarks.yml @@ -24,25 +24,36 @@ jobs: name: CPU Pytest benchmark runs-on: ubuntu-20.04 steps: - - uses: actions/checkout@v3 - - uses: actions/setup-python@v4 + - name: Who triggered this? + run: | + echo "Action triggered by ${{ github.event.pull_request.html_url }}" + - name: Checkout + uses: actions/checkout@v3 + with: + fetch-depth: 50 # this is to make sure we obtain the target base commit + - name: Python Setup + uses: actions/setup-python@v4 with: - python-version: 3.8 + python-version: '3.10' - name: Setup Environment run: | - python -m pip install --pre torch --index-url https://download.pytorch.org/whl/nightly/cpu -U - python -m pip install git+https://github.com/pytorch/tensordict - python setup.py develop - python -m pip install pytest pytest-benchmark + python3.10 -m venv ./py310 + source ./py310/bin/activate + + python3 -m pip install --pre torch --index-url https://download.pytorch.org/whl/nightly/cpu -U + python3 -m pip install git+https://github.com/pytorch/tensordict + python3 setup.py develop + python3 -m pip install pytest pytest-benchmark python3 -m pip install "gym[accept-rom-license,atari]" - python3 -m pip install dm_control - - name: Run benchmarks - run: | + python3 -m pip install "dm_control" "mujoco" + cd benchmarks/ - python -m pytest --benchmark-json output.json + export TORCHDYNAMO_INLINE_INBUILT_NN_MODULES=1 + export TD_GET_DEFAULTS_TO_NONE=1 + python3 -m pytest -vvv --rank 0 --benchmark-json output.json --ignore test_collectors_benchmark.py - name: Store benchmark results - if: ${{ github.ref == 'refs/heads/main' || github.event_name == 'workflow_dispatch' }} uses: benchmark-action/github-action-benchmark@v1 + if: ${{ github.ref == 'refs/heads/main' || github.event_name == 'workflow_dispatch' }} with: name: CPU Benchmark Results tool: 'pytest' @@ -66,46 +77,73 @@ jobs: image: nvidia/cuda:12.3.0-base-ubuntu22.04 options: --gpus all steps: - - name: Install deps + - name: Set GITHUB_BRANCH environment variable run: | - export TZ=Europe/London - export DEBIAN_FRONTEND=noninteractive # tzdata bug - apt-get update -y - apt-get install software-properties-common -y - add-apt-repository ppa:git-core/candidate -y - apt-get update -y - apt-get upgrade -y - apt-get -y install libglu1-mesa libgl1-mesa-glx libosmesa6 gcc curl g++ unzip wget libglfw3-dev libgles2-mesa-dev libglew-dev sudo git cmake libz-dev + if [ "${{ github.event_name }}" == "push" ]; then + export GITHUB_BRANCH=${{ github.event.branch }} + elif [ "${{ github.event_name }}" == "pull_request" ]; then + export GITHUB_BRANCH=${{ github.event.pull_request.head.ref }} + else + echo "Unsupported event type" + exit 1 + fi + echo "GITHUB_BRANCH=$GITHUB_BRANCH" >> $GITHUB_ENV + - name: Who triggered this? + run: | + echo "Action triggered by ${{ github.event.pull_request.html_url }}" - name: Check ldd --version run: ldd --version - name: Checkout uses: actions/checkout@v3 + with: + fetch-depth: 50 # this is to make sure we obtain the target base commit - name: Python Setup uses: actions/setup-python@v4 with: - python-version: 3.8 + python-version: '3.10' + - name: Setup Environment + run: | + export TZ=Europe/London + export DEBIAN_FRONTEND=noninteractive # tzdata bug + apt-get update -y + apt-get install software-properties-common -y + add-apt-repository ppa:git-core/candidate -y + apt-get update -y + apt-get upgrade -y + apt-get -y install libglu1-mesa libgl1-mesa-glx libosmesa6 gcc curl g++ unzip wget libglfw3-dev libgles2-mesa-dev libglew-dev sudo git cmake libz-dev libpython3.10-dev - name: Setup git run: git config --global --add safe.directory /__w/rl/rl - name: setup Path run: | echo /usr/local/bin >> $GITHUB_PATH - - name: Setup Environment + - name: Setup benchmarks run: | - python3 -m pip install --pre torch torchvision --index-url https://download.pytorch.org/whl/nightly/cu121 -U - python3 -m pip install git+https://github.com/pytorch/tensordict - python3 setup.py develop - python3 -m pip install pytest pytest-benchmark - python3 -m pip install "gym[accept-rom-license,atari]" - python3 -m pip install dm_control - - name: check GPU presence + echo "BASE_SHA=$(echo ${{ github.event.pull_request.base.sha }} | cut -c1-8)" >> $GITHUB_ENV + echo "HEAD_SHA=$(echo ${{ github.event.pull_request.head.sha }} | cut -c1-8)" >> $GITHUB_ENV + echo "BASELINE_JSON=$(mktemp)" >> $GITHUB_ENV + echo "CONTENDER_JSON=$(mktemp)" >> $GITHUB_ENV + echo "PR_COMMENT=$(mktemp)" >> $GITHUB_ENV + - name: Run run: | - python -c """import torch + python3.10 -m venv --system-site-packages ./py310 + source ./py310/bin/activate + export PYTHON_INCLUDE_DIR=/usr/include/python3.10 + + python3.10 -m pip install --pre torch torchvision --index-url https://download.pytorch.org/whl/nightly/cu124 -U + python3.10 -m pip install cmake ninja pytest pytest-benchmark mujoco dm_control "gym[accept-rom-license,atari]" + python3.10 -m pip install git+https://github.com/pytorch/tensordict + python3.10 setup.py develop + # python3.10 -m pip install git+https://github.com/pytorch/rl@$GITHUB_BRANCH + + # test import + python3 -c """import torch assert torch.cuda.device_count() """ - - name: Run benchmarks - run: | + cd benchmarks/ - python3 -m pytest --benchmark-json output.json + export TORCHDYNAMO_INLINE_INBUILT_NN_MODULES=1 + export TD_GET_DEFAULTS_TO_NONE=1 + python3 -m pytest -vvv --rank 0 --benchmark-json output.json --ignore test_collectors_benchmark.py - name: Store benchmark results uses: benchmark-action/github-action-benchmark@v1 if: ${{ github.ref == 'refs/heads/main' || github.event_name == 'workflow_dispatch' }} diff --git a/.github/workflows/benchmarks_pr.yml b/.github/workflows/benchmarks_pr.yml index a8a1bc4c8dc..dfd8850a6f7 100644 --- a/.github/workflows/benchmarks_pr.yml +++ b/.github/workflows/benchmarks_pr.yml @@ -1,5 +1,4 @@ name: Continuous Benchmark (PR) - on: pull_request: @@ -12,6 +11,7 @@ concurrency: cancel-in-progress: true jobs: + benchmark_cpu: name: CPU Pytest benchmark runs-on: ubuntu-20.04 @@ -26,15 +26,7 @@ jobs: - name: Python Setup uses: actions/setup-python@v4 with: - python-version: 3.8 - - name: Setup Environment - run: | - python -m pip install --pre torch --index-url https://download.pytorch.org/whl/nightly/cpu -U - python -m pip install git+https://github.com/pytorch/tensordict - python setup.py develop - python -m pip install pytest pytest-benchmark - python3 -m pip install "gym[accept-rom-license,atari]" - python3 -m pip install dm_control + python-version: '3.10' - name: Setup benchmarks run: | echo "BASE_SHA=$(echo ${{ github.event.pull_request.base.sha }} | cut -c1-8)" >> $GITHUB_ENV @@ -42,10 +34,22 @@ jobs: echo "BASELINE_JSON=$(mktemp)" >> $GITHUB_ENV echo "CONTENDER_JSON=$(mktemp)" >> $GITHUB_ENV echo "PR_COMMENT=$(mktemp)" >> $GITHUB_ENV - - name: Run benchmarks + - name: Setup Environment and tests run: | + python3.10 -m venv ./py310 + source ./py310/bin/activate + + python3 -m pip install --pre torch --index-url https://download.pytorch.org/whl/nightly/cpu -U + python3 -m pip install git+https://github.com/pytorch/tensordict + python3 setup.py develop + python3 -m pip install pytest pytest-benchmark + python3 -m pip install "gym[accept-rom-license,atari]" + python3 -m pip install "dm_control" "mujoco" + cd benchmarks/ - RUN_BENCHMARK="pytest --rank 0 --benchmark-json " + export TORCHDYNAMO_INLINE_INBUILT_NN_MODULES=1 + export TD_GET_DEFAULTS_TO_NONE=1 + RUN_BENCHMARK="python3 -m pytest -vvv --rank 0 --ignore test_collectors_benchmark.py --benchmark-json " git checkout ${{ github.event.pull_request.base.sha }} $RUN_BENCHMARK ${{ env.BASELINE_JSON }} git checkout ${{ github.event.pull_request.head.sha }} @@ -69,22 +73,23 @@ jobs: run: shell: bash -l {0} container: - image: nvidia/cuda:12.3.0-base-ubuntu22.04 + image: nvidia/cuda:12.4.1-cudnn-runtime-ubuntu22.04 options: --gpus all steps: + - name: Set GITHUB_BRANCH environment variable + run: | + if [ "${{ github.event_name }}" == "push" ]; then + export GITHUB_BRANCH=${{ github.event.branch }} + elif [ "${{ github.event_name }}" == "pull_request" ]; then + export GITHUB_BRANCH=${{ github.event.pull_request.head.ref }} + else + echo "Unsupported event type" + exit 1 + fi + echo "GITHUB_BRANCH=$GITHUB_BRANCH" >> $GITHUB_ENV - name: Who triggered this? run: | echo "Action triggered by ${{ github.event.pull_request.html_url }}" - - name: Install deps - run: | - export TZ=Europe/London - export DEBIAN_FRONTEND=noninteractive # tzdata bug - apt-get update -y - apt-get install software-properties-common -y - add-apt-repository ppa:git-core/candidate -y - apt-get update -y - apt-get upgrade -y - apt-get -y install libglu1-mesa libgl1-mesa-glx libosmesa6 gcc curl g++ unzip wget libglfw3-dev libgles2-mesa-dev libglew-dev sudo git cmake libz-dev - name: Check ldd --version run: ldd --version - name: Checkout @@ -94,25 +99,22 @@ jobs: - name: Python Setup uses: actions/setup-python@v4 with: - python-version: 3.8 + python-version: '3.10' + - name: Setup Environment + run: | + export TZ=Europe/London + export DEBIAN_FRONTEND=noninteractive # tzdata bug + apt-get update -y + apt-get install software-properties-common -y + add-apt-repository ppa:git-core/candidate -y + apt-get update -y + apt-get upgrade -y + apt-get -y install libglu1-mesa libgl1-mesa-glx libosmesa6 gcc curl g++ unzip wget libglfw3-dev libgles2-mesa-dev libglew-dev sudo git cmake libz-dev libpython3.10-dev - name: Setup git run: git config --global --add safe.directory /__w/rl/rl - name: setup Path run: | echo /usr/local/bin >> $GITHUB_PATH - - name: Setup Environment - run: | - python3 -m pip install --pre torch torchvision --index-url https://download.pytorch.org/whl/nightly/cu121 -U - python3 -m pip install git+https://github.com/pytorch/tensordict - python3 setup.py develop - python3 -m pip install pytest pytest-benchmark - python3 -m pip install "gym[accept-rom-license,atari]" - python3 -m pip install dm_control - - name: check GPU presence - run: | - python -c """import torch - assert torch.cuda.device_count() - """ - name: Setup benchmarks run: | echo "BASE_SHA=$(echo ${{ github.event.pull_request.base.sha }} | cut -c1-8)" >> $GITHUB_ENV @@ -120,10 +122,27 @@ jobs: echo "BASELINE_JSON=$(mktemp)" >> $GITHUB_ENV echo "CONTENDER_JSON=$(mktemp)" >> $GITHUB_ENV echo "PR_COMMENT=$(mktemp)" >> $GITHUB_ENV - - name: Run benchmarks + - name: Run run: | + python3.10 -m venv --system-site-packages ./py310 + source ./py310/bin/activate + export PYTHON_INCLUDE_DIR=/usr/include/python3.10 + + python3.10 -m pip install --pre torch torchvision --index-url https://download.pytorch.org/whl/nightly/cu124 -U + python3.10 -m pip install cmake ninja pytest pytest-benchmark mujoco dm_control "gym[accept-rom-license,atari]" + python3.10 -m pip install git+https://github.com/pytorch/tensordict + python3.10 setup.py develop + # python3.10 -m pip install git+https://github.com/pytorch/rl@$GITHUB_BRANCH + + # test import + python3 -c """import torch + assert torch.cuda.device_count() + """ + cd benchmarks/ - RUN_BENCHMARK="pytest --rank 0 --benchmark-json " + export TORCHDYNAMO_INLINE_INBUILT_NN_MODULES=1 + export TD_GET_DEFAULTS_TO_NONE=1 + RUN_BENCHMARK="python3 -m pytest -vvv --rank 0 --ignore test_collectors_benchmark.py --benchmark-json " git checkout ${{ github.event.pull_request.base.sha }} $RUN_BENCHMARK ${{ env.BASELINE_JSON }} git checkout ${{ github.event.pull_request.head.sha }} diff --git a/.github/workflows/build-wheels-aarch64-linux.yml b/.github/workflows/build-wheels-aarch64-linux.yml new file mode 100644 index 00000000000..63818f07365 --- /dev/null +++ b/.github/workflows/build-wheels-aarch64-linux.yml @@ -0,0 +1,51 @@ +name: Build Aarch64 Linux Wheels + +on: + pull_request: + push: + branches: + - nightly + - main + - release/* + tags: + # NOTE: Binary build pipelines should only get triggered on release candidate builds + # Release candidate tags look like: v1.11.0-rc1 + - v[0-9]+.[0-9]+.[0-9]+-rc[0-9]+ + workflow_dispatch: + +permissions: + id-token: write + contents: read + +jobs: + generate-matrix: + uses: pytorch/test-infra/.github/workflows/generate_binary_build_matrix.yml@main + with: + package-type: wheel + os: linux-aarch64 + test-infra-repository: pytorch/test-infra + test-infra-ref: main + with-cuda: disable + build: + needs: generate-matrix + strategy: + fail-fast: false + matrix: + include: + - repository: pytorch/rl + smoke-test-script: test/smoke_test.py + package-name: torchrl + name: pytorch/rl + uses: pytorch/test-infra/.github/workflows/build_wheels_linux.yml@main + with: + repository: ${{ matrix.repository }} + ref: "" + test-infra-repository: pytorch/test-infra + test-infra-ref: main + build-matrix: ${{ needs.generate-matrix.outputs.matrix }} + package-name: ${{ matrix.package-name }} + smoke-test-script: ${{ matrix.smoke-test-script }} + trigger-event: ${{ github.event_name }} + env-var-script: .github/scripts/td_script.sh + architecture: aarch64 + setup-miniconda: false diff --git a/.github/workflows/build-wheels-windows.yml b/.github/workflows/build-wheels-windows.yml index 9f2666ccdbf..556d805c643 100644 --- a/.github/workflows/build-wheels-windows.yml +++ b/.github/workflows/build-wheels-windows.yml @@ -32,10 +32,12 @@ jobs: matrix: include: - repository: pytorch/rl + pre-script: .github/scripts/td_script.sh + env-script: .github/scripts/version_script.bat post-script: "python packaging/wheel/relocate.py" smoke-test-script: test/smoke_test.py package-name: torchrl - name: pytorch/rl + name: ${{ matrix.repository }} uses: pytorch/test-infra/.github/workflows/build_wheels_windows.yml@main with: repository: ${{ matrix.repository }} @@ -43,8 +45,9 @@ jobs: test-infra-repository: pytorch/test-infra test-infra-ref: main build-matrix: ${{ needs.generate-matrix.outputs.matrix }} + pre-script: ${{ matrix.pre-script }} + env-script: ${{ matrix.env-script }} + post-script: ${{ matrix.post-script }} package-name: ${{ matrix.package-name }} smoke-test-script: ${{ matrix.smoke-test-script }} trigger-event: ${{ github.event_name }} - pre-script: .github/scripts/td_script.sh - env-script: .github/scripts/version_script.bat diff --git a/.github/workflows/docs.yml b/.github/workflows/docs.yml index 749e85b64dd..e153641e775 100644 --- a/.github/workflows/docs.yml +++ b/.github/workflows/docs.yml @@ -3,14 +3,13 @@ name: Generate documentation on: push: branches: + - nightly - main - release/* tags: - v[0-9]+.[0-9]+.[0-9] - v[0-9]+.[0-9]+.[0-9]+-rc[0-9]+ pull_request: - branches: - - "*" workflow_dispatch: concurrency: @@ -23,7 +22,7 @@ jobs: build-docs: strategy: matrix: - python_version: ["3.9"] + python_version: ["3.10"] cuda_arch_version: ["12.1"] uses: pytorch/test-infra/.github/workflows/linux_job.yml@main with: @@ -35,7 +34,7 @@ jobs: script: | set -e set -v - apt-get update && apt-get install -y git wget gcc g++ + apt-get update && apt-get install -y -f git wget gcc g++ dialog apt-utils root_dir="$(pwd)" conda_dir="${root_dir}/conda" env_dir="${root_dir}/env" @@ -47,14 +46,14 @@ jobs: bash ./miniconda.sh -b -f -p "${conda_dir}" eval "$(${conda_dir}/bin/conda shell.bash hook)" printf "* Creating a test environment\n" - conda create --prefix "${env_dir}" -y python=3.8 + conda create --prefix "${env_dir}" -y python=3.10 printf "* Activating\n" conda activate "${env_dir}" - + # 2. upgrade pip, ninja and packaging - apt-get install python3.8 python3-pip -y + apt-get install python3-pip unzip -y -f python3 -m pip install --upgrade pip - python3 -m pip install setuptools ninja packaging -U + python3 -m pip install setuptools ninja packaging cmake -U # 3. check python version python3 --version diff --git a/.github/workflows/nightly_build.yml b/.github/workflows/nightly_build.yml index c7a7d344157..08eb61bfa6c 100644 --- a/.github/workflows/nightly_build.yml +++ b/.github/workflows/nightly_build.yml @@ -39,7 +39,7 @@ jobs: runs-on: ubuntu-20.04 strategy: matrix: - python_version: [["3.8", "cp38-cp38"], ["3.9", "cp39-cp39"], ["3.10", "cp310-cp310"], ["3.11", "cp311-cp311"], ["3.12", "cp312-cp312"]] + python_version: [["3.9", "cp39-cp39"], ["3.10", "cp310-cp310"], ["3.11", "cp311-cp311"], ["3.12", "cp312-cp312"]] cuda_support: [["", "cpu", "cpu"]] container: pytorch/manylinux-${{ matrix.cuda_support[2] }} steps: @@ -79,7 +79,7 @@ jobs: runs-on: ubuntu-20.04 strategy: matrix: - python_version: [["3.8", "cp38-cp38"], ["3.9", "cp39-cp39"], ["3.10", "cp310-cp310"], ["3.11", "cp311-cp311"], ["3.12", "cp312-cp312"]] + python_version: [["3.9", "cp39-cp39"], ["3.10", "cp310-cp310"], ["3.11", "cp311-cp311"], ["3.12", "cp312-cp312"]] cuda_support: [["", "cpu", "cpu"]] container: pytorch/manylinux-${{ matrix.cuda_support[2] }} steps: @@ -110,7 +110,7 @@ jobs: runs-on: ubuntu-20.04 strategy: matrix: - python_version: [["3.8", "cp38-cp38"], ["3.9", "cp39-cp39"], ["3.10", "cp310-cp310"], ["3.11", "cp311-cp311"], ["3.12", "cp312-cp312"]] + python_version: [["3.9", "cp39-cp39"], ["3.10", "cp310-cp310"], ["3.11", "cp311-cp311"], ["3.12", "cp312-cp312"]] cuda_support: [["", "cpu", "cpu"]] steps: - name: Setup Python @@ -172,7 +172,7 @@ jobs: runs-on: windows-latest strategy: matrix: - python_version: [["3.8", "3.8"], ["3.9", "3.9"], ["3.10", "3.10.3"], ["3.11", "3.11"], ["3.12", "3.12"]] + python_version: [["3.9", "3.9"], ["3.10", "3.10.3"], ["3.11", "3.11"], ["3.12", "3.12"]] steps: - name: Setup Python uses: actions/setup-python@v5 @@ -205,7 +205,7 @@ jobs: runs-on: windows-latest strategy: matrix: - python_version: [["3.8", "3.8"], ["3.9", "3.9"], ["3.10", "3.10.3"], ["3.11", "3.11"], ["3.12", "3.12"]] + python_version: [["3.9", "3.9"], ["3.10", "3.10.3"], ["3.11", "3.11"], ["3.12", "3.12"]] steps: - name: Setup Python uses: actions/setup-python@v5 @@ -262,7 +262,7 @@ jobs: runs-on: windows-latest strategy: matrix: - python_version: [["3.8", "3.8"], ["3.9", "3.9"], ["3.10", "3.10.3"], ["3.11", "3.11"], ["3.12", "3.12"]] + python_version: [["3.9", "3.9"], ["3.10", "3.10.3"], ["3.11", "3.11"], ["3.12", "3.12"]] steps: - name: Checkout torchrl uses: actions/checkout@v3 diff --git a/.github/workflows/test-linux-examples.yml b/.github/workflows/test-linux-examples.yml index fd0adaf6ed5..39c97fae266 100644 --- a/.github/workflows/test-linux-examples.yml +++ b/.github/workflows/test-linux-examples.yml @@ -49,6 +49,7 @@ jobs: echo "PYTHON_VERSION: $PYTHON_VERSION" echo "CU_VERSION: $CU_VERSION" + export TD_GET_DEFAULTS_TO_NONE=1 ## setup_env.sh bash .github/unittest/linux_examples/scripts/run_all.sh diff --git a/.github/workflows/test-linux-habitat.yml b/.github/workflows/test-linux-habitat.yml index 3f6e89a70f9..6a1c52f90fa 100644 --- a/.github/workflows/test-linux-habitat.yml +++ b/.github/workflows/test-linux-habitat.yml @@ -46,5 +46,6 @@ jobs: export CU_VERSION="cu${CUDA_ARCH_VERSION:0:2}${CUDA_ARCH_VERSION:3:1}" # Remove the following line when the GPU tests are working inside docker, and uncomment the above lines #export CU_VERSION="cpu" + export TD_GET_DEFAULTS_TO_NONE=1 bash .github/unittest/linux_libs/scripts_habitat/run_all.sh diff --git a/.github/workflows/test-linux-libs.yml b/.github/workflows/test-linux-libs.yml index 9e1875cac18..bd394f39fa7 100644 --- a/.github/workflows/test-linux-libs.yml +++ b/.github/workflows/test-linux-libs.yml @@ -44,6 +44,7 @@ jobs: export TAR_OPTIONS="--no-same-owner" export UPLOAD_CHANNEL="nightly" export TF_CPP_MIN_LOG_LEVEL=0 + export TD_GET_DEFAULTS_TO_NONE=1 bash .github/unittest/linux_libs/scripts_ataridqn/setup_env.sh bash .github/unittest/linux_libs/scripts_ataridqn/install.sh @@ -81,6 +82,7 @@ jobs: export UPLOAD_CHANNEL="nightly" export TF_CPP_MIN_LOG_LEVEL=0 export BATCHED_PIPE_TIMEOUT=60 + export TD_GET_DEFAULTS_TO_NONE=1 nvidia-smi @@ -114,6 +116,7 @@ jobs: export UPLOAD_CHANNEL="nightly" export TF_CPP_MIN_LOG_LEVEL=0 export BATCHED_PIPE_TIMEOUT=60 + export TD_GET_DEFAULTS_TO_NONE=1 bash .github/unittest/linux_libs/scripts_d4rl/setup_env.sh bash .github/unittest/linux_libs/scripts_d4rl/install.sh @@ -148,6 +151,7 @@ jobs: export UPLOAD_CHANNEL="nightly" export TF_CPP_MIN_LOG_LEVEL=0 export BATCHED_PIPE_TIMEOUT=60 + export TD_GET_DEFAULTS_TO_NONE=1 bash .github/unittest/linux_libs/scripts_d4rl/setup_env.sh bash .github/unittest/linux_libs/scripts_d4rl/install.sh @@ -181,6 +185,7 @@ jobs: export TAR_OPTIONS="--no-same-owner" export UPLOAD_CHANNEL="nightly" export TF_CPP_MIN_LOG_LEVEL=0 + export TD_GET_DEFAULTS_TO_NONE=1 bash .github/unittest/linux_libs/scripts_gen-dgrl/setup_env.sh bash .github/unittest/linux_libs/scripts_gen-dgrl/install.sh @@ -216,6 +221,7 @@ jobs: export LD_LIBRARY_PATH="$LD_LIBRARY_PATH:/work/mujoco-py/mujoco_py/binaries/linux/mujoco210/bin" export TAR_OPTIONS="--no-same-owner" export BATCHED_PIPE_TIMEOUT=60 + export TD_GET_DEFAULTS_TO_NONE=1 ./.github/unittest/linux_libs/scripts_gym/setup_env.sh ./.github/unittest/linux_libs/scripts_gym/batch_scripts.sh @@ -251,6 +257,7 @@ jobs: export UPLOAD_CHANNEL="nightly" export TF_CPP_MIN_LOG_LEVEL=0 export BATCHED_PIPE_TIMEOUT=60 + export TD_GET_DEFAULTS_TO_NONE=1 nvidia-smi @@ -285,6 +292,7 @@ jobs: export UPLOAD_CHANNEL="nightly" export TF_CPP_MIN_LOG_LEVEL=0 export BATCHED_PIPE_TIMEOUT=60 + export TD_GET_DEFAULTS_TO_NONE=1 nvidia-smi @@ -293,6 +301,82 @@ jobs: bash .github/unittest/linux_libs/scripts_meltingpot/run_test.sh bash .github/unittest/linux_libs/scripts_meltingpot/post_process.sh + unittests-open_spiel: + strategy: + matrix: + python_version: ["3.9"] + cuda_arch_version: ["12.1"] + if: ${{ github.event_name == 'push' || contains(github.event.pull_request.labels.*.name, 'Environments') }} + uses: pytorch/test-infra/.github/workflows/linux_job.yml@main + with: + repository: pytorch/rl + runner: "linux.g5.4xlarge.nvidia.gpu" + gpu-arch-type: cuda + gpu-arch-version: "11.7" + docker-image: "pytorch/manylinux-cuda124" + timeout: 120 + script: | + if [[ "${{ github.ref }}" =~ release/* ]]; then + export RELEASE=1 + export TORCH_VERSION=stable + else + export RELEASE=0 + export TORCH_VERSION=nightly + fi + + set -euo pipefail + export PYTHON_VERSION="3.9" + export CU_VERSION="12.1" + export TAR_OPTIONS="--no-same-owner" + export UPLOAD_CHANNEL="nightly" + export TF_CPP_MIN_LOG_LEVEL=0 + export BATCHED_PIPE_TIMEOUT=60 + + nvidia-smi + + bash .github/unittest/linux_libs/scripts_open_spiel/setup_env.sh + bash .github/unittest/linux_libs/scripts_open_spiel/install.sh + bash .github/unittest/linux_libs/scripts_open_spiel/run_test.sh + bash .github/unittest/linux_libs/scripts_open_spiel/post_process.sh + + unittests-unity_mlagents: + strategy: + matrix: + python_version: ["3.10.12"] + cuda_arch_version: ["12.1"] + if: ${{ github.event_name == 'push' || contains(github.event.pull_request.labels.*.name, 'Environments') }} + uses: pytorch/test-infra/.github/workflows/linux_job.yml@main + with: + repository: pytorch/rl + runner: "linux.g5.4xlarge.nvidia.gpu" + gpu-arch-type: cuda + gpu-arch-version: "11.7" + docker-image: "pytorch/manylinux-cuda124" + timeout: 120 + script: | + if [[ "${{ github.ref }}" =~ release/* ]]; then + export RELEASE=1 + export TORCH_VERSION=stable + else + export RELEASE=0 + export TORCH_VERSION=nightly + fi + + set -euo pipefail + export PYTHON_VERSION="3.10.12" + export CU_VERSION="12.1" + export TAR_OPTIONS="--no-same-owner" + export UPLOAD_CHANNEL="nightly" + export TF_CPP_MIN_LOG_LEVEL=0 + export BATCHED_PIPE_TIMEOUT=60 + + nvidia-smi + + bash .github/unittest/linux_libs/scripts_unity_mlagents/setup_env.sh + bash .github/unittest/linux_libs/scripts_unity_mlagents/install.sh + bash .github/unittest/linux_libs/scripts_unity_mlagents/run_test.sh + bash .github/unittest/linux_libs/scripts_unity_mlagents/post_process.sh + unittests-minari: strategy: matrix: @@ -321,6 +405,7 @@ jobs: export UPLOAD_CHANNEL="nightly" export TF_CPP_MIN_LOG_LEVEL=0 export BATCHED_PIPE_TIMEOUT=60 + export TD_GET_DEFAULTS_TO_NONE=1 bash .github/unittest/linux_libs/scripts_minari/setup_env.sh bash .github/unittest/linux_libs/scripts_minari/install.sh @@ -355,6 +440,7 @@ jobs: export UPLOAD_CHANNEL="nightly" export TF_CPP_MIN_LOG_LEVEL=0 export BATCHED_PIPE_TIMEOUT=60 + export TD_GET_DEFAULTS_TO_NONE=1 bash .github/unittest/linux_libs/scripts_openx/setup_env.sh bash .github/unittest/linux_libs/scripts_openx/install.sh @@ -387,6 +473,7 @@ jobs: export UPLOAD_CHANNEL="nightly" export TF_CPP_MIN_LOG_LEVEL=0 export BATCHED_PIPE_TIMEOUT=60 + export TD_GET_DEFAULTS_TO_NONE=1 nvidia-smi @@ -423,6 +510,7 @@ jobs: export UPLOAD_CHANNEL="nightly" export TF_CPP_MIN_LOG_LEVEL=0 export BATCHED_PIPE_TIMEOUT=60 + export TD_GET_DEFAULTS_TO_NONE=1 bash .github/unittest/linux_libs/scripts_robohive/setup_env.sh bash .github/unittest/linux_libs/scripts_robohive/install_and_run_test.sh @@ -456,6 +544,7 @@ jobs: export UPLOAD_CHANNEL="nightly" export TF_CPP_MIN_LOG_LEVEL=0 export BATCHED_PIPE_TIMEOUT=60 + export TD_GET_DEFAULTS_TO_NONE=1 bash .github/unittest/linux_libs/scripts_roboset/setup_env.sh bash .github/unittest/linux_libs/scripts_roboset/install.sh @@ -491,6 +580,7 @@ jobs: export UPLOAD_CHANNEL="nightly" export TF_CPP_MIN_LOG_LEVEL=0 export BATCHED_PIPE_TIMEOUT=60 + export TD_GET_DEFAULTS_TO_NONE=1 bash .github/unittest/linux_libs/scripts_sklearn/setup_env.sh bash .github/unittest/linux_libs/scripts_sklearn/install.sh @@ -527,6 +617,7 @@ jobs: export UPLOAD_CHANNEL="nightly" export TF_CPP_MIN_LOG_LEVEL=0 export BATCHED_PIPE_TIMEOUT=60 + export TD_GET_DEFAULTS_TO_NONE=1 nvidia-smi @@ -563,6 +654,7 @@ jobs: export UPLOAD_CHANNEL="nightly" export TF_CPP_MIN_LOG_LEVEL=0 export BATCHED_PIPE_TIMEOUT=60 + export TD_GET_DEFAULTS_TO_NONE=1 bash .github/unittest/linux_libs/scripts_vd4rl/setup_env.sh bash .github/unittest/linux_libs/scripts_vd4rl/install.sh @@ -599,6 +691,7 @@ jobs: export UPLOAD_CHANNEL="nightly" export TF_CPP_MIN_LOG_LEVEL=0 export BATCHED_PIPE_TIMEOUT=60 + export TD_GET_DEFAULTS_TO_NONE=1 nvidia-smi diff --git a/.github/workflows/test-linux-rlhf.yml b/.github/workflows/test-linux-rlhf.yml index 832d432c997..accbe6e7610 100644 --- a/.github/workflows/test-linux-rlhf.yml +++ b/.github/workflows/test-linux-rlhf.yml @@ -44,6 +44,7 @@ jobs: export TAR_OPTIONS="--no-same-owner" export UPLOAD_CHANNEL="nightly" export TF_CPP_MIN_LOG_LEVEL=0 + export TD_GET_DEFAULTS_TO_NONE=1 bash .github/unittest/linux_libs/scripts_rlhf/setup_env.sh bash .github/unittest/linux_libs/scripts_rlhf/install.sh diff --git a/.github/workflows/test-linux.yml b/.github/workflows/test-linux.yml index e8728180c67..75a646c25c4 100644 --- a/.github/workflows/test-linux.yml +++ b/.github/workflows/test-linux.yml @@ -22,7 +22,7 @@ jobs: tests-cpu: strategy: matrix: - python_version: ["3.8", "3.9", "3.10", "3.11", "3.12"] + python_version: ["3.9", "3.10", "3.11", "3.12"] fail-fast: false uses: pytorch/test-infra/.github/workflows/linux_job.yml@main with: @@ -38,6 +38,39 @@ jobs: export RELEASE=0 export TORCH_VERSION=nightly fi + export TD_GET_DEFAULTS_TO_NONE=1 + # Set env vars from matrix + export PYTHON_VERSION=${{ matrix.python_version }} + export CU_VERSION="cpu" + + echo "PYTHON_VERSION: $PYTHON_VERSION" + echo "CU_VERSION: $CU_VERSION" + + ## setup_env.sh + bash .github/unittest/linux/scripts/run_all.sh + + tests-cpu-oldget: + # Tests that TD_GET_DEFAULTS_TO_NONE=0 works fine as this will be the default for TD up to 0.7 + strategy: + matrix: + python_version: ["3.12"] + fail-fast: false + uses: pytorch/test-infra/.github/workflows/linux_job.yml@main + with: + runner: linux.12xlarge + repository: pytorch/rl + docker-image: "nvidia/cuda:12.2.0-devel-ubuntu22.04" + timeout: 90 + script: | + if [[ "${{ github.ref }}" =~ release/* ]]; then + export RELEASE=1 + export TORCH_VERSION=stable + else + export RELEASE=0 + export TORCH_VERSION=nightly + fi + export TD_GET_DEFAULTS_TO_NONE=0 + # Set env vars from matrix export PYTHON_VERSION=${{ matrix.python_version }} export CU_VERSION="cpu" @@ -75,6 +108,8 @@ jobs: export RELEASE=0 export TORCH_VERSION=nightly fi + export TD_GET_DEFAULTS_TO_NONE=1 + # Remove the following line when the GPU tests are working inside docker, and uncomment the above lines #export CU_VERSION="cpu" @@ -110,6 +145,7 @@ jobs: export TORCH_VERSION=nightly fi export TF_CPP_MIN_LOG_LEVEL=0 + export TD_GET_DEFAULTS_TO_NONE=1 bash .github/unittest/linux_olddeps/scripts_gym_0_13/setup_env.sh @@ -119,8 +155,8 @@ jobs: tests-optdeps: strategy: matrix: - python_version: ["3.10"] # "3.8", "3.9", "3.10", "3.11" - cuda_arch_version: ["12.1"] # "11.6", "11.7" + python_version: ["3.11"] + cuda_arch_version: ["12.1"] fail-fast: false uses: pytorch/test-infra/.github/workflows/linux_job.yml@main with: @@ -136,9 +172,6 @@ jobs: # Commenting these out for now because the GPU test are not working inside docker export CUDA_ARCH_VERSION=${{ matrix.cuda_arch_version }} export CU_VERSION="cu${CUDA_ARCH_VERSION:0:2}${CUDA_ARCH_VERSION:3:1}" - # Remove the following line when the GPU tests are working inside docker, and uncomment the above lines - #export CU_VERSION="cpu" - if [[ "${{ github.ref }}" =~ release/* ]]; then export RELEASE=1 export TORCH_VERSION=stable @@ -146,6 +179,7 @@ jobs: export RELEASE=0 export TORCH_VERSION=nightly fi + export TD_GET_DEFAULTS_TO_NONE=1 echo "PYTHON_VERSION: $PYTHON_VERSION" echo "CU_VERSION: $CU_VERSION" @@ -156,7 +190,7 @@ jobs: tests-stable-gpu: strategy: matrix: - python_version: ["3.10"] # "3.8", "3.9", "3.10", "3.11" + python_version: ["3.10"] # "3.9", "3.10", "3.11" cuda_arch_version: ["11.8"] # "11.6", "11.7" fail-fast: false uses: pytorch/test-infra/.github/workflows/linux_job.yml@main @@ -187,6 +221,7 @@ jobs: echo "PYTHON_VERSION: $PYTHON_VERSION" echo "CU_VERSION: $CU_VERSION" + export TD_GET_DEFAULTS_TO_NONE=1 ## setup_env.sh bash .github/unittest/linux/scripts/run_all.sh diff --git a/.github/workflows/test-windows-optdepts.yml b/.github/workflows/test-windows-optdepts.yml index e98b6c1810e..14a8dd7ab13 100644 --- a/.github/workflows/test-windows-optdepts.yml +++ b/.github/workflows/test-windows-optdepts.yml @@ -42,6 +42,7 @@ jobs: export RELEASE=0 export TORCH_VERSION=nightly fi + export TD_GET_DEFAULTS_TO_NONE=1 ## setup_env.sh ./.github/unittest/windows_optdepts/scripts/setup_env.sh diff --git a/.github/workflows/wheels-legacy.yml b/.github/workflows/wheels-legacy.yml index 80dd2640e17..998b9eba4b2 100644 --- a/.github/workflows/wheels-legacy.yml +++ b/.github/workflows/wheels-legacy.yml @@ -5,6 +5,7 @@ on: push: branches: - release/* + - main concurrency: # Documentation suggests ${{ github.head_ref }}, but that's only available on pull_request/pull_request_target triggers, so using ${{ github.ref }}. @@ -18,7 +19,7 @@ jobs: runs-on: windows-latest strategy: matrix: - python_version: [["3.8", "3.8"], ["3.9", "3.9"], ["3.10", "3.10.3"], ["3.11", "3.11"], ["3.12", "3.12"]] + python_version: [["3.9", "3.9"], ["3.10", "3.10.3"], ["3.11", "3.11"], ["3.12", "3.12"]] steps: - name: Setup Python uses: actions/setup-python@v2 @@ -36,12 +37,12 @@ jobs: python3 -mpip install wheel TORCHRL_BUILD_VERSION=0.5.0 python3 setup.py bdist_wheel - name: Upload wheel for the test-wheel job - uses: actions/upload-artifact@v2 + uses: actions/upload-artifact@v3 with: name: torchrl-win-${{ matrix.python_version[0] }}.whl path: dist/torchrl-*.whl - name: Upload wheel for download - uses: actions/upload-artifact@v2 + uses: actions/upload-artifact@v3 with: name: torchrl-batch.whl path: dist/*.whl @@ -50,7 +51,7 @@ jobs: needs: build-wheel-windows strategy: matrix: - python_version: [ "3.8", "3.9", "3.10", "3.11", "3.12" ] + python_version: ["3.9", "3.10", "3.11", "3.12" ] runs-on: windows-latest steps: - name: Setup Python @@ -76,7 +77,7 @@ jobs: run: | python3 -mpip install numpy pytest pytest-cov codecov unittest-xml-reporting pillow>=4.1.1 scipy av networkx expecttest pyyaml - name: Download built wheels - uses: actions/download-artifact@v2 + uses: actions/download-artifact@v3 with: name: torchrl-win-${{ matrix.python_version }}.whl path: wheels diff --git a/README.md b/README.md index f82a8ff0c4c..abcf7349192 100644 --- a/README.md +++ b/README.md @@ -99,68 +99,69 @@ lines of code*! from torchrl.collectors import SyncDataCollector from torchrl.data.replay_buffers import TensorDictReplayBuffer, \ - LazyTensorStorage, SamplerWithoutReplacement + LazyTensorStorage, SamplerWithoutReplacement from torchrl.envs.libs.gym import GymEnv from torchrl.modules import ProbabilisticActor, ValueOperator, TanhNormal from torchrl.objectives import ClipPPOLoss from torchrl.objectives.value import GAE - env = GymEnv("Pendulum-v1") + env = GymEnv("Pendulum-v1") model = TensorDictModule( - nn.Sequential( - nn.Linear(3, 128), nn.Tanh(), - nn.Linear(128, 128), nn.Tanh(), - nn.Linear(128, 128), nn.Tanh(), - nn.Linear(128, 2), - NormalParamExtractor() - ), - in_keys=["observation"], - out_keys=["loc", "scale"] + nn.Sequential( + nn.Linear(3, 128), nn.Tanh(), + nn.Linear(128, 128), nn.Tanh(), + nn.Linear(128, 128), nn.Tanh(), + nn.Linear(128, 2), + NormalParamExtractor() + ), + in_keys=["observation"], + out_keys=["loc", "scale"] ) critic = ValueOperator( - nn.Sequential( - nn.Linear(3, 128), nn.Tanh(), - nn.Linear(128, 128), nn.Tanh(), - nn.Linear(128, 128), nn.Tanh(), - nn.Linear(128, 1), - ), - in_keys=["observation"], + nn.Sequential( + nn.Linear(3, 128), nn.Tanh(), + nn.Linear(128, 128), nn.Tanh(), + nn.Linear(128, 128), nn.Tanh(), + nn.Linear(128, 1), + ), + in_keys=["observation"], ) actor = ProbabilisticActor( - model, - in_keys=["loc", "scale"], - distribution_class=TanhNormal, - distribution_kwargs={"min": -1.0, "max": 1.0}, - return_log_prob=True - ) + model, + in_keys=["loc", "scale"], + distribution_class=TanhNormal, + distribution_kwargs={"low": -1.0, "high": 1.0}, + return_log_prob=True + ) buffer = TensorDictReplayBuffer( - LazyTensorStorage(1000), - SamplerWithoutReplacement() - ) + storage=LazyTensorStorage(1000), + sampler=SamplerWithoutReplacement(), + batch_size=50, + ) collector = SyncDataCollector( - env, - actor, - frames_per_batch=1000, - total_frames=1_000_000 - ) - loss_fn = ClipPPOLoss(actor, critic, gamma=0.99) + env, + actor, + frames_per_batch=1000, + total_frames=1_000_000, + ) + loss_fn = ClipPPOLoss(actor, critic) + adv_fn = GAE(value_network=critic, average_gae=True, gamma=0.99, lmbda=0.95) optim = torch.optim.Adam(loss_fn.parameters(), lr=2e-4) - adv_fn = GAE(value_network=critic, gamma=0.99, lmbda=0.95, average_gae=True) + for data in collector: # collect data - for epoch in range(10): - adv_fn(data) # compute advantage - buffer.extend(data.view(-1)) - for i in range(20): # consume data - sample = buffer.sample(50) # mini-batch - loss_vals = loss_fn(sample) - loss_val = sum( - value for key, value in loss_vals.items() if - key.startswith("loss") - ) - loss_val.backward() - optim.step() - optim.zero_grad() - print(f"avg reward: {data['next', 'reward'].mean().item(): 4.4f}") + for epoch in range(10): + adv_fn(data) # compute advantage + buffer.extend(data) + for sample in buffer: # consume data + loss_vals = loss_fn(sample) + loss_val = sum( + value for key, value in loss_vals.items() if + key.startswith("loss") + ) + loss_val.backward() + optim.step() + optim.zero_grad() + print(f"avg reward: {data['next', 'reward'].mean().item(): 4.4f}") ``` @@ -478,7 +479,7 @@ And it is `functorch` and `torch.compile` compatible! policy_explore = EGreedyWrapper(policy) with set_exploration_type(ExplorationType.RANDOM): tensordict = policy_explore(tensordict) # will use eps-greedy - with set_exploration_type(ExplorationType.MODE): + with set_exploration_type(ExplorationType.DETERMINISTIC): tensordict = policy_explore(tensordict) # will not use eps-greedy ``` @@ -521,23 +522,288 @@ If you would like to contribute to new features, check our [call for contributio ## Examples, tutorials and demos -A series of [examples](https://github.com/pytorch/rl/blob/main/examples/) are provided with an illustrative purpose: -- [DQN](https://github.com/pytorch/rl/blob/main/sota-implementations/dqn) -- [DDPG](https://github.com/pytorch/rl/blob/main/sota-implementations/ddpg/ddpg.py) -- [IQL](https://github.com/pytorch/rl/blob/main/sota-implementations/iql/iql_offline.py) -- [CQL](https://github.com/pytorch/rl/blob/main/sota-implementations/cql/cql_offline.py) -- [TD3](https://github.com/pytorch/rl/blob/main/sota-implementations/td3/td3.py) -- [TD3+BC](https://github.com/pytorch/rl/blob/main/sota-implementations/td3+bc/td3+bc.py) -- [A2C](https://github.com/pytorch/rl/blob/main/examples/a2c_old/a2c.py) -- [PPO](https://github.com/pytorch/rl/blob/main/sota-implementations/ppo/ppo.py) -- [SAC](https://github.com/pytorch/rl/blob/main/sota-implementations/sac/sac.py) -- [REDQ](https://github.com/pytorch/rl/blob/main/sota-implementations/redq/redq.py) -- [Dreamer](https://github.com/pytorch/rl/blob/main/sota-implementations/dreamer/dreamer.py) -- [Decision Transformers](https://github.com/pytorch/rl/blob/main/sota-implementations/decision_transformer) -- [RLHF](https://github.com/pytorch/rl/blob/main/examples/rlhf) +A series of [State-of-the-Art implementations](https://github.com/pytorch/rl/blob/main/sota-implementations/) are provided with an illustrative purpose: + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
Algorithm + Compile Support** + Tensordict-free API + Modular Losses + Continuous and Discrete +
DQN + 1.9x + + + NA + + (through ActionDiscretizer transform) +
DDPG + 1.87x + + + + + - (continuous only) +
IQL + 3.22x + + + + + + +
CQL + 2.68x + + + + + + +
TD3 + 2.27x + + + + + - (continuous only) +
+ TD3+BC + untested + + + + + - (continuous only) +
+ A2C + 2.67x + + + - + + +
+ PPO + 2.42x + + + - + + +
SAC + 2.62x + + + - + + +
REDQ + 2.28x + + + - + - (continuous only) +
Dreamer v1 + untested + + + + (different classes) + - (continuous only) +
Decision Transformers + untested + + + NA + - (continuous only) +
CrossQ + untested + + + + + - (continuous only) +
Gail + untested + + + NA + + +
Impala + untested + + + - + + +
IQL (MARL) + untested + + + + + + +
DDPG (MARL) + untested + + + + + - (continuous only) +
PPO (MARL) + untested + + + - + + +
QMIX-VDN (MARL) + untested + + + NA + + +
SAC (MARL) + untested + + + - + + +
RLHF + NA + + + NA + NA +
+ +** The number indicates expected speed-up compared to eager mode when executed on CPU. Numbers may vary depending on + architecture and device. and many more to come! +[Code examples](examples/) displaying toy code snippets and training scripts are also available +- [RLHF](examples/rlhf) +- [Memory-mapped replay buffers](examples/torchrl_features) + + Check the [examples](https://github.com/pytorch/rl/blob/main/sota-implementations/) directory for more details about handling the various configuration settings. @@ -592,7 +858,7 @@ Importantly, the nightly builds require the nightly builds of PyTorch too. To install extra dependencies, call ```bash -pip3 install "torchrl[atari,dm_control,gym_continuous,rendering,tests,utils,marl,checkpointing]" +pip3 install "torchrl[atari,dm_control,gym_continuous,rendering,tests,utils,marl,open_spiel,checkpointing]" ``` or a subset of these. diff --git a/benchmarks/test_collectors_benchmark.py b/benchmarks/test_collectors_benchmark.py index 1bdd26c0746..f2273d5cc3f 100644 --- a/benchmarks/test_collectors_benchmark.py +++ b/benchmarks/test_collectors_benchmark.py @@ -3,16 +3,20 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. import argparse +import time import pytest import torch.cuda +import tqdm from torchrl.collectors import SyncDataCollector from torchrl.collectors.collectors import ( MultiaSyncDataCollector, MultiSyncDataCollector, ) -from torchrl.envs import EnvCreator, GymEnv, StepCounter, TransformedEnv +from torchrl.data import LazyTensorStorage, ReplayBuffer +from torchrl.data.utils import CloudpickleWrapper +from torchrl.envs import EnvCreator, GymEnv, ParallelEnv, StepCounter, TransformedEnv from torchrl.envs.libs.dm_control import DMControlEnv from torchrl.envs.utils import RandomPolicy @@ -180,6 +184,57 @@ def test_async_pixels(benchmark): benchmark(execute_collector, c) +class TestRBGCollector: + @pytest.mark.parametrize( + "n_col,n_wokrers_per_col", + [ + [2, 2], + [4, 2], + [8, 2], + [16, 2], + [2, 1], + [4, 1], + [8, 1], + [16, 1], + ], + ) + def test_multiasync_rb(self, n_col, n_wokrers_per_col): + make_env = EnvCreator(lambda: GymEnv("ALE/Pong-v5")) + if n_wokrers_per_col > 1: + make_env = ParallelEnv(n_wokrers_per_col, make_env) + env = make_env + policy = RandomPolicy(env.action_spec) + else: + env = make_env() + policy = RandomPolicy(env.action_spec) + + storage = LazyTensorStorage(10_000) + rb = ReplayBuffer(storage=storage) + rb.extend(env.rollout(2, policy).reshape(-1)) + rb.append_transform(CloudpickleWrapper(lambda x: x.reshape(-1)), invert=True) + + fpb = n_wokrers_per_col * 100 + total_frames = n_wokrers_per_col * 100_000 + c = MultiaSyncDataCollector( + [make_env] * n_col, + policy, + frames_per_batch=fpb, + total_frames=total_frames, + replay_buffer=rb, + ) + frames = 0 + pbar = tqdm.tqdm(total=total_frames - (n_col * fpb)) + for i, _ in enumerate(c): + if i == n_col: + t0 = time.time() + if i >= n_col: + frames += fpb + if i > n_col: + fps = frames / (time.time() - t0) + pbar.update(fpb) + pbar.set_description(f"fps: {fps: 4.4f}") + + if __name__ == "__main__": args, unknown = argparse.ArgumentParser().parse_known_args() pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown) diff --git a/benchmarks/test_objectives_benchmarks.py b/benchmarks/test_objectives_benchmarks.py index 4cfc8470a15..d07b40595bc 100644 --- a/benchmarks/test_objectives_benchmarks.py +++ b/benchmarks/test_objectives_benchmarks.py @@ -6,9 +6,11 @@ import pytest import torch +from packaging import version from tensordict import TensorDict from tensordict.nn import ( + InteractionType, NormalParamExtractor, ProbabilisticTensorDictModule as ProbMod, ProbabilisticTensorDictSequential as ProbSeq, @@ -16,7 +18,7 @@ TensorDictSequential as Seq, ) from torch.nn import functional as F -from torchrl.data.tensor_specs import BoundedTensorSpec, UnboundedContinuousTensorSpec +from torchrl.data.tensor_specs import Bounded, Unbounded from torchrl.modules import MLP, QValueActor, TanhNormal from torchrl.objectives import ( A2CLoss, @@ -42,6 +44,20 @@ vec_td_lambda_return_estimate, ) +TORCH_VERSION = torch.__version__ +FULLGRAPH = version.parse(".".join(TORCH_VERSION.split(".")[:3])) >= version.parse( + "2.5.0" +) # Anything from 2.5, incl. nightlies, allows for fullgraph + + +@pytest.fixture(scope="module") +def set_default_device(): + cur_device = torch.get_default_device() + device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + torch.set_default_device(device) + yield + torch.set_default_device(cur_device) + class setup_value_fn: def __init__(self, has_lmbda, has_state_value): @@ -137,7 +153,26 @@ def test_gae_speed(benchmark, gae_fn, gamma_tensor, batches, timesteps): ) -def test_dqn_speed(benchmark, n_obs=8, n_act=4, depth=3, ncells=128, batch=128): +def _maybe_compile(fn, compile, td, fullgraph=FULLGRAPH, warmup=3): + if compile: + if isinstance(compile, str): + fn = torch.compile(fn, mode=compile, fullgraph=fullgraph) + else: + fn = torch.compile(fn, fullgraph=fullgraph) + + for _ in range(warmup): + fn(td) + + return fn + + +@pytest.mark.parametrize("backward", [None, "backward"]) +@pytest.mark.parametrize("compile", [False, True, "reduce-overhead"]) +def test_dqn_speed( + benchmark, backward, compile, n_obs=8, n_act=4, depth=3, ncells=128, batch=128 +): + if compile: + torch._dynamo.reset_code_caches() net = MLP(in_features=n_obs, out_features=n_act, depth=depth, num_cells=ncells) action_space = "one-hot" mod = QValueActor(net, in_keys=["obs"], action_space=action_space) @@ -155,10 +190,36 @@ def test_dqn_speed(benchmark, n_obs=8, n_act=4, depth=3, ncells=128, batch=128): [batch], ) loss(td) - benchmark(loss, td) + loss = _maybe_compile(loss, compile, td) + + if backward: + + def loss_and_bw(td): + losses = loss(td) + sum( + [val for key, val in losses.items() if key.startswith("loss")] + ).backward() + + benchmark.pedantic( + loss_and_bw, + args=(td,), + setup=loss.zero_grad, + iterations=1, + warmup_rounds=5, + rounds=50, + ) + else: + benchmark(loss, td) -def test_ddpg_speed(benchmark, n_obs=8, n_act=4, ncells=128, batch=128, n_hidden=64): + +@pytest.mark.parametrize("backward", [None, "backward"]) +@pytest.mark.parametrize("compile", [False, True, "reduce-overhead"]) +def test_ddpg_speed( + benchmark, backward, compile, n_obs=8, n_act=4, ncells=128, batch=128, n_hidden=64 +): + if compile: + torch._dynamo.reset_code_caches() common = MLP( num_cells=ncells, in_features=n_obs, @@ -200,10 +261,36 @@ def test_ddpg_speed(benchmark, n_obs=8, n_act=4, ncells=128, batch=128, n_hidden loss = DDPGLoss(actor, value) loss(td) - benchmark(loss, td) + loss = _maybe_compile(loss, compile, td) + + if backward: + + def loss_and_bw(td): + losses = loss(td) + sum( + [val for key, val in losses.items() if key.startswith("loss")] + ).backward() + + benchmark.pedantic( + loss_and_bw, + args=(td,), + setup=loss.zero_grad, + iterations=1, + warmup_rounds=5, + rounds=50, + ) + else: + benchmark(loss, td) -def test_sac_speed(benchmark, n_obs=8, n_act=4, ncells=128, batch=128, n_hidden=64): + +@pytest.mark.parametrize("backward", [None, "backward"]) +@pytest.mark.parametrize("compile", [False, True, "reduce-overhead"]) +def test_sac_speed( + benchmark, backward, compile, n_obs=8, n_act=4, ncells=128, batch=128, n_hidden=64 +): + if compile: + torch._dynamo.reset_code_caches() common = MLP( num_cells=ncells, in_features=n_obs, @@ -245,23 +332,48 @@ def test_sac_speed(benchmark, n_obs=8, n_act=4, ncells=128, batch=128, n_hidden= in_keys=["loc", "scale"], out_keys=["action"], distribution_class=TanhNormal, + distribution_kwargs={"safe_tanh": False}, ), ) value_head = Mod( value, in_keys=["hidden", "action"], out_keys=["state_action_value"] ) value = Seq(common, value_head) - value(actor(td)) + value(actor(td.clone())) - loss = SACLoss( - actor, value, action_spec=UnboundedContinuousTensorSpec(shape=(n_act,)) - ) + loss = SACLoss(actor, value, action_spec=Unbounded(shape=(n_act,))) loss(td) - benchmark(loss, td) + loss = _maybe_compile(loss, compile, td) + + if backward: -def test_redq_speed(benchmark, n_obs=8, n_act=4, ncells=128, batch=128, n_hidden=64): + def loss_and_bw(td): + losses = loss(td) + sum( + [val for key, val in losses.items() if key.startswith("loss")] + ).backward() + + benchmark.pedantic( + loss_and_bw, + args=(td,), + setup=loss.zero_grad, + iterations=1, + warmup_rounds=5, + rounds=50, + ) + else: + benchmark(loss, td) + + +@pytest.mark.parametrize("backward", [None, "backward"]) +@pytest.mark.parametrize("compile", [False, True, "reduce-overhead"]) +def test_redq_speed( + benchmark, backward, compile, n_obs=8, n_act=4, ncells=128, batch=128, n_hidden=64 +): + if compile: + torch._dynamo.reset_code_caches() common = MLP( num_cells=ncells, in_features=n_obs, @@ -304,25 +416,50 @@ def test_redq_speed(benchmark, n_obs=8, n_act=4, ncells=128, batch=128, n_hidden out_keys=["action"], distribution_class=TanhNormal, return_log_prob=True, + distribution_kwargs={"safe_tanh": False}, ), ) value_head = Mod( value, in_keys=["hidden", "action"], out_keys=["state_action_value"] ) value = Seq(common, value_head) - value(actor(td)) + value(actor(td.copy())) - loss = REDQLoss( - actor, value, action_spec=UnboundedContinuousTensorSpec(shape=(n_act,)) - ) + loss = REDQLoss(actor, value, action_spec=Unbounded(shape=(n_act,))) loss(td) - benchmark(loss, td) + loss = _maybe_compile(loss, compile, td) + + if backward: + + def loss_and_bw(td): + losses = loss(td) + totalloss = sum( + [val for key, val in losses.items() if key.startswith("loss")] + ) + totalloss.backward() + + loss_and_bw(td) + + benchmark.pedantic( + loss_and_bw, + args=(td,), + setup=loss.zero_grad, + iterations=1, + warmup_rounds=5, + rounds=50, + ) + else: + benchmark(loss, td) +@pytest.mark.parametrize("backward", [None, "backward"]) +@pytest.mark.parametrize("compile", [False, True, "reduce-overhead"]) def test_redq_deprec_speed( - benchmark, n_obs=8, n_act=4, ncells=128, batch=128, n_hidden=64 + benchmark, backward, compile, n_obs=8, n_act=4, ncells=128, batch=128, n_hidden=64 ): + if compile: + torch._dynamo.reset_code_caches() common = MLP( num_cells=ncells, in_features=n_obs, @@ -365,23 +502,48 @@ def test_redq_deprec_speed( out_keys=["action"], distribution_class=TanhNormal, return_log_prob=True, + distribution_kwargs={"safe_tanh": False}, ), ) value_head = Mod( value, in_keys=["hidden", "action"], out_keys=["state_action_value"] ) value = Seq(common, value_head) - value(actor(td)) + value(actor(td.copy())) - loss = REDQLoss_deprecated( - actor, value, action_spec=UnboundedContinuousTensorSpec(shape=(n_act,)) - ) + loss = REDQLoss_deprecated(actor, value, action_spec=Unbounded(shape=(n_act,))) loss(td) - benchmark(loss, td) + loss = _maybe_compile(loss, compile, td) + + if backward: + + def loss_and_bw(td): + losses = loss(td) + sum( + [val for key, val in losses.items() if key.startswith("loss")] + ).backward() + + benchmark.pedantic( + loss_and_bw, + args=(td,), + setup=loss.zero_grad, + iterations=1, + warmup_rounds=5, + rounds=50, + ) + else: + benchmark(loss, td) -def test_td3_speed(benchmark, n_obs=8, n_act=4, ncells=128, batch=128, n_hidden=64): + +@pytest.mark.parametrize("backward", [None, "backward"]) +@pytest.mark.parametrize("compile", [False, True, "reduce-overhead"]) +def test_td3_speed( + benchmark, backward, compile, n_obs=8, n_act=4, ncells=128, batch=128, n_hidden=64 +): + if compile: + torch._dynamo.reset_code_caches() common = MLP( num_cells=ncells, in_features=n_obs, @@ -423,26 +585,54 @@ def test_td3_speed(benchmark, n_obs=8, n_act=4, ncells=128, batch=128, n_hidden= in_keys=["loc", "scale"], out_keys=["action"], distribution_class=TanhNormal, + distribution_kwargs={"safe_tanh": False}, return_log_prob=True, + default_interaction_type=InteractionType.DETERMINISTIC, ), ) value_head = Mod( value, in_keys=["hidden", "action"], out_keys=["state_action_value"] ) value = Seq(common, value_head) - value(actor(td)) + value(actor(td.clone())) loss = TD3Loss( actor, value, - action_spec=BoundedTensorSpec(shape=(n_act,), low=-1, high=1), + action_spec=Bounded(shape=(n_act,), low=-1, high=1), ) loss(td) - benchmark.pedantic(loss, args=(td,), rounds=100, iterations=10) + loss = _maybe_compile(loss, compile, td) + + if backward: + + def loss_and_bw(td): + losses = loss(td) + sum( + [val for key, val in losses.items() if key.startswith("loss")] + ).backward() -def test_cql_speed(benchmark, n_obs=8, n_act=4, ncells=128, batch=128, n_hidden=64): + benchmark.pedantic( + loss_and_bw, + args=(td,), + setup=loss.zero_grad, + iterations=1, + warmup_rounds=5, + rounds=50, + ) + else: + benchmark.pedantic(loss, args=(td,), rounds=100, iterations=10) + + +@pytest.mark.parametrize("backward", [None, "backward"]) +@pytest.mark.parametrize("compile", [False, True, "reduce-overhead"]) +def test_cql_speed( + benchmark, backward, compile, n_obs=8, n_act=4, ncells=128, batch=128, n_hidden=64 +): + if compile: + torch._dynamo.reset_code_caches() common = MLP( num_cells=ncells, in_features=n_obs, @@ -481,26 +671,59 @@ def test_cql_speed(benchmark, n_obs=8, n_act=4, ncells=128, batch=128, n_hidden= Mod(actor_net, in_keys=["hidden"], out_keys=["param"]), Mod(NormalParamExtractor(), in_keys=["param"], out_keys=["loc", "scale"]), ProbMod( - in_keys=["loc", "scale"], out_keys=["action"], distribution_class=TanhNormal + in_keys=["loc", "scale"], + out_keys=["action"], + distribution_class=TanhNormal, + distribution_kwargs={"safe_tanh": False}, ), ) value_head = Mod( value, in_keys=["hidden", "action"], out_keys=["state_action_value"] ) value = Seq(common, value_head) - value(actor(td)) + value(actor(td.copy())) - loss = CQLLoss( - actor, value, action_spec=UnboundedContinuousTensorSpec(shape=(n_act,)) - ) + loss = CQLLoss(actor, value, action_spec=Unbounded(shape=(n_act,))) loss(td) - benchmark(loss, td) + loss = _maybe_compile(loss, compile, td) + + if backward: + + def loss_and_bw(td): + losses = loss(td) + sum( + [val for key, val in losses.items() if key.startswith("loss")] + ).backward() + benchmark.pedantic( + loss_and_bw, + args=(td,), + setup=loss.zero_grad, + iterations=1, + warmup_rounds=5, + rounds=50, + ) + else: + benchmark(loss, td) + + +@pytest.mark.parametrize("backward", [None, "backward"]) +@pytest.mark.parametrize("compile", [False, True, "reduce-overhead"]) def test_a2c_speed( - benchmark, n_obs=8, n_act=4, n_hidden=64, ncells=128, batch=128, T=10 + benchmark, + backward, + compile, + n_obs=8, + n_act=4, + n_hidden=64, + ncells=128, + batch=128, + T=10, ): + if compile: + torch._dynamo.reset_code_caches() common_net = MLP( num_cells=ncells, in_features=n_obs, @@ -541,7 +764,10 @@ def test_a2c_speed( Mod(actor_net, in_keys=["hidden"], out_keys=["param"]), Mod(NormalParamExtractor(), in_keys=["param"], out_keys=["loc", "scale"]), ProbMod( - in_keys=["loc", "scale"], out_keys=["action"], distribution_class=TanhNormal + in_keys=["loc", "scale"], + out_keys=["action"], + distribution_class=TanhNormal, + distribution_kwargs={"safe_tanh": False}, ), ) critic = Seq(common, Mod(value_net, in_keys=["hidden"], out_keys=["state_value"])) @@ -552,12 +778,44 @@ def test_a2c_speed( advantage = GAE(value_network=critic, gamma=0.99, lmbda=0.95, shifted=True) advantage(td) loss(td) - benchmark(loss, td) + loss = _maybe_compile(loss, compile, td) + if backward: + + def loss_and_bw(td): + losses = loss(td) + sum( + [val for key, val in losses.items() if key.startswith("loss")] + ).backward() + + benchmark.pedantic( + loss_and_bw, + args=(td,), + setup=loss.zero_grad, + iterations=1, + warmup_rounds=5, + rounds=50, + ) + else: + benchmark(loss, td) + + +@pytest.mark.parametrize("backward", [None, "backward"]) +@pytest.mark.parametrize("compile", [False, True, "reduce-overhead"]) def test_ppo_speed( - benchmark, n_obs=8, n_act=4, n_hidden=64, ncells=128, batch=128, T=10 + benchmark, + backward, + compile, + n_obs=8, + n_act=4, + n_hidden=64, + ncells=128, + batch=128, + T=10, ): + if compile: + torch._dynamo.reset_code_caches() common_net = MLP( num_cells=ncells, in_features=n_obs, @@ -598,7 +856,10 @@ def test_ppo_speed( Mod(actor_net, in_keys=["hidden"], out_keys=["param"]), Mod(NormalParamExtractor(), in_keys=["param"], out_keys=["loc", "scale"]), ProbMod( - in_keys=["loc", "scale"], out_keys=["action"], distribution_class=TanhNormal + in_keys=["loc", "scale"], + out_keys=["action"], + distribution_class=TanhNormal, + distribution_kwargs={"safe_tanh": False}, ), ) critic = Seq(common, Mod(value_net, in_keys=["hidden"], out_keys=["state_value"])) @@ -609,12 +870,44 @@ def test_ppo_speed( advantage = GAE(value_network=critic, gamma=0.99, lmbda=0.95, shifted=True) advantage(td) loss(td) - benchmark(loss, td) + loss = _maybe_compile(loss, compile, td) + + if backward: + + def loss_and_bw(td): + losses = loss(td) + sum( + [val for key, val in losses.items() if key.startswith("loss")] + ).backward() + + benchmark.pedantic( + loss_and_bw, + args=(td,), + setup=loss.zero_grad, + iterations=1, + warmup_rounds=5, + rounds=50, + ) + else: + benchmark(loss, td) + +@pytest.mark.parametrize("backward", [None, "backward"]) +@pytest.mark.parametrize("compile", [False, True, "reduce-overhead"]) def test_reinforce_speed( - benchmark, n_obs=8, n_act=4, n_hidden=64, ncells=128, batch=128, T=10 + benchmark, + backward, + compile, + n_obs=8, + n_act=4, + n_hidden=64, + ncells=128, + batch=128, + T=10, ): + if compile: + torch._dynamo.reset_code_caches() common_net = MLP( num_cells=ncells, in_features=n_obs, @@ -655,7 +948,10 @@ def test_reinforce_speed( Mod(actor_net, in_keys=["hidden"], out_keys=["param"]), Mod(NormalParamExtractor(), in_keys=["param"], out_keys=["loc", "scale"]), ProbMod( - in_keys=["loc", "scale"], out_keys=["action"], distribution_class=TanhNormal + in_keys=["loc", "scale"], + out_keys=["action"], + distribution_class=TanhNormal, + distribution_kwargs={"safe_tanh": False}, ), ) critic = Seq(common, Mod(value_net, in_keys=["hidden"], out_keys=["state_value"])) @@ -666,12 +962,44 @@ def test_reinforce_speed( advantage = GAE(value_network=critic, gamma=0.99, lmbda=0.95, shifted=True) advantage(td) loss(td) - benchmark(loss, td) + loss = _maybe_compile(loss, compile, td) + + if backward: + def loss_and_bw(td): + losses = loss(td) + sum( + [val for key, val in losses.items() if key.startswith("loss")] + ).backward() + + benchmark.pedantic( + loss_and_bw, + args=(td,), + setup=loss.zero_grad, + iterations=1, + warmup_rounds=5, + rounds=50, + ) + else: + benchmark(loss, td) + + +@pytest.mark.parametrize("backward", [None, "backward"]) +@pytest.mark.parametrize("compile", [False, True, "reduce-overhead"]) def test_iql_speed( - benchmark, n_obs=8, n_act=4, n_hidden=64, ncells=128, batch=128, T=10 + benchmark, + backward, + compile, + n_obs=8, + n_act=4, + n_hidden=64, + ncells=128, + batch=128, + T=10, ): + if compile: + torch._dynamo.reset_code_caches() common_net = MLP( num_cells=ncells, in_features=n_obs, @@ -718,7 +1046,10 @@ def test_iql_speed( Mod(actor_net, in_keys=["hidden"], out_keys=["param"]), Mod(NormalParamExtractor(), in_keys=["param"], out_keys=["loc", "scale"]), ProbMod( - in_keys=["loc", "scale"], out_keys=["action"], distribution_class=TanhNormal + in_keys=["loc", "scale"], + out_keys=["action"], + distribution_class=TanhNormal, + distribution_kwargs={"safe_tanh": False}, ), ) value = Seq(common, Mod(value_net, in_keys=["hidden"], out_keys=["state_value"])) @@ -731,7 +1062,27 @@ def test_iql_speed( loss = IQLLoss(actor_network=actor, value_network=value, qvalue_network=qvalue) loss(td) - benchmark(loss, td) + + loss = _maybe_compile(loss, compile, td) + + if backward: + + def loss_and_bw(td): + losses = loss(td) + sum( + [val for key, val in losses.items() if key.startswith("loss")] + ).backward() + + benchmark.pedantic( + loss_and_bw, + args=(td,), + setup=loss.zero_grad, + iterations=1, + warmup_rounds=5, + rounds=50, + ) + else: + benchmark(loss, td) if __name__ == "__main__": diff --git a/docs/requirements.txt b/docs/requirements.txt index f6138cac30a..702a2884421 100644 --- a/docs/requirements.txt +++ b/docs/requirements.txt @@ -15,9 +15,8 @@ sphinx_design torchvision dm_control -atari-py -ale-py -gym[classic_control,accept-rom-license] +mujoco +gym[classic_control,accept-rom-license,ale-py,atari] pygame tqdm ipython diff --git a/docs/source/_static/img/collector-copy.png b/docs/source/_static/img/collector-copy.png new file mode 100644 index 00000000000..8a8921cacca Binary files /dev/null and b/docs/source/_static/img/collector-copy.png differ diff --git a/docs/source/_static/img/rename_transform.png b/docs/source/_static/img/rename_transform.png new file mode 100644 index 00000000000..3de518362cd Binary files /dev/null and b/docs/source/_static/img/rename_transform.png differ diff --git a/docs/source/reference/collectors.rst b/docs/source/reference/collectors.rst index 74bd058b8f0..41b743dce15 100644 --- a/docs/source/reference/collectors.rst +++ b/docs/source/reference/collectors.rst @@ -45,7 +45,7 @@ worker) may also impact the memory management. The key parameters to control are :obj:`devices` which controls the execution devices (ie the device of the policy) and :obj:`storing_device` which will control the device where the environment and data are stored during a rollout. A good heuristic is usually to use the same device -for storage and compute, which is the default behaviour when only the `devices` argument +for storage and compute, which is the default behavior when only the `devices` argument is being passed. Besides those compute parameters, users may choose to configure the following parameters: @@ -99,6 +99,25 @@ delivers batches of data on a first-come, first-serve basis, whereas :class:`~torchrl.collectors.MultiSyncDataCollector` gathers data from each sub-collector before delivering it. +Collectors and policy copies +---------------------------- + +When passing a policy to a collector, we can choose the device on which this policy will be run. This can be used to +keep the training version of the policy on a device and the inference version on another. For example, if you have two +CUDA devices, it may be wise to train on one device and execute the policy for inference on the other. If that is the +case, a :meth:`~torchrl.collectors.DataCollector.update_policy_weights_` can be used to copy the parameters from one +device to the other (if no copy is required, this method is a no-op). + +Since the goal is to avoid calling `policy.to(policy_device)` explicitly, the collector will do a deepcopy of the +policy structure and copy the parameters placed on the new device during instantiation if necessary. +Since not all policies support deepcopies (e.g., policies using CUDA graphs or relying on third-party libraries), we +try to limit the cases where a deepcopy will be executed. The following chart shows when this will occur. + +.. figure:: /_static/img/collector-copy.png + + Policy copy decision tree in Collectors. + + Collectors and replay buffers interoperability ---------------------------------------------- diff --git a/docs/source/reference/data.rst b/docs/source/reference/data.rst index 0dca499f4d9..6fbeada5bd0 100644 --- a/docs/source/reference/data.rst +++ b/docs/source/reference/data.rst @@ -171,7 +171,7 @@ using the following components: Storage choice is very influential on replay buffer sampling latency, especially in distributed reinforcement learning settings with larger data volumes. :class:`~torchrl.data.replay_buffers.storages.LazyMemmapStorage` is highly -advised in distributed settings with shared storage due to the lower serialisation +advised in distributed settings with shared storage due to the lower serialization cost of MemoryMappedTensors as well as the ability to specify file storage locations for improved node failure recovery. The following mean sampling latency improvements over using :class:`~torchrl.data.replay_buffers.ListStorage` @@ -877,11 +877,58 @@ TensorSpec .. _ref_specs: -The `TensorSpec` parent class and subclasses define the basic properties of observations and actions in TorchRL, such -as shape, device, dtype and domain. +The :class:`~torchrl.data.TensorSpec` parent class and subclasses define the basic properties of state, observations +actions, rewards and done status in TorchRL, such as their shape, device, dtype and domain. + It is important that your environment specs match the input and output that it sends and receives, as -:obj:`ParallelEnv` will create buffers from these specs to communicate with the spawn processes. -Check the :obj:`torchrl.envs.utils.check_env_specs` method for a sanity check. +:class:`~torchrl.envs.ParallelEnv` will create buffers from these specs to communicate with the spawn processes. +Check the :func:`torchrl.envs.utils.check_env_specs` method for a sanity check. + +If needed, specs can be automatially generated from data using the :func:`~torchrl.envs.utils.make_composite_from_td` +function. + +Specs fall in two main categories, numerical and categorical. + +.. table:: Numerical TensorSpec subclasses. + + +-------------------------------------------------------------------------------+ + | Numerical | + +=====================================+=========================================+ + | Bounded | Unbounded | + +-----------------+-------------------+-------------------+---------------------+ + | BoundedDiscrete | BoundedContinuous | UnboundedDiscrete | UnboundedContinuous | + +-----------------+-------------------+-------------------+---------------------+ + +Whenever a :class:`~torchrl.data.Bounded` instance is created, its domain (defined either implicitly by its dtype or +explicitly by the `"domain"` keyword argument) will determine if the instantiated class will be of :class:`~torchrl.data.BoundedContinuous` +or :class:`~torchrl.data.BoundedDiscrete` type. The same applies to the :class:`~torchrl.data.Unbounded` class. +See these classes for further information. + +.. table:: Categorical TensorSpec subclasses. + + +------------------------------------------------------------------+ + | Categorical | + +========+=============+=============+==================+==========+ + | OneHot | MultiOneHot | Categorical | MultiCategorical | Binary | + +--------+-------------+-------------+------------------+----------+ + +Unlike ``gymnasium``, TorchRL does not have the concept of an arbitrary list of specs. If multiple specs have to be +combined together, TorchRL assumes that the data will be presented as dictionaries (more specifically, as +:class:`~tensordict.TensorDict` or related formats). The corresponding :class:`~torchrl.data.TensorSpec` class in these +cases is the :class:`~torchrl.data.Composite` spec. + +Nevertheless, specs can be stacked together using :func:`~torch.stack`: if they are identical, their shape will be +expanded accordingly. +Otherwise, a lazy stack will be created through the :class:`~torchrl.data.Stacked` class. + +Similarly, ``TensorSpecs`` possess some common behavior with :class:`~torch.Tensor` and +:class:`~tensordict.TensorDict`: they can be reshaped, indexed, squeezed, unsqueezed, moved to another device (``to``) +or unbound (``unbind``) as regular :class:`~torch.Tensor` instances would be. + +Specs where some dimensions are ``-1`` are said to be "dynamic" and the negative dimensions indicate that the corresponding +data has an inconsistent shape. When seen by an optimizer or an environment (e.g., batched environment such as +:class:`~torchrl.envs.ParallelEnv`), these negative shapes tell TorchRL to avoid using buffers as the tensor shapes are +not predictable. .. currentmodule:: torchrl.data @@ -890,19 +937,40 @@ Check the :obj:`torchrl.envs.utils.check_env_specs` method for a sanity check. :template: rl_template.rst TensorSpec + Binary + Bounded + Categorical + Composite + MultiCategorical + MultiOneHot + NonTensor + OneHotDiscrete + Stacked + StackedComposite + Unbounded + UnboundedContinuous + UnboundedDiscrete + +The following classes are deprecated and just point to the classes above: + +.. currentmodule:: torchrl.data + +.. autosummary:: + :toctree: generated/ + :template: rl_template.rst + BinaryDiscreteTensorSpec BoundedTensorSpec CompositeSpec DiscreteTensorSpec + LazyStackedCompositeSpec + LazyStackedTensorSpec MultiDiscreteTensorSpec MultiOneHotDiscreteTensorSpec NonTensorSpec OneHotDiscreteTensorSpec UnboundedContinuousTensorSpec UnboundedDiscreteTensorSpec - LazyStackedTensorSpec - LazyStackedCompositeSpec - NonTensorSpec Reinforcement Learning From Human Feedback (RLHF) ------------------------------------------------- diff --git a/docs/source/reference/envs.rst b/docs/source/reference/envs.rst index 11a5bb041a6..960daf0fb12 100644 --- a/docs/source/reference/envs.rst +++ b/docs/source/reference/envs.rst @@ -28,9 +28,9 @@ Each env will have the following attributes: This is especially useful for transforms (see below). For parametric environments (e.g. model-based environments), the device does represent the hardware that will be used to compute the operations. -- :obj:`env.observation_spec`: a :class:`~torchrl.data.CompositeSpec` object +- :obj:`env.observation_spec`: a :class:`~torchrl.data.Composite` object containing all the observation key-spec pairs. -- :obj:`env.state_spec`: a :class:`~torchrl.data.CompositeSpec` object +- :obj:`env.state_spec`: a :class:`~torchrl.data.Composite` object containing all the input key-spec pairs (except action). For most stateful environments, this container will be empty. - :obj:`env.action_spec`: a :class:`~torchrl.data.TensorSpec` object @@ -39,10 +39,10 @@ Each env will have the following attributes: the reward spec. - :obj:`env.done_spec`: a :class:`~torchrl.data.TensorSpec` object representing the done-flag spec. See the section on trajectory termination below. -- :obj:`env.input_spec`: a :class:`~torchrl.data.CompositeSpec` object containing +- :obj:`env.input_spec`: a :class:`~torchrl.data.Composite` object containing all the input keys (:obj:`"full_action_spec"` and :obj:`"full_state_spec"`). It is locked and should not be modified directly. -- :obj:`env.output_spec`: a :class:`~torchrl.data.CompositeSpec` object containing +- :obj:`env.output_spec`: a :class:`~torchrl.data.Composite` object containing all the output keys (:obj:`"full_observation_spec"`, :obj:`"full_reward_spec"` and :obj:`"full_done_spec"`). It is locked and should not be modified directly. @@ -318,7 +318,7 @@ have on an environment returning zeros after reset: We also offer the :class:`~.SerialEnv` class that enjoys the exact same API but is executed serially. This is mostly useful for testing purposes, when one wants to assess the -behaviour of a :class:`~.ParallelEnv` without launching the subprocesses. +behavior of a :class:`~.ParallelEnv` without launching the subprocesses. In addition to :class:`~.ParallelEnv`, which offers process-based parallelism, we also provide a way to create multithreaded environments with :obj:`~.MultiThreadedEnv`. This class uses `EnvPool `_ @@ -433,28 +433,28 @@ only the done flag is shared across agents (as in VMAS): ... action_specs.append(agent_i_action_spec) ... reward_specs.append(agent_i_reward_spec) ... observation_specs.append(agent_i_observation_spec) - >>> env.action_spec = CompositeSpec( + >>> env.action_spec = Composite( ... { - ... "agents": CompositeSpec( + ... "agents": Composite( ... {"action": torch.stack(action_specs)}, shape=(env.n_agents,) ... ) ... } ...) - >>> env.reward_spec = CompositeSpec( + >>> env.reward_spec = Composite( ... { - ... "agents": CompositeSpec( + ... "agents": Composite( ... {"reward": torch.stack(reward_specs)}, shape=(env.n_agents,) ... ) ... } ...) - >>> env.observation_spec = CompositeSpec( + >>> env.observation_spec = Composite( ... { - ... "agents": CompositeSpec( + ... "agents": Composite( ... {"observation": torch.stack(observation_specs)}, shape=(env.n_agents,) ... ) ... } ...) - >>> env.done_spec = DiscreteTensorSpec( + >>> env.done_spec = Categorical( ... n=2, ... shape=torch.Size((1,)), ... dtype=torch.bool, @@ -499,7 +499,7 @@ current episode. To handle these cases, torchrl provides a :class:`~torchrl.envs.AutoResetTransform` that will copy the observations that result from the call to `step` to the next `reset` and skip the calls to `reset` during rollouts (in both :meth:`~torchrl.envs.EnvBase.rollout` and :class:`~torchrl.collectors.SyncDataCollector` iterations). -This transform class also provides a fine-grained control over the behaviour to be adopted for the invalid observations, +This transform class also provides a fine-grained control over the behavior to be adopted for the invalid observations, which can be masked with `"nan"` or any other values, or not masked at all. To tell torchrl that an environment is auto-resetting, it is sufficient to provide an ``auto_reset`` argument @@ -582,23 +582,23 @@ the ``return_contiguous=False`` argument. Here is a working example: >>> from torchrl.envs import EnvBase - >>> from torchrl.data import UnboundedContinuousTensorSpec, CompositeSpec, BoundedTensorSpec, BinaryDiscreteTensorSpec + >>> from torchrl.data import Unbounded, Composite, Bounded, Binary >>> import torch >>> from tensordict import TensorDict, TensorDictBase >>> >>> class EnvWithDynamicSpec(EnvBase): ... def __init__(self, max_count=5): ... super().__init__(batch_size=()) - ... self.observation_spec = CompositeSpec( - ... observation=UnboundedContinuousTensorSpec(shape=(3, -1, 2)), + ... self.observation_spec = Composite( + ... observation=Unbounded(shape=(3, -1, 2)), ... ) - ... self.action_spec = BoundedTensorSpec(low=-1, high=1, shape=(2,)) - ... self.full_done_spec = CompositeSpec( - ... done=BinaryDiscreteTensorSpec(1, shape=(1,), dtype=torch.bool), - ... terminated=BinaryDiscreteTensorSpec(1, shape=(1,), dtype=torch.bool), - ... truncated=BinaryDiscreteTensorSpec(1, shape=(1,), dtype=torch.bool), + ... self.action_spec = Bounded(low=-1, high=1, shape=(2,)) + ... self.full_done_spec = Composite( + ... done=Binary(1, shape=(1,), dtype=torch.bool), + ... terminated=Binary(1, shape=(1,), dtype=torch.bool), + ... truncated=Binary(1, shape=(1,), dtype=torch.bool), ... ) - ... self.reward_spec = UnboundedContinuousTensorSpec((1,), dtype=torch.float) + ... self.reward_spec = Unbounded((1,), dtype=torch.float) ... self.count = 0 ... self.max_count = max_count ... @@ -722,6 +722,9 @@ Since each transform uses a ``"in_keys"``/``"out_keys"`` set of keyword argument also easy to root the transform graph to each component of the observation data (e.g. pixels or states etc). +Forward and inverse transforms +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + Transforms also have an ``inv`` method that is called before the action is applied in reverse order over the composed transform chain: this allows to apply transforms to data in the environment before the action is taken @@ -733,6 +736,20 @@ in the environment. The keys to be included in this inverse transform are passed >>> env.append_transform(DoubleToFloat(in_keys_inv=["action"])) # will map the action from float32 to float64 before calling the base_env.step +The way ``in_keys`` relates to ``in_keys_inv`` can be understood by considering the base environment as the "inner" part +of the transform. In constrast, the user inputs and outputs to and from the transform are to be considered as the +outside world. The following figure shows what this means in practice for the :class:`~torchrl.envs.RenameTransform` +class: the input ``TensorDict`` of the ``step`` function must have the ``out_keys_inv`` listed in its entries as they +are part of the outside world. The transform changes these names to make them match the names of the inner, base +environment using the ``in_keys_inv``. The inverse process is executed with the output tensordict, where the ``in_keys`` +are mapped to the corresponding ``out_keys``. + +.. figure:: /_static/img/rename_transform.png + + Rename transform logic + + + Cloning transforms ~~~~~~~~~~~~~~~~~~ @@ -755,10 +772,10 @@ registered buffers: >>> TransformedEnv(base_env, third_transform.clone()) # works On a single process or if the buffers are placed in shared memory, this will -result in all the clone transforms to keep the same behaviour even if the +result in all the clone transforms to keep the same behavior even if the buffers are changed in place (which is what will happen with the :class:`CatFrames` transform, for instance). In distributed settings, this may not hold and one -should be careful about the expected behaviour of the cloned transforms in this +should be careful about the expected behavior of the cloned transforms in this context. Finally, notice that indexing multiple transforms from a :class:`Compose` transform may also result in loss of parenthood for these transforms: the reason is that @@ -979,11 +996,9 @@ Helpers RandomPolicy check_env_specs - exploration_mode #deprecated exploration_type get_available_libraries make_composite_from_td - set_exploration_mode #deprecated set_exploration_type step_mdp terminated_or_truncated @@ -1061,7 +1076,7 @@ the current gym backend or any of its modules: Another tool that comes in handy with gym and other external dependencies is the :class:`torchrl._utils.implement_for` class. Decorating a function with ``@implement_for`` will tell torchrl that, depending on the version -indicated, a specific behaviour is to be expected. This allows us to easily +indicated, a specific behavior is to be expected. This allows us to easily support multiple versions of gym without requiring any effort from the user side. For example, considering that our virtual environment has the v0.26.2 installed, the following function will return ``1`` when queried: @@ -1098,11 +1113,15 @@ the following function will return ``1`` when queried: MultiThreadedEnv MultiThreadedEnvWrapper OpenMLEnv + OpenSpielWrapper + OpenSpielEnv PettingZooEnv PettingZooWrapper RoboHiveEnv SMACv2Env SMACv2Wrapper + UnityMLAgentsEnv + UnityMLAgentsWrapper VmasEnv VmasWrapper gym_backend diff --git a/docs/source/reference/modules.rst b/docs/source/reference/modules.rst index c73ed5083fd..e1642868228 100644 --- a/docs/source/reference/modules.rst +++ b/docs/source/reference/modules.rst @@ -57,16 +57,21 @@ projected (in a L1-manner) into the desired domain. SafeSequential TanhModule -Exploration wrappers -~~~~~~~~~~~~~~~~~~~~ +Exploration wrappers and modules +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -To efficiently explore the environment, TorchRL proposes a series of wrappers +To efficiently explore the environment, TorchRL proposes a series of modules that will override the action sampled by the policy by a noisier version. -Their behaviour is controlled by :func:`~torchrl.envs.utils.exploration_mode`: -if the exploration is set to ``"random"``, the exploration is active. In all +Their behavior is controlled by :func:`~torchrl.envs.utils.exploration_type`: +if the exploration is set to ``ExplorationType.RANDOM``, the exploration is active. In all other cases, the action written in the tensordict is simply the network output. -.. currentmodule:: torchrl.modules.tensordict_module +.. note:: Unlike other exploration modules, :class:`~torchrl.modules.ConsistentDropoutModule` + uses the ``train``/``eval`` mode to comply with the regular `Dropout` API in PyTorch. + The :func:`~torchrl.envs.utils.set_exploration_type` context manager will have no effect on + this module. + +.. currentmodule:: torchrl.modules .. autosummary:: :toctree: generated/ @@ -74,6 +79,7 @@ other cases, the action written in the tensordict is simply the network output. AdditiveGaussianModule AdditiveGaussianWrapper + ConsistentDropoutModule EGreedyModule EGreedyWrapper OrnsteinUhlenbeckProcessModule @@ -163,11 +169,91 @@ resulting action in the input tensordict along with the list of action values. >>> from tensordict import TensorDict >>> from tensordict.nn.functional_modules import make_functional >>> from torch import nn - >>> from torchrl.data import OneHotDiscreteTensorSpec + >>> from torchrl.data import OneHot + >>> from torchrl.modules.tensordict_module.actors import QValueActor + >>> td = TensorDict({'observation': torch.randn(5, 3)}, [5]) + >>> # we have 4 actions to choose from + >>> action_spec = OneHot(4) + >>> # the model reads a state of dimension 3 and outputs 4 values, one for each action available + >>> module = nn.Linear(3, 4) + >>> qvalue_actor = QValueActor(module=module, spec=action_spec) + >>> qvalue_actor(td) + >>> print(td) + TensorDict( + fields={ + action: Tensor(shape=torch.Size([5, 4]), device=cpu, dtype=torch.int64, is_shared=False), + action_value: Tensor(shape=torch.Size([5, 4]), device=cpu, dtype=torch.float32, is_shared=False), + chosen_action_value: Tensor(shape=torch.Size([5, 1]), device=cpu, dtype=torch.float32, is_shared=False), + observation: Tensor(shape=torch.Size([5, 3]), device=cpu, dtype=torch.float32, is_shared=False)}, + batch_size=torch.Size([5]), + device=None, + is_shared=False) + +Distributional Q-learning is slightly different: in this case, the value network +does not output a scalar value for each state-action value. +Instead, the value space is divided in a an arbitrary number of "bins". The +value network outputs a probability that the state-action value belongs to one bin +or another. +Hence, for a state space of dimension M, an action space of dimension N and a number of bins B, +the value network encodes a +of a (s,a) -> v map. This map can be a table or a function. +For discrete action spaces with continuous (or near-continuous such as pixels) +states, it is customary to use a non-linear model such as a neural network for +the map. +The semantic of the Q-Value network is hopefully quite simple: we just need to +feed a tensor-to-tensor map that given a certain state (the input tensor), +outputs a list of action values to choose from. The wrapper will write the +resulting action in the input tensordict along with the list of action values. + + >>> import torch + >>> from tensordict import TensorDict + >>> from tensordict.nn.functional_modules import make_functional + >>> from torch import nn + >>> from torchrl.data import OneHot + >>> from torchrl.modules.tensordict_module.actors import QValueActor + >>> td = TensorDict({'observation': torch.randn(5, 3)}, [5]) + >>> # we have 4 actions to choose from + >>> action_spec = OneHot(4) + >>> # the model reads a state of dimension 3 and outputs 4 values, one for each action available + >>> module = nn.Linear(3, 4) + >>> qvalue_actor = QValueActor(module=module, spec=action_spec) + >>> qvalue_actor(td) + >>> print(td) + TensorDict( + fields={ + action: Tensor(shape=torch.Size([5, 4]), device=cpu, dtype=torch.int64, is_shared=False), + action_value: Tensor(shape=torch.Size([5, 4]), device=cpu, dtype=torch.float32, is_shared=False), + chosen_action_value: Tensor(shape=torch.Size([5, 1]), device=cpu, dtype=torch.float32, is_shared=False), + observation: Tensor(shape=torch.Size([5, 3]), device=cpu, dtype=torch.float32, is_shared=False)}, + batch_size=torch.Size([5]), + device=None, + is_shared=False) + +Distributional Q-learning is slightly different: in this case, the value network +does not output a scalar value for each state-action value. +Instead, the value space is divided in a an arbitrary number of "bins". The +value network outputs a probability that the state-action value belongs to one bin +or another. +Hence, for a state space of dimension M, an action space of dimension N and a number of bins B, +the value network encodes a +of a (s,a) -> v map. This map can be a table or a function. +For discrete action spaces with continuous (or near-continuous such as pixels) +states, it is customary to use a non-linear model such as a neural network for +the map. +The semantic of the Q-Value network is hopefully quite simple: we just need to +feed a tensor-to-tensor map that given a certain state (the input tensor), +outputs a list of action values to choose from. The wrapper will write the +resulting action in the input tensordict along with the list of action values. + + >>> import torch + >>> from tensordict import TensorDict + >>> from tensordict.nn.functional_modules import make_functional + >>> from torch import nn + >>> from torchrl.data import OneHot >>> from torchrl.modules.tensordict_module.actors import QValueActor >>> td = TensorDict({'observation': torch.randn(5, 3)}, [5]) >>> # we have 4 actions to choose from - >>> action_spec = OneHotDiscreteTensorSpec(4) + >>> action_spec = OneHot(4) >>> # the model reads a state of dimension 3 and outputs 4 values, one for each action available >>> module = nn.Linear(3, 4) >>> qvalue_actor = QValueActor(module=module, spec=action_spec) @@ -196,13 +282,57 @@ class: >>> import torch >>> from tensordict import TensorDict >>> from torch import nn - >>> from torchrl.data import OneHotDiscreteTensorSpec + >>> from torchrl.data import OneHot + >>> from torchrl.modules import DistributionalQValueActor, MLP + >>> td = TensorDict({'observation': torch.randn(5, 4)}, [5]) + >>> nbins = 3 + >>> # our model reads the observation and outputs a stack of 4 logits (one for each action) of size nbins=3 + >>> module = MLP(out_features=(nbins, 4), depth=2) + >>> action_spec = OneHot(4) + >>> qvalue_actor = DistributionalQValueActor(module=module, spec=action_spec, support=torch.arange(nbins)) + >>> td = qvalue_actor(td) + >>> print(td) + TensorDict( + fields={ + action: Tensor(shape=torch.Size([5, 4]), device=cpu, dtype=torch.int64, is_shared=False), + action_value: Tensor(shape=torch.Size([5, 3, 4]), device=cpu, dtype=torch.float32, is_shared=False), + observation: Tensor(shape=torch.Size([5, 4]), device=cpu, dtype=torch.float32, is_shared=False)}, + batch_size=torch.Size([5]), + device=None, + is_shared=False) + + >>> import torch + >>> from tensordict import TensorDict + >>> from torch import nn + >>> from torchrl.data import OneHot >>> from torchrl.modules import DistributionalQValueActor, MLP >>> td = TensorDict({'observation': torch.randn(5, 4)}, [5]) >>> nbins = 3 >>> # our model reads the observation and outputs a stack of 4 logits (one for each action) of size nbins=3 >>> module = MLP(out_features=(nbins, 4), depth=2) - >>> action_spec = OneHotDiscreteTensorSpec(4) + >>> action_spec = OneHot(4) + >>> qvalue_actor = DistributionalQValueActor(module=module, spec=action_spec, support=torch.arange(nbins)) + >>> td = qvalue_actor(td) + >>> print(td) + TensorDict( + fields={ + action: Tensor(shape=torch.Size([5, 4]), device=cpu, dtype=torch.int64, is_shared=False), + action_value: Tensor(shape=torch.Size([5, 3, 4]), device=cpu, dtype=torch.float32, is_shared=False), + observation: Tensor(shape=torch.Size([5, 4]), device=cpu, dtype=torch.float32, is_shared=False)}, + batch_size=torch.Size([5]), + device=None, + is_shared=False) + + >>> import torch + >>> from tensordict import TensorDict + >>> from torch import nn + >>> from torchrl.data import OneHot + >>> from torchrl.modules import DistributionalQValueActor, MLP + >>> td = TensorDict({'observation': torch.randn(5, 4)}, [5]) + >>> nbins = 3 + >>> # our model reads the observation and outputs a stack of 4 logits (one for each action) of size nbins=3 + >>> module = MLP(out_features=(nbins, 4), depth=2) + >>> action_spec = OneHot(4) >>> qvalue_actor = DistributionalQValueActor(module=module, spec=action_spec, support=torch.arange(nbins)) >>> td = qvalue_actor(td) >>> print(td) @@ -314,12 +444,13 @@ Regular modules :toctree: generated/ :template: rl_template_noinherit.rst - MLP - ConvNet + BatchRenorm1d + ConsistentDropout Conv3dNet - SqueezeLayer + ConvNet + MLP Squeeze2dLayer - BatchRenorm + SqueezeLayer Algorithm-specific modules ~~~~~~~~~~~~~~~~~~~~~~~~~~ diff --git a/docs/source/reference/objectives.rst b/docs/source/reference/objectives.rst index 1d92c390a4e..b3f8e242a9e 100644 --- a/docs/source/reference/objectives.rst +++ b/docs/source/reference/objectives.rst @@ -157,7 +157,7 @@ CrossQ :toctree: generated/ :template: rl_template_noinherit.rst - CrossQ + CrossQLoss IQL ---- @@ -179,6 +179,15 @@ CQL CQLLoss DiscreteCQLLoss +GAIL +---- + +.. autosummary:: + :toctree: generated/ + :template: rl_template_noinherit.rst + + GAILLoss + DT ---- diff --git a/examples/distributed/collectors/multi_nodes/delayed_dist.py b/examples/distributed/collectors/multi_nodes/delayed_dist.py index b140ee7bc67..7b7e053f498 100644 --- a/examples/distributed/collectors/multi_nodes/delayed_dist.py +++ b/examples/distributed/collectors/multi_nodes/delayed_dist.py @@ -114,7 +114,7 @@ def main(): import gym from torchrl.collectors import MultiSyncDataCollector, SyncDataCollector - from torchrl.data import BoundedTensorSpec + from torchrl.data import Bounded from torchrl.envs.libs.gym import GymEnv, set_gym_backend from torchrl.envs.utils import RandomPolicy @@ -128,7 +128,7 @@ def make_env(): collector = DistributedDataCollector( [EnvCreator(make_env)] * num_jobs, - policy=RandomPolicy(BoundedTensorSpec(-1, 1, shape=(1,))), + policy=RandomPolicy(Bounded(-1, 1, shape=(1,))), launcher="submitit_delayed", frames_per_batch=frames_per_batch, total_frames=total_frames, diff --git a/examples/distributed/collectors/multi_nodes/delayed_rpc.py b/examples/distributed/collectors/multi_nodes/delayed_rpc.py index adff8864413..f63c4d17409 100644 --- a/examples/distributed/collectors/multi_nodes/delayed_rpc.py +++ b/examples/distributed/collectors/multi_nodes/delayed_rpc.py @@ -113,7 +113,7 @@ def main(): import gym from torchrl.collectors import MultiSyncDataCollector, SyncDataCollector - from torchrl.data import BoundedTensorSpec + from torchrl.data import Bounded from torchrl.envs.libs.gym import GymEnv, set_gym_backend from torchrl.envs.utils import RandomPolicy @@ -127,7 +127,7 @@ def make_env(): collector = RPCDataCollector( [EnvCreator(make_env)] * num_jobs, - policy=RandomPolicy(BoundedTensorSpec(-1, 1, shape=(1,))), + policy=RandomPolicy(Bounded(-1, 1, shape=(1,))), launcher="submitit_delayed", frames_per_batch=frames_per_batch, total_frames=total_frames, diff --git a/examples/distributed/collectors/multi_nodes/ray_train.py b/examples/distributed/collectors/multi_nodes/ray_train.py index b05e92619fa..5697d88dc61 100644 --- a/examples/distributed/collectors/multi_nodes/ray_train.py +++ b/examples/distributed/collectors/multi_nodes/ray_train.py @@ -26,7 +26,7 @@ TransformedEnv, ) from torchrl.envs.libs.gym import GymEnv -from torchrl.envs.utils import check_env_specs, set_exploration_mode +from torchrl.envs.utils import check_env_specs, ExplorationType, set_exploration_type from torchrl.modules import ProbabilisticActor, TanhNormal, ValueOperator from torchrl.objectives import ClipPPOLoss from torchrl.objectives.value import GAE @@ -85,8 +85,8 @@ in_keys=["loc", "scale"], distribution_class=TanhNormal, distribution_kwargs={ - "min": env.action_spec.space.low, - "max": env.action_spec.space.high, + "low": env.action_spec.space.low, + "high": env.action_spec.space.high, }, return_log_prob=True, ) @@ -201,7 +201,7 @@ stepcount_str = f"step count (max): {logs['step_count'][-1]}" logs["lr"].append(optim.param_groups[0]["lr"]) lr_str = f"lr policy: {logs['lr'][-1]: 4.4f}" - with set_exploration_mode("mean"), torch.no_grad(): + with set_exploration_type(ExplorationType.MODE), torch.no_grad(): # execute a rollout with the trained policy eval_rollout = env.rollout(1000, policy_module) logs["eval reward"].append(eval_rollout["next", "reward"].mean().item()) diff --git a/examples/distributed/replay_buffers/distributed_replay_buffer.py b/examples/distributed/replay_buffers/distributed_replay_buffer.py index c7504fbf8ee..f25ea0bdc8b 100644 --- a/examples/distributed/replay_buffers/distributed_replay_buffer.py +++ b/examples/distributed/replay_buffers/distributed_replay_buffer.py @@ -150,8 +150,8 @@ def _create_and_launch_data_collectors(self) -> None: class ReplayBufferNode(RemoteTensorDictReplayBuffer): """Experience replay buffer node that is capable of accepting remote connections. Being a `RemoteTensorDictReplayBuffer` - means all of it's public methods are remotely invokable using `torch.rpc`. - Using a LazyMemmapStorage is highly advised in distributed settings with shared storage due to the lower serialisation + means all of its public methods are remotely invokable using `torch.rpc`. + Using a LazyMemmapStorage is highly advised in distributed settings with shared storage due to the lower serialization cost of MemoryMappedTensors as well as the ability to specify file storage locations which can improve ability to recover from node failures. Args: diff --git a/examples/envs/gym-async-info-reader.py b/examples/envs/gym-async-info-reader.py index 3f98e039290..72330f13030 100644 --- a/examples/envs/gym-async-info-reader.py +++ b/examples/envs/gym-async-info-reader.py @@ -48,7 +48,7 @@ def step(self, action): if __name__ == "__main__": import torch - from torchrl.data.tensor_specs import UnboundedContinuousTensorSpec + from torchrl.data.tensor_specs import Unbounded from torchrl.envs import check_env_specs, GymEnv, GymWrapper args = parser.parse_args() @@ -66,7 +66,7 @@ def step(self, action): keys = ["field1"] specs = [ - UnboundedContinuousTensorSpec(shape=(num_envs, 3), dtype=torch.float64), + Unbounded(shape=(num_envs, 3), dtype=torch.float64), ] # Create an info reader: this object will read the info and write its content to the tensordict diff --git a/knowledge_base/VIDEO_CUSTOMISATION.md b/knowledge_base/VIDEO_CUSTOMISATION.md index 956110d89aa..e28334708b2 100644 --- a/knowledge_base/VIDEO_CUSTOMISATION.md +++ b/knowledge_base/VIDEO_CUSTOMISATION.md @@ -50,9 +50,5 @@ as advised by the documentation. We can improve the video quality by appending all our desired settings (as keyword arguments) to `recorder` like so: ```python -# The arguments' types don't appear to matter too much, as long as they are -# appropriate for Python. -# For example, this would work as well: -# logger = CSVLogger(exp_name="my_exp", crf=17, preset="slow") -logger = CSVLogger(exp_name="my_exp", crf="17", preset="slow") +recorder = VideoRecorder(logger, tag = "my_video", options = {"crf": "17", "preset": "slow"}) ``` diff --git a/pytest.ini b/pytest.ini index 36d047d3055..39fe36617a1 100644 --- a/pytest.ini +++ b/pytest.ini @@ -4,6 +4,8 @@ addopts = -ra # Make tracebacks shorter --tb=native +markers = + unity_editor testpaths = test xfail_strict = True diff --git a/setup.py b/setup.py index 1c4f267bb94..a52711b0ed5 100644 --- a/setup.py +++ b/setup.py @@ -152,11 +152,15 @@ def get_extensions(): } sources = list(extension_sources) + include_dirs = [this_dir] + python_include_dir = os.getenv("PYTHON_INCLUDE_DIR") + if python_include_dir is not None: + include_dirs.append(python_include_dir) ext_modules = [ extension( "torchrl._torchrl", sources, - include_dirs=[this_dir], + include_dirs=include_dirs, extra_compile_args=extra_compile_args, extra_link_args=extra_link_args, ) @@ -191,7 +195,7 @@ def _main(argv): # tag = _run_cmd(["git", "describe", "--tags", "--exact-match", "@"]) this_directory = Path(__file__).parent - long_description = (this_directory / "README.md").read_text() + long_description = (this_directory / "README.md").read_text(encoding="utf8") sys.argv = [sys.argv[0]] + unknown extra_requires = { @@ -203,7 +207,7 @@ def _main(argv): "pygame", ], "dm_control": ["dm_control"], - "gym_continuous": ["gymnasium", "mujoco"], + "gym_continuous": ["gymnasium<1.0", "mujoco"], "rendering": ["moviepy"], "tests": ["pytest", "pyyaml", "pytest-instafail", "scipy"], "utils": [ @@ -229,6 +233,7 @@ def _main(argv): "pillow", ], "marl": ["vmas>=1.2.10", "pettingzoo>=1.24.1", "dm-meltingpot"], + "open_spiel": ["open_spiel>=1.5"], } extra_requires["all"] = set() for key in list(extra_requires.keys()): diff --git a/sota-implementations/a2c/a2c_atari.py b/sota-implementations/a2c/a2c_atari.py index f8c18147306..42ef4301c4d 100644 --- a/sota-implementations/a2c/a2c_atari.py +++ b/sota-implementations/a2c/a2c_atari.py @@ -226,6 +226,9 @@ def main(cfg: "DictConfig"): # noqa: F821 collector.update_policy_weights_() sampling_start = time.time() + collector.shutdown() + if not test_env.is_closed: + test_env.close() end_time = time.time() execution_time = end_time - start_time torchrl_logger.info(f"Training took {execution_time:.2f} seconds to finish") diff --git a/sota-implementations/a2c/a2c_mujoco.py b/sota-implementations/a2c/a2c_mujoco.py index d115174eb9c..2b390d39d2a 100644 --- a/sota-implementations/a2c/a2c_mujoco.py +++ b/sota-implementations/a2c/a2c_mujoco.py @@ -212,6 +212,9 @@ def main(cfg: "DictConfig"): # noqa: F821 collector.update_policy_weights_() sampling_start = time.time() + collector.shutdown() + if not test_env.is_closed: + test_env.close() end_time = time.time() execution_time = end_time - start_time torchrl_logger.info(f"Training took {execution_time:.2f} seconds to finish") diff --git a/sota-implementations/a2c/utils_atari.py b/sota-implementations/a2c/utils_atari.py index 58fa8541d90..6a09ff715e4 100644 --- a/sota-implementations/a2c/utils_atari.py +++ b/sota-implementations/a2c/utils_atari.py @@ -7,8 +7,8 @@ import torch.nn import torch.optim from tensordict.nn import TensorDictModule -from torchrl.data import CompositeSpec -from torchrl.data.tensor_specs import DiscreteBox +from torchrl.data import Composite +from torchrl.data.tensor_specs import CategoricalBox from torchrl.envs import ( CatFrames, DoubleToFloat, @@ -92,7 +92,7 @@ def make_ppo_modules_pixels(proof_environment): input_shape = proof_environment.observation_spec["pixels"].shape # Define distribution class and kwargs - if isinstance(proof_environment.action_spec.space, DiscreteBox): + if isinstance(proof_environment.action_spec.space, CategoricalBox): num_outputs = proof_environment.action_spec.space.n distribution_class = OneHotCategorical distribution_kwargs = {} @@ -148,7 +148,7 @@ def make_ppo_modules_pixels(proof_environment): policy_module = ProbabilisticActor( policy_module, in_keys=["logits"], - spec=CompositeSpec(action=proof_environment.action_spec), + spec=Composite(action=proof_environment.action_spec), distribution_class=distribution_class, distribution_kwargs=distribution_kwargs, return_log_prob=True, diff --git a/sota-implementations/a2c/utils_mujoco.py b/sota-implementations/a2c/utils_mujoco.py index 9bb5a1f6307..996706ce4f9 100644 --- a/sota-implementations/a2c/utils_mujoco.py +++ b/sota-implementations/a2c/utils_mujoco.py @@ -8,7 +8,7 @@ import torch.optim from tensordict.nn import AddStateIndependentNormalScale, TensorDictModule -from torchrl.data import CompositeSpec +from torchrl.data import Composite from torchrl.envs import ( ClipTransform, DoubleToFloat, @@ -90,7 +90,7 @@ def make_ppo_models_state(proof_environment): out_keys=["loc", "scale"], ), in_keys=["loc", "scale"], - spec=CompositeSpec(action=proof_environment.action_spec), + spec=Composite(action=proof_environment.action_spec), distribution_class=distribution_class, distribution_kwargs=distribution_kwargs, return_log_prob=True, diff --git a/sota-implementations/cql/cql_offline.py b/sota-implementations/cql/cql_offline.py index 5ca70f83b53..73155d9fa1a 100644 --- a/sota-implementations/cql/cql_offline.py +++ b/sota-implementations/cql/cql_offline.py @@ -58,14 +58,14 @@ def main(cfg: "DictConfig"): # noqa: F821 device = "cpu" device = torch.device(device) + # Create replay buffer + replay_buffer = make_offline_replay_buffer(cfg.replay_buffer) + # Create env train_env, eval_env = make_environment( cfg, train_num_envs=1, eval_num_envs=cfg.logger.eval_envs, logger=logger ) - # Create replay buffer - replay_buffer = make_offline_replay_buffer(cfg.replay_buffer) - # Create agent model = make_cql_model(cfg, train_env, eval_env, device) del train_env @@ -107,9 +107,6 @@ def main(cfg: "DictConfig"): # noqa: F821 q_loss = q_loss + cql_loss - alpha_loss = loss_vals["loss_alpha"] - alpha_prime_loss = loss_vals["loss_alpha_prime"] - # update model alpha_loss = loss_vals["loss_alpha"] alpha_prime_loss = loss_vals["loss_alpha_prime"] diff --git a/sota-implementations/cql/utils.py b/sota-implementations/cql/utils.py index fae54da049a..c1d6fb52024 100644 --- a/sota-implementations/cql/utils.py +++ b/sota-implementations/cql/utils.py @@ -11,7 +11,7 @@ from torchrl.collectors import SyncDataCollector from torchrl.data import ( - CompositeSpec, + Composite, LazyMemmapStorage, TensorDictPrioritizedReplayBuffer, TensorDictReplayBuffer, @@ -252,7 +252,7 @@ def make_discretecql_model(cfg, train_env, eval_env, device="cpu"): actor_net = MLP(**actor_net_kwargs) qvalue_module = QValueActor( module=actor_net, - spec=CompositeSpec(action=action_spec), + spec=Composite(action=action_spec), in_keys=["observation"], ) qvalue_module = qvalue_module.to(device) diff --git a/sota-implementations/crossq/crossq.py b/sota-implementations/crossq/crossq.py index df34d4ae68d..b07ae880046 100644 --- a/sota-implementations/crossq/crossq.py +++ b/sota-implementations/crossq/crossq.py @@ -203,7 +203,7 @@ def main(cfg: "DictConfig"): # noqa: F821 # Evaluation if abs(collected_frames % eval_iter) < frames_per_batch: - with set_exploration_type(ExplorationType.MODE), torch.no_grad(): + with set_exploration_type(ExplorationType.DETERMINISTIC), torch.no_grad(): eval_start = time.time() eval_rollout = eval_env.rollout( eval_rollout_steps, @@ -220,6 +220,10 @@ def main(cfg: "DictConfig"): # noqa: F821 sampling_start = time.time() collector.shutdown() + if not eval_env.is_closed: + eval_env.close() + if not train_env.is_closed: + train_env.close() end_time = time.time() execution_time = end_time - start_time torchrl_logger.info(f"Training took {execution_time:.2f} seconds to finish") diff --git a/sota-implementations/ddpg/ddpg.py b/sota-implementations/ddpg/ddpg.py index 1b038d69d15..cebc3685625 100644 --- a/sota-implementations/ddpg/ddpg.py +++ b/sota-implementations/ddpg/ddpg.py @@ -205,6 +205,10 @@ def main(cfg: "DictConfig"): # noqa: F821 collector.shutdown() end_time = time.time() execution_time = end_time - start_time + if not eval_env.is_closed: + eval_env.close() + if not train_env.is_closed: + train_env.close() torchrl_logger.info(f"Training took {execution_time:.2f} seconds to finish") diff --git a/sota-implementations/decision_transformer/dt.py b/sota-implementations/decision_transformer/dt.py index 9cca9fd8af5..b892462339c 100644 --- a/sota-implementations/decision_transformer/dt.py +++ b/sota-implementations/decision_transformer/dt.py @@ -131,6 +131,8 @@ def main(cfg: "DictConfig"): # noqa: F821 log_metrics(logger, to_log, i) pbar.close() + if not test_env.is_closed: + test_env.close() torchrl_logger.info(f"Training time: {time.time() - start_time}") diff --git a/sota-implementations/decision_transformer/online_dt.py b/sota-implementations/decision_transformer/online_dt.py index da2241ce9fa..184c850b626 100644 --- a/sota-implementations/decision_transformer/online_dt.py +++ b/sota-implementations/decision_transformer/online_dt.py @@ -145,6 +145,8 @@ def main(cfg: "DictConfig"): # noqa: F821 log_metrics(logger, to_log, i) pbar.close() + if not test_env.is_closed: + test_env.close() torchrl_logger.info(f"Training time: {time.time() - start_time}") diff --git a/sota-implementations/decision_transformer/utils.py b/sota-implementations/decision_transformer/utils.py index 409833c75fa..ee2cc6e424c 100644 --- a/sota-implementations/decision_transformer/utils.py +++ b/sota-implementations/decision_transformer/utils.py @@ -38,7 +38,7 @@ ) from torchrl.envs.libs.dm_control import DMControlEnv from torchrl.envs.libs.gym import set_gym_backend -from torchrl.envs.utils import set_exploration_mode +from torchrl.envs.utils import ExplorationType, set_exploration_type from torchrl.modules import ( DTActor, OnlineDTActor, @@ -374,13 +374,12 @@ def make_odt_model(cfg): module=actor_module, distribution_class=dist_class, distribution_kwargs=dist_kwargs, - default_interaction_mode="random", cache_dist=False, return_log_prob=False, ) # init the lazy layers - with torch.no_grad(), set_exploration_mode("random"): + with torch.no_grad(), set_exploration_type(ExplorationType.RANDOM): td = proof_environment.rollout(max_steps=100) td["action"] = td["next", "action"] actor(td) @@ -428,13 +427,12 @@ def make_dt_model(cfg): module=actor_module, distribution_class=dist_class, distribution_kwargs=dist_kwargs, - default_interaction_mode="random", cache_dist=False, return_log_prob=False, ) # init the lazy layers - with torch.no_grad(), set_exploration_mode("random"): + with torch.no_grad(), set_exploration_type(ExplorationType.RANDOM): td = proof_environment.rollout(max_steps=100) td["action"] = td["next", "action"] actor(td) diff --git a/sota-implementations/discrete_sac/discrete_sac.py b/sota-implementations/discrete_sac/discrete_sac.py index 386f743c7d3..a9a08827f5d 100644 --- a/sota-implementations/discrete_sac/discrete_sac.py +++ b/sota-implementations/discrete_sac/discrete_sac.py @@ -222,6 +222,10 @@ def main(cfg: "DictConfig"): # noqa: F821 sampling_start = time.time() collector.shutdown() + if not eval_env.is_closed: + eval_env.close() + if not train_env.is_closed: + train_env.close() end_time = time.time() execution_time = end_time - start_time torchrl_logger.info(f"Training took {execution_time:.2f} seconds to finish") diff --git a/sota-implementations/discrete_sac/utils.py b/sota-implementations/discrete_sac/utils.py index ddffffc2a8e..8051f07fe95 100644 --- a/sota-implementations/discrete_sac/utils.py +++ b/sota-implementations/discrete_sac/utils.py @@ -12,7 +12,7 @@ from torch import nn, optim from torchrl.collectors import SyncDataCollector from torchrl.data import ( - CompositeSpec, + Composite, TensorDictPrioritizedReplayBuffer, TensorDictReplayBuffer, ) @@ -203,7 +203,7 @@ def make_sac_agent(cfg, train_env, eval_env, device): out_keys=["logits"], ) actor = ProbabilisticActor( - spec=CompositeSpec(action=eval_env.action_spec), + spec=Composite(action=eval_env.action_spec), module=actor_module, in_keys=["logits"], out_keys=["action"], diff --git a/sota-implementations/dqn/dqn_atari.py b/sota-implementations/dqn/dqn_atari.py index 906273ee2f5..5d0162080e2 100644 --- a/sota-implementations/dqn/dqn_atari.py +++ b/sota-implementations/dqn/dqn_atari.py @@ -228,6 +228,9 @@ def main(cfg: "DictConfig"): # noqa: F821 sampling_start = time.time() collector.shutdown() + if not test_env.is_closed: + test_env.close() + end_time = time.time() execution_time = end_time - start_time torchrl_logger.info(f"Training took {execution_time:.2f} seconds to finish") diff --git a/sota-implementations/dqn/dqn_cartpole.py b/sota-implementations/dqn/dqn_cartpole.py index 173f88f7028..8149c700958 100644 --- a/sota-implementations/dqn/dqn_cartpole.py +++ b/sota-implementations/dqn/dqn_cartpole.py @@ -207,6 +207,8 @@ def main(cfg: "DictConfig"): # noqa: F821 sampling_start = time.time() collector.shutdown() + if not test_env.is_closed: + test_env.close() end_time = time.time() execution_time = end_time - start_time torchrl_logger.info(f"Training took {execution_time:.2f} seconds to finish") diff --git a/sota-implementations/dqn/utils_atari.py b/sota-implementations/dqn/utils_atari.py index 3dbbfe87af4..6f39e824c60 100644 --- a/sota-implementations/dqn/utils_atari.py +++ b/sota-implementations/dqn/utils_atari.py @@ -5,7 +5,7 @@ import torch.nn import torch.optim -from torchrl.data import CompositeSpec +from torchrl.data import Composite from torchrl.envs import ( CatFrames, DoubleToFloat, @@ -84,7 +84,7 @@ def make_dqn_modules_pixels(proof_environment): ) qvalue_module = QValueActor( module=torch.nn.Sequential(cnn, mlp), - spec=CompositeSpec(action=action_spec), + spec=Composite(action=action_spec), in_keys=["pixels"], ) return qvalue_module diff --git a/sota-implementations/dqn/utils_cartpole.py b/sota-implementations/dqn/utils_cartpole.py index 2df280a04b4..c7f7491ad15 100644 --- a/sota-implementations/dqn/utils_cartpole.py +++ b/sota-implementations/dqn/utils_cartpole.py @@ -5,7 +5,7 @@ import torch.nn import torch.optim -from torchrl.data import CompositeSpec +from torchrl.data import Composite from torchrl.envs import RewardSum, StepCounter, TransformedEnv from torchrl.envs.libs.gym import GymEnv from torchrl.modules import MLP, QValueActor @@ -48,7 +48,7 @@ def make_dqn_modules(proof_environment): qvalue_module = QValueActor( module=mlp, - spec=CompositeSpec(action=action_spec), + spec=Composite(action=action_spec), in_keys=["observation"], ) return qvalue_module diff --git a/sota-implementations/dreamer/dreamer_utils.py b/sota-implementations/dreamer/dreamer_utils.py index 6745b1a079a..849d8c813b6 100644 --- a/sota-implementations/dreamer/dreamer_utils.py +++ b/sota-implementations/dreamer/dreamer_utils.py @@ -20,11 +20,11 @@ from torchrl.collectors import SyncDataCollector from torchrl.data import ( - CompositeSpec, + Composite, LazyMemmapStorage, SliceSampler, TensorDictReplayBuffer, - UnboundedContinuousTensorSpec, + Unbounded, ) from torchrl.envs import ( @@ -92,8 +92,8 @@ def _make_env(cfg, device, from_pixels=False): else: raise NotImplementedError(f"Unknown lib {lib}.") default_dict = { - "state": UnboundedContinuousTensorSpec(shape=(cfg.networks.state_dim,)), - "belief": UnboundedContinuousTensorSpec(shape=(cfg.networks.rssm_hidden_dim,)), + "state": Unbounded(shape=(cfg.networks.state_dim,)), + "belief": Unbounded(shape=(cfg.networks.rssm_hidden_dim,)), } env = env.append_transform( TensorDictPrimer(random=False, default_value=0, **default_dict) @@ -469,13 +469,13 @@ def _dreamer_make_actor_sim(action_key, proof_environment, actor_module): actor_module, in_keys=["state", "belief"], out_keys=["loc", "scale"], - spec=CompositeSpec( + spec=Composite( **{ - "loc": UnboundedContinuousTensorSpec( + "loc": Unbounded( proof_environment.action_spec.shape, device=proof_environment.action_spec.device, ), - "scale": UnboundedContinuousTensorSpec( + "scale": Unbounded( proof_environment.action_spec.shape, device=proof_environment.action_spec.device, ), @@ -488,7 +488,7 @@ def _dreamer_make_actor_sim(action_key, proof_environment, actor_module): default_interaction_type=InteractionType.RANDOM, distribution_class=TanhNormal, distribution_kwargs={"tanh_loc": True}, - spec=CompositeSpec(**{action_key: proof_environment.action_spec}), + spec=Composite(**{action_key: proof_environment.action_spec}), ), ) return actor_simulator @@ -526,12 +526,12 @@ def _dreamer_make_actor_real( actor_module, in_keys=["state", "belief"], out_keys=["loc", "scale"], - spec=CompositeSpec( + spec=Composite( **{ - "loc": UnboundedContinuousTensorSpec( + "loc": Unbounded( proof_environment.action_spec.shape, ), - "scale": UnboundedContinuousTensorSpec( + "scale": Unbounded( proof_environment.action_spec.shape, ), } @@ -543,9 +543,7 @@ def _dreamer_make_actor_real( default_interaction_type=InteractionType.DETERMINISTIC, distribution_class=TanhNormal, distribution_kwargs={"tanh_loc": True}, - spec=CompositeSpec( - **{action_key: proof_environment.action_spec.to("cpu")} - ), + spec=Composite(**{action_key: proof_environment.action_spec.to("cpu")}), ), ), SafeModule( diff --git a/sota-implementations/gail/config.yaml b/sota-implementations/gail/config.yaml new file mode 100644 index 00000000000..cf6c8053037 --- /dev/null +++ b/sota-implementations/gail/config.yaml @@ -0,0 +1,46 @@ +env: + env_name: HalfCheetah-v4 + seed: 42 + backend: gymnasium + +logger: + backend: wandb + project_name: gail + group_name: null + exp_name: gail_ppo + test_interval: 5000 + num_test_episodes: 5 + video: False + mode: online + +ppo: + collector: + frames_per_batch: 2048 + total_frames: 1_000_000 + + optim: + lr: 3e-4 + weight_decay: 0.0 + anneal_lr: True + + loss: + gamma: 0.99 + mini_batch_size: 64 + ppo_epochs: 10 + gae_lambda: 0.95 + clip_epsilon: 0.2 + anneal_clip_epsilon: False + critic_coef: 0.25 + entropy_coef: 0.0 + loss_critic_type: l2 + +gail: + hidden_dim: 128 + lr: 3e-4 + use_grad_penalty: False + gp_lambda: 10.0 + device: null + +replay_buffer: + dataset: halfcheetah-expert-v2 + batch_size: 256 diff --git a/sota-implementations/gail/gail.py b/sota-implementations/gail/gail.py new file mode 100644 index 00000000000..a3c64693fb3 --- /dev/null +++ b/sota-implementations/gail/gail.py @@ -0,0 +1,281 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +"""GAIL Example. + +This is a self-contained example of an offline GAIL training script. + +The helper functions for gail are coded in the gail_utils.py and helper functions for ppo in ppo_utils. + +""" +import hydra +import numpy as np +import torch +import tqdm + +from gail_utils import log_metrics, make_gail_discriminator, make_offline_replay_buffer +from ppo_utils import eval_model, make_env, make_ppo_models +from torchrl.collectors import SyncDataCollector +from torchrl.data import LazyMemmapStorage, TensorDictReplayBuffer +from torchrl.data.replay_buffers.samplers import SamplerWithoutReplacement + +from torchrl.envs import set_gym_backend +from torchrl.envs.utils import ExplorationType, set_exploration_type +from torchrl.objectives import ClipPPOLoss, GAILLoss +from torchrl.objectives.value.advantages import GAE +from torchrl.record import VideoRecorder +from torchrl.record.loggers import generate_exp_name, get_logger + + +@hydra.main(config_path="", config_name="config") +def main(cfg: "DictConfig"): # noqa: F821 + set_gym_backend(cfg.env.backend).set() + + device = cfg.gail.device + if device in ("", None): + if torch.cuda.is_available(): + device = "cuda:0" + else: + device = "cpu" + device = torch.device(device) + num_mini_batches = ( + cfg.ppo.collector.frames_per_batch // cfg.ppo.loss.mini_batch_size + ) + total_network_updates = ( + (cfg.ppo.collector.total_frames // cfg.ppo.collector.frames_per_batch) + * cfg.ppo.loss.ppo_epochs + * num_mini_batches + ) + + # Create logger + exp_name = generate_exp_name("Gail", cfg.logger.exp_name) + logger = None + if cfg.logger.backend: + logger = get_logger( + logger_type=cfg.logger.backend, + logger_name="gail_logging", + experiment_name=exp_name, + wandb_kwargs={ + "mode": cfg.logger.mode, + "config": dict(cfg), + "project": cfg.logger.project_name, + "group": cfg.logger.group_name, + }, + ) + + # Set seeds + torch.manual_seed(cfg.env.seed) + np.random.seed(cfg.env.seed) + + # Create models (check utils_mujoco.py) + actor, critic = make_ppo_models(cfg.env.env_name) + actor, critic = actor.to(device), critic.to(device) + + # Create collector + collector = SyncDataCollector( + create_env_fn=make_env(cfg.env.env_name, device), + policy=actor, + frames_per_batch=cfg.ppo.collector.frames_per_batch, + total_frames=cfg.ppo.collector.total_frames, + device=device, + storing_device=device, + max_frames_per_traj=-1, + ) + + # Create data buffer + data_buffer = TensorDictReplayBuffer( + storage=LazyMemmapStorage(cfg.ppo.collector.frames_per_batch), + sampler=SamplerWithoutReplacement(), + batch_size=cfg.ppo.loss.mini_batch_size, + ) + + # Create loss and adv modules + adv_module = GAE( + gamma=cfg.ppo.loss.gamma, + lmbda=cfg.ppo.loss.gae_lambda, + value_network=critic, + average_gae=False, + ) + + loss_module = ClipPPOLoss( + actor_network=actor, + critic_network=critic, + clip_epsilon=cfg.ppo.loss.clip_epsilon, + loss_critic_type=cfg.ppo.loss.loss_critic_type, + entropy_coef=cfg.ppo.loss.entropy_coef, + critic_coef=cfg.ppo.loss.critic_coef, + normalize_advantage=True, + ) + + # Create optimizers + actor_optim = torch.optim.Adam(actor.parameters(), lr=cfg.ppo.optim.lr, eps=1e-5) + critic_optim = torch.optim.Adam(critic.parameters(), lr=cfg.ppo.optim.lr, eps=1e-5) + + # Create replay buffer + replay_buffer = make_offline_replay_buffer(cfg.replay_buffer) + + # Create Discriminator + discriminator = make_gail_discriminator(cfg, collector.env, device) + + # Create loss + discriminator_loss = GAILLoss( + discriminator, + use_grad_penalty=cfg.gail.use_grad_penalty, + gp_lambda=cfg.gail.gp_lambda, + ) + + # Create optimizer + discriminator_optim = torch.optim.Adam( + params=discriminator.parameters(), lr=cfg.gail.lr + ) + + # Create test environment + logger_video = cfg.logger.video + test_env = make_env(cfg.env.env_name, device, from_pixels=logger_video) + if logger_video: + test_env = test_env.append_transform( + VideoRecorder(logger, tag="rendering/test", in_keys=["pixels"]) + ) + test_env.eval() + + # Training loop + collected_frames = 0 + num_network_updates = 0 + pbar = tqdm.tqdm(total=cfg.ppo.collector.total_frames) + + # extract cfg variables + cfg_loss_ppo_epochs = cfg.ppo.loss.ppo_epochs + cfg_optim_anneal_lr = cfg.ppo.optim.anneal_lr + cfg_optim_lr = cfg.ppo.optim.lr + cfg_loss_anneal_clip_eps = cfg.ppo.loss.anneal_clip_epsilon + cfg_loss_clip_epsilon = cfg.ppo.loss.clip_epsilon + cfg_logger_test_interval = cfg.logger.test_interval + cfg_logger_num_test_episodes = cfg.logger.num_test_episodes + + for i, data in enumerate(collector): + + log_info = {} + frames_in_batch = data.numel() + collected_frames += frames_in_batch + pbar.update(data.numel()) + + # Update discriminator + # Get expert data + expert_data = replay_buffer.sample() + expert_data = expert_data.to(device) + # Add collector data to expert data + expert_data.set( + discriminator_loss.tensor_keys.collector_action, + data["action"][: expert_data.batch_size[0]], + ) + expert_data.set( + discriminator_loss.tensor_keys.collector_observation, + data["observation"][: expert_data.batch_size[0]], + ) + d_loss = discriminator_loss(expert_data) + + # Backward pass + discriminator_optim.zero_grad() + d_loss.get("loss").backward() + discriminator_optim.step() + + # Compute discriminator reward + with torch.no_grad(): + data = discriminator(data) + d_rewards = -torch.log(1 - data["d_logits"] + 1e-8) + + # Set discriminator rewards to tensordict + data.set(("next", "reward"), d_rewards) + + # Get training rewards and episode lengths + episode_rewards = data["next", "episode_reward"][data["next", "done"]] + if len(episode_rewards) > 0: + episode_length = data["next", "step_count"][data["next", "done"]] + log_info.update( + { + "train/reward": episode_rewards.mean().item(), + "train/episode_length": episode_length.sum().item() + / len(episode_length), + } + ) + # Update PPO + for _ in range(cfg_loss_ppo_epochs): + + # Compute GAE + with torch.no_grad(): + data = adv_module(data) + data_reshape = data.reshape(-1) + + # Update the data buffer + data_buffer.extend(data_reshape) + + for _, batch in enumerate(data_buffer): + + # Get a data batch + batch = batch.to(device) + + # Linearly decrease the learning rate and clip epsilon + alpha = 1.0 + if cfg_optim_anneal_lr: + alpha = 1 - (num_network_updates / total_network_updates) + for group in actor_optim.param_groups: + group["lr"] = cfg_optim_lr * alpha + for group in critic_optim.param_groups: + group["lr"] = cfg_optim_lr * alpha + if cfg_loss_anneal_clip_eps: + loss_module.clip_epsilon.copy_(cfg_loss_clip_epsilon * alpha) + num_network_updates += 1 + + # Forward pass PPO loss + loss = loss_module(batch) + critic_loss = loss["loss_critic"] + actor_loss = loss["loss_objective"] + loss["loss_entropy"] + + # Backward pass + actor_loss.backward() + critic_loss.backward() + + # Update the networks + actor_optim.step() + critic_optim.step() + actor_optim.zero_grad() + critic_optim.zero_grad() + + log_info.update( + { + "train/actor_loss": actor_loss.item(), + "train/critic_loss": critic_loss.item(), + "train/discriminator_loss": d_loss["loss"].item(), + "train/lr": alpha * cfg_optim_lr, + "train/clip_epsilon": ( + alpha * cfg_loss_clip_epsilon + if cfg_loss_anneal_clip_eps + else cfg_loss_clip_epsilon + ), + } + ) + + # evaluation + with torch.no_grad(), set_exploration_type(ExplorationType.DETERMINISTIC): + if ((i - 1) * frames_in_batch) // cfg_logger_test_interval < ( + i * frames_in_batch + ) // cfg_logger_test_interval: + actor.eval() + test_rewards = eval_model( + actor, test_env, num_episodes=cfg_logger_num_test_episodes + ) + log_info.update( + { + "eval/reward": test_rewards.mean(), + } + ) + actor.train() + if logger is not None: + log_metrics(logger, log_info, i) + + pbar.close() + + +if __name__ == "__main__": + main() diff --git a/sota-implementations/gail/gail_utils.py b/sota-implementations/gail/gail_utils.py new file mode 100644 index 00000000000..067e9c8c927 --- /dev/null +++ b/sota-implementations/gail/gail_utils.py @@ -0,0 +1,69 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import torch.nn as nn +import torch.optim + +from torchrl.data.datasets.d4rl import D4RLExperienceReplay +from torchrl.data.replay_buffers import SamplerWithoutReplacement +from torchrl.envs import DoubleToFloat + +from torchrl.modules import SafeModule + + +# ==================================================================== +# Offline Replay buffer +# --------------------------- + + +def make_offline_replay_buffer(rb_cfg): + data = D4RLExperienceReplay( + dataset_id=rb_cfg.dataset, + split_trajs=False, + batch_size=rb_cfg.batch_size, + sampler=SamplerWithoutReplacement(drop_last=False), + prefetch=4, + direct_download=True, + ) + + data.append_transform(DoubleToFloat()) + + return data + + +def make_gail_discriminator(cfg, train_env, device="cpu"): + """Make GAIL discriminator.""" + + state_dim = train_env.observation_spec["observation"].shape[0] + action_dim = train_env.action_spec.shape[0] + + hidden_dim = cfg.gail.hidden_dim + + # Define Discriminator Network + class Discriminator(nn.Module): + def __init__(self, state_dim, action_dim): + super(Discriminator, self).__init__() + self.fc1 = nn.Linear(state_dim + action_dim, hidden_dim) + self.fc2 = nn.Linear(hidden_dim, hidden_dim) + self.fc3 = nn.Linear(hidden_dim, 1) + + def forward(self, state, action): + x = torch.cat([state, action], dim=1) + x = torch.relu(self.fc1(x)) + x = torch.relu(self.fc2(x)) + return torch.sigmoid(self.fc3(x)) + + d_module = SafeModule( + module=Discriminator(state_dim, action_dim), + in_keys=["observation", "action"], + out_keys=["d_logits"], + ) + return d_module.to(device) + + +def log_metrics(logger, metrics, step): + if logger is not None: + for metric_name, metric_value in metrics.items(): + logger.log_scalar(metric_name, metric_value, step) diff --git a/sota-implementations/gail/ppo_utils.py b/sota-implementations/gail/ppo_utils.py new file mode 100644 index 00000000000..7986738f8e6 --- /dev/null +++ b/sota-implementations/gail/ppo_utils.py @@ -0,0 +1,150 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import torch.nn +import torch.optim + +from tensordict.nn import AddStateIndependentNormalScale, TensorDictModule +from torchrl.data import CompositeSpec +from torchrl.envs import ( + ClipTransform, + DoubleToFloat, + ExplorationType, + RewardSum, + StepCounter, + TransformedEnv, + VecNorm, +) +from torchrl.envs.libs.gym import GymEnv +from torchrl.modules import MLP, ProbabilisticActor, TanhNormal, ValueOperator +from torchrl.record import VideoRecorder + + +# ==================================================================== +# Environment utils +# -------------------------------------------------------------------- + + +def make_env(env_name="HalfCheetah-v4", device="cpu", from_pixels: bool = False): + env = GymEnv(env_name, device=device, from_pixels=from_pixels, pixels_only=False) + env = TransformedEnv(env) + env.append_transform(VecNorm(in_keys=["observation"], decay=0.99999, eps=1e-2)) + env.append_transform(ClipTransform(in_keys=["observation"], low=-10, high=10)) + env.append_transform(RewardSum()) + env.append_transform(StepCounter()) + env.append_transform(DoubleToFloat(in_keys=["observation"])) + return env + + +# ==================================================================== +# Model utils +# -------------------------------------------------------------------- + + +def make_ppo_models_state(proof_environment): + + # Define input shape + input_shape = proof_environment.observation_spec["observation"].shape + + # Define policy output distribution class + num_outputs = proof_environment.action_spec.shape[-1] + distribution_class = TanhNormal + distribution_kwargs = { + "low": proof_environment.action_spec.space.low, + "high": proof_environment.action_spec.space.high, + "tanh_loc": False, + } + + # Define policy architecture + policy_mlp = MLP( + in_features=input_shape[-1], + activation_class=torch.nn.Tanh, + out_features=num_outputs, # predict only loc + num_cells=[64, 64], + ) + + # Initialize policy weights + for layer in policy_mlp.modules(): + if isinstance(layer, torch.nn.Linear): + torch.nn.init.orthogonal_(layer.weight, 1.0) + layer.bias.data.zero_() + + # Add state-independent normal scale + policy_mlp = torch.nn.Sequential( + policy_mlp, + AddStateIndependentNormalScale( + proof_environment.action_spec.shape[-1], scale_lb=1e-8 + ), + ) + + # Add probabilistic sampling of the actions + policy_module = ProbabilisticActor( + TensorDictModule( + module=policy_mlp, + in_keys=["observation"], + out_keys=["loc", "scale"], + ), + in_keys=["loc", "scale"], + spec=CompositeSpec(action=proof_environment.action_spec), + distribution_class=distribution_class, + distribution_kwargs=distribution_kwargs, + return_log_prob=True, + default_interaction_type=ExplorationType.RANDOM, + ) + + # Define value architecture + value_mlp = MLP( + in_features=input_shape[-1], + activation_class=torch.nn.Tanh, + out_features=1, + num_cells=[64, 64], + ) + + # Initialize value weights + for layer in value_mlp.modules(): + if isinstance(layer, torch.nn.Linear): + torch.nn.init.orthogonal_(layer.weight, 0.01) + layer.bias.data.zero_() + + # Define value module + value_module = ValueOperator( + value_mlp, + in_keys=["observation"], + ) + + return policy_module, value_module + + +def make_ppo_models(env_name): + proof_environment = make_env(env_name, device="cpu") + actor, critic = make_ppo_models_state(proof_environment) + return actor, critic + + +# ==================================================================== +# Evaluation utils +# -------------------------------------------------------------------- + + +def dump_video(module): + if isinstance(module, VideoRecorder): + module.dump() + + +def eval_model(actor, test_env, num_episodes=3): + test_rewards = [] + for _ in range(num_episodes): + td_test = test_env.rollout( + policy=actor, + auto_reset=True, + auto_cast_to_device=True, + break_when_any_done=True, + max_steps=10_000_000, + ) + reward = td_test["next", "episode_reward"][td_test["next", "done"]] + test_rewards.append(reward.cpu()) + test_env.apply(dump_video) + del td_test + return torch.cat(test_rewards, 0).mean() diff --git a/sota-implementations/impala/utils.py b/sota-implementations/impala/utils.py index b365dca3867..30293940377 100644 --- a/sota-implementations/impala/utils.py +++ b/sota-implementations/impala/utils.py @@ -6,7 +6,7 @@ import torch.nn import torch.optim from tensordict.nn import TensorDictModule -from torchrl.data import CompositeSpec +from torchrl.data import Composite from torchrl.envs import ( CatFrames, DoubleToFloat, @@ -100,7 +100,7 @@ def make_ppo_modules_pixels(proof_environment): out_keys=["common_features"], ) - # Define on head for the policy + # Define one head for the policy policy_net = MLP( in_features=common_mlp_output.shape[-1], out_features=num_outputs, @@ -117,7 +117,7 @@ def make_ppo_modules_pixels(proof_environment): policy_module = ProbabilisticActor( policy_module, in_keys=["logits"], - spec=CompositeSpec(action=proof_environment.action_spec), + spec=Composite(action=proof_environment.action_spec), distribution_class=distribution_class, distribution_kwargs=distribution_kwargs, return_log_prob=True, diff --git a/sota-implementations/iql/iql_offline.py b/sota-implementations/iql/iql_offline.py index d1a16fd8192..53581782d20 100644 --- a/sota-implementations/iql/iql_offline.py +++ b/sota-implementations/iql/iql_offline.py @@ -141,6 +141,10 @@ def main(cfg: "DictConfig"): # noqa: F821 log_metrics(logger, to_log, i) pbar.close() + if not eval_env.is_closed: + eval_env.close() + if not train_env.is_closed: + train_env.close() torchrl_logger.info(f"Training time: {time.time() - start_time}") diff --git a/sota-implementations/iql/iql_online.py b/sota-implementations/iql/iql_online.py index d50ff806294..3cdff06ffa2 100644 --- a/sota-implementations/iql/iql_online.py +++ b/sota-implementations/iql/iql_online.py @@ -204,6 +204,12 @@ def main(cfg: "DictConfig"): # noqa: F821 collector.shutdown() end_time = time.time() execution_time = end_time - start_time + + if not eval_env.is_closed: + eval_env.close() + if not train_env.is_closed: + train_env.close() + torchrl_logger.info(f"Training took {execution_time:.2f} seconds to finish") diff --git a/sota-implementations/iql/utils.py b/sota-implementations/iql/utils.py index 61d31b88eb8..a24c6168375 100644 --- a/sota-implementations/iql/utils.py +++ b/sota-implementations/iql/utils.py @@ -11,7 +11,7 @@ from torchrl.collectors import SyncDataCollector from torchrl.data import ( - CompositeSpec, + Composite, LazyMemmapStorage, TensorDictPrioritizedReplayBuffer, TensorDictReplayBuffer, @@ -306,7 +306,7 @@ def make_discrete_iql_model(cfg, train_env, eval_env, device): out_keys=["logits"], ) actor = ProbabilisticActor( - spec=CompositeSpec(action=eval_env.action_spec), + spec=Composite(action=eval_env.action_spec), module=actor_module, in_keys=["logits"], out_keys=["action"], diff --git a/sota-implementations/multiagent/iql.py b/sota-implementations/multiagent/iql.py index a4d2b88a9d0..39750c5d425 100644 --- a/sota-implementations/multiagent/iql.py +++ b/sota-implementations/multiagent/iql.py @@ -225,6 +225,12 @@ def train(cfg: "DictConfig"): # noqa: F821 logger.experiment.log({}, commit=True) sampling_start = time.time() + collector.shutdown() + if not env.is_closed: + env.close() + if not env_test.is_closed: + env_test.close() + if __name__ == "__main__": train() diff --git a/sota-implementations/multiagent/maddpg_iddpg.py b/sota-implementations/multiagent/maddpg_iddpg.py index e9de2ac4e14..aad1df14fff 100644 --- a/sota-implementations/multiagent/maddpg_iddpg.py +++ b/sota-implementations/multiagent/maddpg_iddpg.py @@ -251,6 +251,11 @@ def train(cfg: "DictConfig"): # noqa: F821 if cfg.logger.backend == "wandb": logger.experiment.log({}, commit=True) sampling_start = time.time() + collector.shutdown() + if not env.is_closed: + env.close() + if not env_test.is_closed: + env_test.close() if __name__ == "__main__": diff --git a/sota-implementations/multiagent/mappo_ippo.py b/sota-implementations/multiagent/mappo_ippo.py index fa006a7d4a2..d2e218b843a 100644 --- a/sota-implementations/multiagent/mappo_ippo.py +++ b/sota-implementations/multiagent/mappo_ippo.py @@ -254,6 +254,11 @@ def train(cfg: "DictConfig"): # noqa: F821 if cfg.logger.backend == "wandb": logger.experiment.log({}, commit=True) sampling_start = time.time() + collector.shutdown() + if not env.is_closed: + env.close() + if not env_test.is_closed: + env_test.close() if __name__ == "__main__": diff --git a/sota-implementations/multiagent/qmix_vdn.py b/sota-implementations/multiagent/qmix_vdn.py index 4e6a962c556..c5993f902c6 100644 --- a/sota-implementations/multiagent/qmix_vdn.py +++ b/sota-implementations/multiagent/qmix_vdn.py @@ -259,6 +259,11 @@ def train(cfg: "DictConfig"): # noqa: F821 if cfg.logger.backend == "wandb": logger.experiment.log({}, commit=True) sampling_start = time.time() + collector.shutdown() + if not env.is_closed: + env.close() + if not env_test.is_closed: + env_test.close() if __name__ == "__main__": diff --git a/sota-implementations/multiagent/sac.py b/sota-implementations/multiagent/sac.py index f7b2523010b..cfafdd47c96 100644 --- a/sota-implementations/multiagent/sac.py +++ b/sota-implementations/multiagent/sac.py @@ -318,6 +318,11 @@ def train(cfg: "DictConfig"): # noqa: F821 if cfg.logger.backend == "wandb": logger.experiment.log({}, commit=True) sampling_start = time.time() + collector.shutdown() + if not env.is_closed: + env.close() + if not env_test.is_closed: + env_test.close() if __name__ == "__main__": diff --git a/sota-implementations/ppo/ppo_atari.py b/sota-implementations/ppo/ppo_atari.py index 2b02254032a..6d8883393d5 100644 --- a/sota-implementations/ppo/ppo_atari.py +++ b/sota-implementations/ppo/ppo_atari.py @@ -243,6 +243,9 @@ def main(cfg: "DictConfig"): # noqa: F821 sampling_start = time.time() collector.shutdown() + if not test_env.is_closed: + test_env.close() + end_time = time.time() execution_time = end_time - start_time torchrl_logger.info(f"Training took {execution_time:.2f} seconds to finish") diff --git a/sota-implementations/ppo/ppo_mujoco.py b/sota-implementations/ppo/ppo_mujoco.py index 219ae1b59b6..8cfea74d0bc 100644 --- a/sota-implementations/ppo/ppo_mujoco.py +++ b/sota-implementations/ppo/ppo_mujoco.py @@ -235,6 +235,9 @@ def main(cfg: "DictConfig"): # noqa: F821 collector.update_policy_weights_() sampling_start = time.time() + collector.shutdown() + if not test_env.is_closed: + test_env.close() end_time = time.time() execution_time = end_time - start_time torchrl_logger.info(f"Training took {execution_time:.2f} seconds to finish") diff --git a/sota-implementations/ppo/utils_atari.py b/sota-implementations/ppo/utils_atari.py index 2344da518bc..50f91ed49cd 100644 --- a/sota-implementations/ppo/utils_atari.py +++ b/sota-implementations/ppo/utils_atari.py @@ -6,8 +6,8 @@ import torch.nn import torch.optim from tensordict.nn import TensorDictModule -from torchrl.data import CompositeSpec -from torchrl.data.tensor_specs import DiscreteBox +from torchrl.data import Composite +from torchrl.data.tensor_specs import CategoricalBox from torchrl.envs import ( CatFrames, DoubleToFloat, @@ -92,7 +92,7 @@ def make_ppo_modules_pixels(proof_environment): input_shape = proof_environment.observation_spec["pixels"].shape # Define distribution class and kwargs - if isinstance(proof_environment.action_spec.space, DiscreteBox): + if isinstance(proof_environment.action_spec.space, CategoricalBox): num_outputs = proof_environment.action_spec.space.n distribution_class = OneHotCategorical distribution_kwargs = {} @@ -148,7 +148,7 @@ def make_ppo_modules_pixels(proof_environment): policy_module = ProbabilisticActor( policy_module, in_keys=["logits"], - spec=CompositeSpec(action=proof_environment.action_spec), + spec=Composite(action=proof_environment.action_spec), distribution_class=distribution_class, distribution_kwargs=distribution_kwargs, return_log_prob=True, diff --git a/sota-implementations/ppo/utils_mujoco.py b/sota-implementations/ppo/utils_mujoco.py index 7986738f8e6..a05d205b000 100644 --- a/sota-implementations/ppo/utils_mujoco.py +++ b/sota-implementations/ppo/utils_mujoco.py @@ -7,7 +7,7 @@ import torch.optim from tensordict.nn import AddStateIndependentNormalScale, TensorDictModule -from torchrl.data import CompositeSpec +from torchrl.data import Composite from torchrl.envs import ( ClipTransform, DoubleToFloat, @@ -87,7 +87,7 @@ def make_ppo_models_state(proof_environment): out_keys=["loc", "scale"], ), in_keys=["loc", "scale"], - spec=CompositeSpec(action=proof_environment.action_spec), + spec=Composite(action=proof_environment.action_spec), distribution_class=distribution_class, distribution_kwargs=distribution_kwargs, return_log_prob=True, diff --git a/sota-implementations/redq/config.yaml b/sota-implementations/redq/config.yaml index e60191c0f93..818f3386fda 100644 --- a/sota-implementations/redq/config.yaml +++ b/sota-implementations/redq/config.yaml @@ -36,7 +36,6 @@ collector: multi_step: 1 n_steps_return: 3 max_frames_per_traj: -1 - exploration_mode: random logger: backend: wandb diff --git a/sota-implementations/redq/redq.py b/sota-implementations/redq/redq.py index eb802f6773d..865533aee2f 100644 --- a/sota-implementations/redq/redq.py +++ b/sota-implementations/redq/redq.py @@ -159,7 +159,7 @@ def main(cfg: "DictConfig"): # noqa: F821 use_env_creator=False, )() if isinstance(create_env_fn, ParallelEnv): - raise NotImplementedError("This behaviour is deprecated") + raise NotImplementedError("This behavior is deprecated") elif isinstance(create_env_fn, EnvCreator): recorder.transform[1:].load_state_dict( get_norm_state_dict(create_env_fn()), strict=False diff --git a/sota-implementations/redq/utils.py b/sota-implementations/redq/utils.py index dd922372cbb..8312d359366 100644 --- a/sota-implementations/redq/utils.py +++ b/sota-implementations/redq/utils.py @@ -1021,7 +1021,6 @@ def make_collector_offpolicy( "init_random_frames": cfg.collector.init_random_frames, "split_trajs": True, # trajectories must be separated if multi-step is used - "exploration_type": ExplorationType.from_str(cfg.collector.exploration_mode), } collector = collector_helper(**collector_helper_kwargs) diff --git a/sota-implementations/sac/sac.py b/sota-implementations/sac/sac.py index 9904fe072ab..68860500149 100644 --- a/sota-implementations/sac/sac.py +++ b/sota-implementations/sac/sac.py @@ -215,6 +215,10 @@ def main(cfg: "DictConfig"): # noqa: F821 sampling_start = time.time() collector.shutdown() + if not eval_env.is_closed: + eval_env.close() + if not train_env.is_closed: + train_env.close() end_time = time.time() execution_time = end_time - start_time torchrl_logger.info(f"Training took {execution_time:.2f} seconds to finish") diff --git a/sota-implementations/td3/td3.py b/sota-implementations/td3/td3.py index 632ee58503d..01a59686ac9 100644 --- a/sota-implementations/td3/td3.py +++ b/sota-implementations/td3/td3.py @@ -213,6 +213,10 @@ def main(cfg: "DictConfig"): # noqa: F821 sampling_start = time.time() collector.shutdown() + if not eval_env.is_closed: + eval_env.close() + if not train_env.is_closed: + train_env.close() end_time = time.time() execution_time = end_time - start_time torchrl_logger.info(f"Training took {execution_time:.2f} seconds to finish") diff --git a/sota-implementations/td3_bc/td3_bc.py b/sota-implementations/td3_bc/td3_bc.py index 7c43fdc1a12..930ff509488 100644 --- a/sota-implementations/td3_bc/td3_bc.py +++ b/sota-implementations/td3_bc/td3_bc.py @@ -128,7 +128,7 @@ def main(cfg: "DictConfig"): # noqa: F821 # evaluation if i % evaluation_interval == 0: - with set_exploration_type(ExplorationType.MODE), torch.no_grad(): + with set_exploration_type(ExplorationType.DETERMINISTIC), torch.no_grad(): eval_td = eval_env.rollout( max_steps=eval_steps, policy=model[0], auto_cast_to_device=True ) @@ -138,6 +138,8 @@ def main(cfg: "DictConfig"): # noqa: F821 if logger is not None: log_metrics(logger, to_log, i) + if not eval_env.is_closed: + eval_env.close() pbar.close() torchrl_logger.info(f"Training time: {time.time() - start_time}") diff --git a/test/_utils_internal.py b/test/_utils_internal.py index 61b0c003f9d..51535afa606 100644 --- a/test/_utils_internal.py +++ b/test/_utils_internal.py @@ -56,7 +56,7 @@ def HALFCHEETAH_VERSIONED(): def PONG_VERSIONED(): # load gym - # Gymnasium says that the ale_py behaviour changes from 1.0 + # Gymnasium says that the ale_py behavior changes from 1.0 # but with python 3.12 it is already the case with 0.29.1 try: import ale_py # noqa @@ -70,7 +70,7 @@ def PONG_VERSIONED(): def BREAKOUT_VERSIONED(): # load gym - # Gymnasium says that the ale_py behaviour changes from 1.0 + # Gymnasium says that the ale_py behavior changes from 1.0 # but with python 3.12 it is already the case with 0.29.1 try: import ale_py # noqa @@ -121,7 +121,7 @@ def _set_gym_environments(): # noqa: F811 _BREAKOUT_VERSIONED = "ALE/Breakout-v5" -@implement_for("gymnasium") +@implement_for("gymnasium", None, "1.0.0") def _set_gym_environments(): # noqa: F811 global _CARTPOLE_VERSIONED, _HALFCHEETAH_VERSIONED, _PENDULUM_VERSIONED, _PONG_VERSIONED, _BREAKOUT_VERSIONED @@ -132,6 +132,11 @@ def _set_gym_environments(): # noqa: F811 _BREAKOUT_VERSIONED = "ALE/Breakout-v5" +@implement_for("gymnasium", "1.0.0", None) +def _set_gym_environments(): # noqa: F811 + raise ImportError + + if _has_gym: _set_gym_environments() @@ -155,6 +160,8 @@ def get_default_devices(): return [torch.device("cpu")] elif num_cuda == 1: return [torch.device("cuda:0")] + elif torch.mps.is_available(): + return [torch.device("mps:0")] else: # then run on all devices return get_available_devices() diff --git a/test/conftest.py b/test/conftest.py index ca418d7b6f2..f2648a18041 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -113,6 +113,18 @@ def pytest_addoption(parser): help="Use 'fork' start method for mp dedicated tests only if there is no cuda device available.", ) + parser.addoption( + "--unity_editor", + action="store_true", + default=False, + help="Run tests that require manually pressing play in the Unity editor.", + ) + + +def pytest_runtest_setup(item): + if "unity_editor" in item.keywords and not item.config.getoption("--unity_editor"): + pytest.skip("need --unity_editor option to run this test") + def pytest_configure(config): config.addinivalue_line("markers", "slow: mark test as slow to run") diff --git a/test/mocking_classes.py b/test/mocking_classes.py index ea4327bb460..4d86d8ec0ac 100644 --- a/test/mocking_classes.py +++ b/test/mocking_classes.py @@ -11,15 +11,15 @@ from tensordict.utils import expand_right, NestedKey from torchrl.data.tensor_specs import ( - BinaryDiscreteTensorSpec, - BoundedTensorSpec, - CompositeSpec, - DiscreteTensorSpec, - MultiOneHotDiscreteTensorSpec, - NonTensorSpec, - OneHotDiscreteTensorSpec, + Binary, + Bounded, + Categorical, + Composite, + MultiOneHot, + NonTensor, + OneHot, TensorSpec, - UnboundedContinuousTensorSpec, + Unbounded, ) from torchrl.data.utils import consolidate_spec from torchrl.envs.common import EnvBase @@ -27,27 +27,27 @@ from torchrl.envs.utils import _terminated_or_truncated spec_dict = { - "bounded": BoundedTensorSpec, - "one_hot": OneHotDiscreteTensorSpec, - "categorical": DiscreteTensorSpec, - "unbounded": UnboundedContinuousTensorSpec, - "binary": BinaryDiscreteTensorSpec, - "mult_one_hot": MultiOneHotDiscreteTensorSpec, - "composite": CompositeSpec, + "bounded": Bounded, + "one_hot": OneHot, + "categorical": Categorical, + "unbounded": Unbounded, + "binary": Binary, + "mult_one_hot": MultiOneHot, + "composite": Composite, } default_spec_kwargs = { - OneHotDiscreteTensorSpec: {"n": 7}, - DiscreteTensorSpec: {"n": 7}, - BoundedTensorSpec: {"minimum": -torch.ones(4), "maximum": torch.ones(4)}, - UnboundedContinuousTensorSpec: { + OneHot: {"n": 7}, + Categorical: {"n": 7}, + Bounded: {"minimum": -torch.ones(4), "maximum": torch.ones(4)}, + Unbounded: { "shape": [ 7, ] }, - BinaryDiscreteTensorSpec: {"n": 7}, - MultiOneHotDiscreteTensorSpec: {"nvec": [7, 3, 5]}, - CompositeSpec: {}, + Binary: {"n": 7}, + MultiOneHot: {"nvec": [7, 3, 5]}, + Composite: {}, } @@ -68,8 +68,8 @@ def __new__( torch.get_default_dtype() ) reward_spec = cls._output_spec["full_reward_spec"] - if isinstance(reward_spec, CompositeSpec): - reward_spec = CompositeSpec( + if isinstance(reward_spec, Composite): + reward_spec = Composite( { key: item.to(torch.get_default_dtype()) for key, item in reward_spec.items(True, True) @@ -80,19 +80,19 @@ def __new__( else: reward_spec = reward_spec.to(torch.get_default_dtype()) cls._output_spec["full_reward_spec"] = reward_spec - if not isinstance(cls._output_spec["full_reward_spec"], CompositeSpec): - cls._output_spec["full_reward_spec"] = CompositeSpec( + if not isinstance(cls._output_spec["full_reward_spec"], Composite): + cls._output_spec["full_reward_spec"] = Composite( reward=cls._output_spec["full_reward_spec"], shape=cls._output_spec["full_reward_spec"].shape[:-1], ) - if not isinstance(cls._output_spec["full_done_spec"], CompositeSpec): - cls._output_spec["full_done_spec"] = CompositeSpec( + if not isinstance(cls._output_spec["full_done_spec"], Composite): + cls._output_spec["full_done_spec"] = Composite( done=cls._output_spec["full_done_spec"].clone(), terminated=cls._output_spec["full_done_spec"].clone(), shape=cls._output_spec["full_done_spec"].shape[:-1], ) - if not isinstance(cls._input_spec["full_action_spec"], CompositeSpec): - cls._input_spec["full_action_spec"] = CompositeSpec( + if not isinstance(cls._input_spec["full_action_spec"], Composite): + cls._input_spec["full_action_spec"] = Composite( action=cls._input_spec["full_action_spec"], shape=cls._input_spec["full_action_spec"].shape[:-1], ) @@ -156,15 +156,15 @@ def __new__( ): batch_size = kwargs.setdefault("batch_size", torch.Size([])) if action_spec is None: - action_spec = UnboundedContinuousTensorSpec( + action_spec = Unbounded( ( *batch_size, 1, ) ) if observation_spec is None: - observation_spec = CompositeSpec( - observation=UnboundedContinuousTensorSpec( + observation_spec = Composite( + observation=Unbounded( ( *batch_size, 1, @@ -173,35 +173,35 @@ def __new__( shape=batch_size, ) if reward_spec is None: - reward_spec = UnboundedContinuousTensorSpec( + reward_spec = Unbounded( ( *batch_size, 1, ) ) if done_spec is None: - done_spec = DiscreteTensorSpec(2, dtype=torch.bool, shape=(*batch_size, 1)) + done_spec = Categorical(2, dtype=torch.bool, shape=(*batch_size, 1)) if state_spec is None: - state_spec = CompositeSpec(shape=batch_size) - input_spec = CompositeSpec( + state_spec = Composite(shape=batch_size) + input_spec = Composite( full_action_spec=action_spec, full_state_spec=state_spec, shape=batch_size ) - cls._output_spec = CompositeSpec(shape=batch_size) + cls._output_spec = Composite(shape=batch_size) cls._output_spec["full_reward_spec"] = reward_spec cls._output_spec["full_done_spec"] = done_spec cls._output_spec["full_observation_spec"] = observation_spec cls._input_spec = input_spec - if not isinstance(cls._output_spec["full_reward_spec"], CompositeSpec): - cls._output_spec["full_reward_spec"] = CompositeSpec( + if not isinstance(cls._output_spec["full_reward_spec"], Composite): + cls._output_spec["full_reward_spec"] = Composite( reward=cls._output_spec["full_reward_spec"], shape=batch_size ) - if not isinstance(cls._output_spec["full_done_spec"], CompositeSpec): - cls._output_spec["full_done_spec"] = CompositeSpec( + if not isinstance(cls._output_spec["full_done_spec"], Composite): + cls._output_spec["full_done_spec"] = Composite( done=cls._output_spec["full_done_spec"], shape=batch_size ) - if not isinstance(cls._input_spec["full_action_spec"], CompositeSpec): - cls._input_spec["full_action_spec"] = CompositeSpec( + if not isinstance(cls._input_spec["full_action_spec"], Composite): + cls._input_spec["full_action_spec"] = Composite( action=cls._input_spec["full_action_spec"], shape=batch_size ) return super().__new__(*args, **kwargs) @@ -268,15 +268,15 @@ def __new__( ): batch_size = kwargs.setdefault("batch_size", torch.Size([])) if action_spec is None: - action_spec = UnboundedContinuousTensorSpec( + action_spec = Unbounded( ( *batch_size, 1, ) ) if state_spec is None: - state_spec = CompositeSpec( - observation=UnboundedContinuousTensorSpec( + state_spec = Composite( + observation=Unbounded( ( *batch_size, 1, @@ -285,8 +285,8 @@ def __new__( shape=batch_size, ) if observation_spec is None: - observation_spec = CompositeSpec( - observation=UnboundedContinuousTensorSpec( + observation_spec = Composite( + observation=Unbounded( ( *batch_size, 1, @@ -295,33 +295,33 @@ def __new__( shape=batch_size, ) if reward_spec is None: - reward_spec = UnboundedContinuousTensorSpec( + reward_spec = Unbounded( ( *batch_size, 1, ) ) if done_spec is None: - done_spec = DiscreteTensorSpec(2, dtype=torch.bool, shape=(*batch_size, 1)) - cls._output_spec = CompositeSpec(shape=batch_size) + done_spec = Categorical(2, dtype=torch.bool, shape=(*batch_size, 1)) + cls._output_spec = Composite(shape=batch_size) cls._output_spec["full_reward_spec"] = reward_spec cls._output_spec["full_done_spec"] = done_spec cls._output_spec["full_observation_spec"] = observation_spec - cls._input_spec = CompositeSpec( + cls._input_spec = Composite( full_action_spec=action_spec, full_state_spec=state_spec, shape=batch_size, ) - if not isinstance(cls._output_spec["full_reward_spec"], CompositeSpec): - cls._output_spec["full_reward_spec"] = CompositeSpec( + if not isinstance(cls._output_spec["full_reward_spec"], Composite): + cls._output_spec["full_reward_spec"] = Composite( reward=cls._output_spec["full_reward_spec"], shape=batch_size ) - if not isinstance(cls._output_spec["full_done_spec"], CompositeSpec): - cls._output_spec["full_done_spec"] = CompositeSpec( + if not isinstance(cls._output_spec["full_done_spec"], Composite): + cls._output_spec["full_done_spec"] = Composite( done=cls._output_spec["full_done_spec"], shape=batch_size ) - if not isinstance(cls._input_spec["full_action_spec"], CompositeSpec): - cls._input_spec["full_action_spec"] = CompositeSpec( + if not isinstance(cls._input_spec["full_action_spec"], Composite): + cls._input_spec["full_action_spec"] = Composite( action=cls._input_spec["full_action_spec"], shape=batch_size ) return super().__new__(cls, *args, **kwargs) @@ -442,46 +442,38 @@ def __new__( size = cls.size = 7 if observation_spec is None: cls.out_key = "observation" - observation_spec = CompositeSpec( - observation=UnboundedContinuousTensorSpec( - shape=torch.Size([*batch_size, size]) - ), - observation_orig=UnboundedContinuousTensorSpec( - shape=torch.Size([*batch_size, size]) - ), + observation_spec = Composite( + observation=Unbounded(shape=torch.Size([*batch_size, size])), + observation_orig=Unbounded(shape=torch.Size([*batch_size, size])), shape=batch_size, ) if action_spec is None: if categorical_action_encoding: - action_spec_cls = DiscreteTensorSpec + action_spec_cls = Categorical action_spec = action_spec_cls(n=7, shape=batch_size) else: - action_spec_cls = OneHotDiscreteTensorSpec + action_spec_cls = OneHot action_spec = action_spec_cls(n=7, shape=(*batch_size, 7)) if reward_spec is None: - reward_spec = CompositeSpec( - reward=UnboundedContinuousTensorSpec(shape=(1,)) - ) + reward_spec = Composite(reward=Unbounded(shape=(1,))) if done_spec is None: - done_spec = CompositeSpec( - terminated=DiscreteTensorSpec( - 2, dtype=torch.bool, shape=(*batch_size, 1) - ) + done_spec = Composite( + terminated=Categorical(2, dtype=torch.bool, shape=(*batch_size, 1)) ) if state_spec is None: cls._out_key = "observation_orig" - state_spec = CompositeSpec( + state_spec = Composite( { cls._out_key: observation_spec["observation"], }, shape=batch_size, ) - cls._output_spec = CompositeSpec(shape=batch_size) + cls._output_spec = Composite(shape=batch_size) cls._output_spec["full_reward_spec"] = reward_spec cls._output_spec["full_done_spec"] = done_spec cls._output_spec["full_observation_spec"] = observation_spec - cls._input_spec = CompositeSpec( + cls._input_spec = Composite( full_action_spec=action_spec, full_state_spec=state_spec, shape=batch_size, @@ -553,17 +545,13 @@ def __new__( size = cls.size = 7 if observation_spec is None: cls.out_key = "observation" - observation_spec = CompositeSpec( - observation=UnboundedContinuousTensorSpec( - shape=torch.Size([*batch_size, size]) - ), - observation_orig=UnboundedContinuousTensorSpec( - shape=torch.Size([*batch_size, size]) - ), + observation_spec = Composite( + observation=Unbounded(shape=torch.Size([*batch_size, size])), + observation_orig=Unbounded(shape=torch.Size([*batch_size, size])), shape=batch_size, ) if action_spec is None: - action_spec = BoundedTensorSpec( + action_spec = Bounded( -1, 1, ( @@ -572,23 +560,23 @@ def __new__( ), ) if reward_spec is None: - reward_spec = UnboundedContinuousTensorSpec(shape=(*batch_size, 1)) + reward_spec = Unbounded(shape=(*batch_size, 1)) if done_spec is None: - done_spec = DiscreteTensorSpec(2, dtype=torch.bool, shape=(*batch_size, 1)) + done_spec = Categorical(2, dtype=torch.bool, shape=(*batch_size, 1)) if state_spec is None: cls._out_key = "observation_orig" - state_spec = CompositeSpec( + state_spec = Composite( { cls._out_key: observation_spec["observation"], }, shape=batch_size, ) - cls._output_spec = CompositeSpec(shape=batch_size) + cls._output_spec = Composite(shape=batch_size) cls._output_spec["full_reward_spec"] = reward_spec cls._output_spec["full_done_spec"] = done_spec cls._output_spec["full_observation_spec"] = observation_spec - cls._input_spec = CompositeSpec( + cls._input_spec = Composite( full_action_spec=action_spec, full_state_spec=state_spec, shape=batch_size, @@ -681,25 +669,21 @@ def __new__( batch_size = kwargs.setdefault("batch_size", torch.Size([])) if observation_spec is None: cls.out_key = "pixels" - observation_spec = CompositeSpec( - pixels=UnboundedContinuousTensorSpec( - shape=torch.Size([*batch_size, 1, 7, 7]) - ), - pixels_orig=UnboundedContinuousTensorSpec( - shape=torch.Size([*batch_size, 1, 7, 7]) - ), + observation_spec = Composite( + pixels=Unbounded(shape=torch.Size([*batch_size, 1, 7, 7])), + pixels_orig=Unbounded(shape=torch.Size([*batch_size, 1, 7, 7])), shape=batch_size, ) if action_spec is None: - action_spec = OneHotDiscreteTensorSpec(7, shape=(*batch_size, 7)) + action_spec = OneHot(7, shape=(*batch_size, 7)) if reward_spec is None: - reward_spec = UnboundedContinuousTensorSpec(shape=(*batch_size, 1)) + reward_spec = Unbounded(shape=(*batch_size, 1)) if done_spec is None: - done_spec = DiscreteTensorSpec(2, dtype=torch.bool, shape=(*batch_size, 1)) + done_spec = Categorical(2, dtype=torch.bool, shape=(*batch_size, 1)) if state_spec is None: cls._out_key = "pixels_orig" - state_spec = CompositeSpec( + state_spec = Composite( { cls._out_key: observation_spec["pixels_orig"].clone(), }, @@ -741,25 +725,17 @@ def __new__( batch_size = kwargs.setdefault("batch_size", torch.Size([])) if observation_spec is None: cls.out_key = "pixels" - observation_spec = CompositeSpec( - pixels=UnboundedContinuousTensorSpec( - shape=torch.Size([*batch_size, 7, 7, 3]) - ), - pixels_orig=UnboundedContinuousTensorSpec( - shape=torch.Size([*batch_size, 7, 7, 3]) - ), + observation_spec = Composite( + pixels=Unbounded(shape=torch.Size([*batch_size, 7, 7, 3])), + pixels_orig=Unbounded(shape=torch.Size([*batch_size, 7, 7, 3])), shape=batch_size, ) if action_spec is None: - action_spec_cls = ( - DiscreteTensorSpec - if categorical_action_encoding - else OneHotDiscreteTensorSpec - ) + action_spec_cls = Categorical if categorical_action_encoding else OneHot action_spec = action_spec_cls(7, shape=(*batch_size, 7)) if state_spec is None: cls._out_key = "pixels_orig" - state_spec = CompositeSpec( + state_spec = Composite( { cls._out_key: observation_spec["pixels_orig"], }, @@ -808,25 +784,21 @@ def __new__( pixel_shape = [1, 7, 7] if observation_spec is None: cls.out_key = "pixels" - observation_spec = CompositeSpec( - pixels=UnboundedContinuousTensorSpec( - shape=torch.Size([*batch_size, *pixel_shape]) - ), - pixels_orig=UnboundedContinuousTensorSpec( - shape=torch.Size([*batch_size, *pixel_shape]) - ), + observation_spec = Composite( + pixels=Unbounded(shape=torch.Size([*batch_size, *pixel_shape])), + pixels_orig=Unbounded(shape=torch.Size([*batch_size, *pixel_shape])), shape=batch_size, ) if action_spec is None: - action_spec = BoundedTensorSpec(-1, 1, [*batch_size, pixel_shape[-1]]) + action_spec = Bounded(-1, 1, [*batch_size, pixel_shape[-1]]) if reward_spec is None: - reward_spec = UnboundedContinuousTensorSpec(shape=(*batch_size, 1)) + reward_spec = Unbounded(shape=(*batch_size, 1)) if done_spec is None: - done_spec = DiscreteTensorSpec(2, dtype=torch.bool, shape=(*batch_size, 1)) + done_spec = Categorical(2, dtype=torch.bool, shape=(*batch_size, 1)) if state_spec is None: cls._out_key = "pixels_orig" - state_spec = CompositeSpec( + state_spec = Composite( {cls._out_key: observation_spec["pixels"]}, shape=batch_size ) return super().__new__( @@ -865,13 +837,9 @@ def __new__( batch_size = kwargs.setdefault("batch_size", torch.Size([])) if observation_spec is None: cls.out_key = "pixels" - observation_spec = CompositeSpec( - pixels=UnboundedContinuousTensorSpec( - shape=torch.Size([*batch_size, 7, 7, 3]) - ), - pixels_orig=UnboundedContinuousTensorSpec( - shape=torch.Size([*batch_size, 7, 7, 3]) - ), + observation_spec = Composite( + pixels=Unbounded(shape=torch.Size([*batch_size, 7, 7, 3])), + pixels_orig=Unbounded(shape=torch.Size([*batch_size, 7, 7, 3])), ) return super().__new__( *args, @@ -928,8 +896,8 @@ def __init__( device=device, batch_size=batch_size, ) - self.observation_spec = CompositeSpec( - hidden_observation=UnboundedContinuousTensorSpec( + self.observation_spec = Composite( + hidden_observation=Unbounded( ( *self.batch_size, 4, @@ -937,8 +905,8 @@ def __init__( ), shape=self.batch_size, ) - self.state_spec = CompositeSpec( - hidden_observation=UnboundedContinuousTensorSpec( + self.state_spec = Composite( + hidden_observation=Unbounded( ( *self.batch_size, 4, @@ -946,13 +914,13 @@ def __init__( ), shape=self.batch_size, ) - self.action_spec = UnboundedContinuousTensorSpec( + self.action_spec = Unbounded( ( *self.batch_size, 1, ) ) - self.reward_spec = UnboundedContinuousTensorSpec( + self.reward_spec = Unbounded( ( *self.batch_size, 1, @@ -1012,8 +980,8 @@ def __init__(self, max_steps: int = 5, start_val: int = 0, **kwargs): self.max_steps = max_steps self.start_val = start_val - self.observation_spec = CompositeSpec( - observation=UnboundedContinuousTensorSpec( + self.observation_spec = Composite( + observation=Unbounded( ( *self.batch_size, 1, @@ -1024,14 +992,14 @@ def __init__(self, max_steps: int = 5, start_val: int = 0, **kwargs): shape=self.batch_size, device=self.device, ) - self.reward_spec = UnboundedContinuousTensorSpec( + self.reward_spec = Unbounded( ( *self.batch_size, 1, ), device=self.device, ) - self.done_spec = DiscreteTensorSpec( + self.done_spec = Categorical( 2, dtype=torch.bool, shape=( @@ -1040,9 +1008,7 @@ def __init__(self, max_steps: int = 5, start_val: int = 0, **kwargs): ), device=self.device, ) - self.action_spec = BinaryDiscreteTensorSpec( - n=1, shape=[*self.batch_size, 1], device=self.device - ) + self.action_spec = Binary(n=1, shape=[*self.batch_size, 1], device=self.device) self.register_buffer( "count", torch.zeros((*self.batch_size, 1), device=self.device, dtype=torch.int), @@ -1072,7 +1038,10 @@ def _step( tensordict: TensorDictBase, ) -> TensorDictBase: action = tensordict.get(self.action_key) - self.count += action.to(dtype=torch.int, device=self.device) + self.count += action.to( + dtype=torch.int, + device=self.action_spec.device if self.device is None else self.device, + ) tensordict = TensorDict( source={ "observation": self.count.clone(), @@ -1129,9 +1098,9 @@ def __init__( self.nested_reward = nest_reward if self.nested_obs_action: - self.observation_spec = CompositeSpec( + self.observation_spec = Composite( { - "data": CompositeSpec( + "data": Composite( { "states": self.observation_spec["observation"] .unsqueeze(-1) @@ -1145,9 +1114,9 @@ def __init__( }, shape=self.batch_size, ) - self.action_spec = CompositeSpec( + self.action_spec = Composite( { - "data": CompositeSpec( + "data": Composite( { "action": self.action_spec.unsqueeze(-1).expand( *self.batch_size, self.nested_dim, 1 @@ -1163,9 +1132,9 @@ def __init__( ) if self.nested_reward: - self.reward_spec = CompositeSpec( + self.reward_spec = Composite( { - "data": CompositeSpec( + "data": Composite( { "reward": self.reward_spec.unsqueeze(-1).expand( *self.batch_size, self.nested_dim, 1 @@ -1184,12 +1153,12 @@ def __init__( done_spec = self.full_done_spec.unsqueeze(-1).expand( *self.batch_size, self.nested_dim ) - done_spec = CompositeSpec( + done_spec = Composite( {"data": done_spec}, shape=self.batch_size, ) if self.has_root_done: - done_spec["done"] = DiscreteTensorSpec( + done_spec["done"] = Categorical( 2, shape=( *self.batch_size, @@ -1309,8 +1278,8 @@ def __init__( self.max_steps = max_steps - self.observation_spec = CompositeSpec( - observation=UnboundedContinuousTensorSpec( + self.observation_spec = Composite( + observation=Unbounded( ( *self.batch_size, 1, @@ -1319,13 +1288,13 @@ def __init__( ), shape=self.batch_size, ) - self.reward_spec = UnboundedContinuousTensorSpec( + self.reward_spec = Unbounded( ( *self.batch_size, 1, ) ) - self.done_spec = DiscreteTensorSpec( + self.done_spec = Categorical( 2, dtype=torch.bool, shape=( @@ -1333,7 +1302,7 @@ def __init__( 1, ), ) - self.action_spec = BinaryDiscreteTensorSpec(n=1, shape=[*self.batch_size, 1]) + self.action_spec = Binary(n=1, shape=[*self.batch_size, 1]) self.count = torch.zeros( (*self.batch_size, 1), device=self.device, dtype=torch.int @@ -1419,34 +1388,30 @@ def _make_specs(self): obs_spec_unlazy = consolidate_spec(obs_specs) action_specs = torch.stack(action_specs, dim=0) - self.unbatched_observation_spec = CompositeSpec( + self.unbatched_observation_spec = Composite( lazy=obs_spec_unlazy, - state=UnboundedContinuousTensorSpec(shape=(64, 64, 3)), + state=Unbounded(shape=(64, 64, 3)), device=self.device, ) - self.unbatched_action_spec = CompositeSpec( + self.unbatched_action_spec = Composite( lazy=action_specs, device=self.device, ) - self.unbatched_reward_spec = CompositeSpec( + self.unbatched_reward_spec = Composite( { - "lazy": CompositeSpec( - { - "reward": UnboundedContinuousTensorSpec( - shape=(self.n_nested_dim, 1) - ) - }, + "lazy": Composite( + {"reward": Unbounded(shape=(self.n_nested_dim, 1))}, shape=(self.n_nested_dim,), ) }, device=self.device, ) - self.unbatched_done_spec = CompositeSpec( + self.unbatched_done_spec = Composite( { - "lazy": CompositeSpec( + "lazy": Composite( { - "done": DiscreteTensorSpec( + "done": Categorical( n=2, shape=(self.n_nested_dim, 1), dtype=torch.bool, @@ -1472,17 +1437,17 @@ def _make_specs(self): ) def get_agent_obs_spec(self, i): - camera = BoundedTensorSpec(low=0, high=200, shape=(7, 7, 3)) - vector_3d = UnboundedContinuousTensorSpec(shape=(3,)) - vector_2d = UnboundedContinuousTensorSpec(shape=(2,)) - lidar = BoundedTensorSpec(low=0, high=5, shape=(8,)) + camera = Bounded(low=0, high=200, shape=(7, 7, 3)) + vector_3d = Unbounded(shape=(3,)) + vector_2d = Unbounded(shape=(2,)) + lidar = Bounded(low=0, high=5, shape=(8,)) - tensor_0 = UnboundedContinuousTensorSpec(shape=(1,)) - tensor_1 = BoundedTensorSpec(low=0, high=3, shape=(1, 2)) - tensor_2 = UnboundedContinuousTensorSpec(shape=(1, 2, 3)) + tensor_0 = Unbounded(shape=(1,)) + tensor_1 = Bounded(low=0, high=3, shape=(1, 2)) + tensor_2 = Unbounded(shape=(1, 2, 3)) if i == 0: - return CompositeSpec( + return Composite( { "camera": camera, "lidar": lidar, @@ -1492,7 +1457,7 @@ def get_agent_obs_spec(self, i): device=self.device, ) elif i == 1: - return CompositeSpec( + return Composite( { "camera": camera, "lidar": lidar, @@ -1502,7 +1467,7 @@ def get_agent_obs_spec(self, i): device=self.device, ) elif i == 2: - return CompositeSpec( + return Composite( { "camera": camera, "vector": vector_2d, @@ -1514,8 +1479,8 @@ def get_agent_obs_spec(self, i): raise ValueError(f"Index {i} undefined for index 3") def get_agent_action_spec(self, i): - action_3d = BoundedTensorSpec(low=-1, high=1, shape=(3,)) - action_2d = BoundedTensorSpec(low=-1, high=1, shape=(2,)) + action_3d = Bounded(low=-1, high=1, shape=(3,)) + action_2d = Bounded(low=-1, high=1, shape=(2,)) # Some have 2d action and some 3d # TODO Introduce composite heterogeneous actions @@ -1528,7 +1493,7 @@ def get_agent_action_spec(self, i): else: raise ValueError(f"Index {i} undefined for index 3") - return CompositeSpec({"action": ret}) + return Composite({"action": ret}) def _reset( self, @@ -1659,18 +1624,16 @@ def __init__(self, max_steps: int = 5, start_val: int = 0, **kwargs): ) def make_specs(self): - self.unbatched_observation_spec = CompositeSpec( - nested_1=CompositeSpec( - observation=BoundedTensorSpec( - low=0, high=200, shape=(self.nested_dim_1, 3) - ), + self.unbatched_observation_spec = Composite( + nested_1=Composite( + observation=Bounded(low=0, high=200, shape=(self.nested_dim_1, 3)), shape=(self.nested_dim_1,), ), - nested_2=CompositeSpec( - observation=UnboundedContinuousTensorSpec(shape=(self.nested_dim_2, 2)), + nested_2=Composite( + observation=Unbounded(shape=(self.nested_dim_2, 2)), shape=(self.nested_dim_2,), ), - observation=UnboundedContinuousTensorSpec( + observation=Unbounded( shape=( 10, 10, @@ -1679,51 +1642,51 @@ def make_specs(self): ), ) - self.unbatched_action_spec = CompositeSpec( - nested_1=CompositeSpec( - action=DiscreteTensorSpec(n=2, shape=(self.nested_dim_1,)), + self.unbatched_action_spec = Composite( + nested_1=Composite( + action=Categorical(n=2, shape=(self.nested_dim_1,)), shape=(self.nested_dim_1,), ), - nested_2=CompositeSpec( - azione=BoundedTensorSpec(low=0, high=100, shape=(self.nested_dim_2, 1)), + nested_2=Composite( + azione=Bounded(low=0, high=100, shape=(self.nested_dim_2, 1)), shape=(self.nested_dim_2,), ), - action=OneHotDiscreteTensorSpec(n=2), + action=OneHot(n=2), ) - self.unbatched_reward_spec = CompositeSpec( - nested_1=CompositeSpec( - gift=UnboundedContinuousTensorSpec(shape=(self.nested_dim_1, 1)), + self.unbatched_reward_spec = Composite( + nested_1=Composite( + gift=Unbounded(shape=(self.nested_dim_1, 1)), shape=(self.nested_dim_1,), ), - nested_2=CompositeSpec( - reward=UnboundedContinuousTensorSpec(shape=(self.nested_dim_2, 1)), + nested_2=Composite( + reward=Unbounded(shape=(self.nested_dim_2, 1)), shape=(self.nested_dim_2,), ), - reward=UnboundedContinuousTensorSpec(shape=(1,)), + reward=Unbounded(shape=(1,)), ) - self.unbatched_done_spec = CompositeSpec( - nested_1=CompositeSpec( - done=DiscreteTensorSpec( + self.unbatched_done_spec = Composite( + nested_1=Composite( + done=Categorical( n=2, shape=(self.nested_dim_1, 1), dtype=torch.bool, ), - terminated=DiscreteTensorSpec( + terminated=Categorical( n=2, shape=(self.nested_dim_1, 1), dtype=torch.bool, ), shape=(self.nested_dim_1,), ), - nested_2=CompositeSpec( - done=DiscreteTensorSpec( + nested_2=Composite( + done=Categorical( n=2, shape=(self.nested_dim_2, 1), dtype=torch.bool, ), - terminated=DiscreteTensorSpec( + terminated=Categorical( n=2, shape=(self.nested_dim_2, 1), dtype=torch.bool, @@ -1731,12 +1694,12 @@ def make_specs(self): shape=(self.nested_dim_2,), ), # done at the root always prevail - done=DiscreteTensorSpec( + done=Categorical( n=2, shape=(1,), dtype=torch.bool, ), - terminated=DiscreteTensorSpec( + terminated=Categorical( n=2, shape=(1,), dtype=torch.bool, @@ -1829,15 +1792,15 @@ def _set_seed(self, seed: Optional[int]): class EnvWithMetadata(EnvBase): def __init__(self): super().__init__() - self.observation_spec = CompositeSpec( - tensor=UnboundedContinuousTensorSpec(3), - non_tensor=NonTensorSpec(shape=()), + self.observation_spec = Composite( + tensor=Unbounded(3), + non_tensor=NonTensor(shape=()), ) - self.state_spec = CompositeSpec( - non_tensor=NonTensorSpec(shape=()), + self.state_spec = Composite( + non_tensor=NonTensor(shape=()), ) - self.reward_spec = UnboundedContinuousTensorSpec(1) - self.action_spec = UnboundedContinuousTensorSpec(1) + self.reward_spec = Unbounded(1) + self.action_spec = Unbounded(1) def _reset(self, tensordict): data = self.observation_spec.zero() @@ -1935,16 +1898,16 @@ def _reset(self, tensordict=None): class EnvWithDynamicSpec(EnvBase): def __init__(self, max_count=5): super().__init__(batch_size=()) - self.observation_spec = CompositeSpec( - observation=UnboundedContinuousTensorSpec(shape=(3, -1, 2)), + self.observation_spec = Composite( + observation=Unbounded(shape=(3, -1, 2)), ) - self.action_spec = BoundedTensorSpec(low=-1, high=1, shape=(2,)) - self.full_done_spec = CompositeSpec( - done=BinaryDiscreteTensorSpec(1, shape=(1,), dtype=torch.bool), - terminated=BinaryDiscreteTensorSpec(1, shape=(1,), dtype=torch.bool), - truncated=BinaryDiscreteTensorSpec(1, shape=(1,), dtype=torch.bool), + self.action_spec = Bounded(low=-1, high=1, shape=(2,)) + self.full_done_spec = Composite( + done=Binary(1, shape=(1,), dtype=torch.bool), + terminated=Binary(1, shape=(1,), dtype=torch.bool), + truncated=Binary(1, shape=(1,), dtype=torch.bool), ) - self.reward_spec = UnboundedContinuousTensorSpec((1,), dtype=torch.float) + self.reward_spec = Unbounded((1,), dtype=torch.float) self.count = 0 self.max_count = max_count diff --git a/test/test_actors.py b/test/test_actors.py index 2d160e31bba..b81f322b708 100644 --- a/test/test_actors.py +++ b/test/test_actors.py @@ -14,14 +14,7 @@ from tensordict.nn.distributions import NormalParamExtractor from torch import distributions as dist, nn -from torchrl.data import ( - BinaryDiscreteTensorSpec, - BoundedTensorSpec, - CompositeSpec, - DiscreteTensorSpec, - MultiOneHotDiscreteTensorSpec, - OneHotDiscreteTensorSpec, -) +from torchrl.data import Binary, Bounded, Categorical, Composite, MultiOneHot, OneHot from torchrl.data.rlhf.dataset import _has_transformers from torchrl.modules import MLP, SafeModule, TanhDelta, TanhNormal from torchrl.modules.tensordict_module.actors import ( @@ -50,9 +43,7 @@ ) def test_probabilistic_actor_nested_delta(log_prob_key, nested_dim=5, n_actions=3): env = NestedCountingEnv(nested_dim=nested_dim) - action_spec = BoundedTensorSpec( - shape=torch.Size((nested_dim, n_actions)), high=1, low=-1 - ) + action_spec = Bounded(shape=torch.Size((nested_dim, n_actions)), high=1, low=-1) policy_module = TensorDictModule( nn.Linear(1, 1), in_keys=[("data", "states")], out_keys=[("data", "param")] ) @@ -63,8 +54,8 @@ def test_probabilistic_actor_nested_delta(log_prob_key, nested_dim=5, n_actions= out_keys=[("data", "action")], distribution_class=TanhDelta, distribution_kwargs={ - "min": action_spec.space.low, - "max": action_spec.space.high, + "low": action_spec.space.low, + "high": action_spec.space.high, }, log_prob_key=log_prob_key, return_log_prob=True, @@ -86,8 +77,8 @@ def test_probabilistic_actor_nested_delta(log_prob_key, nested_dim=5, n_actions= out_keys=[("data", "action")], distribution_class=TanhDelta, distribution_kwargs={ - "min": action_spec.space.low, - "max": action_spec.space.high, + "low": action_spec.space.low, + "high": action_spec.space.high, }, log_prob_key=log_prob_key, return_log_prob=True, @@ -111,9 +102,7 @@ def test_probabilistic_actor_nested_delta(log_prob_key, nested_dim=5, n_actions= ) def test_probabilistic_actor_nested_normal(log_prob_key, nested_dim=5, n_actions=3): env = NestedCountingEnv(nested_dim=nested_dim) - action_spec = BoundedTensorSpec( - shape=torch.Size((nested_dim, n_actions)), high=1, low=-1 - ) + action_spec = Bounded(shape=torch.Size((nested_dim, n_actions)), high=1, low=-1) actor_net = nn.Sequential( nn.Linear(1, 2), NormalParamExtractor(), @@ -181,7 +170,7 @@ def test_distributional_qvalue_hook_wrong_action_space(self): DistributionalQValueHook(action_space="wrong_value", support=None) def test_distributional_qvalue_hook_conflicting_spec(self): - spec = OneHotDiscreteTensorSpec(3) + spec = OneHot(3) _process_action_space_spec("one-hot", spec) _process_action_space_spec("one_hot", spec) _process_action_space_spec("one_hot", None) @@ -190,19 +179,19 @@ def test_distributional_qvalue_hook_conflicting_spec(self): ValueError, match="The action spec and the action space do not match" ): _process_action_space_spec("multi-one-hot", spec) - spec = MultiOneHotDiscreteTensorSpec([3, 3]) + spec = MultiOneHot([3, 3]) _process_action_space_spec("multi-one-hot", spec) _process_action_space_spec(spec, spec) with pytest.raises( ValueError, match="Passing an action_space as a TensorSpec and a spec" ): - _process_action_space_spec(OneHotDiscreteTensorSpec(3), spec) + _process_action_space_spec(OneHot(3), spec) with pytest.raises( - ValueError, match="action_space cannot be of type CompositeSpec" + ValueError, match="action_space cannot be of type Composite" ): - _process_action_space_spec(CompositeSpec(), spec) + _process_action_space_spec(Composite(), spec) with pytest.raises(KeyError, match="action could not be found in the spec"): - _process_action_space_spec(None, CompositeSpec()) + _process_action_space_spec(None, Composite()) with pytest.raises( ValueError, match="Neither action_space nor spec was defined" ): @@ -248,10 +237,10 @@ def test_nested_keys(self, nested_action, batch_size, nested_dim=5): ValueError, match="Passing an action_space as a TensorSpec and a spec isn't allowed, unless they match.", ): - _process_action_space_spec(BinaryDiscreteTensorSpec(n=1), action_spec) - _process_action_space_spec(BinaryDiscreteTensorSpec(n=1), leaf_action_spec) + _process_action_space_spec(Binary(n=1), action_spec) + _process_action_space_spec(Binary(n=1), leaf_action_spec) with pytest.raises( - ValueError, match="action_space cannot be of type CompositeSpec" + ValueError, match="action_space cannot be of type Composite" ): _process_action_space_spec(action_spec, None) @@ -652,7 +641,7 @@ def test_value_based_policy(device): torch.manual_seed(0) obs_dim = 4 action_dim = 5 - action_spec = OneHotDiscreteTensorSpec(action_dim) + action_spec = OneHot(action_dim) def make_net(): net = MLP(in_features=obs_dim, out_features=action_dim, depth=2, device=device) @@ -681,9 +670,7 @@ def make_net(): assert (action.sum(-1) == 1).all() -@pytest.mark.parametrize( - "spec", [None, OneHotDiscreteTensorSpec(3), MultiOneHotDiscreteTensorSpec([3, 2])] -) +@pytest.mark.parametrize("spec", [None, OneHot(3), MultiOneHot([3, 2])]) @pytest.mark.parametrize( "action_space", [None, "one-hot", "one_hot", "mult-one-hot", "mult_one_hot"] ) @@ -706,12 +693,9 @@ def test_qvalactor_construct( QValueActor(**kwargs) return if ( - type(spec) is MultiOneHotDiscreteTensorSpec + type(spec) is MultiOneHot and action_space not in ("mult-one-hot", "mult_one_hot", None) - ) or ( - type(spec) is OneHotDiscreteTensorSpec - and action_space not in ("one-hot", "one_hot", None) - ): + ) or (type(spec) is OneHot and action_space not in ("one-hot", "one_hot", None)): with pytest.raises( ValueError, match="The action spec and the action space do not match" ): @@ -725,7 +709,7 @@ def test_value_based_policy_categorical(device): torch.manual_seed(0) obs_dim = 4 action_dim = 5 - action_spec = DiscreteTensorSpec(action_dim) + action_spec = Categorical(action_dim) def make_net(): net = MLP(in_features=obs_dim, out_features=action_dim, depth=2, device=device) diff --git a/test/test_collector.py b/test/test_collector.py index 7d7208aead0..9e6ccd79408 100644 --- a/test/test_collector.py +++ b/test/test_collector.py @@ -50,7 +50,12 @@ TensorDict, TensorDictBase, ) -from tensordict.nn import TensorDictModule, TensorDictModuleBase, TensorDictSequential +from tensordict.nn import ( + CudaGraphModule, + TensorDictModule, + TensorDictModuleBase, + TensorDictSequential, +) from torch import nn from torchrl._utils import ( @@ -68,13 +73,15 @@ ) from torchrl.collectors.utils import split_trajectories from torchrl.data import ( - CompositeSpec, + Composite, + LazyMemmapStorage, LazyTensorStorage, - NonTensorSpec, + NonTensor, ReplayBuffer, TensorSpec, - UnboundedContinuousTensorSpec, + Unbounded, ) +from torchrl.data.utils import CloudpickleWrapper from torchrl.envs import ( EnvBase, EnvCreator, @@ -114,19 +121,6 @@ def forward(self, observation): return self.linear(observation) -class TensorDictCompatiblePolicy(nn.Module): - def __init__(self, out_features: int): - super().__init__() - self.in_keys = ["observation"] - self.out_keys = ["action"] - self.linear = nn.LazyLinear(out_features) - - def forward(self, tensordict): - return tensordict.set( - self.out_keys[0], self.linear(tensordict.get(self.in_keys[0])) - ) - - class UnwrappablePolicy(nn.Module): def __init__(self, out_features: int): super().__init__() @@ -210,22 +204,16 @@ class DeviceLessEnv(EnvBase): def __init__(self, default_device): self.default_device = default_device super().__init__(device=None) - self.observation_spec = CompositeSpec( - observation=UnboundedContinuousTensorSpec((), device=default_device) + self.observation_spec = Composite( + observation=Unbounded((), device=default_device) ) - self.reward_spec = UnboundedContinuousTensorSpec(1, device=default_device) - self.full_done_spec = CompositeSpec( - done=UnboundedContinuousTensorSpec( - 1, dtype=torch.bool, device=self.default_device - ), - truncated=UnboundedContinuousTensorSpec( - 1, dtype=torch.bool, device=self.default_device - ), - terminated=UnboundedContinuousTensorSpec( - 1, dtype=torch.bool, device=self.default_device - ), + self.reward_spec = Unbounded(1, device=default_device) + self.full_done_spec = Composite( + done=Unbounded(1, dtype=torch.bool, device=self.default_device), + truncated=Unbounded(1, dtype=torch.bool, device=self.default_device), + terminated=Unbounded(1, dtype=torch.bool, device=self.default_device), ) - self.action_spec = UnboundedContinuousTensorSpec((), device=None) + self.action_spec = Unbounded((), device=None) assert self.device is None assert self.full_observation_spec is not None assert self.full_done_spec is not None @@ -268,29 +256,17 @@ class EnvWithDevice(EnvBase): def __init__(self, default_device): self.default_device = default_device super().__init__(device=self.default_device) - self.observation_spec = CompositeSpec( - observation=UnboundedContinuousTensorSpec( - (), device=self.default_device - ) - ) - self.reward_spec = UnboundedContinuousTensorSpec( - 1, device=self.default_device + self.observation_spec = Composite( + observation=Unbounded((), device=self.default_device) ) - self.full_done_spec = CompositeSpec( - done=UnboundedContinuousTensorSpec( - 1, dtype=torch.bool, device=self.default_device - ), - truncated=UnboundedContinuousTensorSpec( - 1, dtype=torch.bool, device=self.default_device - ), - terminated=UnboundedContinuousTensorSpec( - 1, dtype=torch.bool, device=self.default_device - ), + self.reward_spec = Unbounded(1, device=self.default_device) + self.full_done_spec = Composite( + done=Unbounded(1, dtype=torch.bool, device=self.default_device), + truncated=Unbounded(1, dtype=torch.bool, device=self.default_device), + terminated=Unbounded(1, dtype=torch.bool, device=self.default_device), device=self.default_device, ) - self.action_spec = UnboundedContinuousTensorSpec( - (), device=self.default_device - ) + self.action_spec = Unbounded((), device=self.default_device) assert self.device == _make_ordinal_device( torch.device(self.default_device) ) @@ -1295,7 +1271,7 @@ def make_env(): policy, copier, OrnsteinUhlenbeckProcessModule( - spec=CompositeSpec({key: None for key in policy.out_keys}) + spec=Composite({key: None for key in policy.out_keys}) ), ) @@ -1368,12 +1344,12 @@ def test_collector_output_keys( ], } if explicit_spec: - hidden_spec = UnboundedContinuousTensorSpec((1, hidden_size)) - policy_kwargs["spec"] = CompositeSpec( - action=UnboundedContinuousTensorSpec(), + hidden_spec = Unbounded((1, hidden_size)) + policy_kwargs["spec"] = Composite( + action=Unbounded(), hidden1=hidden_spec, hidden2=hidden_spec, - next=CompositeSpec(hidden1=hidden_spec, hidden2=hidden_spec), + next=Composite(hidden1=hidden_spec, hidden2=hidden_spec), ) policy = SafeModule(**policy_kwargs) @@ -1614,8 +1590,8 @@ def test_auto_wrap_error(self, collector_class, env_maker): policy = UnwrappablePolicy(out_features=env_maker().action_spec.shape[-1]) with pytest.raises( TypeError, - match=(r"Arguments to policy.forward are incompatible with entries in"), - ) if collector_class is SyncDataCollector else pytest.raises(EOFError): + match=("Arguments to policy.forward are incompatible with entries in"), + ): collector_class( **self._create_collector_kwargs(env_maker, collector_class, policy) ) @@ -1844,10 +1820,15 @@ def test_set_truncated(collector_cls): NestedCountingEnv(), InitTracker() ).add_truncated_keys() env = env_fn() - policy = env.rand_action + policy = CloudpickleWrapper(env.rand_action) if collector_cls == SyncDataCollector: collector = collector_cls( - env, policy=policy, frames_per_batch=20, total_frames=-1, set_truncated=True + env, + policy=policy, + frames_per_batch=20, + total_frames=-1, + set_truncated=True, + trust_policy=True, ) else: collector = collector_cls( @@ -1857,6 +1838,7 @@ def test_set_truncated(collector_cls): total_frames=-1, cat_results="stack", set_truncated=True, + trust_policy=True, ) try: for data in collector: @@ -2164,21 +2146,18 @@ def test_multi_collector_consistency( assert_allclose_td(c2.unsqueeze(0), d2) -@pytest.mark.skipif(not torch.cuda.device_count(), reason="No casting if no cuda") +@pytest.mark.skipif( + not torch.cuda.is_available() and not torch.mps.is_available(), + reason="No casting if no cuda", +) class TestUpdateParams: class DummyEnv(EnvBase): def __init__(self, device, batch_size=[]): # noqa: B006 super().__init__(batch_size=batch_size, device=device) self.state = torch.zeros(self.batch_size, device=device) - self.observation_spec = CompositeSpec( - state=UnboundedContinuousTensorSpec(shape=(), device=device) - ) - self.action_spec = UnboundedContinuousTensorSpec( - shape=batch_size, device=device - ) - self.reward_spec = UnboundedContinuousTensorSpec( - shape=(*batch_size, 1), device=device - ) + self.observation_spec = Composite(state=Unbounded(shape=(), device=device)) + self.action_spec = Unbounded(shape=batch_size, device=device) + self.reward_spec = Unbounded(shape=(*batch_size, 1), device=device) def _step( self, @@ -2228,8 +2207,8 @@ def forward(self, td): @pytest.mark.parametrize( "policy_device,env_device", [ - ["cpu", "cuda"], - ["cuda", "cpu"], + ["cpu", get_default_devices()[0]], + [get_default_devices()[0], "cpu"], # ["cpu", "cuda:0"], # 1226: faster execution # ["cuda:0", "cpu"], # ["cuda", "cuda:0"], @@ -2253,9 +2232,7 @@ def test_param_sync(self, give_weights, collector, policy_device, env_device): policy.param.data += 1 policy.buf.data += 2 if give_weights: - d = dict(policy.named_parameters()) - d.update(policy.named_buffers()) - p_w = TensorDict(d, []) + p_w = TensorDict.from_module(policy) else: p_w = None col.update_policy_weights_(p_w) @@ -2609,8 +2586,15 @@ def test_unique_traj_sync(self, cat_results): buffer.extend(d) assert c._use_buffers traj_ids = buffer[:].get(("collector", "traj_ids")) - # check that we have as many trajs as expected (no skip) - assert traj_ids.unique().numel() == traj_ids.max() + 1 + # Ideally, we'd like that (sorted_traj.values == sorted_traj.indices).all() + # but in practice, one env can reach the end of the rollout and do a reset + # (which we don't want to prevent) and increment the global traj count, + # when the others have not finished yet. In that case, this traj number will never + # appear. + # sorted_traj = traj_ids.unique().sort() + # assert (sorted_traj.values == sorted_traj.indices).all() + # assert traj_ids.unique().numel() == traj_ids.max() + 1 + # check that trajs are not overlapping if stack_results: sets = [ @@ -2670,6 +2654,89 @@ def test_dynamic_multiasync_collector(self): assert data.names[-1] == "time" +class TestCompile: + @pytest.mark.parametrize( + "collector_cls", + # Clearing compiled policies causes segfault on machines with cuda + [SyncDataCollector, MultiaSyncDataCollector, MultiSyncDataCollector] + if not torch.cuda.is_available() + else [SyncDataCollector], + ) + @pytest.mark.parametrize("compile_policy", [True, {}, {"mode": "default"}]) + @pytest.mark.parametrize( + "device", [torch.device("cuda:0" if torch.cuda.is_available() else "cpu")] + ) + def test_compiled_policy(self, collector_cls, compile_policy, device): + policy = TensorDictModule( + nn.Linear(7, 7, device=device), in_keys=["observation"], out_keys=["action"] + ) + make_env = functools.partial(ContinuousActionVecMockEnv, device=device) + if collector_cls is SyncDataCollector: + torch._dynamo.reset_code_caches() + collector = SyncDataCollector( + make_env(), + policy, + frames_per_batch=30, + total_frames=120, + compile_policy=compile_policy, + ) + assert collector.compiled_policy + else: + collector = collector_cls( + [make_env] * 2, + policy, + frames_per_batch=30, + total_frames=120, + compile_policy=compile_policy, + ) + assert collector.compiled_policy + try: + for data in collector: + assert data is not None + finally: + collector.shutdown() + del collector + + @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA is not available") + @pytest.mark.parametrize( + "collector_cls", + [SyncDataCollector], + ) + @pytest.mark.parametrize("cudagraph_policy", [True, {}, {"warmup": 10}]) + def test_cudagraph_policy(self, collector_cls, cudagraph_policy): + device = torch.device("cuda:0") + policy = TensorDictModule( + nn.Linear(7, 7, device=device), in_keys=["observation"], out_keys=["action"] + ) + make_env = functools.partial(ContinuousActionVecMockEnv, device=device) + if collector_cls is SyncDataCollector: + collector = SyncDataCollector( + make_env(), + policy, + frames_per_batch=30, + total_frames=120, + cudagraph_policy=cudagraph_policy, + device=device, + ) + assert collector.cudagraphed_policy + else: + collector = collector_cls( + [make_env] * 2, + policy, + frames_per_batch=30, + total_frames=120, + cudagraph_policy=cudagraph_policy, + device=device, + ) + assert collector.cudagraphed_policy + try: + for data in collector: + assert data is not None + finally: + collector.shutdown() + del collector + + @pytest.mark.skipif(not _has_gym, reason="gym required for this test") class TestCollectorsNonTensor: class AddNontTensorData(Transform): @@ -2685,7 +2752,7 @@ def _reset( def transform_observation_spec( self, observation_spec: TensorSpec ) -> TensorSpec: - observation_spec["nt"] = NonTensorSpec(shape=()) + observation_spec["nt"] = NonTensor(shape=()) return observation_spec @classmethod @@ -2775,6 +2842,289 @@ def test_async(self, use_buffers): del collector +class TestCollectorRB: + @pytest.mark.skipif(not _has_gym, reason="requires gym.") + def test_collector_rb_sync(self): + env = SerialEnv(8, lambda cp=CARTPOLE_VERSIONED(): GymEnv(cp)) + env.set_seed(0) + rb = ReplayBuffer(storage=LazyTensorStorage(256, ndim=2), batch_size=5) + collector = SyncDataCollector( + env, + RandomPolicy(env.action_spec), + replay_buffer=rb, + total_frames=256, + frames_per_batch=16, + ) + torch.manual_seed(0) + + for c in collector: + assert c is None + rb.sample() + rbdata0 = rb[:].clone() + collector.shutdown() + if not env.is_closed: + env.close() + del collector, env + + env = SerialEnv(8, lambda cp=CARTPOLE_VERSIONED(): GymEnv(cp)) + env.set_seed(0) + rb = ReplayBuffer(storage=LazyTensorStorage(256, ndim=2), batch_size=5) + collector = SyncDataCollector( + env, RandomPolicy(env.action_spec), total_frames=256, frames_per_batch=16 + ) + torch.manual_seed(0) + + for i, c in enumerate(collector): + rb.extend(c) + torch.testing.assert_close( + rbdata0[:, : (i + 1) * 2]["observation"], rb[:]["observation"] + ) + assert c is not None + rb.sample() + + rbdata1 = rb[:].clone() + collector.shutdown() + if not env.is_closed: + env.close() + del collector, env + assert assert_allclose_td(rbdata0, rbdata1) + + @pytest.mark.skipif(not _has_gym, reason="requires gym.") + @pytest.mark.parametrize("replay_buffer_chunk", [False, True]) + @pytest.mark.parametrize("env_creator", [False, True]) + @pytest.mark.parametrize("storagetype", [LazyTensorStorage, LazyMemmapStorage]) + def test_collector_rb_multisync( + self, replay_buffer_chunk, env_creator, storagetype, tmpdir + ): + if not env_creator: + env = GymEnv(CARTPOLE_VERSIONED()).append_transform(StepCounter()) + env.set_seed(0) + action_spec = env.action_spec + env = lambda env=env: env + else: + env = EnvCreator( + lambda cp=CARTPOLE_VERSIONED(): GymEnv(cp).append_transform( + StepCounter() + ) + ) + action_spec = env.meta_data.specs["input_spec", "full_action_spec"] + + if storagetype == LazyMemmapStorage: + storagetype = functools.partial(LazyMemmapStorage, scratch_dir=tmpdir) + rb = ReplayBuffer(storage=storagetype(256), batch_size=5) + + collector = MultiSyncDataCollector( + [env, env], + RandomPolicy(action_spec), + replay_buffer=rb, + total_frames=256, + frames_per_batch=32, + replay_buffer_chunk=replay_buffer_chunk, + ) + torch.manual_seed(0) + pred_len = 0 + for c in collector: + pred_len += 32 + assert c is None + assert len(rb) == pred_len + collector.shutdown() + assert len(rb) == 256 + if not replay_buffer_chunk: + steps_counts = rb["step_count"].squeeze().split(16) + collector_ids = rb["collector", "traj_ids"].squeeze().split(16) + for step_count, ids in zip(steps_counts, collector_ids): + step_countdiff = step_count.diff() + idsdiff = ids.diff() + assert ( + (step_countdiff == 1) | (step_countdiff < 0) + ).all(), steps_counts + assert (idsdiff >= 0).all() + + @pytest.mark.skipif(not _has_gym, reason="requires gym.") + @pytest.mark.parametrize("replay_buffer_chunk", [False, True]) + @pytest.mark.parametrize("env_creator", [False, True]) + @pytest.mark.parametrize("storagetype", [LazyTensorStorage, LazyMemmapStorage]) + def test_collector_rb_multiasync( + self, replay_buffer_chunk, env_creator, storagetype, tmpdir + ): + if not env_creator: + env = GymEnv(CARTPOLE_VERSIONED()).append_transform(StepCounter()) + env.set_seed(0) + action_spec = env.action_spec + env = lambda env=env: env + else: + env = EnvCreator( + lambda cp=CARTPOLE_VERSIONED(): GymEnv(cp).append_transform( + StepCounter() + ) + ) + action_spec = env.meta_data.specs["input_spec", "full_action_spec"] + + if storagetype == LazyMemmapStorage: + storagetype = functools.partial(LazyMemmapStorage, scratch_dir=tmpdir) + rb = ReplayBuffer(storage=storagetype(256), batch_size=5) + + collector = MultiaSyncDataCollector( + [env, env], + RandomPolicy(action_spec), + replay_buffer=rb, + total_frames=256, + frames_per_batch=16, + replay_buffer_chunk=replay_buffer_chunk, + ) + torch.manual_seed(0) + pred_len = 0 + for c in collector: + pred_len += 16 + assert c is None + assert len(rb) >= pred_len + collector.shutdown() + assert len(rb) == 256 + if not replay_buffer_chunk: + steps_counts = rb["step_count"].squeeze().split(16) + collector_ids = rb["collector", "traj_ids"].squeeze().split(16) + for step_count, ids in zip(steps_counts, collector_ids): + step_countdiff = step_count.diff() + idsdiff = ids.diff() + assert ( + (step_countdiff == 1) | (step_countdiff < 0) + ).all(), steps_counts + assert (idsdiff >= 0).all() + + +def __deepcopy_error__(*args, **kwargs): + raise RuntimeError("deepcopy not allowed") + + +@pytest.mark.filterwarnings("error") +@pytest.mark.filterwarnings("ignore:Tensordict is registered in PyTree") +@pytest.mark.parametrize( + "collector_type", + [ + SyncDataCollector, + MultiaSyncDataCollector, + functools.partial(MultiSyncDataCollector, cat_results="stack"), + ], +) +def test_no_deepcopy_policy(collector_type): + # Tests that the collector instantiation does not make a deepcopy of the policy if not necessary. + # + # The only situation where we want to deepcopy the policy is when the policy_device differs from the actual device + # of the policy. This can only be checked if the policy is an nn.Module and any of the params is not on the desired + # device. + # + # If the policy is not a nn.Module or has no parameter, policy_device should warn (we don't know what to do but we + # can trust that the user knows what to do). + + shared_device = torch.device("cpu") + if torch.cuda.is_available(): + original_device = torch.device("cuda:0") + elif torch.mps.is_available(): + original_device = torch.device("mps") + else: + pytest.skip("No GPU or MPS device") + + def make_policy(device=None, nn_module=True): + if nn_module: + return TensorDictModule( + nn.Linear(7, 7, device=device), + in_keys=["observation"], + out_keys=["action"], + ) + policy = make_policy(device=device) + return CloudpickleWrapper(policy) + + def make_and_test_policy( + policy, + policy_device=None, + env_device=None, + device=None, + trust_policy=None, + ): + # make sure policy errors when copied + + policy.__deepcopy__ = __deepcopy_error__ + envs = ContinuousActionVecMockEnv(device=env_device) + if collector_type is not SyncDataCollector: + envs = [envs, envs] + c = collector_type( + envs, + policy=policy, + total_frames=1000, + frames_per_batch=10, + policy_device=policy_device, + env_device=env_device, + device=device, + trust_policy=trust_policy, + ) + for _ in c: + return + + # Simplest use cases + policy = make_policy() + make_and_test_policy(policy) + + if collector_type is SyncDataCollector or original_device.type != "mps": + # mps cannot be shared + policy = make_policy(device=original_device) + make_and_test_policy(policy, env_device=original_device) + + if collector_type is SyncDataCollector or original_device.type != "mps": + policy = make_policy(device=original_device) + make_and_test_policy( + policy, policy_device=original_device, env_device=original_device + ) + + # a deepcopy must occur when the policy_device differs from the actual device + with pytest.raises(RuntimeError, match="deepcopy not allowed"): + policy = make_policy(device=original_device) + make_and_test_policy( + policy, policy_device=shared_device, env_device=shared_device + ) + + # a deepcopy must occur when device differs from the actual device + with pytest.raises(RuntimeError, match="deepcopy not allowed"): + policy = make_policy(device=original_device) + make_and_test_policy(policy, device=shared_device) + + # If the policy is not an nn.Module, we can't cast it to device, so we assume that the policy device + # is there to inform us + substitute_device = ( + original_device if torch.cuda.is_available() else torch.device("cpu") + ) + policy = make_policy(substitute_device, nn_module=False) + with pytest.warns(UserWarning): + make_and_test_policy( + policy, policy_device=substitute_device, env_device=substitute_device + ) + # For instance, if the env is on CPU, knowing the policy device helps with casting stuff on the right device + with pytest.warns(UserWarning): + make_and_test_policy( + policy, policy_device=substitute_device, env_device=shared_device + ) + make_and_test_policy( + policy, + policy_device=substitute_device, + env_device=shared_device, + trust_policy=True, + ) + + # If there is no policy_device, we assume that the user is doing things right too but don't warn + if collector_type is SyncDataCollector or original_device.type != "mps": + policy = make_policy(original_device, nn_module=False) + make_and_test_policy(policy, env_device=original_device) + + # If the policy is a CudaGraphModule, we know it's on cuda - no need to warn + if torch.cuda.is_available() and collector_type is SyncDataCollector: + policy = make_policy(original_device) + cudagraph_policy = CudaGraphModule(policy) + make_and_test_policy( + cudagraph_policy, + policy_device=original_device, + env_device=shared_device, + ) + + if __name__ == "__main__": args, unknown = argparse.ArgumentParser().parse_known_args() pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown) diff --git a/test/test_cost.py b/test/test_cost.py index 871d9170aa1..3530fff825d 100644 --- a/test/test_cost.py +++ b/test/test_cost.py @@ -13,8 +13,8 @@ from packaging import version as pack_version from tensordict._C import unravel_keys - from tensordict.nn import ( + CompositeDistribution, InteractionType, ProbabilisticTensorDictModule, ProbabilisticTensorDictModule as ProbMod, @@ -25,7 +25,6 @@ TensorDictSequential as Seq, ) from torchrl.envs.utils import exploration_type, ExplorationType, set_exploration_type - from torchrl.modules.models import QMixer _has_functorch = True @@ -49,19 +48,12 @@ from mocking_classes import ContinuousActionConvMockEnv # from torchrl.data.postprocs.utils import expand_as_right -from tensordict import assert_allclose_td, TensorDict +from tensordict import assert_allclose_td, TensorDict, TensorDictBase from tensordict.nn import NormalParamExtractor, TensorDictModule from tensordict.nn.utils import Buffer from tensordict.utils import unravel_key from torch import autograd, nn -from torchrl.data import ( - BoundedTensorSpec, - CompositeSpec, - DiscreteTensorSpec, - MultiOneHotDiscreteTensorSpec, - OneHotDiscreteTensorSpec, - UnboundedContinuousTensorSpec, -) +from torchrl.data import Bounded, Categorical, Composite, MultiOneHot, OneHot, Unbounded from torchrl.data.postprocs.postprocs import MultiStep from torchrl.envs.model_based.dreamer import DreamerEnv from torchrl.envs.transforms import TensorDictPrimer, TransformedEnv @@ -105,6 +97,7 @@ DreamerModelLoss, DreamerValueLoss, DTLoss, + GAILLoss, IQLLoss, KLPENPPOLoss, OnlineDTLoss, @@ -153,9 +146,15 @@ _split_and_pad_sequence, ) +TORCH_VERSION = torch.__version__ # Capture all warnings -pytestmark = pytest.mark.filterwarnings("error") +pytestmark = [ + pytest.mark.filterwarnings("error"), + pytest.mark.filterwarnings( + "ignore:The current behavior of MLP when not providing `num_cells` is that the number" + ), +] class _check_td_steady: @@ -303,9 +302,9 @@ def _create_mock_actor( ): # Actor if action_spec_type == "one_hot": - action_spec = OneHotDiscreteTensorSpec(action_dim) + action_spec = OneHot(action_dim) elif action_spec_type == "categorical": - action_spec = DiscreteTensorSpec(action_dim) + action_spec = Categorical(action_dim) # elif action_spec_type == "nd_bounded": # action_spec = BoundedTensorSpec( # -torch.ones(action_dim), torch.ones(action_dim), (action_dim,) @@ -317,7 +316,7 @@ def _create_mock_actor( if is_nn_module: return module.to(device) actor = QValueActor( - spec=CompositeSpec( + spec=Composite( { "action": action_spec, ( @@ -348,14 +347,12 @@ def _create_mock_distributional_actor( # Actor var_nums = None if action_spec_type == "mult_one_hot": - action_spec = MultiOneHotDiscreteTensorSpec( - [action_dim // 2, action_dim // 2] - ) + action_spec = MultiOneHot([action_dim // 2, action_dim // 2]) var_nums = action_spec.nvec elif action_spec_type == "one_hot": - action_spec = OneHotDiscreteTensorSpec(action_dim) + action_spec = OneHot(action_dim) elif action_spec_type == "categorical": - action_spec = DiscreteTensorSpec(action_dim) + action_spec = Categorical(action_dim) else: raise ValueError(f"Wrong {action_spec_type}") support = torch.linspace(vmin, vmax, atoms, dtype=torch.float) @@ -366,7 +363,7 @@ def _create_mock_distributional_actor( # if is_nn_module: # return module actor = DistributionalQValueActor( - spec=CompositeSpec( + spec=Composite( { "action": action_spec, action_value_key: None, @@ -775,7 +772,7 @@ def test_dqn_notensordict( ): n_obs = 3 n_action = 4 - action_spec = OneHotDiscreteTensorSpec(n_action) + action_spec = OneHot(n_action) module = nn.Linear(n_obs, n_action) # a simple value model actor = QValueActor( spec=action_spec, @@ -936,9 +933,9 @@ def _create_mock_actor( ): # Actor if action_spec_type == "one_hot": - action_spec = OneHotDiscreteTensorSpec(action_dim) + action_spec = OneHot(action_dim) elif action_spec_type == "categorical": - action_spec = DiscreteTensorSpec(action_dim) + action_spec = Categorical(action_dim) else: raise ValueError(f"Wrong {action_spec_type}") @@ -1385,7 +1382,7 @@ class TestDDPG(LossModuleTestBase): def _create_mock_actor(self, batch=2, obs_dim=3, action_dim=4, device="cpu"): # Actor - action_spec = BoundedTensorSpec( + action_spec = Bounded( -torch.ones(action_dim), torch.ones(action_dim), (action_dim,) ) module = nn.Linear(obs_dim, action_dim) @@ -2023,7 +2020,7 @@ def _create_mock_actor( dropout=0.0, ): # Actor - action_spec = BoundedTensorSpec( + action_spec = Bounded( -torch.ones(action_dim), torch.ones(action_dim), (action_dim,) ) module = nn.Sequential( @@ -2375,7 +2372,7 @@ def test_td3_separate_losses( loss_fn = TD3Loss( actor, value, - action_spec=BoundedTensorSpec(shape=(n_act,), low=-1, high=1), + action_spec=Bounded(shape=(n_act,), low=-1, high=1), loss_function="l2", separate_losses=separate_losses, ) @@ -2729,7 +2726,7 @@ def _create_mock_actor( dropout=0.0, ): # Actor - action_spec = BoundedTensorSpec( + action_spec = Bounded( -torch.ones(action_dim), torch.ones(action_dim), (action_dim,) ) module = nn.Sequential( @@ -3088,7 +3085,7 @@ def test_td3bc_separate_losses( loss_fn = TD3BCLoss( actor, value, - action_spec=BoundedTensorSpec(shape=(n_act,), low=-1, high=1), + action_spec=Bounded(shape=(n_act,), low=-1, high=1), loss_function="l2", separate_losses=separate_losses, ) @@ -3453,21 +3450,40 @@ def _create_mock_actor( device="cpu", observation_key="observation", action_key="action", + composite_action_dist=False, ): # Actor - action_spec = BoundedTensorSpec( + action_spec = Bounded( -torch.ones(action_dim), torch.ones(action_dim), (action_dim,) ) + if composite_action_dist: + action_spec = Composite({action_key: {"action1": action_spec}}) net = nn.Sequential(nn.Linear(obs_dim, 2 * action_dim), NormalParamExtractor()) + if composite_action_dist: + distribution_class = functools.partial( + CompositeDistribution, + distribution_map={ + "action1": TanhNormal, + }, + aggregate_probabilities=True, + ) + module_out_keys = [ + ("params", "action1", "loc"), + ("params", "action1", "scale"), + ] + actor_in_keys = ["params"] + else: + distribution_class = TanhNormal + module_out_keys = actor_in_keys = ["loc", "scale"] module = TensorDictModule( - net, in_keys=[observation_key], out_keys=["loc", "scale"] + net, in_keys=[observation_key], out_keys=module_out_keys ) actor = ProbabilisticActor( module=module, - in_keys=["loc", "scale"], - spec=action_spec, - distribution_class=TanhNormal, + distribution_class=distribution_class, + in_keys=actor_in_keys, out_keys=[action_key], + spec=action_spec, ) return actor.to(device) @@ -3487,6 +3503,8 @@ def __init__(self): self.linear = nn.Linear(obs_dim + action_dim, 1) def forward(self, obs, act): + if isinstance(act, TensorDictBase): + act = act.get("action1") return self.linear(torch.cat([obs, act], -1)) module = ValueClass() @@ -3515,8 +3533,26 @@ def _create_mock_value( return value.to(device) def _create_mock_common_layer_setup( - self, n_obs=3, n_act=4, ncells=4, batch=2, n_hidden=2 + self, + n_obs=3, + n_act=4, + ncells=4, + batch=2, + n_hidden=2, + composite_action_dist=False, ): + class QValueClass(nn.Module): + def __init__(self): + super().__init__() + self.linear1 = nn.Linear(n_hidden + n_act, n_hidden) + self.relu = nn.ReLU() + self.linear2 = nn.Linear(n_hidden, 1) + + def forward(self, obs, act): + if isinstance(act, TensorDictBase): + act = act.get("action1") + return self.linear2(self.relu(self.linear1(torch.cat([obs, act], -1)))) + common = MLP( num_cells=ncells, in_features=n_obs, @@ -3529,17 +3565,13 @@ def _create_mock_common_layer_setup( depth=1, out_features=2 * n_act, ) - qvalue = MLP( - in_features=n_hidden + n_act, - num_cells=ncells, - depth=1, - out_features=1, - ) + qvalue = QValueClass() batch = [batch] + action = torch.randn(*batch, n_act) td = TensorDict( { "obs": torch.randn(*batch, n_obs), - "action": torch.randn(*batch, n_act), + "action": {"action1": action} if composite_action_dist else action, "done": torch.zeros(*batch, 1, dtype=torch.bool), "terminated": torch.zeros(*batch, 1, dtype=torch.bool), "next": { @@ -3552,14 +3584,30 @@ def _create_mock_common_layer_setup( batch, ) common = Mod(common, in_keys=["obs"], out_keys=["hidden"]) + if composite_action_dist: + distribution_class = functools.partial( + CompositeDistribution, + distribution_map={ + "action1": TanhNormal, + }, + aggregate_probabilities=True, + ) + module_out_keys = [ + ("params", "action1", "loc"), + ("params", "action1", "scale"), + ] + actor_in_keys = ["params"] + else: + distribution_class = TanhNormal + module_out_keys = actor_in_keys = ["loc", "scale"] actor = ProbSeq( common, Mod(actor_net, in_keys=["hidden"], out_keys=["param"]), - Mod(NormalParamExtractor(), in_keys=["param"], out_keys=["loc", "scale"]), + Mod(NormalParamExtractor(), in_keys=["param"], out_keys=module_out_keys), ProbMod( - in_keys=["loc", "scale"], + in_keys=actor_in_keys, out_keys=["action"], - distribution_class=TanhNormal, + distribution_class=distribution_class, ), ) qvalue_head = Mod( @@ -3585,6 +3633,7 @@ def _create_mock_data_sac( done_key="done", terminated_key="terminated", reward_key="reward", + composite_action_dist=False, ): # create a tensordict obs = torch.randn(batch, obs_dim, device=device) @@ -3606,14 +3655,21 @@ def _create_mock_data_sac( terminated_key: terminated, reward_key: reward, }, - action_key: action, + action_key: {"action1": action} if composite_action_dist else action, }, device=device, ) return td def _create_seq_mock_data_sac( - self, batch=8, T=4, obs_dim=3, action_dim=4, atoms=None, device="cpu" + self, + batch=8, + T=4, + obs_dim=3, + action_dim=4, + atoms=None, + device="cpu", + composite_action_dist=False, ): # create a tensordict total_obs = torch.randn(batch, T + 1, obs_dim, device=device) @@ -3629,6 +3685,7 @@ def _create_seq_mock_data_sac( done = torch.zeros(batch, T, 1, dtype=torch.bool, device=device) terminated = torch.zeros(batch, T, 1, dtype=torch.bool, device=device) mask = torch.ones(batch, T, dtype=torch.bool, device=device) + action = action.masked_fill_(~mask.unsqueeze(-1), 0.0) td = TensorDict( batch_size=(batch, T), source={ @@ -3640,7 +3697,7 @@ def _create_seq_mock_data_sac( "reward": reward.masked_fill_(~mask.unsqueeze(-1), 0.0), }, "collector": {"mask": mask}, - "action": action.masked_fill_(~mask.unsqueeze(-1), 0.0), + "action": {"action1": action} if composite_action_dist else action, }, names=[None, "time"], device=device, @@ -3653,6 +3710,7 @@ def _create_seq_mock_data_sac( @pytest.mark.parametrize("num_qvalue", [1, 2, 4, 8]) @pytest.mark.parametrize("device", get_default_devices()) @pytest.mark.parametrize("td_est", list(ValueEstimators) + [None]) + @pytest.mark.parametrize("composite_action_dist", [True, False]) def test_sac( self, delay_value, @@ -3662,14 +3720,19 @@ def test_sac( device, version, td_est, + composite_action_dist, ): if (delay_actor or delay_qvalue) and not delay_value: pytest.skip("incompatible config") torch.manual_seed(self.seed) - td = self._create_mock_data_sac(device=device) + td = self._create_mock_data_sac( + device=device, composite_action_dist=composite_action_dist + ) - actor = self._create_mock_actor(device=device) + actor = self._create_mock_actor( + device=device, composite_action_dist=composite_action_dist + ) qvalue = self._create_mock_qvalue(device=device) if version == 1: value = self._create_mock_value(device=device) @@ -3819,6 +3882,7 @@ def test_sac( @pytest.mark.parametrize("delay_qvalue", (True, False)) @pytest.mark.parametrize("num_qvalue", [2]) @pytest.mark.parametrize("device", get_default_devices()) + @pytest.mark.parametrize("composite_action_dist", [True, False]) def test_sac_state_dict( self, delay_value, @@ -3827,13 +3891,16 @@ def test_sac_state_dict( num_qvalue, device, version, + composite_action_dist, ): if (delay_actor or delay_qvalue) and not delay_value: pytest.skip("incompatible config") torch.manual_seed(self.seed) - actor = self._create_mock_actor(device=device) + actor = self._create_mock_actor( + device=device, composite_action_dist=composite_action_dist + ) qvalue = self._create_mock_qvalue(device=device) if version == 1: value = self._create_mock_value(device=device) @@ -3869,20 +3936,24 @@ def test_sac_state_dict( @pytest.mark.parametrize("device", get_default_devices()) @pytest.mark.parametrize("separate_losses", [False, True]) + @pytest.mark.parametrize("composite_action_dist", [True, False]) def test_sac_separate_losses( self, device, separate_losses, version, + composite_action_dist, n_act=4, ): torch.manual_seed(self.seed) - actor, qvalue, common, td = self._create_mock_common_layer_setup(n_act=n_act) + actor, qvalue, common, td = self._create_mock_common_layer_setup( + n_act=n_act, composite_action_dist=composite_action_dist + ) loss_fn = SACLoss( actor_network=actor, qvalue_network=qvalue, - action_spec=UnboundedContinuousTensorSpec(shape=(n_act,)), + action_spec=Unbounded(shape=(n_act,)), num_qvalue_nets=1, separate_losses=separate_losses, ) @@ -3963,6 +4034,7 @@ def test_sac_separate_losses( @pytest.mark.parametrize("delay_qvalue", (True, False)) @pytest.mark.parametrize("num_qvalue", [1, 2, 4, 8]) @pytest.mark.parametrize("device", get_default_devices()) + @pytest.mark.parametrize("composite_action_dist", [True, False]) def test_sac_batcher( self, n, @@ -3972,13 +4044,18 @@ def test_sac_batcher( num_qvalue, device, version, + composite_action_dist, ): if (delay_actor or delay_qvalue) and not delay_value: pytest.skip("incompatible config") torch.manual_seed(self.seed) - td = self._create_seq_mock_data_sac(device=device) + td = self._create_seq_mock_data_sac( + device=device, composite_action_dist=composite_action_dist + ) - actor = self._create_mock_actor(device=device) + actor = self._create_mock_actor( + device=device, composite_action_dist=composite_action_dist + ) qvalue = self._create_mock_qvalue(device=device) if version == 1: value = self._create_mock_value(device=device) @@ -4129,10 +4206,11 @@ def test_sac_batcher( @pytest.mark.parametrize( "td_est", [ValueEstimators.TD1, ValueEstimators.TD0, ValueEstimators.TDLambda] ) - def test_sac_tensordict_keys(self, td_est, version): - td = self._create_mock_data_sac() + @pytest.mark.parametrize("composite_action_dist", [True, False]) + def test_sac_tensordict_keys(self, td_est, version, composite_action_dist): + td = self._create_mock_data_sac(composite_action_dist=composite_action_dist) - actor = self._create_mock_actor() + actor = self._create_mock_actor(composite_action_dist=composite_action_dist) qvalue = self._create_mock_qvalue() if version == 1: value = self._create_mock_value() @@ -4152,7 +4230,7 @@ def test_sac_tensordict_keys(self, td_est, version): "value": "state_value", "state_action_value": "state_action_value", "action": "action", - "log_prob": "_log_prob", + "log_prob": "sample_log_prob", "reward": "reward", "done": "done", "terminated": "terminated", @@ -4286,14 +4364,14 @@ def test_state_dict(self, version): loss = SACLoss( actor_network=policy, qvalue_network=value, - action_spec=UnboundedContinuousTensorSpec(shape=(2,)), + action_spec=Unbounded(shape=(2,)), ) state = loss.state_dict() loss = SACLoss( actor_network=policy, qvalue_network=value, - action_spec=UnboundedContinuousTensorSpec(shape=(2,)), + action_spec=Unbounded(shape=(2,)), ) loss.load_state_dict(state) @@ -4301,7 +4379,7 @@ def test_state_dict(self, version): loss = SACLoss( actor_network=policy, qvalue_network=value, - action_spec=UnboundedContinuousTensorSpec(shape=(2,)), + action_spec=Unbounded(shape=(2,)), ) loss.target_entropy state = loss.state_dict() @@ -4309,20 +4387,25 @@ def test_state_dict(self, version): loss = SACLoss( actor_network=policy, qvalue_network=value, - action_spec=UnboundedContinuousTensorSpec(shape=(2,)), + action_spec=Unbounded(shape=(2,)), ) loss.load_state_dict(state) @pytest.mark.parametrize("reduction", [None, "none", "mean", "sum"]) - def test_sac_reduction(self, reduction, version): + @pytest.mark.parametrize("composite_action_dist", [True, False]) + def test_sac_reduction(self, reduction, version, composite_action_dist): torch.manual_seed(self.seed) device = ( torch.device("cpu") if torch.cuda.device_count() == 0 else torch.device("cuda") ) - td = self._create_mock_data_sac(device=device) - actor = self._create_mock_actor(device=device) + td = self._create_mock_data_sac( + device=device, composite_action_dist=composite_action_dist + ) + actor = self._create_mock_actor( + device=device, composite_action_dist=composite_action_dist + ) qvalue = self._create_mock_qvalue(device=device) if version == 1: value = self._create_mock_value(device=device) @@ -4367,7 +4450,7 @@ def _create_mock_actor( action_key="action", ): # Actor - action_spec = OneHotDiscreteTensorSpec(action_dim) + action_spec = OneHot(action_dim) net = nn.Sequential(nn.Linear(obs_dim, 2 * action_dim), NormalParamExtractor()) module = TensorDictModule(net, in_keys=[observation_key], out_keys=["logits"]) actor = ProbabilisticActor( @@ -4953,7 +5036,7 @@ def _create_mock_actor( action_key="action", ): # Actor - action_spec = BoundedTensorSpec( + action_spec = Bounded( -torch.ones(action_dim), torch.ones(action_dim), (action_dim,) ) net = nn.Sequential(nn.Linear(obs_dim, 2 * action_dim), NormalParamExtractor()) @@ -5269,7 +5352,7 @@ def test_crossq_separate_losses( loss_fn = CrossQLoss( actor_network=actor, qvalue_network=qvalue, - action_spec=UnboundedContinuousTensorSpec(shape=(n_act,)), + action_spec=Unbounded(shape=(n_act,)), num_qvalue_nets=1, separate_losses=separate_losses, ) @@ -5574,14 +5657,14 @@ def test_state_dict( loss = CrossQLoss( actor_network=policy, qvalue_network=value, - action_spec=UnboundedContinuousTensorSpec(shape=(2,)), + action_spec=Unbounded(shape=(2,)), ) state = loss.state_dict() loss = CrossQLoss( actor_network=policy, qvalue_network=value, - action_spec=UnboundedContinuousTensorSpec(shape=(2,)), + action_spec=Unbounded(shape=(2,)), ) loss.load_state_dict(state) @@ -5589,7 +5672,7 @@ def test_state_dict( loss = CrossQLoss( actor_network=policy, qvalue_network=value, - action_spec=UnboundedContinuousTensorSpec(shape=(2,)), + action_spec=Unbounded(shape=(2,)), ) loss.target_entropy state = loss.state_dict() @@ -5597,7 +5680,7 @@ def test_state_dict( loss = CrossQLoss( actor_network=policy, qvalue_network=value, - action_spec=UnboundedContinuousTensorSpec(shape=(2,)), + action_spec=Unbounded(shape=(2,)), ) loss.load_state_dict(state) @@ -5648,7 +5731,7 @@ def _create_mock_actor( action_key="action", ): # Actor - action_spec = BoundedTensorSpec( + action_spec = Bounded( -torch.ones(action_dim), torch.ones(action_dim), (action_dim,) ) net = nn.Sequential(nn.Linear(obs_dim, 2 * action_dim), NormalParamExtractor()) @@ -6593,7 +6676,7 @@ class TestCQL(LossModuleTestBase): def _create_mock_actor(self, batch=2, obs_dim=3, action_dim=4, device="cpu"): # Actor - action_spec = BoundedTensorSpec( + action_spec = Bounded( -torch.ones(action_dim), torch.ones(action_dim), (action_dim,) ) net = nn.Sequential(nn.Linear(obs_dim, 2 * action_dim), NormalParamExtractor()) @@ -7156,9 +7239,9 @@ def _create_mock_actor( ): # Actor if action_spec_type == "one_hot": - action_spec = OneHotDiscreteTensorSpec(action_dim) + action_spec = OneHot(action_dim) elif action_spec_type == "categorical": - action_spec = DiscreteTensorSpec(action_dim) + action_spec = Categorical(action_dim) else: raise ValueError(f"Wrong action spec type: {action_spec_type}") @@ -7166,7 +7249,7 @@ def _create_mock_actor( if is_nn_module: return module.to(device) actor = QValueActor( - spec=CompositeSpec( + spec=Composite( { "action": action_spec, ( @@ -7476,7 +7559,7 @@ def test_dcql_notensordict( ): n_obs = 3 n_action = 4 - action_spec = OneHotDiscreteTensorSpec(n_action) + action_spec = OneHot(n_action) module = nn.Linear(n_obs, n_action) # a simple value model actor = QValueActor( spec=action_spec, @@ -7547,21 +7630,46 @@ def _create_mock_actor( obs_dim=3, action_dim=4, device="cpu", + action_key="action", observation_key="observation", sample_log_prob_key="sample_log_prob", + composite_action_dist=False, ): # Actor - action_spec = BoundedTensorSpec( + action_spec = Bounded( -torch.ones(action_dim), torch.ones(action_dim), (action_dim,) ) + if composite_action_dist: + action_spec = Composite({action_key: {"action1": action_spec}}) net = nn.Sequential(nn.Linear(obs_dim, 2 * action_dim), NormalParamExtractor()) + if composite_action_dist: + distribution_class = functools.partial( + CompositeDistribution, + distribution_map={ + "action1": TanhNormal, + }, + name_map={ + "action1": (action_key, "action1"), + }, + log_prob_key=sample_log_prob_key, + aggregate_probabilities=True, + ) + module_out_keys = [ + ("params", "action1", "loc"), + ("params", "action1", "scale"), + ] + actor_in_keys = ["params"] + else: + distribution_class = TanhNormal + module_out_keys = actor_in_keys = ["loc", "scale"] module = TensorDictModule( - net, in_keys=[observation_key], out_keys=["loc", "scale"] + net, in_keys=[observation_key], out_keys=module_out_keys ) actor = ProbabilisticActor( module=module, - distribution_class=TanhNormal, - in_keys=["loc", "scale"], + distribution_class=distribution_class, + in_keys=actor_in_keys, + out_keys=[action_key], spec=action_spec, return_log_prob=True, log_prob_key=sample_log_prob_key, @@ -7585,22 +7693,52 @@ def _create_mock_value( ) return value.to(device) - def _create_mock_actor_value(self, batch=2, obs_dim=3, action_dim=4, device="cpu"): + def _create_mock_actor_value( + self, + batch=2, + obs_dim=3, + action_dim=4, + device="cpu", + composite_action_dist=False, + sample_log_prob_key="sample_log_prob", + ): # Actor - action_spec = BoundedTensorSpec( + action_spec = Bounded( -torch.ones(action_dim), torch.ones(action_dim), (action_dim,) ) + if composite_action_dist: + action_spec = Composite({"action": {"action1": action_spec}}) base_layer = nn.Linear(obs_dim, 5) net = nn.Sequential( base_layer, nn.Linear(5, 2 * action_dim), NormalParamExtractor() ) + if composite_action_dist: + distribution_class = functools.partial( + CompositeDistribution, + distribution_map={ + "action1": TanhNormal, + }, + name_map={ + "action1": ("action", "action1"), + }, + log_prob_key=sample_log_prob_key, + aggregate_probabilities=True, + ) + module_out_keys = [ + ("params", "action1", "loc"), + ("params", "action1", "scale"), + ] + actor_in_keys = ["params"] + else: + distribution_class = TanhNormal + module_out_keys = actor_in_keys = ["loc", "scale"] module = TensorDictModule( - net, in_keys=["observation"], out_keys=["loc", "scale"] + net, in_keys=["observation"], out_keys=module_out_keys ) actor = ProbabilisticActor( module=module, - distribution_class=TanhNormal, - in_keys=["loc", "scale"], + distribution_class=distribution_class, + in_keys=actor_in_keys, spec=action_spec, return_log_prob=True, ) @@ -7612,22 +7750,50 @@ def _create_mock_actor_value(self, batch=2, obs_dim=3, action_dim=4, device="cpu return actor.to(device), value.to(device) def _create_mock_actor_value_shared( - self, batch=2, obs_dim=3, action_dim=4, device="cpu" + self, + batch=2, + obs_dim=3, + action_dim=4, + device="cpu", + composite_action_dist=False, + sample_log_prob_key="sample_log_prob", ): # Actor - action_spec = BoundedTensorSpec( + action_spec = Bounded( -torch.ones(action_dim), torch.ones(action_dim), (action_dim,) ) + if composite_action_dist: + action_spec = Composite({"action": {"action1": action_spec}}) base_layer = nn.Linear(obs_dim, 5) common = TensorDictModule( base_layer, in_keys=["observation"], out_keys=["hidden"] ) net = nn.Sequential(nn.Linear(5, 2 * action_dim), NormalParamExtractor()) - module = TensorDictModule(net, in_keys=["hidden"], out_keys=["loc", "scale"]) + if composite_action_dist: + distribution_class = functools.partial( + CompositeDistribution, + distribution_map={ + "action1": TanhNormal, + }, + name_map={ + "action1": ("action", "action1"), + }, + log_prob_key=sample_log_prob_key, + aggregate_probabilities=True, + ) + module_out_keys = [ + ("params", "action1", "loc"), + ("params", "action1", "scale"), + ] + actor_in_keys = ["params"] + else: + distribution_class = TanhNormal + module_out_keys = actor_in_keys = ["loc", "scale"] + module = TensorDictModule(net, in_keys=["hidden"], out_keys=module_out_keys) actor_head = ProbabilisticActor( module=module, - distribution_class=TanhNormal, - in_keys=["loc", "scale"], + distribution_class=distribution_class, + in_keys=actor_in_keys, spec=action_spec, return_log_prob=True, ) @@ -7657,6 +7823,7 @@ def _create_mock_data_ppo( done_key="done", terminated_key="terminated", sample_log_prob_key="sample_log_prob", + composite_action_dist=False, ): # create a tensordict obs = torch.randn(batch, obs_dim, device=device) @@ -7682,13 +7849,17 @@ def _create_mock_data_ppo( terminated_key: terminated, reward_key: reward, }, - action_key: action, + action_key: {"action1": action} if composite_action_dist else action, sample_log_prob_key: torch.randn_like(action[..., 1]) / 10, - loc_key: loc, - scale_key: scale, }, device=device, ) + if composite_action_dist: + td[("params", "action1", loc_key)] = loc + td[("params", "action1", scale_key)] = scale + else: + td[loc_key] = loc + td[scale_key] = scale return td def _create_seq_mock_data_ppo( @@ -7701,6 +7872,7 @@ def _create_seq_mock_data_ppo( device="cpu", sample_log_prob_key="sample_log_prob", action_key="action", + composite_action_dist=False, ): # create a tensordict total_obs = torch.randn(batch, T + 1, obs_dim, device=device) @@ -7716,8 +7888,11 @@ def _create_seq_mock_data_ppo( done = torch.zeros(batch, T, 1, dtype=torch.bool, device=device) terminated = torch.zeros(batch, T, 1, dtype=torch.bool, device=device) mask = torch.ones(batch, T, dtype=torch.bool, device=device) + action = action.masked_fill_(~mask.unsqueeze(-1), 0.0) params_mean = torch.randn_like(action) / 10 params_scale = torch.rand_like(action) / 10 + loc = params_mean.masked_fill_(~mask.unsqueeze(-1), 0.0) + scale = params_scale.masked_fill_(~mask.unsqueeze(-1), 0.0) td = TensorDict( batch_size=(batch, T), source={ @@ -7729,16 +7904,21 @@ def _create_seq_mock_data_ppo( "reward": reward.masked_fill_(~mask.unsqueeze(-1), 0.0), }, "collector": {"mask": mask}, - action_key: action.masked_fill_(~mask.unsqueeze(-1), 0.0), + action_key: {"action1": action} if composite_action_dist else action, sample_log_prob_key: ( torch.randn_like(action[..., 1]) / 10 ).masked_fill_(~mask, 0.0), - "loc": params_mean.masked_fill_(~mask.unsqueeze(-1), 0.0), - "scale": params_scale.masked_fill_(~mask.unsqueeze(-1), 0.0), }, device=device, names=[None, "time"], ) + if composite_action_dist: + td[("params", "action1", "loc")] = loc + td[("params", "action1", "scale")] = scale + else: + td["loc"] = loc + td["scale"] = scale + return td @pytest.mark.parametrize("loss_class", (PPOLoss, ClipPPOLoss, KLPENPPOLoss)) @@ -7747,6 +7927,7 @@ def _create_seq_mock_data_ppo( @pytest.mark.parametrize("device", get_default_devices()) @pytest.mark.parametrize("td_est", list(ValueEstimators) + [None]) @pytest.mark.parametrize("functional", [True, False]) + @pytest.mark.parametrize("composite_action_dist", [True, False]) def test_ppo( self, loss_class, @@ -7755,11 +7936,16 @@ def test_ppo( advantage, td_est, functional, + composite_action_dist, ): torch.manual_seed(self.seed) - td = self._create_seq_mock_data_ppo(device=device) + td = self._create_seq_mock_data_ppo( + device=device, composite_action_dist=composite_action_dist + ) - actor = self._create_mock_actor(device=device) + actor = self._create_mock_actor( + device=device, composite_action_dist=composite_action_dist + ) value = self._create_mock_value(device=device) if advantage == "gae": advantage = GAE( @@ -7799,7 +7985,10 @@ def test_ppo( loss = loss_fn(td) if isinstance(loss_fn, KLPENPPOLoss): - kl = loss.pop("kl") + if composite_action_dist: + kl = loss.pop("kl_approx") + else: + kl = loss.pop("kl") assert (kl != 0).any() loss_critic = loss["loss_critic"] @@ -7836,10 +8025,15 @@ def test_ppo( @pytest.mark.parametrize("loss_class", (PPOLoss, ClipPPOLoss, KLPENPPOLoss)) @pytest.mark.parametrize("gradient_mode", (True,)) @pytest.mark.parametrize("device", get_default_devices()) - def test_ppo_state_dict(self, loss_class, device, gradient_mode): + @pytest.mark.parametrize("composite_action_dist", [True, False]) + def test_ppo_state_dict( + self, loss_class, device, gradient_mode, composite_action_dist + ): torch.manual_seed(self.seed) - actor = self._create_mock_actor(device=device) + actor = self._create_mock_actor( + device=device, composite_action_dist=composite_action_dist + ) value = self._create_mock_value(device=device) loss_fn = loss_class(actor, value, loss_critic_type="l2") sd = loss_fn.state_dict() @@ -7849,11 +8043,16 @@ def test_ppo_state_dict(self, loss_class, device, gradient_mode): @pytest.mark.parametrize("loss_class", (PPOLoss, ClipPPOLoss, KLPENPPOLoss)) @pytest.mark.parametrize("advantage", ("gae", "vtrace", "td", "td_lambda", None)) @pytest.mark.parametrize("device", get_default_devices()) - def test_ppo_shared(self, loss_class, device, advantage): + @pytest.mark.parametrize("composite_action_dist", [True, False]) + def test_ppo_shared(self, loss_class, device, advantage, composite_action_dist): torch.manual_seed(self.seed) - td = self._create_seq_mock_data_ppo(device=device) + td = self._create_seq_mock_data_ppo( + device=device, composite_action_dist=composite_action_dist + ) - actor, value = self._create_mock_actor_value(device=device) + actor, value = self._create_mock_actor_value( + device=device, composite_action_dist=composite_action_dist + ) if advantage == "gae": advantage = GAE( gamma=0.9, @@ -7935,18 +8134,24 @@ def test_ppo_shared(self, loss_class, device, advantage): ) @pytest.mark.parametrize("device", get_default_devices()) @pytest.mark.parametrize("separate_losses", [True, False]) + @pytest.mark.parametrize("composite_action_dist", [True, False]) def test_ppo_shared_seq( self, loss_class, device, advantage, separate_losses, + composite_action_dist, ): """Tests PPO with shared module with and without passing twice across the common module.""" torch.manual_seed(self.seed) - td = self._create_seq_mock_data_ppo(device=device) + td = self._create_seq_mock_data_ppo( + device=device, composite_action_dist=composite_action_dist + ) - model, actor, value = self._create_mock_actor_value_shared(device=device) + model, actor, value = self._create_mock_actor_value_shared( + device=device, composite_action_dist=composite_action_dist + ) value2 = value[-1] # prune the common module if advantage == "gae": advantage = GAE( @@ -8004,8 +8209,20 @@ def test_ppo_shared_seq( grad2 = TensorDict(dict(model.named_parameters()), []).apply( lambda x: x.grad.clone() ) - assert_allclose_td(loss, loss2) - assert_allclose_td(grad, grad2) + if composite_action_dist and loss_class is KLPENPPOLoss: + # KL computation for composite dist is based on randomly + # sampled data, thus will not be the same. + # Similarly, objective loss depends on the KL, so ir will + # not be the same either. + # Finally, gradients will be different too. + loss.pop("kl", None) + loss2.pop("kl", None) + loss.pop("loss_objective", None) + loss2.pop("loss_objective", None) + assert_allclose_td(loss, loss2) + else: + assert_allclose_td(loss, loss2) + assert_allclose_td(grad, grad2) model.zero_grad() @pytest.mark.skipif( @@ -8015,11 +8232,18 @@ def test_ppo_shared_seq( @pytest.mark.parametrize("gradient_mode", (True, False)) @pytest.mark.parametrize("advantage", ("gae", "vtrace", "td", "td_lambda", None)) @pytest.mark.parametrize("device", get_default_devices()) - def test_ppo_diff(self, loss_class, device, gradient_mode, advantage): + @pytest.mark.parametrize("composite_action_dist", [True, False]) + def test_ppo_diff( + self, loss_class, device, gradient_mode, advantage, composite_action_dist + ): torch.manual_seed(self.seed) - td = self._create_seq_mock_data_ppo(device=device) + td = self._create_seq_mock_data_ppo( + device=device, composite_action_dist=composite_action_dist + ) - actor = self._create_mock_actor(device=device) + actor = self._create_mock_actor( + device=device, composite_action_dist=composite_action_dist + ) value = self._create_mock_value(device=device) if advantage == "gae": advantage = GAE( @@ -8108,8 +8332,9 @@ def zero_param(p): ValueEstimators.TDLambda, ], ) - def test_ppo_tensordict_keys(self, loss_class, td_est): - actor = self._create_mock_actor() + @pytest.mark.parametrize("composite_action_dist", [True, False]) + def test_ppo_tensordict_keys(self, loss_class, td_est, composite_action_dist): + actor = self._create_mock_actor(composite_action_dist=composite_action_dist) value = self._create_mock_value() loss_fn = loss_class(actor, value, loss_critic_type="l2") @@ -8148,7 +8373,10 @@ def test_ppo_tensordict_keys(self, loss_class, td_est): @pytest.mark.parametrize("loss_class", (PPOLoss, ClipPPOLoss, KLPENPPOLoss)) @pytest.mark.parametrize("advantage", ("gae", "vtrace", "td", "td_lambda", None)) @pytest.mark.parametrize("td_est", list(ValueEstimators) + [None]) - def test_ppo_tensordict_keys_run(self, loss_class, advantage, td_est): + @pytest.mark.parametrize("composite_action_dist", [True, False]) + def test_ppo_tensordict_keys_run( + self, loss_class, advantage, td_est, composite_action_dist + ): """Test PPO loss module with non-default tensordict keys.""" torch.manual_seed(self.seed) gradient_mode = True @@ -8163,9 +8391,12 @@ def test_ppo_tensordict_keys_run(self, loss_class, advantage, td_est): td = self._create_seq_mock_data_ppo( sample_log_prob_key=tensor_keys["sample_log_prob"], action_key=tensor_keys["action"], + composite_action_dist=composite_action_dist, ) actor = self._create_mock_actor( - sample_log_prob_key=tensor_keys["sample_log_prob"] + sample_log_prob_key=tensor_keys["sample_log_prob"], + composite_action_dist=composite_action_dist, + action_key=tensor_keys["action"], ) value = self._create_mock_value(out_keys=[tensor_keys["value"]]) @@ -8256,6 +8487,12 @@ def test_ppo_tensordict_keys_run(self, loss_class, advantage, td_est): @pytest.mark.parametrize("reward_key", ["reward", "reward2"]) @pytest.mark.parametrize("done_key", ["done", "done2"]) @pytest.mark.parametrize("terminated_key", ["terminated", "terminated2"]) + @pytest.mark.parametrize( + "composite_action_dist", + [ + False, + ], + ) def test_ppo_notensordict( self, loss_class, @@ -8265,6 +8502,7 @@ def test_ppo_notensordict( reward_key, done_key, terminated_key, + composite_action_dist, ): torch.manual_seed(self.seed) td = self._create_mock_data_ppo( @@ -8274,10 +8512,14 @@ def test_ppo_notensordict( reward_key=reward_key, done_key=done_key, terminated_key=terminated_key, + composite_action_dist=composite_action_dist, ) actor = self._create_mock_actor( - observation_key=observation_key, sample_log_prob_key=sample_log_prob_key + observation_key=observation_key, + sample_log_prob_key=sample_log_prob_key, + composite_action_dist=composite_action_dist, + action_key=action_key, ) value = self._create_mock_value(observation_key=observation_key) @@ -8300,7 +8542,9 @@ def test_ppo_notensordict( f"next_{observation_key}": td.get(("next", observation_key)), } if loss_class is KLPENPPOLoss: - kwargs.update({"loc": td.get("loc"), "scale": td.get("scale")}) + loc_key = "params" if composite_action_dist else "loc" + scale_key = "params" if composite_action_dist else "scale" + kwargs.update({loc_key: td.get(loc_key), scale_key: td.get(scale_key)}) td = TensorDict(kwargs, td.batch_size, names=["time"]).unflatten_keys("_") @@ -8313,6 +8557,7 @@ def test_ppo_notensordict( loss_val = loss(**kwargs) torch.manual_seed(self.seed) if beta is not None: + loss.beta = beta.clone() loss_val_td = loss(td) @@ -8340,15 +8585,20 @@ def test_ppo_notensordict( @pytest.mark.parametrize("loss_class", (PPOLoss, ClipPPOLoss, KLPENPPOLoss)) @pytest.mark.parametrize("reduction", [None, "none", "mean", "sum"]) - def test_ppo_reduction(self, reduction, loss_class): + @pytest.mark.parametrize("composite_action_dist", [True, False]) + def test_ppo_reduction(self, reduction, loss_class, composite_action_dist): torch.manual_seed(self.seed) device = ( torch.device("cpu") if torch.cuda.device_count() == 0 else torch.device("cuda") ) - td = self._create_seq_mock_data_ppo(device=device) - actor = self._create_mock_actor(device=device) + td = self._create_seq_mock_data_ppo( + device=device, composite_action_dist=composite_action_dist + ) + actor = self._create_mock_actor( + device=device, composite_action_dist=composite_action_dist + ) value = self._create_mock_value(device=device) advantage = GAE( gamma=0.9, @@ -8376,10 +8626,17 @@ def test_ppo_reduction(self, reduction, loss_class): @pytest.mark.parametrize("device", get_default_devices()) @pytest.mark.parametrize("loss_class", (PPOLoss, ClipPPOLoss, KLPENPPOLoss)) @pytest.mark.parametrize("clip_value", [True, False, None, 0.5, torch.tensor(0.5)]) - def test_ppo_value_clipping(self, clip_value, loss_class, device): + @pytest.mark.parametrize("composite_action_dist", [True, False]) + def test_ppo_value_clipping( + self, clip_value, loss_class, device, composite_action_dist + ): torch.manual_seed(self.seed) - td = self._create_seq_mock_data_ppo(device=device) - actor = self._create_mock_actor(device=device) + td = self._create_seq_mock_data_ppo( + device=device, composite_action_dist=composite_action_dist + ) + actor = self._create_mock_actor( + device=device, composite_action_dist=composite_action_dist + ) value = self._create_mock_value(device=device) advantage = GAE( gamma=0.9, @@ -8438,22 +8695,47 @@ def _create_mock_actor( obs_dim=3, action_dim=4, device="cpu", + action_key="action", observation_key="observation", sample_log_prob_key="sample_log_prob", + composite_action_dist=False, ): # Actor - action_spec = BoundedTensorSpec( + action_spec = Bounded( -torch.ones(action_dim), torch.ones(action_dim), (action_dim,) ) + if composite_action_dist: + action_spec = Composite({action_key: {"action1": action_spec}}) net = nn.Sequential(nn.Linear(obs_dim, 2 * action_dim), NormalParamExtractor()) + if composite_action_dist: + distribution_class = functools.partial( + CompositeDistribution, + distribution_map={ + "action1": TanhNormal, + }, + name_map={ + "action1": (action_key, "action1"), + }, + log_prob_key=sample_log_prob_key, + aggregate_probabilities=True, + ) + module_out_keys = [ + ("params", "action1", "loc"), + ("params", "action1", "scale"), + ] + actor_in_keys = ["params"] + else: + distribution_class = TanhNormal + module_out_keys = actor_in_keys = ["loc", "scale"] module = TensorDictModule( - net, in_keys=[observation_key], out_keys=["loc", "scale"] + net, in_keys=[observation_key], out_keys=module_out_keys ) actor = ProbabilisticActor( module=module, - in_keys=["loc", "scale"], + in_keys=actor_in_keys, + out_keys=[action_key], spec=action_spec, - distribution_class=TanhNormal, + distribution_class=distribution_class, return_log_prob=True, log_prob_key=sample_log_prob_key, ) @@ -8477,7 +8759,15 @@ def _create_mock_value( return value.to(device) def _create_mock_common_layer_setup( - self, n_obs=3, n_act=4, ncells=4, batch=2, n_hidden=2, T=10 + self, + n_obs=3, + n_act=4, + ncells=4, + batch=2, + n_hidden=2, + T=10, + composite_action_dist=False, + sample_log_prob_key="sample_log_prob", ): common_net = MLP( num_cells=ncells, @@ -8498,10 +8788,11 @@ def _create_mock_common_layer_setup( out_features=1, ) batch = [batch, T] + action = torch.randn(*batch, n_act) td = TensorDict( { "obs": torch.randn(*batch, n_obs), - "action": torch.randn(*batch, n_act), + "action": {"action1": action} if composite_action_dist else action, "sample_log_prob": torch.randn(*batch), "done": torch.zeros(*batch, 1, dtype=torch.bool), "terminated": torch.zeros(*batch, 1, dtype=torch.bool), @@ -8516,14 +8807,36 @@ def _create_mock_common_layer_setup( names=[None, "time"], ) common = Mod(common_net, in_keys=["obs"], out_keys=["hidden"]) + + if composite_action_dist: + distribution_class = functools.partial( + CompositeDistribution, + distribution_map={ + "action1": TanhNormal, + }, + name_map={ + "action1": ("action", "action1"), + }, + log_prob_key=sample_log_prob_key, + aggregate_probabilities=True, + ) + module_out_keys = [ + ("params", "action1", "loc"), + ("params", "action1", "scale"), + ] + actor_in_keys = ["params"] + else: + distribution_class = TanhNormal + module_out_keys = actor_in_keys = ["loc", "scale"] + actor = ProbSeq( common, Mod(actor_net, in_keys=["hidden"], out_keys=["param"]), - Mod(NormalParamExtractor(), in_keys=["param"], out_keys=["loc", "scale"]), + Mod(NormalParamExtractor(), in_keys=["param"], out_keys=module_out_keys), ProbMod( - in_keys=["loc", "scale"], + in_keys=actor_in_keys, out_keys=["action"], - distribution_class=TanhNormal, + distribution_class=distribution_class, ), ) critic = Seq( @@ -8547,6 +8860,7 @@ def _create_seq_mock_data_a2c( done_key="done", terminated_key="terminated", sample_log_prob_key="sample_log_prob", + composite_action_dist=False, ): # create a tensordict total_obs = torch.randn(batch, T + 1, obs_dim, device=device) @@ -8562,8 +8876,11 @@ def _create_seq_mock_data_a2c( done = torch.zeros(batch, T, 1, dtype=torch.bool, device=device) terminated = torch.zeros(batch, T, 1, dtype=torch.bool, device=device) mask = ~torch.zeros(batch, T, dtype=torch.bool, device=device) + action = action.masked_fill_(~mask.unsqueeze(-1), 0.0) params_mean = torch.randn_like(action) / 10 params_scale = torch.rand_like(action) / 10 + loc = params_mean.masked_fill_(~mask.unsqueeze(-1), 0.0) + scale = params_scale.masked_fill_(~mask.unsqueeze(-1), 0.0) td = TensorDict( batch_size=(batch, T), source={ @@ -8575,17 +8892,21 @@ def _create_seq_mock_data_a2c( reward_key: reward.masked_fill_(~mask.unsqueeze(-1), 0.0), }, "collector": {"mask": mask}, - action_key: action.masked_fill_(~mask.unsqueeze(-1), 0.0), + action_key: {"action1": action} if composite_action_dist else action, sample_log_prob_key: torch.randn_like(action[..., 1]).masked_fill_( ~mask, 0.0 ) / 10, - "loc": params_mean.masked_fill_(~mask.unsqueeze(-1), 0.0), - "scale": params_scale.masked_fill_(~mask.unsqueeze(-1), 0.0), }, device=device, names=[None, "time"], ) + if composite_action_dist: + td[("params", "action1", "loc")] = loc + td[("params", "action1", "scale")] = scale + else: + td["loc"] = loc + td["scale"] = scale return td @pytest.mark.parametrize("gradient_mode", (True, False)) @@ -8593,11 +8914,24 @@ def _create_seq_mock_data_a2c( @pytest.mark.parametrize("device", get_default_devices()) @pytest.mark.parametrize("td_est", list(ValueEstimators) + [None]) @pytest.mark.parametrize("functional", (True, False)) - def test_a2c(self, device, gradient_mode, advantage, td_est, functional): + @pytest.mark.parametrize("composite_action_dist", [True, False]) + def test_a2c( + self, + device, + gradient_mode, + advantage, + td_est, + functional, + composite_action_dist, + ): torch.manual_seed(self.seed) - td = self._create_seq_mock_data_a2c(device=device) + td = self._create_seq_mock_data_a2c( + device=device, composite_action_dist=composite_action_dist + ) - actor = self._create_mock_actor(device=device) + actor = self._create_mock_actor( + device=device, composite_action_dist=composite_action_dist + ) value = self._create_mock_value(device=device) if advantage == "gae": advantage = GAE( @@ -8630,14 +8964,24 @@ def test_a2c(self, device, gradient_mode, advantage, td_est, functional): functional=functional, ) + def set_requires_grad(tensor, requires_grad): + tensor.requires_grad = requires_grad + return tensor + # Check error is raised when actions require grads - td["action"].requires_grad = True + if composite_action_dist: + td["action"].apply_(lambda x: set_requires_grad(x, True)) + else: + td["action"].requires_grad = True with pytest.raises( RuntimeError, - match="tensordict stored action require grad.", + match="tensordict stored action requires grad.", ): _ = loss_fn._log_probs(td) - td["action"].requires_grad = False + if composite_action_dist: + td["action"].apply_(lambda x: set_requires_grad(x, False)) + else: + td["action"].requires_grad = False td = td.exclude(loss_fn.tensor_keys.value_target) if advantage is not None: @@ -8678,9 +9022,12 @@ def test_a2c(self, device, gradient_mode, advantage, td_est, functional): @pytest.mark.parametrize("gradient_mode", (True, False)) @pytest.mark.parametrize("device", get_default_devices()) - def test_a2c_state_dict(self, device, gradient_mode): + @pytest.mark.parametrize("composite_action_dist", [True, False]) + def test_a2c_state_dict(self, device, gradient_mode, composite_action_dist): torch.manual_seed(self.seed) - actor = self._create_mock_actor(device=device) + actor = self._create_mock_actor( + device=device, composite_action_dist=composite_action_dist + ) value = self._create_mock_value(device=device) loss_fn = A2CLoss(actor, value, loss_critic_type="l2") sd = loss_fn.state_dict() @@ -8688,23 +9035,36 @@ def test_a2c_state_dict(self, device, gradient_mode): loss_fn2.load_state_dict(sd) @pytest.mark.parametrize("separate_losses", [False, True]) - def test_a2c_separate_losses(self, separate_losses): + @pytest.mark.parametrize("composite_action_dist", [True, False]) + def test_a2c_separate_losses(self, separate_losses, composite_action_dist): torch.manual_seed(self.seed) - actor, critic, common, td = self._create_mock_common_layer_setup() + actor, critic, common, td = self._create_mock_common_layer_setup( + composite_action_dist=composite_action_dist + ) loss_fn = A2CLoss( actor_network=actor, critic_network=critic, separate_losses=separate_losses, ) + def set_requires_grad(tensor, requires_grad): + tensor.requires_grad = requires_grad + return tensor + # Check error is raised when actions require grads - td["action"].requires_grad = True + if composite_action_dist: + td["action"].apply_(lambda x: set_requires_grad(x, True)) + else: + td["action"].requires_grad = True with pytest.raises( RuntimeError, - match="tensordict stored action require grad.", + match="tensordict stored action requires grad.", ): _ = loss_fn._log_probs(td) - td["action"].requires_grad = False + if composite_action_dist: + td["action"].apply_(lambda x: set_requires_grad(x, False)) + else: + td["action"].requires_grad = False td = td.exclude(loss_fn.tensor_keys.value_target) loss = loss_fn(td) @@ -8748,13 +9108,18 @@ def test_a2c_separate_losses(self, separate_losses): @pytest.mark.parametrize("gradient_mode", (True, False)) @pytest.mark.parametrize("advantage", ("gae", "vtrace", "td", "td_lambda", None)) @pytest.mark.parametrize("device", get_default_devices()) - def test_a2c_diff(self, device, gradient_mode, advantage): + @pytest.mark.parametrize("composite_action_dist", [True, False]) + def test_a2c_diff(self, device, gradient_mode, advantage, composite_action_dist): if pack_version.parse(torch.__version__) > pack_version.parse("1.14"): raise pytest.skip("make_functional_with_buffers needs to be changed") torch.manual_seed(self.seed) - td = self._create_seq_mock_data_a2c(device=device) + td = self._create_seq_mock_data_a2c( + device=device, composite_action_dist=composite_action_dist + ) - actor = self._create_mock_actor(device=device) + actor = self._create_mock_actor( + device=device, composite_action_dist=composite_action_dist + ) value = self._create_mock_value(device=device) if advantage == "gae": advantage = GAE( @@ -8824,8 +9189,9 @@ def test_a2c_diff(self, device, gradient_mode, advantage): ValueEstimators.TDLambda, ], ) - def test_a2c_tensordict_keys(self, td_est): - actor = self._create_mock_actor() + @pytest.mark.parametrize("composite_action_dist", [True, False]) + def test_a2c_tensordict_keys(self, td_est, composite_action_dist): + actor = self._create_mock_actor(composite_action_dist=composite_action_dist) value = self._create_mock_value() loss_fn = A2CLoss(actor, value, loss_critic_type="l2") @@ -8870,7 +9236,10 @@ def test_a2c_tensordict_keys(self, td_est): ) @pytest.mark.parametrize("advantage", ("gae", "vtrace", None)) @pytest.mark.parametrize("device", get_default_devices()) - def test_a2c_tensordict_keys_run(self, device, advantage, td_est): + @pytest.mark.parametrize("composite_action_dist", [True, False]) + def test_a2c_tensordict_keys_run( + self, device, advantage, td_est, composite_action_dist + ): """Test A2C loss module with non-default tensordict keys.""" torch.manual_seed(self.seed) gradient_mode = True @@ -8890,10 +9259,14 @@ def test_a2c_tensordict_keys_run(self, device, advantage, td_est): done_key=done_key, terminated_key=terminated_key, sample_log_prob_key=sample_log_prob_key, + composite_action_dist=composite_action_dist, ) actor = self._create_mock_actor( - device=device, sample_log_prob_key=sample_log_prob_key + device=device, + sample_log_prob_key=sample_log_prob_key, + composite_action_dist=composite_action_dist, + action_key=action_key, ) value = self._create_mock_value(device=device, out_keys=[value_key]) if advantage == "gae": @@ -8972,12 +9345,26 @@ def test_a2c_tensordict_keys_run(self, device, advantage, td_est): @pytest.mark.parametrize("reward_key", ["reward", "reward2"]) @pytest.mark.parametrize("done_key", ["done", "done2"]) @pytest.mark.parametrize("terminated_key", ["terminated", "terminated2"]) + @pytest.mark.parametrize( + "composite_action_dist", + [ + False, + ], + ) def test_a2c_notensordict( - self, action_key, observation_key, reward_key, done_key, terminated_key + self, + action_key, + observation_key, + reward_key, + done_key, + terminated_key, + composite_action_dist, ): torch.manual_seed(self.seed) - actor = self._create_mock_actor(observation_key=observation_key) + actor = self._create_mock_actor( + observation_key=observation_key, composite_action_dist=composite_action_dist + ) value = self._create_mock_value(observation_key=observation_key) td = self._create_seq_mock_data_a2c( action_key=action_key, @@ -8985,6 +9372,7 @@ def test_a2c_notensordict( reward_key=reward_key, done_key=done_key, terminated_key=terminated_key, + composite_action_dist=composite_action_dist, ) loss = A2CLoss(actor, value) @@ -9029,15 +9417,20 @@ def test_a2c_notensordict( assert loss_critic == loss_val_td["loss_critic"] @pytest.mark.parametrize("reduction", [None, "none", "mean", "sum"]) - def test_a2c_reduction(self, reduction): + @pytest.mark.parametrize("composite_action_dist", [True, False]) + def test_a2c_reduction(self, reduction, composite_action_dist): torch.manual_seed(self.seed) device = ( torch.device("cpu") if torch.cuda.device_count() == 0 else torch.device("cuda") ) - td = self._create_seq_mock_data_a2c(device=device) - actor = self._create_mock_actor(device=device) + td = self._create_seq_mock_data_a2c( + device=device, composite_action_dist=composite_action_dist + ) + actor = self._create_mock_actor( + device=device, composite_action_dist=composite_action_dist + ) value = self._create_mock_value(device=device) advantage = GAE( gamma=0.9, @@ -9064,10 +9457,15 @@ def test_a2c_reduction(self, reduction): @pytest.mark.parametrize("device", get_default_devices()) @pytest.mark.parametrize("clip_value", [True, None, 0.5, torch.tensor(0.5)]) - def test_a2c_value_clipping(self, clip_value, device): + @pytest.mark.parametrize("composite_action_dist", [True, False]) + def test_a2c_value_clipping(self, clip_value, device, composite_action_dist): torch.manual_seed(self.seed) - td = self._create_seq_mock_data_a2c(device=device) - actor = self._create_mock_actor(device=device) + td = self._create_seq_mock_data_a2c( + device=device, composite_action_dist=composite_action_dist + ) + actor = self._create_mock_actor( + device=device, composite_action_dist=composite_action_dist + ) value = self._create_mock_value(device=device) advantage = GAE( gamma=0.9, @@ -9151,7 +9549,7 @@ def test_reinforce_value_net( distribution_class=TanhNormal, return_log_prob=True, in_keys=["loc", "scale"], - spec=UnboundedContinuousTensorSpec(n_act), + spec=Unbounded(n_act), ) if advantage == "gae": advantage = GAE( @@ -9261,7 +9659,7 @@ def test_reinforce_tensordict_keys(self, td_est): distribution_class=TanhNormal, return_log_prob=True, in_keys=["loc", "scale"], - spec=UnboundedContinuousTensorSpec(n_act), + spec=Unbounded(n_act), ) loss_fn = ReinforceLoss( @@ -9455,7 +9853,7 @@ def test_reinforce_notensordict( distribution_class=TanhNormal, return_log_prob=True, in_keys=["loc", "scale"], - spec=UnboundedContinuousTensorSpec(n_act), + spec=Unbounded(n_act), ) loss = ReinforceLoss(actor_network=actor_net, critic_network=value_net) loss.set_keys( @@ -9631,8 +10029,8 @@ def _create_world_model_model(self, rssm_hidden_dim, state_dim, mlp_num_units=13 ContinuousActionConvMockEnv(pixel_shape=[3, *self.img_size]) ) default_dict = { - "state": UnboundedContinuousTensorSpec(state_dim), - "belief": UnboundedContinuousTensorSpec(rssm_hidden_dim), + "state": Unbounded(state_dim), + "belief": Unbounded(rssm_hidden_dim), } mock_env.append_transform( TensorDictPrimer(random=False, default_value=0, **default_dict) @@ -9708,8 +10106,8 @@ def _create_mb_env(self, rssm_hidden_dim, state_dim, mlp_num_units=13): ContinuousActionConvMockEnv(pixel_shape=[3, *self.img_size]) ) default_dict = { - "state": UnboundedContinuousTensorSpec(state_dim), - "belief": UnboundedContinuousTensorSpec(rssm_hidden_dim), + "state": Unbounded(state_dim), + "belief": Unbounded(rssm_hidden_dim), } mock_env.append_transform( TensorDictPrimer(random=False, default_value=0, **default_dict) @@ -9759,8 +10157,8 @@ def _create_actor_model(self, rssm_hidden_dim, state_dim, mlp_num_units=13): ContinuousActionConvMockEnv(pixel_shape=[3, *self.img_size]) ) default_dict = { - "state": UnboundedContinuousTensorSpec(state_dim), - "belief": UnboundedContinuousTensorSpec(rssm_hidden_dim), + "state": Unbounded(state_dim), + "belief": Unbounded(rssm_hidden_dim), } mock_env.append_transform( TensorDictPrimer(random=False, default_value=0, **default_dict) @@ -10049,7 +10447,7 @@ class TestOnlineDT(LossModuleTestBase): def _create_mock_actor(self, batch=2, obs_dim=3, action_dim=4, device="cpu"): # Actor - action_spec = BoundedTensorSpec( + action_spec = Bounded( -torch.ones(action_dim), torch.ones(action_dim), (action_dim,) ) net = nn.Sequential(nn.Linear(obs_dim, 2 * action_dim), NormalParamExtractor()) @@ -10281,7 +10679,7 @@ class TestDT(LossModuleTestBase): def _create_mock_actor(self, batch=2, obs_dim=3, action_dim=4, device="cpu"): # Actor - action_spec = BoundedTensorSpec( + action_spec = Bounded( -torch.ones(action_dim), torch.ones(action_dim), (action_dim,) ) net = nn.Sequential(nn.Linear(obs_dim, 2 * action_dim), NormalParamExtractor()) @@ -10459,6 +10857,227 @@ def test_dt_reduction(self, reduction): assert loss["loss"].shape == torch.Size([]) +class TestGAIL(LossModuleTestBase): + seed = 0 + + def _create_mock_discriminator( + self, batch=2, obs_dim=3, action_dim=4, device="cpu" + ): + # Discriminator + body = TensorDictModule( + MLP( + in_features=obs_dim + action_dim, + out_features=32, + depth=1, + num_cells=32, + activation_class=torch.nn.ReLU, + activate_last_layer=True, + ), + in_keys=["observation", "action"], + out_keys="hidden", + ) + head = TensorDictModule( + MLP( + in_features=32, + out_features=1, + depth=0, + num_cells=32, + activation_class=torch.nn.Sigmoid, + activate_last_layer=True, + ), + in_keys="hidden", + out_keys="d_logits", + ) + discriminator = TensorDictSequential(body, head) + + return discriminator.to(device) + + def _create_mock_data_gail(self, batch=2, obs_dim=3, action_dim=4, device="cpu"): + # create a tensordict + obs = torch.randn(batch, obs_dim, device=device) + action = torch.randn(batch, action_dim, device=device).clamp(-1, 1) + td = TensorDict( + batch_size=(batch,), + source={ + "observation": obs, + "action": action, + "collector_action": action, + "collector_observation": obs, + }, + device=device, + ) + return td + + def _create_seq_mock_data_gail( + self, batch=2, T=4, obs_dim=3, action_dim=4, device="cpu" + ): + # create a tensordict + obs = torch.randn(batch, T, obs_dim, device=device) + action = torch.randn(batch, T, action_dim, device=device).clamp(-1, 1) + + td = TensorDict( + batch_size=(batch, T), + source={ + "observation": obs, + "action": action, + "collector_action": action, + "collector_observation": obs, + }, + device=device, + ) + return td + + def test_gail_tensordict_keys(self): + discriminator = self._create_mock_discriminator() + loss_fn = GAILLoss(discriminator) + + default_keys = { + "expert_action": "action", + "expert_observation": "observation", + "collector_action": "collector_action", + "collector_observation": "collector_observation", + "discriminator_pred": "d_logits", + } + + self.tensordict_keys_test( + loss_fn, + default_keys=default_keys, + ) + + @pytest.mark.parametrize("device", get_default_devices()) + @pytest.mark.parametrize("use_grad_penalty", [True, False]) + @pytest.mark.parametrize("gp_lambda", [0.1, 1.0]) + def test_gail_notensordict(self, device, use_grad_penalty, gp_lambda): + torch.manual_seed(self.seed) + discriminator = self._create_mock_discriminator(device=device) + loss_fn = GAILLoss( + discriminator, use_grad_penalty=use_grad_penalty, gp_lambda=gp_lambda + ) + + tensordict = self._create_mock_data_gail(device=device) + + in_keys = self._flatten_in_keys(loss_fn.in_keys) + kwargs = dict(tensordict.flatten_keys("_").select(*in_keys)) + + loss_val_td = loss_fn(tensordict) + if use_grad_penalty: + loss_val, _ = loss_fn(**kwargs) + else: + loss_val = loss_fn(**kwargs) + + torch.testing.assert_close(loss_val_td.get("loss"), loss_val) + # test select + loss_fn.select_out_keys("loss") + if torch.__version__ >= "2.0.0": + loss_discriminator = loss_fn(**kwargs) + else: + with pytest.raises( + RuntimeError, + match="You are likely using tensordict.nn.dispatch with keyword arguments", + ): + loss_discriminator = loss_fn(**kwargs) + return + assert loss_discriminator == loss_val_td["loss"] + + @pytest.mark.parametrize("device", get_available_devices()) + @pytest.mark.parametrize("use_grad_penalty", [True, False]) + @pytest.mark.parametrize("gp_lambda", [0.1, 1.0]) + def test_gail(self, device, use_grad_penalty, gp_lambda): + torch.manual_seed(self.seed) + td = self._create_mock_data_gail(device=device) + + discriminator = self._create_mock_discriminator(device=device) + + loss_fn = GAILLoss( + discriminator, use_grad_penalty=use_grad_penalty, gp_lambda=gp_lambda + ) + loss = loss_fn(td) + loss_transformer = loss["loss"] + loss_transformer.backward(retain_graph=True) + named_parameters = loss_fn.named_parameters() + + for name, p in named_parameters: + if p.grad is not None and p.grad.norm() > 0.0: + assert "discriminator" in name + if p.grad is None: + assert "discriminator" not in name + loss_fn.zero_grad() + + sum([loss_transformer]).backward() + named_parameters = list(loss_fn.named_parameters()) + named_buffers = list(loss_fn.named_buffers()) + + assert len({p for n, p in named_parameters}) == len(list(named_parameters)) + assert len({p for n, p in named_buffers}) == len(list(named_buffers)) + + for name, p in named_parameters: + assert p.grad.norm() > 0.0, f"parameter {name} has a null gradient" + + @pytest.mark.parametrize("device", get_available_devices()) + def test_gail_state_dict(self, device): + torch.manual_seed(self.seed) + + discriminator = self._create_mock_discriminator(device=device) + + loss_fn = GAILLoss(discriminator) + sd = loss_fn.state_dict() + loss_fn2 = GAILLoss(discriminator) + loss_fn2.load_state_dict(sd) + + @pytest.mark.parametrize("device", get_available_devices()) + @pytest.mark.parametrize("use_grad_penalty", [True, False]) + @pytest.mark.parametrize("gp_lambda", [0.1, 1.0]) + def test_seq_gail(self, device, use_grad_penalty, gp_lambda): + torch.manual_seed(self.seed) + td = self._create_seq_mock_data_gail(device=device) + + discriminator = self._create_mock_discriminator(device=device) + + loss_fn = GAILLoss( + discriminator, use_grad_penalty=use_grad_penalty, gp_lambda=gp_lambda + ) + loss = loss_fn(td) + loss_transformer = loss["loss"] + loss_transformer.backward(retain_graph=True) + named_parameters = loss_fn.named_parameters() + + for name, p in named_parameters: + if p.grad is not None and p.grad.norm() > 0.0: + assert "discriminator" in name + if p.grad is None: + assert "discriminator" not in name + loss_fn.zero_grad() + + sum([loss_transformer]).backward() + named_parameters = list(loss_fn.named_parameters()) + named_buffers = list(loss_fn.named_buffers()) + + assert len({p for n, p in named_parameters}) == len(list(named_parameters)) + assert len({p for n, p in named_buffers}) == len(list(named_buffers)) + + for name, p in named_parameters: + assert p.grad.norm() > 0.0, f"parameter {name} has a null gradient" + + @pytest.mark.parametrize("reduction", [None, "none", "mean", "sum"]) + @pytest.mark.parametrize("use_grad_penalty", [True, False]) + @pytest.mark.parametrize("gp_lambda", [0.1, 1.0]) + def test_gail_reduction(self, reduction, use_grad_penalty, gp_lambda): + torch.manual_seed(self.seed) + device = ( + torch.device("cpu") + if torch.cuda.device_count() == 0 + else torch.device("cuda") + ) + td = self._create_mock_data_gail(device=device) + discriminator = self._create_mock_discriminator(device=device) + loss_fn = GAILLoss(discriminator, reduction=reduction) + loss = loss_fn(td) + if reduction == "none": + assert loss["loss"].shape == (td["observation"].shape[0], 1) + else: + assert loss["loss"].shape == torch.Size([]) + + @pytest.mark.skipif( not _has_functorch, reason=f"functorch not installed: {FUNCTORCH_ERR}" ) @@ -10474,7 +11093,7 @@ def _create_mock_actor( observation_key="observation", ): # Actor - action_spec = BoundedTensorSpec( + action_spec = Bounded( -torch.ones(action_dim), torch.ones(action_dim), (action_dim,) ) net = nn.Sequential(nn.Linear(obs_dim, 2 * action_dim), NormalParamExtractor()) @@ -11285,7 +11904,7 @@ def _create_mock_actor( action_key="action", ): # Actor - action_spec = OneHotDiscreteTensorSpec(action_dim) + action_spec = OneHot(action_dim) net = nn.Sequential(nn.Linear(obs_dim, 2 * action_dim), NormalParamExtractor()) module = TensorDictModule(net, in_keys=[observation_key], out_keys=["logits"]) actor = ProbabilisticActor( @@ -14750,7 +15369,7 @@ def _forward_value_estimator_keys(self, **kwargs) -> None: class MyLoss3(MyLoss2): @dataclass class _AcceptedKeys: - some_key = "some_value" + some_key: str = "some_value" loss_module = MyLoss3() assert loss_module.tensor_keys.some_key == "some_value" @@ -15112,6 +15731,68 @@ def __init__(self): assert p.device == dest +@pytest.mark.skipif(TORCH_VERSION < "2.5", reason="requires torch>=2.5") +def test_exploration_compile(): + m = ProbabilisticTensorDictModule( + in_keys=["loc", "scale"], + out_keys=["sample"], + distribution_class=torch.distributions.Normal, + ) + + # class set_exploration_type_random(set_exploration_type): + # __init__ = object.__init__ + # type = ExplorationType.RANDOM + it = exploration_type() + + @torch.compile(fullgraph=True) + def func(t): + with set_exploration_type(ExplorationType.RANDOM): + t0 = m(t.clone()) + t1 = m(t.clone()) + return t0, t1 + + t = TensorDict(loc=torch.randn(3), scale=torch.rand(3)) + t0, t1 = func(t) + assert (t0["sample"] != t1["sample"]).any() + assert it == exploration_type() + + @torch.compile(fullgraph=True) + def func(t): + with set_exploration_type(ExplorationType.MEAN): + t0 = m(t.clone()) + t1 = m(t.clone()) + return t0, t1 + + t = TensorDict(loc=torch.randn(3), scale=torch.rand(3)) + t0, t1 = func(t) + assert (t0["sample"] == t1["sample"]).all() + assert it == exploration_type() + + @torch.compile(fullgraph=True) + @set_exploration_type(ExplorationType.RANDOM) + def func(t): + t0 = m(t.clone()) + t1 = m(t.clone()) + return t0, t1 + + t = TensorDict(loc=torch.randn(3), scale=torch.rand(3)) + t0, t1 = func(t) + assert (t0["sample"] != t1["sample"]).any() + assert it == exploration_type() + + @torch.compile(fullgraph=True) + @set_exploration_type(ExplorationType.MEAN) + def func(t): + t0 = m(t.clone()) + t1 = m(t.clone()) + return t0, t1 + + t = TensorDict(loc=torch.randn(3), scale=torch.rand(3)) + t0, t1 = func(t) + assert (t0["sample"] == t1["sample"]).all() + assert it == exploration_type() + + def test_loss_exploration(): class DummyLoss(LossModule): def forward(self, td, mode): diff --git a/test/test_distributed.py b/test/test_distributed.py index 40b4f5eae44..fd369f64962 100644 --- a/test/test_distributed.py +++ b/test/test_distributed.py @@ -405,7 +405,7 @@ def _test_distributed_collector_updatepolicy( MultiaSyncDataCollector, ], ) - @pytest.mark.parametrize("update_interval", [1_000_000, 1]) + @pytest.mark.parametrize("update_interval", [1]) def test_distributed_collector_updatepolicy(self, collector_class, update_interval): """Testing various collector classes to be used in nodes.""" queue = mp.Queue(1) diff --git a/test/test_distributions.py b/test/test_distributions.py index 53bfda343a2..8a5b651531e 100644 --- a/test/test_distributions.py +++ b/test/test_distributions.py @@ -190,8 +190,8 @@ def test_truncnormal(self, min, max, vecs, upscale, shape, device): d = TruncatedNormal( *vecs, upscale=upscale, - min=min, - max=max, + low=min, + high=max, ) assert d.device == device for _ in range(100): @@ -218,7 +218,7 @@ def test_truncnormal_against_scipy(self): high = 2 low = -1 log_pi_x = TruncatedNormal( - mu, sigma, min=low, max=high, tanh_loc=False + mu, sigma, low=low, high=high, tanh_loc=False ).log_prob(x) pi_x = torch.exp(log_pi_x) log_pi_x.backward(torch.ones_like(log_pi_x)) @@ -264,8 +264,8 @@ def test_truncnormal_mode(self, min, max, vecs, upscale, shape, device): d = TruncatedNormal( *vecs, upscale=upscale, - min=min, - max=max, + low=min, + high=max, ) assert d.mode is not None assert d.entropy() is not None diff --git a/test/test_env.py b/test/test_env.py index dee03c06e7d..9602c596f22 100644 --- a/test/test_env.py +++ b/test/test_env.py @@ -4,6 +4,7 @@ # LICENSE file in the root directory of this source tree. import argparse +import contextlib import functools import gc import os.path @@ -66,12 +67,7 @@ from torch import nn from torchrl.collectors import MultiSyncDataCollector, SyncDataCollector -from torchrl.data.tensor_specs import ( - CompositeSpec, - DiscreteTensorSpec, - NonTensorSpec, - UnboundedContinuousTensorSpec, -) +from torchrl.data.tensor_specs import Categorical, Composite, NonTensor, Unbounded from torchrl.envs import ( CatFrames, CatTensors, @@ -495,6 +491,26 @@ def test_mb_env_batch_lock(self, device, seed=0): class TestParallel: + def test_create_env_fn(self, maybe_fork_ParallelEnv): + def make_env(): + return GymEnv(PENDULUM_VERSIONED()) + + with pytest.raises( + RuntimeError, match="len\\(create_env_fn\\) and num_workers mismatch" + ): + maybe_fork_ParallelEnv(4, [make_env, make_env]) + + def test_create_env_kwargs(self, maybe_fork_ParallelEnv): + def make_env(): + return GymEnv(PENDULUM_VERSIONED()) + + with pytest.raises( + RuntimeError, match="len\\(create_env_kwargs\\) and num_workers mismatch" + ): + maybe_fork_ParallelEnv( + 4, make_env, create_env_kwargs=[{"seed": 0}, {"seed": 1}] + ) + @pytest.mark.skipif( not torch.cuda.device_count(), reason="No cuda device detected." ) @@ -1125,6 +1141,25 @@ def env_fn2(seed): env1.close() env2.close() + @pytest.mark.parametrize("parallel", [True, False]) + def test_parallel_env_update_kwargs(self, parallel, maybe_fork_ParallelEnv): + def make_env(seed=None): + env = DiscreteActionConvMockEnv() + if seed is not None: + env.set_seed(seed) + return env + + _class = maybe_fork_ParallelEnv if parallel else SerialEnv + env = _class( + num_workers=2, + create_env_fn=make_env, + create_env_kwargs=[{"seed": 0}, {"seed": 1}], + ) + with pytest.raises( + RuntimeError, match="len\\(kwargs\\) and num_workers mismatch" + ): + env.update_kwargs([{"seed": 42}]) + @pytest.mark.parametrize("batch_size", [(32, 5), (4,), (1,), ()]) @pytest.mark.parametrize("n_workers", [2, 1]) def test_parallel_env_reset_flag( @@ -1908,18 +1943,12 @@ def test_info_dict_reader(self, device, seed=0): env.set_info_dict_reader( default_info_dict_reader( ["x_position"], - spec=CompositeSpec( - x_position=UnboundedContinuousTensorSpec( - dtype=torch.float64, shape=() - ) - ), + spec=Composite(x_position=Unbounded(dtype=torch.float64, shape=())), ) ) assert "x_position" in env.observation_spec.keys() - assert isinstance( - env.observation_spec["x_position"], UnboundedContinuousTensorSpec - ) + assert isinstance(env.observation_spec["x_position"], Unbounded) tensordict = env.reset() tensordict = env.rand_step(tensordict) @@ -1932,13 +1961,13 @@ def test_info_dict_reader(self, device, seed=0): ) for spec in ( - {"x_position": UnboundedContinuousTensorSpec((), dtype=torch.float64)}, + {"x_position": Unbounded((), dtype=torch.float64)}, # None, - CompositeSpec( - x_position=UnboundedContinuousTensorSpec((), dtype=torch.float64), + Composite( + x_position=Unbounded((), dtype=torch.float64), shape=[], ), - [UnboundedContinuousTensorSpec((), dtype=torch.float64)], + [Unbounded((), dtype=torch.float64)], ): env2 = GymWrapper(gym.make("HalfCheetah-v4")) env2.set_info_dict_reader( @@ -2079,7 +2108,7 @@ def main_penv(j, q=None): ], ) spec = env_p.action_spec - policy = TestConcurrentEnvs.Policy(CompositeSpec(action=spec.to(device))) + policy = TestConcurrentEnvs.Policy(Composite(action=spec.to(device))) N = 10 r_p = [] r_s = [] @@ -2113,7 +2142,7 @@ def main_collector(j, q=None): lambda i=i: CountingEnv(i, device=device) for i in range(j, j + n_workers) ] spec = make_envs[0]().action_spec - policy = TestConcurrentEnvs.Policy(CompositeSpec(action=spec)) + policy = TestConcurrentEnvs.Policy(Composite(action=spec)) collector = MultiSyncDataCollector( make_envs, policy, @@ -2225,7 +2254,7 @@ def test_nested_env(self, envclass): else: raise NotImplementedError reset = env.reset() - assert not isinstance(env.reward_spec, CompositeSpec) + assert not isinstance(env.reward_spec, Composite) for done_key in env.done_keys: assert ( env.full_done_spec[done_key] @@ -2496,8 +2525,8 @@ def test_mocking_envs(envclass): class TestTerminatedOrTruncated: @pytest.mark.parametrize("done_key", ["done", "terminated", "truncated"]) def test_root_prevail(self, done_key): - _spec = DiscreteTensorSpec(2, shape=(), dtype=torch.bool) - spec = CompositeSpec({done_key: _spec, ("agent", done_key): _spec}) + _spec = Categorical(2, shape=(), dtype=torch.bool) + spec = Composite({done_key: _spec, ("agent", done_key): _spec}) data = TensorDict({done_key: [False], ("agent", done_key): [True, False]}, []) assert not _terminated_or_truncated(data) assert not _terminated_or_truncated(data, full_done_spec=spec) @@ -2560,8 +2589,8 @@ def test_terminated_or_truncated_nospec(self): def test_terminated_or_truncated_spec(self): done_shape = (2, 1) nested_done_shape = (2, 3, 1) - spec = CompositeSpec( - done=DiscreteTensorSpec(2, shape=done_shape, dtype=torch.bool), + spec = Composite( + done=Categorical(2, shape=done_shape, dtype=torch.bool), shape=[ 2, ], @@ -2578,12 +2607,12 @@ def test_terminated_or_truncated_spec(self): ) assert data.get("_reset", None) is None - spec = CompositeSpec( + spec = Composite( { - ("agent", "done"): DiscreteTensorSpec( + ("agent", "done"): Categorical( 2, shape=nested_done_shape, dtype=torch.bool ), - ("nested", "done"): DiscreteTensorSpec( + ("nested", "done"): Categorical( 2, shape=nested_done_shape, dtype=torch.bool ), }, @@ -2618,11 +2647,11 @@ def test_terminated_or_truncated_spec(self): assert data["agent", "_reset"].shape == nested_done_shape assert data["nested", "_reset"].shape == nested_done_shape - spec = CompositeSpec( + spec = Composite( { - "truncated": DiscreteTensorSpec(2, shape=done_shape, dtype=torch.bool), - "terminated": DiscreteTensorSpec(2, shape=done_shape, dtype=torch.bool), - ("nested", "terminated"): DiscreteTensorSpec( + "truncated": Categorical(2, shape=done_shape, dtype=torch.bool), + "terminated": Categorical(2, shape=done_shape, dtype=torch.bool), + ("nested", "terminated"): Categorical( 2, shape=nested_done_shape, dtype=torch.bool ), }, @@ -2774,15 +2803,15 @@ def test_backprop(device, maybe_fork_ParallelEnv, share_individual_td): class DifferentiableEnv(EnvBase): def __init__(self, device): super().__init__(device=device) - self.observation_spec = CompositeSpec( - observation=UnboundedContinuousTensorSpec(3, device=device), + self.observation_spec = Composite( + observation=Unbounded(3, device=device), device=device, ) - self.action_spec = CompositeSpec( - action=UnboundedContinuousTensorSpec(3, device=device), device=device + self.action_spec = Composite( + action=Unbounded(3, device=device), device=device ) - self.reward_spec = CompositeSpec( - reward=UnboundedContinuousTensorSpec(1, device=device), device=device + self.reward_spec = Composite( + reward=Unbounded(1, device=device), device=device ) self.seed = 0 @@ -3283,7 +3312,7 @@ def _reset( return tensordict_reset def transform_observation_spec(self, observation_spec): - observation_spec["string"] = NonTensorSpec(()) + observation_spec["string"] = NonTensor(()) return observation_spec @pytest.mark.parametrize("batched", ["serial", "parallel"]) @@ -3351,6 +3380,98 @@ def test_pendulum_env(self): assert r.shape == torch.Size((5, 10)) +@pytest.mark.parametrize("device", [None, *get_default_devices()]) +@pytest.mark.parametrize("env_device", [None, *get_default_devices()]) +class TestPartialSteps: + @pytest.mark.parametrize("use_buffers", [False, True]) + def test_parallel_partial_steps( + self, use_buffers, device, env_device, maybe_fork_ParallelEnv + ): + with torch.device(device) if device is not None else contextlib.nullcontext(): + penv = maybe_fork_ParallelEnv( + 4, + lambda: CountingEnv(max_steps=10, start_val=2, device=env_device), + use_buffers=use_buffers, + device=device, + ) + td = penv.reset() + psteps = torch.zeros(4, dtype=torch.bool) + psteps[[1, 3]] = True + td.set("_step", psteps) + + td.set("action", penv.action_spec.one()) + td = penv.step(td) + assert (td[0].get("next") == 0).all() + assert (td[1].get("next") != 0).any() + assert (td[2].get("next") == 0).all() + assert (td[3].get("next") != 0).any() + + @pytest.mark.parametrize("use_buffers", [False, True]) + def test_parallel_partial_step_and_maybe_reset( + self, use_buffers, device, env_device, maybe_fork_ParallelEnv + ): + with torch.device(device) if device is not None else contextlib.nullcontext(): + penv = maybe_fork_ParallelEnv( + 4, + lambda: CountingEnv(max_steps=10, start_val=2, device=env_device), + use_buffers=use_buffers, + device=device, + ) + td = penv.reset() + psteps = torch.zeros(4, dtype=torch.bool) + psteps[[1, 3]] = True + td.set("_step", psteps) + + td.set("action", penv.action_spec.one()) + td, tdreset = penv.step_and_maybe_reset(td) + assert (td[0].get("next") == 0).all() + assert (td[1].get("next") != 0).any() + assert (td[2].get("next") == 0).all() + assert (td[3].get("next") != 0).any() + + @pytest.mark.parametrize("use_buffers", [False, True]) + def test_serial_partial_steps(self, use_buffers, device, env_device): + with torch.device(device) if device is not None else contextlib.nullcontext(): + penv = SerialEnv( + 4, + lambda: CountingEnv(max_steps=10, start_val=2, device=env_device), + use_buffers=use_buffers, + device=device, + ) + td = penv.reset() + psteps = torch.zeros(4, dtype=torch.bool) + psteps[[1, 3]] = True + td.set("_step", psteps) + + td.set("action", penv.action_spec.one()) + td = penv.step(td) + assert (td[0].get("next") == 0).all() + assert (td[1].get("next") != 0).any() + assert (td[2].get("next") == 0).all() + assert (td[3].get("next") != 0).any() + + @pytest.mark.parametrize("use_buffers", [False, True]) + def test_serial_partial_step_and_maybe_reset(self, use_buffers, device, env_device): + with torch.device(device) if device is not None else contextlib.nullcontext(): + penv = SerialEnv( + 4, + lambda: CountingEnv(max_steps=10, start_val=2, device=env_device), + use_buffers=use_buffers, + device=device, + ) + td = penv.reset() + psteps = torch.zeros(4, dtype=torch.bool) + psteps[[1, 3]] = True + td.set("_step", psteps) + + td.set("action", penv.action_spec.one()) + td = penv.step(td) + assert (td[0].get("next") == 0).all() + assert (td[1].get("next") != 0).any() + assert (td[2].get("next") == 0).all() + assert (td[3].get("next") != 0).any() + + if __name__ == "__main__": args, unknown = argparse.ArgumentParser().parse_known_args() pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown) diff --git a/test/test_exploration.py b/test/test_exploration.py index 83ee4bc4220..f6a3ab7041b 100644 --- a/test/test_exploration.py +++ b/test/test_exploration.py @@ -21,12 +21,7 @@ from torchrl._utils import _replace_last from torchrl.collectors import SyncDataCollector -from torchrl.data import ( - BoundedTensorSpec, - CompositeSpec, - DiscreteTensorSpec, - OneHotDiscreteTensorSpec, -) +from torchrl.data import Bounded, Categorical, Composite, OneHot from torchrl.envs import SerialEnv from torchrl.envs.transforms.transforms import gSDENoise, InitTracker, TransformedEnv from torchrl.envs.utils import set_exploration_type @@ -36,7 +31,7 @@ NormalParamExtractor, TanhNormal, ) -from torchrl.modules.models.exploration import LazygSDEModule +from torchrl.modules.models.exploration import ConsistentDropoutModule, LazygSDEModule from torchrl.modules.tensordict_module.actors import ( Actor, ProbabilisticActor, @@ -59,7 +54,7 @@ class TestEGreedy: @set_exploration_type(InteractionType.RANDOM) def test_egreedy(self, eps_init, module): torch.manual_seed(0) - spec = BoundedTensorSpec(1, 1, torch.Size([4])) + spec = Bounded(1, 1, torch.Size([4])) module = torch.nn.Linear(4, 4, bias=False) policy = Actor(spec=spec, module=module) @@ -91,9 +86,9 @@ def test_egreedy_masked(self, module, eps_init, spec_class): batch_size = (3, 4, 2) module = torch.nn.Linear(action_size, action_size, bias=False) if spec_class == "discrete": - spec = DiscreteTensorSpec(action_size) + spec = Categorical(action_size) else: - spec = OneHotDiscreteTensorSpec( + spec = OneHot( action_size, shape=(action_size,), ) @@ -166,7 +161,7 @@ def test_no_spec_error( action_size = 4 batch_size = (3, 4, 2) module = torch.nn.Linear(action_size, action_size, bias=False) - spec = OneHotDiscreteTensorSpec(action_size, shape=(action_size,)) + spec = OneHot(action_size, shape=(action_size,)) policy = QValueActor(spec=spec, module=module) explorative_policy = TensorDictSequential( policy, @@ -187,7 +182,7 @@ def test_no_spec_error( @pytest.mark.parametrize("module", [True, False]) def test_wrong_action_shape(self, module): torch.manual_seed(0) - spec = BoundedTensorSpec(1, 1, torch.Size([4])) + spec = Bounded(1, 1, torch.Size([4])) module = torch.nn.Linear(4, 5, bias=False) policy = Actor(spec=spec, module=module) @@ -240,7 +235,7 @@ def test_ou( device ) module = SafeModule(net, in_keys=["observation"], out_keys=["loc", "scale"]) - action_spec = BoundedTensorSpec(-torch.ones(d_act), torch.ones(d_act), (d_act,)) + action_spec = Bounded(-torch.ones(d_act), torch.ones(d_act), (d_act,)) policy = ProbabilisticActor( spec=action_spec, module=module, @@ -444,7 +439,7 @@ def test_additivegaussian_sd( pytest.skip("module raises an error if given spec=None") torch.manual_seed(seed) - action_spec = BoundedTensorSpec( + action_spec = Bounded( -torch.ones(d_act, device=device), torch.ones(d_act, device=device), (d_act,), @@ -463,9 +458,7 @@ def test_additivegaussian_sd( spec=None, ) policy = ProbabilisticActor( - spec=CompositeSpec(action=action_spec) - if spec_origin is not None - else None, + spec=Composite(action=action_spec) if spec_origin is not None else None, module=module, in_keys=["loc", "scale"], distribution_class=TanhNormal, @@ -541,7 +534,7 @@ def test_additivegaussian( device ) module = SafeModule(net, in_keys=["observation"], out_keys=["loc", "scale"]) - action_spec = BoundedTensorSpec( + action_spec = Bounded( -torch.ones(d_act, device=device), torch.ones(d_act, device=device), (d_act,), @@ -644,7 +637,7 @@ def test_no_spec_error(self, device): @pytest.mark.parametrize("safe", [True, False]) @pytest.mark.parametrize("device", get_default_devices()) @pytest.mark.parametrize( - "exploration_type", [InteractionType.RANDOM, InteractionType.MODE] + "exploration_type", [InteractionType.RANDOM, InteractionType.DETERMINISTIC] ) def test_gsde( state_dim, action_dim, gSDE, device, safe, exploration_type, batch=16, bound=0.1 @@ -670,7 +663,7 @@ def test_gsde( module = SafeModule(wrapper, in_keys=in_keys, out_keys=["loc", "scale"]) distribution_class = TanhNormal distribution_kwargs = {"low": -bound, "high": bound} - spec = BoundedTensorSpec( + spec = Bounded( -torch.ones(action_dim) * bound, torch.ones(action_dim) * bound, (action_dim,) ).to(device) @@ -708,7 +701,10 @@ def test_gsde( with set_exploration_type(exploration_type): action1 = module(td).get("action") action2 = actor(td.exclude("action")).get("action") - if gSDE or exploration_type == InteractionType.MODE: + if gSDE or exploration_type in ( + InteractionType.DETERMINISTIC, + InteractionType.MODE, + ): torch.testing.assert_close(action1, action2) else: with pytest.raises(AssertionError): @@ -742,6 +738,156 @@ def test_gsde_init(sigma_init, state_dim, action_dim, mean, std, device, learn_s ), f"failed: mean={mean}, std={std}, sigma_init={sigma_init}, actual: {sigma.mean()}" +class TestConsistentDropout: + @pytest.mark.parametrize("dropout_p", [0.0, 0.1, 0.5]) + @pytest.mark.parametrize("parallel_spec", [False, True]) + @pytest.mark.parametrize("device", get_default_devices()) + def test_consistent_dropout(self, dropout_p, parallel_spec, device): + """ + + This preliminary test seeks to ensure two things for both + ConsistentDropout and ConsistentDropoutModule: + 1. Rollout transitions generate a dropout mask as desired. + - We can easily verify the existence of a mask + 2. The dropout mask is correctly applied. + - We will check with stochastic policies whether or not + the loc and scale are the same. + """ + torch.manual_seed(0) + + # NOTE: Please only put a module with one dropout layer. + # That's how this test is constructed anyways. + @torch.no_grad + def inner_verify_routine(module, env): + # Perform transitions. + collector = SyncDataCollector( + create_env_fn=env, + policy=module, + frames_per_batch=1, + total_frames=10, + device=device, + ) + for frames in collector: + masks = [ + (key, value) + for key, value in frames.items() + if key.startswith("mask_") + ] + # Assert rollouts do indeed correctly generate the masks. + assert len(masks) == 1, ( + "Expected exactly ONE mask since we only put " + f"one dropout module, got {len(masks)}." + ) + + # Verify that the result for this batch is the same. + # Kind of Monte Carlo, to be honest. + sentinel_mask = masks[0][1].clone() + sentinel_outputs = frames.select("loc", "scale").clone() + + desired_dropout_mask = torch.full_like( + sentinel_mask, 1 / (1 - dropout_p) + ) + desired_dropout_mask[sentinel_mask == 0.0] = 0.0 + # As of 15/08/24, :meth:`~torch.nn.functional.dropout` + # is being used. Never hurts to be safe. + assert torch.allclose( + sentinel_mask, desired_dropout_mask + ), "Dropout was not scaled properly." + + new_frames = module(frames.clone()) + infer_mask = new_frames[masks[0][0]] + infer_outputs = new_frames.select("loc", "scale") + assert (infer_mask == sentinel_mask).all(), "Mask does not match" + + assert all( + [ + torch.allclose(infer_outputs[key], sentinel_outputs[key]) + for key in ("loc", "scale") + ] + ), ( + "Outputs do not match:\n " + f"{infer_outputs['loc']}\n--- vs ---\n{sentinel_outputs['loc']}" + f"{infer_outputs['scale']}\n--- vs ---\n{sentinel_outputs['scale']}" + ) + + env = SerialEnv( + 2, + ContinuousActionVecMockEnv, + ) + env = TransformedEnv(env.to(device), InitTracker()) + env = env.to(device) + # the module must work with the action spec of a single env or a serial env + if parallel_spec: + action_spec = env.action_spec + else: + action_spec = ContinuousActionVecMockEnv(device=device).action_spec + d_act = action_spec.shape[-1] + + # NOTE: Please only put a module with one dropout layer. + # That's how this test is constructed anyways. + module_td_seq = TensorDictSequential( + TensorDictModule( + nn.LazyLinear(2 * d_act), in_keys=["observation"], out_keys=["out"] + ), + ConsistentDropoutModule(p=dropout_p, in_keys="out"), + TensorDictModule( + NormalParamExtractor(), in_keys=["out"], out_keys=["loc", "scale"] + ), + ) + + policy_td_seq = ProbabilisticActor( + module=module_td_seq, + in_keys=["loc", "scale"], + distribution_class=TanhNormal, + default_interaction_type=InteractionType.RANDOM, + spec=action_spec, + ).to(device) + + # Wake up the policies + policy_td_seq(env.reset()) + + # Test. + inner_verify_routine(policy_td_seq, env) + + def test_consistent_dropout_primer(self): + import torch + + from tensordict.nn import TensorDictModule as Mod, TensorDictSequential as Seq + from torchrl.envs import SerialEnv, StepCounter + from torchrl.modules import ConsistentDropoutModule, get_primers_from_module + + torch.manual_seed(0) + + m = Seq( + Mod( + torch.nn.Linear(7, 4), + in_keys=["observation"], + out_keys=["intermediate"], + ), + ConsistentDropoutModule( + p=0.5, + input_shape=( + 2, + 4, + ), + in_keys="intermediate", + ), + Mod(torch.nn.Linear(4, 7), in_keys=["intermediate"], out_keys=["action"]), + ) + primer = get_primers_from_module(m) + env0 = ContinuousActionVecMockEnv().append_transform(StepCounter(5)) + env1 = ContinuousActionVecMockEnv().append_transform(StepCounter(6)) + env = SerialEnv(2, [lambda env=env0: env, lambda env=env1: env]) + env = env.append_transform(primer) + r = env.rollout(10, m, break_when_any_done=False) + mask = [k for k in r.keys() if k.startswith("mask")][0] + assert (r[mask][0, :5] != r[mask][0, 5:6]).any() + assert (r[mask][0, :4] == r[mask][0, 4:5]).all() + + assert (r[mask][1, :6] != r[mask][1, 6:7]).any() + assert (r[mask][1, :5] == r[mask][1, 5:6]).all() + + if __name__ == "__main__": args, unknown = argparse.ArgumentParser().parse_known_args() pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown) diff --git a/test/test_helpers.py b/test/test_helpers.py index f468eddf6ed..cf28252a318 100644 --- a/test/test_helpers.py +++ b/test/test_helpers.py @@ -30,7 +30,7 @@ MockSerialEnv, ) from packaging import version -from torchrl.data import BoundedTensorSpec, CompositeSpec +from torchrl.data import Bounded, Composite from torchrl.envs.libs.gym import _has_gym from torchrl.envs.transforms import ObservationNorm from torchrl.envs.transforms.transforms import ( @@ -259,17 +259,14 @@ def test_transformed_env_constructor_with_state_dict(from_pixels): def test_initialize_stats_from_observation_norms(device, keys, composed, initialized): obs_spec, stat_key = None, None if keys: - obs_spec = CompositeSpec( - **{ - key: BoundedTensorSpec(high=1, low=1, shape=torch.Size([1])) - for key in keys - } + obs_spec = Composite( + **{key: Bounded(high=1, low=1, shape=torch.Size([1])) for key in keys} ) stat_key = keys[0] env = ContinuousActionVecMockEnv( device=device, observation_spec=obs_spec, - action_spec=BoundedTensorSpec(low=1, high=2, shape=torch.Size((1,))), + action_spec=Bounded(low=1, high=2, shape=torch.Size((1,))), ) env.out_key = "observation" else: diff --git a/test/test_libs.py b/test/test_libs.py index 6ccbf2788a9..3d04648fd4e 100644 --- a/test/test_libs.py +++ b/test/test_libs.py @@ -5,6 +5,7 @@ import functools import gc import importlib.util +import urllib.error _has_isaac = importlib.util.find_spec("isaacgym") is not None @@ -18,10 +19,12 @@ import os import time +import urllib from contextlib import nullcontext from pathlib import Path from sys import platform from typing import Optional, Union +from unittest import mock import numpy as np import pytest @@ -36,6 +39,7 @@ PENDULUM_VERSIONED, PONG_VERSIONED, rand_reset, + retry, rollout_consistency_assertion, ) from packaging import version @@ -55,16 +59,16 @@ from torchrl._utils import implement_for, logger as torchrl_logger from torchrl.collectors.collectors import SyncDataCollector from torchrl.data import ( - BinaryDiscreteTensorSpec, - BoundedTensorSpec, - CompositeSpec, - DiscreteTensorSpec, - MultiDiscreteTensorSpec, - MultiOneHotDiscreteTensorSpec, - OneHotDiscreteTensorSpec, + Binary, + Bounded, + Categorical, + Composite, + MultiCategorical, + MultiOneHot, + OneHot, ReplayBuffer, ReplayBufferEnsemble, - UnboundedContinuousTensorSpec, + Unbounded, UnboundedDiscreteTensorSpec, ) from torchrl.data.datasets.atari_dqn import AtariDQNExperienceReplay @@ -107,9 +111,15 @@ from torchrl.envs.libs.jumanji import _has_jumanji, JumanjiEnv from torchrl.envs.libs.meltingpot import MeltingpotEnv, MeltingpotWrapper from torchrl.envs.libs.openml import OpenMLEnv +from torchrl.envs.libs.openspiel import _has_pyspiel, OpenSpielEnv, OpenSpielWrapper from torchrl.envs.libs.pettingzoo import _has_pettingzoo, PettingZooEnv from torchrl.envs.libs.robohive import _has_robohive, RoboHiveEnv from torchrl.envs.libs.smacv2 import _has_smacv2, SMACv2Env +from torchrl.envs.libs.unity_mlagents import ( + _has_unity_mlagents, + UnityMLAgentsEnv, + UnityMLAgentsWrapper, +) from torchrl.envs.libs.vmas import _has_vmas, VmasEnv, VmasWrapper from torchrl.envs.transforms import ActionMask, TransformedEnv @@ -152,6 +162,16 @@ _has_meltingpot = importlib.util.find_spec("meltingpot") is not None +_has_minigrid = importlib.util.find_spec("minigrid") is not None + + +@pytest.fixture(scope="session", autouse=True) +def maybe_init_minigrid(): + if _has_minigrid and _has_gymnasium: + import minigrid + + minigrid.register_minigrid_envs() + def get_gym_pixel_wrapper(): try: @@ -206,18 +226,16 @@ def __init__(self, arg1, *, arg2, **kwargs): assert arg1 == 1 assert arg2 == 2 - self.observation_spec = CompositeSpec( - observation=UnboundedContinuousTensorSpec((*self.batch_size, 3)), - other=CompositeSpec( - another_other=UnboundedContinuousTensorSpec((*self.batch_size, 3)), + self.observation_spec = Composite( + observation=Unbounded((*self.batch_size, 3)), + other=Composite( + another_other=Unbounded((*self.batch_size, 3)), shape=self.batch_size, ), shape=self.batch_size, ) - self.action_spec = UnboundedContinuousTensorSpec((*self.batch_size, 3)) - self.done_spec = DiscreteTensorSpec( - 2, (*self.batch_size, 1), dtype=torch.bool - ) + self.action_spec = Unbounded((*self.batch_size, 3)) + self.done_spec = Categorical(2, (*self.batch_size, 1), dtype=torch.bool) self.full_done_spec["truncated"] = self.full_done_spec["terminated"].clone() def _reset(self, tensordict): @@ -242,16 +260,14 @@ def _set_seed(self, seed): @implement_for("gym", None, "0.18") def _make_spec(self, batch_size, cat, cat_shape, multicat, multicat_shape): - return CompositeSpec( - a=UnboundedContinuousTensorSpec(shape=(*batch_size, 1)), - b=CompositeSpec( - c=cat(5, shape=cat_shape, dtype=torch.int64), shape=batch_size - ), + return Composite( + a=Unbounded(shape=(*batch_size, 1)), + b=Composite(c=cat(5, shape=cat_shape, dtype=torch.int64), shape=batch_size), d=cat(5, shape=cat_shape, dtype=torch.int64), e=multicat([2, 3], shape=(*batch_size, multicat_shape), dtype=torch.int64), - f=BoundedTensorSpec(-3, 4, shape=(*batch_size, 1)), + f=Bounded(-3, 4, shape=(*batch_size, 1)), # g=UnboundedDiscreteTensorSpec(shape=(*batch_size, 1), dtype=torch.long), - h=BinaryDiscreteTensorSpec(n=5, shape=(*batch_size, 5)), + h=Binary(n=5, shape=(*batch_size, 5)), shape=batch_size, ) @@ -259,44 +275,38 @@ def _make_spec(self, batch_size, cat, cat_shape, multicat, multicat_shape): def _make_spec( # noqa: F811 self, batch_size, cat, cat_shape, multicat, multicat_shape ): - return CompositeSpec( - a=UnboundedContinuousTensorSpec(shape=(*batch_size, 1)), - b=CompositeSpec( - c=cat(5, shape=cat_shape, dtype=torch.int64), shape=batch_size - ), + return Composite( + a=Unbounded(shape=(*batch_size, 1)), + b=Composite(c=cat(5, shape=cat_shape, dtype=torch.int64), shape=batch_size), d=cat(5, shape=cat_shape, dtype=torch.int64), e=multicat([2, 3], shape=(*batch_size, multicat_shape), dtype=torch.int64), - f=BoundedTensorSpec(-3, 4, shape=(*batch_size, 1)), + f=Bounded(-3, 4, shape=(*batch_size, 1)), g=UnboundedDiscreteTensorSpec(shape=(*batch_size, 1), dtype=torch.long), - h=BinaryDiscreteTensorSpec(n=5, shape=(*batch_size, 5)), + h=Binary(n=5, shape=(*batch_size, 5)), shape=batch_size, ) - @implement_for("gymnasium") + @implement_for("gymnasium", None, "1.0.0") def _make_spec( # noqa: F811 self, batch_size, cat, cat_shape, multicat, multicat_shape ): - return CompositeSpec( - a=UnboundedContinuousTensorSpec(shape=(*batch_size, 1)), - b=CompositeSpec( - c=cat(5, shape=cat_shape, dtype=torch.int64), shape=batch_size - ), + return Composite( + a=Unbounded(shape=(*batch_size, 1)), + b=Composite(c=cat(5, shape=cat_shape, dtype=torch.int64), shape=batch_size), d=cat(5, shape=cat_shape, dtype=torch.int64), e=multicat([2, 3], shape=(*batch_size, multicat_shape), dtype=torch.int64), - f=BoundedTensorSpec(-3, 4, shape=(*batch_size, 1)), + f=Bounded(-3, 4, shape=(*batch_size, 1)), g=UnboundedDiscreteTensorSpec(shape=(*batch_size, 1), dtype=torch.long), - h=BinaryDiscreteTensorSpec(n=5, shape=(*batch_size, 5)), + h=Binary(n=5, shape=(*batch_size, 5)), shape=batch_size, ) @pytest.mark.parametrize("categorical", [True, False]) def test_gym_spec_cast(self, categorical): batch_size = [3, 4] - cat = DiscreteTensorSpec if categorical else OneHotDiscreteTensorSpec + cat = Categorical if categorical else OneHot cat_shape = batch_size if categorical else (*batch_size, 5) - multicat = ( - MultiDiscreteTensorSpec if categorical else MultiOneHotDiscreteTensorSpec - ) + multicat = MultiCategorical if categorical else MultiOneHot multicat_shape = 2 if categorical else 5 spec = self._make_spec(batch_size, cat, cat_shape, multicat, multicat_shape) recon = _gym_to_torchrl_spec_transform( @@ -321,7 +331,7 @@ def test_gym_spec_cast_tuple_sequential(self, order): # @pytest.mark.parametrize("order", ["seq_tuple", "tuple_seq"]) @pytest.mark.parametrize("order", ["tuple_seq"]) - @implement_for("gymnasium") + @implement_for("gymnasium", None, "1.0.0") def test_gym_spec_cast_tuple_sequential(self, order): # noqa: F811 with set_gym_backend("gymnasium"): if order == "seq_tuple": @@ -837,7 +847,7 @@ def info_reader(info, tensordict): finally: set_gym_backend(gb).set() - @implement_for("gymnasium") + @implement_for("gymnasium", None, "1.0.0") def test_one_hot_and_categorical(self): # tests that one-hot and categorical work ok when an integer is expected as action cliff_walking = GymEnv("CliffWalking-v0", categorical_action_encoding=True) @@ -856,7 +866,7 @@ def test_one_hot_and_categorical(self): # noqa: F811 # versions. return - @implement_for("gymnasium") + @implement_for("gymnasium", None, "1.0.0") @pytest.mark.parametrize( "envname", ["HalfCheetah-v4", "CartPole-v1", "ALE/Pong-v5"] @@ -882,7 +892,7 @@ def test_vecenvs_wrapper(self, envname): assert env.batch_size == torch.Size([2]) check_env_specs(env) - @implement_for("gymnasium") + @implement_for("gymnasium", None, "1.0.0") # this env has Dict-based observation which is a nice thing to test @pytest.mark.parametrize( "envname", @@ -1044,7 +1054,7 @@ def test_gym_output_num(self, wrapper): # noqa: F811 finally: set_gym_backend(gym).set() - @implement_for("gymnasium") + @implement_for("gymnasium", None, "1.0.0") @pytest.mark.parametrize("wrapper", [True, False]) def test_gym_output_num(self, wrapper): # noqa: F811 # gym has 5 outputs, with truncation @@ -1147,7 +1157,7 @@ def test_vecenvs_nan(self): # noqa: F811 del c return - @implement_for("gymnasium") + @implement_for("gymnasium", None, "1.0.0") def test_vecenvs_nan(self): # noqa: F811 # new versions of gym must never return nan for next values when there is a done state torch.manual_seed(0) @@ -1288,6 +1298,24 @@ def test_resetting_strategies(self, heterogeneous): gc.collect() +@pytest.mark.skipif( + not _has_minigrid or not _has_gymnasium, reason="MiniGrid not found" +) +class TestMiniGrid: + @pytest.mark.parametrize( + "id", + [ + "BabyAI-KeyCorridorS6R3-v0", + "MiniGrid-Empty-16x16-v0", + "MiniGrid-BlockedUnlockPickup-v0", + ], + ) + def test_minigrid(self, id): + env_base = gymnasium.make(id) + env = GymWrapper(env_base) + check_env_specs(env) + + @implement_for("gym", None, "0.26") def _make_gym_environment(env_name): # noqa: F811 gym = gym_backend() @@ -1300,7 +1328,7 @@ def _make_gym_environment(env_name): # noqa: F811 return gym.make(env_name, render_mode="rgb_array") -@implement_for("gymnasium") +@implement_for("gymnasium", None, "1.0.0") def _make_gym_environment(env_name): # noqa: F811 gym = gym_backend() return gym.make(env_name, render_mode="rgb_array") @@ -2804,34 +2832,11 @@ def _minari_selected_datasets(): torch.manual_seed(0) - # We rely on sorting the keys as v0 < v1 but if the version is greater than 9 this won't work - total_keys = sorted(minari.list_remote_datasets()) - assert not any( - key[-2:] == "10" for key in total_keys - ), "You should adapt the Minari test scripts as some dataset have a version >= 10 and sorting will fail." - total_keys_splits = [key.split("-") for key in total_keys] + total_keys = sorted( + minari.list_remote_datasets(latest_version=True, compatible_minari_version=True) + ) indices = torch.randperm(len(total_keys))[:20] keys = [total_keys[idx] for idx in indices] - keys = [ - key - for key in keys - if "=0.4" in minari.list_remote_datasets()[key]["minari_version"] - ] - - def _replace_with_max(key): - key_split = key.split("-") - same_entries = ( - torch.tensor( - [total_key[:-1] == key_split[:-1] for total_key in total_keys_splits] - ) - .nonzero() - .squeeze() - .tolist() - ) - last_same_entry = same_entries[-1] - return total_keys[last_same_entry] - - keys = [_replace_with_max(key) for key in keys] assert len(keys) > 5, keys _MINARI_DATASETS += keys @@ -2861,12 +2866,8 @@ def test_load(self, selected_dataset, split): break def test_minari_preproc(self, tmpdir): - global _MINARI_DATASETS - if not _MINARI_DATASETS: - _minari_selected_datasets() - selected_dataset = _MINARI_DATASETS[0] dataset = MinariExperienceReplay( - selected_dataset, + "D4RL/pointmaze/large-v2", batch_size=32, split_trajs=False, download="force", @@ -3073,7 +3074,7 @@ def test_atari_preproc(self, dataset_id, tmpdir): t = Compose( UnsqueezeTransform( - unsqueeze_dim=-3, in_keys=["observation", ("next", "observation")] + dim=-3, in_keys=["observation", ("next", "observation")] ), Resize(32, in_keys=["observation", ("next", "observation")]), RenameTransform(in_keys=["action"], out_keys=["other_action"]), @@ -3692,7 +3693,7 @@ class TestRoboHive: # The other option would be not to use parametrize but that also # means less informative error trace stacks. # In the CI, robohive should not coexist with other libs so that's fine. - # Robohive logging behaviour can be controlled via ROBOHIVE_VERBOSITY=ALL/INFO/(WARN)/ERROR/ONCE/ALWAYS/SILENT + # Robohive logging behavior can be controlled via ROBOHIVE_VERBOSITY=ALL/INFO/(WARN)/ERROR/ONCE/ALWAYS/SILENT @pytest.mark.parametrize("from_pixels", [False, True]) @pytest.mark.parametrize("from_depths", [False, True]) @pytest.mark.parametrize("envname", RoboHiveEnv.available_envs) @@ -3812,6 +3813,270 @@ def test_collector(self): collector.shutdown() +# List of OpenSpiel games to test +# TODO: Some of the games in `OpenSpielWrapper.available_envs` raise errors for +# a few different reasons, mostly because we do not support chance nodes yet. So +# we cannot run tests on all of them yet. +_openspiel_games = [ + # ---------------- + # Sequential games + # 1-player + "morpion_solitaire", + # 2-player + "amazons", + "battleship", + "breakthrough", + "checkers", + "chess", + "cliff_walking", + "clobber", + "connect_four", + "cursor_go", + "dark_chess", + "dark_hex", + "dark_hex_ir", + "dots_and_boxes", + "go", + "havannah", + "hex", + "kriegspiel", + "mancala", + "nim", + "nine_mens_morris", + "othello", + "oware", + "pentago", + "phantom_go", + "phantom_ttt", + "phantom_ttt_ir", + "sheriff", + "tic_tac_toe", + "twixt", + "ultimate_tic_tac_toe", + "y", + # -------------- + # Parallel games + # 2-player + "blotto", + "matrix_bos", + "matrix_brps", + "matrix_cd", + "matrix_coordination", + "matrix_mp", + "matrix_pd", + "matrix_rps", + "matrix_rpsw", + "matrix_sh", + "matrix_shapleys_game", + "oshi_zumo", + # 3-player + "matching_pennies_3p", +] + + +@pytest.mark.skipif(not _has_pyspiel, reason="open_spiel not found") +class TestOpenSpiel: + @pytest.mark.parametrize("game_string", _openspiel_games) + @pytest.mark.parametrize("return_state", [False, True]) + @pytest.mark.parametrize("categorical_actions", [False, True]) + def test_all_envs(self, game_string, return_state, categorical_actions): + env = OpenSpielEnv( + game_string, + categorical_actions=categorical_actions, + return_state=return_state, + ) + check_env_specs(env) + + @pytest.mark.parametrize("game_string", _openspiel_games) + @pytest.mark.parametrize("return_state", [False, True]) + @pytest.mark.parametrize("categorical_actions", [False, True]) + def test_wrapper(self, game_string, return_state, categorical_actions): + import pyspiel + + base_env = pyspiel.load_game(game_string).new_initial_state() + env_torchrl = OpenSpielWrapper( + base_env, categorical_actions=categorical_actions, return_state=return_state + ) + env_torchrl.rollout(max_steps=5) + + @pytest.mark.parametrize("game_string", _openspiel_games) + @pytest.mark.parametrize("return_state", [False, True]) + @pytest.mark.parametrize("categorical_actions", [False, True]) + def test_reset_state(self, game_string, return_state, categorical_actions): + env = OpenSpielEnv( + game_string, + categorical_actions=categorical_actions, + return_state=return_state, + ) + td = env.reset() + td_init = td.clone() + + # Perform an action + td = env.step(env.full_action_spec.rand()) + + # Save the current td for reset + td_reset = td["next"].clone() + + # Perform a second action + td = env.step(env.full_action_spec.rand()) + + # Resetting to a specific state can only happen if `return_state` is + # enabled. Otherwise, it is reset to the initial state. + if return_state: + # Check that the state was reset to the specified state + td = env.reset(td_reset) + assert (td == td_reset).all() + else: + # Check that the state was reset to the initial state + td = env.reset() + assert (td == td_init).all() + + def test_chance_not_implemented(self): + with pytest.raises( + NotImplementedError, + match="not yet supported", + ): + OpenSpielEnv("bridge") + + +# NOTE: Each of the registered envs are around 180 MB, so only test a few. +_mlagents_registered_envs = [ + "3DBall", + "StrikersVsGoalie", +] + + +@pytest.mark.skipif(not _has_unity_mlagents, reason="mlagents_envs not found") +class TestUnityMLAgents: + @mock.patch("mlagents_envs.env_utils.launch_executable") + @mock.patch("mlagents_envs.environment.UnityEnvironment._get_communicator") + @pytest.mark.parametrize( + "group_map", + [None, MarlGroupMapType.ONE_GROUP_PER_AGENT, MarlGroupMapType.ALL_IN_ONE_GROUP], + ) + def test_env(self, mock_communicator, mock_launcher, group_map): + from mlagents_envs.mock_communicator import MockCommunicator + + mock_communicator.return_value = MockCommunicator( + discrete_action=False, visual_inputs=0 + ) + env = UnityMLAgentsEnv(" ", group_map=group_map) + try: + check_env_specs(env) + finally: + env.close() + + @mock.patch("mlagents_envs.env_utils.launch_executable") + @mock.patch("mlagents_envs.environment.UnityEnvironment._get_communicator") + @pytest.mark.parametrize( + "group_map", + [None, MarlGroupMapType.ONE_GROUP_PER_AGENT, MarlGroupMapType.ALL_IN_ONE_GROUP], + ) + def test_wrapper(self, mock_communicator, mock_launcher, group_map): + from mlagents_envs.environment import UnityEnvironment + from mlagents_envs.mock_communicator import MockCommunicator + + mock_communicator.return_value = MockCommunicator( + discrete_action=False, visual_inputs=0 + ) + env = UnityMLAgentsWrapper(UnityEnvironment(" "), group_map=group_map) + try: + check_env_specs(env) + finally: + env.close() + + @mock.patch("mlagents_envs.env_utils.launch_executable") + @mock.patch("mlagents_envs.environment.UnityEnvironment._get_communicator") + @pytest.mark.parametrize( + "group_map", + [None, MarlGroupMapType.ONE_GROUP_PER_AGENT, MarlGroupMapType.ALL_IN_ONE_GROUP], + ) + def test_rollout(self, mock_communicator, mock_launcher, group_map): + from mlagents_envs.environment import UnityEnvironment + from mlagents_envs.mock_communicator import MockCommunicator + + mock_communicator.return_value = MockCommunicator( + discrete_action=False, visual_inputs=0 + ) + env = UnityMLAgentsWrapper(UnityEnvironment(" "), group_map=group_map) + try: + env.rollout( + max_steps=500, break_when_any_done=False, break_when_all_done=False + ) + finally: + env.close() + + @pytest.mark.unity_editor + def test_with_editor(self): + print("Please press play in the Unity editor") # noqa: T201 + env = UnityMLAgentsEnv(timeout_wait=30) + try: + env.reset() + check_env_specs(env) + + # Perform a rollout + td = env.reset() + env.rollout( + max_steps=100, break_when_any_done=False, break_when_all_done=False + ) + + # Step manually + tensordicts = [] + td = env.reset() + tensordicts.append(td) + traj_len = 200 + for _ in range(traj_len - 1): + td = env.step(td.update(env.full_action_spec.rand())) + tensordicts.append(td) + + traj = torch.stack(tensordicts) + assert traj.batch_size == torch.Size([traj_len]) + finally: + env.close() + + @retry( + ( + urllib.error.HTTPError, + urllib.error.URLError, + urllib.error.ContentTooShortError, + ), + 5, + ) + @pytest.mark.parametrize("registered_name", _mlagents_registered_envs) + @pytest.mark.parametrize( + "group_map", + [None, MarlGroupMapType.ONE_GROUP_PER_AGENT, MarlGroupMapType.ALL_IN_ONE_GROUP], + ) + def test_registered_envs(self, registered_name, group_map): + env = UnityMLAgentsEnv( + registered_name=registered_name, + no_graphics=True, + group_map=group_map, + ) + try: + check_env_specs(env) + + # Perform a rollout + td = env.reset() + env.rollout( + max_steps=20, break_when_any_done=False, break_when_all_done=False + ) + + # Step manually + tensordicts = [] + td = env.reset() + tensordicts.append(td) + traj_len = 20 + for _ in range(traj_len - 1): + td = env.step(td.update(env.full_action_spec.rand())) + tensordicts.append(td) + + traj = torch.stack(tensordicts) + assert traj.batch_size == torch.Size([traj_len]) + finally: + env.close() + + @pytest.mark.skipif(not _has_meltingpot, reason="Meltingpot not found") class TestMeltingpot: @pytest.mark.parametrize("substrate", MeltingpotWrapper.available_envs) diff --git a/test/test_loggers.py b/test/test_loggers.py index 735911bd95c..eb40ca1fdb8 100644 --- a/test/test_loggers.py +++ b/test/test_loggers.py @@ -281,25 +281,27 @@ def test_log_video(self, wandb_logger): # C - number of image channels (e.g. 3 for RGB), H, W - image dimensions. # the first 64 frames are black and the next 64 are white video = torch.cat( - (torch.zeros(64, 1, 32, 32), torch.full((64, 1, 32, 32), 255)) + (torch.zeros(128, 1, 32, 32), torch.full((128, 1, 32, 32), 255)) ) video = video[None, :] wandb_logger.log_video( name="foo", video=video, - fps=6, + fps=4, + format="mp4", ) wandb_logger.log_video( - name="foo_12fps", + name="foo_16fps", video=video, - fps=24, + fps=16, + format="mp4", ) sleep(0.01) # wait until events are registered # check that fps can be passed and that it has impact on the length of the video - video_6fps_size = wandb_logger.experiment.summary["foo"]["size"] - video_24fps_size = wandb_logger.experiment.summary["foo_12fps"]["size"] - assert video_6fps_size > video_24fps_size, video_6fps_size + video_4fps_size = wandb_logger.experiment.summary["foo"]["size"] + video_16fps_size = wandb_logger.experiment.summary["foo_16fps"]["size"] + assert video_4fps_size > video_16fps_size, (video_4fps_size, video_16fps_size) # check that we catch the error in case the format of the tensor is wrong video_wrong_format = torch.zeros(64, 2, 32, 32) diff --git a/test/test_modules.py b/test/test_modules.py index 00e58678788..8966b61154c 100644 --- a/test/test_modules.py +++ b/test/test_modules.py @@ -16,7 +16,7 @@ from packaging import version from tensordict import TensorDict from torch import nn -from torchrl.data.tensor_specs import BoundedTensorSpec, CompositeSpec +from torchrl.data.tensor_specs import Bounded, Composite from torchrl.modules import ( CEMPlanner, DTActor, @@ -466,9 +466,7 @@ def test_dreamer_decoder( @pytest.mark.parametrize("deter_size", [20, 30]) @pytest.mark.parametrize("action_size", [3, 6]) def test_rssm_prior(self, device, batch_size, stoch_size, deter_size, action_size): - action_spec = BoundedTensorSpec( - shape=(action_size,), dtype=torch.float32, low=-1, high=1 - ) + action_spec = Bounded(shape=(action_size,), dtype=torch.float32, low=-1, high=1) rssm_prior = RSSMPrior( action_spec, hidden_dim=stoch_size, @@ -521,9 +519,7 @@ def test_rssm_posterior(self, device, batch_size, stoch_size, deter_size): def test_rssm_rollout( self, device, batch_size, temporal_size, stoch_size, deter_size, action_size ): - action_spec = BoundedTensorSpec( - shape=(action_size,), dtype=torch.float32, low=-1, high=1 - ) + action_spec = Bounded(shape=(action_size,), dtype=torch.float32, low=-1, high=1) rssm_prior = RSSMPrior( action_spec, hidden_dim=stoch_size, @@ -650,10 +646,10 @@ def test_errors(self): ): TanhModule(in_keys=["a", "b"], out_keys=["a"]) with pytest.raises(ValueError, match=r"The minimum value \(-2\) provided"): - spec = BoundedTensorSpec(-1, 1, shape=()) + spec = Bounded(-1, 1, shape=()) TanhModule(in_keys=["act"], low=-2, spec=spec) with pytest.raises(ValueError, match=r"The maximum value \(-2\) provided to"): - spec = BoundedTensorSpec(-1, 1, shape=()) + spec = Bounded(-1, 1, shape=()) TanhModule(in_keys=["act"], high=-2, spec=spec) with pytest.raises(ValueError, match="Got high < low"): TanhModule(in_keys=["act"], high=-2, low=-1) @@ -709,12 +705,12 @@ def test_multi_inputs(self, out_keys, has_spec): if any(has_spec): spec = {} if has_spec[0]: - spec.update({real_out_keys[0]: BoundedTensorSpec(-2.0, 2.0, shape=())}) + spec.update({real_out_keys[0]: Bounded(-2.0, 2.0, shape=())}) low, high = -2.0, 2.0 if has_spec[1]: - spec.update({real_out_keys[1]: BoundedTensorSpec(-3.0, 3.0, shape=())}) + spec.update({real_out_keys[1]: Bounded(-3.0, 3.0, shape=())}) low, high = None, None - spec = CompositeSpec(spec) + spec = Composite(spec) else: spec = None low, high = -2.0, 2.0 diff --git a/test/test_rb.py b/test/test_rb.py index e17cd410c49..24b33f89795 100644 --- a/test/test_rb.py +++ b/test/test_rb.py @@ -26,6 +26,7 @@ assert_allclose_td, is_tensor_collection, is_tensorclass, + LazyStackedTensorDict, tensorclass, TensorDict, TensorDictBase, @@ -58,6 +59,11 @@ SliceSampler, SliceSamplerWithoutReplacement, ) +from torchrl.data.replay_buffers.scheduler import ( + LinearScheduler, + SchedulerList, + StepScheduler, +) from torchrl.data.replay_buffers.storages import ( LazyMemmapStorage, @@ -99,6 +105,7 @@ VecNorm, ) + OLD_TORCH = parse(torch.__version__) < parse("2.0.0") _has_tv = importlib.util.find_spec("torchvision") is not None _has_gym = importlib.util.find_spec("gym") is not None @@ -109,6 +116,11 @@ ".".join([str(s) for s in version.parse(str(torch.__version__)).release]) ) >= version.parse("2.3.0") +ReplayBufferRNG = functools.partial(ReplayBuffer, generator=torch.Generator()) +TensorDictReplayBufferRNG = functools.partial( + TensorDictReplayBuffer, generator=torch.Generator() +) + @pytest.mark.parametrize( "sampler", @@ -125,17 +137,27 @@ "rb_type,storage,datatype", [ [ReplayBuffer, ListStorage, None], + [ReplayBufferRNG, ListStorage, None], [TensorDictReplayBuffer, ListStorage, "tensordict"], + [TensorDictReplayBufferRNG, ListStorage, "tensordict"], [RemoteTensorDictReplayBuffer, ListStorage, "tensordict"], [ReplayBuffer, LazyTensorStorage, "tensor"], [ReplayBuffer, LazyTensorStorage, "tensordict"], [ReplayBuffer, LazyTensorStorage, "pytree"], + [ReplayBufferRNG, LazyTensorStorage, "tensor"], + [ReplayBufferRNG, LazyTensorStorage, "tensordict"], + [ReplayBufferRNG, LazyTensorStorage, "pytree"], [TensorDictReplayBuffer, LazyTensorStorage, "tensordict"], + [TensorDictReplayBufferRNG, LazyTensorStorage, "tensordict"], [RemoteTensorDictReplayBuffer, LazyTensorStorage, "tensordict"], [ReplayBuffer, LazyMemmapStorage, "tensor"], [ReplayBuffer, LazyMemmapStorage, "tensordict"], [ReplayBuffer, LazyMemmapStorage, "pytree"], + [ReplayBufferRNG, LazyMemmapStorage, "tensor"], + [ReplayBufferRNG, LazyMemmapStorage, "tensordict"], + [ReplayBufferRNG, LazyMemmapStorage, "pytree"], [TensorDictReplayBuffer, LazyMemmapStorage, "tensordict"], + [TensorDictReplayBufferRNG, LazyMemmapStorage, "tensordict"], [RemoteTensorDictReplayBuffer, LazyMemmapStorage, "tensordict"], ], ) @@ -531,6 +553,20 @@ def test_errors(self, storage_type): ): storage_type(data, max_size=4) + def test_existsok_lazymemmap(self, tmpdir): + storage0 = LazyMemmapStorage(10, scratch_dir=tmpdir) + rb = ReplayBuffer(storage=storage0) + rb.extend(TensorDict(a=torch.randn(3), batch_size=[3])) + + storage1 = LazyMemmapStorage(10, scratch_dir=tmpdir) + rb = ReplayBuffer(storage=storage1) + with pytest.raises(RuntimeError, match="existsok"): + rb.extend(TensorDict(a=torch.randn(3), batch_size=[3])) + + storage2 = LazyMemmapStorage(10, scratch_dir=tmpdir, existsok=True) + rb = ReplayBuffer(storage=storage2) + rb.extend(TensorDict(a=torch.randn(3), batch_size=[3])) + @pytest.mark.parametrize( "data_type", ["tensor", "tensordict", "tensorclass", "pytree"] ) @@ -686,6 +722,20 @@ def test_storage_state_dict(self, storage_in, storage_out, init_out, backend): s = new_replay_buffer.sample() assert (s.exclude("index") == 1).all() + @pytest.mark.parametrize("storage_type", [LazyMemmapStorage, LazyTensorStorage]) + def test_extend_lazystack(self, storage_type): + + rb = ReplayBuffer( + storage=storage_type(6), + batch_size=2, + ) + td1 = TensorDict(a=torch.rand(5, 4, 8), batch_size=5) + td2 = TensorDict(a=torch.rand(5, 3, 8), batch_size=5) + ltd = LazyStackedTensorDict(td1, td2, stack_dim=1) + rb.extend(ltd) + rb.sample(3) + assert len(rb) == 5 + @pytest.mark.parametrize("device_data", get_default_devices()) @pytest.mark.parametrize("storage_type", [LazyMemmapStorage, LazyTensorStorage]) @pytest.mark.parametrize("data_type", ["tensor", "tc", "td", "pytree"]) @@ -1155,17 +1205,115 @@ def test_replay_buffer_trajectories(stack, reduction, datatype): # sampled_td_filtered.batch_size = [3, 4] +class TestRNG: + def test_rb_rng(self): + state = torch.random.get_rng_state() + rb = ReplayBufferRNG(sampler=RandomSampler(), storage=LazyTensorStorage(100)) + rb.extend(torch.arange(100)) + rb._rng.set_state(state) + a = rb.sample(32) + rb._rng.set_state(state) + b = rb.sample(32) + assert (a == b).all() + c = rb.sample(32) + assert (a != c).any() + + def test_prb_rng(self): + state = torch.random.get_rng_state() + rb = ReplayBuffer( + sampler=PrioritizedSampler(100, 1.0, 1.0), + storage=LazyTensorStorage(100), + generator=torch.Generator(), + ) + rb.extend(torch.arange(100)) + rb.update_priority(index=torch.arange(100), priority=torch.arange(1, 101)) + + rb._rng.set_state(state) + a = rb.sample(32) + + rb._rng.set_state(state) + b = rb.sample(32) + assert (a == b).all() + + c = rb.sample(32) + assert (a != c).any() + + def test_slice_rng(self): + state = torch.random.get_rng_state() + rb = ReplayBuffer( + sampler=SliceSampler(num_slices=4), + storage=LazyTensorStorage(100), + generator=torch.Generator(), + ) + done = torch.zeros(100, 1, dtype=torch.bool) + done[49] = 1 + done[-1] = 1 + data = TensorDict( + { + "data": torch.arange(100), + ("next", "done"): done, + }, + batch_size=[100], + ) + rb.extend(data) + + rb._rng.set_state(state) + a = rb.sample(32) + + rb._rng.set_state(state) + b = rb.sample(32) + assert (a == b).all() + + c = rb.sample(32) + assert (a != c).any() + + def test_rng_state_dict(self): + state = torch.random.get_rng_state() + rb = ReplayBufferRNG(sampler=RandomSampler(), storage=LazyTensorStorage(100)) + rb.extend(torch.arange(100)) + rb._rng.set_state(state) + sd = rb.state_dict() + assert sd.get("_rng") is not None + a = rb.sample(32) + + rb.load_state_dict(sd) + b = rb.sample(32) + assert (a == b).all() + c = rb.sample(32) + assert (a != c).any() + + def test_rng_dumps(self, tmpdir): + state = torch.random.get_rng_state() + rb = ReplayBufferRNG(sampler=RandomSampler(), storage=LazyTensorStorage(100)) + rb.extend(torch.arange(100)) + rb._rng.set_state(state) + rb.dumps(tmpdir) + a = rb.sample(32) + + rb.loads(tmpdir) + b = rb.sample(32) + assert (a == b).all() + c = rb.sample(32) + assert (a != c).any() + + @pytest.mark.parametrize( "rbtype,storage", [ (ReplayBuffer, None), (ReplayBuffer, ListStorage), + (ReplayBufferRNG, None), + (ReplayBufferRNG, ListStorage), (PrioritizedReplayBuffer, None), (PrioritizedReplayBuffer, ListStorage), (TensorDictReplayBuffer, None), (TensorDictReplayBuffer, ListStorage), (TensorDictReplayBuffer, LazyTensorStorage), (TensorDictReplayBuffer, LazyMemmapStorage), + (TensorDictReplayBufferRNG, None), + (TensorDictReplayBufferRNG, ListStorage), + (TensorDictReplayBufferRNG, LazyTensorStorage), + (TensorDictReplayBufferRNG, LazyMemmapStorage), (TensorDictPrioritizedReplayBuffer, None), (TensorDictPrioritizedReplayBuffer, ListStorage), (TensorDictPrioritizedReplayBuffer, LazyTensorStorage), @@ -1175,33 +1323,34 @@ def test_replay_buffer_trajectories(stack, reduction, datatype): @pytest.mark.parametrize("size", [3, 5, 100]) @pytest.mark.parametrize("prefetch", [0]) class TestBuffers: - _default_params_rb = {} - _default_params_td_rb = {} - _default_params_prb = {"alpha": 0.8, "beta": 0.9} - _default_params_td_prb = {"alpha": 0.8, "beta": 0.9} + + default_constr = { + ReplayBuffer: ReplayBuffer, + PrioritizedReplayBuffer: functools.partial( + PrioritizedReplayBuffer, alpha=0.8, beta=0.9 + ), + TensorDictReplayBuffer: TensorDictReplayBuffer, + TensorDictPrioritizedReplayBuffer: functools.partial( + TensorDictPrioritizedReplayBuffer, alpha=0.8, beta=0.9 + ), + TensorDictReplayBufferRNG: TensorDictReplayBufferRNG, + ReplayBufferRNG: ReplayBufferRNG, + } def _get_rb(self, rbtype, size, storage, prefetch): if storage is not None: storage = storage(size) - if rbtype is ReplayBuffer: - params = self._default_params_rb - elif rbtype is PrioritizedReplayBuffer: - params = self._default_params_prb - elif rbtype is TensorDictReplayBuffer: - params = self._default_params_td_rb - elif rbtype is TensorDictPrioritizedReplayBuffer: - params = self._default_params_td_prb - else: - raise NotImplementedError(rbtype) - rb = rbtype(storage=storage, prefetch=prefetch, batch_size=3, **params) + rb = self.default_constr[rbtype]( + storage=storage, prefetch=prefetch, batch_size=3 + ) return rb def _get_datum(self, rbtype): - if rbtype is ReplayBuffer: + if rbtype in (ReplayBuffer, ReplayBufferRNG): data = torch.randint(100, (1,)) elif rbtype is PrioritizedReplayBuffer: data = torch.randint(100, (1,)) - elif rbtype is TensorDictReplayBuffer: + elif rbtype in (TensorDictReplayBuffer, TensorDictReplayBufferRNG): data = TensorDict({"a": torch.randint(100, (1,))}, []) elif rbtype is TensorDictPrioritizedReplayBuffer: data = TensorDict({"a": torch.randint(100, (1,))}, []) @@ -1210,11 +1359,11 @@ def _get_datum(self, rbtype): return data def _get_data(self, rbtype, size): - if rbtype is ReplayBuffer: + if rbtype in (ReplayBuffer, ReplayBufferRNG): data = [torch.randint(100, (1,)) for _ in range(size)] elif rbtype is PrioritizedReplayBuffer: data = [torch.randint(100, (1,)) for _ in range(size)] - elif rbtype is TensorDictReplayBuffer: + elif rbtype in (TensorDictReplayBuffer, TensorDictReplayBufferRNG): data = TensorDict( { "a": torch.randint(100, (size,)), @@ -1627,10 +1776,8 @@ def test_insert_transform(self): not _has_tv, reason="needs torchvision dependency" ), ), - pytest.param( - partial(UnsqueezeTransform, unsqueeze_dim=-1), id="UnsqueezeTransform" - ), - pytest.param(partial(SqueezeTransform, squeeze_dim=-1), id="SqueezeTransform"), + pytest.param(partial(UnsqueezeTransform, dim=-1), id="UnsqueezeTransform"), + pytest.param(partial(SqueezeTransform, dim=-1), id="SqueezeTransform"), GrayScale, pytest.param(partial(ObservationNorm, loc=1, scale=2), id="ObservationNorm"), pytest.param(partial(CatFrames, dim=-3, N=4), id="CatFrames"), @@ -1950,13 +2097,16 @@ def exec_multiproc_rb( init=True, writer_type=TensorDictRoundRobinWriter, sampler_type=RandomSampler, + device=None, ): rb = TensorDictReplayBuffer( storage=storage_type(21), writer=writer_type(), sampler=sampler_type() ) if init: td = TensorDict( - {"a": torch.zeros(10), "next": {"reward": torch.ones(10)}}, [10] + {"a": torch.zeros(10), "next": {"reward": torch.ones(10)}}, + [10], + device=device, ) rb.extend(td) q0 = mp.Queue(1) @@ -1984,13 +2134,6 @@ def test_error_list(self): with pytest.raises(RuntimeError, match="Cannot share a storage of type"): self.exec_multiproc_rb(storage_type=ListStorage) - def test_error_nonshared(self): - # non shared tensor storage cannot be shared - with pytest.raises( - RuntimeError, match="The storage must be place in shared memory" - ): - self.exec_multiproc_rb(storage_type=LazyTensorStorage) - def test_error_maxwriter(self): # TensorDictMaxValueWriter cannot be shared with pytest.raises(RuntimeError, match="cannot be shared between processes"): @@ -2902,6 +3045,77 @@ def test_prioritized_slice_sampler_episodes(device): ), "after priority update, only episode 1 and 3 are expected to be sampled" +@pytest.mark.parametrize("alpha", [0.6, torch.tensor(1.0)]) +@pytest.mark.parametrize("beta", [0.7, torch.tensor(0.1)]) +@pytest.mark.parametrize("gamma", [0.1]) +@pytest.mark.parametrize("total_steps", [200]) +@pytest.mark.parametrize("n_annealing_steps", [100]) +@pytest.mark.parametrize("anneal_every_n", [10, 159]) +@pytest.mark.parametrize("alpha_min", [0, 0.2]) +@pytest.mark.parametrize("beta_max", [1, 1.4]) +def test_prioritized_parameter_scheduler( + alpha, + beta, + gamma, + total_steps, + n_annealing_steps, + anneal_every_n, + alpha_min, + beta_max, +): + rb = TensorDictPrioritizedReplayBuffer( + alpha=alpha, beta=beta, storage=ListStorage(max_size=1000) + ) + data = TensorDict({"data": torch.randn(1000, 5)}, batch_size=1000) + rb.extend(data) + alpha_scheduler = LinearScheduler( + rb, param_name="alpha", final_value=alpha_min, num_steps=n_annealing_steps + ) + beta_scheduler = StepScheduler( + rb, + param_name="beta", + gamma=gamma, + n_steps=anneal_every_n, + max_value=beta_max, + mode="additive", + ) + + scheduler = SchedulerList(schedulers=(alpha_scheduler, beta_scheduler)) + + alpha = alpha if torch.is_tensor(alpha) else torch.tensor(alpha) + alpha_min = torch.tensor(alpha_min) + expected_alpha_vals = torch.linspace(alpha, alpha_min, n_annealing_steps + 1) + expected_alpha_vals = torch.nn.functional.pad( + expected_alpha_vals, (0, total_steps - n_annealing_steps), value=alpha_min + ) + + expected_beta_vals = [beta] + annealing_steps = total_steps // anneal_every_n + gammas = torch.arange(0, annealing_steps + 1, dtype=torch.float32) * gamma + expected_beta_vals = ( + (beta + gammas).repeat_interleave(anneal_every_n).clip(None, beta_max) + ) + for i in range(total_steps): + curr_alpha = rb.sampler.alpha + torch.testing.assert_close( + curr_alpha + if torch.is_tensor(curr_alpha) + else torch.tensor(curr_alpha).float(), + expected_alpha_vals[i], + msg=f"expected {expected_alpha_vals[i]}, got {curr_alpha}", + ) + curr_beta = rb.sampler.beta + torch.testing.assert_close( + curr_beta + if torch.is_tensor(curr_beta) + else torch.tensor(curr_beta).float(), + expected_beta_vals[i], + msg=f"expected {expected_beta_vals[i]}, got {curr_beta}", + ) + rb.sample(20) + scheduler.step() + + class TestEnsemble: def _make_data(self, data_type): if data_type is torch.Tensor: diff --git a/test/test_specs.py b/test/test_specs.py index 2d597d770f0..82d2b7f2e1d 100644 --- a/test/test_specs.py +++ b/test/test_specs.py @@ -4,6 +4,7 @@ # LICENSE file in the root directory of this source tree. import argparse import contextlib +import warnings import numpy as np import pytest @@ -14,19 +15,32 @@ from tensordict import LazyStackedTensorDict, TensorDict, TensorDictBase from tensordict.utils import _unravel_key_to_tuple from torchrl._utils import _make_ordinal_device + from torchrl.data.tensor_specs import ( _keys_to_empty_composite_spec, + Binary, BinaryDiscreteTensorSpec, + Bounded, BoundedTensorSpec, + Categorical, + Composite, CompositeSpec, + ContinuousBox, DiscreteTensorSpec, - LazyStackedCompositeSpec, + MultiCategorical, MultiDiscreteTensorSpec, + MultiOneHot, MultiOneHotDiscreteTensorSpec, + NonTensor, NonTensorSpec, + OneHot, OneHotDiscreteTensorSpec, + StackedComposite, TensorSpec, + Unbounded, + UnboundedContinuous, UnboundedContinuousTensorSpec, + UnboundedDiscrete, UnboundedDiscreteTensorSpec, ) from torchrl.data.utils import check_no_exclusive_keys, consolidate_spec @@ -38,9 +52,7 @@ def test_bounded(dtype): np.random.seed(0) for _ in range(100): bounds = torch.randn(2).sort()[0] - ts = BoundedTensorSpec( - bounds[0].item(), bounds[1].item(), torch.Size((1,)), dtype=dtype - ) + ts = Bounded(bounds[0].item(), bounds[1].item(), torch.Size((1,)), dtype=dtype) _dtype = dtype if dtype is None: _dtype = torch.get_default_dtype() @@ -53,7 +65,7 @@ def test_bounded(dtype): assert (ts.encode(ts.to_numpy(r)) == r).all() -@pytest.mark.parametrize("cls", [OneHotDiscreteTensorSpec, DiscreteTensorSpec]) +@pytest.mark.parametrize("cls", [OneHot, Categorical]) def test_discrete(cls): torch.manual_seed(0) np.random.seed(0) @@ -78,7 +90,7 @@ def test_discrete(cls): def test_unbounded(dtype): torch.manual_seed(0) np.random.seed(0) - ts = UnboundedContinuousTensorSpec(dtype=dtype) + ts = Unbounded(dtype=dtype) if dtype is None: dtype = torch.get_default_dtype() @@ -99,7 +111,7 @@ def test_ndbounded(dtype, shape): for _ in range(100): lb = torch.rand(10) - 1 ub = torch.rand(10) + 1 - ts = BoundedTensorSpec(lb, ub, dtype=dtype) + ts = Bounded(lb, ub, dtype=dtype) _dtype = dtype if dtype is None: _dtype = torch.get_default_dtype() @@ -150,7 +162,7 @@ def test_ndunbounded(dtype, n, shape): torch.manual_seed(0) np.random.seed(0) - ts = UnboundedContinuousTensorSpec( + ts = Unbounded( shape=[ n, ], @@ -195,7 +207,7 @@ def test_binary(n, shape): torch.manual_seed(0) np.random.seed(0) - ts = BinaryDiscreteTensorSpec(n) + ts = Binary(n) for _ in range(100): r = ts.rand(shape) assert r.shape == torch.Size( @@ -238,7 +250,7 @@ def test_binary(n, shape): def test_mult_onehot(shape, ns): torch.manual_seed(0) np.random.seed(0) - ts = MultiOneHotDiscreteTensorSpec(nvec=ns) + ts = MultiOneHot(nvec=ns) for _ in range(100): r = ts.rand(shape) assert r.shape == torch.Size( @@ -279,7 +291,7 @@ def test_mult_onehot(shape, ns): def test_multi_discrete(shape, ns, dtype): torch.manual_seed(0) np.random.seed(0) - ts = MultiDiscreteTensorSpec(ns, dtype=dtype) + ts = MultiCategorical(ns, dtype=dtype) _real_shape = shape if shape is not None else [] nvec_shape = torch.tensor(ns).size() for _ in range(100): @@ -315,9 +327,9 @@ def test_multi_discrete(shape, ns, dtype): @pytest.mark.parametrize("device", get_default_devices()) @pytest.mark.parametrize("shape", [None, [], [1], [1, 2]]) def test_discrete_conversion(n, device, shape): - categorical = DiscreteTensorSpec(n, device=device, shape=shape) + categorical = Categorical(n, device=device, shape=shape) shape_one_hot = [n] if not shape else [*shape, n] - one_hot = OneHotDiscreteTensorSpec(n, device=device, shape=shape_one_hot) + one_hot = OneHot(n, device=device, shape=shape_one_hot) assert categorical != one_hot assert categorical.to_one_hot_spec() == one_hot @@ -333,8 +345,8 @@ def test_discrete_conversion(n, device, shape): @pytest.mark.parametrize("shape", [torch.Size([3]), torch.Size([4, 5])]) @pytest.mark.parametrize("device", get_default_devices()) def test_multi_discrete_conversion(ns, shape, device): - categorical = MultiDiscreteTensorSpec(ns, device=device) - one_hot = MultiOneHotDiscreteTensorSpec(ns, device=device) + categorical = MultiCategorical(ns, device=device) + one_hot = MultiOneHot(ns, device=device) assert categorical != one_hot assert categorical.to_one_hot_spec() == one_hot @@ -356,14 +368,14 @@ def _composite_spec(shape, is_complete=True, device=None, dtype=None): torch.manual_seed(0) np.random.seed(0) - return CompositeSpec( - obs=BoundedTensorSpec( + return Composite( + obs=Bounded( torch.zeros(*shape, 3, 32, 32), torch.ones(*shape, 3, 32, 32), dtype=dtype, device=device, ), - act=UnboundedContinuousTensorSpec( + act=Unbounded( ( *shape, 7, @@ -379,9 +391,9 @@ def _composite_spec(shape, is_complete=True, device=None, dtype=None): def test_getitem(self, shape, is_complete, device, dtype): ts = self._composite_spec(shape, is_complete, device, dtype) - assert isinstance(ts["obs"], BoundedTensorSpec) + assert isinstance(ts["obs"], Bounded) if is_complete: - assert isinstance(ts["act"], UnboundedContinuousTensorSpec) + assert isinstance(ts["act"], Unbounded) else: assert ts["act"] is None with pytest.raises(KeyError): @@ -397,21 +409,17 @@ def test_setitem_forbidden_keys(self, shape, is_complete, device, dtype): def test_setitem_matches_device(self, shape, is_complete, device, dtype, dest): ts = self._composite_spec(shape, is_complete, device, dtype) - ts["good"] = UnboundedContinuousTensorSpec( - shape=shape, device=device, dtype=dtype - ) + ts["good"] = Unbounded(shape=shape, device=device, dtype=dtype) cm = ( contextlib.nullcontext() if (device == dest) or (device is None) else pytest.raises( - RuntimeError, match="All devices of CompositeSpec must match" + RuntimeError, match="All devices of Composite must match" ) ) with cm: # auto-casting is introduced since v0.3 - ts["bad"] = UnboundedContinuousTensorSpec( - shape=shape, device=dest, dtype=dtype - ) + ts["bad"] = Unbounded(shape=shape, device=dest, dtype=dtype) assert ts.device == device assert ts["good"].device == ( device if device is not None else torch.zeros(()).device @@ -490,7 +498,7 @@ def test_rand(self, shape, is_complete, device, dtype, shape_other): def test_repr(self, shape, is_complete, device, dtype): ts = self._composite_spec(shape, is_complete, device, dtype) output = repr(ts) - assert output.startswith("CompositeSpec") + assert output.startswith("Composite") assert "obs: " in output assert "act: " in output @@ -606,7 +614,7 @@ def test_nested_composite_spec_delitem(self, shape, is_complete, device, dtype): def test_nested_composite_spec_update(self, shape, is_complete, device, dtype): ts = self._composite_spec(shape, is_complete, device, dtype) ts["nested_cp"] = self._composite_spec(shape, is_complete, device, dtype) - td2 = CompositeSpec(new=None) + td2 = Composite(new=None) ts.update(td2) assert set(ts.keys(include_nested=True)) == { "obs", @@ -619,7 +627,7 @@ def test_nested_composite_spec_update(self, shape, is_complete, device, dtype): ts = self._composite_spec(shape, is_complete, device, dtype) ts["nested_cp"] = self._composite_spec(shape, is_complete, device, dtype) - td2 = CompositeSpec(nested_cp=CompositeSpec(new=None).to(device)) + td2 = Composite(nested_cp=Composite(new=None).to(device)) ts.update(td2) assert set(ts.keys(include_nested=True)) == { "obs", @@ -632,7 +640,7 @@ def test_nested_composite_spec_update(self, shape, is_complete, device, dtype): ts = self._composite_spec(shape, is_complete, device, dtype) ts["nested_cp"] = self._composite_spec(shape, is_complete, device, dtype) - td2 = CompositeSpec(nested_cp=CompositeSpec(act=None).to(device)) + td2 = Composite(nested_cp=Composite(act=None).to(device)) ts.update(td2) assert set(ts.keys(include_nested=True)) == { "obs", @@ -645,13 +653,13 @@ def test_nested_composite_spec_update(self, shape, is_complete, device, dtype): ts = self._composite_spec(shape, is_complete, device, dtype) ts["nested_cp"] = self._composite_spec(shape, is_complete, device, dtype) - td2 = CompositeSpec( - nested_cp=CompositeSpec(act=None, shape=shape).to(device), shape=shape + td2 = Composite( + nested_cp=Composite(act=None, shape=shape).to(device), shape=shape ) ts.update(td2) - td2 = CompositeSpec( - nested_cp=CompositeSpec( - act=UnboundedContinuousTensorSpec(shape=shape, device=device), + td2 = Composite( + nested_cp=Composite( + act=Unbounded(shape=shape, device=device), shape=shape, ), shape=shape, @@ -668,8 +676,8 @@ def test_nested_composite_spec_update(self, shape, is_complete, device, dtype): def test_change_batch_size(self, shape, is_complete, device, dtype): ts = self._composite_spec(shape, is_complete, device, dtype) - ts["nested"] = CompositeSpec( - leaf=UnboundedContinuousTensorSpec(shape, device=device), + ts["nested"] = Composite( + leaf=Unbounded(shape, device=device), shape=shape, device=device, ) @@ -690,12 +698,12 @@ def test_change_batch_size(self, shape, is_complete, device, dtype): @pytest.mark.parametrize("device", get_default_devices()) def test_create_composite_nested(shape, device): d = [ - {("a", "b"): UnboundedContinuousTensorSpec(shape=shape, device=device)}, - {"a": {"b": UnboundedContinuousTensorSpec(shape=shape, device=device)}}, + {("a", "b"): Unbounded(shape=shape, device=device)}, + {"a": {"b": Unbounded(shape=shape, device=device)}}, ] for _d in d: - c = CompositeSpec(_d, shape=shape) - assert isinstance(c["a", "b"], UnboundedContinuousTensorSpec) + c = Composite(_d, shape=shape) + assert isinstance(c["a", "b"], Unbounded) assert c["a"].shape == torch.Size(shape) assert c.device is None # device not explicitly passed assert c["a"].device is None # device not explicitly passed @@ -708,10 +716,8 @@ def test_create_composite_nested(shape, device): @pytest.mark.parametrize("recurse", [True, False]) def test_lock(recurse): shape = [3, 4, 5] - spec = CompositeSpec( - a=CompositeSpec( - b=CompositeSpec(shape=shape[:3], device="cpu"), shape=shape[:2] - ), + spec = Composite( + a=Composite(b=Composite(shape=shape[:3], device="cpu"), shape=shape[:2]), shape=shape[:1], ) spec["a"] = spec["a"].clone() @@ -719,15 +725,15 @@ def test_lock(recurse): assert not spec.locked spec.lock_(recurse=recurse) assert spec.locked - with pytest.raises(RuntimeError, match="Cannot modify a locked CompositeSpec."): + with pytest.raises(RuntimeError, match="Cannot modify a locked Composite."): spec["a"] = spec["a"].clone() - with pytest.raises(RuntimeError, match="Cannot modify a locked CompositeSpec."): + with pytest.raises(RuntimeError, match="Cannot modify a locked Composite."): spec.set("a", spec["a"].clone()) if recurse: assert spec["a"].locked - with pytest.raises(RuntimeError, match="Cannot modify a locked CompositeSpec."): + with pytest.raises(RuntimeError, match="Cannot modify a locked Composite."): spec["a"].set("b", spec["a", "b"].clone()) - with pytest.raises(RuntimeError, match="Cannot modify a locked CompositeSpec."): + with pytest.raises(RuntimeError, match="Cannot modify a locked Composite."): spec["a", "b"] = spec["a", "b"].clone() else: assert not spec["a"].locked @@ -763,33 +769,25 @@ def test_equality_bounded(self): device = "cpu" dtype = torch.float16 - ts = BoundedTensorSpec(minimum, maximum, torch.Size((1,)), device, dtype) + ts = Bounded(minimum, maximum, torch.Size((1,)), device, dtype) - ts_same = BoundedTensorSpec(minimum, maximum, torch.Size((1,)), device, dtype) + ts_same = Bounded(minimum, maximum, torch.Size((1,)), device, dtype) assert ts == ts_same - ts_other = BoundedTensorSpec( - minimum + 1, maximum, torch.Size((1,)), device, dtype - ) + ts_other = Bounded(minimum + 1, maximum, torch.Size((1,)), device, dtype) assert ts != ts_other - ts_other = BoundedTensorSpec( - minimum, maximum + 1, torch.Size((1,)), device, dtype - ) + ts_other = Bounded(minimum, maximum + 1, torch.Size((1,)), device, dtype) assert ts != ts_other if torch.cuda.device_count(): - ts_other = BoundedTensorSpec( - minimum, maximum, torch.Size((1,)), "cuda:0", dtype - ) + ts_other = Bounded(minimum, maximum, torch.Size((1,)), "cuda:0", dtype) assert ts != ts_other - ts_other = BoundedTensorSpec( - minimum, maximum, torch.Size((1,)), device, torch.float64 - ) + ts_other = Bounded(minimum, maximum, torch.Size((1,)), device, torch.float64) assert ts != ts_other ts_other = TestEquality._ts_make_all_fields_equal( - UnboundedContinuousTensorSpec(device=device, dtype=dtype), ts + Unbounded(device=device, dtype=dtype), ts ) assert ts != ts_other @@ -799,38 +797,34 @@ def test_equality_onehot(self): dtype = torch.float16 use_register = False - ts = OneHotDiscreteTensorSpec( - n=n, device=device, dtype=dtype, use_register=use_register - ) + ts = OneHot(n=n, device=device, dtype=dtype, use_register=use_register) - ts_same = OneHotDiscreteTensorSpec( - n=n, device=device, dtype=dtype, use_register=use_register - ) + ts_same = OneHot(n=n, device=device, dtype=dtype, use_register=use_register) assert ts == ts_same - ts_other = OneHotDiscreteTensorSpec( + ts_other = OneHot( n=n + 1, device=device, dtype=dtype, use_register=use_register ) assert ts != ts_other if torch.cuda.device_count(): - ts_other = OneHotDiscreteTensorSpec( + ts_other = OneHot( n=n, device="cuda:0", dtype=dtype, use_register=use_register ) assert ts != ts_other - ts_other = OneHotDiscreteTensorSpec( + ts_other = OneHot( n=n, device=device, dtype=torch.float64, use_register=use_register ) assert ts != ts_other - ts_other = OneHotDiscreteTensorSpec( + ts_other = OneHot( n=n, device=device, dtype=dtype, use_register=not use_register ) assert ts != ts_other ts_other = TestEquality._ts_make_all_fields_equal( - UnboundedContinuousTensorSpec(device=device, dtype=dtype), ts + Unbounded(device=device, dtype=dtype), ts ) assert ts != ts_other @@ -838,21 +832,25 @@ def test_equality_unbounded(self): device = "cpu" dtype = torch.float16 - ts = UnboundedContinuousTensorSpec(device=device, dtype=dtype) + ts = Unbounded(device=device, dtype=dtype) - ts_same = UnboundedContinuousTensorSpec(device=device, dtype=dtype) + ts_same = Unbounded(device=device, dtype=dtype) assert ts == ts_same if torch.cuda.device_count(): - ts_other = UnboundedContinuousTensorSpec(device="cuda:0", dtype=dtype) + ts_other = Unbounded(device="cuda:0", dtype=dtype) assert ts != ts_other - ts_other = UnboundedContinuousTensorSpec(device=device, dtype=torch.float64) + ts_other = Unbounded(device=device, dtype=torch.float64) assert ts != ts_other ts_other = TestEquality._ts_make_all_fields_equal( - BoundedTensorSpec(0, 1, torch.Size((1,)), device, dtype), ts + Bounded(0, 1, torch.Size((1,)), device, dtype), ts + ) + ts_other.space = ContinuousBox( + ts_other.space.low * 0, ts_other.space.high * 0 + 1 ) + assert ts.space != ts_other.space, (ts.space, ts_other.space) assert ts != ts_other def test_equality_ndbounded(self): @@ -861,36 +859,28 @@ def test_equality_ndbounded(self): device = "cpu" dtype = torch.float16 - ts = BoundedTensorSpec(low=minimum, high=maximum, device=device, dtype=dtype) + ts = Bounded(low=minimum, high=maximum, device=device, dtype=dtype) - ts_same = BoundedTensorSpec( - low=minimum, high=maximum, device=device, dtype=dtype - ) + ts_same = Bounded(low=minimum, high=maximum, device=device, dtype=dtype) assert ts == ts_same - ts_other = BoundedTensorSpec( - low=minimum + 1, high=maximum, device=device, dtype=dtype - ) + ts_other = Bounded(low=minimum + 1, high=maximum, device=device, dtype=dtype) assert ts != ts_other - ts_other = BoundedTensorSpec( - low=minimum, high=maximum + 1, device=device, dtype=dtype - ) + ts_other = Bounded(low=minimum, high=maximum + 1, device=device, dtype=dtype) assert ts != ts_other if torch.cuda.device_count(): - ts_other = BoundedTensorSpec( - low=minimum, high=maximum, device="cuda:0", dtype=dtype - ) + ts_other = Bounded(low=minimum, high=maximum, device="cuda:0", dtype=dtype) assert ts != ts_other - ts_other = BoundedTensorSpec( + ts_other = Bounded( low=minimum, high=maximum, device=device, dtype=torch.float64 ) assert ts != ts_other ts_other = TestEquality._ts_make_all_fields_equal( - UnboundedContinuousTensorSpec(device=device, dtype=dtype), ts + Unbounded(device=device, dtype=dtype), ts ) assert ts != ts_other @@ -900,32 +890,28 @@ def test_equality_discrete(self): device = "cpu" dtype = torch.float16 - ts = DiscreteTensorSpec(n=n, shape=shape, device=device, dtype=dtype) + ts = Categorical(n=n, shape=shape, device=device, dtype=dtype) - ts_same = DiscreteTensorSpec(n=n, shape=shape, device=device, dtype=dtype) + ts_same = Categorical(n=n, shape=shape, device=device, dtype=dtype) assert ts == ts_same - ts_other = DiscreteTensorSpec(n=n + 1, shape=shape, device=device, dtype=dtype) + ts_other = Categorical(n=n + 1, shape=shape, device=device, dtype=dtype) assert ts != ts_other if torch.cuda.device_count(): - ts_other = DiscreteTensorSpec( - n=n, shape=shape, device="cuda:0", dtype=dtype - ) + ts_other = Categorical(n=n, shape=shape, device="cuda:0", dtype=dtype) assert ts != ts_other - ts_other = DiscreteTensorSpec( - n=n, shape=shape, device=device, dtype=torch.float64 - ) + ts_other = Categorical(n=n, shape=shape, device=device, dtype=torch.float64) assert ts != ts_other - ts_other = DiscreteTensorSpec( + ts_other = Categorical( n=n, shape=torch.Size([2]), device=device, dtype=torch.float64 ) assert ts != ts_other ts_other = TestEquality._ts_make_all_fields_equal( - UnboundedContinuousTensorSpec(device=device, dtype=dtype), ts + Unbounded(device=device, dtype=dtype), ts ) assert ts != ts_other @@ -941,30 +927,24 @@ def test_equality_ndunbounded(self, shape): device = "cpu" dtype = torch.float16 - ts = UnboundedContinuousTensorSpec(shape=shape, device=device, dtype=dtype) + ts = Unbounded(shape=shape, device=device, dtype=dtype) - ts_same = UnboundedContinuousTensorSpec(shape=shape, device=device, dtype=dtype) + ts_same = Unbounded(shape=shape, device=device, dtype=dtype) assert ts == ts_same - other_shape = 13 if type(shape) == int else torch.Size(np.array(shape) + 10) - ts_other = UnboundedContinuousTensorSpec( - shape=other_shape, device=device, dtype=dtype - ) + other_shape = 13 if isinstance(shape, int) else torch.Size(np.array(shape) + 10) + ts_other = Unbounded(shape=other_shape, device=device, dtype=dtype) assert ts != ts_other if torch.cuda.device_count(): - ts_other = UnboundedContinuousTensorSpec( - shape=shape, device="cuda:0", dtype=dtype - ) + ts_other = Unbounded(shape=shape, device="cuda:0", dtype=dtype) assert ts != ts_other - ts_other = UnboundedContinuousTensorSpec( - shape=shape, device=device, dtype=torch.float64 - ) + ts_other = Unbounded(shape=shape, device=device, dtype=torch.float64) assert ts != ts_other ts_other = TestEquality._ts_make_all_fields_equal( - BoundedTensorSpec(0, 1, torch.Size((1,)), device, dtype), ts + Bounded(0, 1, torch.Size((1,)), device, dtype), ts ) # Unbounded and bounded without space are technically the same assert ts == ts_other @@ -974,23 +954,23 @@ def test_equality_binary(self): device = "cpu" dtype = torch.float16 - ts = BinaryDiscreteTensorSpec(n=n, device=device, dtype=dtype) + ts = Binary(n=n, device=device, dtype=dtype) - ts_same = BinaryDiscreteTensorSpec(n=n, device=device, dtype=dtype) + ts_same = Binary(n=n, device=device, dtype=dtype) assert ts == ts_same - ts_other = BinaryDiscreteTensorSpec(n=n + 5, device=device, dtype=dtype) + ts_other = Binary(n=n + 5, device=device, dtype=dtype) assert ts != ts_other if torch.cuda.device_count(): - ts_other = BinaryDiscreteTensorSpec(n=n, device="cuda:0", dtype=dtype) + ts_other = Binary(n=n, device="cuda:0", dtype=dtype) assert ts != ts_other - ts_other = BinaryDiscreteTensorSpec(n=n, device=device, dtype=torch.float64) + ts_other = Binary(n=n, device=device, dtype=torch.float64) assert ts != ts_other ts_other = TestEquality._ts_make_all_fields_equal( - BoundedTensorSpec(0, 1, torch.Size((1,)), device, dtype), ts + Bounded(0, 1, torch.Size((1,)), device, dtype), ts ) assert ts != ts_other @@ -999,42 +979,32 @@ def test_equality_multi_onehot(self, nvec): device = "cpu" dtype = torch.float16 - ts = MultiOneHotDiscreteTensorSpec(nvec=nvec, device=device, dtype=dtype) + ts = MultiOneHot(nvec=nvec, device=device, dtype=dtype) - ts_same = MultiOneHotDiscreteTensorSpec(nvec=nvec, device=device, dtype=dtype) + ts_same = MultiOneHot(nvec=nvec, device=device, dtype=dtype) assert ts == ts_same other_nvec = np.array(nvec) + 3 - ts_other = MultiOneHotDiscreteTensorSpec( - nvec=other_nvec, device=device, dtype=dtype - ) + ts_other = MultiOneHot(nvec=other_nvec, device=device, dtype=dtype) assert ts != ts_other other_nvec = [12] - ts_other = MultiOneHotDiscreteTensorSpec( - nvec=other_nvec, device=device, dtype=dtype - ) + ts_other = MultiOneHot(nvec=other_nvec, device=device, dtype=dtype) assert ts != ts_other other_nvec = [12, 13] - ts_other = MultiOneHotDiscreteTensorSpec( - nvec=other_nvec, device=device, dtype=dtype - ) + ts_other = MultiOneHot(nvec=other_nvec, device=device, dtype=dtype) assert ts != ts_other if torch.cuda.device_count(): - ts_other = MultiOneHotDiscreteTensorSpec( - nvec=nvec, device="cuda:0", dtype=dtype - ) + ts_other = MultiOneHot(nvec=nvec, device="cuda:0", dtype=dtype) assert ts != ts_other - ts_other = MultiOneHotDiscreteTensorSpec( - nvec=nvec, device=device, dtype=torch.float64 - ) + ts_other = MultiOneHot(nvec=nvec, device=device, dtype=torch.float64) assert ts != ts_other ts_other = TestEquality._ts_make_all_fields_equal( - BoundedTensorSpec(0, 1, torch.Size((1,)), device, dtype), ts + Bounded(0, 1, torch.Size((1,)), device, dtype), ts ) assert ts != ts_other @@ -1043,34 +1013,32 @@ def test_equality_multi_discrete(self, nvec): device = "cpu" dtype = torch.float16 - ts = MultiDiscreteTensorSpec(nvec=nvec, device=device, dtype=dtype) + ts = MultiCategorical(nvec=nvec, device=device, dtype=dtype) - ts_same = MultiDiscreteTensorSpec(nvec=nvec, device=device, dtype=dtype) + ts_same = MultiCategorical(nvec=nvec, device=device, dtype=dtype) assert ts == ts_same other_nvec = np.array(nvec) + 3 - ts_other = MultiDiscreteTensorSpec(nvec=other_nvec, device=device, dtype=dtype) + ts_other = MultiCategorical(nvec=other_nvec, device=device, dtype=dtype) assert ts != ts_other other_nvec = [12] - ts_other = MultiDiscreteTensorSpec(nvec=other_nvec, device=device, dtype=dtype) + ts_other = MultiCategorical(nvec=other_nvec, device=device, dtype=dtype) assert ts != ts_other other_nvec = [12, 13] - ts_other = MultiDiscreteTensorSpec(nvec=other_nvec, device=device, dtype=dtype) + ts_other = MultiCategorical(nvec=other_nvec, device=device, dtype=dtype) assert ts != ts_other if torch.cuda.device_count(): - ts_other = MultiDiscreteTensorSpec(nvec=nvec, device="cuda:0", dtype=dtype) + ts_other = MultiCategorical(nvec=nvec, device="cuda:0", dtype=dtype) assert ts != ts_other - ts_other = MultiDiscreteTensorSpec( - nvec=nvec, device=device, dtype=torch.float64 - ) + ts_other = MultiCategorical(nvec=nvec, device=device, dtype=torch.float64) assert ts != ts_other ts_other = TestEquality._ts_make_all_fields_equal( - BoundedTensorSpec(0, 1, torch.Size((1,)), device, dtype), ts + Bounded(0, 1, torch.Size((1,)), device, dtype), ts ) assert ts != ts_other @@ -1080,69 +1048,63 @@ def test_equality_composite(self): device = "cpu" dtype = torch.float16 - bounded = BoundedTensorSpec(0, 1, torch.Size((1,)), device, dtype) - bounded_same = BoundedTensorSpec(0, 1, torch.Size((1,)), device, dtype) - bounded_other = BoundedTensorSpec(0, 2, torch.Size((1,)), device, dtype) + bounded = Bounded(0, 1, torch.Size((1,)), device, dtype) + bounded_same = Bounded(0, 1, torch.Size((1,)), device, dtype) + bounded_other = Bounded(0, 2, torch.Size((1,)), device, dtype) - nd = BoundedTensorSpec( - low=minimum, high=maximum + 1, device=device, dtype=dtype - ) - nd_same = BoundedTensorSpec( - low=minimum, high=maximum + 1, device=device, dtype=dtype - ) - _ = BoundedTensorSpec(low=minimum, high=maximum + 3, device=device, dtype=dtype) + nd = Bounded(low=minimum, high=maximum + 1, device=device, dtype=dtype) + nd_same = Bounded(low=minimum, high=maximum + 1, device=device, dtype=dtype) + _ = Bounded(low=minimum, high=maximum + 3, device=device, dtype=dtype) # Equality tests - ts = CompositeSpec(ts1=bounded) - ts_same = CompositeSpec(ts1=bounded) + ts = Composite(ts1=bounded) + ts_same = Composite(ts1=bounded) assert ts == ts_same - ts = CompositeSpec(ts1=bounded) - ts_same = CompositeSpec(ts1=bounded_same) + ts = Composite(ts1=bounded) + ts_same = Composite(ts1=bounded_same) assert ts == ts_same - ts = CompositeSpec(ts1=bounded, ts2=nd) - ts_same = CompositeSpec(ts1=bounded, ts2=nd) + ts = Composite(ts1=bounded, ts2=nd) + ts_same = Composite(ts1=bounded, ts2=nd) assert ts == ts_same - ts = CompositeSpec(ts1=bounded, ts2=nd) - ts_same = CompositeSpec(ts1=bounded_same, ts2=nd_same) + ts = Composite(ts1=bounded, ts2=nd) + ts_same = Composite(ts1=bounded_same, ts2=nd_same) assert ts == ts_same - ts = CompositeSpec(ts1=bounded, ts2=nd) - ts_same = CompositeSpec(ts2=nd_same, ts1=bounded_same) + ts = Composite(ts1=bounded, ts2=nd) + ts_same = Composite(ts2=nd_same, ts1=bounded_same) assert ts == ts_same # Inequality tests - ts = CompositeSpec(ts1=bounded) - ts_other = CompositeSpec(ts5=bounded) + ts = Composite(ts1=bounded) + ts_other = Composite(ts5=bounded) assert ts != ts_other - ts = CompositeSpec(ts1=bounded) - ts_other = CompositeSpec(ts1=bounded_other) + ts = Composite(ts1=bounded) + ts_other = Composite(ts1=bounded_other) assert ts != ts_other - ts = CompositeSpec(ts1=bounded) - ts_other = CompositeSpec(ts1=nd) + ts = Composite(ts1=bounded) + ts_other = Composite(ts1=nd) assert ts != ts_other - ts = CompositeSpec(ts1=bounded) - ts_other = CompositeSpec(ts1=bounded, ts2=nd) + ts = Composite(ts1=bounded) + ts_other = Composite(ts1=bounded, ts2=nd) assert ts != ts_other - ts = CompositeSpec(ts1=bounded, ts2=nd) - ts_other = CompositeSpec(ts2=nd) + ts = Composite(ts1=bounded, ts2=nd) + ts_other = Composite(ts2=nd) assert ts != ts_other - ts = CompositeSpec(ts1=bounded, ts2=nd) - ts_other = CompositeSpec(ts1=bounded, ts2=nd, ts3=bounded_other) + ts = Composite(ts1=bounded, ts2=nd) + ts_other = Composite(ts1=bounded, ts2=nd, ts3=bounded_other) assert ts != ts_other class TestSpec: - @pytest.mark.parametrize( - "action_spec_cls", [OneHotDiscreteTensorSpec, DiscreteTensorSpec] - ) + @pytest.mark.parametrize("action_spec_cls", [OneHot, Categorical]) def test_discrete_action_spec_reconstruct(self, action_spec_cls): torch.manual_seed(0) action_spec = action_spec_cls(10) @@ -1161,7 +1123,7 @@ def test_discrete_action_spec_reconstruct(self, action_spec_cls): def test_mult_discrete_action_spec_reconstruct(self): torch.manual_seed(0) - action_spec = MultiOneHotDiscreteTensorSpec((10, 5)) + action_spec = MultiOneHot((10, 5)) actions_tensors = [action_spec.rand() for _ in range(10)] actions_categorical = [action_spec.to_categorical(a) for a in actions_tensors] @@ -1183,7 +1145,7 @@ def test_mult_discrete_action_spec_reconstruct(self): def test_one_hot_discrete_action_spec_rand(self): torch.manual_seed(0) - action_spec = OneHotDiscreteTensorSpec(10) + action_spec = OneHot(10) sample = action_spec.rand((100000,)) @@ -1197,7 +1159,7 @@ def test_one_hot_discrete_action_spec_rand(self): def test_categorical_action_spec_rand(self): torch.manual_seed(1) - action_spec = DiscreteTensorSpec(10) + action_spec = Categorical(10) sample = action_spec.rand((10000,)) @@ -1213,7 +1175,7 @@ def test_mult_discrete_action_spec_rand(self): torch.manual_seed(0) ns = (10, 5) N = 100000 - action_spec = MultiOneHotDiscreteTensorSpec((10, 5)) + action_spec = MultiOneHot((10, 5)) actions_tensors = [action_spec.rand() for _ in range(10)] actions_categorical = [action_spec.to_categorical(a) for a in actions_tensors] @@ -1238,7 +1200,7 @@ def test_mult_discrete_action_spec_rand(self): assert chisquare(sample_list).pvalue > 0.1 def test_categorical_action_spec_encode(self): - action_spec = DiscreteTensorSpec(10) + action_spec = Categorical(10) projected = action_spec.project( torch.tensor([-100, -1, 0, 1, 9, 10, 100], dtype=torch.long) @@ -1255,12 +1217,12 @@ def test_categorical_action_spec_encode(self): ).all() def test_bounded_rand(self): - spec = BoundedTensorSpec(-3, 3, torch.Size((1,))) + spec = Bounded(-3, 3, torch.Size((1,))) sample = torch.stack([spec.rand() for _ in range(100)]) assert (-3 <= sample).all() and (3 >= sample).all() def test_ndbounded_shape(self): - spec = BoundedTensorSpec(-3, 3 * torch.ones(10, 5), shape=[10, 5]) + spec = Bounded(-3, 3 * torch.ones(10, 5), shape=[10, 5]) sample = torch.stack([spec.rand() for _ in range(100)], 0) assert (-3 <= sample).all() and (3 >= sample).all() assert sample.shape == torch.Size([100, 10, 5]) @@ -1270,9 +1232,7 @@ class TestExpand: @pytest.mark.parametrize("shape1", [None, (4,), (5, 4)]) @pytest.mark.parametrize("shape2", [(), (10,)]) def test_binary(self, shape1, shape2): - spec = BinaryDiscreteTensorSpec( - n=4, shape=shape1, device="cpu", dtype=torch.bool - ) + spec = Binary(n=4, shape=shape1, device="cpu", dtype=torch.bool) if shape1 is not None: shape2_real = (*shape2, *shape1) else: @@ -1304,9 +1264,7 @@ def test_binary(self, shape1, shape2): ], ) def test_bounded(self, shape1, shape2, mini, maxi): - spec = BoundedTensorSpec( - mini, maxi, shape=shape1, device="cpu", dtype=torch.bool - ) + spec = Bounded(mini, maxi, shape=shape1, device="cpu", dtype=torch.bool) shape1 = spec.shape assert shape1 == torch.Size([10]) shape2_real = (*shape2, *shape1) @@ -1326,7 +1284,7 @@ def test_bounded(self, shape1, shape2, mini, maxi): def test_composite(self): batch_size = (5,) - spec1 = BoundedTensorSpec( + spec1 = Bounded( -torch.ones([*batch_size, 10]), torch.ones([*batch_size, 10]), shape=( @@ -1336,22 +1294,16 @@ def test_composite(self): device="cpu", dtype=torch.bool, ) - spec2 = BinaryDiscreteTensorSpec( - n=4, shape=(*batch_size, 4), device="cpu", dtype=torch.bool - ) - spec3 = DiscreteTensorSpec( - n=4, shape=batch_size, device="cpu", dtype=torch.long - ) - spec4 = MultiDiscreteTensorSpec( + spec2 = Binary(n=4, shape=(*batch_size, 4), device="cpu", dtype=torch.bool) + spec3 = Categorical(n=4, shape=batch_size, device="cpu", dtype=torch.long) + spec4 = MultiCategorical( nvec=(4, 5, 6), shape=(*batch_size, 3), device="cpu", dtype=torch.long ) - spec5 = MultiOneHotDiscreteTensorSpec( + spec5 = MultiOneHot( nvec=(4, 5, 6), shape=(*batch_size, 15), device="cpu", dtype=torch.long ) - spec6 = OneHotDiscreteTensorSpec( - n=15, shape=(*batch_size, 15), device="cpu", dtype=torch.long - ) - spec7 = UnboundedContinuousTensorSpec( + spec6 = OneHot(n=15, shape=(*batch_size, 15), device="cpu", dtype=torch.long) + spec7 = Unbounded( shape=(*batch_size, 9), device="cpu", dtype=torch.float64, @@ -1361,7 +1313,7 @@ def test_composite(self): device="cpu", dtype=torch.long, ) - spec = CompositeSpec( + spec = Composite( spec1=spec1, spec2=spec2, spec3=spec3, @@ -1392,7 +1344,7 @@ def test_composite(self): @pytest.mark.parametrize("shape1", [None, (), (5,)]) @pytest.mark.parametrize("shape2", [(), (10,)]) def test_discrete(self, shape1, shape2): - spec = DiscreteTensorSpec(n=4, shape=shape1, device="cpu", dtype=torch.long) + spec = Categorical(n=4, shape=shape1, device="cpu", dtype=torch.long) if shape1 is not None: shape2_real = (*shape2, *shape1) else: @@ -1418,7 +1370,7 @@ def test_multidiscrete(self, shape1, shape2): shape1 = (3,) else: shape1 = (*shape1, 3) - spec = MultiDiscreteTensorSpec( + spec = MultiCategorical( nvec=(4, 5, 6), shape=shape1, device="cpu", dtype=torch.long ) if shape1 is not None: @@ -1446,9 +1398,7 @@ def test_multionehot(self, shape1, shape2): shape1 = (15,) else: shape1 = (*shape1, 15) - spec = MultiOneHotDiscreteTensorSpec( - nvec=(4, 5, 6), shape=shape1, device="cpu", dtype=torch.long - ) + spec = MultiOneHot(nvec=(4, 5, 6), shape=shape1, device="cpu", dtype=torch.long) if shape1 is not None: shape2_real = (*shape2, *shape1) else: @@ -1468,11 +1418,11 @@ def test_multionehot(self, shape1, shape2): assert spec2.zero().shape == spec2.shape def test_non_tensor(self): - spec = NonTensorSpec((3, 4), device="cpu") + spec = NonTensor((3, 4), device="cpu") assert ( spec.expand(2, 3, 4) == spec.expand((2, 3, 4)) - == NonTensorSpec((2, 3, 4), device="cpu") + == NonTensor((2, 3, 4), device="cpu") ) @pytest.mark.parametrize("shape1", [None, (), (5,)]) @@ -1482,9 +1432,7 @@ def test_onehot(self, shape1, shape2): shape1 = (15,) else: shape1 = (*shape1, 15) - spec = OneHotDiscreteTensorSpec( - n=15, shape=shape1, device="cpu", dtype=torch.long - ) + spec = OneHot(n=15, shape=shape1, device="cpu", dtype=torch.long) if shape1 is not None: shape2_real = (*shape2, *shape1) else: @@ -1510,9 +1458,7 @@ def test_unbounded(self, shape1, shape2): shape1 = (15,) else: shape1 = (*shape1, 15) - spec = UnboundedContinuousTensorSpec( - shape=shape1, device="cpu", dtype=torch.float64 - ) + spec = Unbounded(shape=shape1, device="cpu", dtype=torch.float64) if shape1 is not None: shape2_real = (*shape2, *shape1) else: @@ -1571,9 +1517,7 @@ class TestClone: ], ) def test_binary(self, shape1): - spec = BinaryDiscreteTensorSpec( - n=4, shape=shape1, device="cpu", dtype=torch.bool - ) + spec = Binary(n=4, shape=shape1, device="cpu", dtype=torch.bool) assert spec == spec.clone() assert spec is not spec.clone() @@ -1589,15 +1533,13 @@ def test_binary(self, shape1): ], ) def test_bounded(self, shape1, mini, maxi): - spec = BoundedTensorSpec( - mini, maxi, shape=shape1, device="cpu", dtype=torch.bool - ) + spec = Bounded(mini, maxi, shape=shape1, device="cpu", dtype=torch.bool) assert spec == spec.clone() assert spec is not spec.clone() def test_composite(self): batch_size = (5,) - spec1 = BoundedTensorSpec( + spec1 = Bounded( -torch.ones([*batch_size, 10]), torch.ones([*batch_size, 10]), shape=( @@ -1607,22 +1549,16 @@ def test_composite(self): device="cpu", dtype=torch.bool, ) - spec2 = BinaryDiscreteTensorSpec( - n=4, shape=(*batch_size, 4), device="cpu", dtype=torch.bool - ) - spec3 = DiscreteTensorSpec( - n=4, shape=batch_size, device="cpu", dtype=torch.long - ) - spec4 = MultiDiscreteTensorSpec( + spec2 = Binary(n=4, shape=(*batch_size, 4), device="cpu", dtype=torch.bool) + spec3 = Categorical(n=4, shape=batch_size, device="cpu", dtype=torch.long) + spec4 = MultiCategorical( nvec=(4, 5, 6), shape=(*batch_size, 3), device="cpu", dtype=torch.long ) - spec5 = MultiOneHotDiscreteTensorSpec( + spec5 = MultiOneHot( nvec=(4, 5, 6), shape=(*batch_size, 15), device="cpu", dtype=torch.long ) - spec6 = OneHotDiscreteTensorSpec( - n=15, shape=(*batch_size, 15), device="cpu", dtype=torch.long - ) - spec7 = UnboundedContinuousTensorSpec( + spec6 = OneHot(n=15, shape=(*batch_size, 15), device="cpu", dtype=torch.long) + spec7 = Unbounded( shape=(*batch_size, 9), device="cpu", dtype=torch.float64, @@ -1632,7 +1568,7 @@ def test_composite(self): device="cpu", dtype=torch.long, ) - spec = CompositeSpec( + spec = Composite( spec1=spec1, spec2=spec2, spec3=spec3, @@ -1654,7 +1590,7 @@ def test_discrete( self, shape1, ): - spec = DiscreteTensorSpec(n=4, shape=shape1, device="cpu", dtype=torch.long) + spec = Categorical(n=4, shape=shape1, device="cpu", dtype=torch.long) assert spec == spec.clone() assert spec is not spec.clone() @@ -1667,7 +1603,7 @@ def test_multidiscrete( shape1 = (3,) else: shape1 = (*shape1, 3) - spec = MultiDiscreteTensorSpec( + spec = MultiCategorical( nvec=(4, 5, 6), shape=shape1, device="cpu", dtype=torch.long ) assert spec == spec.clone() @@ -1682,14 +1618,12 @@ def test_multionehot( shape1 = (15,) else: shape1 = (*shape1, 15) - spec = MultiOneHotDiscreteTensorSpec( - nvec=(4, 5, 6), shape=shape1, device="cpu", dtype=torch.long - ) + spec = MultiOneHot(nvec=(4, 5, 6), shape=shape1, device="cpu", dtype=torch.long) assert spec == spec.clone() assert spec is not spec.clone() def test_non_tensor(self): - spec = NonTensorSpec(shape=(3, 4), device="cpu") + spec = NonTensor(shape=(3, 4), device="cpu") assert spec.clone() == spec assert spec.clone() is not spec @@ -1702,9 +1636,7 @@ def test_onehot( shape1 = (15,) else: shape1 = (*shape1, 15) - spec = OneHotDiscreteTensorSpec( - n=15, shape=shape1, device="cpu", dtype=torch.long - ) + spec = OneHot(n=15, shape=shape1, device="cpu", dtype=torch.long) assert spec == spec.clone() assert spec is not spec.clone() @@ -1717,9 +1649,7 @@ def test_unbounded( shape1 = (15,) else: shape1 = (*shape1, 15) - spec = UnboundedContinuousTensorSpec( - shape=shape1, device="cpu", dtype=torch.float64 - ) + spec = Unbounded(shape=shape1, device="cpu", dtype=torch.float64) assert spec == spec.clone() assert spec is not spec.clone() @@ -1740,9 +1670,7 @@ def test_unboundeddiscrete( class TestUnbind: @pytest.mark.parametrize("shape1", [(5, 4)]) def test_binary(self, shape1): - spec = BinaryDiscreteTensorSpec( - n=4, shape=shape1, device="cpu", dtype=torch.bool - ) + spec = Binary(n=4, shape=shape1, device="cpu", dtype=torch.bool) assert spec == torch.stack(spec.unbind(0), 0) with pytest.raises(ValueError): spec.unbind(-1) @@ -1759,16 +1687,14 @@ def test_binary(self, shape1): ], ) def test_bounded(self, shape1, mini, maxi): - spec = BoundedTensorSpec( - mini, maxi, shape=shape1, device="cpu", dtype=torch.bool - ) + spec = Bounded(mini, maxi, shape=shape1, device="cpu", dtype=torch.bool) assert spec == torch.stack(spec.unbind(0), 0) with pytest.raises(ValueError): spec.unbind(-1) def test_composite(self): batch_size = (5,) - spec1 = BoundedTensorSpec( + spec1 = Bounded( -torch.ones([*batch_size, 10]), torch.ones([*batch_size, 10]), shape=( @@ -1778,22 +1704,16 @@ def test_composite(self): device="cpu", dtype=torch.bool, ) - spec2 = BinaryDiscreteTensorSpec( - n=4, shape=(*batch_size, 4), device="cpu", dtype=torch.bool - ) - spec3 = DiscreteTensorSpec( - n=4, shape=batch_size, device="cpu", dtype=torch.long - ) - spec4 = MultiDiscreteTensorSpec( + spec2 = Binary(n=4, shape=(*batch_size, 4), device="cpu", dtype=torch.bool) + spec3 = Categorical(n=4, shape=batch_size, device="cpu", dtype=torch.long) + spec4 = MultiCategorical( nvec=(4, 5, 6), shape=(*batch_size, 3), device="cpu", dtype=torch.long ) - spec5 = MultiOneHotDiscreteTensorSpec( + spec5 = MultiOneHot( nvec=(4, 5, 6), shape=(*batch_size, 15), device="cpu", dtype=torch.long ) - spec6 = OneHotDiscreteTensorSpec( - n=15, shape=(*batch_size, 15), device="cpu", dtype=torch.long - ) - spec7 = UnboundedContinuousTensorSpec( + spec6 = OneHot(n=15, shape=(*batch_size, 15), device="cpu", dtype=torch.long) + spec7 = Unbounded( shape=(*batch_size, 9), device="cpu", dtype=torch.float64, @@ -1803,7 +1723,7 @@ def test_composite(self): device="cpu", dtype=torch.long, ) - spec = CompositeSpec( + spec = Composite( spec1=spec1, spec2=spec2, spec3=spec3, @@ -1822,7 +1742,7 @@ def test_discrete( self, shape1, ): - spec = DiscreteTensorSpec(n=4, shape=shape1, device="cpu", dtype=torch.long) + spec = Categorical(n=4, shape=shape1, device="cpu", dtype=torch.long) assert spec == torch.stack(spec.unbind(0), 0) assert spec == torch.stack(spec.unbind(-1), -1) @@ -1835,7 +1755,7 @@ def test_multidiscrete( shape1 = (3,) else: shape1 = (*shape1, 3) - spec = MultiDiscreteTensorSpec( + spec = MultiCategorical( nvec=(4, 5, 6), shape=shape1, device="cpu", dtype=torch.long ) assert spec == torch.stack(spec.unbind(0), 0) @@ -1851,15 +1771,13 @@ def test_multionehot( shape1 = (15,) else: shape1 = (*shape1, 15) - spec = MultiOneHotDiscreteTensorSpec( - nvec=(4, 5, 6), shape=shape1, device="cpu", dtype=torch.long - ) + spec = MultiOneHot(nvec=(4, 5, 6), shape=shape1, device="cpu", dtype=torch.long) assert spec == torch.stack(spec.unbind(0), 0) with pytest.raises(ValueError): spec.unbind(-1) def test_non_tensor(self): - spec = NonTensorSpec(shape=(3, 4), device="cpu") + spec = NonTensor(shape=(3, 4), device="cpu") assert spec.unbind(1)[0] == spec[:, 0] assert spec.unbind(1)[0] is not spec[:, 0] @@ -1872,9 +1790,7 @@ def test_onehot( shape1 = (15,) else: shape1 = (*shape1, 15) - spec = OneHotDiscreteTensorSpec( - n=15, shape=shape1, device="cpu", dtype=torch.long - ) + spec = OneHot(n=15, shape=shape1, device="cpu", dtype=torch.long) assert spec == torch.stack(spec.unbind(0), 0) with pytest.raises(ValueError): spec.unbind(-1) @@ -1888,9 +1804,7 @@ def test_unbounded( shape1 = (15,) else: shape1 = (*shape1, 15) - spec = UnboundedContinuousTensorSpec( - shape=shape1, device="cpu", dtype=torch.float64 - ) + spec = Unbounded(shape=shape1, device="cpu", dtype=torch.float64) assert spec == torch.stack(spec.unbind(0), 0) assert spec == torch.stack(spec.unbind(-1), -1) @@ -1908,15 +1822,15 @@ def test_unboundeddiscrete( assert spec == torch.stack(spec.unbind(-1), -1) def test_composite_encode_err(self): - c = CompositeSpec( - a=UnboundedContinuousTensorSpec( + c = Composite( + a=Unbounded( 1, ), - b=UnboundedContinuousTensorSpec( + b=Unbounded( 2, ), ) - with pytest.raises(KeyError, match="The CompositeSpec instance with keys"): + with pytest.raises(KeyError, match="The Composite instance with keys"): c.encode({"c": 0}) with pytest.raises( RuntimeError, match="raised a RuntimeError. Scroll up to know more" @@ -1932,9 +1846,7 @@ def test_composite_encode_err(self): class TestTo: @pytest.mark.parametrize("shape1", [(5, 4)]) def test_binary(self, shape1, device): - spec = BinaryDiscreteTensorSpec( - n=4, shape=shape1, device="cpu", dtype=torch.bool - ) + spec = Binary(n=4, shape=shape1, device="cpu", dtype=torch.bool) assert spec.to(device).device == device @pytest.mark.parametrize( @@ -1949,14 +1861,12 @@ def test_binary(self, shape1, device): ], ) def test_bounded(self, shape1, mini, maxi, device): - spec = BoundedTensorSpec( - mini, maxi, shape=shape1, device="cpu", dtype=torch.bool - ) + spec = Bounded(mini, maxi, shape=shape1, device="cpu", dtype=torch.bool) assert spec.to(device).device == device def test_composite(self, device): batch_size = (5,) - spec1 = BoundedTensorSpec( + spec1 = Bounded( -torch.ones([*batch_size, 10]), torch.ones([*batch_size, 10]), shape=( @@ -1966,22 +1876,16 @@ def test_composite(self, device): device="cpu", dtype=torch.bool, ) - spec2 = BinaryDiscreteTensorSpec( - n=4, shape=(*batch_size, 4), device="cpu", dtype=torch.bool - ) - spec3 = DiscreteTensorSpec( - n=4, shape=batch_size, device="cpu", dtype=torch.long - ) - spec4 = MultiDiscreteTensorSpec( + spec2 = Binary(n=4, shape=(*batch_size, 4), device="cpu", dtype=torch.bool) + spec3 = Categorical(n=4, shape=batch_size, device="cpu", dtype=torch.long) + spec4 = MultiCategorical( nvec=(4, 5, 6), shape=(*batch_size, 3), device="cpu", dtype=torch.long ) - spec5 = MultiOneHotDiscreteTensorSpec( + spec5 = MultiOneHot( nvec=(4, 5, 6), shape=(*batch_size, 15), device="cpu", dtype=torch.long ) - spec6 = OneHotDiscreteTensorSpec( - n=15, shape=(*batch_size, 15), device="cpu", dtype=torch.long - ) - spec7 = UnboundedContinuousTensorSpec( + spec6 = OneHot(n=15, shape=(*batch_size, 15), device="cpu", dtype=torch.long) + spec7 = Unbounded( shape=(*batch_size, 9), device="cpu", dtype=torch.float64, @@ -1991,7 +1895,7 @@ def test_composite(self, device): device="cpu", dtype=torch.long, ) - spec = CompositeSpec( + spec = Composite( spec1=spec1, spec2=spec2, spec3=spec3, @@ -2010,7 +1914,7 @@ def test_discrete( shape1, device, ): - spec = DiscreteTensorSpec(n=4, shape=shape1, device="cpu", dtype=torch.long) + spec = Categorical(n=4, shape=shape1, device="cpu", dtype=torch.long) assert spec.to(device).device == device @pytest.mark.parametrize("shape1", [(5,), (5, 6)]) @@ -2019,7 +1923,7 @@ def test_multidiscrete(self, shape1, device): shape1 = (3,) else: shape1 = (*shape1, 3) - spec = MultiDiscreteTensorSpec( + spec = MultiCategorical( nvec=(4, 5, 6), shape=shape1, device="cpu", dtype=torch.long ) assert spec.to(device).device == device @@ -2030,13 +1934,11 @@ def test_multionehot(self, shape1, device): shape1 = (15,) else: shape1 = (*shape1, 15) - spec = MultiOneHotDiscreteTensorSpec( - nvec=(4, 5, 6), shape=shape1, device="cpu", dtype=torch.long - ) + spec = MultiOneHot(nvec=(4, 5, 6), shape=shape1, device="cpu", dtype=torch.long) assert spec.to(device).device == device def test_non_tensor(self, device): - spec = NonTensorSpec(shape=(3, 4), device="cpu") + spec = NonTensor(shape=(3, 4), device="cpu") assert spec.to(device).device == device @pytest.mark.parametrize("shape1", [(5,), (5, 6)]) @@ -2045,9 +1947,7 @@ def test_onehot(self, shape1, device): shape1 = (15,) else: shape1 = (*shape1, 15) - spec = OneHotDiscreteTensorSpec( - n=15, shape=shape1, device="cpu", dtype=torch.long - ) + spec = OneHot(n=15, shape=shape1, device="cpu", dtype=torch.long) assert spec.to(device).device == device @pytest.mark.parametrize("shape1", [(5,), (5, 6)]) @@ -2056,9 +1956,7 @@ def test_unbounded(self, shape1, device): shape1 = (15,) else: shape1 = (*shape1, 15) - spec = UnboundedContinuousTensorSpec( - shape=shape1, device="cpu", dtype=torch.float64 - ) + spec = Unbounded(shape=shape1, device="cpu", dtype=torch.float64) assert spec.to(device).device == device @pytest.mark.parametrize("shape1", [(5,), (5, 6)]) @@ -2079,10 +1977,10 @@ class TestStack: def test_stack_binarydiscrete(self, shape, stack_dim): n = 5 shape = (*shape, n) - c1 = BinaryDiscreteTensorSpec(n=n, shape=shape) + c1 = Binary(n=n, shape=shape) c2 = c1.clone() c = torch.stack([c1, c2], stack_dim) - assert isinstance(c, BinaryDiscreteTensorSpec) + assert isinstance(c, Binary) shape = list(shape) if stack_dim < 0: stack_dim = len(shape) + stack_dim + 1 @@ -2092,7 +1990,7 @@ def test_stack_binarydiscrete(self, shape, stack_dim): def test_stack_binarydiscrete_expand(self, shape, stack_dim): n = 5 shape = (*shape, n) - c1 = BinaryDiscreteTensorSpec(n=n, shape=shape) + c1 = Binary(n=n, shape=shape) c2 = c1.clone() c = torch.stack([c1, c2], stack_dim) shape = list(shape) @@ -2105,7 +2003,7 @@ def test_stack_binarydiscrete_expand(self, shape, stack_dim): def test_stack_binarydiscrete_rand(self, shape, stack_dim): n = 5 shape = (*shape, n) - c1 = BinaryDiscreteTensorSpec(n=n, shape=shape) + c1 = Binary(n=n, shape=shape) c2 = c1.clone() c = torch.stack([c1, c2], 0) r = c.rand() @@ -2114,7 +2012,7 @@ def test_stack_binarydiscrete_rand(self, shape, stack_dim): def test_stack_binarydiscrete_zero(self, shape, stack_dim): n = 5 shape = (*shape, n) - c1 = BinaryDiscreteTensorSpec(n=n, shape=shape) + c1 = Binary(n=n, shape=shape) c2 = c1.clone() c = torch.stack([c1, c2], 0) r = c.zero() @@ -2124,10 +2022,10 @@ def test_stack_bounded(self, shape, stack_dim): mini = -1 maxi = 1 shape = (*shape,) - c1 = BoundedTensorSpec(mini, maxi, shape=shape) + c1 = Bounded(mini, maxi, shape=shape) c2 = c1.clone() c = torch.stack([c1, c2], stack_dim) - assert isinstance(c, BoundedTensorSpec) + assert isinstance(c, Bounded) shape = list(shape) if stack_dim < 0: stack_dim = len(shape) + stack_dim + 1 @@ -2138,7 +2036,7 @@ def test_stack_bounded_expand(self, shape, stack_dim): mini = -1 maxi = 1 shape = (*shape,) - c1 = BoundedTensorSpec(mini, maxi, shape=shape) + c1 = Bounded(mini, maxi, shape=shape) c2 = c1.clone() c = torch.stack([c1, c2], stack_dim) shape = list(shape) @@ -2152,7 +2050,7 @@ def test_stack_bounded_rand(self, shape, stack_dim): mini = -1 maxi = 1 shape = (*shape,) - c1 = BoundedTensorSpec(mini, maxi, shape=shape) + c1 = Bounded(mini, maxi, shape=shape) c2 = c1.clone() c = torch.stack([c1, c2], 0) r = c.rand() @@ -2162,7 +2060,7 @@ def test_stack_bounded_zero(self, shape, stack_dim): mini = -1 maxi = 1 shape = (*shape,) - c1 = BoundedTensorSpec(mini, maxi, shape=shape) + c1 = Bounded(mini, maxi, shape=shape) c2 = c1.clone() c = torch.stack([c1, c2], 0) r = c.zero() @@ -2171,10 +2069,10 @@ def test_stack_bounded_zero(self, shape, stack_dim): def test_stack_discrete(self, shape, stack_dim): n = 4 shape = (*shape,) - c1 = DiscreteTensorSpec(n, shape=shape) + c1 = Categorical(n, shape=shape) c2 = c1.clone() c = torch.stack([c1, c2], stack_dim) - assert isinstance(c, DiscreteTensorSpec) + assert isinstance(c, Categorical) shape = list(shape) if stack_dim < 0: stack_dim = len(shape) + stack_dim + 1 @@ -2184,7 +2082,7 @@ def test_stack_discrete(self, shape, stack_dim): def test_stack_discrete_expand(self, shape, stack_dim): n = 4 shape = (*shape,) - c1 = DiscreteTensorSpec(n, shape=shape) + c1 = Categorical(n, shape=shape) c2 = c1.clone() c = torch.stack([c1, c2], stack_dim) shape = list(shape) @@ -2197,7 +2095,7 @@ def test_stack_discrete_expand(self, shape, stack_dim): def test_stack_discrete_rand(self, shape, stack_dim): n = 4 shape = (*shape,) - c1 = DiscreteTensorSpec(n, shape=shape) + c1 = Categorical(n, shape=shape) c2 = c1.clone() c = torch.stack([c1, c2], 0) r = c.rand() @@ -2206,7 +2104,7 @@ def test_stack_discrete_rand(self, shape, stack_dim): def test_stack_discrete_zero(self, shape, stack_dim): n = 4 shape = (*shape,) - c1 = DiscreteTensorSpec(n, shape=shape) + c1 = Categorical(n, shape=shape) c2 = c1.clone() c = torch.stack([c1, c2], 0) r = c.zero() @@ -2215,10 +2113,10 @@ def test_stack_discrete_zero(self, shape, stack_dim): def test_stack_multidiscrete(self, shape, stack_dim): nvec = [4, 5] shape = (*shape, 2) - c1 = MultiDiscreteTensorSpec(nvec, shape=shape) + c1 = MultiCategorical(nvec, shape=shape) c2 = c1.clone() c = torch.stack([c1, c2], stack_dim) - assert isinstance(c, MultiDiscreteTensorSpec) + assert isinstance(c, MultiCategorical) shape = list(shape) if stack_dim < 0: stack_dim = len(shape) + stack_dim + 1 @@ -2228,7 +2126,7 @@ def test_stack_multidiscrete(self, shape, stack_dim): def test_stack_multidiscrete_expand(self, shape, stack_dim): nvec = [4, 5] shape = (*shape, 2) - c1 = MultiDiscreteTensorSpec(nvec, shape=shape) + c1 = MultiCategorical(nvec, shape=shape) c2 = c1.clone() c = torch.stack([c1, c2], stack_dim) shape = list(shape) @@ -2241,7 +2139,7 @@ def test_stack_multidiscrete_expand(self, shape, stack_dim): def test_stack_multidiscrete_rand(self, shape, stack_dim): nvec = [4, 5] shape = (*shape, 2) - c1 = MultiDiscreteTensorSpec(nvec, shape=shape) + c1 = MultiCategorical(nvec, shape=shape) c2 = c1.clone() c = torch.stack([c1, c2], 0) r = c.rand() @@ -2250,7 +2148,7 @@ def test_stack_multidiscrete_rand(self, shape, stack_dim): def test_stack_multidiscrete_zero(self, shape, stack_dim): nvec = [4, 5] shape = (*shape, 2) - c1 = MultiDiscreteTensorSpec(nvec, shape=shape) + c1 = MultiCategorical(nvec, shape=shape) c2 = c1.clone() c = torch.stack([c1, c2], 0) r = c.zero() @@ -2259,10 +2157,10 @@ def test_stack_multidiscrete_zero(self, shape, stack_dim): def test_stack_multionehot(self, shape, stack_dim): nvec = [4, 5] shape = (*shape, 9) - c1 = MultiOneHotDiscreteTensorSpec(nvec, shape=shape) + c1 = MultiOneHot(nvec, shape=shape) c2 = c1.clone() c = torch.stack([c1, c2], stack_dim) - assert isinstance(c, MultiOneHotDiscreteTensorSpec) + assert isinstance(c, MultiOneHot) shape = list(shape) if stack_dim < 0: stack_dim = len(shape) + stack_dim + 1 @@ -2272,7 +2170,7 @@ def test_stack_multionehot(self, shape, stack_dim): def test_stack_multionehot_expand(self, shape, stack_dim): nvec = [4, 5] shape = (*shape, 9) - c1 = MultiOneHotDiscreteTensorSpec(nvec, shape=shape) + c1 = MultiOneHot(nvec, shape=shape) c2 = c1.clone() c = torch.stack([c1, c2], stack_dim) shape = list(shape) @@ -2285,7 +2183,7 @@ def test_stack_multionehot_expand(self, shape, stack_dim): def test_stack_multionehot_rand(self, shape, stack_dim): nvec = [4, 5] shape = (*shape, 9) - c1 = MultiOneHotDiscreteTensorSpec(nvec, shape=shape) + c1 = MultiOneHot(nvec, shape=shape) c2 = c1.clone() c = torch.stack([c1, c2], 0) r = c.rand() @@ -2294,15 +2192,15 @@ def test_stack_multionehot_rand(self, shape, stack_dim): def test_stack_multionehot_zero(self, shape, stack_dim): nvec = [4, 5] shape = (*shape, 9) - c1 = MultiOneHotDiscreteTensorSpec(nvec, shape=shape) + c1 = MultiOneHot(nvec, shape=shape) c2 = c1.clone() c = torch.stack([c1, c2], 0) r = c.zero() assert r.shape == c.shape def test_stack_non_tensor(self, shape, stack_dim): - spec0 = NonTensorSpec(shape=shape, device="cpu") - spec1 = NonTensorSpec(shape=shape, device="cpu") + spec0 = NonTensor(shape=shape, device="cpu") + spec1 = NonTensor(shape=shape, device="cpu") new_spec = torch.stack([spec0, spec1], stack_dim) shape_insert = list(shape) shape_insert.insert(stack_dim, 2) @@ -2312,10 +2210,10 @@ def test_stack_non_tensor(self, shape, stack_dim): def test_stack_onehot(self, shape, stack_dim): n = 5 shape = (*shape, 5) - c1 = OneHotDiscreteTensorSpec(n, shape=shape) + c1 = OneHot(n, shape=shape) c2 = c1.clone() c = torch.stack([c1, c2], stack_dim) - assert isinstance(c, OneHotDiscreteTensorSpec) + assert isinstance(c, OneHot) shape = list(shape) if stack_dim < 0: stack_dim = len(shape) + stack_dim + 1 @@ -2325,7 +2223,7 @@ def test_stack_onehot(self, shape, stack_dim): def test_stack_onehot_expand(self, shape, stack_dim): n = 5 shape = (*shape, 5) - c1 = OneHotDiscreteTensorSpec(n, shape=shape) + c1 = OneHot(n, shape=shape) c2 = c1.clone() c = torch.stack([c1, c2], stack_dim) shape = list(shape) @@ -2338,7 +2236,7 @@ def test_stack_onehot_expand(self, shape, stack_dim): def test_stack_onehot_rand(self, shape, stack_dim): n = 5 shape = (*shape, 5) - c1 = OneHotDiscreteTensorSpec(n, shape=shape) + c1 = OneHot(n, shape=shape) c2 = c1.clone() c = torch.stack([c1, c2], 0) r = c.rand() @@ -2347,7 +2245,7 @@ def test_stack_onehot_rand(self, shape, stack_dim): def test_stack_onehot_zero(self, shape, stack_dim): n = 5 shape = (*shape, 5) - c1 = OneHotDiscreteTensorSpec(n, shape=shape) + c1 = OneHot(n, shape=shape) c2 = c1.clone() c = torch.stack([c1, c2], 0) r = c.zero() @@ -2355,10 +2253,10 @@ def test_stack_onehot_zero(self, shape, stack_dim): def test_stack_unboundedcont(self, shape, stack_dim): shape = (*shape,) - c1 = UnboundedContinuousTensorSpec(shape=shape) + c1 = Unbounded(shape=shape) c2 = c1.clone() c = torch.stack([c1, c2], stack_dim) - assert isinstance(c, UnboundedContinuousTensorSpec) + assert isinstance(c, Unbounded) shape = list(shape) if stack_dim < 0: stack_dim = len(shape) + stack_dim + 1 @@ -2367,7 +2265,7 @@ def test_stack_unboundedcont(self, shape, stack_dim): def test_stack_unboundedcont_expand(self, shape, stack_dim): shape = (*shape,) - c1 = UnboundedContinuousTensorSpec(shape=shape) + c1 = Unbounded(shape=shape) c2 = c1.clone() c = torch.stack([c1, c2], stack_dim) shape = list(shape) @@ -2379,7 +2277,7 @@ def test_stack_unboundedcont_expand(self, shape, stack_dim): def test_stack_unboundedcont_rand(self, shape, stack_dim): shape = (*shape,) - c1 = UnboundedContinuousTensorSpec(shape=shape) + c1 = Unbounded(shape=shape) c2 = c1.clone() c = torch.stack([c1, c2], 0) r = c.rand() @@ -2434,8 +2332,8 @@ def test_stack_unboundeddiscrete_zero(self, shape, stack_dim): assert r.shape == c.shape def test_to_numpy(self, shape, stack_dim): - c1 = BoundedTensorSpec(-1, 1, shape=shape, dtype=torch.float64) - c2 = BoundedTensorSpec(-1, 1, shape=shape, dtype=torch.float64) + c1 = Bounded(-1, 1, shape=shape, dtype=torch.float64) + c2 = Bounded(-1, 1, shape=shape, dtype=torch.float64) c = torch.stack([c1, c2], stack_dim) @@ -2455,13 +2353,13 @@ def test_to_numpy(self, shape, stack_dim): c.to_numpy(val + 1, safe=True) def test_malformed_stack(self, shape, stack_dim): - c1 = BoundedTensorSpec(-1, 1, shape=shape, dtype=torch.float64) - c2 = BoundedTensorSpec(-1, 1, shape=shape, dtype=torch.float32) + c1 = Bounded(-1, 1, shape=shape, dtype=torch.float64) + c2 = Bounded(-1, 1, shape=shape, dtype=torch.float32) with pytest.raises(RuntimeError, match="Dtypes differ"): torch.stack([c1, c2], stack_dim) - c1 = BoundedTensorSpec(-1, 1, shape=shape, dtype=torch.float32) - c2 = UnboundedContinuousTensorSpec(shape=shape, dtype=torch.float32) + c1 = Bounded(-1, 1, shape=shape, dtype=torch.float32) + c2 = Unbounded(shape=shape, dtype=torch.float32) c3 = UnboundedDiscreteTensorSpec(shape=shape, dtype=torch.float32) with pytest.raises( RuntimeError, @@ -2470,40 +2368,40 @@ def test_malformed_stack(self, shape, stack_dim): torch.stack([c1, c2], stack_dim) torch.stack([c3, c2], stack_dim) - c1 = BoundedTensorSpec(-1, 1, shape=shape, dtype=torch.float32) - c2 = BoundedTensorSpec(-1, 1, shape=shape + (3,), dtype=torch.float32) + c1 = Bounded(-1, 1, shape=shape, dtype=torch.float32) + c2 = Bounded(-1, 1, shape=shape + (3,), dtype=torch.float32) with pytest.raises(RuntimeError, match="Ndims differ"): torch.stack([c1, c2], stack_dim) -class TestDenseStackedCompositeSpecs: +class TestDenseStackedComposite: def test_stack(self): - c1 = CompositeSpec(a=UnboundedContinuousTensorSpec()) + c1 = Composite(a=Unbounded()) c2 = c1.clone() c = torch.stack([c1, c2], 0) - assert isinstance(c, CompositeSpec) + assert isinstance(c, Composite) -class TestLazyStackedCompositeSpecs: +class TestLazyStackedComposite: def _get_heterogeneous_specs( self, batch_size=(), stack_dim: int = 0, ): - shared = BoundedTensorSpec(low=0, high=1, shape=(*batch_size, 32, 32, 3)) - hetero_3d = UnboundedContinuousTensorSpec( + shared = Bounded(low=0, high=1, shape=(*batch_size, 32, 32, 3)) + hetero_3d = Unbounded( shape=( *batch_size, 3, ) ) - hetero_2d = UnboundedContinuousTensorSpec( + hetero_2d = Unbounded( shape=( *batch_size, 2, ) ) - lidar = BoundedTensorSpec( + lidar = Bounded( low=0, high=5, shape=( @@ -2512,9 +2410,9 @@ def _get_heterogeneous_specs( ), ) - individual_0_obs = CompositeSpec( + individual_0_obs = Composite( { - "individual_0_obs_0": UnboundedContinuousTensorSpec( + "individual_0_obs_0": Unbounded( shape=( *batch_size, 3, @@ -2524,25 +2422,21 @@ def _get_heterogeneous_specs( }, shape=(*batch_size, 3), ) - individual_1_obs = CompositeSpec( + individual_1_obs = Composite( { - "individual_1_obs_0": BoundedTensorSpec( + "individual_1_obs_0": Bounded( low=0, high=3, shape=(*batch_size, 3, 1, 2) ) }, shape=(*batch_size, 3), ) - individual_2_obs = CompositeSpec( - { - "individual_1_obs_0": UnboundedContinuousTensorSpec( - shape=(*batch_size, 3, 1, 2, 3) - ) - }, + individual_2_obs = Composite( + {"individual_1_obs_0": Unbounded(shape=(*batch_size, 3, 1, 2, 3))}, shape=(*batch_size, 3), ) spec_list = [ - CompositeSpec( + Composite( { "shared": shared, "lidar": lidar, @@ -2551,7 +2445,7 @@ def _get_heterogeneous_specs( }, shape=batch_size, ), - CompositeSpec( + Composite( { "shared": shared, "lidar": lidar, @@ -2560,7 +2454,7 @@ def _get_heterogeneous_specs( }, shape=batch_size, ), - CompositeSpec( + Composite( { "shared": shared, "hetero": hetero_2d, @@ -2573,10 +2467,8 @@ def _get_heterogeneous_specs( return torch.stack(spec_list, dim=stack_dim).cpu() def test_stack_index(self): - c1 = CompositeSpec(a=UnboundedContinuousTensorSpec()) - c2 = CompositeSpec( - a=UnboundedContinuousTensorSpec(), b=UnboundedDiscreteTensorSpec() - ) + c1 = Composite(a=Unbounded()) + c2 = Composite(a=Unbounded(), b=UnboundedDiscreteTensorSpec()) c = torch.stack([c1, c2], 0) assert c.shape == torch.Size([2]) assert c[0] is c1 @@ -2585,19 +2477,19 @@ def test_stack_index(self): assert c[..., 1] is c2 assert c[0, ...] is c1 assert c[1, ...] is c2 - assert isinstance(c[:], LazyStackedCompositeSpec) + assert isinstance(c[:], StackedComposite) @pytest.mark.parametrize("stack_dim", [0, 1, 2, -3, -2, -1]) def test_stack_index_multdim(self, stack_dim): - c1 = CompositeSpec(a=UnboundedContinuousTensorSpec(shape=(1, 3)), shape=(1, 3)) - c2 = CompositeSpec( - a=UnboundedContinuousTensorSpec(shape=(1, 3)), + c1 = Composite(a=Unbounded(shape=(1, 3)), shape=(1, 3)) + c2 = Composite( + a=Unbounded(shape=(1, 3)), b=UnboundedDiscreteTensorSpec(shape=(1, 3)), shape=(1, 3), ) c = torch.stack([c1, c2], stack_dim) if stack_dim in (0, -3): - assert isinstance(c[:], LazyStackedCompositeSpec) + assert isinstance(c[:], StackedComposite) assert c.shape == torch.Size([2, 1, 3]) assert c[0] is c1 assert c[1] is c2 @@ -2614,7 +2506,7 @@ def test_stack_index_multdim(self, stack_dim): assert c[0, ...] is c1 assert c[1, ...] is c2 elif stack_dim == (1, -2): - assert isinstance(c[:, :], LazyStackedCompositeSpec) + assert isinstance(c[:, :], StackedComposite) assert c.shape == torch.Size([1, 2, 3]) assert c[:, 0] is c1 assert c[:, 1] is c2 @@ -2641,7 +2533,7 @@ def test_stack_index_multdim(self, stack_dim): assert c[:, 0, ...] is c1 assert c[:, 1, ...] is c2 elif stack_dim == (2, -1): - assert isinstance(c[:, :, :], LazyStackedCompositeSpec) + assert isinstance(c[:, :, :], StackedComposite) with pytest.raises( IndexError, match="along dimension 0 when the stack dimension is 2." ): @@ -2660,9 +2552,9 @@ def test_stack_index_multdim(self, stack_dim): @pytest.mark.parametrize("stack_dim", [0, 1, 2, -3, -2, -1]) def test_stack_expand_multi(self, stack_dim): - c1 = CompositeSpec(a=UnboundedContinuousTensorSpec(shape=(1, 3)), shape=(1, 3)) - c2 = CompositeSpec( - a=UnboundedContinuousTensorSpec(shape=(1, 3)), + c1 = Composite(a=Unbounded(shape=(1, 3)), shape=(1, 3)) + c2 = Composite( + a=Unbounded(shape=(1, 3)), b=UnboundedDiscreteTensorSpec(shape=(1, 3)), shape=(1, 3), ) @@ -2691,9 +2583,9 @@ def test_stack_expand_multi(self, stack_dim): @pytest.mark.parametrize("stack_dim", [0, 1, 2, -3, -2, -1]) def test_stack_rand(self, stack_dim): - c1 = CompositeSpec(a=UnboundedContinuousTensorSpec(shape=(1, 3)), shape=(1, 3)) - c2 = CompositeSpec( - a=UnboundedContinuousTensorSpec(shape=(1, 3)), + c1 = Composite(a=Unbounded(shape=(1, 3)), shape=(1, 3)) + c2 = Composite( + a=Unbounded(shape=(1, 3)), b=UnboundedDiscreteTensorSpec(shape=(1, 3)), shape=(1, 3), ) @@ -2713,9 +2605,9 @@ def test_stack_rand(self, stack_dim): @pytest.mark.parametrize("stack_dim", [0, 1, 2, -3, -2, -1]) def test_stack_rand_shape(self, stack_dim): - c1 = CompositeSpec(a=UnboundedContinuousTensorSpec(shape=(1, 3)), shape=(1, 3)) - c2 = CompositeSpec( - a=UnboundedContinuousTensorSpec(shape=(1, 3)), + c1 = Composite(a=Unbounded(shape=(1, 3)), shape=(1, 3)) + c2 = Composite( + a=Unbounded(shape=(1, 3)), b=UnboundedDiscreteTensorSpec(shape=(1, 3)), shape=(1, 3), ) @@ -2736,9 +2628,9 @@ def test_stack_rand_shape(self, stack_dim): @pytest.mark.parametrize("stack_dim", [0, 1, 2, -3, -2, -1]) def test_stack_zero(self, stack_dim): - c1 = CompositeSpec(a=UnboundedContinuousTensorSpec(shape=(1, 3)), shape=(1, 3)) - c2 = CompositeSpec( - a=UnboundedContinuousTensorSpec(shape=(1, 3)), + c1 = Composite(a=Unbounded(shape=(1, 3)), shape=(1, 3)) + c2 = Composite( + a=Unbounded(shape=(1, 3)), b=UnboundedDiscreteTensorSpec(shape=(1, 3)), shape=(1, 3), ) @@ -2758,9 +2650,9 @@ def test_stack_zero(self, stack_dim): @pytest.mark.parametrize("stack_dim", [0, 1, 2, -3, -2, -1]) def test_stack_zero_shape(self, stack_dim): - c1 = CompositeSpec(a=UnboundedContinuousTensorSpec(shape=(1, 3)), shape=(1, 3)) - c2 = CompositeSpec( - a=UnboundedContinuousTensorSpec(shape=(1, 3)), + c1 = Composite(a=Unbounded(shape=(1, 3)), shape=(1, 3)) + c2 = Composite( + a=Unbounded(shape=(1, 3)), b=UnboundedDiscreteTensorSpec(shape=(1, 3)), shape=(1, 3), ) @@ -2782,14 +2674,14 @@ def test_stack_zero_shape(self, stack_dim): @pytest.mark.skipif(not torch.cuda.device_count(), reason="no cuda") @pytest.mark.parametrize("stack_dim", [0, 1, 2, -3, -2, -1]) def test_to(self, stack_dim): - c1 = CompositeSpec(a=UnboundedContinuousTensorSpec(shape=(1, 3)), shape=(1, 3)) - c2 = CompositeSpec( - a=UnboundedContinuousTensorSpec(shape=(1, 3)), + c1 = Composite(a=Unbounded(shape=(1, 3)), shape=(1, 3)) + c2 = Composite( + a=Unbounded(shape=(1, 3)), b=UnboundedDiscreteTensorSpec(shape=(1, 3)), shape=(1, 3), ) c = torch.stack([c1, c2], stack_dim) - assert isinstance(c, LazyStackedCompositeSpec) + assert isinstance(c, StackedComposite) cdevice = c.to("cuda:0") assert cdevice.device != c.device assert cdevice.device == torch.device("cuda:0") @@ -2799,9 +2691,9 @@ def test_to(self, stack_dim): assert cdevice[index].device == torch.device("cuda:0") def test_clone(self): - c1 = CompositeSpec(a=UnboundedContinuousTensorSpec(shape=(1, 3)), shape=(1, 3)) - c2 = CompositeSpec( - a=UnboundedContinuousTensorSpec(shape=(1, 3)), + c1 = Composite(a=Unbounded(shape=(1, 3)), shape=(1, 3)) + c2 = Composite( + a=Unbounded(shape=(1, 3)), b=UnboundedDiscreteTensorSpec(shape=(1, 3)), shape=(1, 3), ) @@ -2811,9 +2703,9 @@ def test_clone(self): assert cclone[0] == c[0] def test_to_numpy(self): - c1 = CompositeSpec(a=BoundedTensorSpec(-1, 1, shape=(1, 3)), shape=(1, 3)) - c2 = CompositeSpec( - a=BoundedTensorSpec(-1, 1, shape=(1, 3)), + c1 = Composite(a=Bounded(-1, 1, shape=(1, 3)), shape=(1, 3)) + c2 = Composite( + a=Bounded(-1, 1, shape=(1, 3)), b=UnboundedDiscreteTensorSpec(shape=(1, 3)), shape=(1, 3), ) @@ -2829,9 +2721,9 @@ def test_to_numpy(self): c.to_numpy(td_fail, safe=True) def test_unsqueeze(self): - c1 = CompositeSpec(a=BoundedTensorSpec(-1, 1, shape=(1, 3)), shape=(1, 3)) - c2 = CompositeSpec( - a=BoundedTensorSpec(-1, 1, shape=(1, 3)), + c1 = Composite(a=Bounded(-1, 1, shape=(1, 3)), shape=(1, 3)) + c2 = Composite( + a=Bounded(-1, 1, shape=(1, 3)), b=UnboundedDiscreteTensorSpec(shape=(1, 3)), shape=(1, 3), ) @@ -2984,12 +2876,11 @@ def test_project(self, batch_size): def test_repr(self): c = self._get_heterogeneous_specs() - - expected = f"""LazyStackedCompositeSpec( + expected = f"""StackedComposite( fields={{ - hetero: LazyStackedUnboundedContinuousTensorSpec( + hetero: StackedUnboundedContinuous( shape=torch.Size([3, -1]), device=cpu, dtype=torch.float32, domain=continuous), - shared: BoundedTensorSpec( + shared: BoundedContinuous( shape=torch.Size([3, 32, 32, 3]), space=ContinuousBox( low=Tensor(shape=torch.Size([3, 32, 32, 3]), device=cpu, dtype=torch.float32, contiguous=True), @@ -2999,7 +2890,7 @@ def test_repr(self): domain=continuous)}}, exclusive_fields={{ 0 -> - lidar: BoundedTensorSpec( + lidar: BoundedContinuous( shape=torch.Size([20]), space=ContinuousBox( low=Tensor(shape=torch.Size([20]), device=cpu, dtype=torch.float32, contiguous=True), @@ -3007,17 +2898,19 @@ def test_repr(self): device=cpu, dtype=torch.float32, domain=continuous), - individual_0_obs: CompositeSpec( - individual_0_obs_0: UnboundedContinuousTensorSpec( + individual_0_obs: Composite( + individual_0_obs_0: UnboundedContinuous( shape=torch.Size([3, 1]), - space=None, + space=ContinuousBox( + low=Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.float32, contiguous=True), + high=Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.float32, contiguous=True)), device=cpu, dtype=torch.float32, domain=continuous), device=cpu, shape=torch.Size([3])), 1 -> - lidar: BoundedTensorSpec( + lidar: BoundedContinuous( shape=torch.Size([20]), space=ContinuousBox( low=Tensor(shape=torch.Size([20]), device=cpu, dtype=torch.float32, contiguous=True), @@ -3025,8 +2918,8 @@ def test_repr(self): device=cpu, dtype=torch.float32, domain=continuous), - individual_1_obs: CompositeSpec( - individual_1_obs_0: BoundedTensorSpec( + individual_1_obs: Composite( + individual_1_obs_0: BoundedContinuous( shape=torch.Size([3, 1, 2]), space=ContinuousBox( low=Tensor(shape=torch.Size([3, 1, 2]), device=cpu, dtype=torch.float32, contiguous=True), @@ -3037,10 +2930,12 @@ def test_repr(self): device=cpu, shape=torch.Size([3])), 2 -> - individual_2_obs: CompositeSpec( - individual_1_obs_0: UnboundedContinuousTensorSpec( + individual_2_obs: Composite( + individual_1_obs_0: UnboundedContinuous( shape=torch.Size([3, 1, 2, 3]), - space=None, + space=ContinuousBox( + low=Tensor(shape=torch.Size([3, 1, 2, 3]), device=cpu, dtype=torch.float32, contiguous=True), + high=Tensor(shape=torch.Size([3, 1, 2, 3]), device=cpu, dtype=torch.float32, contiguous=True)), device=cpu, dtype=torch.float32, domain=continuous), @@ -3054,11 +2949,11 @@ def test_repr(self): c = c[0:2] del c["individual_0_obs"] del c["individual_1_obs"] - expected = f"""LazyStackedCompositeSpec( + expected = f"""StackedComposite( fields={{ - hetero: LazyStackedUnboundedContinuousTensorSpec( + hetero: StackedUnboundedContinuous( shape=torch.Size([2, -1]), device=cpu, dtype=torch.float32, domain=continuous), - lidar: BoundedTensorSpec( + lidar: BoundedContinuous( shape=torch.Size([2, 20]), space=ContinuousBox( low=Tensor(shape=torch.Size([2, 20]), device=cpu, dtype=torch.float32, contiguous=True), @@ -3066,7 +2961,7 @@ def test_repr(self): device=cpu, dtype=torch.float32, domain=continuous), - shared: BoundedTensorSpec( + shared: BoundedContinuous( shape=torch.Size([2, 32, 32, 3]), space=ContinuousBox( low=Tensor(shape=torch.Size([2, 32, 32, 3]), device=cpu, dtype=torch.float32, contiguous=True), @@ -3100,7 +2995,7 @@ def test_consolidate_spec(self, batch_size): @pytest.mark.parametrize("batch_size", [(), (2,), (2, 1)]) def test_consolidate_spec_exclusive_lazy_stacked(self, batch_size): - shared = UnboundedContinuousTensorSpec( + shared = Unbounded( shape=( *batch_size, 5, @@ -3110,29 +3005,29 @@ def test_consolidate_spec_exclusive_lazy_stacked(self, batch_size): ) lazy_spec = torch.stack( [ - UnboundedContinuousTensorSpec(shape=(*batch_size, 5, 6, 7)), - UnboundedContinuousTensorSpec(shape=(*batch_size, 5, 7, 7)), - UnboundedContinuousTensorSpec(shape=(*batch_size, 5, 8, 7)), - UnboundedContinuousTensorSpec(shape=(*batch_size, 5, 8, 7)), + Unbounded(shape=(*batch_size, 5, 6, 7)), + Unbounded(shape=(*batch_size, 5, 7, 7)), + Unbounded(shape=(*batch_size, 5, 8, 7)), + Unbounded(shape=(*batch_size, 5, 8, 7)), ], dim=len(batch_size), ) spec_list = [ - CompositeSpec( + Composite( { "shared": shared, "lazy_spec": lazy_spec, }, shape=batch_size, ), - CompositeSpec( + Composite( { "shared": shared, }, shape=batch_size, ), - CompositeSpec( + Composite( {}, shape=batch_size, device="cpu", @@ -3168,9 +3063,7 @@ def test_update(self, batch_size, stack_dim=0): spec[1]["individual_1_obs"]["individual_1_obs_0"].space.low.sum() == 0 ) # Only non exclusive keys will be updated - new = torch.stack( - [UnboundedContinuousTensorSpec(shape=(*batch_size, i)) for i in range(3)], 0 - ) + new = torch.stack([Unbounded(shape=(*batch_size, i)) for i in range(3)], 0) spec2["new"] = new spec.update(spec2) assert spec["new"] == new @@ -3181,7 +3074,7 @@ def test_set_item(self, batch_size, stack_dim): spec = self._get_heterogeneous_specs(batch_size, stack_dim) new = torch.stack( - [UnboundedContinuousTensorSpec(shape=(*batch_size, i)) for i in range(3)], + [Unbounded(shape=(*batch_size, i)) for i in range(3)], stack_dim, ) spec["new"] = new @@ -3196,15 +3089,15 @@ def test_set_item(self, batch_size, stack_dim): spec[("other", "key")] = new assert spec[("other", "key")] == new - assert isinstance(spec["other"], LazyStackedCompositeSpec) + assert isinstance(spec["other"], StackedComposite) with pytest.raises(RuntimeError, match="key should be a Sequence"): spec[0] = new comp = torch.stack( [ - CompositeSpec( - {"a": UnboundedContinuousTensorSpec(shape=(*batch_size, i))}, + Composite( + {"a": Unbounded(shape=(*batch_size, i))}, shape=batch_size, ) for i in range(3) @@ -3220,10 +3113,10 @@ def test_set_item(self, batch_size, stack_dim): @pytest.mark.parametrize( "spec_class", [ - BinaryDiscreteTensorSpec, - OneHotDiscreteTensorSpec, - MultiOneHotDiscreteTensorSpec, - CompositeSpec, + Binary, + OneHot, + MultiOneHot, + Composite, ], ) @pytest.mark.parametrize( @@ -3240,13 +3133,13 @@ def test_set_item(self, batch_size, stack_dim): ], # [:,1:2,1] ) def test_invalid_indexing(spec_class, idx): - if spec_class in [BinaryDiscreteTensorSpec, OneHotDiscreteTensorSpec]: + if spec_class in [Binary, OneHot]: spec = spec_class(n=4, shape=[3, 4]) - elif spec_class == MultiDiscreteTensorSpec: + elif spec_class == MultiCategorical: spec = spec_class([2, 2, 2], shape=[3]) - elif spec_class == MultiOneHotDiscreteTensorSpec: + elif spec_class == MultiOneHot: spec = spec_class([4], shape=[3, 4]) - elif spec_class == CompositeSpec: + elif spec_class == Composite: spec = spec_class(k=UnboundedDiscreteTensorSpec(shape=(3, 4)), shape=(3,)) with pytest.raises(IndexError): spec[idx] @@ -3256,13 +3149,13 @@ def test_invalid_indexing(spec_class, idx): @pytest.mark.parametrize( "spec_class", [ - BinaryDiscreteTensorSpec, - DiscreteTensorSpec, - MultiOneHotDiscreteTensorSpec, - OneHotDiscreteTensorSpec, - UnboundedContinuousTensorSpec, + Binary, + Categorical, + MultiOneHot, + OneHot, + Unbounded, UnboundedDiscreteTensorSpec, - CompositeSpec, + Composite, ], ) def test_valid_indexing(spec_class): @@ -3270,14 +3163,14 @@ def test_valid_indexing(spec_class): args = {"0d": [], "2d": [], "3d": [], "4d": [], "5d": []} kwargs = {} if spec_class in [ - BinaryDiscreteTensorSpec, - DiscreteTensorSpec, - OneHotDiscreteTensorSpec, + Binary, + Categorical, + OneHot, ]: args = {"0d": [0], "2d": [3], "3d": [4], "4d": [6], "5d": [7]} - elif spec_class == MultiOneHotDiscreteTensorSpec: + elif spec_class == MultiOneHot: args = {"0d": [[0]], "2d": [[3]], "3d": [[4]], "4d": [[6]], "5d": [[7]]} - elif spec_class == MultiDiscreteTensorSpec: + elif spec_class == MultiCategorical: args = { "0d": [[0]], "2d": [[2] * 3], @@ -3285,7 +3178,7 @@ def test_valid_indexing(spec_class): "4d": [[1] * 6], "5d": [[2] * 7], } - elif spec_class == BoundedTensorSpec: + elif spec_class == Bounded: min_max = (-1, -1) args = { "0d": min_max, @@ -3294,17 +3187,17 @@ def test_valid_indexing(spec_class): "4d": min_max, "5d": min_max, } - elif spec_class == CompositeSpec: + elif spec_class == Composite: kwargs = { "k1": UnboundedDiscreteTensorSpec(shape=(5, 3, 4, 6, 7, 8)), - "k2": OneHotDiscreteTensorSpec(n=7, shape=(5, 3, 4, 6, 7)), + "k2": OneHot(n=7, shape=(5, 3, 4, 6, 7)), } spec_0d = spec_class(*args["0d"], **kwargs) if spec_class in [ - UnboundedContinuousTensorSpec, + Unbounded, UnboundedDiscreteTensorSpec, - CompositeSpec, + Composite, ]: spec_0d = spec_class(*args["0d"], shape=[], **kwargs) spec_2d = spec_class(*args["2d"], shape=[5, 3], **kwargs) @@ -3374,10 +3267,10 @@ def test_valid_indexing(spec_class): # Specific tests when specs have non-indexable dimensions if spec_class in [ - BinaryDiscreteTensorSpec, - OneHotDiscreteTensorSpec, - MultiDiscreteTensorSpec, - MultiOneHotDiscreteTensorSpec, + Binary, + OneHot, + MultiCategorical, + MultiOneHot, ]: # Ellipsis assert spec_0d[None].shape == torch.Size([1, 0]) @@ -3390,7 +3283,6 @@ def test_valid_indexing(spec_class): assert spec_3d[None, 1, ..., None].shape == torch.Size([1, 3, 1, 4]) assert spec_4d[:, None, ..., None, :].shape == torch.Size([5, 1, 3, 1, 4, 6]) - # BoundedTensorSpec, DiscreteTensorSpec, UnboundedContinuousTensorSpec, UnboundedDiscreteTensorSpec, CompositeSpec else: # Integers assert spec_2d[0, 1].shape == torch.Size([]) @@ -3407,7 +3299,7 @@ def test_valid_indexing(spec_class): assert spec_4d[:, None, ..., None, :].shape == torch.Size([5, 1, 3, 4, 1, 6]) # Additional tests for composite spec - if spec_class == CompositeSpec: + if spec_class == Composite: assert spec_2d[1]["k1"].shape == torch.Size([3, 4, 6, 7, 8]) assert spec_3d[[1, 2]]["k1"].shape == torch.Size([2, 3, 4, 6, 7, 8]) assert spec_2d[torch.randint(3, (3, 2))]["k1"].shape == torch.Size( @@ -3422,9 +3314,7 @@ def test_valid_indexing(spec_class): def test_composite_contains(): - spec = CompositeSpec( - a=CompositeSpec(b=CompositeSpec(c=UnboundedContinuousTensorSpec())) - ) + spec = Composite(a=Composite(b=Composite(c=Unbounded()))) assert "a" in spec.keys() assert "a" in spec.keys(True) assert ("a",) in spec.keys() @@ -3444,10 +3334,10 @@ def get_all_keys(spec: TensorSpec, include_exclusive: bool): """ keys = set() - if isinstance(spec, LazyStackedCompositeSpec) and include_exclusive: + if isinstance(spec, StackedComposite) and include_exclusive: for t in spec._specs: keys = keys.union(get_all_keys(t, include_exclusive)) - if isinstance(spec, CompositeSpec): + if isinstance(spec, Composite): for key in spec.keys(): keys.add((key,)) inner_keys = get_all_keys(spec[key], include_exclusive) @@ -3481,7 +3371,7 @@ def _make_mask(self, shape): def _one_hot_spec(self, shape, device, n): shape = torch.Size([*shape, n]) mask = self._make_mask(shape).to(device) - return OneHotDiscreteTensorSpec(n, shape, device, mask=mask) + return OneHot(n, shape, device, mask=mask) def _mult_one_hot_spec(self, shape, device, n): shape = torch.Size([*shape, n + n + 2]) @@ -3492,11 +3382,11 @@ def _mult_one_hot_spec(self, shape, device, n): ], -1, ) - return MultiOneHotDiscreteTensorSpec([n, n + 2], shape, device, mask=mask) + return MultiOneHot([n, n + 2], shape, device, mask=mask) def _discrete_spec(self, shape, device, n): mask = self._make_mask(torch.Size([*shape, n])).to(device) - return DiscreteTensorSpec(n, shape, device, mask=mask) + return Categorical(n, shape, device, mask=mask) def _mult_discrete_spec(self, shape, device, n): shape = torch.Size([*shape, 2]) @@ -3507,7 +3397,7 @@ def _mult_discrete_spec(self, shape, device, n): ], -1, ) - return MultiDiscreteTensorSpec([n, n + 2], shape, device, mask=mask) + return MultiCategorical([n, n + 2], shape, device, mask=mask) def test_equal(self, shape, device, spectype, rand_shape, n=5): shape = torch.Size(shape) @@ -3579,7 +3469,7 @@ def test_project(self, shape, device, spectype, rand_shape, n=5): class TestDynamicSpec: def test_all(self): - spec = UnboundedContinuousTensorSpec((-1, 1, 2)) + spec = Unbounded((-1, 1, 2)) unb = spec assert spec.shape == (-1, 1, 2) x = torch.randn(3, 1, 2) @@ -3593,14 +3483,14 @@ def test_all(self): xunbd = x assert spec.is_in(x) - spec = BoundedTensorSpec(shape=(-1, 1, 2), low=-1, high=1) + spec = Bounded(shape=(-1, 1, 2), low=-1, high=1) bound = spec assert spec.shape == (-1, 1, 2) x = torch.rand((3, 1, 2)) xbound = x assert spec.is_in(x) - spec = OneHotDiscreteTensorSpec(shape=(-1, 1, 2, 4), n=4) + spec = OneHot(shape=(-1, 1, 2, 4), n=4) oneh = spec assert spec.shape == (-1, 1, 2, 4) x = torch.zeros((3, 1, 2, 4), dtype=torch.bool) @@ -3608,14 +3498,14 @@ def test_all(self): xoneh = x assert spec.is_in(x) - spec = DiscreteTensorSpec(shape=(-1, 1, 2), n=4) + spec = Categorical(shape=(-1, 1, 2), n=4) disc = spec assert spec.shape == (-1, 1, 2) x = torch.randint(4, (3, 1, 2)) xdisc = x assert spec.is_in(x) - spec = MultiOneHotDiscreteTensorSpec(shape=(-1, 1, 2, 7), nvec=[3, 4]) + spec = MultiOneHot(shape=(-1, 1, 2, 7), nvec=[3, 4]) moneh = spec assert spec.shape == (-1, 1, 2, 7) x = torch.zeros((3, 1, 2, 7), dtype=torch.bool) @@ -3624,7 +3514,7 @@ def test_all(self): xmoneh = x assert spec.is_in(x) - spec = MultiDiscreteTensorSpec(shape=(-1, 1, 2, 2), nvec=[3, 4]) + spec = MultiCategorical(shape=(-1, 1, 2, 2), nvec=[3, 4]) mdisc = spec assert spec.mask is None assert spec.shape == (-1, 1, 2, 2) @@ -3632,7 +3522,7 @@ def test_all(self): xmdisc = x assert spec.is_in(x) - spec = CompositeSpec( + spec = Composite( unb=unb, unbd=unbd, bound=bound, @@ -3659,15 +3549,15 @@ def test_all(self): assert spec.is_in(data) def test_expand(self): - unb = UnboundedContinuousTensorSpec((-1, 1, 2)) + unb = Unbounded((-1, 1, 2)) unbd = UnboundedDiscreteTensorSpec((-1, 1, 2)) - bound = BoundedTensorSpec(shape=(-1, 1, 2), low=-1, high=1) - oneh = OneHotDiscreteTensorSpec(shape=(-1, 1, 2, 4), n=4) - disc = DiscreteTensorSpec(shape=(-1, 1, 2), n=4) - moneh = MultiOneHotDiscreteTensorSpec(shape=(-1, 1, 2, 7), nvec=[3, 4]) - mdisc = MultiDiscreteTensorSpec(shape=(-1, 1, 2, 2), nvec=[3, 4]) + bound = Bounded(shape=(-1, 1, 2), low=-1, high=1) + oneh = OneHot(shape=(-1, 1, 2, 4), n=4) + disc = Categorical(shape=(-1, 1, 2), n=4) + moneh = MultiOneHot(shape=(-1, 1, 2, 7), nvec=[3, 4]) + mdisc = MultiCategorical(shape=(-1, 1, 2, 2), nvec=[3, 4]) - spec = CompositeSpec( + spec = Composite( unb=unb, unbd=unbd, bound=bound, @@ -3689,7 +3579,7 @@ def test_expand(self): class TestNonTensorSpec: def test_sample(self): - nts = NonTensorSpec(shape=(3, 4)) + nts = NonTensor(shape=(3, 4)) assert nts.one((2,)).shape == (2, 3, 4) assert nts.rand((2,)).shape == (2, 3, 4) assert nts.zero((2,)).shape == (2, 3, 4) @@ -3707,26 +3597,24 @@ def test_device_ordinal(): assert _make_ordinal_device(device) is None device = torch.device("cuda") - unb = UnboundedContinuousTensorSpec((-1, 1, 2), device=device) + unb = Unbounded((-1, 1, 2), device=device) assert unb.device == torch.device("cuda:0") unbd = UnboundedDiscreteTensorSpec((-1, 1, 2), device=device) assert unbd.device == torch.device("cuda:0") - bound = BoundedTensorSpec(shape=(-1, 1, 2), low=-1, high=1, device=device) + bound = Bounded(shape=(-1, 1, 2), low=-1, high=1, device=device) assert bound.device == torch.device("cuda:0") - oneh = OneHotDiscreteTensorSpec(shape=(-1, 1, 2, 4), n=4, device=device) + oneh = OneHot(shape=(-1, 1, 2, 4), n=4, device=device) assert oneh.device == torch.device("cuda:0") - disc = DiscreteTensorSpec(shape=(-1, 1, 2), n=4, device=device) + disc = Categorical(shape=(-1, 1, 2), n=4, device=device) assert disc.device == torch.device("cuda:0") - moneh = MultiOneHotDiscreteTensorSpec( - shape=(-1, 1, 2, 7), nvec=[3, 4], device=device - ) + moneh = MultiOneHot(shape=(-1, 1, 2, 7), nvec=[3, 4], device=device) assert moneh.device == torch.device("cuda:0") - mdisc = MultiDiscreteTensorSpec(shape=(-1, 1, 2, 2), nvec=[3, 4], device=device) + mdisc = MultiCategorical(shape=(-1, 1, 2, 2), nvec=[3, 4], device=device) assert mdisc.device == torch.device("cuda:0") - mdisc = NonTensorSpec(shape=(-1, 1, 2, 2), device=device) + mdisc = NonTensor(shape=(-1, 1, 2, 2), device=device) assert mdisc.device == torch.device("cuda:0") - spec = CompositeSpec( + spec = Composite( unb=unb, unbd=unbd, bound=bound, @@ -3740,6 +3628,181 @@ def test_device_ordinal(): assert spec.device == torch.device("cuda:0") +class TestLegacy: + def test_one_hot(self): + with pytest.warns( + DeprecationWarning, + match="The OneHotDiscreteTensorSpec has been deprecated and will be removed in v0.7. Please use OneHot instead.", + ): + one_hot = OneHotDiscreteTensorSpec(n=4) + assert isinstance(one_hot, OneHotDiscreteTensorSpec) + assert isinstance(one_hot, OneHot) + assert not isinstance(one_hot, Categorical) + one_hot = OneHot(n=4) + assert isinstance(one_hot, OneHotDiscreteTensorSpec) + assert isinstance(one_hot, OneHot) + assert not isinstance(one_hot, Categorical) + + def test_discrete(self): + with pytest.warns( + DeprecationWarning, + match="The DiscreteTensorSpec has been deprecated and will be removed in v0.7. Please use Categorical instead.", + ): + discrete = DiscreteTensorSpec(n=4) + assert isinstance(discrete, DiscreteTensorSpec) + assert isinstance(discrete, Categorical) + assert not isinstance(discrete, OneHot) + discrete = Categorical(n=4) + assert isinstance(discrete, DiscreteTensorSpec) + assert isinstance(discrete, Categorical) + assert not isinstance(discrete, OneHot) + + def test_unbounded(self): + + unbounded_continuous_impl = Unbounded(dtype=torch.float) + assert isinstance(unbounded_continuous_impl, Unbounded) + assert isinstance(unbounded_continuous_impl, UnboundedContinuous) + assert isinstance(unbounded_continuous_impl, UnboundedContinuousTensorSpec) + assert not isinstance(unbounded_continuous_impl, UnboundedDiscreteTensorSpec) + + unbounded_discrete_impl = Unbounded(dtype=torch.int) + assert isinstance(unbounded_discrete_impl, Unbounded) + assert isinstance(unbounded_discrete_impl, UnboundedDiscrete) + assert isinstance(unbounded_discrete_impl, UnboundedDiscreteTensorSpec) + assert not isinstance(unbounded_discrete_impl, UnboundedContinuousTensorSpec) + + with pytest.warns( + DeprecationWarning, + match="The UnboundedContinuousTensorSpec has been deprecated and will be removed in v0.7. Please use Unbounded instead.", + ): + unbounded_continuous = UnboundedContinuousTensorSpec() + assert isinstance(unbounded_continuous, Unbounded) + assert isinstance(unbounded_continuous, UnboundedContinuous) + assert isinstance(unbounded_continuous, UnboundedContinuousTensorSpec) + assert not isinstance(unbounded_continuous, UnboundedDiscreteTensorSpec) + + with warnings.catch_warnings(): + unbounded_continuous = UnboundedContinuous() + + with pytest.warns( + DeprecationWarning, + match="The UnboundedDiscreteTensorSpec has been deprecated and will be removed in v0.7. Please use Unbounded instead.", + ): + unbounded_discrete = UnboundedDiscreteTensorSpec() + assert isinstance(unbounded_discrete, Unbounded) + assert isinstance(unbounded_discrete, UnboundedDiscrete) + assert isinstance(unbounded_discrete, UnboundedDiscreteTensorSpec) + assert not isinstance(unbounded_discrete, UnboundedContinuousTensorSpec) + + with warnings.catch_warnings(): + unbounded_discrete = UnboundedDiscrete() + + # What if we mess with dtypes? + with pytest.warns(DeprecationWarning): + unbounded_continuous_fake = UnboundedContinuousTensorSpec(dtype=torch.int32) + assert isinstance(unbounded_continuous_fake, Unbounded) + assert not isinstance(unbounded_continuous_fake, UnboundedContinuous) + assert not isinstance(unbounded_continuous_fake, UnboundedContinuousTensorSpec) + assert isinstance(unbounded_continuous_fake, UnboundedDiscrete) + assert isinstance(unbounded_continuous_fake, UnboundedDiscreteTensorSpec) + + with pytest.warns(DeprecationWarning): + unbounded_discrete_fake = UnboundedDiscreteTensorSpec(dtype=torch.float32) + assert isinstance(unbounded_discrete_fake, Unbounded) + assert isinstance(unbounded_discrete_fake, UnboundedContinuous) + assert isinstance(unbounded_discrete_fake, UnboundedContinuousTensorSpec) + assert not isinstance(unbounded_discrete_fake, UnboundedDiscrete) + assert not isinstance(unbounded_discrete_fake, UnboundedDiscreteTensorSpec) + + def test_multi_one_hot(self): + with pytest.warns( + DeprecationWarning, + match="The MultiOneHotDiscreteTensorSpec has been deprecated and will be removed in v0.7. Please use MultiOneHot instead.", + ): + one_hot = MultiOneHotDiscreteTensorSpec(nvec=[4, 3]) + assert isinstance(one_hot, MultiOneHotDiscreteTensorSpec) + assert isinstance(one_hot, MultiOneHot) + assert not isinstance(one_hot, MultiCategorical) + one_hot = MultiOneHot(nvec=[4, 3]) + assert isinstance(one_hot, MultiOneHotDiscreteTensorSpec) + assert isinstance(one_hot, MultiOneHot) + assert not isinstance(one_hot, MultiCategorical) + + def test_multi_categorical(self): + with pytest.warns( + DeprecationWarning, + match="The MultiDiscreteTensorSpec has been deprecated and will be removed in v0.7. Please use MultiCategorical instead.", + ): + categorical = MultiDiscreteTensorSpec(nvec=[4, 3]) + assert isinstance(categorical, MultiDiscreteTensorSpec) + assert isinstance(categorical, MultiCategorical) + assert not isinstance(categorical, MultiOneHot) + categorical = MultiCategorical(nvec=[4, 3]) + assert isinstance(categorical, MultiDiscreteTensorSpec) + assert isinstance(categorical, MultiCategorical) + assert not isinstance(categorical, MultiOneHot) + + def test_binary(self): + with pytest.warns( + DeprecationWarning, + match="The BinaryDiscreteTensorSpec has been deprecated and will be removed in v0.7. Please use Binary instead.", + ): + binary = BinaryDiscreteTensorSpec(5) + assert isinstance(binary, BinaryDiscreteTensorSpec) + assert isinstance(binary, Binary) + assert not isinstance(binary, MultiOneHot) + binary = Binary(5) + assert isinstance(binary, BinaryDiscreteTensorSpec) + assert isinstance(binary, Binary) + assert not isinstance(binary, MultiOneHot) + + def test_bounded(self): + with pytest.warns( + DeprecationWarning, + match="The BoundedTensorSpec has been deprecated and will be removed in v0.7. Please use Bounded instead.", + ): + bounded = BoundedTensorSpec(-2, 2, shape=()) + assert isinstance(bounded, BoundedTensorSpec) + assert isinstance(bounded, Bounded) + assert not isinstance(bounded, MultiOneHot) + bounded = Bounded(-2, 2, shape=()) + assert isinstance(bounded, BoundedTensorSpec) + assert isinstance(bounded, Bounded) + assert not isinstance(bounded, MultiOneHot) + + def test_composite(self): + with ( + pytest.warns( + DeprecationWarning, + match="The CompositeSpec has been deprecated and will be removed in v0.7. Please use Composite instead.", + ) + ): + composite = CompositeSpec() + assert isinstance(composite, CompositeSpec) + assert isinstance(composite, Composite) + assert not isinstance(composite, MultiOneHot) + composite = Composite() + assert isinstance(composite, CompositeSpec) + assert isinstance(composite, Composite) + assert not isinstance(composite, MultiOneHot) + + def test_non_tensor(self): + with ( + pytest.warns( + DeprecationWarning, + match="The NonTensorSpec has been deprecated and will be removed in v0.7. Please use NonTensor instead.", + ) + ): + non_tensor = NonTensorSpec() + assert isinstance(non_tensor, NonTensorSpec) + assert isinstance(non_tensor, NonTensor) + assert not isinstance(non_tensor, MultiOneHot) + non_tensor = NonTensor() + assert isinstance(non_tensor, NonTensorSpec) + assert isinstance(non_tensor, NonTensor) + assert not isinstance(non_tensor, MultiOneHot) + + if __name__ == "__main__": args, unknown = argparse.ArgumentParser().parse_known_args() pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown) diff --git a/test/test_tensordictmodules.py b/test/test_tensordictmodules.py index 38360a464e0..ea177cb9f96 100644 --- a/test/test_tensordictmodules.py +++ b/test/test_tensordictmodules.py @@ -11,11 +11,7 @@ from tensordict import LazyStackedTensorDict, pad, TensorDict, unravel_key_list from tensordict.nn import InteractionType, TensorDictModule, TensorDictSequential from torch import nn -from torchrl.data.tensor_specs import ( - BoundedTensorSpec, - CompositeSpec, - UnboundedContinuousTensorSpec, -) +from torchrl.data.tensor_specs import Bounded, Composite, Unbounded from torchrl.envs import ( CatFrames, Compose, @@ -119,8 +115,8 @@ def forward(self, x): return self.linear_1(x), self.linear_2(x) spec_dict = { - "_": UnboundedContinuousTensorSpec((4,)), - "out_2": UnboundedContinuousTensorSpec((3,)), + "_": Unbounded((4,)), + "out_2": Unbounded((3,)), } # warning due to "_" in spec keys @@ -129,7 +125,7 @@ def forward(self, x): MultiHeadLinear(5, 4, 3), in_keys=["input"], out_keys=["_", "out_2"], - spec=CompositeSpec(**spec_dict), + spec=Composite(**spec_dict), ) @pytest.mark.parametrize("safe", [True, False]) @@ -146,9 +142,9 @@ def test_stateful(self, safe, spec_type, lazy): if spec_type is None: spec = None elif spec_type == "bounded": - spec = BoundedTensorSpec(-0.1, 0.1, 4) + spec = Bounded(-0.1, 0.1, 4) elif spec_type == "unbounded": - spec = UnboundedContinuousTensorSpec(4) + spec = Unbounded(4) if safe and spec is None: with pytest.raises( @@ -189,7 +185,7 @@ def test_stateful(self, safe, spec_type, lazy): @pytest.mark.parametrize("out_keys", [["loc", "scale"], ["loc_1", "scale_1"]]) @pytest.mark.parametrize("lazy", [True, False]) @pytest.mark.parametrize( - "exp_mode", [InteractionType.MODE, InteractionType.RANDOM, None] + "exp_mode", [InteractionType.DETERMINISTIC, InteractionType.RANDOM, None] ) def test_stateful_probabilistic(self, safe, spec_type, lazy, exp_mode, out_keys): torch.manual_seed(0) @@ -210,9 +206,9 @@ def test_stateful_probabilistic(self, safe, spec_type, lazy, exp_mode, out_keys) if spec_type is None: spec = None elif spec_type == "bounded": - spec = BoundedTensorSpec(-0.1, 0.1, 4) + spec = Bounded(-0.1, 0.1, 4) elif spec_type == "unbounded": - spec = UnboundedContinuousTensorSpec(4) + spec = Unbounded(4) else: raise NotImplementedError @@ -291,9 +287,9 @@ def test_stateful(self, safe, spec_type, lazy): if spec_type is None: spec = None elif spec_type == "bounded": - spec = BoundedTensorSpec(-0.1, 0.1, 4) + spec = Bounded(-0.1, 0.1, 4) elif spec_type == "unbounded": - spec = UnboundedContinuousTensorSpec(4) + spec = Unbounded(4) kwargs = {} @@ -368,9 +364,9 @@ def test_stateful_probabilistic(self, safe, spec_type, lazy): if spec_type is None: spec = None elif spec_type == "bounded": - spec = BoundedTensorSpec(-0.1, 0.1, 4) + spec = Bounded(-0.1, 0.1, 4) elif spec_type == "unbounded": - spec = UnboundedContinuousTensorSpec(4) + spec = Unbounded(4) else: raise NotImplementedError @@ -481,7 +477,7 @@ def test_sequential_partial(self, stack): net3 = nn.Sequential(net3, NormalParamExtractor()) net3 = SafeModule(net3, in_keys=["c"], out_keys=["loc", "scale"]) - spec = BoundedTensorSpec(-0.1, 0.1, 4) + spec = Bounded(-0.1, 0.1, 4) kwargs = {"distribution_class": TanhNormal} @@ -1340,7 +1336,7 @@ def call(data, params): def test_safe_specs(): out_key = ("a", "b") - spec = CompositeSpec(CompositeSpec({out_key: UnboundedContinuousTensorSpec()})) + spec = Composite(Composite({out_key: Unbounded()})) original_spec = spec.clone() mod = SafeModule( module=nn.Linear(3, 1), @@ -1354,9 +1350,7 @@ def test_safe_specs(): def test_actor_critic_specs(): action_key = ("agents", "action") - spec = CompositeSpec( - CompositeSpec({action_key: UnboundedContinuousTensorSpec(shape=(3,))}) - ) + spec = Composite(Composite({action_key: Unbounded(shape=(3,))})) policy_module = TensorDictModule( nn.Linear(3, 1), in_keys=[("agents", "observation")], diff --git a/test/test_transforms.py b/test/test_transforms.py index c38908eba1d..ca9a031bb2f 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -15,7 +15,6 @@ from functools import partial from sys import platform -import numpy as np import pytest import tensordict.tensordict @@ -51,15 +50,15 @@ from torch import multiprocessing as mp, nn, Tensor from torchrl._utils import _replace_last, prod from torchrl.data import ( - BoundedTensorSpec, - CompositeSpec, - DiscreteTensorSpec, + Bounded, + Categorical, + Composite, LazyTensorStorage, ReplayBuffer, TensorDictReplayBuffer, TensorSpec, TensorStorage, - UnboundedContinuousTensorSpec, + Unbounded, ) from torchrl.envs import ( ActionMask, @@ -147,7 +146,7 @@ class TransformBase: We ask for every new transform tests to be coded following this minimum requirement class. - Of course, specific behaviours can also be tested separately. + Of course, specific behaviors can also be tested separately. If your transform identifies an issue with the EnvBase or _BatchedEnv abstraction(s), this needs to be corrected independently. @@ -934,21 +933,17 @@ def test_catframes_transform_observation_spec(self): ) mins = [0, 0.5] maxes = [0.5, 1] - observation_spec = CompositeSpec( + observation_spec = Composite( { - key: BoundedTensorSpec( - space_min, space_max, (1, 3, 3), dtype=torch.double - ) + key: Bounded(space_min, space_max, (1, 3, 3), dtype=torch.double) for key, space_min, space_max in zip(keys, mins, maxes) } ) result = cat_frames.transform_observation_spec(observation_spec) - observation_spec = CompositeSpec( + observation_spec = Composite( { - key: BoundedTensorSpec( - space_min, space_max, (1, 3, 3), dtype=torch.double - ) + key: Bounded(space_min, space_max, (1, 3, 3), dtype=torch.double) for key, space_min, space_max in zip(keys, mins, maxes) } ) @@ -1502,15 +1497,12 @@ def test_r3mnet_transform_observation_spec( ): r3m_net = _R3MNet(in_keys, out_keys, model, del_keys) - observation_spec = CompositeSpec( - {key: BoundedTensorSpec(-1, 1, (3, 16, 16), device) for key in in_keys} + observation_spec = Composite( + {key: Bounded(-1, 1, (3, 16, 16), device) for key in in_keys} ) if del_keys: - exp_ts = CompositeSpec( - { - key: UnboundedContinuousTensorSpec(r3m_net.outdim, device) - for key in out_keys - } + exp_ts = Composite( + {key: Unbounded(r3m_net.outdim, device) for key in out_keys} ) observation_spec_out = r3m_net.transform_observation_spec(observation_spec) @@ -1526,8 +1518,8 @@ def test_r3mnet_transform_observation_spec( for key in in_keys: ts_dict[key] = observation_spec[key] for key in out_keys: - ts_dict[key] = UnboundedContinuousTensorSpec(r3m_net.outdim, device) - exp_ts = CompositeSpec(ts_dict) + ts_dict[key] = Unbounded(r3m_net.outdim, device) + exp_ts = Composite(ts_dict) observation_spec_out = r3m_net.transform_observation_spec(observation_spec) @@ -2020,12 +2012,12 @@ def test_transform_no_env(self, keys, device, out_key): assert tdc.get("dont touch").shape == dont_touch.shape if len(keys) == 1: - observation_spec = BoundedTensorSpec(0, 1, (1, 4, 32)) + observation_spec = Bounded(0, 1, (1, 4, 32)) observation_spec = cattensors.transform_observation_spec(observation_spec) assert observation_spec.shape == torch.Size([1, len(keys) * 4, 32]) else: - observation_spec = CompositeSpec( - {key: BoundedTensorSpec(0, 1, (1, 4, 32)) for key in keys} + observation_spec = Composite( + {key: Bounded(0, 1, (1, 4, 32)) for key in keys} ) observation_spec = cattensors.transform_observation_spec(observation_spec) assert observation_spec[out_key].shape == torch.Size([1, len(keys) * 4, 32]) @@ -2166,12 +2158,12 @@ def test_transform_no_env(self, keys, h, nchannels, batch, device): assert (td.get("dont touch") == dont_touch).all() if len(keys) == 1: - observation_spec = BoundedTensorSpec(-1, 1, (nchannels, 16, 16)) + observation_spec = Bounded(-1, 1, (nchannels, 16, 16)) observation_spec = crop.transform_observation_spec(observation_spec) assert observation_spec.shape == torch.Size([nchannels, 20, h]) else: - observation_spec = CompositeSpec( - {key: BoundedTensorSpec(-1, 1, (nchannels, 16, 16)) for key in keys} + observation_spec = Composite( + {key: Bounded(-1, 1, (nchannels, 16, 16)) for key in keys} ) observation_spec = crop.transform_observation_spec(observation_spec) for key in keys: @@ -2373,12 +2365,12 @@ def test_transform_no_env(self, keys, h, nchannels, batch, device): assert (td.get("dont touch") == dont_touch).all() if len(keys) == 1: - observation_spec = BoundedTensorSpec(-1, 1, (nchannels, 16, 16)) + observation_spec = Bounded(-1, 1, (nchannels, 16, 16)) observation_spec = cc.transform_observation_spec(observation_spec) assert observation_spec.shape == torch.Size([nchannels, 20, h]) else: - observation_spec = CompositeSpec( - {key: BoundedTensorSpec(-1, 1, (nchannels, 16, 16)) for key in keys} + observation_spec = Composite( + {key: Bounded(-1, 1, (nchannels, 16, 16)) for key in keys} ) observation_spec = cc.transform_observation_spec(observation_spec) for key in keys: @@ -2722,18 +2714,15 @@ def test_double2float(self, keys, keys_inv, device): assert td.get("dont touch").dtype != torch.double if len(keys_total) == 1 and len(keys_inv) and keys[0] == "action": - action_spec = BoundedTensorSpec(0, 1, (1, 3, 3), dtype=torch.double) - input_spec = CompositeSpec( - full_action_spec=CompositeSpec(action=action_spec), full_state_spec=None + action_spec = Bounded(0, 1, (1, 3, 3), dtype=torch.double) + input_spec = Composite( + full_action_spec=Composite(action=action_spec), full_state_spec=None ) action_spec = double2float.transform_input_spec(input_spec) assert action_spec.dtype == torch.float else: - observation_spec = CompositeSpec( - { - key: BoundedTensorSpec(0, 1, (1, 3, 3), dtype=torch.double) - for key in keys - } + observation_spec = Composite( + {key: Bounded(0, 1, (1, 3, 3), dtype=torch.double) for key in keys} ) observation_spec = double2float.transform_observation_spec(observation_spec) for key in keys: @@ -2950,13 +2939,13 @@ class TestExcludeTransform(TransformBase): class EnvWithManyKeys(EnvBase): def __init__(self): super().__init__() - self.observation_spec = CompositeSpec( - a=UnboundedContinuousTensorSpec(3), - b=UnboundedContinuousTensorSpec(3), - c=UnboundedContinuousTensorSpec(3), + self.observation_spec = Composite( + a=Unbounded(3), + b=Unbounded(3), + c=Unbounded(3), ) - self.reward_spec = UnboundedContinuousTensorSpec(1) - self.action_spec = UnboundedContinuousTensorSpec(2) + self.reward_spec = Unbounded(1) + self.action_spec = Unbounded(2) def _step( self, @@ -3188,13 +3177,13 @@ class TestSelectTransform(TransformBase): class EnvWithManyKeys(EnvBase): def __init__(self): super().__init__() - self.observation_spec = CompositeSpec( - a=UnboundedContinuousTensorSpec(3), - b=UnboundedContinuousTensorSpec(3), - c=UnboundedContinuousTensorSpec(3), + self.observation_spec = Composite( + a=Unbounded(3), + b=Unbounded(3), + c=Unbounded(3), ) - self.reward_spec = UnboundedContinuousTensorSpec(1) - self.action_spec = UnboundedContinuousTensorSpec(2) + self.reward_spec = Unbounded(1) + self.action_spec = Unbounded(2) def _step( self, @@ -3513,15 +3502,12 @@ def test_transform_no_env(self, keys, size, nchannels, batch, device): assert (td.get("dont touch") == dont_touch).all() if len(keys) == 1: - observation_spec = BoundedTensorSpec(-1, 1, (*size, nchannels, 16, 16)) + observation_spec = Bounded(-1, 1, (*size, nchannels, 16, 16)) observation_spec = flatten.transform_observation_spec(observation_spec) assert observation_spec.shape[-3] == expected_size else: - observation_spec = CompositeSpec( - { - key: BoundedTensorSpec(-1, 1, (*size, nchannels, 16, 16)) - for key in keys - } + observation_spec = Composite( + {key: Bounded(-1, 1, (*size, nchannels, 16, 16)) for key in keys} ) observation_spec = flatten.transform_observation_spec(observation_spec) for key in keys: @@ -3556,15 +3542,12 @@ def test_transform_compose(self, keys, size, nchannels, batch, device): assert (td.get("dont touch") == dont_touch).all() if len(keys) == 1: - observation_spec = BoundedTensorSpec(-1, 1, (*size, nchannels, 16, 16)) + observation_spec = Bounded(-1, 1, (*size, nchannels, 16, 16)) observation_spec = flatten.transform_observation_spec(observation_spec) assert observation_spec.shape[-3] == expected_size else: - observation_spec = CompositeSpec( - { - key: BoundedTensorSpec(-1, 1, (*size, nchannels, 16, 16)) - for key in keys - } + observation_spec = Composite( + {key: Bounded(-1, 1, (*size, nchannels, 16, 16)) for key in keys} ) observation_spec = flatten.transform_observation_spec(observation_spec) for key in keys: @@ -3801,12 +3784,12 @@ def test_transform_no_env(self, keys, device): assert (td.get("dont touch") == dont_touch).all() if len(keys) == 1: - observation_spec = BoundedTensorSpec(-1, 1, (nchannels, 16, 16)) + observation_spec = Bounded(-1, 1, (nchannels, 16, 16)) observation_spec = gs.transform_observation_spec(observation_spec) assert observation_spec.shape == torch.Size([1, 16, 16]) else: - observation_spec = CompositeSpec( - {key: BoundedTensorSpec(-1, 1, (nchannels, 16, 16)) for key in keys} + observation_spec = Composite( + {key: Bounded(-1, 1, (nchannels, 16, 16)) for key in keys} ) observation_spec = gs.transform_observation_spec(observation_spec) for key in keys: @@ -3838,12 +3821,12 @@ def test_transform_compose(self, keys, device): assert (td.get("dont touch") == dont_touch).all() if len(keys) == 1: - observation_spec = BoundedTensorSpec(-1, 1, (nchannels, 16, 16)) + observation_spec = Bounded(-1, 1, (nchannels, 16, 16)) observation_spec = gs.transform_observation_spec(observation_spec) assert observation_spec.shape == torch.Size([1, 16, 16]) else: - observation_spec = CompositeSpec( - {key: BoundedTensorSpec(-1, 1, (nchannels, 16, 16)) for key in keys} + observation_spec = Composite( + {key: Bounded(-1, 1, (nchannels, 16, 16)) for key in keys} ) observation_spec = gs.transform_observation_spec(observation_spec) for key in keys: @@ -4443,9 +4426,7 @@ def test_observationnorm( assert (td.get("dont touch") == dont_touch).all() if len(keys) == 1: - observation_spec = BoundedTensorSpec( - 0, 1, (nchannels, 16, 16), device=device - ) + observation_spec = Bounded(0, 1, (nchannels, 16, 16), device=device) observation_spec = on.transform_observation_spec(observation_spec) if standard_normal: assert (observation_spec.space.low == -loc / scale).all() @@ -4455,11 +4436,8 @@ def test_observationnorm( assert (observation_spec.space.high == scale + loc).all() else: - observation_spec = CompositeSpec( - { - key: BoundedTensorSpec(0, 1, (nchannels, 16, 16), device=device) - for key in keys - } + observation_spec = Composite( + {key: Bounded(0, 1, (nchannels, 16, 16), device=device) for key in keys} ) observation_spec = on.transform_observation_spec(observation_spec) for key in keys: @@ -4480,15 +4458,11 @@ def test_observationnorm_init_stats( ): def make_env(): base_env = ContinuousActionVecMockEnv( - observation_spec=CompositeSpec( - observation=BoundedTensorSpec( - low=1, high=1, shape=torch.Size([size]) - ), - observation_orig=BoundedTensorSpec( - low=1, high=1, shape=torch.Size([size]) - ), + observation_spec=Composite( + observation=Bounded(low=1, high=1, shape=torch.Size([size])), + observation_orig=Bounded(low=1, high=1, shape=torch.Size([size])), ), - action_spec=BoundedTensorSpec(low=1, high=1, shape=torch.Size((size,))), + action_spec=Bounded(low=1, high=1, shape=torch.Size((size,))), seed=0, ) base_env.out_key = "observation" @@ -4669,12 +4643,12 @@ def test_transform_no_env(self, interpolation, keys, nchannels, batch, device): assert (td.get("dont touch") == dont_touch).all() if len(keys) == 1: - observation_spec = BoundedTensorSpec(-1, 1, (nchannels, 16, 16)) + observation_spec = Bounded(-1, 1, (nchannels, 16, 16)) observation_spec = resize.transform_observation_spec(observation_spec) assert observation_spec.shape == torch.Size([nchannels, 20, 21]) else: - observation_spec = CompositeSpec( - {key: BoundedTensorSpec(-1, 1, (nchannels, 16, 16)) for key in keys} + observation_spec = Composite( + {key: Bounded(-1, 1, (nchannels, 16, 16)) for key in keys} ) observation_spec = resize.transform_observation_spec(observation_spec) for key in keys: @@ -4706,12 +4680,12 @@ def test_transform_compose(self, interpolation, keys, nchannels, batch, device): assert (td.get("dont touch") == dont_touch).all() if len(keys) == 1: - observation_spec = BoundedTensorSpec(-1, 1, (nchannels, 16, 16)) + observation_spec = Bounded(-1, 1, (nchannels, 16, 16)) observation_spec = resize.transform_observation_spec(observation_spec) assert observation_spec.shape == torch.Size([nchannels, 20, 21]) else: - observation_spec = CompositeSpec( - {key: BoundedTensorSpec(-1, 1, (nchannels, 16, 16)) for key in keys} + observation_spec = Composite( + {key: Bounded(-1, 1, (nchannels, 16, 16)) for key in keys} ) observation_spec = resize.transform_observation_spec(observation_spec) for key in keys: @@ -4947,7 +4921,7 @@ def test_reward_scaling(self, batch, scale, loc, keys, device, standard_normal): assert (td.get("dont touch") == td_copy.get("dont touch")).all() if len(keys_total) == 1: - reward_spec = UnboundedContinuousTensorSpec(device=device) + reward_spec = Unbounded(device=device) reward_spec = reward_scaling.transform_reward_spec(reward_spec) assert reward_spec.shape == torch.Size([1]) @@ -5140,7 +5114,9 @@ def test_trans_parallel_env_check(self, maybe_fork_ParallelEnv): pass @pytest.mark.parametrize("has_in_keys,", [True, False]) - @pytest.mark.parametrize("reset_keys,", [None, ["_reset"] * 3]) + @pytest.mark.parametrize( + "reset_keys,", [[("some", "nested", "reset")], ["_reset"] * 3, None] + ) def test_trans_multi_key( self, has_in_keys, reset_keys, n_workers=2, batch_size=(3, 2), max_steps=5 ): @@ -5162,9 +5138,9 @@ def test_trans_multi_key( ) with pytest.raises( ValueError, match="Could not match the env reset_keys" - ) if reset_keys is None else contextlib.nullcontext(): + ) if reset_keys == [("some", "nested", "reset")] else contextlib.nullcontext(): check_env_specs(env) - if reset_keys is not None: + if reset_keys != [("some", "nested", "reset")]: td = env.rollout(max_steps, policy=policy) for reward_key in env.reward_keys: reward_key = _unravel_key_to_tuple(reward_key) @@ -5341,24 +5317,24 @@ def test_sum_reward(self, keys, device): # test transform_observation_spec base_env = ContinuousActionVecMockEnv( - reward_spec=UnboundedContinuousTensorSpec(shape=(3, 16, 16)), + reward_spec=Unbounded(shape=(3, 16, 16)), ) transfomed_env = TransformedEnv(base_env, RewardSum()) transformed_observation_spec1 = transfomed_env.observation_spec - assert isinstance(transformed_observation_spec1, CompositeSpec) + assert isinstance(transformed_observation_spec1, Composite) assert "episode_reward" in transformed_observation_spec1.keys() assert "observation" in transformed_observation_spec1.keys() base_env = ContinuousActionVecMockEnv( - reward_spec=UnboundedContinuousTensorSpec(), - observation_spec=CompositeSpec( - observation=UnboundedContinuousTensorSpec(), - some_extra_observation=UnboundedContinuousTensorSpec(), + reward_spec=Unbounded(), + observation_spec=Composite( + observation=Unbounded(), + some_extra_observation=Unbounded(), ), ) transfomed_env = TransformedEnv(base_env, RewardSum()) transformed_observation_spec2 = transfomed_env.observation_spec - assert isinstance(transformed_observation_spec2, CompositeSpec) + assert isinstance(transformed_observation_spec2, Composite) assert "some_extra_observation" in transformed_observation_spec2.keys() assert "episode_reward" in transformed_observation_spec2.keys() @@ -5653,7 +5629,7 @@ def test_transform_model(self): class TestUnsqueezeTransform(TransformBase): - @pytest.mark.parametrize("unsqueeze_dim", [1, -2]) + @pytest.mark.parametrize("dim", [1, -2]) @pytest.mark.parametrize("nchannels", [1, 3]) @pytest.mark.parametrize("batch", [[], [2], [2, 4]]) @pytest.mark.parametrize("size", [[], [4]]) @@ -5661,14 +5637,10 @@ class TestUnsqueezeTransform(TransformBase): "keys", [["observation", ("some_other", "nested_key")], ["observation_pixels"]] ) @pytest.mark.parametrize("device", get_default_devices()) - def test_transform_no_env( - self, keys, size, nchannels, batch, device, unsqueeze_dim - ): + def test_transform_no_env(self, keys, size, nchannels, batch, device, dim): torch.manual_seed(0) dont_touch = torch.randn(*batch, *size, nchannels, 16, 16, device=device) - unsqueeze = UnsqueezeTransform( - unsqueeze_dim, in_keys=keys, allow_positive_dim=True - ) + unsqueeze = UnsqueezeTransform(dim, in_keys=keys, allow_positive_dim=True) td = TensorDict( { key: torch.randn(*batch, *size, nchannels, 16, 16, device=device) @@ -5678,16 +5650,16 @@ def test_transform_no_env( device=device, ) td.set("dont touch", dont_touch.clone()) - if unsqueeze_dim >= 0 and unsqueeze_dim < len(batch): + if dim >= 0 and dim < len(batch): with pytest.raises(RuntimeError, match="batch dimension mismatch"): unsqueeze(td) return unsqueeze(td) expected_size = [*batch, *size, nchannels, 16, 16] - if unsqueeze_dim < 0: - expected_size.insert(len(expected_size) + unsqueeze_dim + 1, 1) + if dim < 0: + expected_size.insert(len(expected_size) + dim + 1, 1) else: - expected_size.insert(unsqueeze_dim, 1) + expected_size.insert(dim, 1) expected_size = torch.Size(expected_size) for key in keys: @@ -5695,20 +5667,18 @@ def test_transform_no_env( batch, size, nchannels, - unsqueeze_dim, + dim, ) assert (td.get("dont touch") == dont_touch).all() if len(keys) == 1: - observation_spec = BoundedTensorSpec( - -1, 1, (*batch, *size, nchannels, 16, 16) - ) + observation_spec = Bounded(-1, 1, (*batch, *size, nchannels, 16, 16)) observation_spec = unsqueeze.transform_observation_spec(observation_spec) assert observation_spec.shape == expected_size else: - observation_spec = CompositeSpec( + observation_spec = Composite( { - key: BoundedTensorSpec(-1, 1, (*batch, *size, nchannels, 16, 16)) + key: Bounded(-1, 1, (*batch, *size, nchannels, 16, 16)) for key in keys } ) @@ -5716,7 +5686,7 @@ def test_transform_no_env( for key in keys: assert observation_spec[key].shape == expected_size - @pytest.mark.parametrize("unsqueeze_dim", [1, -2]) + @pytest.mark.parametrize("dim", [1, -2]) @pytest.mark.parametrize("nchannels", [1, 3]) @pytest.mark.parametrize("batch", [[], [2], [2, 4]]) @pytest.mark.parametrize("size", [[], [4]]) @@ -5732,13 +5702,11 @@ def test_transform_no_env( [("next", "observation_pixels")], ], ) - def test_unsqueeze_inv( - self, keys, keys_inv, size, nchannels, batch, device, unsqueeze_dim - ): + def test_unsqueeze_inv(self, keys, keys_inv, size, nchannels, batch, device, dim): torch.manual_seed(0) keys_total = set(keys + keys_inv) unsqueeze = UnsqueezeTransform( - unsqueeze_dim, in_keys=keys, in_keys_inv=keys_inv, allow_positive_dim=True + dim, in_keys=keys, in_keys_inv=keys_inv, allow_positive_dim=True ) td = TensorDict( { @@ -5754,8 +5722,8 @@ def test_unsqueeze_inv( for key in keys_total.difference(keys_inv): assert td.get(key).shape == torch.Size(expected_size) - if expected_size[unsqueeze_dim] == 1: - del expected_size[unsqueeze_dim] + if expected_size[dim] == 1: + del expected_size[dim] for key in keys_inv: assert td_modif.get(key).shape == torch.Size(expected_size) # for key in keys_inv: @@ -5815,7 +5783,7 @@ def test_trans_parallel_env_check(self, maybe_fork_ParallelEnv): except RuntimeError: pass - @pytest.mark.parametrize("unsqueeze_dim", [1, -2]) + @pytest.mark.parametrize("dim", [1, -2]) @pytest.mark.parametrize("nchannels", [1, 3]) @pytest.mark.parametrize("batch", [[], [2], [2, 4]]) @pytest.mark.parametrize("size", [[], [4]]) @@ -5823,13 +5791,11 @@ def test_trans_parallel_env_check(self, maybe_fork_ParallelEnv): "keys", [["observation", "some_other_key"], ["observation_pixels"]] ) @pytest.mark.parametrize("device", get_default_devices()) - def test_transform_compose( - self, keys, size, nchannels, batch, device, unsqueeze_dim - ): + def test_transform_compose(self, keys, size, nchannels, batch, device, dim): torch.manual_seed(0) dont_touch = torch.randn(*batch, *size, nchannels, 16, 16, device=device) unsqueeze = Compose( - UnsqueezeTransform(unsqueeze_dim, in_keys=keys, allow_positive_dim=True) + UnsqueezeTransform(dim, in_keys=keys, allow_positive_dim=True) ) td = TensorDict( { @@ -5840,16 +5806,16 @@ def test_transform_compose( device=device, ) td.set("dont touch", dont_touch.clone()) - if unsqueeze_dim >= 0 and unsqueeze_dim < len(batch): + if dim >= 0 and dim < len(batch): with pytest.raises(RuntimeError, match="batch dimension mismatch"): unsqueeze(td) return unsqueeze(td) expected_size = [*batch, *size, nchannels, 16, 16] - if unsqueeze_dim < 0: - expected_size.insert(len(expected_size) + unsqueeze_dim + 1, 1) + if dim < 0: + expected_size.insert(len(expected_size) + dim + 1, 1) else: - expected_size.insert(unsqueeze_dim, 1) + expected_size.insert(dim, 1) expected_size = torch.Size(expected_size) for key in keys: @@ -5857,20 +5823,18 @@ def test_transform_compose( batch, size, nchannels, - unsqueeze_dim, + dim, ) assert (td.get("dont touch") == dont_touch).all() if len(keys) == 1: - observation_spec = BoundedTensorSpec( - -1, 1, (*batch, *size, nchannels, 16, 16) - ) + observation_spec = Bounded(-1, 1, (*batch, *size, nchannels, 16, 16)) observation_spec = unsqueeze.transform_observation_spec(observation_spec) assert observation_spec.shape == expected_size else: - observation_spec = CompositeSpec( + observation_spec = Composite( { - key: BoundedTensorSpec(-1, 1, (*batch, *size, nchannels, 16, 16)) + key: Bounded(-1, 1, (*batch, *size, nchannels, 16, 16)) for key in keys } ) @@ -5895,10 +5859,10 @@ def test_transform_env(self, out_keys): check_env_specs(env) @pytest.mark.parametrize("out_keys", [None, ["stuff"]]) - @pytest.mark.parametrize("unsqueeze_dim", [-1, 1]) - def test_transform_model(self, out_keys, unsqueeze_dim): + @pytest.mark.parametrize("dim", [-1, 1]) + def test_transform_model(self, out_keys, dim): t = UnsqueezeTransform( - unsqueeze_dim, + dim, in_keys=["observation"], out_keys=out_keys, allow_positive_dim=True, @@ -5908,21 +5872,21 @@ def test_transform_model(self, out_keys, unsqueeze_dim): ) t(td) expected_shape = [3, 4] - if unsqueeze_dim >= 0: - expected_shape.insert(unsqueeze_dim, 1) + if dim >= 0: + expected_shape.insert(dim, 1) else: - expected_shape.insert(len(expected_shape) + unsqueeze_dim + 1, 1) + expected_shape.insert(len(expected_shape) + dim + 1, 1) if out_keys is None: assert td["observation"].shape == torch.Size(expected_shape) else: assert td[out_keys[0]].shape == torch.Size(expected_shape) @pytest.mark.parametrize("out_keys", [None, ["stuff"]]) - @pytest.mark.parametrize("unsqueeze_dim", [-1, 1]) + @pytest.mark.parametrize("dim", [-1, 1]) @pytest.mark.parametrize("rbclass", [ReplayBuffer, TensorDictReplayBuffer]) - def test_transform_rb(self, rbclass, out_keys, unsqueeze_dim): + def test_transform_rb(self, rbclass, out_keys, dim): t = UnsqueezeTransform( - unsqueeze_dim, + dim, in_keys=["observation"], out_keys=out_keys, allow_positive_dim=True, @@ -5935,10 +5899,10 @@ def test_transform_rb(self, rbclass, out_keys, unsqueeze_dim): rb.extend(td) td = rb.sample(2) expected_shape = [2, 3, 4] - if unsqueeze_dim >= 0: - expected_shape.insert(unsqueeze_dim, 1) + if dim >= 0: + expected_shape.insert(dim, 1) else: - expected_shape.insert(len(expected_shape) + unsqueeze_dim + 1, 1) + expected_shape.insert(len(expected_shape) + dim + 1, 1) if out_keys is None: assert td["observation"].shape == torch.Size(expected_shape) else: @@ -5962,7 +5926,7 @@ def test_transform_inverse(self): class TestSqueezeTransform(TransformBase): - @pytest.mark.parametrize("squeeze_dim", [1, -2]) + @pytest.mark.parametrize("dim", [1, -2]) @pytest.mark.parametrize("nchannels", [1, 3]) @pytest.mark.parametrize("batch", [[], [2], [2, 4]]) @pytest.mark.parametrize("size", [[], [4]]) @@ -5983,12 +5947,12 @@ class TestSqueezeTransform(TransformBase): ], ) def test_transform_no_env( - self, keys, keys_inv, size, nchannels, batch, device, squeeze_dim + self, keys, keys_inv, size, nchannels, batch, device, dim ): torch.manual_seed(0) keys_total = set(keys + keys_inv) squeeze = SqueezeTransform( - squeeze_dim, in_keys=keys, in_keys_inv=keys_inv, allow_positive_dim=True + dim, in_keys=keys, in_keys_inv=keys_inv, allow_positive_dim=True ) td = TensorDict( { @@ -6003,12 +5967,12 @@ def test_transform_no_env( for key in keys_total.difference(keys): assert td.get(key).shape == torch.Size(expected_size) - if expected_size[squeeze_dim] == 1: - del expected_size[squeeze_dim] + if expected_size[dim] == 1: + del expected_size[dim] for key in keys: assert td.get(key).shape == torch.Size(expected_size) - @pytest.mark.parametrize("squeeze_dim", [1, -2]) + @pytest.mark.parametrize("dim", [1, -2]) @pytest.mark.parametrize("nchannels", [1, 3]) @pytest.mark.parametrize("batch", [[], [2], [2, 4]]) @pytest.mark.parametrize("size", [[], [4]]) @@ -6028,15 +5992,13 @@ def test_transform_no_env( [("next", "observation_pixels")], ], ) - def test_squeeze_inv( - self, keys, keys_inv, size, nchannels, batch, device, squeeze_dim - ): + def test_squeeze_inv(self, keys, keys_inv, size, nchannels, batch, device, dim): torch.manual_seed(0) - if squeeze_dim >= 0: - squeeze_dim = squeeze_dim + len(batch) + if dim >= 0: + dim = dim + len(batch) keys_total = set(keys + keys_inv) squeeze = SqueezeTransform( - squeeze_dim, in_keys=keys, in_keys_inv=keys_inv, allow_positive_dim=True + dim, in_keys=keys, in_keys_inv=keys_inv, allow_positive_dim=True ) td = TensorDict( { @@ -6051,14 +6013,14 @@ def test_squeeze_inv( for key in keys_total.difference(keys_inv): assert td.get(key).shape == torch.Size(expected_size) - if squeeze_dim < 0: - expected_size.insert(len(expected_size) + squeeze_dim + 1, 1) + if dim < 0: + expected_size.insert(len(expected_size) + dim + 1, 1) else: - expected_size.insert(squeeze_dim, 1) + expected_size.insert(dim, 1) expected_size = torch.Size(expected_size) for key in keys_inv: - assert td.get(key).shape == torch.Size(expected_size), squeeze_dim + assert td.get(key).shape == torch.Size(expected_size), dim @property def _circular_transform(self): @@ -6131,7 +6093,7 @@ def test_trans_parallel_env_check(self, maybe_fork_ParallelEnv): except RuntimeError: pass - @pytest.mark.parametrize("squeeze_dim", [1, -2]) + @pytest.mark.parametrize("dim", [1, -2]) @pytest.mark.parametrize("nchannels", [1, 3]) @pytest.mark.parametrize("batch", [[], [2], [2, 4]]) @pytest.mark.parametrize("size", [[], [4]]) @@ -6144,13 +6106,13 @@ def test_trans_parallel_env_check(self, maybe_fork_ParallelEnv): "keys_inv", [[], ["action", "some_other_key"], [("next", "observation_pixels")]] ) def test_transform_compose( - self, keys, keys_inv, size, nchannels, batch, device, squeeze_dim + self, keys, keys_inv, size, nchannels, batch, device, dim ): torch.manual_seed(0) keys_total = set(keys + keys_inv) squeeze = Compose( SqueezeTransform( - squeeze_dim, in_keys=keys, in_keys_inv=keys_inv, allow_positive_dim=True + dim, in_keys=keys, in_keys_inv=keys_inv, allow_positive_dim=True ) ) td = TensorDict( @@ -6166,8 +6128,8 @@ def test_transform_compose( for key in keys_total.difference(keys): assert td.get(key).shape == torch.Size(expected_size) - if expected_size[squeeze_dim] == 1: - del expected_size[squeeze_dim] + if expected_size[dim] == 1: + del expected_size[dim] for key in keys: assert td.get(key).shape == torch.Size(expected_size) @@ -6184,9 +6146,9 @@ def test_transform_env(self, keys_inv): @pytest.mark.parametrize("out_keys", [None, ["obs_sq"]]) def test_transform_model(self, out_keys): - squeeze_dim = 1 + dim = 1 t = SqueezeTransform( - squeeze_dim, + dim, in_keys=["observation"], out_keys=out_keys, allow_positive_dim=True, @@ -6205,9 +6167,9 @@ def test_transform_model(self, out_keys): @pytest.mark.parametrize("out_keys", [None, ["obs_sq"]]) @pytest.mark.parametrize("rbclass", [ReplayBuffer, TensorDictReplayBuffer]) def test_transform_rb(self, out_keys, rbclass): - squeeze_dim = -2 + dim = -2 t = SqueezeTransform( - squeeze_dim, + dim, in_keys=["observation"], out_keys=out_keys, allow_positive_dim=True, @@ -6466,7 +6428,7 @@ def test_transform_no_env(self, keys, batch, device): assert (td.get("dont touch") == dont_touch).all() if len(keys) == 1: - observation_spec = BoundedTensorSpec(0, 255, (16, 16, 3), dtype=torch.uint8) + observation_spec = Bounded(0, 255, (16, 16, 3), dtype=torch.uint8) observation_spec = totensorimage.transform_observation_spec( observation_spec ) @@ -6474,11 +6436,8 @@ def test_transform_no_env(self, keys, batch, device): assert (observation_spec.space.low == 0).all() assert (observation_spec.space.high == 1).all() else: - observation_spec = CompositeSpec( - { - key: BoundedTensorSpec(0, 255, (16, 16, 3), dtype=torch.uint8) - for key in keys - } + observation_spec = Composite( + {key: Bounded(0, 255, (16, 16, 3), dtype=torch.uint8) for key in keys} ) observation_spec = totensorimage.transform_observation_spec( observation_spec @@ -6515,7 +6474,7 @@ def test_transform_compose(self, keys, batch, device): assert (td.get("dont touch") == dont_touch).all() if len(keys) == 1: - observation_spec = BoundedTensorSpec(0, 255, (16, 16, 3), dtype=torch.uint8) + observation_spec = Bounded(0, 255, (16, 16, 3), dtype=torch.uint8) observation_spec = totensorimage.transform_observation_spec( observation_spec ) @@ -6523,11 +6482,8 @@ def test_transform_compose(self, keys, batch, device): assert (observation_spec.space.low == 0).all() assert (observation_spec.space.high == 1).all() else: - observation_spec = CompositeSpec( - { - key: BoundedTensorSpec(0, 255, (16, 16, 3), dtype=torch.uint8) - for key in keys - } + observation_spec = Composite( + {key: Bounded(0, 255, (16, 16, 3), dtype=torch.uint8) for key in keys} ) observation_spec = totensorimage.transform_observation_spec( observation_spec @@ -6670,7 +6626,7 @@ class TestTensorDictPrimer(TransformBase): def test_single_trans_env_check(self): env = TransformedEnv( ContinuousActionVecMockEnv(), - TensorDictPrimer(mykey=UnboundedContinuousTensorSpec([3])), + TensorDictPrimer(mykey=Unbounded([3])), ) check_env_specs(env) assert "mykey" in env.reset().keys() @@ -6682,14 +6638,10 @@ def test_nested_key_env(self): env = TransformedEnv( env, TensorDictPrimer( - CompositeSpec( + Composite( { - "nested_1": CompositeSpec( - { - "mykey": UnboundedContinuousTensorSpec( - (env.nested_dim_1, 4) - ) - }, + "nested_1": Composite( + {"mykey": Unbounded((env.nested_dim_1, 4))}, shape=(env.nested_dim_1,), ) } @@ -6707,13 +6659,13 @@ def test_nested_key_env(self): assert ("next", "nested_1", "mykey") in env.rollout(3).keys(True, True) def test_transform_no_env(self): - t = TensorDictPrimer(mykey=UnboundedContinuousTensorSpec([3])) + t = TensorDictPrimer(mykey=Unbounded([3])) td = TensorDict({"a": torch.zeros(())}, []) t(td) assert "mykey" in td.keys() def test_transform_model(self): - t = TensorDictPrimer(mykey=UnboundedContinuousTensorSpec([3])) + t = TensorDictPrimer(mykey=Unbounded([3])) model = nn.Sequential(t, nn.Identity()) td = TensorDict({}, []) model(td) @@ -6722,7 +6674,7 @@ def test_transform_model(self): @pytest.mark.parametrize("rbclass", [ReplayBuffer, TensorDictReplayBuffer]) def test_transform_rb(self, rbclass): batch_size = (2,) - t = TensorDictPrimer(mykey=UnboundedContinuousTensorSpec([*batch_size, 3])) + t = TensorDictPrimer(mykey=Unbounded([*batch_size, 3])) rb = rbclass(storage=LazyTensorStorage(10)) rb.append_transform(t) td = TensorDict({"a": torch.zeros(())}, []) @@ -6734,7 +6686,7 @@ def test_transform_inverse(self): raise pytest.skip("No inverse method for TensorDictPrimer") def test_transform_compose(self): - t = Compose(TensorDictPrimer(mykey=UnboundedContinuousTensorSpec([3]))) + t = Compose(TensorDictPrimer(mykey=Unbounded([3]))) td = TensorDict({"a": torch.zeros(())}, []) t(td) assert "mykey" in td.keys() @@ -6743,7 +6695,7 @@ def test_parallel_trans_env_check(self, maybe_fork_ParallelEnv): def make_env(): return TransformedEnv( ContinuousActionVecMockEnv(), - TensorDictPrimer(mykey=UnboundedContinuousTensorSpec([3])), + TensorDictPrimer(mykey=Unbounded([3])), ) env = maybe_fork_ParallelEnv(2, make_env) @@ -6761,7 +6713,7 @@ def test_serial_trans_env_check(self): def make_env(): return TransformedEnv( ContinuousActionVecMockEnv(), - TensorDictPrimer(mykey=UnboundedContinuousTensorSpec([3])), + TensorDictPrimer(mykey=Unbounded([3])), ) env = SerialEnv(2, make_env) @@ -6778,7 +6730,7 @@ def make_env(): def test_trans_parallel_env_check(self, maybe_fork_ParallelEnv): env = TransformedEnv( maybe_fork_ParallelEnv(2, ContinuousActionVecMockEnv), - TensorDictPrimer(mykey=UnboundedContinuousTensorSpec([2, 4])), + TensorDictPrimer(mykey=Unbounded([2, 4])), ) try: check_env_specs(env) @@ -6796,7 +6748,7 @@ def test_trans_parallel_env_check(self, maybe_fork_ParallelEnv): def test_trans_serial_env_check(self, spec_shape): env = TransformedEnv( SerialEnv(2, ContinuousActionVecMockEnv), - TensorDictPrimer(mykey=UnboundedContinuousTensorSpec(spec_shape)), + TensorDictPrimer(mykey=Unbounded(spec_shape)), ) check_env_specs(env) assert "mykey" in env.reset().keys() @@ -6810,8 +6762,8 @@ def test_trans_serial_env_check(self, spec_shape): @pytest.mark.parametrize( "spec", [ - CompositeSpec(b=BoundedTensorSpec(-3, 3, [4])), - BoundedTensorSpec(-3, 3, [4]), + Composite(b=Bounded(-3, 3, [4])), + Bounded(-3, 3, [4]), ], ) @pytest.mark.parametrize("random", [True, False]) @@ -6861,9 +6813,7 @@ def make_env(): else: assert (tensordict_select == value).all() - if isinstance(spec, CompositeSpec) and any( - key != "action" for key in default_keys - ): + if isinstance(spec, Composite) and any(key != "action" for key in default_keys): for key in default_keys: if key in ("action",): continue @@ -6878,7 +6828,7 @@ def test_tensordictprimer_batching(self, batched_class, break_when_any_done): env = TransformedEnv( batched_class(2, lambda: GymEnv(CARTPOLE_VERSIONED())), - TensorDictPrimer(mykey=UnboundedContinuousTensorSpec([2, 4])), + TensorDictPrimer(mykey=Unbounded([2, 4])), ) torch.manual_seed(0) env.set_seed(0) @@ -6888,7 +6838,7 @@ def test_tensordictprimer_batching(self, batched_class, break_when_any_done): 2, lambda: TransformedEnv( GymEnv(CARTPOLE_VERSIONED()), - TensorDictPrimer(mykey=UnboundedContinuousTensorSpec([4])), + TensorDictPrimer(mykey=Unbounded([4])), ), ) torch.manual_seed(0) @@ -6902,9 +6852,7 @@ def create_tensor(): env = TransformedEnv( ContinuousActionVecMockEnv(), - TensorDictPrimer( - mykey=UnboundedContinuousTensorSpec([3]), default_value=create_tensor - ), + TensorDictPrimer(mykey=Unbounded([3]), default_value=create_tensor), ) check_env_specs(env) assert "mykey" in env.reset().keys() @@ -6913,8 +6861,8 @@ def create_tensor(): def test_dict_default_value(self): # Test with a dict of float default values - key1_spec = UnboundedContinuousTensorSpec([3]) - key2_spec = UnboundedContinuousTensorSpec([3]) + key1_spec = Unbounded([3]) + key2_spec = Unbounded([3]) env = TransformedEnv( ContinuousActionVecMockEnv(), TensorDictPrimer( @@ -6937,8 +6885,8 @@ def test_dict_default_value(self): assert (rollout_td.get(("next", "mykey2")) == 2.0).all() # Test with a dict of callable default values - key1_spec = UnboundedContinuousTensorSpec([3]) - key2_spec = DiscreteTensorSpec(3, dtype=torch.int64) + key1_spec = Unbounded([3]) + key2_spec = Categorical(3, dtype=torch.int64) env = TransformedEnv( ContinuousActionVecMockEnv(), TensorDictPrimer( @@ -7751,13 +7699,11 @@ def test_vipnet_transform_observation_spec( ): vip_net = _VIPNet(in_keys, out_keys, model, del_keys) - observation_spec = CompositeSpec( - {key: BoundedTensorSpec(-1, 1, (3, 16, 16), device) for key in in_keys} + observation_spec = Composite( + {key: Bounded(-1, 1, (3, 16, 16), device) for key in in_keys} ) if del_keys: - exp_ts = CompositeSpec( - {key: UnboundedContinuousTensorSpec(1024, device) for key in out_keys} - ) + exp_ts = Composite({key: Unbounded(1024, device) for key in out_keys}) observation_spec_out = vip_net.transform_observation_spec(observation_spec) @@ -7772,8 +7718,8 @@ def test_vipnet_transform_observation_spec( for key in in_keys: ts_dict[key] = observation_spec[key] for key in out_keys: - ts_dict[key] = UnboundedContinuousTensorSpec(1024, device) - exp_ts = CompositeSpec(ts_dict) + ts_dict[key] = Unbounded(1024, device) + exp_ts = Composite(ts_dict) observation_spec_out = vip_net.transform_observation_spec(observation_spec) @@ -8466,8 +8412,8 @@ def transform_reward_spec(self, reward_spec: TensorSpec) -> TensorSpec: env.transform.transform_reward_spec(env.base_env.full_reward_spec) def test_independent_obs_specs_from_shared_env(self): - obs_spec = CompositeSpec( - observation=BoundedTensorSpec(low=0, high=10, shape=torch.Size((1,))) + obs_spec = Composite( + observation=Bounded(low=0, high=10, shape=torch.Size((1,))) ) base_env = ContinuousActionVecMockEnv(observation_spec=obs_spec) t1 = TransformedEnv( @@ -8490,7 +8436,7 @@ def test_independent_obs_specs_from_shared_env(self): assert base_env.observation_spec["observation"].space.high == 10 def test_independent_reward_specs_from_shared_env(self): - reward_spec = UnboundedContinuousTensorSpec() + reward_spec = Unbounded() base_env = ContinuousActionVecMockEnv(reward_spec=reward_spec) t1 = TransformedEnv( base_env, transform=RewardClipping(clamp_min=0, clamp_max=4) @@ -8508,8 +8454,14 @@ def test_independent_reward_specs_from_shared_env(self): assert t2_reward_spec.space.low == -2 assert t2_reward_spec.space.high == 2 - assert base_env.reward_spec.space.low == -np.inf - assert base_env.reward_spec.space.high == np.inf + assert ( + base_env.reward_spec.space.low + == torch.finfo(base_env.reward_spec.dtype).min + ) + assert ( + base_env.reward_spec.space.high + == torch.finfo(base_env.reward_spec.dtype).max + ) def test_allow_done_after_reset(self): base_env = ContinuousActionVecMockEnv(allow_done_after_reset=True) @@ -8637,13 +8589,13 @@ def test_compose(self, keys, batch, device, nchannels=1, N=4): assert (td.get("dont touch") == dont_touch).all() if len(keys) == 1: - observation_spec = BoundedTensorSpec(0, 255, (nchannels, 16, 16)) + observation_spec = Bounded(0, 255, (nchannels, 16, 16)) # StepCounter does not want non composite specs observation_spec = compose[:2].transform_observation_spec(observation_spec) assert observation_spec.shape == torch.Size([nchannels * N, 16, 16]) else: - observation_spec = CompositeSpec( - {key: BoundedTensorSpec(0, 255, (nchannels, 16, 16)) for key in keys} + observation_spec = Composite( + {key: Bounded(0, 255, (nchannels, 16, 16)) for key in keys} ) observation_spec = compose.transform_observation_spec(observation_spec) for key in keys: @@ -8715,6 +8667,35 @@ def test_compose_indexing(self): assert last_t.scale == 4 assert last_t2.scale == 4 + def test_compose_action_spec(self): + # Create a Compose transform that renames "action" to "action_1" and then to "action_2" + c = Compose( + RenameTransform( + in_keys=(), + out_keys=(), + in_keys_inv=("action",), + out_keys_inv=("action_1",), + ), + RenameTransform( + in_keys=(), + out_keys=(), + in_keys_inv=("action_1",), + out_keys_inv=("action_2",), + ), + ) + base_env = ContinuousActionVecMockEnv() + env = TransformedEnv(base_env, c) + + # Check the `full_action_spec`s + assert "action_2" in env.full_action_spec + # Ensure intermediate keys are no longer in the action spec + assert "action_1" not in env.full_action_spec + assert "action" not in env.full_action_spec + + # Final check to ensure clean sampling from the action_spec + action = env.rand_action() + assert "action_2" in action + @pytest.mark.parametrize("device", get_default_devices()) def test_finitetensordictcheck(self, device): ftd = FiniteTensorDictCheck() @@ -8936,10 +8917,8 @@ def test_batch_unlocked_with_batch_size_transformed(device): pytest.param( partial(FlattenObservation, first_dim=-3, last_dim=-3), id="FlattenObservation" ), - pytest.param( - partial(UnsqueezeTransform, unsqueeze_dim=-1), id="UnsqueezeTransform" - ), - pytest.param(partial(SqueezeTransform, squeeze_dim=-1), id="SqueezeTransform"), + pytest.param(partial(UnsqueezeTransform, dim=-1), id="UnsqueezeTransform"), + pytest.param(partial(SqueezeTransform, dim=-1), id="SqueezeTransform"), GrayScale, pytest.param( partial(ObservationNorm, in_keys=["observation"]), id="ObservationNorm" @@ -9371,6 +9350,28 @@ def test_transform_inverse(self, create_copy): else: assert "b" not in tensordict.keys() + def test_rename_action(self, create_copy): + base_env = ContinuousActionVecMockEnv() + env = base_env.append_transform( + RenameTransform( + in_keys=[], + out_keys=[], + in_keys_inv=["action"], + out_keys_inv=[("renamed", "action")], + create_copy=create_copy, + ) + ) + r = env.rollout(3) + assert ("renamed", "action") in env.action_keys, env.action_keys + assert ("renamed", "action") in r + assert env.full_action_spec[("renamed", "action")] is not None + if create_copy: + assert "action" in env.action_keys + assert "action" in r + else: + assert "action" not in env.action_keys + assert "action" not in r + class TestInitTracker(TransformBase): @pytest.mark.skipif(not _has_gym, reason="no gym detected") @@ -9600,9 +9601,7 @@ def _make_transform_env(self, out_key, base_env): return Compose( TensorDictPrimer( primers={ - "sample_log_prob": UnboundedContinuousTensorSpec( - shape=base_env.action_spec.shape[:-1] - ) + "sample_log_prob": Unbounded(shape=base_env.action_spec.shape[:-1]) } ), transform, @@ -9836,20 +9835,18 @@ def test_kl_lstm(self): class TestActionMask(TransformBase): @property def _env_class(self): - from torchrl.data import BinaryDiscreteTensorSpec, DiscreteTensorSpec + from torchrl.data import Binary, Categorical class MaskedEnv(EnvBase): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - self.action_spec = DiscreteTensorSpec(4) - self.state_spec = CompositeSpec( - action_mask=BinaryDiscreteTensorSpec(4, dtype=torch.bool) + self.action_spec = Categorical(4) + self.state_spec = Composite(action_mask=Binary(4, dtype=torch.bool)) + self.observation_spec = Composite( + obs=Unbounded(3), + action_mask=Binary(4, dtype=torch.bool), ) - self.observation_spec = CompositeSpec( - obs=UnboundedContinuousTensorSpec(3), - action_mask=BinaryDiscreteTensorSpec(4, dtype=torch.bool), - ) - self.reward_spec = UnboundedContinuousTensorSpec(1) + self.reward_spec = Unbounded(1) def _reset(self, tensordict): td = self.observation_spec.rand() @@ -9960,16 +9957,27 @@ def test_transform_inverse(self): class TestDeviceCastTransformPart(TransformBase): + @pytest.fixture(scope="class") + def _cast_device(self): + if torch.cuda.is_available(): + yield torch.device("cuda:0") + elif torch.backends.mps.is_available(): + yield torch.device("mps:0") + else: + yield torch.device("cpu:1") + @pytest.mark.parametrize("in_keys", ["observation"]) @pytest.mark.parametrize("out_keys", [None, ["obs_device"]]) @pytest.mark.parametrize("in_keys_inv", ["action"]) @pytest.mark.parametrize("out_keys_inv", [None, ["action_device"]]) - def test_single_trans_env_check(self, in_keys, out_keys, in_keys_inv, out_keys_inv): + def test_single_trans_env_check( + self, in_keys, out_keys, in_keys_inv, out_keys_inv, _cast_device + ): env = ContinuousActionVecMockEnv(device="cpu:0") env = TransformedEnv( env, DeviceCastTransform( - "cpu:1", + _cast_device, in_keys=in_keys, out_keys=out_keys, in_keys_inv=in_keys_inv, @@ -9983,12 +9991,14 @@ def test_single_trans_env_check(self, in_keys, out_keys, in_keys_inv, out_keys_i @pytest.mark.parametrize("out_keys", [None, ["obs_device"]]) @pytest.mark.parametrize("in_keys_inv", ["action"]) @pytest.mark.parametrize("out_keys_inv", [None, ["action_device"]]) - def test_serial_trans_env_check(self, in_keys, out_keys, in_keys_inv, out_keys_inv): + def test_serial_trans_env_check( + self, in_keys, out_keys, in_keys_inv, out_keys_inv, _cast_device + ): def make_env(): return TransformedEnv( ContinuousActionVecMockEnv(device="cpu:0"), DeviceCastTransform( - "cpu:1", + _cast_device, in_keys=in_keys, out_keys=out_keys, in_keys_inv=in_keys_inv, @@ -10005,13 +10015,13 @@ def make_env(): @pytest.mark.parametrize("in_keys_inv", ["action"]) @pytest.mark.parametrize("out_keys_inv", [None, ["action_device"]]) def test_parallel_trans_env_check( - self, in_keys, out_keys, in_keys_inv, out_keys_inv + self, in_keys, out_keys, in_keys_inv, out_keys_inv, _cast_device ): def make_env(): return TransformedEnv( ContinuousActionVecMockEnv(device="cpu:0"), DeviceCastTransform( - "cpu:1", + _cast_device, in_keys=in_keys, out_keys=out_keys, in_keys_inv=in_keys_inv, @@ -10037,14 +10047,16 @@ def make_env(): @pytest.mark.parametrize("out_keys", [None, ["obs_device"]]) @pytest.mark.parametrize("in_keys_inv", ["action"]) @pytest.mark.parametrize("out_keys_inv", [None, ["action_device"]]) - def test_trans_serial_env_check(self, in_keys, out_keys, in_keys_inv, out_keys_inv): + def test_trans_serial_env_check( + self, in_keys, out_keys, in_keys_inv, out_keys_inv, _cast_device + ): def make_env(): return ContinuousActionVecMockEnv(device="cpu:0") env = TransformedEnv( SerialEnv(2, make_env), DeviceCastTransform( - "cpu:1", + _cast_device, in_keys=in_keys, out_keys=out_keys, in_keys_inv=in_keys_inv, @@ -10059,7 +10071,7 @@ def make_env(): @pytest.mark.parametrize("in_keys_inv", ["action"]) @pytest.mark.parametrize("out_keys_inv", [None, ["action_device"]]) def test_trans_parallel_env_check( - self, in_keys, out_keys, in_keys_inv, out_keys_inv + self, in_keys, out_keys, in_keys_inv, out_keys_inv, _cast_device ): def make_env(): return ContinuousActionVecMockEnv(device="cpu:0") @@ -10071,7 +10083,7 @@ def make_env(): mp_start_method=mp_ctx if not torch.cuda.is_available() else "spawn", ), DeviceCastTransform( - "cpu:1", + _cast_device, in_keys=in_keys, out_keys=out_keys, in_keys_inv=in_keys_inv, @@ -10087,8 +10099,8 @@ def make_env(): except RuntimeError: pass - def test_transform_no_env(self): - t = DeviceCastTransform("cpu:1", "cpu:0", in_keys=["a"], out_keys=["b"]) + def test_transform_no_env(self, _cast_device): + t = DeviceCastTransform(_cast_device, "cpu:0", in_keys=["a"], out_keys=["b"]) td = TensorDict({"a": torch.randn((), device="cpu:0")}, [], device="cpu:0") tdt = t._call(td) assert tdt.device is None @@ -10097,12 +10109,14 @@ def test_transform_no_env(self): @pytest.mark.parametrize("out_keys", [None, ["obs_device"]]) @pytest.mark.parametrize("in_keys_inv", ["action"]) @pytest.mark.parametrize("out_keys_inv", [None, ["action_device"]]) - def test_transform_env(self, in_keys, out_keys, in_keys_inv, out_keys_inv): + def test_transform_env( + self, in_keys, out_keys, in_keys_inv, out_keys_inv, _cast_device + ): env = ContinuousActionVecMockEnv(device="cpu:0") env = TransformedEnv( env, DeviceCastTransform( - "cpu:1", + _cast_device, in_keys=in_keys, out_keys=out_keys, in_keys_inv=in_keys_inv, @@ -10110,13 +10124,13 @@ def test_transform_env(self, in_keys, out_keys, in_keys_inv, out_keys_inv): ), ) assert env.device is None - assert env.transform.device == torch.device("cpu:1") + assert env.transform.device == _cast_device assert env.transform.orig_device == torch.device("cpu:0") - def test_transform_compose(self): + def test_transform_compose(self, _cast_device): t = Compose( DeviceCastTransform( - "cpu:1", + _cast_device, "cpu:0", in_keys=["a"], out_keys=["b"], @@ -10128,7 +10142,7 @@ def test_transform_compose(self): td = TensorDict( { "a": torch.randn((), device="cpu:0"), - "c": torch.randn((), device="cpu:1"), + "c": torch.randn((), device=_cast_device), }, [], device="cpu:0", @@ -10139,11 +10153,11 @@ def test_transform_compose(self): assert tdt.device is None assert tdit.device is None - def test_transform_model(self): + def test_transform_model(self, _cast_device): t = nn.Sequential( Compose( DeviceCastTransform( - "cpu:1", + _cast_device, "cpu:0", in_keys=["a"], out_keys=["b"], @@ -10166,11 +10180,11 @@ def test_transform_model(self): @pytest.mark.parametrize("rbclass", [ReplayBuffer, TensorDictReplayBuffer]) @pytest.mark.parametrize("storage", [LazyTensorStorage]) - def test_transform_rb(self, rbclass, storage): + def test_transform_rb(self, rbclass, storage, _cast_device): # we don't test casting to cuda on Memmap tensor storage since it's discouraged t = Compose( DeviceCastTransform( - "cpu:1", + _cast_device, "cpu:0", in_keys=["a"], out_keys=["b"], @@ -10183,7 +10197,7 @@ def test_transform_rb(self, rbclass, storage): td = TensorDict( { "a": torch.randn((), device="cpu:0"), - "c": torch.randn((), device="cpu:1"), + "c": torch.randn((), device=_cast_device), }, [], device="cpu:0", @@ -10467,17 +10481,18 @@ def test_transform_no_env(self, batch): reason="EndOfLifeTransform can only be tested when Gym is present.", ) class TestEndOfLife(TransformBase): + pytest.mark.filterwarnings("ignore:The base_env is not a gym env") + def test_trans_parallel_env_check(self, maybe_fork_ParallelEnv): def make(): with set_gym_backend("gymnasium"): return GymEnv(BREAKOUT_VERSIONED()) - with pytest.warns(UserWarning, match="The base_env is not a gym env"): - with pytest.raises(AttributeError): - env = TransformedEnv( - maybe_fork_ParallelEnv(2, make), transform=EndOfLifeTransform() - ) - check_env_specs(env) + with pytest.raises(AttributeError): + env = TransformedEnv( + maybe_fork_ParallelEnv(2, make), transform=EndOfLifeTransform() + ) + check_env_specs(env) def test_trans_serial_env_check(self): def make(): @@ -10987,27 +11002,25 @@ class TestRemoveEmptySpecs(TransformBase): class DummyEnv(EnvBase): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - self.observation_spec = CompositeSpec( - observation=UnboundedContinuousTensorSpec((*self.batch_size, 3)), - other=CompositeSpec( - another_other=CompositeSpec(shape=self.batch_size), + self.observation_spec = Composite( + observation=Unbounded((*self.batch_size, 3)), + other=Composite( + another_other=Composite(shape=self.batch_size), shape=self.batch_size, ), shape=self.batch_size, ) - self.action_spec = UnboundedContinuousTensorSpec((*self.batch_size, 3)) - self.done_spec = DiscreteTensorSpec( - 2, (*self.batch_size, 1), dtype=torch.bool - ) + self.action_spec = Unbounded((*self.batch_size, 3)) + self.done_spec = Categorical(2, (*self.batch_size, 1), dtype=torch.bool) self.full_done_spec["truncated"] = self.full_done_spec["terminated"].clone() - self.reward_spec = CompositeSpec( - reward=UnboundedContinuousTensorSpec(*self.batch_size, 1), - other_reward=CompositeSpec(shape=self.batch_size), + self.reward_spec = Composite( + reward=Unbounded(*self.batch_size, 1), + other_reward=Composite(shape=self.batch_size), shape=self.batch_size, ) - self.state_spec = CompositeSpec( - state=CompositeSpec( - sub=CompositeSpec(shape=self.batch_size), shape=self.batch_size + self.state_spec = Composite( + state=Composite( + sub=Composite(shape=self.batch_size), shape=self.batch_size ), shape=self.batch_size, ) @@ -11213,11 +11226,9 @@ class MyEnv(EnvBase): def __init__(self): super().__init__() - self.observation_spec = CompositeSpec( - observation=UnboundedContinuousTensorSpec(3) - ) - self.reward_spec = UnboundedContinuousTensorSpec(1) - self.action_spec = UnboundedContinuousTensorSpec(1) + self.observation_spec = Composite(observation=Unbounded(3)) + self.reward_spec = Unbounded(1) + self.action_spec = Unbounded(1) def _reset(self, tensordict: TensorDictBase, **kwargs) -> TensorDictBase: tensordict_batch_size = ( diff --git a/test/test_utils.py b/test/test_utils.py index c2ce2eae6b9..4224a36b54f 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -174,8 +174,8 @@ def test_implement_for_reset(): ("0.9.0", "0.1.0", "0.21.0", True), ("0.19.99", "0.19.9", "0.21.0", True), ("0.19.99", None, "0.19.0", False), - ("5.61.77", "0.21.0", None, True), - ("5.61.77", None, "0.21.0", False), + ("0.99.0", "0.21.0", None, True), + ("0.99.0", None, "0.21.0", False), ], ) def test_implement_for_check_versions( @@ -189,9 +189,9 @@ def test_implement_for_check_versions( @pytest.mark.parametrize( "gymnasium_version, expected_from_version_gymnasium, expected_to_version_gymnasium", [ - ("0.27.0", None, None), - ("0.27.2", None, None), - ("5.1.77", None, None), + ("0.27.0", None, "1.0.0"), + ("0.27.2", None, "1.0.0"), + # ("1.0.1", "1.0.0", None), ], ) @pytest.mark.parametrize( @@ -199,7 +199,7 @@ def test_implement_for_check_versions( [ ("0.21.0", "0.21.0", None), ("0.22.0", "0.21.0", None), - ("5.61.77", "0.21.0", None), + ("0.99.0", "0.21.0", None), ("0.9.0", None, "0.21.0"), ("0.20.0", None, "0.21.0"), ("0.19.99", None, "0.21.0"), @@ -228,6 +228,8 @@ def test_set_gym_environments( import gymnasium # look for the right function that should be called according to gym versions (and same for gymnasium) + expected_fn_gymnasium = None + expected_fn_gym = None for impfor in implement_for._setters: if impfor.fn.__name__ == "_set_gym_environments": if (impfor.module_name, impfor.from_version, impfor.to_version) == ( @@ -242,20 +244,22 @@ def test_set_gym_environments( expected_to_version_gymnasium, ): expected_fn_gymnasium = impfor.fn + if expected_fn_gym is not None and expected_fn_gymnasium is not None: + break with set_gym_backend(gymnasium): assert ( - _utils_internal._set_gym_environments == expected_fn_gymnasium + _utils_internal._set_gym_environments is expected_fn_gymnasium ), expected_fn_gym with set_gym_backend(gym): assert ( - _utils_internal._set_gym_environments == expected_fn_gym + _utils_internal._set_gym_environments is expected_fn_gym ), expected_fn_gymnasium with set_gym_backend(gymnasium): assert ( - _utils_internal._set_gym_environments == expected_fn_gymnasium + _utils_internal._set_gym_environments is expected_fn_gymnasium ), expected_fn_gym diff --git a/torchrl/__init__.py b/torchrl/__init__.py index 25103423cac..7a41bf0ab8f 100644 --- a/torchrl/__init__.py +++ b/torchrl/__init__.py @@ -3,6 +3,7 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. import os +import weakref from warnings import warn import torch @@ -10,6 +11,7 @@ from tensordict import set_lazy_legacy from torch import multiprocessing as mp +from torch.distributions.transforms import _InverseTransform, ComposeTransform set_lazy_legacy(False).set() @@ -25,6 +27,11 @@ except ImportError: __version__ = None +try: + from torch.compiler import is_dynamo_compiling +except ImportError: + from torch._dynamo import is_compiling as is_dynamo_compiling + _init_extension() try: @@ -51,3 +58,45 @@ filter_warnings_subprocess = True _THREAD_POOL_INIT = torch.get_num_threads() + + +# monkey-patch dist transforms until https://github.com/pytorch/pytorch/pull/135001/ finds a home +@property +def _inv(self): + """Patched version of Transform.inv. + + Returns the inverse :class:`Transform` of this transform. + + This should satisfy ``t.inv.inv is t``. + """ + inv = None + if self._inv is not None: + inv = self._inv() + if inv is None: + inv = _InverseTransform(self) + if not is_dynamo_compiling(): + self._inv = weakref.ref(inv) + return inv + + +torch.distributions.transforms.Transform.inv = _inv + + +@property +def _inv(self): + inv = None + if self._inv is not None: + inv = self._inv() + if inv is None: + inv = ComposeTransform([p.inv for p in reversed(self.parts)]) + if not is_dynamo_compiling(): + self._inv = weakref.ref(inv) + inv._inv = weakref.ref(self) + else: + # We need inv.inv to be equal to self, but weakref can cause a graph break + inv._inv = lambda out=self: out + + return inv + + +ComposeTransform.inv = _inv diff --git a/torchrl/_utils.py b/torchrl/_utils.py index 895f3d80fdc..3af44ee0ed7 100644 --- a/torchrl/_utils.py +++ b/torchrl/_utils.py @@ -46,7 +46,7 @@ console_handler.setFormatter(formatter) logger.addHandler(console_handler) -VERBOSE = strtobool(os.environ.get("VERBOSE", "0")) +VERBOSE = strtobool(os.environ.get("VERBOSE", str(logger.isEnabledFor(logging.DEBUG)))) _os_is_windows = sys.platform == "win32" RL_WARNINGS = strtobool(os.environ.get("RL_WARNINGS", "1")) if RL_WARNINGS: @@ -269,7 +269,7 @@ class implement_for: ... # More recent gym versions will return x + 2 ... return x + 2 ... - >>> @implement_for("gymnasium") + >>> @implement_for("gymnasium", None, "1.0.0") >>> def fun(self, x): ... # If gymnasium is to be used instead of gym, x+3 will be returned ... return x + 3 @@ -785,4 +785,6 @@ def _make_ordinal_device(device: torch.device): return device if device.type == "cuda" and device.index is None: return torch.device("cuda", index=torch.cuda.current_device()) + if device.type == "mps" and device.index is None: + return torch.device("mps", index=0) return device diff --git a/torchrl/collectors/collectors.py b/torchrl/collectors/collectors.py index be24a06e39c..91355ae261f 100644 --- a/torchrl/collectors/collectors.py +++ b/torchrl/collectors/collectors.py @@ -10,13 +10,13 @@ import contextlib import functools - import os import queue import sys import time +import typing import warnings -from collections import OrderedDict +from collections import defaultdict, OrderedDict from copy import deepcopy from multiprocessing import connection, queues from multiprocessing.managers import SyncManager @@ -33,7 +33,8 @@ TensorDictBase, TensorDictParams, ) -from tensordict.nn import TensorDictModule +from tensordict.base import NO_DEFAULT +from tensordict.nn import CudaGraphModule, TensorDictModule from torch import multiprocessing as mp from torch.utils.data import IterableDataset @@ -50,20 +51,22 @@ VERBOSE, ) from torchrl.collectors.utils import split_trajectories +from torchrl.data import ReplayBuffer from torchrl.data.tensor_specs import TensorSpec from torchrl.data.utils import CloudpickleWrapper, DEVICE_TYPING from torchrl.envs.common import _do_nothing, EnvBase +from torchrl.envs.env_creator import EnvCreator from torchrl.envs.transforms import StepCounter, TransformedEnv from torchrl.envs.utils import ( _aggregate_end_of_traj, - _convert_exploration_type, _make_compatible_policy, - _NonParametricPolicyWrapper, ExplorationType, + RandomPolicy, set_exploration_type, ) _TIMEOUT = 1.0 +INSTANTIATE_TIMEOUT = 20 _MIN_TIMEOUT = 1e-3 # should be several orders of magnitude inferior wrt time spent collecting a trajectory # MAX_IDLE_COUNT is the maximum number of times a Dataloader worker can timeout with his queue. _MAX_IDLE_COUNT = int(os.environ.get("MAX_IDLE_COUNT", 1000)) @@ -131,59 +134,97 @@ class DataCollectorBase(IterableDataset, metaclass=abc.ABCMeta): """Base class for data collectors.""" _iterator = None + total_frames: int + frames_per_batch: int + trust_policy: bool + compiled_policy: bool + cudagraphed_policy: bool def _get_policy_and_device( self, - policy: Optional[ - Union[ - TensorDictModule, - Callable[[TensorDictBase], TensorDictBase], - ] - ] = None, + policy: Callable[[Any], Any] | None = None, observation_spec: TensorSpec = None, + policy_device: Any = NO_DEFAULT, + env_maker: Any | None = None, + env_maker_kwargs: dict | None = None, ) -> Tuple[TensorDictModule, Union[None, Callable[[], dict]]]: """Util method to get a policy and its device given the collector __init__ inputs. Args: - create_env_fn (Callable or list of callables): an env creator - function (or a list of creators) - create_env_kwargs (dictionary): kwargs for the env creator policy (TensorDictModule, optional): a policy to be used observation_spec (TensorSpec, optional): spec of the observations + policy_device (torch.device, optional): the device where the policy should be placed. + Defaults to self.policy_device + env_maker (a callable or a batched env, optional): the env_maker function for this device/policy pair. + env_maker_kwargs (a dict, optional): the env_maker function kwargs. """ - policy = _make_compatible_policy( - policy, observation_spec, env=getattr(self, "env", None) - ) - param_and_buf = TensorDict.from_module(policy, as_module=True) - - def get_weights_fn(param_and_buf=param_and_buf): - return param_and_buf.data - - if self.policy_device: - # create a stateless policy and populate it with params - def _map_to_device_params(param, device): - is_param = isinstance(param, nn.Parameter) - - pd = param.detach().to(device, non_blocking=True) - - if is_param: - pd = nn.Parameter(pd, requires_grad=False) - return pd + if policy_device is NO_DEFAULT: + policy_device = self.policy_device + + if not self.trust_policy: + env = getattr(self, "env", None) + policy = _make_compatible_policy( + policy, + observation_spec, + env=env, + env_maker=env_maker, + env_maker_kwargs=env_maker_kwargs, + ) + if not policy_device: + return policy, None - # Create a stateless policy, then populate this copy with params on device - with param_and_buf.apply( - functools.partial(_map_to_device_params, device="meta"), - filter_empty=False, - ).to_module(policy): - policy = deepcopy(policy) + if isinstance(policy, nn.Module): + param_and_buf = TensorDict.from_module(policy, as_module=True) + else: + # Because we want to reach the warning + param_and_buf = TensorDict() + + i = -1 + for p in param_and_buf.values(True, True): + i += 1 + if p.device != policy_device: + # Then we need casting + break + else: + if i == -1 and not self.trust_policy: + # We trust that the policy policy device is adequate + warnings.warn( + "A policy device was provided but no parameter/buffer could be found in " + "the policy. Casting to policy_device is therefore impossible. " + "The collector will trust that the devices match. To suppress this " + "warning, set `trust_policy=True` when building the collector." + ) + return policy, None - param_and_buf.apply( - functools.partial(_map_to_device_params, device=self.policy_device), - filter_empty=False, - ).to_module(policy) + def map_weight( + weight, + policy_device=policy_device, + ): - return policy, get_weights_fn + is_param = isinstance(weight, nn.Parameter) + is_buffer = isinstance(weight, nn.Buffer) + weight = weight.data + if weight.device != policy_device: + weight = weight.to(policy_device) + elif weight.device.type in ("cpu", "mps"): + weight = weight.share_memory_() + if is_param: + weight = nn.Parameter(weight, requires_grad=False) + elif is_buffer: + weight = nn.Buffer(weight) + return weight + + # Create a stateless policy, then populate this copy with params on device + get_original_weights = functools.partial(TensorDict.from_module, policy) + with param_and_buf.to("meta").to_module(policy): + policy = deepcopy(policy) + + param_and_buf.apply( + functools.partial(map_weight), + filter_empty=False, + ).to_module(policy) + return policy, get_original_weights def update_policy_weights_( self, policy_weights: Optional[TensorDictBase] = None @@ -234,6 +275,16 @@ def state_dict(self) -> OrderedDict: def load_state_dict(self, state_dict: OrderedDict) -> None: raise NotImplementedError + def _read_compile_kwargs(self, compile_policy, cudagraph_policy): + self.compiled_policy = compile_policy not in (False, None) + self.cudagraphed_policy = cudagraph_policy not in (False, None) + self.compiled_policy_kwargs = ( + {} if not isinstance(compile_policy, typing.Mapping) else compile_policy + ) + self.cudagraphed_policy_kwargs = ( + {} if not isinstance(cudagraph_policy, typing.Mapping) else cudagraph_policy + ) + def __repr__(self) -> str: string = f"{self.__class__.__name__}()" return string @@ -241,6 +292,11 @@ def __repr__(self) -> str: def __class_getitem__(self, index): raise NotImplementedError + def __len__(self) -> int: + if self.total_frames > 0: + return -(self.total_frames // -self.frames_per_batch) + raise RuntimeError("Non-terminating collectors do not have a length") + @accept_remote_rref_udf_invocation class SyncDataCollector(DataCollectorBase): @@ -357,6 +413,17 @@ class SyncDataCollector(DataCollectorBase): use_buffers (bool, optional): if ``True``, a buffer will be used to stack the data. This isn't compatible with environments with dynamic specs. Defaults to ``True`` for envs without dynamic specs, ``False`` for others. + replay_buffer (ReplayBuffer, optional): if provided, the collector will not yield tensordict + but populate the buffer instead. Defaults to ``None``. + trust_policy (bool, optional): if ``True``, a non-TensorDictModule policy will be trusted to be + assumed to be compatible with the collector. This defaults to ``True`` for CudaGraphModules + and ``False`` otherwise. + compile_policy (bool or Dict[str, Any], optional): if ``True``, the policy will be compiled + using :func:`~torch.compile` default behaviour. If a dictionary of kwargs is passed, it + will be used to compile the policy. + cudagraph_policy (bool or Dict[str, Any], optional): if ``True``, the policy will be wrapped + in :class:`~tensordict.nn.CudaGraphModule` with default kwargs. + If a dictionary of kwargs is passed, it will be used to wrap the policy. Examples: >>> from torchrl.envs.libs.gym import GymEnv @@ -440,19 +507,20 @@ def __init__( postproc: Callable[[TensorDictBase], TensorDictBase] | None = None, split_trajs: bool | None = None, exploration_type: ExplorationType = DEFAULT_EXPLORATION_TYPE, - exploration_mode: str | None = None, return_same_td: bool = False, reset_when_done: bool = True, interruptor=None, set_truncated: bool = False, use_buffers: bool | None = None, + replay_buffer: ReplayBuffer | None = None, + trust_policy: bool = None, + compile_policy: bool | Dict[str, Any] | None = None, + cudagraph_policy: bool | Dict[str, Any] | None = None, + **kwargs, ): from torchrl.envs.batched_envs import BatchedEnvBase self.closed = True - exploration_type = _convert_exploration_type( - exploration_mode=exploration_mode, exploration_type=exploration_type - ) if create_env_kwargs is None: create_env_kwargs = {} if not isinstance(create_env_fn, EnvBase): @@ -468,9 +536,20 @@ def __init__( env.update_kwargs(create_env_kwargs) if policy is None: - from torchrl.collectors import RandomPolicy policy = RandomPolicy(env.full_action_spec) + if trust_policy is None: + trust_policy = isinstance(policy, (RandomPolicy, CudaGraphModule)) + self.trust_policy = trust_policy + self._read_compile_kwargs(compile_policy, cudagraph_policy) + + ########################## + # Trajectory pool + self._traj_pool_val = kwargs.pop("traj_pool", None) + if kwargs: + raise TypeError( + f"Keys {list(kwargs.keys())} are unknown to {type(self).__name__}." + ) ########################## # Setting devices: @@ -493,7 +572,8 @@ def __init__( # Cuda handles sync if torch.cuda.is_available(): self._sync_storage = torch.cuda.synchronize - elif torch.backends.mps.is_available(): + elif torch.backends.mps.is_available() and hasattr(torch, "mps"): + # Will break for older PT versions which don't have torch.mps self._sync_storage = torch.mps.synchronize elif self.storing_device.type == "cpu": self._sync_storage = _do_nothing @@ -507,7 +587,7 @@ def __init__( # Cuda handles sync if torch.cuda.is_available(): self._sync_env = torch.cuda.synchronize - elif torch.backends.mps.is_available(): + elif torch.backends.mps.is_available() and hasattr(torch, "mps"): self._sync_env = torch.mps.synchronize elif self.env_device.type == "cpu": self._sync_env = _do_nothing @@ -520,7 +600,7 @@ def __init__( # Cuda handles sync if torch.cuda.is_available(): self._sync_policy = torch.cuda.synchronize - elif torch.backends.mps.is_available(): + elif torch.backends.mps.is_available() and hasattr(torch, "mps"): self._sync_policy = torch.mps.synchronize elif self.policy_device.type == "cpu": self._sync_policy = _do_nothing @@ -538,10 +618,19 @@ def __init__( self.env: EnvBase = env del env + self.replay_buffer = replay_buffer + if self.replay_buffer is not None: + if postproc is not None: + raise TypeError("postproc must be None when a replay buffer is passed.") + if use_buffers: + raise TypeError("replay_buffer is exclusive with use_buffers.") if use_buffers is None: - use_buffers = not self.env._has_dynamic_specs + use_buffers = not self.env._has_dynamic_specs and self.replay_buffer is None self._use_buffers = use_buffers + self.replay_buffer = replay_buffer + self.closed = False + if not reset_when_done: raise ValueError("reset_when_done is deprectated.") self.reset_when_done = reset_when_done @@ -551,11 +640,15 @@ def __init__( policy=policy, observation_spec=self.env.observation_spec, ) - if isinstance(self.policy, nn.Module): self.policy_weights = TensorDict.from_module(self.policy, as_module=True) else: - self.policy_weights = TensorDict({}, []) + self.policy_weights = TensorDict() + + if self.compiled_policy: + self.policy = torch.compile(self.policy, **self.compiled_policy_kwargs) + if self.cudagraphed_policy: + self.policy = CudaGraphModule(self.policy, **self.cudagraphed_policy_kwargs) if self.env_device: self.env: EnvBase = self.env.to(self.env_device) @@ -655,6 +748,13 @@ def __init__( self._frames = 0 self._iter = -1 + @property + def _traj_pool(self): + pool = getattr(self, "_traj_pool_val", None) + if pool is None: + pool = self._traj_pool_val = _TrajectoryPool() + return pool + def _make_shuttle(self): # Shuttle is a deviceless tensordict that just carried data from env to policy and policy to env with torch.no_grad(): @@ -665,9 +765,9 @@ def _make_shuttle(self): else: self._shuttle_has_no_device = False - traj_ids = torch.arange(self.n_env, device=self.storing_device).view( - self.env.batch_size - ) + traj_ids = self._traj_pool.get_traj_and_increment( + self.n_env, device=self.storing_device + ).view(self.env.batch_size) self._shuttle.set( ("collector", "traj_ids"), traj_ids, @@ -753,11 +853,13 @@ def check_exclusive(val): # changed them here). # This will cause a failure to update entries when policy and env device mismatch and # casting is necessary. - def filter_policy(value_output, value_input, value_input_clone): - if ( - (value_input is None) - or (value_output is not value_input) - or ~torch.isclose(value_output, value_input_clone).any() + def filter_policy(name, value_output, value_input, value_input_clone): + if (value_input is None) or ( + (value_output is not value_input) + and ( + value_output.device != value_input_clone.device + or ~torch.isclose(value_output, value_input_clone).any() + ) ): return value_output @@ -767,6 +869,7 @@ def filter_policy(value_output, value_input, value_input_clone): policy_input_clone, default=None, filter_empty=True, + named=True, ) self._policy_output_keys = list( self._policy_output_keys.union( @@ -871,7 +974,15 @@ def set_seed(self, seed: int, static_seed: bool = False) -> int: >>> out_seed = collector.set_seed(1) # out_seed = 6 """ - return self.env.set_seed(seed, static_seed=static_seed) + out = self.env.set_seed(seed, static_seed=static_seed) + return out + + def _increment_frames(self, numel): + self._frames += numel + completed = self._frames >= self.total_frames + if completed: + self.env.close() + return completed def iterator(self) -> Iterator[TensorDictBase]: """Iterates through the DataCollector. @@ -917,14 +1028,15 @@ def cuda_check(tensor: torch.Tensor): for stream in streams: stack.enter_context(torch.cuda.stream(stream)) - total_frames = self.total_frames - while self._frames < self.total_frames: self._iter += 1 tensordict_out = self.rollout() - self._frames += tensordict_out.numel() - if self._frames >= total_frames: - self.env.close() + if tensordict_out is None: + # if a replay buffer is passed, there is no tensordict_out + # frames are updated within the rollout function + yield + continue + self._increment_frames(tensordict_out.numel()) if self.split_trajs: tensordict_out = split_trajectories( @@ -976,14 +1088,20 @@ def _update_traj_ids(self, env_output) -> None: env_output.get("next"), done_keys=self.env.done_keys ) if traj_sop.any(): + device = self.storing_device + traj_ids = self._shuttle.get(("collector", "traj_ids")) - traj_sop = traj_sop.to(self.storing_device) - traj_ids = traj_ids.clone().to(self.storing_device) - traj_ids[traj_sop] = traj_ids.max() + torch.arange( - 1, - traj_sop.sum() + 1, - device=self.storing_device, + if device is not None: + traj_ids = traj_ids.to(device) + traj_sop = traj_sop.to(device) + elif traj_sop.device != traj_ids.device: + traj_sop = traj_sop.to(traj_ids.device) + + pool = self._traj_pool + new_traj = pool.get_traj_and_increment( + traj_sop.sum(), device=traj_sop.device ) + traj_ids = traj_ids.masked_scatter(traj_sop, new_traj) self._shuttle.set(("collector", "traj_ids"), traj_ids) @torch.no_grad() @@ -1053,13 +1171,18 @@ def rollout(self) -> TensorDictBase: next_data.clear_device_() self._shuttle.set("next", next_data) - if self.storing_device is not None: - tensordicts.append( - self._shuttle.to(self.storing_device, non_blocking=True) - ) - self._sync_storage() + if self.replay_buffer is not None: + self.replay_buffer.add(self._shuttle) + if self._increment_frames(self._shuttle.numel()): + return else: - tensordicts.append(self._shuttle) + if self.storing_device is not None: + tensordicts.append( + self._shuttle.to(self.storing_device, non_blocking=True) + ) + self._sync_storage() + else: + tensordicts.append(self._shuttle) # carry over collector data without messing up devices collector_data = self._shuttle.get("collector").copy() @@ -1067,13 +1190,14 @@ def rollout(self) -> TensorDictBase: if self._shuttle_has_no_device: self._shuttle.clear_device_() self._shuttle.set("collector", collector_data) - self._update_traj_ids(env_output) if ( self.interruptor is not None and self.interruptor.collection_stopped() ): + if self.replay_buffer is not None: + return result = self._final_rollout if self._use_buffers: try: @@ -1109,6 +1233,8 @@ def rollout(self) -> TensorDictBase: self._final_rollout.ndim - 1, out=self._final_rollout, ) + elif self.replay_buffer is not None: + return else: result = TensorDict.maybe_dense_stack(tensordicts, dim=-1) result.refine_names(..., "time") @@ -1359,7 +1485,7 @@ class _MultiDataCollector(DataCollectorBase): workers may charge the cpu load too much and harm performance. cat_results (str, int or None): (:class:`~torchrl.collectors.MultiSyncDataCollector` exclusively). If ``"stack"``, the data collected from the workers will be stacked along the - first dimension. This is the preferred behaviour as it is the most compatible + first dimension. This is the preferred behavior as it is the most compatible with the rest of the library. If ``0``, results will be concatenated along the first dimension of the outputs, which can be the batched dimension if the environments are @@ -1367,7 +1493,7 @@ class _MultiDataCollector(DataCollectorBase): A ``cat_results`` value of ``-1`` will always concatenate results along the time dimension. This should be preferred over the default. Intermediate values are also accepted. - Defaults to ``0``. + Defaults to ``"stack"``. .. note:: From v0.5, this argument will default to ``"stack"`` for a better interoperability with the rest of the library. @@ -1380,6 +1506,17 @@ class _MultiDataCollector(DataCollectorBase): use_buffers (bool, optional): if ``True``, a buffer will be used to stack the data. This isn't compatible with environments with dynamic specs. Defaults to ``True`` for envs without dynamic specs, ``False`` for others. + replay_buffer (ReplayBuffer, optional): if provided, the collector will not yield tensordict + but populate the buffer instead. Defaults to ``None``. + trust_policy (bool, optional): if ``True``, a non-TensorDictModule policy will be trusted to be + assumed to be compatible with the collector. This defaults to ``True`` for CudaGraphModules + and ``False`` otherwise. + compile_policy (bool or Dict[str, Any], optional): if ``True``, the policy will be compiled + using :func:`~torch.compile` default behaviour. If a dictionary of kwargs is passed, it + will be used to compile the policy. + cudagraph_policy (bool or Dict[str, Any], optional): if ``True``, the policy will be wrapped + in :class:`~tensordict.nn.CudaGraphModule` with default kwargs. + If a dictionary of kwargs is passed, it will be used to wrap the policy. """ @@ -1406,7 +1543,6 @@ def __init__( postproc: Optional[Callable[[TensorDictBase], TensorDictBase]] = None, split_trajs: Optional[bool] = None, exploration_type: ExplorationType = DEFAULT_EXPLORATION_TYPE, - exploration_mode=None, reset_when_done: bool = True, update_at_each_batch: bool = False, preemptive_threshold: float = None, @@ -1415,10 +1551,12 @@ def __init__( cat_results: str | int | None = None, set_truncated: bool = False, use_buffers: bool | None = None, + replay_buffer: ReplayBuffer | None = None, + replay_buffer_chunk: bool = True, + trust_policy: bool = None, + compile_policy: bool | Dict[str, Any] | None = None, + cudagraph_policy: bool | Dict[str, Any] | None = None, ): - exploration_type = _convert_exploration_type( - exploration_mode=exploration_mode, exploration_type=exploration_type - ) self.closed = True self.num_workers = len(create_env_fn) @@ -1426,6 +1564,7 @@ def __init__( self.num_sub_threads = num_sub_threads self.num_threads = num_threads self.create_env_fn = create_env_fn + self._read_compile_kwargs(compile_policy, cudagraph_policy) self.create_env_kwargs = ( create_env_kwargs if create_env_kwargs is not None @@ -1458,79 +1597,42 @@ def __init__( del storing_device, env_device, policy_device, device self._use_buffers = use_buffers + self.replay_buffer = replay_buffer + self._check_replay_buffer_init() + self.replay_buffer_chunk = replay_buffer_chunk + if ( + replay_buffer is not None + and hasattr(replay_buffer, "shared") + and not replay_buffer.shared + ): + replay_buffer.share() - _policy_weights_dict = {} - _get_weights_fn_dict = {} - - if policy is not None: - policy = _NonParametricPolicyWrapper(policy) - policy_weights = TensorDict.from_module(policy, as_module=True) - - # store a stateless policy - with policy_weights.apply(_make_meta_params).to_module(policy): - # TODO: - self.policy = deepcopy(policy) - - else: - policy_weights = TensorDict() - self.policy = None - - for policy_device in policy_devices: - # if we have already mapped onto that device, get that value - if policy_device in _policy_weights_dict: - continue - # If policy device is None, the only thing we need to do is - # make sure that the weights are shared. - if policy_device is None: + self._policy_weights_dict = {} + self._get_weights_fn_dict = {} - def map_weight( - weight, - ): - is_param = isinstance(weight, nn.Parameter) - weight = weight.data - if weight.device.type in ("cpu", "mps"): - weight = weight.share_memory_() - if is_param: - weight = nn.Parameter(weight, requires_grad=False) - return weight - - # in other cases, we need to cast the policy if and only if not all the weights - # are on the appropriate device - else: - # check the weights devices - has_different_device = [False] + if trust_policy is None: + trust_policy = isinstance(policy, CudaGraphModule) + self.trust_policy = trust_policy - def map_weight( - weight, - policy_device=policy_device, - has_different_device=has_different_device, - ): - is_param = isinstance(weight, nn.Parameter) - weight = weight.data - if weight.device != policy_device: - has_different_device[0] = True - weight = weight.to(policy_device) - elif weight.device.type in ("cpu", "mps"): - weight = weight.share_memory_() - if is_param: - weight = nn.Parameter(weight, requires_grad=False) - return weight - - local_policy_weights = TensorDictParams( - policy_weights.apply(map_weight, filter_empty=False) + for policy_device, env_maker, env_maker_kwargs in zip( + self.policy_device, self.create_env_fn, self.create_env_kwargs + ): + (policy_copy, get_weights_fn,) = self._get_policy_and_device( + policy=policy, + policy_device=policy_device, + env_maker=env_maker, + env_maker_kwargs=env_maker_kwargs, ) - - def _get_weight_fn(weights=policy_weights): - # This function will give the local_policy_weight the original weights. - # see self.update_policy_weights_ to see how this is used - return weights - - # We lock the weights to be able to cache a bunch of ops and to avoid modifying it - _policy_weights_dict[policy_device] = local_policy_weights.lock_() - _get_weights_fn_dict[policy_device] = _get_weight_fn - - self._policy_weights_dict = _policy_weights_dict - self._get_weights_fn_dict = _get_weights_fn_dict + if type(policy_copy) is not type(policy): + policy = policy_copy + weights = ( + TensorDict.from_module(policy_copy) + if isinstance(policy_copy, nn.Module) + else TensorDict() + ) + self._policy_weights_dict[policy_device] = weights + self._get_weights_fn_dict[policy_device] = get_weights_fn + self.policy = policy if total_frames is None or total_frames < 0: total_frames = float("inf") @@ -1598,6 +1700,26 @@ def _get_weight_fn(weights=policy_weights): ) self.cat_results = cat_results + def _check_replay_buffer_init(self): + if self.replay_buffer is None: + return + is_init = getattr(self.replay_buffer._storage, "initialized", True) + if not is_init: + if isinstance(self.create_env_fn[0], EnvCreator): + fake_td = self.create_env_fn[0].meta_data.tensordict + elif isinstance(self.create_env_fn[0], EnvBase): + fake_td = self.create_env_fn[0].fake_tensordict() + else: + fake_td = self.create_env_fn[0]( + **self.create_env_kwargs[0] + ).fake_tensordict() + fake_td["collector", "traj_ids"] = torch.zeros( + fake_td.shape, dtype=torch.long + ) + + self.replay_buffer.add(fake_td) + self.replay_buffer.empty() + @classmethod def _total_workers_from_env(cls, env_creators): if isinstance(env_creators, (tuple, list)): @@ -1665,10 +1787,10 @@ def frames_per_batch_worker(self): raise NotImplementedError def update_policy_weights_(self, policy_weights=None) -> None: + if isinstance(policy_weights, TensorDictParams): + policy_weights = policy_weights.data for _device in self._policy_weights_dict: if policy_weights is not None: - if isinstance(policy_weights, TensorDictParams): - policy_weights = policy_weights.data self._policy_weights_dict[_device].data.update_(policy_weights) elif self._get_weights_fn_dict[_device] is not None: original_weights = self._get_weights_fn_dict[_device]() @@ -1694,6 +1816,8 @@ def _run_processes(self) -> None: queue_out = mp.Queue(self._queue_len) # sends data from proc to main self.procs = [] self.pipes = [] + self._traj_pool = _TrajectoryPool(lock=True) + for i, (env_fun, env_fun_kwargs) in enumerate( zip(self.create_env_fn, self.create_env_kwargs) ): @@ -1708,9 +1832,12 @@ def _run_processes(self) -> None: storing_device = self.storing_device[i] env_device = self.env_device[i] policy = self.policy - with self._policy_weights_dict[policy_device].to_module( - policy - ) if policy is not None else contextlib.nullcontext(): + policy_weights = self._policy_weights_dict[policy_device] + if policy is not None and policy_weights is not None: + cm = policy_weights.to_module(policy) + else: + cm = contextlib.nullcontext() + with cm: kwargs = { "pipe_parent": pipe_parent, "pipe_child": pipe_child, @@ -1730,6 +1857,16 @@ def _run_processes(self) -> None: "interruptor": self.interruptor, "set_truncated": self.set_truncated, "use_buffers": self._use_buffers, + "replay_buffer": self.replay_buffer, + "replay_buffer_chunk": self.replay_buffer_chunk, + "traj_pool": self._traj_pool, + "trust_policy": self.trust_policy, + "compile_policy": self.compiled_policy_kwargs + if self.compiled_policy + else False, + "cudagraph_policy": self.cudagraphed_policy_kwargs + if self.cudagraphed_policy + else False, } proc = _ProcessNoWarn( target=_main_async_collector, @@ -1754,6 +1891,7 @@ def _run_processes(self) -> None: self.procs.append(proc) self.pipes.append(pipe_parent) for pipe_parent in self.pipes: + pipe_parent.poll(timeout=INSTANTIATE_TIMEOUT) msg = pipe_parent.recv() if msg != "instantiated": raise RuntimeError(msg) @@ -2069,29 +2207,12 @@ def iterator(self) -> Iterator[TensorDictBase]: cat_results = self.cat_results if cat_results is None: cat_results = "stack" - warnings.warn( - f"`cat_results` was not specified in the constructor of {type(self).__name__}. " - f"For MultiSyncDataCollector, `cat_results` indicates how the data should " - f"be packed: the preferred option and current default is `cat_results='stack'` " - f"which provides the best interoperability across torchrl components. " - f"Other accepted values are `cat_results=0` (previous behaviour) and " - f"`cat_results=-1` (cat along time dimension). Among these two, the latter " - f"should be preferred for consistency across environment configurations. " - f"Currently, the default value is `'stack'`." - f"From v0.6 onward, this warning will be removed. " - f"To suppress this warning, set `cat_results` to the desired value.", - category=DeprecationWarning, - ) self.buffers = {} dones = [False for _ in range(self.num_workers)] workers_frames = [0 for _ in range(self.num_workers)] same_device = None self.out_buffer = None - last_traj_ids = [-10 for _ in range(self.num_workers)] - last_traj_ids_subs = [None for _ in range(self.num_workers)] - traj_max = -1 - traj_ids_list = [None for _ in range(self.num_workers)] preempt = self.interruptor is not None and self.preemptive_threshold < 1.0 while not all(dones) and self._frames < self.total_frames: @@ -2125,7 +2246,13 @@ def iterator(self) -> Iterator[TensorDictBase]: for _ in range(self.num_workers): new_data, j = self.queue_out.get() use_buffers = self._use_buffers - if j == 0 or not use_buffers: + if self.replay_buffer is not None: + idx = new_data + workers_frames[idx] = ( + workers_frames[idx] + self.frames_per_batch_worker + ) + continue + elif j == 0 or not use_buffers: try: data, idx = new_data self.buffers[idx] = data @@ -2167,51 +2294,25 @@ def iterator(self) -> Iterator[TensorDictBase]: if workers_frames[idx] >= self.total_frames: dones[idx] = True + if self.replay_buffer is not None: + yield + self._frames += self.frames_per_batch_worker * self.num_workers + continue + # we have to correct the traj_ids to make sure that they don't overlap # We can count the number of frames collected for free in this loop n_collected = 0 for idx in range(self.num_workers): buffer = buffers[idx] traj_ids = buffer.get(("collector", "traj_ids")) - is_last = traj_ids == last_traj_ids[idx] - # If we `cat` interrupted data, we have already filtered out - # non-valid steps. If we stack, we haven't. - if preempt and cat_results == "stack": - valid = buffer.get(("collector", "traj_ids")) != -1 - if valid.ndim > 2: - valid = valid.flatten(0, -2) - if valid.ndim == 2: - valid = valid.any(0) - last_traj_ids[idx] = traj_ids[..., valid][..., -1:].clone() - else: - last_traj_ids[idx] = traj_ids[..., -1:].clone() - if not is_last.all(): - traj_to_correct = traj_ids[~is_last] - traj_to_correct = ( - traj_to_correct + (traj_max + 1) - traj_to_correct.min() - ) - traj_ids = traj_ids.masked_scatter(~is_last, traj_to_correct) - # is_last can only be true if we're after the first iteration - if is_last.any(): - traj_ids = torch.where( - is_last, last_traj_ids_subs[idx].expand_as(traj_ids), traj_ids - ) - if preempt: if cat_results == "stack": mask_frames = buffer.get(("collector", "traj_ids")) != -1 - traj_ids = torch.where(mask_frames, traj_ids, -1) n_collected += mask_frames.sum().cpu() - last_traj_ids_subs[idx] = traj_ids[..., valid][..., -1:].clone() else: - last_traj_ids_subs[idx] = traj_ids[..., -1:].clone() n_collected += traj_ids.numel() else: - last_traj_ids_subs[idx] = traj_ids[..., -1:].clone() n_collected += traj_ids.numel() - traj_ids_list[idx] = traj_ids - - traj_max = max(traj_max, traj_ids.max()) if same_device is None: prev_device = None @@ -2232,9 +2333,6 @@ def iterator(self) -> Iterator[TensorDictBase]: self.out_buffer = stack( [item.cpu() for item in buffers.values()], 0 ) - self.out_buffer.set( - ("collector", "traj_ids"), torch.stack(traj_ids_list), inplace=True - ) else: if self._use_buffers is None: torchrl_logger.warning( @@ -2251,9 +2349,6 @@ def iterator(self) -> Iterator[TensorDictBase]: self.out_buffer = torch.cat( [item.cpu() for item in buffers.values()], cat_results ) - self.out_buffer.set_( - ("collector", "traj_ids"), torch.cat(traj_ids_list, cat_results) - ) except RuntimeError as err: if ( preempt @@ -2397,7 +2492,7 @@ class MultiaSyncDataCollector(_MultiDataCollector): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - self.out_tensordicts = {} + self.out_tensordicts = defaultdict(lambda: None) self.running = False if self.postprocs is not None: @@ -2442,7 +2537,9 @@ def frames_per_batch_worker(self): def _get_from_queue(self, timeout=None) -> Tuple[int, int, TensorDictBase]: new_data, j = self.queue_out.get(timeout=timeout) use_buffers = self._use_buffers - if j == 0 or not use_buffers: + if self.replay_buffer is not None: + idx = new_data + elif j == 0 or not use_buffers: try: data, idx = new_data self.out_tensordicts[idx] = data @@ -2457,7 +2554,7 @@ def _get_from_queue(self, timeout=None) -> Tuple[int, int, TensorDictBase]: else: idx = new_data out = self.out_tensordicts[idx] - if j == 0 or use_buffers: + if not self.replay_buffer and (j == 0 or use_buffers): # we clone the data to make sure that we'll be working with a fixed copy out = out.clone() return idx, j, out @@ -2479,12 +2576,19 @@ def iterator(self) -> Iterator[TensorDictBase]: workers_frames = [0 for _ in range(self.num_workers)] while self._frames < self.total_frames: - _check_for_faulty_process(self.procs) self._iter += 1 - idx, j, out = self._get_from_queue() - worker_frames = out.numel() - if self.split_trajs: - out = split_trajectories(out, prefix="collector") + while True: + try: + idx, j, out = self._get_from_queue(timeout=10.0) + break + except TimeoutError: + _check_for_faulty_process(self.procs) + if self.replay_buffer is None: + worker_frames = out.numel() + if self.split_trajs: + out = split_trajectories(out, prefix="collector") + else: + worker_frames = self.frames_per_batch_worker self._frames += worker_frames workers_frames[idx] = workers_frames[idx] + worker_frames if self.postprocs: @@ -2500,7 +2604,7 @@ def iterator(self) -> Iterator[TensorDictBase]: else: msg = "continue" self.pipes[idx].send((idx, msg)) - if self._exclude_private_keys: + if out is not None and self._exclude_private_keys: excluded_keys = [key for key in out.keys() if key.startswith("_")] out = out.exclude(*excluded_keys) yield out @@ -2687,7 +2791,6 @@ def __init__( postproc: Optional[Callable[[TensorDictBase], TensorDictBase]] = None, split_trajs: Optional[bool] = None, exploration_type: ExplorationType = DEFAULT_EXPLORATION_TYPE, - exploration_mode=None, reset_when_done: bool = True, update_at_each_batch: bool = False, preemptive_threshold: float = None, @@ -2712,7 +2815,6 @@ def __init__( env_device=env_device, storing_device=storing_device, exploration_type=exploration_type, - exploration_mode=exploration_mode, reset_when_done=reset_when_done, update_at_each_batch=update_at_each_batch, preemptive_threshold=preemptive_threshold, @@ -2762,6 +2864,12 @@ def _main_async_collector( interruptor=None, set_truncated: bool = False, use_buffers: bool | None = None, + replay_buffer: ReplayBuffer | None = None, + replay_buffer_chunk: bool = True, + traj_pool: _TrajectoryPool = None, + trust_policy: bool = False, + compile_policy: bool = False, + cudagraph_policy: bool = False, ) -> None: pipe_parent.close() # init variables that will be cleared when closing @@ -2782,10 +2890,15 @@ def _main_async_collector( env_device=env_device, exploration_type=exploration_type, reset_when_done=reset_when_done, - return_same_td=True, + return_same_td=replay_buffer is None, interruptor=interruptor, set_truncated=set_truncated, use_buffers=use_buffers, + replay_buffer=replay_buffer if replay_buffer_chunk else None, + traj_pool=traj_pool, + trust_policy=trust_policy, + compile_policy=compile_policy, + cudagraph_policy=cudagraph_policy, ) use_buffers = inner_collector._use_buffers if verbose: @@ -2848,6 +2961,25 @@ def _main_async_collector( # In that case, we skip the collected trajectory and get the message from main. This is faster than # sending the trajectory in the queue until timeout when it's never going to be received. continue + + if replay_buffer is not None: + if not replay_buffer_chunk: + next_data.names = None + replay_buffer.extend(next_data) + + try: + queue_out.put((idx, j), timeout=_TIMEOUT) + if verbose: + torchrl_logger.info(f"worker {idx} successfully sent data") + j += 1 + has_timed_out = False + continue + except queue.Full: + if verbose: + torchrl_logger.info(f"worker {idx} has timed out") + has_timed_out = True + continue + if j == 0 or not use_buffers: collected_tensordict = next_data if ( @@ -2862,10 +2994,16 @@ def _main_async_collector( # if policy is on cuda and env on cuda, we are fine with this # If policy is on cuda and env on cpu (or opposite) we put tensors that # are on cpu in shared mem. + MPS_ERROR = ( + "tensors on mps device cannot be put in shared memory. Make sure " + "the shared device (aka storing_device) is set to CPU." + ) if collected_tensordict.device is not None: - # placehoder in case we need different behaviours - if collected_tensordict.device.type in ("cpu", "mps"): + # placehoder in case we need different behaviors + if collected_tensordict.device.type in ("cpu",): collected_tensordict.share_memory_() + elif collected_tensordict.device.type in ("mps",): + raise RuntimeError(MPS_ERROR) elif collected_tensordict.device.type == "cuda": collected_tensordict.share_memory_() else: @@ -2874,11 +3012,13 @@ def _main_async_collector( ) else: # make sure each cpu tensor is shared - assuming non-cpu devices are shared - collected_tensordict.apply( - lambda x: x.share_memory_() - if x.device.type in ("cpu", "mps") - else x - ) + def cast_tensor(x, MPS_ERROR=MPS_ERROR): + if x.device.type in ("cpu",): + x.share_memory_() + if x.device.type in ("mps",): + RuntimeError(MPS_ERROR) + + collected_tensordict.apply(cast_tensor, filter_empty=True) data = (collected_tensordict, idx) else: if next_data is not collected_tensordict: @@ -2956,3 +3096,20 @@ def _make_meta_params(param): if is_param: pd = nn.Parameter(pd, requires_grad=False) return pd + + +class _TrajectoryPool: + def __init__(self, ctx=None, lock: bool = False): + self.ctx = ctx + self._traj_id = torch.zeros((), device="cpu", dtype=torch.int).share_memory_() + if ctx is None: + self.lock = contextlib.nullcontext() if not lock else mp.RLock() + else: + self.lock = contextlib.nullcontext() if not lock else ctx.RLock() + + def get_traj_and_increment(self, n=1, device=None): + with self.lock: + v = self._traj_id.item() + out = torch.arange(v, v + n).to(device) + self._traj_id.copy_(1 + out[-1].item()) + return out diff --git a/torchrl/collectors/distributed/generic.py b/torchrl/collectors/distributed/generic.py index 596c1f5d191..729b8a48171 100644 --- a/torchrl/collectors/distributed/generic.py +++ b/torchrl/collectors/distributed/generic.py @@ -30,11 +30,10 @@ MAX_TIME_TO_CONNECT, TCP_PORT, ) -from torchrl.collectors.utils import split_trajectories +from torchrl.collectors.utils import _NON_NN_POLICY_WEIGHTS, split_trajectories from torchrl.data.utils import CloudpickleWrapper from torchrl.envs.common import EnvBase from torchrl.envs.env_creator import EnvCreator -from torchrl.envs.utils import _convert_exploration_type SUBMITIT_ERR = None try: @@ -172,18 +171,11 @@ def _run_collector( ) if isinstance(policy, nn.Module): - policy_weights = TensorDict(dict(policy.named_parameters()), []) - # TODO: Do we want this? - # updates the policy weights to avoid them to be shared - if all( - param.device == torch.device("cpu") for param in policy_weights.values() - ): - policy = deepcopy(policy) - policy_weights = TensorDict(dict(policy.named_parameters()), []) - - policy_weights = policy_weights.apply(lambda x: x.data) + policy_weights = TensorDict.from_module(policy) + policy_weights = policy_weights.data.lock_() else: - policy_weights = TensorDict({}, []) + warnings.warn(_NON_NN_POLICY_WEIGHTS) + policy_weights = TensorDict(lock=True) collector = collector_class( env_make, @@ -426,7 +418,6 @@ def __init__( postproc: Callable | None = None, split_trajs: bool = False, exploration_type: "ExporationType" = DEFAULT_EXPLORATION_TYPE, # noqa - exploration_mode: str = None, collector_class: Type = SyncDataCollector, collector_kwargs: dict = None, num_workers_per_collector: int = 1, @@ -438,9 +429,6 @@ def __init__( launcher: str = "submitit", tcp_port: int = None, ): - exploration_type = _convert_exploration_type( - exploration_mode=exploration_mode, exploration_type=exploration_type - ) if collector_class == "async": collector_class = MultiaSyncDataCollector @@ -452,10 +440,11 @@ def __init__( self.env_constructors = create_env_fn self.policy = policy if isinstance(policy, nn.Module): - policy_weights = TensorDict(dict(policy.named_parameters()), []) - policy_weights = policy_weights.apply(lambda x: x.data) + policy_weights = TensorDict.from_module(policy) + policy_weights = policy_weights.data.lock_() else: - policy_weights = TensorDict({}, []) + warnings.warn(_NON_NN_POLICY_WEIGHTS) + policy_weights = TensorDict(lock=True) self.policy_weights = policy_weights self.num_workers = len(create_env_fn) self.frames_per_batch = frames_per_batch @@ -820,6 +809,8 @@ def _iterator_dist(self): for i in range(self.num_workers): rank = i + 1 + if self._VERBOSE: + torchrl_logger.info(f"shutting down rank {rank}.") self._store.set(f"NODE_{rank}_in", b"shutdown") def _next_sync(self, total_frames): diff --git a/torchrl/collectors/distributed/ray.py b/torchrl/collectors/distributed/ray.py index 79b3ee9063c..1f088c2c404 100644 --- a/torchrl/collectors/distributed/ray.py +++ b/torchrl/collectors/distributed/ray.py @@ -20,7 +20,7 @@ MultiSyncDataCollector, SyncDataCollector, ) -from torchrl.collectors.utils import split_trajectories +from torchrl.collectors.utils import _NON_NN_POLICY_WEIGHTS, split_trajectories from torchrl.envs.common import EnvBase from torchrl.envs.env_creator import EnvCreator @@ -99,7 +99,7 @@ class RayCollector(DataCollectorBase): The class dictionary input parameter "ray_init_config" can be used to provide the kwargs to call Ray initialization method ray.init(). If "ray_init_config" is not provided, the default - behaviour is to autodetect an existing Ray cluster or start a new Ray instance locally if no + behavior is to autodetect an existing Ray cluster or start a new Ray instance locally if no existing cluster is found. Refer to Ray documentation for advanced initialization kwargs. Similarly, dictionary input parameter "remote_configs" can be used to specify the kwargs for @@ -401,9 +401,11 @@ def check_list_length_consistency(*lists): self._local_policy = policy if isinstance(self._local_policy, nn.Module): - policy_weights = TensorDict(dict(policy.named_parameters()), []) + policy_weights = TensorDict.from_module(self._local_policy) + policy_weights = policy_weights.data.lock_() else: - policy_weights = TensorDict({}, []) + warnings.warn(_NON_NN_POLICY_WEIGHTS) + policy_weights = TensorDict(lock=True) self.policy_weights = policy_weights self.collector_class = collector_class self.collected_frames = 0 diff --git a/torchrl/collectors/distributed/rpc.py b/torchrl/collectors/distributed/rpc.py index b6c324bb7b5..73247df4b0c 100644 --- a/torchrl/collectors/distributed/rpc.py +++ b/torchrl/collectors/distributed/rpc.py @@ -22,9 +22,8 @@ IDLE_TIMEOUT, TCP_PORT, ) -from torchrl.collectors.utils import split_trajectories +from torchrl.collectors.utils import _NON_NN_POLICY_WEIGHTS, split_trajectories from torchrl.data.utils import CloudpickleWrapper -from torchrl.envs.utils import _convert_exploration_type SUBMITIT_ERR = None try: @@ -275,7 +274,6 @@ def __init__( postproc: Callable | None = None, split_trajs: bool = False, exploration_type: "ExporationType" = DEFAULT_EXPLORATION_TYPE, # noqa - exploration_mode: str = None, collector_class=SyncDataCollector, collector_kwargs=None, num_workers_per_collector=1, @@ -288,9 +286,6 @@ def __init__( visible_devices=None, tensorpipe_options=None, ): - exploration_type = _convert_exploration_type( - exploration_mode=exploration_mode, exploration_type=exploration_type - ) if collector_class == "async": collector_class = MultiaSyncDataCollector elif collector_class == "sync": @@ -301,9 +296,11 @@ def __init__( self.env_constructors = create_env_fn self.policy = policy if isinstance(policy, nn.Module): - policy_weights = TensorDict(dict(policy.named_parameters()), []) + policy_weights = TensorDict.from_module(policy) + policy_weights = policy_weights.data.lock_() else: - policy_weights = TensorDict({}, []) + warnings.warn(_NON_NN_POLICY_WEIGHTS) + policy_weights = TensorDict(lock=True) self.policy_weights = policy_weights self.num_workers = len(create_env_fn) self.frames_per_batch = frames_per_batch diff --git a/torchrl/collectors/distributed/sync.py b/torchrl/collectors/distributed/sync.py index 6f959086c83..481fb70cc31 100644 --- a/torchrl/collectors/distributed/sync.py +++ b/torchrl/collectors/distributed/sync.py @@ -8,6 +8,7 @@ import os import socket +import warnings from copy import copy, deepcopy from datetime import timedelta from typing import Callable, List, OrderedDict @@ -29,11 +30,10 @@ DEFAULT_SLURM_CONF, MAX_TIME_TO_CONNECT, ) -from torchrl.collectors.utils import split_trajectories +from torchrl.collectors.utils import _NON_NN_POLICY_WEIGHTS, split_trajectories from torchrl.data.utils import CloudpickleWrapper from torchrl.envs.common import EnvBase from torchrl.envs.env_creator import EnvCreator -from torchrl.envs.utils import _convert_exploration_type SUBMITIT_ERR = None try: @@ -78,18 +78,11 @@ def _distributed_init_collection_node( ) if isinstance(policy, nn.Module): - policy_weights = TensorDict(dict(policy.named_parameters()), []) - # TODO: Do we want this? - # updates the policy weights to avoid them to be shared - if all( - param.device == torch.device("cpu") for param in policy_weights.values() - ): - policy = deepcopy(policy) - policy_weights = TensorDict(dict(policy.named_parameters()), []) - - policy_weights = policy_weights.apply(lambda x: x.data) + policy_weights = TensorDict.from_module(policy) + policy_weights = policy_weights.data.lock_() else: - policy_weights = TensorDict({}, []) + warnings.warn(_NON_NN_POLICY_WEIGHTS) + policy_weights = TensorDict(lock=True) collector = collector_class( env_make, @@ -291,7 +284,6 @@ def __init__( postproc: Callable | None = None, split_trajs: bool = False, exploration_type: "ExporationType" = DEFAULT_EXPLORATION_TYPE, # noqa - exploration_mode: str = None, collector_class=SyncDataCollector, collector_kwargs=None, num_workers_per_collector=1, @@ -302,9 +294,6 @@ def __init__( launcher="submitit", tcp_port=None, ): - exploration_type = _convert_exploration_type( - exploration_mode=exploration_mode, exploration_type=exploration_type - ) if collector_class == "async": collector_class = MultiaSyncDataCollector @@ -315,11 +304,14 @@ def __init__( self.collector_class = collector_class self.env_constructors = create_env_fn self.policy = policy + if isinstance(policy, nn.Module): - policy_weights = TensorDict(dict(policy.named_parameters()), []) - policy_weights = policy_weights.apply(lambda x: x.data) + policy_weights = TensorDict.from_module(policy) + policy_weights = policy_weights.data.lock_() else: - policy_weights = TensorDict({}, []) + warnings.warn(_NON_NN_POLICY_WEIGHTS) + policy_weights = TensorDict(lock=True) + self.policy_weights = policy_weights self.num_workers = len(create_env_fn) self.frames_per_batch = frames_per_batch diff --git a/torchrl/collectors/distributed/utils.py b/torchrl/collectors/distributed/utils.py index aeee573f8dc..2dd6fcf6c93 100644 --- a/torchrl/collectors/distributed/utils.py +++ b/torchrl/collectors/distributed/utils.py @@ -53,10 +53,10 @@ class submitit_delayed_launcher: ... def main(): ... from torchrl.envs.utils import RandomPolicy from torchrl.envs.libs.gym import GymEnv - ... from torchrl.data import BoundedTensorSpec + ... from torchrl.data import BoundedContinuous ... collector = DistributedDataCollector( ... [EnvCreator(lambda: GymEnv("Pendulum-v1"))] * num_jobs, - ... policy=RandomPolicy(BoundedTensorSpec(-1, 1, shape=(1,))), + ... policy=RandomPolicy(BoundedContinuous(-1, 1, shape=(1,))), ... launcher="submitit_delayed", ... ) ... for data in collector: diff --git a/torchrl/collectors/utils.py b/torchrl/collectors/utils.py index d777da3de2a..74bea267c22 100644 --- a/torchrl/collectors/utils.py +++ b/torchrl/collectors/utils.py @@ -11,6 +11,12 @@ from tensordict import NestedKey, pad, set_lazy_legacy, TensorDictBase +_NON_NN_POLICY_WEIGHTS = ( + "The policy is not an nn.Module. TorchRL will assume that the parameter set is empty and " + "update_policy_weights_ will be a no-op." +) + + def _stack_output(fun) -> Callable: def stacked_output_fun(*args, **kwargs): out = fun(*args, **kwargs) diff --git a/torchrl/data/__init__.py b/torchrl/data/__init__.py index c894c15724b..026a0b3baf2 100644 --- a/torchrl/data/__init__.py +++ b/torchrl/data/__init__.py @@ -59,19 +59,32 @@ TokenizedDatasetLoader, ) from .tensor_specs import ( + Binary, BinaryDiscreteTensorSpec, + Bounded, BoundedTensorSpec, + Categorical, + Composite, CompositeSpec, DEVICE_TYPING, DiscreteTensorSpec, LazyStackedCompositeSpec, LazyStackedTensorSpec, + MultiCategorical, MultiDiscreteTensorSpec, + MultiOneHot, MultiOneHotDiscreteTensorSpec, + NonTensor, NonTensorSpec, + OneHot, OneHotDiscreteTensorSpec, + Stacked, + StackedComposite, TensorSpec, + Unbounded, + UnboundedContinuous, UnboundedContinuousTensorSpec, + UnboundedDiscrete, UnboundedDiscreteTensorSpec, ) from .utils import check_no_exclusive_keys, consolidate_spec, contains_lazy_spec diff --git a/torchrl/data/datasets/minari_data.py b/torchrl/data/datasets/minari_data.py index 3cc8d7437c0..d6a49f17113 100644 --- a/torchrl/data/datasets/minari_data.py +++ b/torchrl/data/datasets/minari_data.py @@ -25,12 +25,7 @@ from torchrl.data.replay_buffers.samplers import Sampler from torchrl.data.replay_buffers.storages import TensorStorage from torchrl.data.replay_buffers.writers import ImmutableDatasetWriter, Writer -from torchrl.data.tensor_specs import ( - BoundedTensorSpec, - CompositeSpec, - DiscreteTensorSpec, - UnboundedContinuousTensorSpec, -) +from torchrl.data.tensor_specs import Bounded, Categorical, Composite, Unbounded from torchrl.envs.utils import _classproperty _has_tqdm = importlib.util.find_spec("tqdm", None) is not None @@ -398,24 +393,22 @@ def _proc_spec(spec): if spec is None: return if spec["type"] == "Dict": - return CompositeSpec( + return Composite( {key: _proc_spec(subspec) for key, subspec in spec["subspaces"].items()} ) elif spec["type"] == "Box": if all(item == -float("inf") for item in spec["low"]) and all( item == float("inf") for item in spec["high"] ): - return UnboundedContinuousTensorSpec( - spec["shape"], dtype=_DTYPE_DIR[spec["dtype"]] - ) - return BoundedTensorSpec( + return Unbounded(spec["shape"], dtype=_DTYPE_DIR[spec["dtype"]]) + return Bounded( shape=spec["shape"], low=torch.as_tensor(spec["low"]), high=torch.as_tensor(spec["high"]), dtype=_DTYPE_DIR[spec["dtype"]], ) elif spec["type"] == "Discrete": - return DiscreteTensorSpec( + return Categorical( spec["n"], shape=spec["shape"], dtype=_DTYPE_DIR[spec["dtype"]] ) else: diff --git a/torchrl/data/datasets/openx.py b/torchrl/data/datasets/openx.py index 975384a3662..2dbf0720a37 100644 --- a/torchrl/data/datasets/openx.py +++ b/torchrl/data/datasets/openx.py @@ -77,7 +77,7 @@ class for more information on how to interact with non-tensor data shuffle=False will also impact the sampling. We advice users to create a copy of the dataset where the ``shuffle`` attribute of the sampler is set to ``False`` if they wish to enjoy the two different - behaviours (shuffled and not) within the same code base. + behaviors (shuffled and not) within the same code base. num_slices (int, optional): the number of slices in a batch. This corresponds to the number of trajectories present in a batch. @@ -134,7 +134,7 @@ class for more information on how to interact with non-tensor data the dataset. This isn't possible at a reasonable cost with `streaming=True`: in this case, trajectories will be sampled one at a time and delivered as such (with cropping to comply with - the batch-size etc). The behaviour of the two modalities is + the batch-size etc). The behavior of the two modalities is much more similar when `num_slices` and `slice_len` are specified, as in these cases, views of sub-episodes will be returned in both cases. diff --git a/torchrl/data/replay_buffers/replay_buffers.py b/torchrl/data/replay_buffers/replay_buffers.py index a688bd8585e..2e0eeb80705 100644 --- a/torchrl/data/replay_buffers/replay_buffers.py +++ b/torchrl/data/replay_buffers/replay_buffers.py @@ -7,6 +7,7 @@ import collections import contextlib import json +import multiprocessing import textwrap import threading import warnings @@ -23,6 +24,7 @@ is_tensorclass, LazyStackedTensorDict, NestedKey, + TensorDict, TensorDictBase, unravel_key, ) @@ -120,7 +122,15 @@ class ReplayBuffer: >>> for d in data.unbind(1): ... rb.add(d) >>> rb.extend(data) + generator (torch.Generator, optional): a generator to use for sampling. + Using a dedicated generator for the replay buffer can allow a fine-grained control + over seeding, for instance keeping the global seed different but the RB seed identical + for distributed jobs. + Defaults to ``None`` (global default generator). + .. warning:: As of now, the generator has no effect on the transforms. + shared (bool, optional): whether the buffer will be shared using multiprocessing or not. + Defaults to ``False``. Examples: >>> import torch @@ -204,6 +214,8 @@ def __init__( batch_size: int | None = None, dim_extend: int | None = None, checkpointer: "StorageCheckpointerBase" | None = None, # noqa: F821 + generator: torch.Generator | None = None, + shared: bool = False, ) -> None: self._storage = storage if storage is not None else ListStorage(max_size=1_000) self._storage.attach(self) @@ -220,6 +232,9 @@ def __init__( if self._prefetch_cap: self._prefetch_executor = ThreadPoolExecutor(max_workers=self._prefetch_cap) + self.shared = shared + self.share(self.shared) + self._replay_lock = threading.RLock() self._futures_lock = threading.RLock() from torchrl.envs.transforms.transforms import ( @@ -262,6 +277,20 @@ def __init__( raise ValueError("dim_extend must be a positive value.") self.dim_extend = dim_extend self._storage.checkpointer = checkpointer + self.set_rng(generator=generator) + + def share(self, shared: bool = True): + self.shared = shared + if self.shared: + self._write_lock = multiprocessing.Lock() + else: + self._write_lock = contextlib.nullcontext() + + def set_rng(self, generator): + self._rng = generator + self._storage._rng = generator + self._sampler._rng = generator + self._writer._rng = generator @property def dim_extend(self): @@ -335,6 +364,11 @@ def __len__(self) -> int: with self._replay_lock: return len(self._storage) + @property + def write_count(self): + """The total number of items written so far in the buffer through add and extend.""" + return self._writer._write_count + def __repr__(self) -> str: from torchrl.envs.transforms import Compose @@ -414,6 +448,9 @@ def state_dict(self) -> Dict[str, Any]: "_writer": self._writer.state_dict(), "_transforms": self._transform.state_dict(), "_batch_size": self._batch_size, + "_rng": (self._rng.get_state().clone(), str(self._rng.device)) + if self._rng is not None + else None, } def load_state_dict(self, state_dict: Dict[str, Any]) -> None: @@ -422,6 +459,12 @@ def load_state_dict(self, state_dict: Dict[str, Any]) -> None: self._writer.load_state_dict(state_dict["_writer"]) self._transform.load_state_dict(state_dict["_transforms"]) self._batch_size = state_dict["_batch_size"] + rng = state_dict.get("_rng") + if rng is not None: + state, device = rng + rng = torch.Generator(device=device) + rng.set_state(state) + self.set_rng(generator=rng) def dumps(self, path): """Saves the replay buffer on disk at the specified path. @@ -465,6 +508,13 @@ def dumps(self, path): self._storage.dumps(path / "storage") self._sampler.dumps(path / "sampler") self._writer.dumps(path / "writer") + if self._rng is not None: + rng_state = TensorDict( + rng_state=self._rng.get_state().clone(), + device=self._rng.device, + ) + rng_state.memmap(path / "rng_state") + # fall back on state_dict for transforms transform_sd = self._transform.state_dict() if transform_sd: @@ -487,6 +537,11 @@ def loads(self, path): self._storage.loads(path / "storage") self._sampler.loads(path / "sampler") self._writer.loads(path / "writer") + if (path / "rng_state").exists(): + rng_state = TensorDict.load_memmap(path / "rng_state") + rng = torch.Generator(device=rng_state.device) + rng.set_state(rng_state["rng_state"]) + self.set_rng(rng) # fall back on state_dict for transforms if (path / "transform.t").exists(): self._transform.load_state_dict(torch.load(path / "transform.t")) @@ -540,13 +595,13 @@ def add(self, data: Any) -> int: return self._add(data) def _add(self, data): - with self._replay_lock: + with self._replay_lock, self._write_lock: index = self._writer.add(data) self._sampler.add(index) return index def _extend(self, data: Sequence) -> torch.Tensor: - with self._replay_lock: + with self._replay_lock, self._write_lock: if self.dim_extend > 0: data = self._transpose(data) index = self._writer.extend(data) @@ -594,7 +649,7 @@ def update_priority( if self.dim_extend > 0 and priority.ndim > 1: priority = self._transpose(priority).flatten() # priority = priority.flatten() - with self._replay_lock: + with self._replay_lock, self._write_lock: self._sampler.update_priority(index, priority, storage=self.storage) @pin_memory_output @@ -753,6 +808,12 @@ def __iter__(self): def __getstate__(self) -> Dict[str, Any]: state = self.__dict__.copy() + if self._rng is not None: + rng_state = TensorDict( + rng_state=self._rng.get_state().clone(), + device=self._rng.device, + ) + state["_rng"] = rng_state _replay_lock = state.pop("_replay_lock", None) _futures_lock = state.pop("_futures_lock", None) if _replay_lock is not None: @@ -762,6 +823,13 @@ def __getstate__(self) -> Dict[str, Any]: return state def __setstate__(self, state: Dict[str, Any]): + rngstate = None + if "_rng" in state: + rngstate = state["_rng"] + if rngstate is not None: + rng = torch.Generator(device=rngstate.device) + rng.set_state(rngstate["rng_state"]) + if "_replay_lock_placeholder" in state: state.pop("_replay_lock_placeholder") _replay_lock = threading.RLock() @@ -771,6 +839,8 @@ def __setstate__(self, state: Dict[str, Any]): _futures_lock = threading.RLock() state["_futures_lock"] = _futures_lock self.__dict__.update(state) + if rngstate is not None: + self.set_rng(rng) @property def sampler(self): @@ -995,6 +1065,15 @@ class TensorDictReplayBuffer(ReplayBuffer): >>> for d in data.unbind(1): ... rb.add(d) >>> rb.extend(data) + generator (torch.Generator, optional): a generator to use for sampling. + Using a dedicated generator for the replay buffer can allow a fine-grained control + over seeding, for instance keeping the global seed different but the RB seed identical + for distributed jobs. + Defaults to ``None`` (global default generator). + + .. warning:: As of now, the generator has no effect on the transforms. + shared (bool, optional): whether the buffer will be shared using multiprocessing or not. + Defaults to ``False``. Examples: >>> import torch @@ -1207,7 +1286,7 @@ def sample( if include_info is not None: warnings.warn( "include_info is going to be deprecated soon." - "The default behaviour has changed to `include_info=True` " + "The default behavior has changed to `include_info=True` " "to avoid bugs linked to wrongly preassigned values in the " "output tensordict." ) @@ -1327,6 +1406,15 @@ class TensorDictPrioritizedReplayBuffer(TensorDictReplayBuffer): >>> for d in data.unbind(1): ... rb.add(d) >>> rb.extend(data) + generator (torch.Generator, optional): a generator to use for sampling. + Using a dedicated generator for the replay buffer can allow a fine-grained control + over seeding, for instance keeping the global seed different but the RB seed identical + for distributed jobs. + Defaults to ``None`` (global default generator). + + .. warning:: As of now, the generator has no effect on the transforms. + shared (bool, optional): whether the buffer will be shared using multiprocessing or not. + Defaults to ``False``. Examples: >>> import torch @@ -1400,6 +1488,8 @@ def __init__( reduction: str = "max", batch_size: int | None = None, dim_extend: int | None = None, + generator: torch.Generator | None = None, + shared: bool = False, ) -> None: if storage is None: storage = ListStorage(max_size=1_000) @@ -1416,6 +1506,8 @@ def __init__( transform=transform, batch_size=batch_size, dim_extend=dim_extend, + generator=generator, + shared=shared, ) @@ -1454,12 +1546,18 @@ def update_tensordict_priority(self, data: TensorDictBase) -> None: class InPlaceSampler: """A sampler to write tennsordicts in-place. - To be used cautiously as this may lead to unexpected behaviour (i.e. tensordicts + .. warning:: This class is deprecated and will be removed in v0.7. + + To be used cautiously as this may lead to unexpected behavior (i.e. tensordicts overwritten during execution). """ def __init__(self, device: DEVICE_TYPING | None = None): + warnings.warn( + "InPlaceSampler has been deprecated and will be removed in v0.7.", + category=DeprecationWarning, + ) self.out = None if device is None: device = "cpu" @@ -1555,6 +1653,15 @@ class ReplayBufferEnsemble(ReplayBuffer): sampled according to the probabilities ``p``. Can also be passed to torchrl.data.replay_buffers.samplers.SamplerEnsemble` if the buffer is built explicitely. + generator (torch.Generator, optional): a generator to use for sampling. + Using a dedicated generator for the replay buffer can allow a fine-grained control + over seeding, for instance keeping the global seed different but the RB seed identical + for distributed jobs. + Defaults to ``None`` (global default generator). + + .. warning:: As of now, the generator has no effect on the transforms. + shared (bool, optional): whether the buffer will be shared using multiprocessing or not. + Defaults to ``False``. Examples: >>> from torchrl.envs import Compose, ToTensorImage, Resize, RenameTransform @@ -1644,6 +1751,8 @@ def __init__( p: Tensor = None, sample_from_all: bool = False, num_buffer_sampled: int | None = None, + generator: torch.Generator | None = None, + shared: bool = False, **kwargs, ): @@ -1680,6 +1789,8 @@ def __init__( transform=transform, batch_size=batch_size, collate_fn=collate_fn, + generator=generator, + shared=shared, **kwargs, ) diff --git a/torchrl/data/replay_buffers/samplers.py b/torchrl/data/replay_buffers/samplers.py index 582ac88f52d..45fede16cf5 100644 --- a/torchrl/data/replay_buffers/samplers.py +++ b/torchrl/data/replay_buffers/samplers.py @@ -22,7 +22,7 @@ from torchrl._extension import EXTENSION_WARNING -from torchrl._utils import _replace_last, implement_for, logger +from torchrl._utils import _replace_last, logger from torchrl.data.replay_buffers.storages import Storage, StorageEnsemble, TensorStorage from torchrl.data.replay_buffers.utils import _is_int, unravel_index @@ -46,6 +46,9 @@ class Sampler(ABC): # need to keep track of the number of remaining batches _remaining_batches = int(torch.iinfo(torch.int64).max) + # The RNG is set by the replay buffer + _rng: torch.Generator | None = None + @abstractmethod def sample(self, storage: Storage, batch_size: int) -> Tuple[Any, dict]: ... @@ -64,7 +67,7 @@ def update_priority( storage: Storage | None = None, ) -> dict | None: warnings.warn( - f"Calling update_priority() on a sampler {type(self).__name__} that is not prioritized. Make sure this is the indented behaviour." + f"Calling update_priority() on a sampler {type(self).__name__} that is not prioritized. Make sure this is the indented behavior." ) return @@ -105,6 +108,11 @@ def loads(self, path): def __repr__(self): return f"{self.__class__.__name__}()" + def __getstate__(self): + state = copy(self.__dict__) + state["_rng"] = None + return state + class RandomSampler(Sampler): """A uniformly random sampler for composable replay buffers. @@ -192,7 +200,9 @@ def _get_sample_list(self, storage: Storage, len_storage: int, batch_size: int): device = storage.device if hasattr(storage, "device") else None if self.shuffle: - _sample_list = torch.randperm(len_storage, device=device) + _sample_list = torch.randperm( + len_storage, device=device, generator=self._rng + ) else: _sample_list = torch.arange(len_storage, device=device) self._sample_list = _sample_list @@ -385,13 +395,28 @@ def __repr__(self): def max_size(self): return self._max_capacity + @property + def alpha(self): + return self._alpha + + @alpha.setter + def alpha(self, value): + self._alpha = value + + @property + def beta(self): + return self._beta + + @beta.setter + def beta(self, value): + self._beta = value + def __getstate__(self): if get_spawning_popen() is not None: raise RuntimeError( f"Samplers of type {type(self)} cannot be shared between processes." ) - state = copy(self.__dict__) - return state + return super().__getstate__() def _init(self): if self.dtype in (torch.float, torch.FloatType, torch.float32): @@ -473,7 +498,11 @@ def sample(self, storage: Storage, batch_size: int) -> torch.Tensor: raise RuntimeError("non-positive p_min") # For some undefined reason, only np.random works here. # All PT attempts fail, even when subsequently transformed into numpy - mass = np.random.uniform(0.0, p_sum, size=batch_size) + if self._rng is None: + mass = np.random.uniform(0.0, p_sum, size=batch_size) + else: + mass = torch.rand(batch_size, generator=self._rng) * p_sum + # mass = torch.zeros(batch_size, dtype=torch.double).uniform_(0.0, p_sum) # mass = torch.rand(batch_size).mul_(p_sum) index = self._sum_tree.scan_lower_bound(mass) @@ -929,7 +958,7 @@ def __getstate__(self): f"one process will NOT erase the cache on another process's sampler, " f"which will cause synchronization issues." ) - state = copy(self.__dict__) + state = super().__getstate__() state["_cache"] = {} return state @@ -969,22 +998,20 @@ def _find_start_stop_traj( # faster end = trajectory[:-1] != trajectory[1:] - end = torch.cat([end, trajectory[-1:] != trajectory[:1]], 0) + if not at_capacity: + end = torch.cat([end, torch.ones_like(end[:1])], 0) + else: + end = torch.cat([end, trajectory[-1:] != trajectory[:1]], 0) length = trajectory.shape[0] else: - # TODO: check that storage is at capacity here, if not we need to assume that the last element of end is True - # We presume that not done at the end means that the traj spans across end and beginning of storage length = end.shape[0] + if not at_capacity: + end = end.clone() + end[length - 1] = True + ndim = end.ndim - if not at_capacity: - end = torch.index_fill( - end, - index=torch.tensor(-1, device=end.device, dtype=torch.long), - dim=0, - value=1, - ) - else: + if at_capacity: # we must have at least one end by traj to individuate trajectories # so if no end can be found we set it manually if cursor is not None: @@ -1006,7 +1033,6 @@ def _find_start_stop_traj( mask = ~end.any(0, True) mask = torch.cat([torch.zeros_like(end[:-1]), mask]) end = torch.masked_fill(mask, end, 1) - ndim = end.ndim if ndim == 0: raise RuntimeError( "Expected the end-of-trajectory signal to be at least 1-dimensional." @@ -1063,9 +1089,28 @@ def _tensor_slices_from_startend(self, seq_length, start, storage_length): # seq_length is a 1d tensor indicating the desired length of each sequence if isinstance(seq_length, int): - result = torch.cat( - [self._start_to_end(_start, length=seq_length) for _start in start] + arange = torch.arange(seq_length, device=start.device, dtype=start.dtype) + ndims = start.shape[-1] - 1 if (start.ndim - 1) else 0 + if ndims: + arange_reshaped = torch.empty( + arange.shape + torch.Size([ndims + 1]), + device=start.device, + dtype=start.dtype, + ) + arange_reshaped[..., 0] = arange + arange_reshaped[..., 1:] = 0 + else: + arange_reshaped = arange.unsqueeze(-1) + arange_expanded = arange_reshaped.expand( + torch.Size([start.shape[0]]) + arange_reshaped.shape ) + if start.shape != arange_expanded.shape: + n_missing_dims = arange_expanded.dim() - start.dim() + start_expanded = start[ + (slice(None),) + (None,) * n_missing_dims + ].expand_as(arange_expanded) + result = (start_expanded + arange_expanded).flatten(0, 1) + else: # when padding is needed result = torch.cat( @@ -1094,7 +1139,7 @@ def _get_stop_and_length(self, storage, fallback=True): "Could not get a tensordict out of the storage, which is required for SliceSampler to compute the trajectories." ) vals = self._find_start_stop_traj( - trajectory=trajectory, + trajectory=trajectory.clone(), at_capacity=storage._is_full, cursor=getattr(storage, "_last_cursor", None), ) @@ -1187,7 +1232,9 @@ def _sample_slices( # start_idx and stop_idx are 2d tensors organized like a non-zero def get_traj_idx(maxval): - return torch.randint(maxval, (num_slices,), device=lengths.device) + return torch.randint( + maxval, (num_slices,), device=lengths.device, generator=self._rng + ) if (lengths < seq_length).any(): if self.strict_length: @@ -1290,7 +1337,8 @@ def _get_index( start_point = -span_right relative_starts = ( - torch.rand(num_slices, device=lengths.device) * (end_point - start_point) + torch.rand(num_slices, device=lengths.device, generator=self._rng) + * (end_point - start_point) ).floor().to(start_idx.dtype) + start_point if self.span[0]: @@ -1800,34 +1848,38 @@ def __repr__(self): def __getstate__(self): state = SliceSampler.__getstate__(self) state.update(PrioritizedSampler.__getstate__(self)) + return state def mark_update( self, index: Union[int, torch.Tensor], *, storage: Storage | None = None ) -> None: return PrioritizedSampler.mark_update(self, index, storage=storage) - @implement_for("torch", "2.4") def _padded_indices(self, shapes, arange) -> torch.Tensor: # this complex mumbo jumbo creates a left padded tensor with valid indices on the right, e.g. # tensor([[ 0, 1, 2, 3, 4], # [-1, -1, 5, 6, 7], # [-1, 8, 9, 10, 11]]) # where the -1 items on the left are padded values - st, off = torch._nested_compute_contiguous_strides_offsets(shapes.flip(0)) - nt = torch._nested_view_from_buffer( - arange.flip(0).contiguous(), shapes.flip(0), st, off + num_groups = shapes.shape[0] + max_group_len = shapes.max() + pad_lengths = max_group_len - shapes + + # Get all the start and end indices within arange for each group + group_ends = shapes.cumsum(0) + group_starts = torch.empty_like(group_ends) + group_starts[0] = 0 + group_starts[1:] = group_ends[:-1] + pad = torch.empty( + (num_groups, max_group_len), dtype=arange.dtype, device=arange.device ) - pad = nt.to_padded_tensor(-1).flip(-1).flip(0) - return pad + for pad_row, group_start, group_end, pad_len in zip( + pad, group_starts, group_ends, pad_lengths + ): + pad_row[:pad_len] = -1 + pad_row[pad_len:] = arange[group_start:group_end] - @implement_for("torch", None, "2.4") - def _padded_indices(self, shapes, arange) -> torch.Tensor: # noqa: F811 - arange = arange.flip(0).split(shapes.flip(0).squeeze().unbind()) - return ( - torch.nn.utils.rnn.pad_sequence(arange, batch_first=True, padding_value=-1) - .flip(-1) - .flip(0) - ) + return pad def _preceding_stop_idx(self, storage, lengths, seq_length, start_idx): preceding_stop_idx = self._cache.get("preceding_stop_idx") @@ -2033,6 +2085,7 @@ class SamplerEnsemble(Sampler): def __init__( self, *samplers, p=None, sample_from_all=False, num_buffer_sampled=None ): + self._rng_private = None self._samplers = samplers self.sample_from_all = sample_from_all if sample_from_all and p is not None: @@ -2042,6 +2095,16 @@ def __init__( self.p = p self.num_buffer_sampled = num_buffer_sampled + @property + def _rng(self): + return self._rng_private + + @_rng.setter + def _rng(self, value): + self._rng_private = value + for sampler in self._samplers: + sampler._rng = value + @property def p(self): return self._p @@ -2082,7 +2145,10 @@ def sample(self, storage, batch_size): else: if self.p is None: buffer_ids = torch.randint( - len(self._samplers), (self.num_buffer_sampled,) + len(self._samplers), + (self.num_buffer_sampled,), + generator=self._rng, + device=getattr(storage, "device", None), ) else: buffer_ids = torch.multinomial(self.p, self.num_buffer_sampled, True) diff --git a/torchrl/data/replay_buffers/scheduler.py b/torchrl/data/replay_buffers/scheduler.py new file mode 100644 index 00000000000..6829424c620 --- /dev/null +++ b/torchrl/data/replay_buffers/scheduler.py @@ -0,0 +1,267 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +from __future__ import annotations + +from abc import ABC, abstractmethod + +from typing import Any, Callable, Dict + +import numpy as np + +import torch + +from torchrl.data.replay_buffers.replay_buffers import ReplayBuffer +from torchrl.data.replay_buffers.samplers import Sampler + + +class ParameterScheduler(ABC): + """Scheduler to adjust the value of a given parameter of a replay buffer's sampler. + + Scheduler can for example be used to alter the alpha and beta values in the PrioritizedSampler. + + Args: + obj (ReplayBuffer or Sampler): the replay buffer or sampler whose sampler to adjust + param_name (str): the name of the attribute to adjust, e.g. `beta` to adjust the beta parameter + min_value (Union[int, float], optional): a lower bound for the parameter to be adjusted + Defaults to `None`. + max_value (Union[int, float], optional): an upper bound for the parameter to be adjusted + Defaults to `None`. + + """ + + def __init__( + self, + obj: ReplayBuffer | Sampler, + param_name: str, + min_value: int | float | None = None, + max_value: int | float | None = None, + ): + if not isinstance(obj, (ReplayBuffer, Sampler)): + raise TypeError( + f"ParameterScheduler only supports Sampler class. Pass either `ReplayBuffer` or `Sampler` object. Got {type(obj)} instead." + ) + self.sampler = obj.sampler if isinstance(obj, ReplayBuffer) else obj + self.param_name = param_name + self._min_val = min_value or float("-inf") + self._max_val = max_value or float("inf") + if not hasattr(self.sampler, self.param_name): + raise ValueError( + f"Provided class {type(obj).__name__} does not have an attribute {param_name}" + ) + initial_val = getattr(self.sampler, self.param_name) + if isinstance(initial_val, torch.Tensor): + initial_val = initial_val.clone() + self.backend = torch + else: + self.backend = np + self.initial_val = initial_val + self._step_cnt = 0 + + def state_dict(self): + """Returns the state of the scheduler as a :class:`dict`. + + It contains an entry for every variable in ``self.__dict__`` which + is not the sampler. + """ + sd = dict(self.__dict__) + del sd["sampler"] + return sd + + def load_state_dict(self, state_dict: Dict[str, Any]): + """Load the scheduler's state. + + Args: + state_dict (dict): scheduler state. Should be an object returned + from a call to :meth:`state_dict`. + """ + self.__dict__.update(state_dict) + + def step(self): + self._step_cnt += 1 + # Apply the step function + new_value = self._step() + # clip value to specified range + new_value_clipped = self.backend.clip(new_value, self._min_val, self._max_val) + # Set the new value of the parameter dynamically + setattr(self.sampler, self.param_name, new_value_clipped) + + @abstractmethod + def _step(self): + ... + + +class LambdaScheduler(ParameterScheduler): + """Sets a parameter to its initial value times a given function. + + Similar to :class:`~torch.optim.LambdaLR`. + + Args: + obj (ReplayBuffer or Sampler): the replay buffer whose sampler to adjust (or the sampler itself). + param_name (str): the name of the attribute to adjust, e.g. `beta` to adjust the + beta parameter. + lambda_fn (Callable[[int], float]): A function which computes a multiplicative factor given an integer + parameter ``step_count``. + min_value (Union[int, float], optional): a lower bound for the parameter to be adjusted + Defaults to `None`. + max_value (Union[int, float], optional): an upper bound for the parameter to be adjusted + Defaults to `None`. + + """ + + def __init__( + self, + obj: ReplayBuffer | Sampler, + param_name: str, + lambda_fn: Callable[[int], float], + min_value: int | float | None = None, + max_value: int | float | None = None, + ): + super().__init__(obj, param_name, min_value, max_value) + self.lambda_fn = lambda_fn + + def _step(self): + return self.initial_val * self.lambda_fn(self._step_cnt) + + +class LinearScheduler(ParameterScheduler): + """A linear scheduler for gradually altering a parameter in an object over a given number of steps. + + This scheduler linearly interpolates between the initial value of the parameter and a final target value. + + Args: + obj (ReplayBuffer or Sampler): the replay buffer whose sampler to adjust (or the sampler itself). + param_name (str): the name of the attribute to adjust, e.g. `beta` to adjust the + beta parameter. + final_value (number): The final value that the parameter will reach after the + specified number of steps. + num_steps (number, optional): The total number of steps over which the parameter + will be linearly altered. + + Example: + >>> # xdoctest: +SKIP + >>> # Assuming sampler uses initial beta = 0.6 + >>> # beta = 0.7 if step == 1 + >>> # beta = 0.8 if step == 2 + >>> # beta = 0.9 if step == 3 + >>> # beta = 1.0 if step >= 4 + >>> scheduler = LinearScheduler(sampler, param_name='beta', final_value=1.0, num_steps=4) + >>> for epoch in range(100): + >>> train(...) + >>> validate(...) + >>> scheduler.step() + """ + + def __init__( + self, + obj: ReplayBuffer | Sampler, + param_name: str, + final_value: int | float, + num_steps: int, + ): + super().__init__(obj, param_name) + if isinstance(self.initial_val, torch.Tensor): + # cast to same type as initial value + final_value = torch.tensor(final_value).to(self.initial_val) + self.final_val = final_value + self.num_steps = num_steps + self._delta = (self.final_val - self.initial_val) / self.num_steps + + def _step(self): + # Nit: we should use torch.where instead than if/else here to make the scheduler compatible with compile + # without graph breaks + if self._step_cnt < self.num_steps: + return self.initial_val + (self._delta * self._step_cnt) + else: + return self.final_val + + +class StepScheduler(ParameterScheduler): + """A step scheduler that alters a parameter after every n steps using either multiplicative or additive changes. + + The scheduler can apply: + 1. Multiplicative changes: `new_val = curr_val * gamma` + 2. Additive changes: `new_val = curr_val + gamma` + + Args: + obj (ReplayBuffer or Sampler): the replay buffer whose sampler to adjust (or the sampler itself). + param_name (str): the name of the attribute to adjust, e.g. `beta` to adjust the + beta parameter. + gamma (int or float, optional): The value by which to adjust the parameter, + either in a multiplicative or additive way. + n_steps (int, optional): The number of steps after which the parameter should be altered. + Defaults to 1. + mode (str, optional): The mode of scheduling. Can be either `'multiplicative'` or `'additive'`. + Defaults to `'multiplicative'`. + min_value (int or float, optional): a lower bound for the parameter to be adjusted. + Defaults to `None`. + max_value (int or float, optional): an upper bound for the parameter to be adjusted. + Defaults to `None`. + + Example: + >>> # xdoctest: +SKIP + >>> # Assuming sampler uses initial beta = 0.6 + >>> # beta = 0.6 if 0 <= step < 10 + >>> # beta = 0.7 if 10 <= step < 20 + >>> # beta = 0.8 if 20 <= step < 30 + >>> # beta = 0.9 if 30 <= step < 40 + >>> # beta = 1.0 if 40 <= step + >>> scheduler = StepScheduler(sampler, param_name='beta', gamma=0.1, mode='additive', max_value=1.0) + >>> for epoch in range(100): + >>> train(...) + >>> validate(...) + >>> scheduler.step() + """ + + def __init__( + self, + obj: ReplayBuffer | Sampler, + param_name: str, + gamma: int | float = 0.9, + n_steps: int = 1, + mode: str = "multiplicative", + min_value: int | float | None = None, + max_value: int | float | None = None, + ): + + super().__init__(obj, param_name, min_value, max_value) + self.gamma = gamma + self.n_steps = n_steps + self.mode = mode + if mode == "additive": + operator = self.backend.add + elif mode == "multiplicative": + operator = self.backend.multiply + else: + raise ValueError( + f"Invalid mode: {mode}. Choose 'multiplicative' or 'additive'." + ) + self.operator = operator + + def _step(self): + """Applies the scheduling logic to alter the parameter value every `n_steps`.""" + # Check if the current step count is a multiple of n_steps + current_val = getattr(self.sampler, self.param_name) + # Nit: we should use torch.where instead than if/else here to make the scheduler compatible with compile + # without graph breaks + if self._step_cnt % self.n_steps == 0: + return self.operator(current_val, self.gamma) + else: + return current_val + + +class SchedulerList: + """Simple container abstracting a list of schedulers.""" + + def __init__(self, schedulers: list[ParameterScheduler]) -> None: + if isinstance(schedulers, ParameterScheduler): + schedulers = [schedulers] + self.schedulers = schedulers + + def append(self, scheduler: ParameterScheduler): + self.schedulers.append(scheduler) + + def step(self): + for scheduler in self.schedulers: + scheduler.step() diff --git a/torchrl/data/replay_buffers/storages.py b/torchrl/data/replay_buffers/storages.py index 58b1729296d..a36c59b66d9 100644 --- a/torchrl/data/replay_buffers/storages.py +++ b/torchrl/data/replay_buffers/storages.py @@ -5,6 +5,8 @@ from __future__ import annotations import abc + +import logging import os import textwrap import warnings @@ -22,6 +24,7 @@ TensorDict, TensorDictBase, ) +from tensordict.base import _NESTED_TENSORS_AS_LISTS from tensordict.memmap import MemoryMappedTensor from torch import multiprocessing as mp from torch.utils._pytree import tree_flatten, tree_map, tree_unflatten @@ -54,6 +57,7 @@ class Storage: ndim = 1 max_size: int _default_checkpointer: StorageCheckpointerBase = StorageCheckpointerBase + _rng: torch.Generator | None = None def __init__( self, max_size: int, checkpointer: StorageCheckpointerBase | None = None @@ -142,7 +146,13 @@ def _empty(self): def _rand_given_ndim(self, batch_size): # a method to return random indices given the storage ndim if self.ndim == 1: - return torch.randint(0, len(self), (batch_size,)) + return torch.randint( + 0, + len(self), + (batch_size,), + generator=self._rng, + device=getattr(self, "device", None), + ) raise RuntimeError( f"Random number generation is not implemented for storage of type {type(self)} with ndim {self.ndim}. " f"Please report this exception as well as the use case (incl. buffer construction) on github." @@ -185,6 +195,11 @@ def load(self, *args, **kwargs): """Alias for :meth:`~.loads`.""" return self.loads(*args, **kwargs) + def __getstate__(self): + state = copy(self.__dict__) + state["_rng"] = None + return state + class ListStorage(Storage): """A storage stored in a list. @@ -299,7 +314,7 @@ def __getstate__(self): raise RuntimeError( f"Cannot share a storage of type {type(self)} between processes." ) - state = copy(self.__dict__) + state = super().__getstate__() return state def __repr__(self): @@ -497,7 +512,10 @@ def _rand_given_ndim(self, batch_size): if self.ndim == 1: return super()._rand_given_ndim(batch_size) shape = self.shape - return tuple(torch.randint(_dim, (batch_size,)) for _dim in shape) + return tuple( + torch.randint(_dim, (batch_size,), generator=self._rng, device=self.device) + for _dim in shape + ) def flatten(self): if self.ndim == 1: @@ -522,7 +540,7 @@ def flatten(self): ) def __getstate__(self): - state = copy(self.__dict__) + state = super().__getstate__() if get_spawning_popen() is None: length = self._len del state["_len_value"] @@ -539,15 +557,24 @@ def __getstate__(self): # check that the content is shared, otherwise tell the user we can't help storage = self._storage STORAGE_ERR = "The storage must be place in shared memory or memmapped before being shared between processes." + + # If the content is on cpu, it will be placed in shared memory. + # If it's on cuda it's already shared. + # If it's memmaped no worry in this case either. + # Only if the device is not "cpu" or "cuda" we may have a problem. + def assert_is_sharable(tensor): + if tensor.device is None or tensor.device.type in ( + "cuda", + "cpu", + "meta", + ): + return + raise RuntimeError(STORAGE_ERR) + if is_tensor_collection(storage): - if not storage.is_memmap() and not storage.is_shared(): - raise RuntimeError(STORAGE_ERR) + storage.apply(assert_is_sharable, filter_empty=True) else: - if ( - not isinstance(storage, MemoryMappedTensor) - and not storage.is_shared() - ): - raise RuntimeError(STORAGE_ERR) + tree_map(storage, assert_is_sharable) return state @@ -722,7 +749,7 @@ def set( # noqa: F811 "A cursor of length superior to the storage capacity was provided. " "To accommodate for this, the cursor will be truncated to its last " "element such that its length matched the length of the storage. " - "This may **not** be the optimal behaviour for your application! " + "This may **not** be the optimal behavior for your application! " "Make sure that the storage capacity is big enough to support the " "batch size provided." ) @@ -882,9 +909,7 @@ def max_size_along_dim0(data_shape): if is_tensor_collection(data): out = data.to(self.device) - out = out.expand(max_size_along_dim0(data.shape)) - out = out.clone() - out = out.zero_() + out = torch.empty_like(out.expand(max_size_along_dim0(data.shape))) else: # if Tensor, we just create a MemoryMappedTensor of the desired shape, device and dtype out = tree_map( @@ -906,6 +931,8 @@ class LazyMemmapStorage(LazyTensorStorage): Args: max_size (int): size of the storage, i.e. maximum number of elements stored in the buffer. + + Keyword Args: scratch_dir (str or path): directory where memmap-tensors will be written. device (torch.device, optional): device where the sampled tensors will be stored and sent. Default is :obj:`torch.device("cpu")`. @@ -916,6 +943,9 @@ class LazyMemmapStorage(LazyTensorStorage): measuring the storage size. For instance, a storage of shape ``[3, 4]`` has capacity ``3`` if ``ndim=1`` and ``12`` if ``ndim=2``. Defaults to ``1``. + existsok (bool, optional): whether an error should be raised if any of the + tensors already exists on disk. Defaults to ``True``. If ``False``, the + tensor will be opened as is, not overewritten. .. note:: When checkpointing a ``LazyMemmapStorage``, one can provide a path identical to where the storage is already stored to avoid executing long copies of data that is already stored on disk. @@ -992,10 +1022,12 @@ def __init__( scratch_dir=None, device: torch.device = "cpu", ndim: int = 1, + existsok: bool = False, ): super().__init__(max_size, ndim=ndim) self.initialized = False self.scratch_dir = None + self.existsok = existsok if scratch_dir is not None: self.scratch_dir = str(scratch_dir) if self.scratch_dir[-1] != "/": @@ -1091,17 +1123,23 @@ def max_size_along_dim0(data_shape): if is_tensor_collection(data): out = data.clone().to(self.device) out = out.expand(max_size_along_dim0(data.shape)) - out = out.memmap_like(prefix=self.scratch_dir) - for key, tensor in sorted( - out.items(include_nested=True, leaves_only=True), key=str - ): - try: - filesize = os.path.getsize(tensor.filename) / 1024 / 1024 - torchrl_logger.debug( - f"\t{key}: {tensor.filename}, {filesize} Mb of storage (size: {tensor.shape})." - ) - except (AttributeError, RuntimeError): - pass + out = out.memmap_like(prefix=self.scratch_dir, existsok=self.existsok) + if torchrl_logger.isEnabledFor(logging.DEBUG): + for key, tensor in sorted( + out.items( + include_nested=True, + leaves_only=True, + is_leaf=_NESTED_TENSORS_AS_LISTS, + ), + key=str, + ): + try: + filesize = os.path.getsize(tensor.filename) / 1024 / 1024 + torchrl_logger.debug( + f"\t{key}: {tensor.filename}, {filesize} Mb of storage (size: {tensor.shape})." + ) + except (AttributeError, RuntimeError): + pass else: out = _init_pytree(self.scratch_dir, max_size_along_dim0, data) self._storage = out @@ -1142,6 +1180,7 @@ def __init__( *storages: Storage, transforms: List["Transform"] = None, # noqa: F821 ): + self._rng_private = None self._storages = storages self._transforms = transforms if transforms is not None and len(transforms) != len(storages): @@ -1149,6 +1188,16 @@ def __init__( "transforms must have the same length as the storages " "provided." ) + @property + def _rng(self): + return self._rng_private + + @_rng.setter + def _rng(self, value): + self._rng_private = value + for storage in self._storages: + storage._rng = value + @property def _attached_entities(self): return set() diff --git a/torchrl/data/replay_buffers/utils.py b/torchrl/data/replay_buffers/utils.py index 3a4141fd218..15a90f1a8f5 100644 --- a/torchrl/data/replay_buffers/utils.py +++ b/torchrl/data/replay_buffers/utils.py @@ -802,7 +802,7 @@ def _path2str(path, default_name=None): if result == default_name: raise RuntimeError( "A tensor had the same identifier as the default name used when the buffer contains " - f"a single tensor (name={default_name}). This behaviour is not allowed. Please rename your " + f"a single tensor (name={default_name}). This behavior is not allowed. Please rename your " f"tensor in the map/dict or set a new default name with the environment variable SINGLE_TENSOR_BUFFER_NAME." ) return result diff --git a/torchrl/data/replay_buffers/writers.py b/torchrl/data/replay_buffers/writers.py index ea3b2b4a047..3a95c3975cc 100644 --- a/torchrl/data/replay_buffers/writers.py +++ b/torchrl/data/replay_buffers/writers.py @@ -38,6 +38,7 @@ class Writer(ABC): """A ReplayBuffer base Writer class.""" _storage: Storage + _rng: torch.Generator | None = None def __init__(self) -> None: self._storage = None @@ -103,6 +104,11 @@ def _replicate_index(self, index): def __repr__(self): return f"{self.__class__.__name__}()" + def __getstate__(self): + state = copy(self.__dict__) + state["_rng"] = None + return state + class ImmutableDatasetWriter(Writer): """A blocking writer for immutable datasets.""" @@ -157,6 +163,7 @@ def add(self, data: Any) -> int | torch.Tensor: self._cursor = (self._cursor + 1) % self._storage._max_size_along_dim0( single_data=data ) + self._write_count += 1 # Replicate index requires the shape of the storage to be known # Other than that, a "flat" (1d) index is ok to write the data self._storage.set(_cursor, data) @@ -185,6 +192,7 @@ def extend(self, data: Sequence) -> torch.Tensor: ) # we need to update the cursor first to avoid race conditions between workers self._cursor = (batch_size + cur_size) % max_size_along0 + self._write_count += batch_size # Replicate index requires the shape of the storage to be known # Other than that, a "flat" (1d) index is ok to write the data self._storage.set(index, data) @@ -216,8 +224,22 @@ def _cursor(self, value): _cursor_value = self._cursor_value = mp.Value("i", 0) _cursor_value.value = value + @property + def _write_count(self): + _write_count = self.__dict__.get("_write_count_value", None) + if _write_count is None: + _write_count = self._write_count_value = mp.Value("i", 0) + return _write_count.value + + @_write_count.setter + def _write_count(self, value): + _write_count = self.__dict__.get("_write_count_value", None) + if _write_count is None: + _write_count = self._write_count_value = mp.Value("i", 0) + _write_count.value = value + def __getstate__(self): - state = copy(self.__dict__) + state = super().__getstate__() if get_spawning_popen() is None: cursor = self._cursor del state["_cursor_value"] @@ -243,6 +265,7 @@ def add(self, data: Any) -> int | torch.Tensor: # we need to update the cursor first to avoid race conditions between workers max_size_along_dim0 = self._storage._max_size_along_dim0(single_data=data) self._cursor = (index + 1) % max_size_along_dim0 + self._write_count += 1 if not is_tensorclass(data): data.set( "index", @@ -269,6 +292,7 @@ def extend(self, data: Sequence) -> torch.Tensor: ) # we need to update the cursor first to avoid race conditions between workers self._cursor = (batch_size + cur_size) % max_size_along_dim0 + self._write_count += batch_size # storage must convert the data to the appropriate format if needed if not is_tensorclass(data): data.set( @@ -454,6 +478,20 @@ def get_insert_index(self, data: Any) -> int: return ret + @property + def _write_count(self): + _write_count = self.__dict__.get("_write_count_value", None) + if _write_count is None: + _write_count = self._write_count_value = mp.Value("i", 0) + return _write_count.value + + @_write_count.setter + def _write_count(self, value): + _write_count = self.__dict__.get("_write_count_value", None) + if _write_count is None: + _write_count = self._write_count_value = mp.Value("i", 0) + _write_count.value = value + def add(self, data: Any) -> int | torch.Tensor: """Inserts a single element of data at an appropriate index, and returns that index. @@ -463,6 +501,7 @@ def add(self, data: Any) -> int | torch.Tensor: index = self.get_insert_index(data) if index is not None: data.set("index", index) + self._write_count += 1 # Replicate index requires the shape of the storage to be known # Other than that, a "flat" (1d) index is ok to write the data self._storage.set(index, data) @@ -482,6 +521,7 @@ def extend(self, data: TensorDictBase) -> None: for data_idx, sample in enumerate(data): storage_idx = self.get_insert_index(sample) if storage_idx is not None: + self._write_count += 1 data_to_replace[storage_idx] = data_idx # -1 will be interpreted as invalid by prioritized buffers @@ -511,9 +551,10 @@ def _empty(self) -> None: def __getstate__(self): if get_spawning_popen() is not None: raise RuntimeError( - f"Writers of type {type(self)} cannot be shared between processes." + f"Writers of type {type(self)} cannot be shared between processes. " + f"Please submit an issue at https://github.com/pytorch/rl if this feature is needed." ) - state = copy(self.__dict__) + state = super().__getstate__() return state def dumps(self, path): @@ -582,8 +623,19 @@ class WriterEnsemble(Writer): """ def __init__(self, *writers): + self._rng_private = None self._writers = writers + @property + def _rng(self): + return self._rng_private + + @_rng.setter + def _rng(self, value): + self._rng_private = value + for writer in self._writers: + writer._rng = value + def _empty(self): raise NotImplementedError diff --git a/torchrl/data/tensor_specs.py b/torchrl/data/tensor_specs.py index 7c787b3ccfc..98a32de5715 100644 --- a/torchrl/data/tensor_specs.py +++ b/torchrl/data/tensor_specs.py @@ -6,6 +6,7 @@ from __future__ import annotations import abc +import enum import math import warnings from collections.abc import Iterable @@ -20,6 +21,7 @@ Generic, List, Optional, + overload, Sequence, Tuple, TypeVar, @@ -38,6 +40,7 @@ TensorDictBase, unravel_key, ) +from tensordict.base import NO_DEFAULT from tensordict.utils import _getitem_batch_size, NestedKey from torchrl._utils import _make_ordinal_device, get_binary_env_var @@ -72,14 +75,27 @@ _DEFAULT_SHAPE = torch.Size((1,)) -DEVICE_ERR_MSG = "device of empty CompositeSpec is not defined." +DEVICE_ERR_MSG = "device of empty Composite is not defined." NOT_IMPLEMENTED_ERROR = NotImplementedError( "method is not currently implemented." " If you are interested in this feature please submit" " an issue at https://github.com/pytorch/rl/issues" ) -NO_DEFAULT = object() + +def _size(list_of_ints): + # ensures that np int64 elements don't slip through Size + # see https://github.com/pytorch/pytorch/issues/127194 + return torch.Size([int(i) for i in list_of_ints]) + + +# Akin to TD's NO_DEFAULT but won't raise a KeyError when found in a TD or used as default +class _NoDefault(enum.IntEnum): + ZERO = 0 + ONE = 1 + + +NO_DEFAULT_RL = _NoDefault.ONE def _default_dtype_and_device( @@ -190,7 +206,7 @@ def _shape_indexing( Shape of the resulting spec Examples: >>> idx = (2, ..., None) - >>> DiscreteTensorSpec(2, shape=(3, 4))[idx].shape + >>> Categorical(2, shape=(3, 4))[idx].shape torch.Size([4, 1]) >>> _shape_indexing([3, 4], idx) torch.Size([4, 1]) @@ -350,7 +366,7 @@ def to(self, dest: Union[torch.dtype, DEVICE_TYPING]) -> ContinuousBox: def __repr__(self): return f"{self.__class__.__name__}()" - def clone(self) -> DiscreteBox: + def clone(self) -> CategoricalBox: return deepcopy(self) @@ -387,16 +403,6 @@ def high(self, value): self.device = value.device self._high = value.cpu() - @low.setter - def low(self, value): - self.device = value.device - self._low = value.cpu() - - @high.setter - def high(self, value): - self.device = value.device - self._high = value.cpu() - def __post_init__(self): self.low = self.low.clone() self.high = self.high.clone() @@ -450,19 +456,25 @@ def __eq__(self, other): @dataclass(repr=False) -class DiscreteBox(Box): - """A box of discrete values.""" +class CategoricalBox(Box): + """A box of discrete, categorical values.""" n: int register = invertible_dict() - def to(self, dest: Union[torch.dtype, DEVICE_TYPING]) -> DiscreteBox: + def to(self, dest: Union[torch.dtype, DEVICE_TYPING]) -> CategoricalBox: return deepcopy(self) def __repr__(self): return f"{self.__class__.__name__}(n={self.n})" +class DiscreteBox(CategoricalBox): + """Deprecated version of :class:`CategoricalBox`.""" + + ... + + @dataclass(repr=False) class BoxList(Box): """A box of discrete values.""" @@ -485,7 +497,7 @@ def __len__(self): @staticmethod def from_nvec(nvec: torch.Tensor): if nvec.ndim == 0: - return DiscreteBox(nvec.item()) + return CategoricalBox(nvec.item()) else: return BoxList([BoxList.from_nvec(n) for n in nvec.unbind(-1)]) @@ -505,14 +517,30 @@ def __repr__(self): @dataclass(repr=False) class TensorSpec: - """Parent class of the tensor meta-data containers for observation, actions and rewards. + """Parent class of the tensor meta-data containers. + + TorchRL's TensorSpec are used to present what input/output is to be expected for a specific class, + or sometimes to simulate simple behaviors by generating random data within a defined space. + + TensorSpecs are primarily used in environments to specify their input/output structure without needing to + execute the environment (or starting it). They can also be used to instantiate shared buffers to pass + data from worker to worker. + + TensorSpecs are dataclasses that always share the following fields: `shape`, `space, `dtype` and `device`. + + As such, TensorSpecs possess some common behavior with :class:`~torch.Tensor` and :class:`~tensordict.TensorDict`: + they can be reshaped, indexed, squeezed, unsqueezed, moved to another device etc. Args: - shape (torch.Size): size of the tensor - space (Box): Box instance describing what kind of values can be - expected - device (torch.device): device of the tensor - dtype (torch.dtype): dtype of the tensor + shape (torch.Size): size of the tensor. The shape includes the batch dimensions as well as the feature + dimension. A negative shape (``-1``) means that the dimension has a variable number of elements. + space (Box): Box instance describing what kind of values can be expected. + device (torch.device): device of the tensor. + dtype (torch.dtype): dtype of the tensor. + + .. note:: A spec can be constructed from a :class:`~tensordict.TensorDict` using the :func:`~torchrl.envs.utils.make_composite_from_td` + function. This function makes a low-assumption educated guess on the specs that may correspond to the input + tensordict and can help to build specs automatically without an in-depth knowledge of the `TensorSpec` API. """ @@ -537,21 +565,35 @@ def decorator(func): @property def device(self) -> torch.device: + """The device of the spec. + + Only :class:`Composite` specs can have a ``None`` device. All leaves must have a non-null device. + """ return self._device @device.setter def device(self, device: torch.device | None) -> None: self._device = _make_ordinal_device(device) - def clear_device_(self): - """A no-op for all leaf specs (which must have a device).""" + def clear_device_(self) -> T: + """A no-op for all leaf specs (which must have a device). + + For :class:`Composite` specs, this method will erase the device. + """ return self def encode( - self, val: Union[np.ndarray, torch.Tensor], *, ignore_device=False - ) -> torch.Tensor: + self, + val: np.ndarray | torch.Tensor | TensorDictBase, + *, + ignore_device: bool = False, + ) -> torch.Tensor | TensorDictBase: """Encodes a value given the specified spec, and return the corresponding tensor. + This method is to be used in environments that return a value (eg, a numpy array) that can be + easily mapped to the TorchRL required domain. + If the value is already a tensor, the spec will not change its value and return it as-is. + Args: val (np.ndarray or torch.Tensor): value to be encoded as tensor. @@ -604,11 +646,15 @@ def __ne__(self, other): def __setattr__(self, key, value): if key == "shape": - value = torch.Size(value) + value = _size(value) super().__setattr__(key, value) - def to_numpy(self, val: torch.Tensor, safe: bool = None) -> np.ndarray: - """Returns the np.ndarray correspondent of an input tensor. + def to_numpy( + self, val: torch.Tensor | TensorDictBase, safe: bool = None + ) -> np.ndarray | dict: + """Returns the ``np.ndarray`` correspondent of an input tensor. + + This is intended to be the inverse operation of :meth:`.encode`. Args: val (torch.Tensor): tensor to be transformed_in to numpy. @@ -617,7 +663,7 @@ def to_numpy(self, val: torch.Tensor, safe: bool = None) -> np.ndarray: Defaults to the value of the ``CHECK_SPEC_ENCODE`` environment variable. Returns: - a np.ndarray + a np.ndarray. """ if safe is None: @@ -627,19 +673,31 @@ def to_numpy(self, val: torch.Tensor, safe: bool = None) -> np.ndarray: return val.detach().cpu().numpy() @property - def ndim(self): + def ndim(self) -> int: + """Number of dimensions of the spec shape. + + Shortcut for ``len(spec.shape)``. + + """ return self.ndimension() - def ndimension(self): + def ndimension(self) -> int: + """Number of dimensions of the spec shape. + + Shortcut for ``len(spec.shape)``. + + """ return len(self.shape) @property - def _safe_shape(self): + def _safe_shape(self) -> torch.Size: """Returns a shape where all heterogeneous values are replaced by one (to be expandable).""" - return torch.Size([int(v) if v >= 0 else 1 for v in self.shape]) + return _size([int(v) if v >= 0 else 1 for v in self.shape]) @abc.abstractmethod - def index(self, index: INDEX_TYPING, tensor_to_index: torch.Tensor) -> torch.Tensor: + def index( + self, index: INDEX_TYPING, tensor_to_index: torch.Tensor | TensorDictBase + ) -> torch.Tensor | TensorDictBase: """Indexes the input tensor. Args: @@ -652,20 +710,25 @@ def index(self, index: INDEX_TYPING, tensor_to_index: torch.Tensor) -> torch.Ten """ ... + @overload + def expand(self, shape: torch.Size): + ... + @abc.abstractmethod - def expand(self, *shape): - """Returns a new Spec with the extended shape. + def expand(self, *shape: int) -> T: + """Returns a new Spec with the expanded shape. Args: - *shape (tuple or iterable of int): the new shape of the Spec. Must comply with the current shape: + *shape (tuple or iterable of int): the new shape of the Spec. + Must be broadcastable with the current shape: its length must be at least as long as the current shape length, - and its last values must be complient too; ie they can only differ + and its last values must be compliant too; ie they can only differ from it if the current dimension is a singleton. """ ... - def squeeze(self, dim: int | None = None): + def squeeze(self, dim: int | None = None) -> T: """Returns a new Spec with all the dimensions of size ``1`` removed. When ``dim`` is given, a squeeze operation is done only in that dimension. @@ -679,21 +742,30 @@ def squeeze(self, dim: int | None = None): return self return self.__class__(shape=shape, device=self.device, dtype=self.dtype) - def unsqueeze(self, dim: int): + def unsqueeze(self, dim: int) -> T: + """Returns a new Spec with one more singleton dimension (at the position indicated by ``dim``). + + Args: + dim (int or None): the dimension to apply the unsqueeze operation to. + + """ shape = _unsqueezed_shape(self.shape, dim) return self.__class__(shape=shape, device=self.device, dtype=self.dtype) - def make_neg_dim(self, dim): + def make_neg_dim(self, dim: int) -> T: + """Converts a specific dimension to ``-1``.""" if dim < 0: dim = self.ndim + dim if dim < 0 or dim > self.ndim - 1: raise ValueError(f"dim={dim} is out of bound for ndim={self.ndim}") - self.shape = torch.Size( - [s if i != dim else -1 for i, s in enumerate(self.shape)] - ) + self.shape = _size([s if i != dim else -1 for i, s in enumerate(self.shape)]) + + @overload + def reshape(self, shape) -> T: + ... - def reshape(self, *shape): - """Reshapes a tensorspec. + def reshape(self, *shape) -> T: + """Reshapes a ``TensorSpec``. Check :func:`~torch.reshape` for more information on this method. @@ -705,23 +777,23 @@ def reshape(self, *shape): view = reshape @abc.abstractmethod - def _reshape(self, shape): + def _reshape(self, shape: torch.Size) -> T: ... - def unflatten(self, dim, sizes): - """Unflattens a tensorspec. + def unflatten(self, dim: int, sizes: Tuple[int]) -> T: + """Unflattens a ``TensorSpec``. Check :func:`~torch.unflatten` for more information on this method. """ return self._unflatten(dim, sizes) - def _unflatten(self, dim, sizes): + def _unflatten(self, dim: int, sizes: Tuple[int]) -> T: shape = torch.zeros(self.shape, device="meta").unflatten(dim, sizes).shape return self._reshape(shape) - def flatten(self, start_dim, end_dim): - """Flattens a tensorspec. + def flatten(self, start_dim: int, end_dim: int) -> T: + """Flattens a ``TensorSpec``. Check :func:`~torch.flatten` for more information on this method. @@ -733,31 +805,39 @@ def _flatten(self, start_dim, end_dim): return self._reshape(shape) @abc.abstractmethod - def _project(self, val: torch.Tensor) -> torch.Tensor: + def _project( + self, val: torch.Tensor | TensorDictBase + ) -> torch.Tensor | TensorDictBase: raise NotImplementedError(type(self)) @abc.abstractmethod - def is_in(self, val: torch.Tensor) -> bool: - """If the value :obj:`val` is in the box defined by the TensorSpec, returns True, otherwise False. + def is_in(self, val: torch.Tensor | TensorDictBase) -> bool: + """If the value ``val`` could have been generated by the ``TensorSpec``, returns ``True``, otherwise ``False``. + + More precisely, the ``is_in`` methods checks that the value ``val`` is within the limits defined by the ``space`` + attribute (the box), and that the ``dtype``, ``device``, ``shape`` potentially other metadata match those + of the spec. If any of these checks fails, the ``is_in`` method will return ``False``. Args: - val (torch.Tensor): value to be checked + val (torch.Tensor): value to be checked. Returns: - boolean indicating if values belongs to the TensorSpec box + boolean indicating if values belongs to the TensorSpec box. """ ... - def contains(self, item): - """Returns whether a sample is contained within the space defined by the TensorSpec. + def contains(self, item: torch.Tensor | TensorDictBase) -> bool: + """If the value ``val`` could have been generated by the ``TensorSpec``, returns ``True``, otherwise ``False``. See :meth:`~.is_in` for more information. """ return self.is_in(item) - def project(self, val: torch.Tensor) -> torch.Tensor: - """If the input tensor is not in the TensorSpec box, it maps it back to it given some heuristic. + def project( + self, val: torch.Tensor | TensorDictBase + ) -> torch.Tensor | TensorDictBase: + """If the input tensor is not in the TensorSpec box, it maps it back to it given some defined heuristic. Args: val (torch.Tensor): tensor to be mapped to the box. @@ -785,10 +865,10 @@ def assert_is_in(self, value: torch.Tensor) -> None: ) def type_check(self, value: torch.Tensor, key: NestedKey = None) -> None: - """Checks the input value dtype against the TensorSpec dtype and raises an exception if they don't match. + """Checks the input value ``dtype`` against the ``TensorSpec`` ``dtype`` and raises an exception if they don't match. Args: - value (torch.Tensor): tensor whose dtype has to be checked + value (torch.Tensor): tensor whose dtype has to be checked. key (str, optional): if the TensorSpec has keys, the value dtype will be checked against the spec pointed by the indicated key. @@ -801,8 +881,11 @@ def type_check(self, value: torch.Tensor, key: NestedKey = None) -> None: ) @abc.abstractmethod - def rand(self, shape=None) -> torch.Tensor: - """Returns a random tensor in the space defined by the spec. The sampling will be uniform unless the box is unbounded. + def rand(self, shape: torch.Size = None) -> torch.Tensor | TensorDictBase: + """Returns a random tensor in the space defined by the spec. + + The sampling will be done uniformly over the space, unless the box is unbounded in which case normal values + will be drawn. Args: shape (torch.Size): shape of the random tensor @@ -811,19 +894,22 @@ def rand(self, shape=None) -> torch.Tensor: a random tensor sampled in the TensorSpec box. """ - raise NotImplementedError + ... - @property - def sample(self): + def sample(self, shape: torch.Size = None) -> torch.Tensor | TensorDictBase: """Returns a random tensor in the space defined by the spec. See :meth:`~.rand` for details. """ - return self.rand + return self.rand(shape=shape) - def zero(self, shape=None) -> torch.Tensor: + def zero(self, shape: torch.Size = None) -> torch.Tensor | TensorDictBase: """Returns a zero-filled tensor in the box. + .. note:: Even though there is no guarantee that ``0`` belongs to the spec domain, + this method will not raise an exception when this condition is violated. + The primary use case of ``zero`` is to generate empty data buffers, not meaningful data. + Args: shape (torch.Size): shape of the zero-tensor @@ -832,26 +918,59 @@ def zero(self, shape=None) -> torch.Tensor: """ if shape is None: - shape = torch.Size([]) + shape = _size([]) return torch.zeros( (*shape, *self._safe_shape), dtype=self.dtype, device=self.device ) + def zeros(self, shape: torch.Size = None) -> torch.Tensor | TensorDictBase: + """Proxy to :meth:`~.zero`.""" + return self.zero(shape=shape) + + def one(self, shape: torch.Size = None) -> torch.Tensor | TensorDictBase: + """Returns a one-filled tensor in the box. + + .. note:: Even though there is no guarantee that ``1`` belongs to the spec domain, + this method will not raise an exception when this condition is violated. + The primary use case of ``one`` is to generate empty data buffers, not meaningful data. + + Args: + shape (torch.Size): shape of the one-tensor + + Returns: + a one-filled tensor sampled in the TensorSpec box. + + """ + if self.dtype == torch.bool: + return ~self.zero(shape=shape) + return self.zero(shape) + 1 + + def ones(self, shape: torch.Size = None) -> torch.Tensor | TensorDictBase: + """Proxy to :meth:`~.one`.""" + return self.one(shape=shape) + @abc.abstractmethod def to(self, dest: Union[torch.dtype, DEVICE_TYPING]) -> "TensorSpec": - raise NotImplementedError + """Casts a TensorSpec to a device or a dtype. + + Returns the same spec if no change is made. + """ + ... def cpu(self): + """Casts the TensorSpec to 'cpu' device.""" return self.to("cpu") def cuda(self, device=None): + """Casts the TensorSpec to 'cuda' device.""" if device is None: return self.to("cuda") return self.to(f"cuda:{device}") @abc.abstractmethod def clone(self) -> "TensorSpec": - raise NotImplementedError + """Creates a copy of the TensorSpec.""" + ... def __repr__(self): shape_str = indent("shape=" + str(self.shape), " " * 4) @@ -898,7 +1017,7 @@ def __init__(self, *specs: tuple[T, ...], dim: int) -> None: self.dim = len(self.shape) + self.dim def clear_device_(self): - """Clears the device of the CompositeSpec.""" + """Clears the device of the Composite.""" for spec in self._specs: spec.clear_device_() return self @@ -997,7 +1116,7 @@ def clone(self) -> T: def stack_dim(self): return self.dim - def zero(self, shape=None) -> TensorDictBase: + def zero(self, shape: torch.Size = None) -> TensorDictBase: if shape is not None: dim = self.dim + len(shape) else: @@ -1008,7 +1127,7 @@ def zero(self, shape=None) -> TensorDictBase: ) return torch.nested.nested_tensor([spec.zero(shape) for spec in self._specs]) - def one(self, shape=None) -> TensorDictBase: + def one(self, shape: torch.Size = None) -> TensorDictBase: if shape is not None: dim = self.dim + len(shape) else: @@ -1019,7 +1138,7 @@ def one(self, shape=None) -> TensorDictBase: ) return torch.nested.nested_tensor([spec.one(shape) for spec in self._specs]) - def rand(self, shape=None) -> TensorDictBase: + def rand(self, shape: torch.Size = None) -> TensorDictBase: if shape is not None: dim = self.dim + len(shape) else: @@ -1125,7 +1244,7 @@ def squeeze(self, dim: int = None): ) -class LazyStackedTensorSpec(_LazyStackedMixin[TensorSpec], TensorSpec): +class Stacked(_LazyStackedMixin[TensorSpec], TensorSpec): """A lazy representation of a stack of tensor specs. Stacks tensor-specs together along one dimension. @@ -1134,13 +1253,13 @@ class LazyStackedTensorSpec(_LazyStackedMixin[TensorSpec], TensorSpec): Indexing is allowed but only along the stack dimension. - This class is aimed to be used in multi-task and multi-agent settings, where + This class aims at being used in multi-tasks and multi-agent settings, where heterogeneous specs may occur (same semantic but different shape). """ def __eq__(self, other): - if not isinstance(other, LazyStackedTensorSpec): + if not isinstance(other, Stacked): return False if self.device != other.device: raise RuntimeError((self, other)) @@ -1161,8 +1280,7 @@ def to_numpy(self, val: torch.Tensor, safe: bool = None) -> dict: if safe: if val.shape[self.dim] != len(self._specs): raise ValueError( - "Size of LazyStackedTensorSpec and val differ along the stacking " - "dimension" + "Size of Stacked and val differ along the stacking " "dimension" ) for spec, v in zip(self._specs, torch.unbind(val, dim=self.dim)): spec.assert_is_in(v) @@ -1174,7 +1292,7 @@ def __repr__(self): dtype_str = "dtype=" + str(self.dtype) domain_str = "domain=" + str(self._specs[0].domain) sub_string = ", ".join([shape_str, device_str, dtype_str, domain_str]) - string = f"LazyStacked{self._specs[0].__class__.__name__}(\n {sub_string})" + string = f"Stacked{self._specs[0].__class__.__name__}(\n {sub_string})" return string @property @@ -1204,7 +1322,7 @@ def shape(self): if dim < 0: dim = len(shape) + dim + 1 shape.insert(dim, len(self._specs)) - return torch.Size(shape) + return _size(shape) @shape.setter def shape(self, shape): @@ -1216,7 +1334,7 @@ def shape(self, shape): raise RuntimeError( f"The shape attribute mismatches between the input {shape} and self.shape={self.shape}." ) - shape_strip = torch.Size([s for i, s in enumerate(self.shape) if i != self.dim]) + shape_strip = _size([s for i, s in enumerate(self.shape) if i != self.dim]) for spec in self._specs: spec.shape = shape_strip @@ -1295,7 +1413,7 @@ def encode( @dataclass(repr=False) -class OneHotDiscreteTensorSpec(TensorSpec): +class OneHot(TensorSpec): """A unidimensional, one-hot discrete tensor spec. By default, TorchRL assumes that categorical variables are encoded as @@ -1316,10 +1434,10 @@ class OneHotDiscreteTensorSpec(TensorSpec): Args: n (int): number of possible outcomes. shape (torch.Size, optional): total shape of the sampled tensors. - If provided, the last dimension must match n. + If provided, the last dimension must match ``n``. device (str, int or torch.device, optional): device of the tensors. dtype (str or torch.dtype, optional): dtype of the tensors. - user_register (bool): experimental feature. If True, every integer + use_register (bool): experimental feature. If ``True``, every integer will be mapped onto a binary vector in the order in which they appear. This feature is designed for environment with no a-priori definition of the number of possible outcomes (e.g. @@ -1329,16 +1447,29 @@ class OneHotDiscreteTensorSpec(TensorSpec): mask (torch.Tensor or None): mask some of the possible outcomes when a sample is taken. See :meth:`~.update_mask` for more information. + Examples: + >>> from torchrl.data.tensor_specs import OneHot + >>> spec = OneHot(5, shape=(2, 5)) + >>> spec.rand() + tensor([[False, True, False, False, False], + [False, True, False, False, False]]) + >>> mask = torch.tensor([ + ... [False, False, False, False, True], + ... [False, False, False, False, True] + ... ]) + >>> spec.update_mask(mask) + >>> spec.rand() + tensor([[False, False, False, False, True], + [False, False, False, False, True]]) + """ shape: torch.Size - space: DiscreteBox + space: CategoricalBox device: torch.device | None = None dtype: torch.dtype = torch.float domain: str = "" - # SPEC_HANDLED_FUNCTIONS = {} - def __init__( self, n: int, @@ -1350,11 +1481,11 @@ def __init__( ): dtype, device = _default_dtype_and_device(dtype, device) self.use_register = use_register - space = DiscreteBox(n) + space = CategoricalBox(n) if shape is None: - shape = torch.Size((space.n,)) + shape = _size((space.n,)) else: - shape = torch.Size(shape) + shape = _size(shape) if not len(shape) or shape[-1] != space.n: raise ValueError( f"The last value of the shape must match n for transform of type {self.__class__}. " @@ -1378,12 +1509,12 @@ def update_mask(self, mask): mask (torch.Tensor or None): boolean mask. If None, the mask is disabled. Otherwise, the shape of the mask must be expandable to the shape of the spec. ``False`` masks an outcome and ``True`` - leaves the outcome unmasked. If all of the possible outcomes are + leaves the outcome unmasked. If all the possible outcomes are masked, then an error is raised when a sample is taken. Examples: >>> mask = torch.tensor([True, False, False]) - >>> ts = OneHotDiscreteTensorSpec(3, (2, 3,), dtype=torch.int64, mask=mask) + >>> ts = OneHot(3, (2, 3,), dtype=torch.int64, mask=mask) >>> # All but one of the three possible outcomes are masked >>> ts.rand() tensor([[1, 0, 0], @@ -1398,7 +1529,7 @@ def update_mask(self, mask): raise ValueError("Only boolean masks are accepted.") self.mask = mask - def to(self, dest: Union[torch.dtype, DEVICE_TYPING]) -> CompositeSpec: + def to(self, dest: Union[torch.dtype, DEVICE_TYPING]) -> OneHot: if dest is None: return self if isinstance(dest, torch.dtype): @@ -1418,7 +1549,7 @@ def to(self, dest: Union[torch.dtype, DEVICE_TYPING]) -> CompositeSpec: mask=self.mask.to(dest) if self.mask is not None else None, ) - def clone(self) -> OneHotDiscreteTensorSpec: + def clone(self) -> OneHot: return self.__class__( n=self.space.n, shape=self.shape, @@ -1536,11 +1667,11 @@ def unbind(self, dim: int = 0): for i in range(self.shape[dim]) ) - def rand(self, shape=None) -> torch.Tensor: + def rand(self, shape: torch.Size = None) -> torch.Tensor: if shape is None: shape = self.shape[:-1] else: - shape = torch.Size([*shape, *self.shape[:-1]]) + shape = _size([*shape, *self.shape[:-1]]) mask = self.mask if mask is None: n = self.space.n @@ -1561,7 +1692,7 @@ def rand(self, shape=None) -> torch.Tensor: def encode( self, val: Union[np.ndarray, torch.Tensor], - space: Optional[DiscreteBox] = None, + space: Optional[CategoricalBox] = None, *, ignore_device: bool = False, ) -> torch.Tensor: @@ -1619,7 +1750,7 @@ def __getitem__(self, idx: SHAPE_INDEX_TYPING): indexed_shape = _shape_indexing(self.shape[:-1], idx) return self.__class__( n=self.space.n, - shape=torch.Size(indexed_shape + [self.shape[-1]]), + shape=_size(indexed_shape + [self.shape[-1]]), device=self.device, dtype=self.dtype, use_register=self.use_register, @@ -1689,6 +1820,16 @@ def to_categorical(self, val: torch.Tensor, safe: bool = None) -> torch.Tensor: Returns: The categorical tensor. + + Examples: + >>> one_hot = OneHot(3, shape=(2, 3)) + >>> one_hot_sample = one_hot.rand() + >>> one_hot_sample + tensor([[False, True, False], + [False, True, False]]) + >>> categ_sample = one_hot.to_categorical(one_hot_sample) + >>> categ_sample + tensor([1, 1]) """ if safe is None: safe = _CHECK_SPEC_ENCODE @@ -1696,25 +1837,103 @@ def to_categorical(self, val: torch.Tensor, safe: bool = None) -> torch.Tensor: self.assert_is_in(val) return val.long().argmax(-1) - def to_categorical_spec(self) -> DiscreteTensorSpec: - """Converts the spec to the equivalent categorical spec.""" - return DiscreteTensorSpec( + def to_categorical_spec(self) -> Categorical: + """Converts the spec to the equivalent categorical spec. + + Examples: + >>> one_hot = OneHot(3, shape=(2, 3)) + >>> one_hot.to_categorical_spec() + Categorical( + shape=torch.Size([2]), + space=CategoricalBox(n=3), + device=cpu, + dtype=torch.int64, + domain=discrete) + + """ + return Categorical( self.space.n, device=self.device, shape=self.shape[:-1], mask=self.mask, ) + def to_one_hot(self, val: torch.Tensor, safe: bool = None) -> torch.Tensor: + """No-op for OneHot.""" + return val + + def to_one_hot_spec(self) -> OneHot: + """No-op for OneHot.""" + return self + + +class _BoundedMeta(abc.ABCMeta): + def __call__(cls, *args, **kwargs): + instance = super().__call__(*args, **kwargs) + if instance.domain == "continuous": + instance.__class__ = BoundedContinuous + else: + instance.__class__ = BoundedDiscrete + return instance + @dataclass(repr=False) -class BoundedTensorSpec(TensorSpec): - """A bounded continuous tensor spec. +class Bounded(TensorSpec, metaclass=_BoundedMeta): + """A bounded tensor spec. + + ``Bounded`` specs will never appear as such and always be subclassed as :class:`BoundedContinuous` + or :class:`BoundedDiscrete` depending on their dtype (floating points dtypes will result in + :class:`BoundedContinuous` instances, all others in :class:`BoundedDiscrete` instances). Args: low (np.ndarray, torch.Tensor or number): lower bound of the box. high (np.ndarray, torch.Tensor or number): upper bound of the box. + shape (torch.Size): the shape of the ``Bounded`` spec. The shape must be specified. + Inputs ``low``, ``high`` and ``shape`` must be broadcastable. device (str, int or torch.device, optional): device of the tensors. dtype (str or torch.dtype, optional): dtype of the tensors. + domain (str): `"continuous"` or `"discrete"`. Can be used to override the automatic type assignment. + + Examples: + >>> spec = Bounded(low=-1, high=1, shape=(), dtype=torch.float) + >>> spec + BoundedContinuous( + shape=torch.Size([]), + space=ContinuousBox( + low=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, contiguous=True), + high=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, contiguous=True)), + device=cpu, + dtype=torch.float32, + domain=continuous) + >>> spec = Bounded(low=-1, high=1, shape=(), dtype=torch.int) + >>> spec + BoundedDiscrete( + shape=torch.Size([]), + space=ContinuousBox( + low=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.int32, contiguous=True), + high=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.int32, contiguous=True)), + device=cpu, + dtype=torch.int32, + domain=discrete) + >>> spec.to(torch.float) + BoundedContinuous( + shape=torch.Size([]), + space=ContinuousBox( + low=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, contiguous=True), + high=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, contiguous=True)), + device=cpu, + dtype=torch.float32, + domain=continuous) + >>> spec = Bounded(low=-1, high=1, shape=(), dtype=torch.int, domain="continuous") + >>> spec + BoundedContinuous( + shape=torch.Size([]), + space=ContinuousBox( + low=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.int32, contiguous=True), + high=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.int32, contiguous=True)), + device=cpu, + dtype=torch.int32, + domain=continuous) """ @@ -1748,13 +1967,18 @@ def __init__( "Minimum is deprecated since v0.4.0, using low instead.", category=DeprecationWarning, ) - domain = kwargs.pop("domain", "continuous") + domain = kwargs.pop("domain", None) if len(kwargs): raise TypeError(f"Got unrecognised kwargs {tuple(kwargs.keys())}.") dtype, device = _default_dtype_and_device(dtype, device) if dtype is None: dtype = torch.get_default_dtype() + if domain is None: + if dtype.is_floating_point: + domain = "continuous" + else: + domain = "discrete" if not isinstance(low, torch.Tensor): low = torch.tensor(low, dtype=dtype, device=device) @@ -1769,7 +1993,7 @@ def __init__( if dtype is not None and high.dtype is not dtype: high = high.to(dtype) err_msg = ( - "BoundedTensorSpec requires the shape to be explicitely (via " + "Bounded requires the shape to be explicitely (via " "the shape argument) or implicitely defined (via either the " "minimum or the maximum or both). If the maximum and/or the " "minimum have a non-singleton shape, they must match the " @@ -1777,9 +2001,9 @@ def __init__( ) if shape is not None and not isinstance(shape, torch.Size): if isinstance(shape, int): - shape = torch.Size([shape]) + shape = _size([shape]) else: - shape = torch.Size(list(shape)) + shape = _size(list(shape)) if shape is not None: shape_corr = _remove_neg_shapes(shape) else: @@ -1812,9 +2036,9 @@ def __init__( shape = low.shape else: if isinstance(shape_corr, float): - shape_corr = torch.Size([shape_corr]) + shape_corr = _size([shape_corr]) elif not isinstance(shape_corr, torch.Size): - shape_corr = torch.Size(shape_corr) + shape_corr = _size(shape_corr) shape_corr_err_msg = ( f"low and shape_corr mismatch, got {low.shape} and {shape_corr}" ) @@ -1945,9 +2169,9 @@ def unbind(self, dim: int = 0): for low, high in zip(low, high) ) - def rand(self, shape=None) -> torch.Tensor: + def rand(self, shape: torch.Size = None) -> torch.Tensor: if shape is None: - shape = torch.Size([]) + shape = _size([]) a, b = self.space if self.dtype in (torch.float, torch.double, torch.half): shape = [*shape, *self._safe_shape] @@ -1971,9 +2195,7 @@ def rand(self, shape=None) -> torch.Tensor: else: mini = self.space.low interval = maxi - mini - r = torch.rand( - torch.Size([*shape, *self._safe_shape]), device=interval.device - ) + r = torch.rand(_size([*shape, *self._safe_shape]), device=interval.device) r = interval * r r = self.space.low + r r = r.to(self.dtype).to(self.device) @@ -2028,7 +2250,7 @@ def is_in(self, val: torch.Tensor) -> bool: return False raise err - def to(self, dest: Union[torch.dtype, DEVICE_TYPING]) -> CompositeSpec: + def to(self, dest: Union[torch.dtype, DEVICE_TYPING]) -> Bounded: if isinstance(dest, torch.dtype): dest_dtype = dest dest_device = self.device @@ -2039,15 +2261,16 @@ def to(self, dest: Union[torch.dtype, DEVICE_TYPING]) -> CompositeSpec: dest_device = torch.device(dest) if dest_device == self.device and dest_dtype == self.dtype: return self - return self.__class__( - low=self.space.low.to(dest), - high=self.space.high.to(dest), + self.space.device = dest_device + return Bounded( + low=self.space.low, + high=self.space.high, shape=self.shape, device=dest_device, dtype=dest_dtype, ) - def clone(self) -> BoundedTensorSpec: + def clone(self) -> Bounded: return self.__class__( low=self.space.low.clone(), high=self.space.high.clone(), @@ -2063,7 +2286,7 @@ def __getitem__(self, idx: SHAPE_INDEX_TYPING): "Pending resolution of https://github.com/pytorch/pytorch/issues/100080." ) - indexed_shape = torch.Size(_shape_indexing(self.shape, idx)) + indexed_shape = _size(_shape_indexing(self.shape, idx)) # Expand is required as pytorch.tensor indexing return self.__class__( low=self.space.low[idx].clone().expand(indexed_shape), @@ -2074,6 +2297,45 @@ def __getitem__(self, idx: SHAPE_INDEX_TYPING): ) +class BoundedContinuous(Bounded, metaclass=_BoundedMeta): + """A specialized version of :class:`torchrl.data.Bounded` with continuous space.""" + + def __init__( + self, + low: Union[float, torch.Tensor, np.ndarray] = None, + high: Union[float, torch.Tensor, np.ndarray] = None, + shape: Optional[Union[torch.Size, int]] = None, + device: Optional[DEVICE_TYPING] = None, + dtype: Optional[Union[torch.dtype, str]] = None, + domain: str = "continuous", + ): + super().__init__( + low=low, high=high, shape=shape, device=device, dtype=dtype, domain=domain + ) + + +class BoundedDiscrete(Bounded, metaclass=_BoundedMeta): + """A specialized version of :class:`torchrl.data.Bounded` with discrete space.""" + + def __init__( + self, + low: Union[float, torch.Tensor, np.ndarray] = None, + high: Union[float, torch.Tensor, np.ndarray] = None, + shape: Optional[Union[torch.Size, int]] = None, + device: Optional[DEVICE_TYPING] = None, + dtype: Optional[Union[torch.dtype, str]] = None, + domain: str = "discrete", + ): + super().__init__( + low=low, + high=high, + shape=shape, + device=device, + dtype=dtype, + domain=domain, + ) + + def _is_nested_list(index, notuple=False): if not notuple and isinstance(index, tuple): for idx in index: @@ -2088,8 +2350,14 @@ def _is_nested_list(index, notuple=False): return False -class NonTensorSpec(TensorSpec): - """A spec for non-tensor data.""" +class NonTensor(TensorSpec): + """A spec for non-tensor data. + + This spec has a shae, device and dtype like :class:`~tensordict.NonTensorData`. + + :meth:`.rand` will return a :class:`~tensordict.NonTensorData` object with `None` data value. + (same will go for :meth:`.zero` and :meth:`.one`). + """ def __init__( self, @@ -2099,7 +2367,7 @@ def __init__( **kwargs, ): if isinstance(shape, int): - shape = torch.Size([shape]) + shape = _size([shape]) _, device = _default_dtype_and_device(None, device) domain = None @@ -2107,7 +2375,7 @@ def __init__( shape=shape, space=None, device=device, dtype=dtype, domain=domain, **kwargs ) - def to(self, dest: Union[torch.dtype, DEVICE_TYPING]) -> NonTensorSpec: + def to(self, dest: Union[torch.dtype, DEVICE_TYPING]) -> NonTensor: if isinstance(dest, torch.dtype): dest_dtype = dest dest_device = self.device @@ -2120,7 +2388,7 @@ def to(self, dest: Union[torch.dtype, DEVICE_TYPING]) -> NonTensorSpec: return self return self.__class__(shape=self.shape, device=dest_device, dtype=None) - def clone(self) -> NonTensorSpec: + def clone(self) -> NonTensor: return self.__class__(shape=self.shape, device=self.device, dtype=self.dtype) def rand(self, shape=None): @@ -2158,7 +2426,7 @@ def is_in(self, val: torch.Tensor) -> bool: def expand(self, *shape): if len(shape) == 1 and isinstance(shape[0], (tuple, list, torch.Size)): shape = shape[0] - shape = torch.Size(shape) + shape = _size(shape) if not all( (old == 1) or (old == new) for old, new in zip(self.shape, shape[-len(self.shape) :]) @@ -2181,7 +2449,7 @@ def _unflatten(self, dim, sizes): def __getitem__(self, idx: SHAPE_INDEX_TYPING): """Indexes the current TensorSpec based on the provided index.""" - indexed_shape = torch.Size(_shape_indexing(self.shape, idx)) + indexed_shape = _size(_shape_indexing(self.shape, idx)) return self.__class__(shape=indexed_shape, device=self.device, dtype=self.dtype) def unbind(self, dim: int = 0): @@ -2203,17 +2471,76 @@ def unbind(self, dim: int = 0): ) +class _UnboundedMeta(abc.ABCMeta): + def __call__(cls, *args, **kwargs): + instance = super().__call__(*args, **kwargs) + if instance.domain == "continuous": + instance.__class__ = UnboundedContinuous + else: + instance.__class__ = UnboundedDiscrete + return instance + + @dataclass(repr=False) -class UnboundedContinuousTensorSpec(TensorSpec): - """An unbounded continuous tensor spec. +class Unbounded(TensorSpec, metaclass=_UnboundedMeta): + """An unbounded tensor spec. + + ``Unbounded`` specs will never appear as such and always be subclassed as :class:`UnboundedContinuous` + or :class:`UnboundedDiscrete` depending on their dtype (floating points dtypes will result in + :class:`UnboundedContinuous` instances, all others in :class:`UnboundedDiscrete` instances). + + Although it is not properly limited above and below, this class still has a :attr:`Box` space that encodes + the maximum and minimum value that the dtype accepts. Args: + shape (torch.Size): the shape of the ``Bounded`` spec. The shape must be specified. + Inputs ``low``, ``high`` and ``shape`` must be broadcastable. device (str, int or torch.device, optional): device of the tensors. - dtype (str or torch.dtype, optional): dtype of the tensors - (should be an floating point dtype such as float, double etc.) - """ + dtype (str or torch.dtype, optional): dtype of the tensors. + domain (str): `"continuous"` or `"discrete"`. Can be used to override the automatic type assignment. - # SPEC_HANDLED_FUNCTIONS = {} + Examples: + >>> spec = Unbounded(shape=(), dtype=torch.float) + >>> spec + UnboundedContinuous( + shape=torch.Size([]), + space=ContinuousBox( + low=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, contiguous=True), + high=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, contiguous=True)), + device=cpu, + dtype=torch.float32, + domain=continuous) + >>> spec = Unbounded(shape=(), dtype=torch.int) + >>> spec + UnboundedDiscrete( + shape=torch.Size([]), + space=ContinuousBox( + low=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.int32, contiguous=True), + high=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.int32, contiguous=True)), + device=cpu, + dtype=torch.int32, + domain=discrete) + >>> spec.to(torch.float) + UnboundedContinuous( + shape=torch.Size([]), + space=ContinuousBox( + low=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, contiguous=True), + high=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, contiguous=True)), + device=cpu, + dtype=torch.float32, + domain=continuous) + >>> spec = Unbounded(shape=(), dtype=torch.int, domain="continuous") + >>> spec + UnboundedContinuous( + shape=torch.Size([]), + space=ContinuousBox( + low=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.int32, contiguous=True), + high=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.int32, contiguous=True)), + device=cpu, + dtype=torch.int32, + domain=continuous) + + """ def __init__( self, @@ -2223,26 +2550,37 @@ def __init__( **kwargs, ): if isinstance(shape, int): - shape = torch.Size([shape]) + shape = _size([shape]) dtype, device = _default_dtype_and_device(dtype, device) - box = ( - ContinuousBox( - torch.as_tensor(-np.inf, device=device).expand(shape), - torch.as_tensor(np.inf, device=device).expand(shape), - ) - if shape == _DEFAULT_SHAPE - else None + if dtype == torch.bool: + min_value = False + max_value = True + default_domain = "discrete" + else: + if dtype.is_floating_point: + min_value = torch.finfo(dtype).min + max_value = torch.finfo(dtype).max + default_domain = "continuous" + else: + min_value = torch.iinfo(dtype).min + max_value = torch.iinfo(dtype).max + default_domain = "discrete" + box = ContinuousBox( + torch.full( + _remove_neg_shapes(shape), min_value, device=device, dtype=dtype + ), + torch.full( + _remove_neg_shapes(shape), max_value, device=device, dtype=dtype + ), ) - default_domain = "continuous" if dtype.is_floating_point else "discrete" + domain = kwargs.pop("domain", default_domain) super().__init__( shape=shape, space=box, device=device, dtype=dtype, domain=domain, **kwargs ) - def to( - self, dest: Union[torch.dtype, DEVICE_TYPING] - ) -> UnboundedContinuousTensorSpec: + def to(self, dest: Union[torch.dtype, DEVICE_TYPING]) -> Unbounded: if isinstance(dest, torch.dtype): dest_dtype = dest dest_device = self.device @@ -2253,14 +2591,14 @@ def to( dest_device = torch.device(dest) if dest_device == self.device and dest_dtype == self.dtype: return self - return self.__class__(shape=self.shape, device=dest_device, dtype=dest_dtype) + return Unbounded(shape=self.shape, device=dest_device, dtype=dest_dtype) - def clone(self) -> UnboundedContinuousTensorSpec: + def clone(self) -> Unbounded: return self.__class__(shape=self.shape, device=self.device, dtype=self.dtype) - def rand(self, shape=None) -> torch.Tensor: + def rand(self, shape: torch.Size = None) -> torch.Tensor: if shape is None: - shape = torch.Size([]) + shape = _size([]) shape = [*shape, *self.shape] if self.dtype.is_floating_point: return torch.randn(shape, device=self.device, dtype=self.dtype) @@ -2301,7 +2639,7 @@ def _unflatten(self, dim, sizes): def __getitem__(self, idx: SHAPE_INDEX_TYPING): """Indexes the current TensorSpec based on the provided index.""" - indexed_shape = torch.Size(_shape_indexing(self.shape, idx)) + indexed_shape = _size(_shape_indexing(self.shape, idx)) return self.__class__(shape=indexed_shape, device=self.device, dtype=self.dtype) def unbind(self, dim: int = 0): @@ -2324,21 +2662,12 @@ def unbind(self, dim: int = 0): def __eq__(self, other): # those specs are equivalent to a discrete spec - if isinstance(other, UnboundedDiscreteTensorSpec): - return ( - UnboundedDiscreteTensorSpec( - shape=self.shape, - device=self.device, - dtype=self.dtype, - ) - == other - ) - if isinstance(other, BoundedTensorSpec): + if isinstance(other, Bounded): minval, maxval = _minmax_dtype(self.dtype) minval = torch.as_tensor(minval).to(self.device, self.dtype) maxval = torch.as_tensor(maxval).to(self.device, self.dtype) return ( - BoundedTensorSpec( + Bounded( shape=self.shape, high=maxval, low=minval, @@ -2348,185 +2677,43 @@ def __eq__(self, other): ) == other ) + elif isinstance(other, Unbounded): + if self.dtype != other.dtype: + return False + if self.shape != other.shape: + return False + if self.device != other.device: + return False + return True return super().__eq__(other) -@dataclass(repr=False) -class UnboundedDiscreteTensorSpec(TensorSpec): - """An unbounded discrete tensor spec. +class UnboundedContinuous(Unbounded): + """A specialized version of :class:`torchrl.data.Unbounded` with continuous space.""" - Args: - device (str, int or torch.device, optional): device of the tensors. - dtype (str or torch.dtype, optional): dtype of the tensors - (should be an integer dtype such as long, uint8 etc.) - """ + ... - # SPEC_HANDLED_FUNCTIONS = {} + +class UnboundedDiscrete(Unbounded): + """A specialized version of :class:`torchrl.data.Unbounded` with discrete space.""" def __init__( self, shape: Union[torch.Size, int] = _DEFAULT_SHAPE, device: Optional[DEVICE_TYPING] = None, dtype: Optional[Union[str, torch.dtype]] = torch.int64, + **kwargs, ): - if isinstance(shape, int): - shape = torch.Size([shape]) - - dtype, device = _default_dtype_and_device(dtype, device) - if dtype == torch.bool: - min_value = False - max_value = True - else: - if dtype.is_floating_point: - min_value = torch.finfo(dtype).min - max_value = torch.finfo(dtype).max - else: - min_value = torch.iinfo(dtype).min - max_value = torch.iinfo(dtype).max - space = ContinuousBox( - torch.full(_remove_neg_shapes(shape), min_value, device=device), - torch.full(_remove_neg_shapes(shape), max_value, device=device), - ) - - super().__init__( - shape=shape, - space=space, - device=device, - dtype=dtype, - domain="discrete", - ) - - def to(self, dest: Union[torch.dtype, DEVICE_TYPING]) -> CompositeSpec: - if isinstance(dest, torch.dtype): - dest_dtype = dest - dest_device = self.device - elif dest is None: - return self - else: - dest_dtype = self.dtype - dest_device = torch.device(dest) - if dest_device == self.device and dest_dtype == self.dtype: - return self - return self.__class__(shape=self.shape, device=dest_device, dtype=dest_dtype) - - def clone(self) -> UnboundedDiscreteTensorSpec: - return self.__class__(shape=self.shape, device=self.device, dtype=self.dtype) - - def rand(self, shape=None) -> torch.Tensor: - if shape is None: - shape = torch.Size([]) - interval = self.space.high - self.space.low - r = torch.rand(torch.Size([*shape, *interval.shape]), device=interval.device) - r = r * interval - r = self.space.low + r - r = r.to(self.dtype) - return r.to(self.device) - - def is_in(self, val: torch.Tensor) -> bool: - shape = torch.broadcast_shapes(self._safe_shape, val.shape) - return val.shape == shape and val.dtype == self.dtype - - def expand(self, *shape): - if len(shape) == 1 and isinstance(shape[0], (tuple, list, torch.Size)): - shape = shape[0] - if any(s1 != s2 and s2 != 1 for s1, s2 in zip(shape[-self.ndim :], self.shape)): - raise ValueError( - f"The last {self.ndim} of the expanded shape {shape} must match the" - f"shape of the {self.__class__.__name__} spec in expand()." - ) - return self.__class__(shape=shape, device=self.device, dtype=self.dtype) - - def _reshape(self, shape): - return self.__class__(shape=shape, device=self.device, dtype=self.dtype) - - def _unflatten(self, dim, sizes): - shape = torch.zeros(self.shape, device="meta").unflatten(dim, sizes).shape - return self.__class__( - shape=shape, - device=self.device, - dtype=self.dtype, - ) - - def __getitem__(self, idx: SHAPE_INDEX_TYPING): - """Indexes the current TensorSpec based on the provided index.""" - indexed_shape = torch.Size(_shape_indexing(self.shape, idx)) - return self.__class__(shape=indexed_shape, device=self.device, dtype=self.dtype) - - def unbind(self, dim: int = 0): - orig_dim = dim - if dim < 0: - dim = len(self.shape) + dim - if dim < 0: - raise ValueError( - f"Cannot unbind along dim {orig_dim} with shape {self.shape}." - ) - shape = tuple(s for i, s in enumerate(self.shape) if i != dim) - return tuple( - self.__class__( - shape=shape, - device=self.device, - dtype=self.dtype, - ) - for i in range(self.shape[dim]) - ) - - def __eq__(self, other): - # those specs are equivalent to a discrete spec - if isinstance(other, UnboundedContinuousTensorSpec): - return ( - UnboundedContinuousTensorSpec( - shape=self.shape, - device=self.device, - dtype=self.dtype, - domain=self.domain, - ) - == other - ) - if isinstance(other, BoundedTensorSpec): - return ( - BoundedTensorSpec( - shape=self.shape, - high=self.space.high, - low=self.space.low, - dtype=self.dtype, - device=self.device, - domain=self.domain, - ) - == other - ) - return super().__eq__(other) - - def __ne__(self, other): - # those specs are equivalent to a discrete spec - if isinstance(other, UnboundedContinuousTensorSpec): - return ( - UnboundedContinuousTensorSpec( - shape=self.shape, - device=self.device, - dtype=self.dtype, - domain=self.domain, - ) - != other - ) - if isinstance(other, BoundedTensorSpec): - return ( - BoundedTensorSpec( - shape=self.shape, - high=self.space.high, - low=self.space.low, - dtype=self.dtype, - device=self.device, - domain=self.domain, - ) - != other - ) - return super().__ne__(other) + super().__init__(shape=shape, device=device, dtype=dtype, **kwargs) @dataclass(repr=False) -class MultiOneHotDiscreteTensorSpec(OneHotDiscreteTensorSpec): +class MultiOneHot(OneHot): """A concatenation of one-hot discrete tensor spec. + This class can be used when a single tensor must carry information about multiple one-hot encoded + values. + The last dimension of the shape (domain of the tensor elements) cannot be indexed. Args: @@ -2541,20 +2728,22 @@ class MultiOneHotDiscreteTensorSpec(OneHotDiscreteTensorSpec): sample is taken. See :meth:`~.update_mask` for more information. Examples: - >>> ts = MultiOneHotDiscreteTensorSpec((3,2,3)) - >>> ts.is_in(torch.tensor([0,0,1, - ... 0,1, - ... 1,0,0])) + >>> ts = MultiOneHot((3,2,3)) + >>> ts.rand() + tensor([ True, False, False, True, False, False, False, True]) + >>> ts.is_in(torch.tensor([ + ... 0, 0, 1, + ... 0, 1, + ... 1, 0, 0], dtype=torch.bool)) True - >>> ts.is_in(torch.tensor([1,0,1, - ... 0,1, - ... 1,0,0])) # False + >>> ts.is_in(torch.tensor([ + ... 1, 0, 1, + ... 0, 1, + ... 1, 0, 0], dtype=torch.bool)) False """ - # SPEC_HANDLED_FUNCTIONS = {} - def __init__( self, nvec: Sequence[int], @@ -2567,17 +2756,17 @@ def __init__( self.nvec = nvec dtype, device = _default_dtype_and_device(dtype, device) if shape is None: - shape = torch.Size((sum(nvec),)) + shape = _size((sum(nvec),)) else: - shape = torch.Size(shape) + shape = _size(shape) if shape[-1] != sum(nvec): raise ValueError( f"The last value of the shape must match sum(nvec) for transform of type {self.__class__}. " f"Got sum(nvec)={sum(nvec)} and shape={shape}." ) - space = BoxList([DiscreteBox(n) for n in nvec]) + space = BoxList([CategoricalBox(n) for n in nvec]) self.use_register = use_register - super(OneHotDiscreteTensorSpec, self).__init__( + super(OneHot, self).__init__( shape, space, device, @@ -2601,7 +2790,7 @@ def update_mask(self, mask): Examples: >>> mask = torch.tensor([True, False, False, ... True, True]) - >>> ts = MultiOneHotDiscreteTensorSpec((3, 2), (2, 5), dtype=torch.int64, mask=mask) + >>> ts = MultiOneHot((3, 2), (2, 5), dtype=torch.int64, mask=mask) >>> # All but one of the three possible outcomes for the first >>> # one-hot group are masked, but neither of the two possible >>> # outcomes for the second one-hot group are masked. @@ -2618,7 +2807,7 @@ def update_mask(self, mask): raise ValueError("Only boolean masks are accepted.") self.mask = mask - def to(self, dest: Union[torch.dtype, DEVICE_TYPING]) -> CompositeSpec: + def to(self, dest: Union[torch.dtype, DEVICE_TYPING]) -> MultiOneHot: if isinstance(dest, torch.dtype): dest_dtype = dest dest_device = self.device @@ -2629,7 +2818,7 @@ def to(self, dest: Union[torch.dtype, DEVICE_TYPING]) -> CompositeSpec: dest_device = torch.device(dest) if dest_device == self.device and dest_dtype == self.dtype: return self - return self.__class__( + return MultiOneHot( nvec=deepcopy(self.nvec), shape=self.shape, device=dest_device, @@ -2637,7 +2826,7 @@ def to(self, dest: Union[torch.dtype, DEVICE_TYPING]) -> CompositeSpec: mask=self.mask.to(dest) if self.mask is not None else None, ) - def clone(self) -> MultiOneHotDiscreteTensorSpec: + def clone(self) -> MultiOneHot: return self.__class__( nvec=deepcopy(self.nvec), shape=self.shape, @@ -2670,7 +2859,7 @@ def rand(self, shape: Optional[torch.Size] = None) -> torch.Tensor: if shape is None: shape = self.shape[:-1] else: - shape = torch.Size([*shape, *self.shape[:-1]]) + shape = _size([*shape, *self.shape[:-1]]) mask = self.mask if mask is None: @@ -2722,9 +2911,7 @@ def encode( f"value {v} is greater than the allowed max {space.n}" ) x.append( - super(MultiOneHotDiscreteTensorSpec, self).encode( - v, space, ignore_device=ignore_device - ) + super(MultiOneHot, self).encode(v, space, ignore_device=ignore_device) ) return torch.cat(x, -1).reshape(self.shape) @@ -2776,7 +2963,7 @@ def _split_self(self): n = space.n shape = self.shape[:-1] + (n,) result.append( - OneHotDiscreteTensorSpec( + OneHot( n=n, shape=shape, device=device, @@ -2798,6 +2985,16 @@ def to_categorical(self, val: torch.Tensor, safe: bool = None) -> torch.Tensor: Returns: The categorical tensor. + + Examples: + >>> mone_hot = MultiOneHot((2, 3, 4)) + >>> onehot_sample = mone_hot.rand() + >>> onehot_sample + tensor([False, True, False, False, True, False, True, False, False]) + >>> categ_sample = mone_hot.to_categorical(onehot_sample) + >>> categ_sample + tensor([1, 2, 1]) + """ if safe is None: safe = _CHECK_SPEC_ENCODE @@ -2806,15 +3003,36 @@ def to_categorical(self, val: torch.Tensor, safe: bool = None) -> torch.Tensor: vals = self._split(val) return torch.stack([val.long().argmax(-1) for val in vals], -1) - def to_categorical_spec(self) -> MultiDiscreteTensorSpec: - """Converts the spec to the equivalent categorical spec.""" - return MultiDiscreteTensorSpec( + def to_categorical_spec(self) -> MultiCategorical: + """Converts the spec to the equivalent categorical spec. + + Examples: + >>> mone_hot = MultiOneHot((2, 3, 4)) + >>> categ = mone_hot.to_categorical_spec() + >>> categ + MultiCategorical( + shape=torch.Size([3]), + space=BoxList(boxes=[CategoricalBox(n=2), CategoricalBox(n=3), CategoricalBox(n=4)]), + device=cpu, + dtype=torch.int64, + domain=discrete) + + """ + return MultiCategorical( [_space.n for _space in self.space], device=self.device, shape=[*self.shape[:-1], len(self.space)], mask=self.mask, ) + def to_one_hot(self, val: torch.Tensor, safe: bool = None) -> torch.Tensor: + """No-op for MultiOneHot.""" + return val + + def to_one_hot_spec(self) -> OneHot: + """No-op for MultiOneHot.""" + return self + def expand(self, *shape): nvecs = [space.n for space in self.space] if len(shape) == 1 and isinstance(shape[0], (tuple, list, torch.Size)): @@ -2917,28 +3135,21 @@ def __getitem__(self, idx: SHAPE_INDEX_TYPING): indexed_shape = _shape_indexing(self.shape[:-1], idx) return self.__class__( nvec=self.nvec, - shape=torch.Size(indexed_shape + [self.shape[-1]]), + shape=_size(indexed_shape + [self.shape[-1]]), device=self.device, dtype=self.dtype, ) -class DiscreteTensorSpec(TensorSpec): +class Categorical(TensorSpec): """A discrete tensor spec. - An alternative to OneHotTensorSpec for categorical variables in TorchRL. Instead of - using multiplication, categorical variables perform indexing which can speed up + An alternative to :class:`OneHot` for categorical variables in TorchRL. + Categorical variables perform indexing insted of masking, which can speed-up computation and reduce memory cost for large categorical variables. - The last dimension of the spec (length n of the binary vector) cannot be indexed - Example: - >>> batch, size = 3, 4 - >>> action_value = torch.arange(batch*size) - >>> action_value = action_value.view(batch, size).to(torch.float) - >>> action = torch.argmax(action_value, dim=-1).to(torch.long) - >>> chosen_action_value = action_value[range(batch), action] - >>> print(chosen_action_value) - tensor([ 3., 7., 11.]) + The spec will have the shape defined by the ``shape`` argument: if a singleton dimension is + desired for the training dimension, one should specify it explicitly. Args: n (int): number of possible outcomes. @@ -2948,10 +3159,32 @@ class DiscreteTensorSpec(TensorSpec): mask (torch.Tensor or None): mask some of the possible outcomes when a sample is taken. See :meth:`~.update_mask` for more information. + Examples: + >>> categ = Categorical(3) + >>> categ + Categorical( + shape=torch.Size([]), + space=CategoricalBox(n=3), + device=cpu, + dtype=torch.int64, + domain=discrete) + >>> categ.rand() + tensor(2) + >>> categ = Categorical(3, shape=(1,)) + >>> categ + Categorical( + shape=torch.Size([1]), + space=CategoricalBox(n=3), + device=cpu, + dtype=torch.int64, + domain=discrete) + >>> categ.rand() + tensor([1]) + """ shape: torch.Size - space: DiscreteBox + space: CategoricalBox device: torch.device | None = None dtype: torch.dtype = torch.float domain: str = "" @@ -2967,9 +3200,9 @@ def __init__( mask: torch.Tensor | None = None, ): if shape is None: - shape = torch.Size([]) + shape = _size([]) dtype, device = _default_dtype_and_device(dtype, device) - space = DiscreteBox(n) + space = CategoricalBox(n) super().__init__( shape=shape, space=space, device=device, dtype=dtype, domain="discrete" ) @@ -2994,7 +3227,7 @@ def update_mask(self, mask): Examples: >>> mask = torch.tensor([True, False, True]) - >>> ts = DiscreteTensorSpec(3, (10,), dtype=torch.int64, mask=mask) + >>> ts = Categorical(3, (10,), dtype=torch.int64, mask=mask) >>> # One of the three possible outcomes is masked >>> ts.rand() tensor([0, 2, 2, 0, 2, 0, 2, 2, 0, 2]) @@ -3008,14 +3241,14 @@ def update_mask(self, mask): raise ValueError("Only boolean masks are accepted.") self.mask = mask - def rand(self, shape=None) -> torch.Tensor: + def rand(self, shape: torch.Size = None) -> torch.Tensor: if shape is None: - shape = torch.Size([]) + shape = _size([]) if self.mask is None: return torch.randint( 0, self.space.n, - torch.Size([*shape, *self.shape]), + _size([*shape, *self.shape]), device=self.device, dtype=self.dtype, ) @@ -3035,7 +3268,7 @@ def _project(self, val: torch.Tensor) -> torch.Tensor: if self.mask is None: return val.clamp_(min=0, max=self.space.n - 1) shape = self.mask.shape - shape = torch.Size([*torch.broadcast_shapes(shape[:-1], val.shape), shape[-1]]) + shape = _size([*torch.broadcast_shapes(shape[:-1], val.shape), shape[-1]]) mask_expand = self.mask.expand(shape) gathered = mask_expand.gather(-1, val.unsqueeze(-1)) oob = ~gathered.all(-1) @@ -3054,14 +3287,14 @@ def is_in(self, val: torch.Tensor) -> bool: return False return (0 <= val).all() and (val < self.space.n).all() shape = self.mask.shape - shape = torch.Size([*torch.broadcast_shapes(shape[:-1], val.shape), shape[-1]]) + shape = _size([*torch.broadcast_shapes(shape[:-1], val.shape), shape[-1]]) mask_expand = self.mask.expand(shape) gathered = mask_expand.gather(-1, val.unsqueeze(-1)) return gathered.all() def __getitem__(self, idx: SHAPE_INDEX_TYPING): """Indexes the current TensorSpec based on the provided index.""" - indexed_shape = torch.Size(_shape_indexing(self.shape, idx)) + indexed_shape = _size(_shape_indexing(self.shape, idx)) return self.__class__( n=self.space.n, shape=indexed_shape, @@ -3106,6 +3339,15 @@ def to_one_hot(self, val: torch.Tensor, safe: bool = None) -> torch.Tensor: Returns: The one-hot encoded tensor. + + Examples: + >>> categ = Categorical(3) + >>> categ_sample = categ.zero() + >>> categ_sample + tensor(0) + >>> onehot_sample = categ.to_one_hot(categ_sample) + >>> onehot_sample + tensor([ True, False, False]) """ if safe is None: safe = _CHECK_SPEC_ENCODE @@ -3113,15 +3355,35 @@ def to_one_hot(self, val: torch.Tensor, safe: bool = None) -> torch.Tensor: self.assert_is_in(val) return torch.nn.functional.one_hot(val, self.space.n).bool() - def to_one_hot_spec(self) -> OneHotDiscreteTensorSpec: - """Converts the spec to the equivalent one-hot spec.""" + def to_categorical(self, val: torch.Tensor, safe: bool = None) -> torch.Tensor: + """No-op for categorical.""" + return val + + def to_one_hot_spec(self) -> OneHot: + """Converts the spec to the equivalent one-hot spec. + + Examples: + >>> categ = Categorical(3) + >>> categ.to_one_hot_spec() + OneHot( + shape=torch.Size([3]), + space=CategoricalBox(n=3), + device=cpu, + dtype=torch.bool, + domain=discrete) + + """ shape = [*self.shape, self.space.n] - return OneHotDiscreteTensorSpec( + return OneHot( n=self.space.n, shape=shape, device=self.device, ) + def to_categorical_spec(self) -> Categorical: + """No-op for categorical.""" + return self + def expand(self, *shape): if len(shape) == 1 and isinstance(shape[0], (tuple, list, torch.Size)): shape = shape[0] @@ -3199,7 +3461,7 @@ def unbind(self, dim: int = 0): for i in range(self.shape[dim]) ) - def to(self, dest: Union[torch.dtype, DEVICE_TYPING]) -> CompositeSpec: + def to(self, dest: Union[torch.dtype, DEVICE_TYPING]) -> Categorical: if isinstance(dest, torch.dtype): dest_dtype = dest dest_device = self.device @@ -3214,7 +3476,7 @@ def to(self, dest: Union[torch.dtype, DEVICE_TYPING]) -> CompositeSpec: n=self.space.n, shape=self.shape, device=dest_device, dtype=dest_dtype ) - def clone(self) -> DiscreteTensorSpec: + def clone(self) -> Categorical: return self.__class__( n=self.space.n, shape=self.shape, @@ -3225,32 +3487,59 @@ def clone(self) -> DiscreteTensorSpec: @dataclass(repr=False) -class BinaryDiscreteTensorSpec(DiscreteTensorSpec): +class Binary(Categorical): """A binary discrete tensor spec. + A binary tensor spec encodes tensors of arbitrary size where the values are either 0 or 1 (or ``True`` or ``False`` + if the dtype it ``torch.bool``). + + Unlike :class:`OneHot`, `Binary` can have more than one non-null element along the last dimension. + Args: - n (int): length of the binary vector. + n (int): length of the binary vector. If provided along with ``shape``, ``shape[-1]`` must match ``n``. + If not provided, ``shape`` must be passed. + + .. warning:: the ``n`` argument from ``Binary`` must not be confused with the ``n`` argument from :class:`Categorical` + or :class:`OneHot` which denotes the maximum nmber of elements that can be sampled. + For clarity, use ``shape`` instead. + shape (torch.Size, optional): total shape of the sampled tensors. - If provided, the last dimension must match n. + If provided, the last dimension must match ``n``. device (str, int or torch.device, optional): device of the tensors. - dtype (str or torch.dtype, optional): dtype of the tensors. Defaults to torch.long. + dtype (str or torch.dtype, optional): dtype of the tensors. + Defaults to ``torch.int8``. Examples: - >>> spec = BinaryDiscreteTensorSpec(n=4, shape=(5, 4), device="cpu", dtype=torch.bool) - >>> print(spec.zero()) + >>> torch.manual_seed(0) + >>> spec = Binary(n=4, shape=(2, 4)) + >>> print(spec.rand()) + tensor([[0, 1, 1, 0], + [1, 1, 1, 1]], dtype=torch.int8) + >>> spec = Binary(shape=(2, 4)) + >>> print(spec.rand()) + tensor([[1, 1, 1, 0], + [0, 1, 0, 0]], dtype=torch.int8) + >>> spec = Binary(n=4) + >>> print(spec.rand()) + tensor([0, 0, 0, 1], dtype=torch.int8) + """ def __init__( self, - n: int, + n: int | None = None, shape: Optional[torch.Size] = None, device: Optional[DEVICE_TYPING] = None, dtype: Union[str, torch.dtype] = torch.int8, ): + if n is None and not shape: + raise TypeError("Must provide either n or shape.") + if n is None: + n = shape[-1] if shape is None or not len(shape): - shape = torch.Size((n,)) + shape = _size((n,)) else: - shape = torch.Size(shape) + shape = _size(shape) if shape[-1] != n: raise ValueError( f"The last value of the shape must match n for spec {self.__class__}. " @@ -3318,7 +3607,7 @@ def unbind(self, dim: int = 0): for i in range(self.shape[dim]) ) - def to(self, dest: Union[torch.dtype, DEVICE_TYPING]) -> CompositeSpec: + def to(self, dest: Union[torch.dtype, DEVICE_TYPING]) -> Binary: if isinstance(dest, torch.dtype): dest_dtype = dest dest_device = self.device @@ -3333,7 +3622,7 @@ def to(self, dest: Union[torch.dtype, DEVICE_TYPING]) -> CompositeSpec: n=self.shape[-1], shape=self.shape, device=dest_device, dtype=dest_dtype ) - def clone(self) -> BinaryDiscreteTensorSpec: + def clone(self) -> Binary: return self.__class__( n=self.shape[-1], shape=self.shape, @@ -3349,14 +3638,14 @@ def __getitem__(self, idx: SHAPE_INDEX_TYPING): indexed_shape = _shape_indexing(self.shape[:-1], idx) return self.__class__( n=self.shape[-1], - shape=torch.Size(indexed_shape + [self.shape[-1]]), + shape=_size(indexed_shape + [self.shape[-1]]), device=self.device, dtype=self.dtype, ) def __eq__(self, other): - if not isinstance(other, BinaryDiscreteTensorSpec): - if isinstance(other, DiscreteTensorSpec): + if not isinstance(other, Binary): + if isinstance(other, Categorical): return ( other.n == 2 and other.device == self.device @@ -3368,7 +3657,7 @@ def __eq__(self, other): @dataclass(repr=False) -class MultiDiscreteTensorSpec(DiscreteTensorSpec): +class MultiCategorical(Categorical): """A concatenation of discrete tensor spec. Args: @@ -3385,15 +3674,13 @@ class MultiDiscreteTensorSpec(DiscreteTensorSpec): sample is taken. See :meth:`~.update_mask` for more information. Examples: - >>> ts = MultiDiscreteTensorSpec((3, 2, 3)) + >>> ts = MultiCategorical((3, 2, 3)) >>> ts.is_in(torch.tensor([2, 0, 1])) True - >>> ts.is_in(torch.tensor([2, 2, 1])) + >>> ts.is_in(torch.tensor([2, 10, 1])) False """ - # SPEC_HANDLED_FUNCTIONS = {} - def __init__( self, nvec: Union[Sequence[int], torch.Tensor, int], @@ -3412,7 +3699,7 @@ def __init__( if shape is None: shape = nvec.shape else: - shape = torch.Size(shape) + shape = _size(shape) if shape[-1] != nvec.shape[-1]: raise ValueError( f"The last value of the shape must match nvec.shape[-1] for transform of type {self.__class__}. " @@ -3422,7 +3709,7 @@ def __init__( self.nvec = self.nvec.expand(_remove_neg_shapes(shape)) space = BoxList.from_nvec(self.nvec) - super(DiscreteTensorSpec, self).__init__( + super(Categorical, self).__init__( shape, space, device, dtype, domain="discrete" ) self.update_mask(mask) @@ -3442,9 +3729,10 @@ def update_mask(self, mask): sample is taken. Examples: + >>> torch.manual_seed(0) >>> mask = torch.tensor([False, False, True, ... True, True]) - >>> ts = MultiDiscreteTensorSpec((3, 2), (5, 2,), dtype=torch.int64, mask=mask) + >>> ts = MultiCategorical((3, 2), (5, 2,), dtype=torch.int64, mask=mask) >>> # All but one of the three possible outcomes for the first >>> # group are masked, but neither of the two possible >>> # outcomes for the second group are masked. @@ -3453,7 +3741,7 @@ def update_mask(self, mask): [2, 0], [2, 1], [2, 1], - [2, 0]]) + [2, 1]]) """ if mask is not None: try: @@ -3464,7 +3752,7 @@ def update_mask(self, mask): raise ValueError("Only boolean masks are accepted.") self.mask = mask - def to(self, dest: Union[torch.dtype, DEVICE_TYPING]) -> CompositeSpec: + def to(self, dest: Union[torch.dtype, DEVICE_TYPING]) -> MultiCategorical: if isinstance(dest, torch.dtype): dest_dtype = dest dest_device = self.device @@ -3503,7 +3791,7 @@ def __eq__(self, other): and mask_equal ) - def clone(self) -> MultiDiscreteTensorSpec: + def clone(self) -> MultiCategorical: return self.__class__( nvec=self.nvec.clone(), shape=None, @@ -3541,7 +3829,7 @@ def rand(self, shape: Optional[torch.Size] = None) -> torch.Tensor: *self.shape[:-1], ) x = self._rand(space=self.space, shape=shape, i=self.nvec.ndim) - if self.remove_singleton and self.shape == torch.Size([1]): + if self.remove_singleton and self.shape == _size([1]): x = x.squeeze(-1) return x @@ -3565,9 +3853,7 @@ def _split_self(self): for n, _mask in zip(nvec, mask): shape = self.shape[:-1] result.append( - DiscreteTensorSpec( - n=n, shape=shape, device=device, dtype=dtype, mask=_mask - ) + Categorical(n=n, shape=shape, device=device, dtype=dtype, mask=_mask) ) return result @@ -3620,7 +3906,7 @@ def is_in(self, val: torch.Tensor) -> bool: def to_one_hot( self, val: torch.Tensor, safe: bool = None - ) -> Union[MultiOneHotDiscreteTensorSpec, torch.Tensor]: + ) -> Union[MultiOneHot, torch.Tensor]: """Encodes a discrete tensor from the spec domain into its one-hot correspondent. Args: @@ -3644,16 +3930,24 @@ def to_one_hot( -1, ).to(self.device) - def to_one_hot_spec(self) -> MultiOneHotDiscreteTensorSpec: + def to_one_hot_spec(self) -> MultiOneHot: """Converts the spec to the equivalent one-hot spec.""" nvec = [_space.n for _space in self.space] - return MultiOneHotDiscreteTensorSpec( + return MultiOneHot( nvec, device=self.device, shape=[*self.shape[:-1], sum(nvec)], mask=self.mask, ) + def to_categorical(self, val: torch.Tensor, safe: bool = None) -> MultiCategorical: + """Not op for MultiCategorical.""" + return val + + def to_categorical_spec(self) -> MultiCategorical: + """Not op for MultiCategorical.""" + return self + def expand(self, *shape): if len(shape) == 1 and isinstance(shape[0], (tuple, list, torch.Size)): shape = shape[0] @@ -3770,12 +4064,16 @@ def __getitem__(self, idx: SHAPE_INDEX_TYPING): ) -class CompositeSpec(TensorSpec): +class Composite(TensorSpec): """A composition of TensorSpecs. + If a ``TensorSpec`` is the set-description of Tensor category, the ``Composite`` class is akin to + the :class:`~tensordict.TensorDict` class. Like :class:`~tensordict.TensorDict`, it has a ``shape`` (akin to the + ``TensorDict``'s ``batch_size``) and an optional ``device``. + Args: *args: if an unnamed argument is passed, it must be a dictionary with keys - matching the expected keys to be found in the :obj:`CompositeSpec` object. + matching the expected keys to be found in the :obj:`Composite` object. This is useful to build nested CompositeSpecs with tuple indices. **kwargs (key (str): value (TensorSpec)): dictionary of tensorspecs to be stored. Values can be None, in which case is_in will be assumed @@ -3792,53 +4090,59 @@ class CompositeSpec(TensorSpec): to the batch-size of the corresponding tensordicts. Examples: - >>> pixels_spec = BoundedTensorSpec( - ... torch.zeros(3,32,32), - ... torch.ones(3, 32, 32)) - >>> observation_vector_spec = BoundedTensorSpec(torch.zeros(33), - ... torch.ones(33)) - >>> composite_spec = CompositeSpec( + >>> pixels_spec = Bounded( + ... low=torch.zeros(4, 3, 32, 32), + ... high=torch.ones(4, 3, 32, 32), + ... dtype=torch.uint8 + ... ) + >>> observation_vector_spec = Bounded( + ... low=torch.zeros(4, 33), + ... high=torch.ones(4, 33), + ... dtype=torch.float) + >>> composite_spec = Composite( ... pixels=pixels_spec, - ... observation_vector=observation_vector_spec) - >>> td = TensorDict({"pixels": torch.rand(10,3,32,32), - ... "observation_vector": torch.rand(10,33)}, batch_size=[10]) - >>> print("td (rand) is within bounds: ", composite_spec.is_in(td)) - td (rand) is within bounds: True - >>> td = TensorDict({"pixels": torch.randn(10,3,32,32), - ... "observation_vector": torch.randn(10,33)}, batch_size=[10]) - >>> print("td (randn) is within bounds: ", composite_spec.is_in(td)) - td (randn) is within bounds: False - >>> td_project = composite_spec.project(td) - >>> print("td modification done in place: ", td_project is td) - td modification done in place: True - >>> print("check td is within bounds after projection: ", - ... composite_spec.is_in(td_project)) - check td is within bounds after projection: True - >>> print("random td: ", composite_spec.rand([3,])) - random td: TensorDict( + ... observation_vector=observation_vector_spec, + ... shape=(4,) + ... ) + >>> composite_spec + Composite( + pixels: BoundedDiscrete( + shape=torch.Size([4, 3, 32, 32]), + space=ContinuousBox( + low=Tensor(shape=torch.Size([4, 3, 32, 32]), device=cpu, dtype=torch.uint8, contiguous=True), + high=Tensor(shape=torch.Size([4, 3, 32, 32]), device=cpu, dtype=torch.uint8, contiguous=True)), + device=cpu, + dtype=torch.uint8, + domain=discrete), + observation_vector: BoundedContinuous( + shape=torch.Size([4, 33]), + space=ContinuousBox( + low=Tensor(shape=torch.Size([4, 33]), device=cpu, dtype=torch.float32, contiguous=True), + high=Tensor(shape=torch.Size([4, 33]), device=cpu, dtype=torch.float32, contiguous=True)), + device=cpu, + dtype=torch.float32, + domain=continuous), + device=None, + shape=torch.Size([4])) + >>> td = composite_spec.rand() + >>> td + TensorDict( fields={ - observation_vector: Tensor(torch.Size([3, 33]), dtype=torch.float32), - pixels: Tensor(torch.Size([3, 3, 32, 32]), dtype=torch.float32)}, - batch_size=torch.Size([3]), + observation_vector: Tensor(shape=torch.Size([4, 33]), device=cpu, dtype=torch.float32, is_shared=False), + pixels: Tensor(shape=torch.Size([4, 3, 32, 32]), device=cpu, dtype=torch.uint8, is_shared=False)}, + batch_size=torch.Size([4]), device=None, is_shared=False) - - Examples: >>> # we can build a nested composite spec using unnamed arguments - >>> print(CompositeSpec({("a", "b"): None, ("a", "c"): None})) - CompositeSpec( - a: CompositeSpec( + >>> print(Composite({("a", "b"): None, ("a", "c"): None})) + Composite( + a: Composite( b: None, - c: None)) - - CompositeSpec supports nested indexing: - >>> spec = CompositeSpec(obs=None) - >>> spec["nested", "x"] = None - >>> print(spec) - CompositeSpec( - nested: CompositeSpec( - x: None), - x: None) + c: None, + device=None, + shape=torch.Size([])), + device=None, + shape=torch.Size([])) """ @@ -3862,17 +4166,17 @@ def shape(self, value: torch.Size): if self.locked: raise RuntimeError("Cannot modify shape of locked composite spec.") for key, spec in self.items(): - if isinstance(spec, CompositeSpec): + if isinstance(spec, Composite): if spec.shape[: len(value)] != value: spec.shape = value elif spec is not None: if spec.shape[: len(value)] != value: raise ValueError( - f"The shape of the spec and the CompositeSpec mismatch during shape resetting: the " + f"The shape of the spec and the Composite mismatch during shape resetting: the " f"{self.ndim} first dimensions should match but got self['{key}'].shape={spec.shape} and " - f"CompositeSpec.shape={self.shape}." + f"Composite.shape={self.shape}." ) - self._shape = torch.Size(value) + self._shape = _size(value) def is_empty(self): """Whether the composite spec contains specs or not.""" @@ -3887,25 +4191,30 @@ def ndimension(self): def set(self, name, spec): if self.locked: - raise RuntimeError("Cannot modify a locked CompositeSpec.") + raise RuntimeError("Cannot modify a locked Composite.") if spec is not None: shape = spec.shape if shape[: self.ndim] != self.shape: raise ValueError( - "The shape of the spec and the CompositeSpec mismatch: the first " + "The shape of the spec and the Composite mismatch: the first " f"{self.ndim} dimensions should match but got spec.shape={spec.shape} and " - f"CompositeSpec.shape={self.shape}." + f"Composite.shape={self.shape}." ) self._specs[name] = spec - def __init__(self, *args, shape=None, device=None, **kwargs): + def __init__( + self, *args, shape: torch.Size = None, device: torch.device = None, **kwargs + ): + # For compatibility with TensorDict + batch_size = kwargs.pop("batch_size", None) + if batch_size is not None: + if shape is not None: + raise TypeError("Cannot specify both batch_size and shape.") + shape = batch_size + if shape is None: - # Should we do this? Other specs have a default empty shape, maybe it would make sense to keep it - # optional for composite (for clarity and easiness of use). - # warnings.warn("shape=None for CompositeSpec will soon be deprecated. Make sure you set the " - # "batch size of your CompositeSpec as you would do for a tensordict.") - shape = [] - self._shape = torch.Size(shape) + shape = _size(()) + self._shape = _size(shape) self._specs = {} for key, value in kwargs.items(): self.set(key, value) @@ -3918,7 +4227,7 @@ def __init__(self, *args, shape=None, device=None, **kwargs): if item is None: continue if ( - isinstance(item, CompositeSpec) + isinstance(item, Composite) and item.device is None and _device is not None ): @@ -3927,22 +4236,22 @@ def __init__(self, *args, shape=None, device=None, **kwargs): raise RuntimeError( f"Setting a new attribute ({key}) on another device " f"({item.device} against {_device}). All devices of " - "CompositeSpec must match." + "Composite must match." ) self._device = _device if len(args): if len(args) > 1: raise RuntimeError( - "Got multiple arguments, when at most one is expected for CompositeSpec." + "Got multiple arguments, when at most one is expected for Composite." ) argdict = args[0] - if not isinstance(argdict, (dict, CompositeSpec)): + if not isinstance(argdict, (dict, Composite)): raise RuntimeError( f"Expected a dictionary of specs, but got an argument of type {type(argdict)}." ) for k, item in argdict.items(): if isinstance(item, dict): - item = CompositeSpec(item, shape=shape, device=_device) + item = Composite(item, shape=shape, device=_device) self[k] = item @property @@ -3959,14 +4268,14 @@ def device(self, device: DEVICE_TYPING): self.to(device) def clear_device_(self): - """Clears the device of the CompositeSpec.""" + """Clears the device of the Composite.""" self._device = None for spec in self._specs.values(): spec.clear_device_() return self def __getitem__(self, idx): - """Indexes the current CompositeSpec based on the provided index.""" + """Indexes the current Composite based on the provided index.""" if isinstance(idx, (str, tuple)): idx_unravel = unravel_key(idx) else: @@ -3975,7 +4284,7 @@ def __getitem__(self, idx): if isinstance(idx_unravel, tuple): return self[idx[0]][idx[1:]] if idx_unravel in {"shape", "device", "dtype", "space"}: - raise AttributeError(f"CompositeSpec has no key {idx_unravel}") + raise AttributeError(f"Composite has no key {idx_unravel}") return self._specs[idx_unravel] indexed_shape = _shape_indexing(self.shape, idx) @@ -3987,9 +4296,9 @@ def __getitem__(self, idx): if any( isinstance(v, spec_class) for spec_class in [ - BinaryDiscreteTensorSpec, - MultiDiscreteTensorSpec, - OneHotDiscreteTensorSpec, + Binary, + MultiCategorical, + OneHot, ] ): protected_dims = 1 @@ -4011,7 +4320,7 @@ def __getitem__(self, idx): ) def get(self, item, default=NO_DEFAULT): - """Gets an item from the CompositeSpec. + """Gets an item from the Composite. If the item is absent, a default value can be passed. @@ -4026,7 +4335,7 @@ def get(self, item, default=NO_DEFAULT): def __setitem__(self, key, value): if isinstance(key, tuple) and len(key) > 1: if key[0] not in self.keys(True): - self[key[0]] = CompositeSpec(shape=self.shape, device=self.device) + self[key[0]] = Composite(shape=self.shape, device=self.device) self[key[0]][key[1:]] = value return elif isinstance(key, tuple): @@ -4035,20 +4344,20 @@ def __setitem__(self, key, value): elif not isinstance(key, str): raise TypeError(f"Got key of type {type(key)} when a string was expected.") if key in {"shape", "device", "dtype", "space"}: - raise AttributeError(f"CompositeSpec[{key}] cannot be set") + raise AttributeError(f"Composite[{key}] cannot be set") if isinstance(value, dict): - value = CompositeSpec(value, device=self._device, shape=self.shape) + value = Composite(value, device=self._device, shape=self.shape) if ( value is not None and self.device is not None and value.device != self.device ): - if isinstance(value, CompositeSpec) and value.device is None: + if isinstance(value, Composite) and value.device is None: value = value.clone().to(self.device) else: raise RuntimeError( f"Setting a new attribute ({key}) on another device ({value.device} against {self.device}). " - f"All devices of CompositeSpec must match." + f"All devices of Composite must match." ) self.set(key, value) @@ -4077,17 +4386,17 @@ def encode( if isinstance(vals, TensorDict): out = vals.empty() # create and empty tensordict similar to vals else: - out = TensorDict._new_unsafe({}, torch.Size([])) + out = TensorDict._new_unsafe({}, _size([])) for key, item in vals.items(): if item is None: raise RuntimeError( - "CompositeSpec.encode cannot be used with missing values." + "Composite.encode cannot be used with missing values." ) try: out[key] = self[key].encode(item, ignore_device=ignore_device) except KeyError: raise KeyError( - f"The CompositeSpec instance with keys {self.keys()} does not have a '{key}' key." + f"The Composite instance with keys {self.keys()} does not have a '{key}' key." ) except RuntimeError as err: raise RuntimeError( @@ -4100,7 +4409,7 @@ def __repr__(self) -> str: indent(f"{k}: {str(item)}", 4 * " ") for k, item in self._specs.items() ] sub_str = ",\n".join(sub_str) - return f"CompositeSpec(\n{sub_str},\n device={self._device},\n shape={self.shape})" + return f"Composite(\n{sub_str},\n device={self._device},\n shape={self.shape})" def type_check( self, @@ -4119,9 +4428,9 @@ def type_check( def is_in(self, val: Union[dict, TensorDictBase]) -> bool: for key, item in self._specs.items(): - if item is None or (isinstance(item, CompositeSpec) and item.is_empty()): + if item is None or (isinstance(item, Composite) and item.is_empty()): continue - val_item = val.get(key) + val_item = val.get(key, NO_DEFAULT) if not item.is_in(val_item): return False return True @@ -4135,9 +4444,9 @@ def project(self, val: TensorDictBase) -> TensorDictBase: val.set(key, self._specs[key].project(_val)) return val - def rand(self, shape=None) -> TensorDictBase: + def rand(self, shape: torch.Size = None) -> TensorDictBase: if shape is None: - shape = torch.Size([]) + shape = _size([]) _dict = {} for key, item in self.items(): if item is not None: @@ -4146,7 +4455,7 @@ def rand(self, shape=None) -> TensorDictBase: # TensorDict requirements return TensorDict._new_unsafe( _dict, - batch_size=torch.Size([*shape, *self.shape]), + batch_size=_size([*shape, *self.shape]), device=self._device, ) @@ -4157,24 +4466,24 @@ def keys( *, is_leaf: Callable[[type], bool] | None = None, ) -> _CompositeSpecKeysView: # noqa: D417 - """Keys of the CompositeSpec. + """Keys of the Composite. The keys argument reflect those of :class:`tensordict.TensorDict`. Args: include_nested (bool, optional): if ``False``, the returned keys will not be nested. They will represent only the immediate children of the root, and not the whole nested sequence, i.e. a - :obj:`CompositeSpec(next=CompositeSpec(obs=None))` will lead to the keys + :obj:`Composite(next=Composite(obs=None))` will lead to the keys :obj:`["next"]. Default is ``False``, i.e. nested keys will not be returned. leaves_only (bool, optional): if ``False``, the values returned - will contain every level of nesting, i.e. a :obj:`CompositeSpec(next=CompositeSpec(obs=None))` + will contain every level of nesting, i.e. a :obj:`Composite(next=Composite(obs=None))` will lead to the keys :obj:`["next", ("next", "obs")]`. Default is ``False``. Keyword Args: is_leaf (callable, optional): reads a type and returns a boolean indicating if that type - should be seen as a leaf. By default, all non-CompositeSpec nodes are considered as + should be seen as a leaf. By default, all non-Composite nodes are considered as leaves. """ @@ -4192,22 +4501,22 @@ def items( *, is_leaf: Callable[[type], bool] | None = None, ) -> _CompositeSpecItemsView: # noqa: D417 - """Items of the CompositeSpec. + """Items of the Composite. Args: include_nested (bool, optional): if ``False``, the returned keys will not be nested. They will represent only the immediate children of the root, and not the whole nested sequence, i.e. a - :obj:`CompositeSpec(next=CompositeSpec(obs=None))` will lead to the keys + :obj:`Composite(next=Composite(obs=None))` will lead to the keys :obj:`["next"]. Default is ``False``, i.e. nested keys will not be returned. leaves_only (bool, optional): if ``False``, the values returned - will contain every level of nesting, i.e. a :obj:`CompositeSpec(next=CompositeSpec(obs=None))` + will contain every level of nesting, i.e. a :obj:`Composite(next=Composite(obs=None))` will lead to the keys :obj:`["next", ("next", "obs")]`. Default is ``False``. Keyword Args: is_leaf (callable, optional): reads a type and returns a boolean indicating if that type - should be seen as a leaf. By default, all non-CompositeSpec nodes are considered as + should be seen as a leaf. By default, all non-Composite nodes are considered as leaves. """ return _CompositeSpecItemsView( @@ -4224,22 +4533,22 @@ def values( *, is_leaf: Callable[[type], bool] | None = None, ) -> _CompositeSpecValuesView: # noqa: D417 - """Values of the CompositeSpec. + """Values of the Composite. Args: include_nested (bool, optional): if ``False``, the returned keys will not be nested. They will represent only the immediate children of the root, and not the whole nested sequence, i.e. a - :obj:`CompositeSpec(next=CompositeSpec(obs=None))` will lead to the keys + :obj:`Composite(next=Composite(obs=None))` will lead to the keys :obj:`["next"]. Default is ``False``, i.e. nested keys will not be returned. leaves_only (bool, optional): if ``False``, the values returned - will contain every level of nesting, i.e. a :obj:`CompositeSpec(next=CompositeSpec(obs=None))` + will contain every level of nesting, i.e. a :obj:`Composite(next=Composite(obs=None))` will lead to the keys :obj:`["next", ("next", "obs")]`. Default is ``False``. Keyword Args: is_leaf (callable, optional): reads a type and returns a boolean indicating if that type - should be seen as a leaf. By default, all non-CompositeSpec nodes are considered as + should be seen as a leaf. By default, all non-Composite nodes are considered as leaves. """ return _CompositeSpecItemsView( @@ -4254,7 +4563,7 @@ def _reshape(self, shape): key: val.reshape((*shape, *val.shape[self.ndimension() :])) for key, val in self._specs.items() } - return CompositeSpec(_specs, shape=shape) + return Composite(_specs, shape=shape) def _unflatten(self, dim, sizes): shape = torch.zeros(self.shape, device="meta").unflatten(dim, sizes).shape @@ -4263,12 +4572,12 @@ def _unflatten(self, dim, sizes): def __len__(self): return len(self.keys()) - def to(self, dest: Union[torch.dtype, DEVICE_TYPING]) -> CompositeSpec: + def to(self, dest: Union[torch.dtype, DEVICE_TYPING]) -> Composite: if dest is None: return self if not isinstance(dest, (str, int, torch.device)): raise ValueError( - "Only device casting is allowed with specs of type CompositeSpec." + "Only device casting is allowed with specs of type Composite." ) if self._device and self._device == torch.device(dest): return self @@ -4283,7 +4592,7 @@ def to(self, dest: Union[torch.dtype, DEVICE_TYPING]) -> CompositeSpec: kwargs[key] = value.to(dest) return self.__class__(**kwargs, device=_device, shape=self.shape) - def clone(self) -> CompositeSpec: + def clone(self) -> Composite: try: device = self.device except RuntimeError: @@ -4312,9 +4621,9 @@ def empty(self): def to_numpy(self, val: TensorDict, safe: bool = None) -> dict: return {key: self[key].to_numpy(val) for key, val in val.items()} - def zero(self, shape=None) -> TensorDictBase: + def zero(self, shape: torch.Size = None) -> TensorDictBase: if shape is None: - shape = torch.Size([]) + shape = _size([]) try: device = self.device except RuntimeError: @@ -4325,7 +4634,7 @@ def zero(self, shape=None) -> TensorDictBase: for key in self.keys(True) if isinstance(key, str) and self[key] is not None }, - torch.Size([*shape, *self._safe_shape]), + _size([*shape, *self._safe_shape]), device=device, ) @@ -4338,9 +4647,9 @@ def __eq__(self, other): and all((self._specs[key] == spec) for (key, spec) in other._specs.items()) ) - def update(self, dict_or_spec: Union[CompositeSpec, Dict[str, TensorSpec]]) -> None: + def update(self, dict_or_spec: Union[Composite, Dict[str, TensorSpec]]) -> None: for key, item in dict_or_spec.items(): - if key in self.keys(True) and isinstance(self[key], CompositeSpec): + if key in self.keys(True) and isinstance(self[key], Composite): self[key].update(item) continue try: @@ -4381,7 +4690,7 @@ def expand(self, *shape): else None for key, value in tuple(self.items()) } - out = CompositeSpec( + out = Composite( specs, shape=shape, device=device, @@ -4402,7 +4711,7 @@ def squeeze(self, dim: int | None = None): except RuntimeError: device = self._device - return CompositeSpec( + return Composite( {key: value.squeeze(dim) for key, value in self.items()}, shape=shape, device=device, @@ -4428,7 +4737,7 @@ def unsqueeze(self, dim: int): except RuntimeError: device = self._device - return CompositeSpec( + return Composite( { key: value.unsqueeze(dim) if value is not None else None for key, value in self.items() @@ -4457,19 +4766,19 @@ def unbind(self, dim: int = 0): ) def lock_(self, recurse=False): - """Locks the CompositeSpec and prevents modification of its content. + """Locks the Composite and prevents modification of its content. This is only a first-level lock, unless specified otherwise through the ``recurse`` arg. Leaf specs can always be modified in place, but they cannot be replaced - in their CompositeSpec parent. + in their Composite parent. Examples: >>> shape = [3, 4, 5] - >>> spec = CompositeSpec( - ... a=CompositeSpec( - ... b=CompositeSpec(shape=shape[:3], device="cpu"), shape=shape[:2] + >>> spec = Composite( + ... a=Composite( + ... b=Composite(shape=shape[:3], device="cpu"), shape=shape[:2] ... ), ... shape=shape[:1], ... ) @@ -4500,12 +4809,12 @@ def lock_(self, recurse=False): self._locked = True if recurse: for value in self.values(): - if isinstance(value, CompositeSpec): + if isinstance(value, Composite): value.lock_(recurse) return self def unlock_(self, recurse=False): - """Unlocks the CompositeSpec and allows modification of its content. + """Unlocks the Composite and allows modification of its content. This is only a first-level lock modification, unless specified otherwise through the ``recurse`` arg. @@ -4514,7 +4823,7 @@ def unlock_(self, recurse=False): self._locked = False if recurse: for value in self.values(): - if isinstance(value, CompositeSpec): + if isinstance(value, Composite): value.unlock_(recurse) return self @@ -4523,7 +4832,7 @@ def locked(self): return self._locked -class LazyStackedCompositeSpec(_LazyStackedMixin[CompositeSpec], CompositeSpec): +class StackedComposite(_LazyStackedMixin[Composite], Composite): """A lazy representation of a stack of composite specs. Stacks composite specs together along one dimension. @@ -4539,7 +4848,7 @@ class LazyStackedCompositeSpec(_LazyStackedMixin[CompositeSpec], CompositeSpec): def update(self, dict) -> None: for key, item in dict.items(): if key in self.keys() and isinstance( - item, (Dict, CompositeSpec, LazyStackedCompositeSpec) + item, (Dict, Composite, StackedComposite) ): for spec, sub_item in zip(self._specs, item.unbind(self.dim)): spec[key].update(sub_item) @@ -4548,7 +4857,7 @@ def update(self, dict) -> None: return self def __eq__(self, other): - if not isinstance(other, LazyStackedCompositeSpec): + if not isinstance(other, StackedComposite): return False if len(self._specs) != len(other._specs): return False @@ -4567,7 +4876,7 @@ def to_numpy(self, val: TensorDict, safe: bool = None) -> dict: if safe: if val.shape[self.dim] != len(self._specs): raise ValueError( - "Size of LazyStackedCompositeSpec and val differ along the " + "Size of StackedComposite and val differ along the " "stacking dimension" ) for spec, v in zip(self._specs, torch.unbind(val, dim=self.dim)): @@ -4665,7 +4974,7 @@ def __repr__(self) -> str: string = ",\n".join( [sub_str, exclusive_key_str, device_str, shape_str, stack_dim] ) - return f"LazyStackedCompositeSpec(\n{string})" + return f"StackedComposite(\n{string})" def repr_exclusive_keys(self): keys = set(self.keys()) @@ -4771,7 +5080,7 @@ def shape(self): if dim < 0: dim = len(shape) + dim + 1 shape.insert(dim, len(self._specs)) - return torch.Size(shape) + return _size(shape) def expand(self, *shape): if len(shape) == 1 and not isinstance(shape[0], (int,)): @@ -4803,7 +5112,7 @@ def expand(self, *shape): ) def empty(self): - return LazyStackedCompositeSpec.maybe_dense_stack( + return StackedComposite.maybe_dense_stack( [spec.empty() for spec in self._specs], dim=self.stack_dim ) @@ -4812,7 +5121,7 @@ def encode( ) -> Dict[str, torch.Tensor]: raise NOT_IMPLEMENTED_ERROR - def zero(self, shape=None) -> TensorDictBase: + def zero(self, shape: torch.Size = None) -> TensorDictBase: if shape is not None: dim = self.dim + len(shape) else: @@ -4821,7 +5130,7 @@ def zero(self, shape=None) -> TensorDictBase: [spec.zero(shape) for spec in self._specs], dim ) - def one(self, shape=None) -> TensorDictBase: + def one(self, shape: torch.Size = None) -> TensorDictBase: if shape is not None: dim = self.dim + len(shape) else: @@ -4830,7 +5139,7 @@ def one(self, shape=None) -> TensorDictBase: [spec.one(shape) for spec in self._specs], dim ) - def rand(self, shape=None) -> TensorDictBase: + def rand(self, shape: torch.Size = None) -> TensorDictBase: if shape is not None: dim = self.dim + len(shape) else: @@ -4840,7 +5149,6 @@ def rand(self, shape=None) -> TensorDictBase: ) -# for SPEC_CLASS in [BinaryDiscreteTensorSpec, BoundedTensorSpec, DiscreteTensorSpec, MultiDiscreteTensorSpec, MultiOneHotDiscreteTensorSpec, OneHotDiscreteTensorSpec, UnboundedContinuousTensorSpec, UnboundedDiscreteTensorSpec]: @TensorSpec.implements_for_spec(torch.stack) def _stack_specs(list_of_spec, dim, out=None): if out is not None: @@ -4873,12 +5181,12 @@ def _stack_specs(list_of_spec, dim, out=None): dim += len(shape) + 1 shape.insert(dim, len(list_of_spec)) return spec0.clone().unsqueeze(dim).expand(shape) - return LazyStackedTensorSpec(*list_of_spec, dim=dim) + return Stacked(*list_of_spec, dim=dim) else: raise NotImplementedError -@CompositeSpec.implements_for_spec(torch.stack) +@Composite.implements_for_spec(torch.stack) def _stack_composite_specs(list_of_spec, dim, out=None): if out is not None: raise NotImplementedError( @@ -4888,7 +5196,7 @@ def _stack_composite_specs(list_of_spec, dim, out=None): if not len(list_of_spec): raise ValueError("Cannot stack an empty list of specs.") spec0 = list_of_spec[0] - if isinstance(spec0, CompositeSpec): + if isinstance(spec0, Composite): devices = {spec.device for spec in list_of_spec} if len(devices) == 1: device = list(devices)[0] @@ -4903,7 +5211,7 @@ def _stack_composite_specs(list_of_spec, dim, out=None): all_equal = True for spec in list_of_spec[1:]: - if not isinstance(spec, CompositeSpec): + if not isinstance(spec, Composite): raise RuntimeError( "Stacking specs cannot occur: Found more than one type of spec in " "the list." @@ -4920,7 +5228,7 @@ def _stack_composite_specs(list_of_spec, dim, out=None): dim += len(shape) + 1 shape.insert(dim, len(list_of_spec)) return spec0.clone().unsqueeze(dim).expand(shape) - return LazyStackedCompositeSpec(*list_of_spec, dim=dim) + return StackedComposite(*list_of_spec, dim=dim) else: raise NotImplementedError @@ -4930,8 +5238,8 @@ def _squeeze_spec(spec: TensorSpec, *args, **kwargs) -> TensorSpec: return spec.squeeze(*args, **kwargs) -@CompositeSpec.implements_for_spec(torch.squeeze) -def _squeeze_composite_spec(spec: CompositeSpec, *args, **kwargs) -> CompositeSpec: +@Composite.implements_for_spec(torch.squeeze) +def _squeeze_composite_spec(spec: Composite, *args, **kwargs) -> Composite: return spec.squeeze(*args, **kwargs) @@ -4940,16 +5248,16 @@ def _unsqueeze_spec(spec: TensorSpec, *args, **kwargs) -> TensorSpec: return spec.unsqueeze(*args, **kwargs) -@CompositeSpec.implements_for_spec(torch.unsqueeze) -def _unsqueeze_composite_spec(spec: CompositeSpec, *args, **kwargs) -> CompositeSpec: +@Composite.implements_for_spec(torch.unsqueeze) +def _unsqueeze_composite_spec(spec: Composite, *args, **kwargs) -> Composite: return spec.unsqueeze(*args, **kwargs) def _keys_to_empty_composite_spec(keys): - """Given a list of keys, creates a CompositeSpec tree where each leaf is assigned a None value.""" + """Given a list of keys, creates a Composite tree where each leaf is assigned a None value.""" if not len(keys): return - c = CompositeSpec() + c = Composite() for key in keys: if isinstance(key, str): c[key] = None @@ -4957,7 +5265,7 @@ def _keys_to_empty_composite_spec(keys): if c[key[0]] is None: # if the value is None we just replace it c[key[0]] = _keys_to_empty_composite_spec([key[1:]]) - elif isinstance(c[key[0]], CompositeSpec): + elif isinstance(c[key[0]], Composite): # if the value is Composite, we update it out = _keys_to_empty_composite_spec([key[1:]]) if out is not None: @@ -4973,7 +5281,7 @@ def _squeezed_shape(shape: torch.Size, dim: int | None) -> torch.Size | None: if dim is None: if len(shape) == 1 or shape.count(1) == 0: return None - new_shape = torch.Size([s for s in shape if s != 1]) + new_shape = _size([s for s in shape if s != 1]) else: if dim < 0: dim += len(shape) @@ -4981,7 +5289,7 @@ def _squeezed_shape(shape: torch.Size, dim: int | None) -> torch.Size | None: if shape[dim] != 1: return None - new_shape = torch.Size([s for i, s in enumerate(shape) if i != dim]) + new_shape = _size([s for i, s in enumerate(shape) if i != dim]) return new_shape @@ -4997,15 +5305,15 @@ def _unsqueezed_shape(shape: torch.Size, dim: int) -> torch.Size: new_shape = list(shape) new_shape.insert(dim, 1) - return torch.Size(new_shape) + return _size(new_shape) class _CompositeSpecItemsView: - """Wrapper class that enables richer behaviour of `items` for CompositeSpec.""" + """Wrapper class that enables richer behavior of `items` for Composite.""" def __init__( self, - composite: CompositeSpec, + composite: Composite, include_nested, leaves_only, *, @@ -5023,13 +5331,13 @@ def __iter__(self): if is_leaf in (None, _NESTED_TENSORS_AS_LISTS): def _is_leaf(cls): - return not issubclass(cls, CompositeSpec) + return not issubclass(cls, Composite) else: _is_leaf = is_leaf def _iter_from_item(key, item): - if self.include_nested and isinstance(item, CompositeSpec): + if self.include_nested and isinstance(item, Composite): for subkey, subitem in item.items( include_nested=True, leaves_only=self.leaves_only, @@ -5054,7 +5362,7 @@ def _iter_from_item(key, item): def _get_composite_items(self, is_leaf): - if isinstance(self.composite, LazyStackedCompositeSpec): + if isinstance(self.composite, StackedComposite): from tensordict.base import _NESTED_TENSORS_AS_LISTS if is_leaf is _NESTED_TENSORS_AS_LISTS: @@ -5141,5 +5449,149 @@ def _minmax_dtype(dtype): def _remove_neg_shapes(*shape): if len(shape) == 1 and not isinstance(shape[0], int): - return _remove_neg_shapes(*shape[0]) - return torch.Size([int(d) if d >= 0 else 1 for d in shape]) + shape = shape[0] + if isinstance(shape, np.integer): + shape = (int(shape),) + return _remove_neg_shapes(*shape) + return _size([int(d) if d >= 0 else 1 for d in shape]) + + +############## +# Legacy +# +class _LegacySpecMeta(abc.ABCMeta): + def __call__(cls, *args, **kwargs): + warnings.warn( + f"The {cls.__name__} has been deprecated and will be removed in v0.7. Please use " + f"{cls.__bases__[-1].__name__} instead.", + category=DeprecationWarning, + ) + instance = super().__call__(*args, **kwargs) + if ( + type(instance) in (UnboundedDiscreteTensorSpec, UnboundedDiscrete) + and instance.domain == "continuous" + ): + instance.__class__ = UnboundedContinuous + elif ( + type(instance) in (UnboundedContinuousTensorSpec, UnboundedContinuous) + and instance.domain == "discrete" + ): + instance.__class__ = UnboundedDiscrete + return instance + + def __instancecheck__(cls, instance): + check0 = super().__instancecheck__(instance) + if check0: + return True + parent_cls = cls.__bases__[-1] + return isinstance(instance, parent_cls) + + +class CompositeSpec(Composite, metaclass=_LegacySpecMeta): + """Deprecated version of :class:`torchrl.data.Composite`.""" + + ... + + +class OneHotDiscreteTensorSpec(OneHot, metaclass=_LegacySpecMeta): + """Deprecated version of :class:`torchrl.data.OneHot`.""" + + ... + + +class MultiOneHotDiscreteTensorSpec(MultiOneHot, metaclass=_LegacySpecMeta): + """Deprecated version of :class:`torchrl.data.MultiOneHot`.""" + + ... + + +class NonTensorSpec(NonTensor, metaclass=_LegacySpecMeta): + """Deprecated version of :class:`torchrl.data.NonTensor`.""" + + ... + + +class MultiDiscreteTensorSpec(MultiCategorical, metaclass=_LegacySpecMeta): + """Deprecated version of :class:`torchrl.data.MultiCategorical`.""" + + ... + + +class LazyStackedTensorSpec(Stacked, metaclass=_LegacySpecMeta): + """Deprecated version of :class:`torchrl.data.Stacked`.""" + + ... + + +class LazyStackedCompositeSpec(StackedComposite, metaclass=_LegacySpecMeta): + """Deprecated version of :class:`torchrl.data.StackedComposite`.""" + + ... + + +class DiscreteTensorSpec(Categorical, metaclass=_LegacySpecMeta): + """Deprecated version of :class:`torchrl.data.Categorical`.""" + + ... + + +class BinaryDiscreteTensorSpec(Binary, metaclass=_LegacySpecMeta): + """Deprecated version of :class:`torchrl.data.Binary`.""" + + ... + + +_BoundedLegacyMeta = type("_BoundedLegacyMeta", (_LegacySpecMeta, _BoundedMeta), {}) + + +class BoundedTensorSpec(Bounded, metaclass=_BoundedLegacyMeta): + """Deprecated version of :class:`torchrl.data.Bounded`.""" + + ... + + +class _UnboundedContinuousMetaclass(_UnboundedMeta): + def __instancecheck__(cls, instance): + return isinstance(instance, Unbounded) and instance.domain == "continuous" + + +_LegacyUnboundedContinuousMetaclass = type( + "_LegacyUnboundedDiscreteMetaclass", + (_UnboundedContinuousMetaclass, _LegacySpecMeta), + {}, +) + + +class UnboundedContinuousTensorSpec( + Unbounded, metaclass=_LegacyUnboundedContinuousMetaclass +): + """Deprecated version of :class:`torchrl.data.Unbounded` with continuous space.""" + + ... + + +class _UnboundedDiscreteMetaclass(_UnboundedMeta): + def __instancecheck__(cls, instance): + return isinstance(instance, Unbounded) and instance.domain == "discrete" + + +_LegacyUnboundedDiscreteMetaclass = type( + "_LegacyUnboundedDiscreteMetaclass", + (_UnboundedDiscreteMetaclass, _LegacySpecMeta), + {}, +) + + +class UnboundedDiscreteTensorSpec( + Unbounded, metaclass=_LegacyUnboundedDiscreteMetaclass +): + """Deprecated version of :class:`torchrl.data.Unbounded` with discrete space.""" + + def __init__( + self, + shape: Union[torch.Size, int] = _DEFAULT_SHAPE, + device: Optional[DEVICE_TYPING] = None, + dtype: Optional[Union[str, torch.dtype]] = torch.int64, + **kwargs, + ): + super().__init__(shape=shape, device=device, dtype=dtype, **kwargs) diff --git a/torchrl/data/utils.py b/torchrl/data/utils.py index fb4ec30daed..db2c8afca10 100644 --- a/torchrl/data/utils.py +++ b/torchrl/data/utils.py @@ -4,6 +4,7 @@ # LICENSE file in the root directory of this source tree. from __future__ import annotations +import functools import typing from typing import Any, Callable, List, Tuple, Union @@ -13,14 +14,14 @@ from torch import Tensor from torchrl.data.tensor_specs import ( - BinaryDiscreteTensorSpec, - CompositeSpec, - DiscreteTensorSpec, - LazyStackedCompositeSpec, - LazyStackedTensorSpec, - MultiDiscreteTensorSpec, - MultiOneHotDiscreteTensorSpec, - OneHotDiscreteTensorSpec, + Binary, + Categorical, + Composite, + MultiCategorical, + MultiOneHot, + OneHot, + Stacked, + StackedComposite, TensorSpec, ) @@ -50,10 +51,10 @@ ACTION_SPACE_MAP = { - OneHotDiscreteTensorSpec: "one_hot", - MultiOneHotDiscreteTensorSpec: "mult_one_hot", - BinaryDiscreteTensorSpec: "binary", - DiscreteTensorSpec: "categorical", + OneHot: "one_hot", + MultiOneHot: "mult_one_hot", + Binary: "binary", + Categorical: "categorical", "one_hot": "one_hot", "one-hot": "one_hot", "mult_one_hot": "mult_one_hot", @@ -62,7 +63,7 @@ "multi-one-hot": "mult_one_hot", "binary": "binary", "categorical": "categorical", - MultiDiscreteTensorSpec: "multi_categorical", + MultiCategorical: "multi_categorical", "multi_categorical": "multi_categorical", "multi-categorical": "multi_categorical", "multi_discrete": "multi_categorical", @@ -71,14 +72,14 @@ def consolidate_spec( - spec: CompositeSpec, + spec: Composite, recurse_through_entries: bool = True, recurse_through_stack: bool = True, ): """Given a TensorSpec, removes exclusive keys by adding 0 shaped specs. Args: - spec (CompositeSpec): the spec to be consolidated. + spec (Composite): the spec to be consolidated. recurse_through_entries (bool): if True, call the function recursively on all entries of the spec. Default is True. recurse_through_stack (bool): if True, if the provided spec is lazy, the function recursively @@ -87,10 +88,10 @@ def consolidate_spec( """ spec = spec.clone() - if not isinstance(spec, (CompositeSpec, LazyStackedCompositeSpec)): + if not isinstance(spec, (Composite, StackedComposite)): return spec - if isinstance(spec, LazyStackedCompositeSpec): + if isinstance(spec, StackedComposite): keys = set(spec.keys()) # shared keys exclusive_keys_per_spec = [ set() for _ in range(len(spec._specs)) @@ -128,7 +129,7 @@ def consolidate_spec( if recurse_through_entries: for key, value in spec.items(): - if isinstance(value, (CompositeSpec, LazyStackedCompositeSpec)): + if isinstance(value, (Composite, StackedComposite)): spec.set( key, consolidate_spec( @@ -145,16 +146,16 @@ def _empty_like_spec(specs: List[TensorSpec], shape): "Found same key in lazy specs corresponding to entries with different classes" ) spec = specs[0] - if isinstance(spec, (CompositeSpec, LazyStackedCompositeSpec)): + if isinstance(spec, (Composite, StackedComposite)): # the exclusive key has values which are CompositeSpecs -> # we create an empty composite spec with same batch size return spec.empty() - elif isinstance(spec, LazyStackedTensorSpec): + elif isinstance(spec, Stacked): # the exclusive key has values which are LazyStackedTensorSpecs -> # we create a LazyStackedTensorSpec with the same shape (aka same -1s) as the first in the list. # this will not add any new -1s when they are stacked shape = list(shape[: spec.stack_dim]) + list(shape[spec.stack_dim + 1 :]) - return LazyStackedTensorSpec( + return Stacked( *[_empty_like_spec(spec._specs, shape) for _ in spec._specs], dim=spec.stack_dim, ) @@ -191,14 +192,14 @@ def check_no_exclusive_keys(spec: TensorSpec, recurse: bool = True): spec (TensorSpec): the spec to check recurse (bool): if True, check recursively in nested specs. Default is True. """ - if isinstance(spec, LazyStackedCompositeSpec): + if isinstance(spec, StackedComposite): keys = set(spec.keys()) for inner_td in spec._specs: if recurse and not check_no_exclusive_keys(inner_td): return False if set(inner_td.keys()) != keys: return False - elif isinstance(spec, CompositeSpec) and recurse: + elif isinstance(spec, Composite) and recurse: for value in spec.values(): if not check_no_exclusive_keys(value): return False @@ -214,9 +215,9 @@ def contains_lazy_spec(spec: TensorSpec) -> bool: spec (TensorSpec): the spec to check """ - if isinstance(spec, (LazyStackedTensorSpec, LazyStackedCompositeSpec)): + if isinstance(spec, (Stacked, StackedComposite)): return True - elif isinstance(spec, CompositeSpec): + elif isinstance(spec, Composite): for inner_spec in spec.values(): if contains_lazy_spec(inner_spec): return True @@ -235,6 +236,8 @@ def __init__(self, fn: Callable, **kwargs): self.fn = fn self.kwargs = kwargs + functools.update_wrapper(self, getattr(fn, "forward", fn)) + def __getstate__(self): import cloudpickle @@ -244,6 +247,7 @@ def __setstate__(self, ob: bytes): import pickle self.fn, self.kwargs = pickle.loads(ob) + functools.update_wrapper(self, self.fn) def __call__(self, *args, **kwargs) -> Any: kwargs.update(self.kwargs) @@ -253,7 +257,7 @@ def __call__(self, *args, **kwargs) -> Any: def _process_action_space_spec(action_space, spec): original_spec = spec composite_spec = False - if isinstance(spec, CompositeSpec): + if isinstance(spec, Composite): # this will break whenever our action is more complex than a single tensor try: if "action" in spec.keys(): @@ -274,8 +278,8 @@ def _process_action_space_spec(action_space, spec): "with a leaf 'action' entry. Otherwise, simply remove the spec and use the action_space only." ) if action_space is not None: - if isinstance(action_space, CompositeSpec): - raise ValueError("action_space cannot be of type CompositeSpec.") + if isinstance(action_space, Composite): + raise ValueError("action_space cannot be of type Composite.") if ( spec is not None and isinstance(action_space, TensorSpec) @@ -305,7 +309,7 @@ def _process_action_space_spec(action_space, spec): def _find_action_space(action_space): if isinstance(action_space, TensorSpec): - if isinstance(action_space, CompositeSpec): + if isinstance(action_space, Composite): if "action" in action_space.keys(): _key = "action" else: diff --git a/torchrl/envs/__init__.py b/torchrl/envs/__init__.py index ced185d7e00..047550fa9d7 100644 --- a/torchrl/envs/__init__.py +++ b/torchrl/envs/__init__.py @@ -28,12 +28,16 @@ MultiThreadedEnv, MultiThreadedEnvWrapper, OpenMLEnv, + OpenSpielEnv, + OpenSpielWrapper, PettingZooEnv, PettingZooWrapper, RoboHiveEnv, set_gym_backend, SMACv2Env, SMACv2Wrapper, + UnityMLAgentsEnv, + UnityMLAgentsWrapper, VmasEnv, VmasWrapper, ) @@ -100,12 +104,10 @@ from .utils import ( check_env_specs, check_marl_grouping, - exploration_mode, exploration_type, ExplorationType, make_composite_from_td, MarlGroupMapType, - set_exploration_mode, set_exploration_type, step_mdp, ) diff --git a/torchrl/envs/batched_envs.py b/torchrl/envs/batched_envs.py index 4996e527527..02c7f5893dc 100644 --- a/torchrl/envs/batched_envs.py +++ b/torchrl/envs/batched_envs.py @@ -28,6 +28,7 @@ TensorDictBase, unravel_key, ) +from tensordict.utils import _zip_strict from torch import multiprocessing as mp from torchrl._utils import ( _check_for_faulty_process, @@ -36,7 +37,7 @@ logger as torchrl_logger, VERBOSE, ) -from torchrl.data.tensor_specs import CompositeSpec, NonTensorSpec +from torchrl.data.tensor_specs import Composite, NonTensor from torchrl.data.utils import CloudpickleWrapper, contains_lazy_spec, DEVICE_TYPING from torchrl.envs.common import _do_nothing, _EnvPostInit, EnvBase, EnvMetaData from torchrl.envs.env_creator import get_env_metadata @@ -318,14 +319,20 @@ def __init__( create_env_fn = [create_env_fn for _ in range(num_workers)] elif len(create_env_fn) != num_workers: raise RuntimeError( - f"num_workers and len(create_env_fn) mismatch, " - f"got {len(create_env_fn)} and {num_workers}" + f"len(create_env_fn) and num_workers mismatch, " + f"got {len(create_env_fn)} and {num_workers}." ) + create_env_kwargs = {} if create_env_kwargs is None else create_env_kwargs if isinstance(create_env_kwargs, dict): create_env_kwargs = [ deepcopy(create_env_kwargs) for _ in range(num_workers) ] + elif len(create_env_kwargs) != num_workers: + raise RuntimeError( + f"len(create_env_kwargs) and num_workers mismatch, " + f"got {len(create_env_kwargs)} and {num_workers}." + ) self.policy_proof = policy_proof self.num_workers = num_workers @@ -534,7 +541,11 @@ def update_kwargs(self, kwargs: Union[dict, List[dict]]) -> None: for _kwargs in self.create_env_kwargs: _kwargs.update(kwargs) else: - for _kwargs, _new_kwargs in zip(self.create_env_kwargs, kwargs): + if len(kwargs) != self.num_workers: + raise RuntimeError( + f"len(kwargs) and num_workers mismatch, got {len(kwargs)} and {self.num_workers}." + ) + for _kwargs, _new_kwargs in _zip_strict(self.create_env_kwargs, kwargs): _kwargs.update(_new_kwargs) def _get_in_keys_to_exclude(self, tensordict): @@ -550,7 +561,7 @@ def _set_properties(self): cls = type(self) - def _check_for_empty_spec(specs: CompositeSpec): + def _check_for_empty_spec(specs: Composite): for subspec in ( "full_state_spec", "full_action_spec", @@ -559,9 +570,9 @@ def _check_for_empty_spec(specs: CompositeSpec): "full_observation_spec", ): for key, spec in reversed( - list(specs.get(subspec, default=CompositeSpec()).items(True)) + list(specs.get(subspec, default=Composite()).items(True)) ): - if isinstance(spec, CompositeSpec) and spec.is_empty(): + if isinstance(spec, Composite) and spec.is_empty(): raise RuntimeError( f"The environment passed to {cls.__name__} has empty specs in {key}. Consider using " f"torchrl.envs.transforms.RemoveEmptySpecs to remove the empty specs." @@ -675,7 +686,7 @@ def _create_td(self) -> None: self.full_done_spec, ): for key, _spec in spec.items(True, True): - if isinstance(_spec, NonTensorSpec): + if isinstance(_spec, NonTensor): non_tensor_keys.append(key) self._non_tensor_keys = non_tensor_keys @@ -1031,12 +1042,18 @@ def _reset(self, tensordict: TensorDictBase, **kwargs) -> TensorDictBase: if out_tds is not None: out_tds[i] = _td + device = self.device if not self._use_buffers: result = LazyStackedTensorDict.maybe_dense_stack(out_tds) + if result.device != device: + if device is None: + result = result.clear_device_() + else: + result = result.to(device, non_blocking=self.non_blocking) + self._sync_w2m() return result selected_output_keys = self._selected_reset_keys_filt - device = self.device # select + clone creates 2 tds, but we can create one only def select_and_clone(name, tensor): @@ -1066,18 +1083,30 @@ def _step( self, tensordict: TensorDict, ) -> TensorDict: - tensordict_in = tensordict.clone(False) + partial_steps = tensordict.get("_step", None) + tensordict_save = tensordict + if partial_steps is not None and partial_steps.all(): + partial_steps = None + if partial_steps is not None: + partial_steps = partial_steps.view(tensordict.shape) + tensordict = tensordict[partial_steps] + workers_range = partial_steps.nonzero(as_tuple=True)[0].tolist() + tensordict_in = tensordict + else: + workers_range = range(self.num_workers) + tensordict_in = tensordict.clone(False) + # if self._use_buffers: + # shared_tensordict_parent = self.shared_tensordict_parent + data_in = [] - for i in range(self.num_workers): + for i, td_ in zip(workers_range, tensordict_in): # shared_tensordicts are locked, and we need to select the keys since we update in-place. # There may be unexpected keys, such as "_reset", that we should comfortably ignore here. env_device = self._envs[i].device if env_device != self.device and env_device is not None: - data_in.append( - tensordict_in[i].to(env_device, non_blocking=self.non_blocking) - ) + data_in.append(td_.to(env_device, non_blocking=self.non_blocking)) else: - data_in.append(tensordict_in[i]) + data_in.append(td_) self._sync_m2w() out_tds = None @@ -1086,7 +1115,7 @@ def _step( if self._use_buffers: next_td = self.shared_tensordict_parent.get("next") - for i, _data_in in enumerate(data_in): + for i, _data_in in zip(workers_range, data_in): out_td = self._envs[i]._step(_data_in) next_td[i].update_( out_td, @@ -1095,32 +1124,43 @@ def _step( ) if out_tds is not None: out_tds.append(out_td) - else: - for i, _data_in in enumerate(data_in): - out_td = self._envs[i]._step(_data_in) - out_tds.append(out_td) - return LazyStackedTensorDict.maybe_dense_stack(out_tds) - # We must pass a clone of the tensordict, as the values of this tensordict - # will be modified in-place at further steps - device = self.device + # We must pass a clone of the tensordict, as the values of this tensordict + # will be modified in-place at further steps + device = self.device - def select_and_clone(name, tensor): - if name in self._selected_step_keys: - return tensor.clone() + def select_and_clone(name, tensor): + if name in self._selected_step_keys: + return tensor.clone() - out = next_td.named_apply(select_and_clone, nested_keys=True, filter_empty=True) - if out_tds is not None: - out.update( - LazyStackedTensorDict(*out_tds), keys_to_update=self._non_tensor_keys + if partial_steps is not None: + next_td = TensorDict.lazy_stack([next_td[i] for i in workers_range]) + out = next_td.named_apply( + select_and_clone, nested_keys=True, filter_empty=True ) + if out_tds is not None: + out.update( + LazyStackedTensorDict(*out_tds), + keys_to_update=self._non_tensor_keys, + ) + + if out.device != device: + if device is None: + out = out.clear_device_() + elif out.device != device: + out = out.to(device, non_blocking=self.non_blocking) + self._sync_w2m() + else: + for i, _data_in in zip(workers_range, data_in): + out_td = self._envs[i]._step(_data_in) + out_tds.append(out_td) + out = LazyStackedTensorDict.maybe_dense_stack(out_tds) + + if partial_steps is not None: + result = out.new_zeros(tensordict_save.shape) + result[partial_steps] = out + return result - if out.device != device: - if device is None: - out = out.clear_device_() - elif out.device != device: - out = out.to(device, non_blocking=self.non_blocking) - self._sync_w2m() return out def __getattr__(self, attr: str) -> Any: @@ -1435,20 +1475,30 @@ def load_state_dict(self, state_dict: OrderedDict) -> None: def _step_and_maybe_reset_no_buffers( self, tensordict: TensorDictBase ) -> Tuple[TensorDictBase, TensorDictBase]: + partial_steps = tensordict.get("_step", None) + tensordict_save = tensordict + if partial_steps is not None and partial_steps.all(): + partial_steps = None + if partial_steps is not None: + partial_steps = partial_steps.view(tensordict.shape) + tensordict = tensordict[partial_steps] + workers_range = partial_steps.nonzero(as_tuple=True)[0].tolist() + else: + workers_range = range(self.num_workers) td = tensordict.consolidate(share_memory=True, inplace=True, num_threads=1) - for i in range(td.shape[0]): + for i in workers_range: # We send the same td multiple times as it is in shared mem and we just need to index it # in each process. # If we don't do this, we need to unbind it but then the custom pickler will require # some extra metadata to be collected. self.parent_channels[i].send(("step_and_maybe_reset", (td, i))) - results = [None] * self.num_workers + results = [None] * len(workers_range) consumed_indices = [] - events = set(range(self.num_workers)) - while len(consumed_indices) < self.num_workers: + events = set(workers_range) + while len(consumed_indices) < len(workers_range): for i in list(events): if self._events[i].is_set(): results[i] = self.parent_channels[i].recv() @@ -1457,9 +1507,14 @@ def _step_and_maybe_reset_no_buffers( events.discard(i) out_next, out_root = zip(*(future for future in results)) - return TensorDict.maybe_dense_stack(out_next), TensorDict.maybe_dense_stack( + out = TensorDict.maybe_dense_stack(out_next), TensorDict.maybe_dense_stack( out_root ) + if partial_steps is not None: + result = out.new_zeros(tensordict_save.shape) + result[partial_steps] = out + return result + return out @torch.no_grad() @_check_start @@ -1471,6 +1526,42 @@ def step_and_maybe_reset( # return self._step_and_maybe_reset_no_buffers(tensordict) return super().step_and_maybe_reset(tensordict) + partial_steps = tensordict.get("_step", None) + tensordict_save = tensordict + if partial_steps is not None and partial_steps.all(): + partial_steps = None + if partial_steps is not None: + partial_steps = partial_steps.view(tensordict.shape) + workers_range = partial_steps.nonzero(as_tuple=True)[0].tolist() + shared_tensordict_parent = TensorDict.lazy_stack( + [self.shared_tensordict_parent[i] for i in workers_range] + ) + next_td = TensorDict.lazy_stack( + [self._shared_tensordict_parent_next[i] for i in workers_range] + ) + tensordict_ = TensorDict.lazy_stack( + [self._shared_tensordict_parent_root[i] for i in workers_range] + ) + if self.shared_tensordict_parent.device is None: + tensordict = tensordict._fast_apply( + lambda x, y: x[partial_steps].to(y.device) + if y is not None + else x[partial_steps], + self.shared_tensordict_parent, + default=None, + device=None, + batch_size=shared_tensordict_parent.shape, + ) + else: + tensordict = tensordict[partial_steps].to( + self.shared_tensordict_parent.device + ) + else: + workers_range = range(self.num_workers) + shared_tensordict_parent = self.shared_tensordict_parent + next_td = self._shared_tensordict_parent_next + tensordict_ = self._shared_tensordict_parent_root + # We must use the in_keys and nothing else for the following reasons: # - efficiency: copying all the keys will in practice mean doing a lot # of writing operations since the input tensordict may (and often will) @@ -1479,7 +1570,7 @@ def step_and_maybe_reset( # and this transform overrides an observation key (eg, CatFrames) # the shape, dtype or device may not necessarily match and writing # the value in-place will fail. - self.shared_tensordict_parent.update_( + shared_tensordict_parent.update_( tensordict, keys_to_update=self._env_input_keys, non_blocking=self.non_blocking, @@ -1489,46 +1580,41 @@ def step_and_maybe_reset( # if we have input "next" data (eg, RNNs which pass the next state) # the sub-envs will need to process them through step_and_maybe_reset. # We keep track of which keys are present to let the worker know what - # should be passd to the env (we don't want to pass done states for instance) + # should be passed to the env (we don't want to pass done states for instance) next_td_keys = list(next_td_passthrough.keys(True, True)) - data = [ - {"next_td_passthrough_keys": next_td_keys} - for _ in range(self.num_workers) - ] - self.shared_tensordict_parent.get("next").update_( + data = [{"next_td_passthrough_keys": next_td_keys} for _ in workers_range] + shared_tensordict_parent.get("next").update_( next_td_passthrough, non_blocking=self.non_blocking ) else: # next_td_keys = None - data = [{} for _ in range(self.num_workers)] + data = [{} for _ in workers_range] if self._non_tensor_keys: - for i in range(self.num_workers): + for i in workers_range: data[i]["non_tensor_data"] = tensordict[i].select( *self._non_tensor_keys, strict=False ) self._sync_m2w() - for i in range(self.num_workers): - self.parent_channels[i].send(("step_and_maybe_reset", data[i])) + for i, _data in zip(workers_range, data): + self.parent_channels[i].send(("step_and_maybe_reset", _data)) - for i in range(self.num_workers): + for i in workers_range: event = self._events[i] event.wait(self._timeout) event.clear() if self._non_tensor_keys: non_tensor_tds = [] - for i in range(self.num_workers): + for i in workers_range: msg, non_tensor_td = self.parent_channels[i].recv() non_tensor_tds.append(non_tensor_td) # We must pass a clone of the tensordict, as the values of this tensordict # will be modified in-place at further steps - next_td = self._shared_tensordict_parent_next - tensordict_ = self._shared_tensordict_parent_root device = self.device - if self.shared_tensordict_parent.device == device: + if shared_tensordict_parent.device == device: next_td = next_td.clone() tensordict_ = tensordict_.clone() elif device is not None: @@ -1558,22 +1644,49 @@ def step_and_maybe_reset( keys_to_update=[("next", key) for key in self._non_tensor_keys], ) tensordict_.update(non_tensor_tds, keys_to_update=self._non_tensor_keys) + + if partial_steps is not None: + result = tensordict.new_zeros(tensordict_save.shape) + result_ = tensordict_.new_zeros(tensordict_save.shape) + result[partial_steps] = tensordict + result_[partial_steps] = tensordict_ + return result, result_ + return tensordict, tensordict_ def _step_no_buffers( self, tensordict: TensorDictBase ) -> Tuple[TensorDictBase, TensorDictBase]: + partial_steps = tensordict.get("_step", None) + tensordict_save = tensordict + if partial_steps is not None and partial_steps.all(): + partial_steps = None + if partial_steps is not None: + partial_steps = partial_steps.view(tensordict.shape) + tensordict = tensordict[partial_steps] + workers_range = partial_steps.nonzero(as_tuple=True)[0].tolist() + else: + workers_range = range(self.num_workers) + data = tensordict.consolidate(share_memory=True, inplace=True, num_threads=1) - for i, local_data in enumerate(data.unbind(0)): + for i, local_data in zip(workers_range, data.unbind(0)): self.parent_channels[i].send(("step", local_data)) # for i in range(data.shape[0]): # self.parent_channels[i].send(("step", (data, i))) out_tds = [] - for i, channel in enumerate(self.parent_channels): + for i in workers_range: + channel = self.parent_channels[i] self._events[i].wait() td = channel.recv() out_tds.append(td) - return LazyStackedTensorDict.maybe_dense_stack(out_tds) + out = LazyStackedTensorDict.maybe_dense_stack(out_tds) + if self.device is not None and out.device != self.device: + out = out.to(self.device, non_blocking=self.non_blocking) + if partial_steps is not None: + result = out.new_zeros(tensordict_save.shape) + result[partial_steps] = out + return result + return out @torch.no_grad() @_check_start @@ -1588,8 +1701,35 @@ def _step(self, tensordict: TensorDictBase) -> TensorDictBase: # and this transform overrides an observation key (eg, CatFrames) # the shape, dtype or device may not necessarily match and writing # the value in-place will fail. + partial_steps = tensordict.get("_step", None) + tensordict_save = tensordict + if partial_steps is not None and partial_steps.all(): + partial_steps = None + if partial_steps is not None: + partial_steps = partial_steps.view(tensordict.shape) + workers_range = partial_steps.nonzero(as_tuple=True)[0].tolist() + shared_tensordict_parent = TensorDict.lazy_stack( + [self.shared_tensordicts[i] for i in workers_range] + ) + if self.shared_tensordict_parent.device is None: + tensordict = tensordict._fast_apply( + lambda x, y: x[partial_steps].to(y.device) + if y is not None + else x[partial_steps], + self.shared_tensordict_parent, + default=None, + device=None, + batch_size=shared_tensordict_parent.shape, + ) + else: + tensordict = tensordict[partial_steps].to( + self.shared_tensordict_parent.device + ) + else: + workers_range = range(self.num_workers) + shared_tensordict_parent = self.shared_tensordict_parent - self.shared_tensordict_parent.update_( + shared_tensordict_parent.update_( tensordict, keys_to_update=list(self._env_input_keys), non_blocking=self.non_blocking, @@ -1599,20 +1739,20 @@ def _step(self, tensordict: TensorDictBase) -> TensorDictBase: # if we have input "next" data (eg, RNNs which pass the next state) # the sub-envs will need to process them through step_and_maybe_reset. # We keep track of which keys are present to let the worker know what - # should be passd to the env (we don't want to pass done states for instance) + # should be passed to the env (we don't want to pass done states for instance) next_td_keys = list(next_td_passthrough.keys(True, True)) data = [ {"next_td_passthrough_keys": next_td_keys} for _ in range(self.num_workers) ] - self.shared_tensordict_parent.get("next").update_( + shared_tensordict_parent.get("next").update_( next_td_passthrough, non_blocking=self.non_blocking ) else: data = [{} for _ in range(self.num_workers)] if self._non_tensor_keys: - for i in range(self.num_workers): + for i in workers_range: data[i]["non_tensor_data"] = tensordict[i].select( *self._non_tensor_keys, strict=False ) @@ -1622,23 +1762,23 @@ def _step(self, tensordict: TensorDictBase) -> TensorDictBase: if self.event is not None: self.event.record() self.event.synchronize() - for i in range(self.num_workers): + for i in workers_range: self.parent_channels[i].send(("step", data[i])) - for i in range(self.num_workers): + for i in workers_range: event = self._events[i] event.wait(self._timeout) event.clear() if self._non_tensor_keys: non_tensor_tds = [] - for i in range(self.num_workers): + for i in workers_range: msg, non_tensor_td = self.parent_channels[i].recv() non_tensor_tds.append(non_tensor_td) # We must pass a clone of the tensordict, as the values of this tensordict # will be modified in-place at further steps - next_td = self.shared_tensordict_parent.get("next") + next_td = shared_tensordict_parent.get("next") device = self.device if next_td.device != device and device is not None: @@ -1665,6 +1805,10 @@ def select_and_clone(name, tensor): keys_to_update=self._non_tensor_keys, ) self._sync_w2m() + if partial_steps is not None: + result = out.new_zeros(tensordict_save.shape) + result[partial_steps] = out + return result return out def _reset_no_buffers( @@ -1698,7 +1842,11 @@ def _reset_no_buffers( self._events[i].wait() td = channel.recv() out_tds[i] = td - return LazyStackedTensorDict.maybe_dense_stack(out_tds) + result = LazyStackedTensorDict.maybe_dense_stack(out_tds) + device = self.device + if device is not None and result.device != device: + return result.to(self.device, non_blocking=self.non_blocking) + return result @torch.no_grad() @_check_start diff --git a/torchrl/envs/common.py b/torchrl/envs/common.py index b9216b58e86..31ff0b905af 100644 --- a/torchrl/envs/common.py +++ b/torchrl/envs/common.py @@ -25,12 +25,7 @@ seed_generator, ) -from torchrl.data.tensor_specs import ( - CompositeSpec, - DiscreteTensorSpec, - TensorSpec, - UnboundedContinuousTensorSpec, -) +from torchrl.data.tensor_specs import Categorical, Composite, TensorSpec, Unbounded from torchrl.data.utils import DEVICE_TYPING from torchrl.envs.utils import ( _make_compatible_policy, @@ -62,7 +57,7 @@ def __init__( self, *, tensordict: TensorDictBase, - specs: CompositeSpec, + specs: Composite, batch_size: torch.Size, env_str: str, device: torch.device, @@ -91,7 +86,7 @@ def tensordict(self, value: TensorDictBase): self._tensordict = value.to("cpu") @specs.setter - def specs(self, value: CompositeSpec): + def specs(self, value: Composite): self._specs = value.to("cpu") @staticmethod @@ -191,6 +186,27 @@ def __call__(cls, *args, **kwargs): return AutoResetEnv( instance, AutoResetTransform(replace=auto_reset_replace) ) + + done_keys = set(instance.full_done_spec.keys(True, True)) + obs_keys = set(instance.full_observation_spec.keys(True, True)) + reward_keys = set(instance.full_reward_spec.keys(True, True)) + # state_keys can match obs_keys so we don't test that + action_keys = set(instance.full_action_spec.keys(True, True)) + state_keys = set(instance.full_state_spec.keys(True, True)) + total_set = set() + for keyset in (done_keys, obs_keys, reward_keys): + if total_set.intersection(keyset): + raise RuntimeError( + f"The set of keys of one spec collides (culprit: {total_set.intersection(keyset)}) with another." + ) + total_set = total_set.union(keyset) + total_set = set() + for keyset in (state_keys, action_keys): + if total_set.intersection(keyset): + raise RuntimeError( + f"The set of keys of one spec collides (culprit: {total_set.intersection(keyset)}) with another." + ) + total_set = total_set.union(keyset) return instance @@ -212,29 +228,29 @@ class EnvBase(nn.Module, metaclass=_EnvPostInit): be done after a call to :meth:`~.reset` is made. Defaults to ``False``. Attributes: - done_spec (CompositeSpec): equivalent to ``full_done_spec`` as all + done_spec (Composite): equivalent to ``full_done_spec`` as all ``done_specs`` contain at least a ``"done"`` and a ``"terminated"`` entry action_spec (TensorSpec): the spec of the action. Links to the spec of the leaf action if only one action tensor is to be expected. Otherwise links to ``full_action_spec``. - observation_spec (CompositeSpec): equivalent to ``full_observation_spec``. + observation_spec (Composite): equivalent to ``full_observation_spec``. reward_spec (TensorSpec): the spec of the reward. Links to the spec of the leaf reward if only one reward tensor is to be expected. Otherwise links to ``full_reward_spec``. - state_spec (CompositeSpec): equivalent to ``full_state_spec``. - full_done_spec (CompositeSpec): a composite spec such that ``full_done_spec.zero()`` + state_spec (Composite): equivalent to ``full_state_spec``. + full_done_spec (Composite): a composite spec such that ``full_done_spec.zero()`` returns a tensordict containing only the leaves encoding the done status of the environment. - full_action_spec (CompositeSpec): a composite spec such that ``full_action_spec.zero()`` + full_action_spec (Composite): a composite spec such that ``full_action_spec.zero()`` returns a tensordict containing only the leaves encoding the action of the environment. - full_observation_spec (CompositeSpec): a composite spec such that ``full_observation_spec.zero()`` + full_observation_spec (Composite): a composite spec such that ``full_observation_spec.zero()`` returns a tensordict containing only the leaves encoding the observation of the environment. - full_reward_spec (CompositeSpec): a composite spec such that ``full_reward_spec.zero()`` + full_reward_spec (Composite): a composite spec such that ``full_reward_spec.zero()`` returns a tensordict containing only the leaves encoding the reward of the environment. - full_state_spec (CompositeSpec): a composite spec such that ``full_state_spec.zero()`` + full_state_spec (Composite): a composite spec such that ``full_state_spec.zero()`` returns a tensordict containing only the leaves encoding the inputs (actions excluded) of the environment. batch_size (torch.Size): The batch-size of the environment. @@ -253,9 +269,9 @@ class EnvBase(nn.Module, metaclass=_EnvPostInit): >>> from torchrl.envs import EnvBase >>> class CounterEnv(EnvBase): ... def __init__(self, batch_size=(), device=None, **kwargs): - ... self.observation_spec = CompositeSpec( - ... count=UnboundedContinuousTensorSpec(batch_size, device=device, dtype=torch.int64)) - ... self.action_spec = UnboundedContinuousTensorSpec(batch_size, device=device, dtype=torch.int8) + ... self.observation_spec = Composite( + ... count=Unbounded(batch_size, device=device, dtype=torch.int64)) + ... self.action_spec = Unbounded(batch_size, device=device, dtype=torch.int8) ... # done spec and reward spec are set automatically ... def _step(self, tensordict): ... @@ -264,10 +280,10 @@ class EnvBase(nn.Module, metaclass=_EnvPostInit): >>> env.batch_size # how many envs are run at once torch.Size([]) >>> env.input_spec - CompositeSpec( + Composite( full_state_spec: None, - full_action_spec: CompositeSpec( - action: BoundedTensorSpec( + full_action_spec: Composite( + action: BoundedContinuous( shape=torch.Size([1]), space=ContinuousBox( low=Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, contiguous=True), @@ -276,7 +292,7 @@ class EnvBase(nn.Module, metaclass=_EnvPostInit): dtype=torch.float32, domain=continuous), device=cpu, shape=torch.Size([])), device=cpu, shape=torch.Size([])) >>> env.action_spec - BoundedTensorSpec( + BoundedContinuous( shape=torch.Size([1]), space=ContinuousBox( low=Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, contiguous=True), @@ -285,8 +301,8 @@ class EnvBase(nn.Module, metaclass=_EnvPostInit): dtype=torch.float32, domain=continuous) >>> env.observation_spec - CompositeSpec( - observation: BoundedTensorSpec( + Composite( + observation: BoundedContinuous( shape=torch.Size([3]), space=ContinuousBox( low=Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, contiguous=True), @@ -295,14 +311,14 @@ class EnvBase(nn.Module, metaclass=_EnvPostInit): dtype=torch.float32, domain=continuous), device=cpu, shape=torch.Size([])) >>> env.reward_spec - UnboundedContinuousTensorSpec( + UnboundedContinuous( shape=torch.Size([1]), space=None, device=cpu, dtype=torch.float32, domain=continuous) >>> env.done_spec - DiscreteTensorSpec( + Categorical( shape=torch.Size([1]), space=DiscreteBox(n=2), device=cpu, @@ -310,16 +326,16 @@ class EnvBase(nn.Module, metaclass=_EnvPostInit): domain=discrete) >>> # the output_spec contains all the expected outputs >>> env.output_spec - CompositeSpec( - full_reward_spec: CompositeSpec( - reward: UnboundedContinuousTensorSpec( + Composite( + full_reward_spec: Composite( + reward: UnboundedContinuous( shape=torch.Size([1]), space=None, device=cpu, dtype=torch.float32, domain=continuous), device=cpu, shape=torch.Size([])), - full_observation_spec: CompositeSpec( - observation: BoundedTensorSpec( + full_observation_spec: Composite( + observation: BoundedContinuous( shape=torch.Size([3]), space=ContinuousBox( low=Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, contiguous=True), @@ -327,8 +343,8 @@ class EnvBase(nn.Module, metaclass=_EnvPostInit): device=cpu, dtype=torch.float32, domain=continuous), device=cpu, shape=torch.Size([])), - full_done_spec: CompositeSpec( - done: DiscreteTensorSpec( + full_done_spec: Composite( + done: Categorical( shape=torch.Size([1]), space=DiscreteBox(n=2), device=cpu, @@ -544,10 +560,10 @@ def input_spec(self) -> TensorSpec: >>> from torchrl.envs.libs.gym import GymEnv >>> env = GymEnv("Pendulum-v1") >>> env.input_spec - CompositeSpec( + Composite( full_state_spec: None, - full_action_spec: CompositeSpec( - action: BoundedTensorSpec( + full_action_spec: Composite( + action: BoundedContinuous( shape=torch.Size([1]), space=ContinuousBox( low=Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, contiguous=True), @@ -560,7 +576,7 @@ def input_spec(self) -> TensorSpec: """ input_spec = self.__dict__.get("_input_spec") if input_spec is None: - input_spec = CompositeSpec( + input_spec = Composite( full_state_spec=None, shape=self.batch_size, device=self.device, @@ -591,16 +607,16 @@ def output_spec(self) -> TensorSpec: >>> from torchrl.envs.libs.gym import GymEnv >>> env = GymEnv("Pendulum-v1") >>> env.output_spec - CompositeSpec( - full_reward_spec: CompositeSpec( - reward: UnboundedContinuousTensorSpec( + Composite( + full_reward_spec: Composite( + reward: UnboundedContinuous( shape=torch.Size([1]), space=None, device=cpu, dtype=torch.float32, domain=continuous), device=cpu, shape=torch.Size([])), - full_observation_spec: CompositeSpec( - observation: BoundedTensorSpec( + full_observation_spec: Composite( + observation: BoundedContinuous( shape=torch.Size([3]), space=ContinuousBox( low=Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, contiguous=True), @@ -608,8 +624,8 @@ def output_spec(self) -> TensorSpec: device=cpu, dtype=torch.float32, domain=continuous), device=cpu, shape=torch.Size([])), - full_done_spec: CompositeSpec( - done: DiscreteTensorSpec( + full_done_spec: Composite( + done: Categorical( shape=torch.Size([1]), space=DiscreteBox(n=2), device=cpu, @@ -620,7 +636,7 @@ def output_spec(self) -> TensorSpec: """ output_spec = self.__dict__.get("_output_spec") if output_spec is None: - output_spec = CompositeSpec( + output_spec = Composite( shape=self.batch_size, device=self.device, ).lock_() @@ -688,9 +704,9 @@ def action_spec(self) -> TensorSpec: If the action spec is provided as a simple spec, this will be returned. - >>> env.action_spec = UnboundedContinuousTensorSpec(1) + >>> env.action_spec = Unbounded(1) >>> env.action_spec - UnboundedContinuousTensorSpec( + UnboundedContinuous( shape=torch.Size([1]), space=ContinuousBox( low=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, contiguous=True), @@ -702,9 +718,9 @@ def action_spec(self) -> TensorSpec: If the action spec is provided as a composite spec and contains only one leaf, this function will return just the leaf. - >>> env.action_spec = CompositeSpec({"nested": {"action": UnboundedContinuousTensorSpec(1)}}) + >>> env.action_spec = Composite({"nested": {"action": Unbounded(1)}}) >>> env.action_spec - UnboundedContinuousTensorSpec( + UnboundedContinuous( shape=torch.Size([1]), space=ContinuousBox( low=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, contiguous=True), @@ -716,11 +732,11 @@ def action_spec(self) -> TensorSpec: If the action spec is provided as a composite spec and has more than one leaf, this function will return the whole spec. - >>> env.action_spec = CompositeSpec({"nested": {"action": UnboundedContinuousTensorSpec(1), "another_action": DiscreteTensorSpec(1)}}) + >>> env.action_spec = Composite({"nested": {"action": Unbounded(1), "another_action": Categorical(1)}}) >>> env.action_spec - CompositeSpec( - nested: CompositeSpec( - action: UnboundedContinuousTensorSpec( + Composite( + nested: Composite( + action: UnboundedContinuous( shape=torch.Size([1]), space=ContinuousBox( low=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, contiguous=True), @@ -728,7 +744,7 @@ def action_spec(self) -> TensorSpec: device=cpu, dtype=torch.float32, domain=continuous), - another_action: DiscreteTensorSpec( + another_action: Categorical( shape=torch.Size([]), space=DiscreteBox(n=1), device=cpu, @@ -745,7 +761,7 @@ def action_spec(self) -> TensorSpec: >>> from torchrl.envs.libs.gym import GymEnv >>> env = GymEnv("Pendulum-v1") >>> env.action_spec - BoundedTensorSpec( + BoundedContinuous( shape=torch.Size([1]), space=ContinuousBox( low=Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, contiguous=True), @@ -794,16 +810,16 @@ def action_spec(self, value: TensorSpec) -> None: f"The value of spec.shape ({value.shape}) must match the env batch size ({self.batch_size})." ) - if isinstance(value, CompositeSpec): + if isinstance(value, Composite): for _ in value.values(True, True): # noqa: B007 break else: raise RuntimeError( - "An empty CompositeSpec was passed for the action spec. " + "An empty Composite was passed for the action spec. " "This is currently not permitted." ) else: - value = CompositeSpec( + value = Composite( action=value.to(device), shape=self.batch_size, device=device ) @@ -812,10 +828,10 @@ def action_spec(self, value: TensorSpec) -> None: self.input_spec.lock_() @property - def full_action_spec(self) -> CompositeSpec: + def full_action_spec(self) -> Composite: """The full action spec. - ``full_action_spec`` is a :class:`~torchrl.data.CompositeSpec`` instance + ``full_action_spec`` is a :class:`~torchrl.data.Composite`` instance that contains all the action entries. Examples: @@ -824,8 +840,8 @@ def full_action_spec(self) -> CompositeSpec: ... break >>> env = BraxEnv(envname) >>> env.full_action_spec - CompositeSpec( - action: BoundedTensorSpec( + Composite( + action: BoundedContinuous( shape=torch.Size([8]), space=ContinuousBox( low=Tensor(shape=torch.Size([8]), device=cpu, dtype=torch.float32, contiguous=True), @@ -835,10 +851,16 @@ def full_action_spec(self) -> CompositeSpec: domain=continuous), device=cpu, shape=torch.Size([])) """ - return self.input_spec["full_action_spec"] + full_action_spec = self.input_spec.get("full_action_spec", None) + if full_action_spec is None: + full_action_spec = Composite(shape=self.batch_size, device=self.device) + self.input_spec.unlock_() + self.input_spec["full_action_spec"] = full_action_spec + self.input_spec.lock_() + return full_action_spec @full_action_spec.setter - def full_action_spec(self, spec: CompositeSpec) -> None: + def full_action_spec(self, spec: Composite) -> None: self.action_spec = spec # Reward spec @@ -881,9 +903,9 @@ def reward_spec(self) -> TensorSpec: If the reward spec is provided as a simple spec, this will be returned. - >>> env.reward_spec = UnboundedContinuousTensorSpec(1) + >>> env.reward_spec = Unbounded(1) >>> env.reward_spec - UnboundedContinuousTensorSpec( + UnboundedContinuous( shape=torch.Size([1]), space=ContinuousBox( low=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, contiguous=True), @@ -895,9 +917,9 @@ def reward_spec(self) -> TensorSpec: If the reward spec is provided as a composite spec and contains only one leaf, this function will return just the leaf. - >>> env.reward_spec = CompositeSpec({"nested": {"reward": UnboundedContinuousTensorSpec(1)}}) + >>> env.reward_spec = Composite({"nested": {"reward": Unbounded(1)}}) >>> env.reward_spec - UnboundedContinuousTensorSpec( + UnboundedContinuous( shape=torch.Size([1]), space=ContinuousBox( low=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, contiguous=True), @@ -909,11 +931,11 @@ def reward_spec(self) -> TensorSpec: If the reward spec is provided as a composite spec and has more than one leaf, this function will return the whole spec. - >>> env.reward_spec = CompositeSpec({"nested": {"reward": UnboundedContinuousTensorSpec(1), "another_reward": DiscreteTensorSpec(1)}}) + >>> env.reward_spec = Composite({"nested": {"reward": Unbounded(1), "another_reward": Categorical(1)}}) >>> env.reward_spec - CompositeSpec( - nested: CompositeSpec( - reward: UnboundedContinuousTensorSpec( + Composite( + nested: Composite( + reward: UnboundedContinuous( shape=torch.Size([1]), space=ContinuousBox( low=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, contiguous=True), @@ -921,7 +943,7 @@ def reward_spec(self) -> TensorSpec: device=cpu, dtype=torch.float32, domain=continuous), - another_reward: DiscreteTensorSpec( + another_reward: Categorical( shape=torch.Size([]), space=DiscreteBox(n=1), device=cpu, @@ -938,7 +960,7 @@ def reward_spec(self) -> TensorSpec: >>> from torchrl.envs.libs.gym import GymEnv >>> env = GymEnv("Pendulum-v1") >>> env.reward_spec - UnboundedContinuousTensorSpec( + UnboundedContinuous( shape=torch.Size([1]), space=None, device=cpu, @@ -952,7 +974,7 @@ def reward_spec(self) -> TensorSpec: # this will be raised if there is not full_reward_spec (unlikely) or no reward_key # Since output_spec is lazily populated with an empty composite spec for # reward_spec, the second case is much more likely to occur. - self.reward_spec = UnboundedContinuousTensorSpec( + self.reward_spec = Unbounded( shape=(*self.batch_size, 1), device=self.device, ) @@ -982,16 +1004,16 @@ def reward_spec(self, value: TensorSpec) -> None: raise ValueError( f"The value of spec.shape ({value.shape}) must match the env batch size ({self.batch_size})." ) - if isinstance(value, CompositeSpec): + if isinstance(value, Composite): for _ in value.values(True, True): # noqa: B007 break else: raise RuntimeError( - "An empty CompositeSpec was passed for the reward spec. " + "An empty Composite was passed for the reward spec. " "This is currently not permitted." ) else: - value = CompositeSpec( + value = Composite( reward=value.to(device), shape=self.batch_size, device=device ) for leaf in value.values(True, True): @@ -1007,10 +1029,10 @@ def reward_spec(self, value: TensorSpec) -> None: self.output_spec.lock_() @property - def full_reward_spec(self) -> CompositeSpec: + def full_reward_spec(self) -> Composite: """The full reward spec. - ``full_reward_spec`` is a :class:`~torchrl.data.CompositeSpec`` instance + ``full_reward_spec`` is a :class:`~torchrl.data.Composite`` instance that contains all the reward entries. Examples: @@ -1019,9 +1041,9 @@ def full_reward_spec(self) -> CompositeSpec: >>> base_env = GymWrapper(gymnasium.make("Pendulum-v1")) >>> env = TransformedEnv(base_env, RenameTransform("reward", ("nested", "reward"))) >>> env.full_reward_spec - CompositeSpec( - nested: CompositeSpec( - reward: UnboundedContinuousTensorSpec( + Composite( + nested: Composite( + reward: UnboundedContinuous( shape=torch.Size([1]), space=ContinuousBox( low=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, contiguous=True), @@ -1034,7 +1056,7 @@ def full_reward_spec(self) -> CompositeSpec: return self.output_spec["full_reward_spec"] @full_reward_spec.setter - def full_reward_spec(self, spec: CompositeSpec) -> None: + def full_reward_spec(self, spec: Composite) -> None: self.reward_spec = spec.to(self.device) if self.device is not None else spec # done spec @@ -1068,10 +1090,10 @@ def done_key(self): return self.done_keys[0] @property - def full_done_spec(self) -> CompositeSpec: + def full_done_spec(self) -> Composite: """The full done spec. - ``full_done_spec`` is a :class:`~torchrl.data.CompositeSpec`` instance + ``full_done_spec`` is a :class:`~torchrl.data.Composite`` instance that contains all the done entries. It can be used to generate fake data with a structure that mimics the one obtained at runtime. @@ -1081,14 +1103,14 @@ def full_done_spec(self) -> CompositeSpec: >>> from torchrl.envs import GymWrapper >>> env = GymWrapper(gymnasium.make("Pendulum-v1")) >>> env.full_done_spec - CompositeSpec( - done: DiscreteTensorSpec( + Composite( + done: Categorical( shape=torch.Size([1]), space=DiscreteBox(n=2), device=cpu, dtype=torch.bool, domain=discrete), - truncated: DiscreteTensorSpec( + truncated: Categorical( shape=torch.Size([1]), space=DiscreteBox(n=2), device=cpu, @@ -1099,7 +1121,7 @@ def full_done_spec(self) -> CompositeSpec: return self.output_spec["full_done_spec"] @full_done_spec.setter - def full_done_spec(self, spec: CompositeSpec) -> None: + def full_done_spec(self, spec: Composite) -> None: self.done_spec = spec.to(self.device) if self.device is not None else spec # Done spec: done specs belong to output_spec @@ -1111,9 +1133,9 @@ def done_spec(self) -> TensorSpec: If the done spec is provided as a simple spec, this will be returned. - >>> env.done_spec = DiscreteTensorSpec(2, dtype=torch.bool) + >>> env.done_spec = Categorical(2, dtype=torch.bool) >>> env.done_spec - DiscreteTensorSpec( + Categorical( shape=torch.Size([]), space=DiscreteBox(n=2), device=cpu, @@ -1123,9 +1145,9 @@ def done_spec(self) -> TensorSpec: If the done spec is provided as a composite spec and contains only one leaf, this function will return just the leaf. - >>> env.done_spec = CompositeSpec({"nested": {"done": DiscreteTensorSpec(2, dtype=torch.bool)}}) + >>> env.done_spec = Composite({"nested": {"done": Categorical(2, dtype=torch.bool)}}) >>> env.done_spec - DiscreteTensorSpec( + Categorical( shape=torch.Size([]), space=DiscreteBox(n=2), device=cpu, @@ -1135,17 +1157,17 @@ def done_spec(self) -> TensorSpec: If the done spec is provided as a composite spec and has more than one leaf, this function will return the whole spec. - >>> env.done_spec = CompositeSpec({"nested": {"done": DiscreteTensorSpec(2, dtype=torch.bool), "another_done": DiscreteTensorSpec(2, dtype=torch.bool)}}) + >>> env.done_spec = Composite({"nested": {"done": Categorical(2, dtype=torch.bool), "another_done": Categorical(2, dtype=torch.bool)}}) >>> env.done_spec - CompositeSpec( - nested: CompositeSpec( - done: DiscreteTensorSpec( + Composite( + nested: Composite( + done: Categorical( shape=torch.Size([]), space=DiscreteBox(n=2), device=cpu, dtype=torch.bool, domain=discrete), - another_done: DiscreteTensorSpec( + another_done: Categorical( shape=torch.Size([]), space=DiscreteBox(n=2), device=cpu, @@ -1162,7 +1184,7 @@ def done_spec(self) -> TensorSpec: >>> from torchrl.envs.libs.gym import GymEnv >>> env = GymEnv("Pendulum-v1") >>> env.done_spec - DiscreteTensorSpec( + Categorical( shape=torch.Size([1]), space=DiscreteBox(n=2), device=cpu, @@ -1185,16 +1207,16 @@ def _create_done_specs(self): try: full_done_spec = self.output_spec["full_done_spec"] except KeyError: - full_done_spec = CompositeSpec( + full_done_spec = Composite( shape=self.output_spec.shape, device=self.output_spec.device ) - full_done_spec["done"] = DiscreteTensorSpec( + full_done_spec["done"] = Categorical( n=2, shape=(*full_done_spec.shape, 1), dtype=torch.bool, device=self.device, ) - full_done_spec["terminated"] = DiscreteTensorSpec( + full_done_spec["terminated"] = Categorical( n=2, shape=(*full_done_spec.shape, 1), dtype=torch.bool, @@ -1215,7 +1237,7 @@ def check_local_done(spec): spec["terminated"] = item.clone() elif key == "terminated" and "done" not in spec.keys(): spec["done"] = item.clone() - elif isinstance(item, CompositeSpec): + elif isinstance(item, Composite): check_local_done(item) else: if shape is None: @@ -1229,10 +1251,10 @@ def check_local_done(spec): # if the spec is empty, we need to add a done and terminated manually if spec.is_empty(): - spec["done"] = DiscreteTensorSpec( + spec["done"] = Categorical( n=2, shape=(*spec.shape, 1), dtype=torch.bool, device=self.device ) - spec["terminated"] = DiscreteTensorSpec( + spec["terminated"] = Categorical( n=2, shape=(*spec.shape, 1), dtype=torch.bool, device=self.device ) @@ -1260,16 +1282,16 @@ def done_spec(self, value: TensorSpec) -> None: raise ValueError( f"The value of spec.shape ({value.shape}) must match the env batch size ({self.batch_size})." ) - if isinstance(value, CompositeSpec): + if isinstance(value, Composite): for _ in value.values(True, True): # noqa: B007 break else: raise RuntimeError( - "An empty CompositeSpec was passed for the done spec. " + "An empty Composite was passed for the done spec. " "This is currently not permitted." ) else: - value = CompositeSpec( + value = Composite( done=value.to(device), terminated=value.to(device), shape=self.batch_size, @@ -1290,10 +1312,10 @@ def done_spec(self, value: TensorSpec) -> None: # observation spec: observation specs belong to output_spec @property - def observation_spec(self) -> CompositeSpec: + def observation_spec(self) -> Composite: """Observation spec. - Must be a :class:`torchrl.data.CompositeSpec` instance. + Must be a :class:`torchrl.data.Composite` instance. The keys listed in the spec are directly accessible after reset and step. In TorchRL, even though they are not properly speaking "observations" @@ -1307,8 +1329,8 @@ def observation_spec(self) -> CompositeSpec: >>> from torchrl.envs.libs.gym import GymEnv >>> env = GymEnv("Pendulum-v1") >>> env.observation_spec - CompositeSpec( - observation: BoundedTensorSpec( + Composite( + observation: BoundedContinuous( shape=torch.Size([3]), space=ContinuousBox( low=Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, contiguous=True), @@ -1318,9 +1340,9 @@ def observation_spec(self) -> CompositeSpec: domain=continuous), device=cpu, shape=torch.Size([])) """ - observation_spec = self.output_spec["full_observation_spec"] + observation_spec = self.output_spec.get("full_observation_spec", default=None) if observation_spec is None: - observation_spec = CompositeSpec(shape=self.batch_size, device=self.device) + observation_spec = Composite(shape=self.batch_size, device=self.device) self.output_spec.unlock_() self.output_spec["full_observation_spec"] = observation_spec self.output_spec.lock_() @@ -1330,7 +1352,7 @@ def observation_spec(self) -> CompositeSpec: def observation_spec(self, value: TensorSpec) -> None: try: self.output_spec.unlock_() - if not isinstance(value, CompositeSpec): + if not isinstance(value, Composite): raise TypeError("The type of an observation_spec must be Composite.") elif value.shape[: len(self.batch_size)] != self.batch_size: raise ValueError( @@ -1348,19 +1370,19 @@ def observation_spec(self, value: TensorSpec) -> None: self.output_spec.lock_() @property - def full_observation_spec(self) -> CompositeSpec: + def full_observation_spec(self) -> Composite: return self.observation_spec @full_observation_spec.setter - def full_observation_spec(self, spec: CompositeSpec): + def full_observation_spec(self, spec: Composite): self.observation_spec = spec # state spec: state specs belong to input_spec @property - def state_spec(self) -> CompositeSpec: + def state_spec(self) -> Composite: """State spec. - Must be a :class:`torchrl.data.CompositeSpec` instance. + Must be a :class:`torchrl.data.Composite` instance. The keys listed here should be provided as input alongside actions to the environment. In TorchRL, even though they are not properly speaking "state" @@ -1376,10 +1398,10 @@ def state_spec(self) -> CompositeSpec: ... break >>> env = BraxEnv(envname) >>> env.state_spec - CompositeSpec( - state: CompositeSpec( - pipeline_state: CompositeSpec( - q: UnboundedContinuousTensorSpec( + Composite( + state: Composite( + pipeline_state: Composite( + q: UnboundedContinuous( shape=torch.Size([15]), space=None, device=cpu, @@ -1391,14 +1413,14 @@ def state_spec(self) -> CompositeSpec: """ state_spec = self.input_spec["full_state_spec"] if state_spec is None: - state_spec = CompositeSpec(shape=self.batch_size, device=self.device) + state_spec = Composite(shape=self.batch_size, device=self.device) self.input_spec.unlock_() self.input_spec["full_state_spec"] = state_spec self.input_spec.lock_() return state_spec @state_spec.setter - def state_spec(self, value: CompositeSpec) -> None: + def state_spec(self, value: Composite) -> None: try: self.input_spec.unlock_() try: @@ -1406,12 +1428,12 @@ def state_spec(self, value: CompositeSpec) -> None: except AttributeError: pass if value is None: - self.input_spec["full_state_spec"] = CompositeSpec( + self.input_spec["full_state_spec"] = Composite( device=self.device, shape=self.batch_size ) else: device = self.input_spec.device - if not isinstance(value, CompositeSpec): + if not isinstance(value, Composite): raise TypeError("The type of an state_spec must be Composite.") elif value.shape[: len(self.batch_size)] != self.batch_size: raise ValueError( @@ -1428,10 +1450,10 @@ def state_spec(self, value: CompositeSpec) -> None: self.input_spec.lock_() @property - def full_state_spec(self) -> CompositeSpec: + def full_state_spec(self) -> Composite: """The full state spec. - ``full_state_spec`` is a :class:`~torchrl.data.CompositeSpec`` instance + ``full_state_spec`` is a :class:`~torchrl.data.Composite`` instance that contains all the state entries (ie, the input data that is not action). Examples: @@ -1440,10 +1462,10 @@ def full_state_spec(self) -> CompositeSpec: ... break >>> env = BraxEnv(envname) >>> env.full_state_spec - CompositeSpec( - state: CompositeSpec( - pipeline_state: CompositeSpec( - q: UnboundedContinuousTensorSpec( + Composite( + state: Composite( + pipeline_state: Composite( + q: UnboundedContinuous( shape=torch.Size([15]), space=None, device=cpu, @@ -1455,7 +1477,7 @@ def full_state_spec(self) -> CompositeSpec: return self.state_spec @full_state_spec.setter - def full_state_spec(self, spec: CompositeSpec) -> None: + def full_state_spec(self, spec: Composite) -> None: self.state_spec = spec def step(self, tensordict: TensorDictBase) -> TensorDictBase: @@ -1494,7 +1516,7 @@ def step(self, tensordict: TensorDictBase) -> TensorDictBase: @classmethod def _complete_done( - cls, done_spec: CompositeSpec, data: TensorDictBase + cls, done_spec: Composite, data: TensorDictBase ) -> TensorDictBase: """Completes the data structure at step time to put missing done keys.""" # by default, if a done key is missing, it is assumed that it is False @@ -1508,7 +1530,7 @@ def _complete_done( i = -1 for i, (key, item) in enumerate(done_spec.items()): # noqa: B007 val = data.get(key, None) - if isinstance(item, CompositeSpec): + if isinstance(item, Composite): if val is not None: cls._complete_done(item, val) continue @@ -2141,9 +2163,9 @@ def reset( self._assert_tensordict_shape(tensordict) tensordict_reset = self._reset(tensordict, **kwargs) - # We assume that this is done properly - # if reset.device != self.device: - # reset = reset.to(self.device, non_blocking=True) + # We assume that this is done properly + # if reset.device != self.device: + # reset = reset.to(self.device, non_blocking=True) if tensordict_reset is tensordict: raise RuntimeError( "EnvBase._reset should return outplace changes to the input " @@ -2300,14 +2322,14 @@ def rand_step(self, tensordict: Optional[TensorDictBase] = None) -> TensorDictBa return self.step(tensordict) @property - def specs(self) -> CompositeSpec: + def specs(self) -> Composite: """Returns a Composite container where all the environment are present. This feature allows one to create an environment, retrieve all of the specs in a single data container and then erase the environment from the workspace. """ - return CompositeSpec( + return Composite( output_spec=self.output_spec, input_spec=self.input_spec, shape=self.batch_size, @@ -2322,13 +2344,16 @@ def rollout( max_steps: int, policy: Optional[Callable[[TensorDictBase], TensorDictBase]] = None, callback: Optional[Callable[[TensorDictBase, ...], Any]] = None, + *, auto_reset: bool = True, auto_cast_to_device: bool = False, - break_when_any_done: bool = True, + break_when_any_done: bool | None = None, + break_when_all_done: bool | None = None, return_contiguous: bool = True, tensordict: Optional[TensorDictBase] = None, set_truncated: bool = False, out=None, + trust_policy: bool = False, ): """Executes a rollout in the environment. @@ -2347,6 +2372,8 @@ def rollout( TensorDict. Defaults to ``None``. The output of ``callback`` will not be collected, it is the user responsibility to save any result within the callback call if data needs to be carried over beyond the call to ``rollout``. + + Keyword Args: auto_reset (bool, optional): if ``True``, resets automatically the environment if it is in a done state when the rollout is initiated. Default is ``True``. @@ -2354,6 +2381,7 @@ def rollout( policy device before the policy is used. Default is ``False``. break_when_any_done (bool): breaks if any of the done state is True. If False, a reset() is called on the sub-envs that are done. Default is True. + break_when_all_done (bool): TODO return_contiguous (bool): if False, a LazyStackedTensorDict will be returned. Default is True. tensordict (TensorDict, optional): if ``auto_reset`` is False, an initial tensordict must be provided. Rollout will check if this tensordict has done flags and reset the @@ -2367,6 +2395,9 @@ def rollout( ``done_spec``, an exception is raised. Truncated keys can be set through ``env.add_truncated_keys``. Defaults to ``False``. + trust_policy (bool, optional): if ``True``, a non-TensorDictModule policy will be trusted to be + assumed to be compatible with the collector. This defaults to ``True`` for CudaGraphModules + and ``False`` otherwise. Returns: TensorDict object containing the resulting trajectory. @@ -2550,9 +2581,26 @@ def rollout( ... ) """ + if break_when_any_done is None: # True by default + if break_when_all_done: # all overrides + break_when_any_done = False + else: + break_when_any_done = True + if break_when_all_done is None: + # There is no case where break_when_all_done is True by default + break_when_all_done = False + if break_when_all_done and break_when_any_done: + raise TypeError( + "Cannot have both break_when_all_done and break_when_any_done True at the same time." + ) + if policy is not None: policy = _make_compatible_policy( - policy, self.observation_spec, env=self, fast_wrap=True + policy, + self.observation_spec, + env=self, + fast_wrap=True, + trust_policy=trust_policy, ) if auto_cast_to_device: try: @@ -2583,8 +2631,12 @@ def rollout( "env_device": env_device, "callback": callback, } - if break_when_any_done: - tensordicts = self._rollout_stop_early(**kwargs) + if break_when_any_done or break_when_all_done: + tensordicts = self._rollout_stop_early( + break_when_all_done=break_when_all_done, + break_when_any_done=break_when_any_done, + **kwargs, + ) else: tensordicts = self._rollout_nonstop(**kwargs) batch_size = self.batch_size if tensordict is None else tensordict.batch_size @@ -2644,6 +2696,8 @@ def _step_mdp(self): def _rollout_stop_early( self, *, + break_when_any_done, + break_when_all_done, tensordict, auto_cast_to_device, max_steps, @@ -2656,6 +2710,7 @@ def _rollout_stop_early( if auto_cast_to_device: sync_func = _get_sync_func(policy_device, env_device) tensordicts = [] + partial_steps = True for i in range(max_steps): if auto_cast_to_device: if policy_device is not None: @@ -2673,6 +2728,14 @@ def _rollout_stop_early( tensordict.clear_device_() tensordict = self.step(tensordict) td_append = tensordict.copy() + if break_when_all_done: + if partial_steps is not True: + # At least one partial step has been done + del td_append["_partial_steps"] + td_append = torch.where( + partial_steps.view(td_append.shape), td_append, tensordicts[-1] + ) + tensordicts.append(td_append) if i == max_steps - 1: @@ -2680,16 +2743,31 @@ def _rollout_stop_early( break tensordict = self._step_mdp(tensordict) - # done and truncated are in done_keys - # We read if any key is done. - any_done = _terminated_or_truncated( - tensordict, - full_done_spec=self.output_spec["full_done_spec"], - key=None, - ) - - if any_done: - break + if break_when_any_done: + # done and truncated are in done_keys + # We read if any key is done. + any_done = _terminated_or_truncated( + tensordict, + full_done_spec=self.output_spec["full_done_spec"], + key=None, + ) + if any_done: + break + else: + _terminated_or_truncated( + tensordict, + full_done_spec=self.output_spec["full_done_spec"], + key="_partial_steps", + write_full_false=False, + ) + partial_step_curr = tensordict.get("_partial_steps", None) + if partial_step_curr is not None: + partial_step_curr = ~partial_step_curr + partial_steps = partial_steps & partial_step_curr + if partial_steps is not True: + if not partial_steps.any(): + break + tensordict.set("_partial_steps", partial_steps) if callback is not None: callback(self, tensordict) @@ -3044,6 +3122,7 @@ def __init__( self._constructor_kwargs = kwargs self._check_kwargs(kwargs) + self._convert_actions_to_numpy = kwargs.pop("convert_actions_to_numpy", True) self._env = self._build_env(**kwargs) # writes the self._env attribute self._make_specs(self._env) # writes the self._env attribute self.is_closed = False @@ -3169,7 +3248,7 @@ def _do_nothing(): return -def _has_dynamic_specs(spec: CompositeSpec): +def _has_dynamic_specs(spec: Composite): from tensordict.base import _NESTED_TENSORS_AS_LISTS return any( diff --git a/torchrl/envs/custom/pendulum.py b/torchrl/envs/custom/pendulum.py index 8253e3df9b7..e2007227127 100644 --- a/torchrl/envs/custom/pendulum.py +++ b/torchrl/envs/custom/pendulum.py @@ -6,11 +6,7 @@ import torch from tensordict import TensorDict, TensorDictBase -from torchrl.data.tensor_specs import ( - BoundedTensorSpec, - CompositeSpec, - UnboundedContinuousTensorSpec, -) +from torchrl.data.tensor_specs import Bounded, Composite, Unbounded from torchrl.envs.common import EnvBase from torchrl.envs.utils import make_composite_from_td @@ -21,125 +17,193 @@ class PendulumEnv(EnvBase): See the Pendulum tutorial for more details: :ref:`tutorial `. Specs: - CompositeSpec( - output_spec: CompositeSpec( - full_observation_spec: CompositeSpec( - th: BoundedTensorSpec( + >>> env = PendulumEnv() + >>> env.specs + Composite( + output_spec: Composite( + full_observation_spec: Composite( + th: BoundedContinuous( shape=torch.Size([]), space=ContinuousBox( - low=Tensor(shape=torch.Size([]), dtype=torch.float32, contiguous=True), - high=Tensor(shape=torch.Size([]), dtype=torch.float32, contiguous=True)), + low=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, contiguous=True), + high=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, contiguous=True)), + device=cpu, dtype=torch.float32, domain=continuous), - thdot: BoundedTensorSpec( + thdot: BoundedContinuous( shape=torch.Size([]), space=ContinuousBox( - low=Tensor(shape=torch.Size([]), dtype=torch.float32, contiguous=True), - high=Tensor(shape=torch.Size([]), dtype=torch.float32, contiguous=True)), + low=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, contiguous=True), + high=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, contiguous=True)), + device=cpu, dtype=torch.float32, domain=continuous), - params: CompositeSpec( - max_speed: UnboundedContinuousTensorSpec( + params: Composite( + max_speed: UnboundedDiscrete( shape=torch.Size([]), + space=ContinuousBox( + low=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.int64, contiguous=True), + high=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.int64, contiguous=True)), + device=cpu, dtype=torch.int64, domain=discrete), - max_torque: UnboundedContinuousTensorSpec( + max_torque: UnboundedContinuous( shape=torch.Size([]), + space=ContinuousBox( + low=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, contiguous=True), + high=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, contiguous=True)), + device=cpu, dtype=torch.float32, domain=continuous), - dt: UnboundedContinuousTensorSpec( + dt: UnboundedContinuous( shape=torch.Size([]), + space=ContinuousBox( + low=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, contiguous=True), + high=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, contiguous=True)), + device=cpu, dtype=torch.float32, domain=continuous), - g: UnboundedContinuousTensorSpec( + g: UnboundedContinuous( shape=torch.Size([]), + space=ContinuousBox( + low=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, contiguous=True), + high=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, contiguous=True)), + device=cpu, dtype=torch.float32, domain=continuous), - m: UnboundedContinuousTensorSpec( + m: UnboundedContinuous( shape=torch.Size([]), + space=ContinuousBox( + low=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, contiguous=True), + high=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, contiguous=True)), + device=cpu, dtype=torch.float32, domain=continuous), - l: UnboundedContinuousTensorSpec( + l: UnboundedContinuous( shape=torch.Size([]), + space=ContinuousBox( + low=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, contiguous=True), + high=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, contiguous=True)), + device=cpu, dtype=torch.float32, domain=continuous), + device=None, shape=torch.Size([])), + device=None, shape=torch.Size([])), - full_reward_spec: CompositeSpec( - reward: UnboundedContinuousTensorSpec( + full_reward_spec: Composite( + reward: UnboundedContinuous( shape=torch.Size([1]), space=ContinuousBox( - low=Tensor(shape=torch.Size([1]), dtype=torch.float32, contiguous=True), - high=Tensor(shape=torch.Size([1]), dtype=torch.float32, contiguous=True)), + low=Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, contiguous=True), + high=Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, contiguous=True)), + device=cpu, dtype=torch.float32, domain=continuous), + device=None, shape=torch.Size([])), - full_done_spec: CompositeSpec( - done: DiscreteTensorSpec( + full_done_spec: Composite( + done: Categorical( shape=torch.Size([1]), - space=DiscreteBox(n=2), + space=CategoricalBox(n=2), + device=cpu, dtype=torch.bool, domain=discrete), - terminated: DiscreteTensorSpec( + terminated: Categorical( shape=torch.Size([1]), - space=DiscreteBox(n=2), + space=CategoricalBox(n=2), + device=cpu, dtype=torch.bool, domain=discrete), + device=None, shape=torch.Size([])), + device=None, shape=torch.Size([])), - input_spec: CompositeSpec( - full_state_spec: CompositeSpec( - th: BoundedTensorSpec( + input_spec: Composite( + full_state_spec: Composite( + th: BoundedContinuous( shape=torch.Size([]), space=ContinuousBox( - low=Tensor(shape=torch.Size([]), dtype=torch.float32, contiguous=True), - high=Tensor(shape=torch.Size([]), dtype=torch.float32, contiguous=True)), + low=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, contiguous=True), + high=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, contiguous=True)), + device=cpu, dtype=torch.float32, domain=continuous), - thdot: BoundedTensorSpec( + thdot: BoundedContinuous( shape=torch.Size([]), space=ContinuousBox( - low=Tensor(shape=torch.Size([]), dtype=torch.float32, contiguous=True), - high=Tensor(shape=torch.Size([]), dtype=torch.float32, contiguous=True)), + low=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, contiguous=True), + high=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, contiguous=True)), + device=cpu, dtype=torch.float32, domain=continuous), - params: CompositeSpec( - max_speed: UnboundedContinuousTensorSpec( + params: Composite( + max_speed: UnboundedDiscrete( shape=torch.Size([]), + space=ContinuousBox( + low=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.int64, contiguous=True), + high=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.int64, contiguous=True)), + device=cpu, dtype=torch.int64, domain=discrete), - max_torque: UnboundedContinuousTensorSpec( + max_torque: UnboundedContinuous( shape=torch.Size([]), + space=ContinuousBox( + low=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, contiguous=True), + high=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, contiguous=True)), + device=cpu, dtype=torch.float32, domain=continuous), - dt: UnboundedContinuousTensorSpec( + dt: UnboundedContinuous( shape=torch.Size([]), + space=ContinuousBox( + low=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, contiguous=True), + high=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, contiguous=True)), + device=cpu, dtype=torch.float32, domain=continuous), - g: UnboundedContinuousTensorSpec( + g: UnboundedContinuous( shape=torch.Size([]), + space=ContinuousBox( + low=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, contiguous=True), + high=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, contiguous=True)), + device=cpu, dtype=torch.float32, domain=continuous), - m: UnboundedContinuousTensorSpec( + m: UnboundedContinuous( shape=torch.Size([]), + space=ContinuousBox( + low=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, contiguous=True), + high=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, contiguous=True)), + device=cpu, dtype=torch.float32, domain=continuous), - l: UnboundedContinuousTensorSpec( + l: UnboundedContinuous( shape=torch.Size([]), + space=ContinuousBox( + low=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, contiguous=True), + high=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, contiguous=True)), + device=cpu, dtype=torch.float32, domain=continuous), + device=None, shape=torch.Size([])), + device=None, shape=torch.Size([])), - full_action_spec: CompositeSpec( - action: BoundedTensorSpec( + full_action_spec: Composite( + action: BoundedContinuous( shape=torch.Size([1]), space=ContinuousBox( - low=Tensor(shape=torch.Size([1]), dtype=torch.float32, contiguous=True), - high=Tensor(shape=torch.Size([1]), dtype=torch.float32, contiguous=True)), + low=Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, contiguous=True), + high=Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, contiguous=True)), + device=cpu, dtype=torch.float32, domain=continuous), + device=None, shape=torch.Size([])), + device=None, shape=torch.Size([])), + device=None, shape=torch.Size([])) """ @@ -152,6 +216,7 @@ class PendulumEnv(EnvBase): "render_fps": 30, } batch_locked = False + rng = None def __init__(self, td_params=None, seed=None, device=None): if td_params is None: @@ -160,7 +225,7 @@ def __init__(self, td_params=None, seed=None, device=None): super().__init__(device=device) self._make_spec(td_params) if seed is None: - seed = torch.empty((), dtype=torch.int64).random_().item() + seed = torch.empty((), dtype=torch.int64).random_(generator=self.rng).item() self.set_seed(seed) @classmethod @@ -240,14 +305,14 @@ def _reset(self, tensordict): def _make_spec(self, td_params): # Under the hood, this will populate self.output_spec["observation"] - self.observation_spec = CompositeSpec( - th=BoundedTensorSpec( + self.observation_spec = Composite( + th=Bounded( low=-torch.pi, high=torch.pi, shape=(), dtype=torch.float32, ), - thdot=BoundedTensorSpec( + thdot=Bounded( low=-td_params["params", "max_speed"], high=td_params["params", "max_speed"], shape=(), @@ -265,22 +330,22 @@ def _make_spec(self, td_params): self.state_spec = self.observation_spec.clone() # action-spec will be automatically wrapped in input_spec when # `self.action_spec = spec` will be called supported - self.action_spec = BoundedTensorSpec( + self.action_spec = Bounded( low=-td_params["params", "max_torque"], high=td_params["params", "max_torque"], shape=(1,), dtype=torch.float32, ) - self.reward_spec = UnboundedContinuousTensorSpec(shape=(*td_params.shape, 1)) + self.reward_spec = Unbounded(shape=(*td_params.shape, 1)) def make_composite_from_td(td): # custom function to convert a ``tensordict`` in a similar spec structure # of unbounded values. - composite = CompositeSpec( + composite = Composite( { key: make_composite_from_td(tensor) if isinstance(tensor, TensorDictBase) - else UnboundedContinuousTensorSpec( + else Unbounded( dtype=tensor.dtype, device=tensor.device, shape=tensor.shape ) for key, tensor in td.items() @@ -290,7 +355,8 @@ def make_composite_from_td(td): return composite def _set_seed(self, seed: int): - rng = torch.manual_seed(seed) + rng = torch.Generator() + rng.manual_seed(seed) self.rng = rng @staticmethod diff --git a/torchrl/envs/custom/tictactoeenv.py b/torchrl/envs/custom/tictactoeenv.py index 79ea3b2dfb6..2c93a5748ef 100644 --- a/torchrl/envs/custom/tictactoeenv.py +++ b/torchrl/envs/custom/tictactoeenv.py @@ -9,12 +9,7 @@ import torch from tensordict import TensorDict, TensorDictBase -from torchrl.data.tensor_specs import ( - CompositeSpec, - DiscreteTensorSpec, - UnboundedContinuousTensorSpec, - UnboundedDiscreteTensorSpec, -) +from torchrl.data.tensor_specs import Categorical, Composite, Unbounded from torchrl.envs.common import EnvBase @@ -39,28 +34,28 @@ class TicTacToeEnv(EnvBase): output entry). Specs: - CompositeSpec( - output_spec: CompositeSpec( - full_observation_spec: CompositeSpec( - board: DiscreteTensorSpec( + Composite( + output_spec: Composite( + full_observation_spec: Composite( + board: Categorical( shape=torch.Size([3, 3]), space=DiscreteBox(n=2), dtype=torch.int32, domain=discrete), - turn: DiscreteTensorSpec( + turn: Categorical( shape=torch.Size([1]), space=DiscreteBox(n=2), dtype=torch.int32, domain=discrete), - mask: DiscreteTensorSpec( + mask: Categorical( shape=torch.Size([9]), space=DiscreteBox(n=2), dtype=torch.bool, domain=discrete), shape=torch.Size([])), - full_reward_spec: CompositeSpec( - player0: CompositeSpec( - reward: UnboundedContinuousTensorSpec( + full_reward_spec: Composite( + player0: Composite( + reward: UnboundedContinuous( shape=torch.Size([1]), space=ContinuousBox( low=Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, contiguous=True), @@ -68,8 +63,8 @@ class TicTacToeEnv(EnvBase): dtype=torch.float32, domain=continuous), shape=torch.Size([])), - player1: CompositeSpec( - reward: UnboundedContinuousTensorSpec( + player1: Composite( + reward: UnboundedContinuous( shape=torch.Size([1]), space=ContinuousBox( low=Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, contiguous=True), @@ -78,43 +73,43 @@ class TicTacToeEnv(EnvBase): domain=continuous), shape=torch.Size([])), shape=torch.Size([])), - full_done_spec: CompositeSpec( - done: DiscreteTensorSpec( + full_done_spec: Composite( + done: Categorical( shape=torch.Size([1]), space=DiscreteBox(n=2), dtype=torch.bool, domain=discrete), - terminated: DiscreteTensorSpec( + terminated: Categorical( shape=torch.Size([1]), space=DiscreteBox(n=2), dtype=torch.bool, domain=discrete), - truncated: DiscreteTensorSpec( + truncated: Categorical( shape=torch.Size([1]), space=DiscreteBox(n=2), dtype=torch.bool, domain=discrete), shape=torch.Size([])), shape=torch.Size([])), - input_spec: CompositeSpec( - full_state_spec: CompositeSpec( - board: DiscreteTensorSpec( + input_spec: Composite( + full_state_spec: Composite( + board: Categorical( shape=torch.Size([3, 3]), space=DiscreteBox(n=2), dtype=torch.int32, domain=discrete), - turn: DiscreteTensorSpec( + turn: Categorical( shape=torch.Size([1]), space=DiscreteBox(n=2), dtype=torch.int32, domain=discrete), - mask: DiscreteTensorSpec( + mask: Categorical( shape=torch.Size([9]), space=DiscreteBox(n=2), dtype=torch.bool, domain=discrete), shape=torch.Size([])), - full_action_spec: CompositeSpec( - action: DiscreteTensorSpec( + full_action_spec: Composite( + action: Categorical( shape=torch.Size([1]), space=DiscreteBox(n=9), dtype=torch.int64, @@ -172,23 +167,21 @@ class TicTacToeEnv(EnvBase): def __init__(self, *, single_player: bool = False, device=None): super().__init__(device=device) self.single_player = single_player - self.action_spec: UnboundedDiscreteTensorSpec = DiscreteTensorSpec( + self.action_spec: Unbounded = Categorical( n=9, shape=(), device=device, ) - self.full_observation_spec: CompositeSpec = CompositeSpec( - board=UnboundedContinuousTensorSpec( - shape=(3, 3), dtype=torch.int, device=device - ), - turn=DiscreteTensorSpec( + self.full_observation_spec: Composite = Composite( + board=Unbounded(shape=(3, 3), dtype=torch.int, device=device), + turn=Categorical( 2, shape=(1,), dtype=torch.int, device=device, ), - mask=DiscreteTensorSpec( + mask=Categorical( 2, shape=(9,), dtype=torch.bool, @@ -196,22 +189,18 @@ def __init__(self, *, single_player: bool = False, device=None): ), device=device, ) - self.state_spec: CompositeSpec = self.observation_spec.clone() + self.state_spec: Composite = self.observation_spec.clone() - self.reward_spec: UnboundedContinuousTensorSpec = CompositeSpec( + self.reward_spec: Unbounded = Composite( { - ("player0", "reward"): UnboundedContinuousTensorSpec( - shape=(1,), device=device - ), - ("player1", "reward"): UnboundedContinuousTensorSpec( - shape=(1,), device=device - ), + ("player0", "reward"): Unbounded(shape=(1,), device=device), + ("player1", "reward"): Unbounded(shape=(1,), device=device), }, device=device, ) - self.full_done_spec: DiscreteTensorSpec = CompositeSpec( - done=DiscreteTensorSpec(2, shape=(1,), dtype=torch.bool, device=device), + self.full_done_spec: Categorical = Composite( + done=Categorical(2, shape=(1,), dtype=torch.bool, device=device), device=device, ) self.full_done_spec["terminated"] = self.full_done_spec["done"].clone() @@ -229,7 +218,7 @@ def _step(self, state: TensorDict) -> TensorDict: turn = state["turn"].clone() action = state["action"] board.flatten(-2, -1).scatter_(index=action.unsqueeze(-1), dim=-1, value=1) - wins = self.win(state["board"], action) + wins = self.win(board, action) mask = board.flatten(-2, -1) == -1 done = wins | ~mask.any(-1, keepdim=True) @@ -245,7 +234,7 @@ def _step(self, state: TensorDict) -> TensorDict: ("player0", "reward"): reward_0.float(), ("player1", "reward"): reward_1.float(), "board": torch.where(board == -1, board, 1 - board), - "turn": 1 - state["turn"], + "turn": 1 - turn, "mask": mask, }, batch_size=state.batch_size, @@ -271,13 +260,15 @@ def _set_seed(self, seed: int | None): def win(board: torch.Tensor, action: torch.Tensor): row = action // 3 # type: ignore col = action % 3 # type: ignore - return ( - board[..., row, :].sum() - == 3 | board[..., col].sum() - == 3 | board.diagonal(0, -2, -1).sum() - == 3 | board.flip(-1).diagonal(0, -2, -1).sum() - == 3 - ) + if board[..., row, :].sum() == 3: + return True + if board[..., col].sum() == 3: + return True + if board.diagonal(0, -2, -1).sum() == 3: + return True + if board.flip(-1).diagonal(0, -2, -1).sum() == 3: + return True + return False @staticmethod def full(board: torch.Tensor) -> bool: diff --git a/torchrl/envs/env_creator.py b/torchrl/envs/env_creator.py index 89ee8cc5614..f090289214d 100644 --- a/torchrl/envs/env_creator.py +++ b/torchrl/envs/env_creator.py @@ -109,7 +109,7 @@ def share_memory(self, state_dict: OrderedDict) -> None: del state_dict[key] @property - def meta_data(self): + def meta_data(self) -> EnvMetaData: if self._meta_data is None: raise RuntimeError( "meta_data is None in EnvCreator. " "Make sure init_() has been called." diff --git a/torchrl/envs/gym_like.py b/torchrl/envs/gym_like.py index c7935272c91..995f245a8ac 100644 --- a/torchrl/envs/gym_like.py +++ b/torchrl/envs/gym_like.py @@ -12,14 +12,10 @@ import numpy as np import torch -from tensordict import TensorDict, TensorDictBase +from tensordict import NonTensorData, TensorDict, TensorDictBase from torchrl._utils import logger as torchrl_logger -from torchrl.data.tensor_specs import ( - CompositeSpec, - TensorSpec, - UnboundedContinuousTensorSpec, -) +from torchrl.data.tensor_specs import Composite, NonTensor, TensorSpec, Unbounded from torchrl.envs.common import _EnvWrapper, EnvBase @@ -44,10 +40,10 @@ class default_info_dict_reader(BaseInfoDictReader): Args: keys (list of keys, optional): If provided, the list of keys to get from the info dictionary. Defaults to all keys. - spec (List[TensorSpec], Dict[str, TensorSpec] or CompositeSpec, optional): + spec (List[TensorSpec], Dict[str, TensorSpec] or Composite, optional): If a list of specs is provided, each spec will be matched to its - correspondent key to form a :class:`torchrl.data.CompositeSpec`. - If not provided, a composite spec with :class:`~torchrl.data.UnboundedContinuousTensorSpec` + correspondent key to form a :class:`torchrl.data.Composite`. + If not provided, a composite spec with :class:`~torchrl.data.Unbounded` specs will lazyly be created. ignore_private (bool, optional): If ``True``, private infos (starting with an underscore) will be ignored. Defaults to ``True``. @@ -72,10 +68,7 @@ class default_info_dict_reader(BaseInfoDictReader): def __init__( self, keys: List[str] | None = None, - spec: Sequence[TensorSpec] - | Dict[str, TensorSpec] - | CompositeSpec - | None = None, + spec: Sequence[TensorSpec] | Dict[str, TensorSpec] | Composite | None = None, ignore_private: bool = True, ): self.ignore_private = ignore_private @@ -87,19 +80,17 @@ def __init__( if spec is None and keys is None: _info_spec = None elif spec is None: - _info_spec = CompositeSpec( - {key: UnboundedContinuousTensorSpec(()) for key in keys}, shape=[] - ) - elif not isinstance(spec, CompositeSpec): + _info_spec = Composite({key: Unbounded(()) for key in keys}, shape=[]) + elif not isinstance(spec, Composite): if self.keys is not None and len(spec) != len(self.keys): raise ValueError( "If specifying specs for info keys with a sequence, the " "length of the sequence must match the number of keys" ) if isinstance(spec, dict): - _info_spec = CompositeSpec(spec, shape=[]) + _info_spec = Composite(spec, shape=[]) else: - _info_spec = CompositeSpec( + _info_spec = Composite( {key: spec for key, spec in zip(keys, spec)}, shape=[] ) else: @@ -121,7 +112,7 @@ def __call__( keys = [key for key in keys if not key.startswith("_")] self.keys = keys # create an info_spec only if there is none - info_spec = None if self.info_spec is not None else CompositeSpec() + info_spec = None if self.info_spec is not None else Composite() for key in keys: if key in info_dict: val = info_dict[key] @@ -130,7 +121,7 @@ def __call__( tensordict.set(key, val) if info_spec is not None: val = tensordict.get(key) - info_spec[key] = UnboundedContinuousTensorSpec( + info_spec[key] = Unbounded( val.shape, device=val.device, dtype=val.dtype ) elif self.info_spec is not None: @@ -158,7 +149,7 @@ def info_spec(self) -> Dict[str, TensorSpec]: class GymLikeEnv(_EnvWrapper): """A gym-like env is an environment. - Its behaviour is similar to gym environments in what common methods (specifically reset and step) are expected to do. + Its behavior is similar to gym environments in what common methods (specifically reset and step) are expected to do. A :obj:`GymLikeEnv` has a :obj:`.step()` method with the following signature: @@ -181,6 +172,7 @@ class GymLikeEnv(_EnvWrapper): def __new__(cls, *args, **kwargs): self = super().__new__(cls, *args, _batch_locked=True, **kwargs) self._info_dict_reader = [] + return self def read_action(self, action): @@ -291,14 +283,18 @@ def read_obs( observations = observations_dict else: for key, val in observations.items(): - observations[key] = self.observation_spec[key].encode( - val, ignore_device=True - ) + if isinstance(self.observation_spec[key], NonTensor): + observations[key] = NonTensorData(val) + else: + observations[key] = self.observation_spec[key].encode( + val, ignore_device=True + ) return observations def _step(self, tensordict: TensorDictBase) -> TensorDictBase: action = tensordict.get(self.action_key) - action_np = self.read_action(action) + if self._convert_actions_to_numpy: + action = self.read_action(action) reward = 0 for _ in range(self.wrapper_frame_skip): @@ -309,7 +305,7 @@ def _step(self, tensordict: TensorDictBase) -> TensorDictBase: truncated, done, info_dict, - ) = self._output_transform(self._env.step(action_np)) + ) = self._output_transform(self._env.step(action)) if _reward is not None: reward = reward + _reward @@ -515,7 +511,7 @@ def auto_register_info_dict( the info is filled at reset time. .. note:: This method requires running a few iterations in the environment to - manually check that the behaviour matches expectations. + manually check that the behavior matches expectations. Args: ignore_private (bool, optional): If ``True``, private infos (starting with diff --git a/torchrl/envs/libs/__init__.py b/torchrl/envs/libs/__init__.py index e322c2cbf01..7ea113ce46d 100644 --- a/torchrl/envs/libs/__init__.py +++ b/torchrl/envs/libs/__init__.py @@ -19,7 +19,9 @@ from .jumanji import JumanjiEnv, JumanjiWrapper from .meltingpot import MeltingpotEnv, MeltingpotWrapper from .openml import OpenMLEnv +from .openspiel import OpenSpielEnv, OpenSpielWrapper from .pettingzoo import PettingZooEnv, PettingZooWrapper from .robohive import RoboHiveEnv from .smacv2 import SMACv2Env, SMACv2Wrapper +from .unity_mlagents import UnityMLAgentsEnv, UnityMLAgentsWrapper from .vmas import VmasEnv, VmasWrapper diff --git a/torchrl/envs/libs/_gym_utils.py b/torchrl/envs/libs/_gym_utils.py index fb01f430fc1..b95bfb335c6 100644 --- a/torchrl/envs/libs/_gym_utils.py +++ b/torchrl/envs/libs/_gym_utils.py @@ -12,9 +12,9 @@ from torch.utils._pytree import tree_map from torchrl._utils import implement_for -from torchrl.data import CompositeSpec +from torchrl.data import Composite from torchrl.envs import step_mdp, TransformedEnv -from torchrl.envs.libs.gym import _torchrl_to_gym_spec_transform +from torchrl.envs.libs.gym import _torchrl_to_gym_spec_transform, GYMNASIUM_1_ERROR _has_gym = importlib.util.find_spec("gym", None) is not None _has_gymnasium = importlib.util.find_spec("gymnasium", None) is not None @@ -37,7 +37,7 @@ def __init__( ), ) self.observation_space = _torchrl_to_gym_spec_transform( - CompositeSpec( + Composite( { key: self.torchrl_env.full_observation_spec[key] for key in self._observation_keys @@ -125,7 +125,11 @@ def _action_keys(self): import gymnasium class _TorchRLGymnasiumWrapper(gymnasium.Env, _BaseGymWrapper): - @implement_for("gymnasium") + @implement_for("gymnasium", "1.0.0") + def step(self, action): # noqa: F811 + raise ImportError(GYMNASIUM_1_ERROR) + + @implement_for("gymnasium", None, "1.0.0") def step(self, action): # noqa: F811 action_keys = self._action_keys if len(action_keys) == 1: @@ -153,7 +157,7 @@ def step(self, action): # noqa: F811 out = tree_map(lambda x: x.detach().cpu().numpy(), out) return out - @implement_for("gymnasium") + @implement_for("gymnasium", None, "1.0.0") def reset(self): # noqa: F811 self._tensordict = self.torchrl_env.reset() observation = self._tensordict @@ -167,6 +171,10 @@ def reset(self): # noqa: F811 out = tree_map(lambda x: x.detach().cpu().numpy(), out) return out + @implement_for("gymnasium", "1.0.0") + def reset(self): # noqa: F811 + raise ImportError(GYMNASIUM_1_ERROR) + else: class _TorchRLGymnasiumWrapper: diff --git a/torchrl/envs/libs/brax.py b/torchrl/envs/libs/brax.py index ac4cd71ddad..9542b8e71ff 100644 --- a/torchrl/envs/libs/brax.py +++ b/torchrl/envs/libs/brax.py @@ -11,11 +11,7 @@ from packaging import version from tensordict import TensorDict, TensorDictBase -from torchrl.data.tensor_specs import ( - BoundedTensorSpec, - CompositeSpec, - UnboundedContinuousTensorSpec, -) +from torchrl.data.tensor_specs import Bounded, Composite, Unbounded from torchrl.envs.common import _EnvWrapper from torchrl.envs.libs.jax_utils import ( _extract_spec, @@ -55,8 +51,8 @@ class BraxWrapper(_EnvWrapper): Args: env (brax.envs.base.PipelineEnv): the environment to wrap. categorical_action_encoding (bool, optional): if ``True``, categorical - specs will be converted to the TorchRL equivalent (:class:`torchrl.data.DiscreteTensorSpec`), - otherwise a one-hot encoding will be used (:class:`torchrl.data.OneHotTensorSpec`). + specs will be converted to the TorchRL equivalent (:class:`torchrl.data.Categorical`), + otherwise a one-hot encoding will be used (:class:`torchrl.data.OneHot`). Defaults to ``False``. Keyword Args: @@ -255,7 +251,7 @@ def _make_state_spec(self, env: "brax.envs.env.Env"): # noqa: F821 return state_spec def _make_specs(self, env: "brax.envs.env.Env") -> None: # noqa: F821 - self.action_spec = BoundedTensorSpec( + self.action_spec = Bounded( low=-1, high=1, shape=( @@ -264,15 +260,15 @@ def _make_specs(self, env: "brax.envs.env.Env") -> None: # noqa: F821 ), device=self.device, ) - self.reward_spec = UnboundedContinuousTensorSpec( + self.reward_spec = Unbounded( shape=[ *self.batch_size, 1, ], device=self.device, ) - self.observation_spec = CompositeSpec( - observation=UnboundedContinuousTensorSpec( + self.observation_spec = Composite( + observation=Unbounded( shape=( *self.batch_size, env.observation_size, @@ -439,8 +435,8 @@ class BraxEnv(BraxWrapper): env_name (str): the environment name of the env to wrap. Must be part of :attr:`~.available_envs`. categorical_action_encoding (bool, optional): if ``True``, categorical - specs will be converted to the TorchRL equivalent (:class:`torchrl.data.DiscreteTensorSpec`), - otherwise a one-hot encoding will be used (:class:`torchrl.data.OneHotTensorSpec`). + specs will be converted to the TorchRL equivalent (:class:`torchrl.data.Categorical`), + otherwise a one-hot encoding will be used (:class:`torchrl.data.OneHot`). Defaults to ``False``. Keyword Args: diff --git a/torchrl/envs/libs/dm_control.py b/torchrl/envs/libs/dm_control.py index 5558754de26..2ca62e106f6 100644 --- a/torchrl/envs/libs/dm_control.py +++ b/torchrl/envs/libs/dm_control.py @@ -16,13 +16,12 @@ from torchrl._utils import logger as torchrl_logger, VERBOSE from torchrl.data.tensor_specs import ( - BoundedTensorSpec, - CompositeSpec, - DiscreteTensorSpec, - OneHotDiscreteTensorSpec, + Bounded, + Categorical, + Composite, + OneHot, TensorSpec, - UnboundedContinuousTensorSpec, - UnboundedDiscreteTensorSpec, + Unbounded, ) from torchrl.data.utils import DEVICE_TYPING, numpy_to_torch_dtype_dict @@ -57,14 +56,10 @@ def _dmcontrol_to_torchrl_spec_transform( ) for k, item in spec.items() } - return CompositeSpec(**spec) + return Composite(**spec) elif isinstance(spec, dm_env.specs.DiscreteArray): # DiscreteArray is a type of BoundedArray so this block needs to go first - action_space_cls = ( - DiscreteTensorSpec - if categorical_discrete_encoding - else OneHotDiscreteTensorSpec - ) + action_space_cls = Categorical if categorical_discrete_encoding else OneHot if dtype is None: dtype = ( numpy_to_torch_dtype_dict[spec.dtype] @@ -78,7 +73,7 @@ def _dmcontrol_to_torchrl_spec_transform( shape = spec.shape if not len(shape): shape = torch.Size([1]) - return BoundedTensorSpec( + return Bounded( shape=shape, low=spec.minimum, high=spec.maximum, @@ -92,11 +87,9 @@ def _dmcontrol_to_torchrl_spec_transform( if dtype is None: dtype = numpy_to_torch_dtype_dict[spec.dtype] if dtype in (torch.float, torch.double, torch.half): - return UnboundedContinuousTensorSpec( - shape=shape, dtype=dtype, device=device - ) + return Unbounded(shape=shape, dtype=dtype, device=device) else: - return UnboundedDiscreteTensorSpec(shape=shape, dtype=dtype, device=device) + return Unbounded(shape=shape, dtype=dtype, device=device) else: raise NotImplementedError(type(spec)) @@ -254,10 +247,10 @@ def _make_specs(self, env: "gym.Env") -> None: # noqa: F821 reward_spec.shape = torch.Size([1]) self.reward_spec = reward_spec # populate default done spec - done_spec = DiscreteTensorSpec( + done_spec = Categorical( n=2, shape=(*self.batch_size, 1), dtype=torch.bool, device=self.device ) - self.done_spec = CompositeSpec( + self.done_spec = Composite( done=done_spec.clone(), truncated=done_spec.clone(), terminated=done_spec.clone(), diff --git a/torchrl/envs/libs/envpool.py b/torchrl/envs/libs/envpool.py index a029a0beb5b..599645dfdfc 100644 --- a/torchrl/envs/libs/envpool.py +++ b/torchrl/envs/libs/envpool.py @@ -13,12 +13,7 @@ from tensordict import TensorDict, TensorDictBase from torchrl._utils import logger as torchrl_logger -from torchrl.data.tensor_specs import ( - CompositeSpec, - DiscreteTensorSpec, - TensorSpec, - UnboundedContinuousTensorSpec, -) +from torchrl.data.tensor_specs import Categorical, Composite, TensorSpec, Unbounded from torchrl.envs.common import _EnvWrapper from torchrl.envs.utils import _classproperty @@ -35,8 +30,8 @@ class MultiThreadedEnvWrapper(_EnvWrapper): Args: env (envpool.python.envpool.EnvPoolMixin): the envpool to wrap. categorical_action_encoding (bool, optional): if ``True``, categorical - specs will be converted to the TorchRL equivalent (:class:`torchrl.data.DiscreteTensorSpec`), - otherwise a one-hot encoding will be used (:class:`torchrl.data.OneHotTensorSpec`). + specs will be converted to the TorchRL equivalent (:class:`torchrl.data.Categorical`), + otherwise a one-hot encoding will be used (:class:`torchrl.data.OneHot`). Defaults to ``False``. Keyword Args: @@ -161,7 +156,7 @@ def _get_action_spec(self) -> TensorSpec: return action_spec def _get_output_spec(self) -> TensorSpec: - return CompositeSpec( + return Composite( full_observation_spec=self._get_observation_spec(), full_reward_spec=self._get_reward_spec(), full_done_spec=self._get_done_spec(), @@ -180,9 +175,9 @@ def _get_observation_spec(self) -> TensorSpec: categorical_action_encoding=True, ) observation_spec = self._add_shape_to_spec(observation_spec) - if isinstance(observation_spec, CompositeSpec): + if isinstance(observation_spec, Composite): return observation_spec - return CompositeSpec( + return Composite( observation=observation_spec, shape=(self.num_workers,), device=self.device, @@ -192,19 +187,19 @@ def _add_shape_to_spec(self, spec: TensorSpec) -> TensorSpec: return spec.expand((self.num_workers, *spec.shape)) def _get_reward_spec(self) -> TensorSpec: - return UnboundedContinuousTensorSpec( + return Unbounded( device=self.device, shape=self.batch_size, ) def _get_done_spec(self) -> TensorSpec: - spec = DiscreteTensorSpec( + spec = Categorical( 2, device=self.device, shape=self.batch_size, dtype=torch.bool, ) - return CompositeSpec( + return Composite( done=spec, truncated=spec.clone(), terminated=spec.clone(), @@ -335,8 +330,8 @@ class MultiThreadedEnv(MultiThreadedEnvWrapper): create_env_kwargs (Dict[str, Any], optional): kwargs to be passed to envpool environment constructor. categorical_action_encoding (bool, optional): if ``True``, categorical - specs will be converted to the TorchRL equivalent (:class:`torchrl.data.DiscreteTensorSpec`), - otherwise a one-hot encoding will be used (:class:`torchrl.data.OneHotTensorSpec`). + specs will be converted to the TorchRL equivalent (:class:`torchrl.data.Categorical`), + otherwise a one-hot encoding will be used (:class:`torchrl.data.OneHot`). Defaults to ``False``. disable_env_checker (bool, optional): for gym > 0.24 only. If ``True`` (default for these versions), the environment checker won't be run. diff --git a/torchrl/envs/libs/gym.py b/torchrl/envs/libs/gym.py index 9195929e31d..dfe0db92230 100644 --- a/torchrl/envs/libs/gym.py +++ b/torchrl/envs/libs/gym.py @@ -23,16 +23,16 @@ from torchrl._utils import implement_for from torchrl.data.tensor_specs import ( _minmax_dtype, - BinaryDiscreteTensorSpec, - BoundedTensorSpec, - CompositeSpec, - DiscreteTensorSpec, - MultiDiscreteTensorSpec, - MultiOneHotDiscreteTensorSpec, - OneHotDiscreteTensorSpec, + Binary, + Bounded, + Categorical, + Composite, + MultiCategorical, + MultiOneHot, + NonTensor, + OneHot, TensorSpec, - UnboundedContinuousTensorSpec, - UnboundedDiscreteTensorSpec, + Unbounded, ) from torchrl.data.utils import numpy_to_torch_dtype_dict, torch_to_numpy_dtype_dict from torchrl.envs.batched_envs import CloudpickleWrapper @@ -56,6 +56,30 @@ _has_mo = importlib.util.find_spec("mo_gymnasium") is not None _has_sb3 = importlib.util.find_spec("stable_baselines3") is not None +_has_minigrid = importlib.util.find_spec("minigrid") is not None + + +GYMNASIUM_1_ERROR = """RuntimeError: TorchRL does not support gymnasium 1.0 or later versions due to incompatible +changes in the Gym API. +Using gymnasium 1.0 with TorchRL would require significant modifications to your code and may result in: +* Inaccurate step counting, as the auto-reset feature can cause unpredictable numbers of steps to be executed. +* Potential data corruption, as the environment may require/produce garbage data during reset steps. +* Trajectory overlap during data collection. +* Increased computational overhead, as the library would need to handle the additional complexity of auto-resets. +* Manual filtering and boilerplate code to mitigate these issues, which would compromise the modularity and ease of +use of TorchRL. +To maintain the integrity and efficiency of our library, we cannot support this version of gymnasium at this time. +If you need to use gymnasium 1.0 or later, we recommend exploring alternative solutions or waiting for future updates +to TorchRL and gymnasium that may address this compatibility issue. +For more information, please refer to discussion https://github.com/pytorch/rl/discussions/2483 in torchrl. +""" + + +def _minigrid_lib(): + assert _has_minigrid, "minigrid not found" + import minigrid + + return minigrid class set_gym_backend(_DecoratorContextManager): @@ -259,11 +283,7 @@ def _gym_to_torchrl_spec_transform( ) return result if isinstance(spec, gym_spaces.discrete.Discrete): - action_space_cls = ( - DiscreteTensorSpec - if categorical_action_encoding - else OneHotDiscreteTensorSpec - ) + action_space_cls = Categorical if categorical_action_encoding else OneHot dtype = ( numpy_to_torch_dtype_dict[spec.dtype] if categorical_action_encoding @@ -271,7 +291,7 @@ def _gym_to_torchrl_spec_transform( ) return action_space_cls(spec.n, device=device, dtype=dtype) elif isinstance(spec, gym_spaces.multi_binary.MultiBinary): - return BinaryDiscreteTensorSpec( + return Binary( spec.n, device=device, dtype=numpy_to_torch_dtype_dict[spec.dtype] ) # a spec type cannot be a string, so we're sure that versions of gym that don't have Sequence will just skip through this @@ -300,11 +320,9 @@ def _gym_to_torchrl_spec_transform( ) return ( - MultiDiscreteTensorSpec(spec.nvec, device=device, dtype=dtype) + MultiCategorical(spec.nvec, device=device, dtype=dtype) if categorical_action_encoding - else MultiOneHotDiscreteTensorSpec( - spec.nvec, device=device, dtype=dtype - ) + else MultiOneHot(spec.nvec, device=device, dtype=dtype) ) return torch.stack( @@ -337,9 +355,9 @@ def _gym_to_torchrl_spec_transform( and torch.isclose(high, torch.as_tensor(maxval, dtype=dtype)).all() ) return ( - UnboundedContinuousTensorSpec(shape, device=device, dtype=dtype) + Unbounded(shape, device=device, dtype=dtype) if is_unbounded - else BoundedTensorSpec( + else Bounded( low, high, shape, @@ -368,7 +386,7 @@ def _gym_to_torchrl_spec_transform( remap_state_to_observation=remap_state_to_observation, ) # the batch-size must be set later - return CompositeSpec(spec_out, device=device) + return Composite(spec_out, device=device) elif isinstance(spec, gym_spaces.dict.Dict): return _gym_to_torchrl_spec_transform( spec.spaces, @@ -376,6 +394,8 @@ def _gym_to_torchrl_spec_transform( categorical_action_encoding=categorical_action_encoding, remap_state_to_observation=remap_state_to_observation, ) + elif _has_minigrid and isinstance(spec, _minigrid_lib().core.mission.MissionSpace): + return NonTensor((), device=device) else: raise NotImplementedError( f"spec of type {type(spec).__name__} is currently unaccounted for" @@ -396,13 +416,18 @@ def _box_convert(spec, gym_spaces, shape): # noqa: F811 return gym_spaces.Box(low=low, high=high, shape=shape) -@implement_for("gymnasium") +@implement_for("gymnasium", None, "1.0.0") def _box_convert(spec, gym_spaces, shape): # noqa: F811 low = spec.low.detach().cpu().numpy() high = spec.high.detach().cpu().numpy() return gym_spaces.Box(low=low, high=high, shape=shape) +@implement_for("gymnasium", "1.0.0") +def _box_convert(spec, gym_spaces, shape): # noqa: F811 + raise ImportError(GYMNASIUM_1_ERROR) + + @implement_for("gym", "0.21", None) def _multidiscrete_convert(gym_spaces, spec): return gym_spaces.multi_discrete.MultiDiscrete( @@ -410,13 +435,18 @@ def _multidiscrete_convert(gym_spaces, spec): ) -@implement_for("gymnasium") +@implement_for("gymnasium", None, "1.0.0") def _multidiscrete_convert(gym_spaces, spec): # noqa: F811 return gym_spaces.multi_discrete.MultiDiscrete( spec.nvec, dtype=torch_to_numpy_dtype_dict[spec.dtype] ) +@implement_for("gymnasium", "1.0.0") +def _multidiscrete_convert(gym_spaces, spec): # noqa: F811 + raise ImportError(GYMNASIUM_1_ERROR) + + @implement_for("gym", None, "0.21") def _multidiscrete_convert(gym_spaces, spec): # noqa: F811 return gym_spaces.multi_discrete.MultiDiscrete(spec.nvec) @@ -445,19 +475,19 @@ def _torchrl_to_gym_spec_transform( return gym_spaces.Tuple( tuple(_torchrl_to_gym_spec_transform(spec) for spec in spec.unbind(0)) ) - if isinstance(spec, MultiDiscreteTensorSpec): + if isinstance(spec, MultiCategorical): return _multidiscrete_convert(gym_spaces, spec) - if isinstance(spec, MultiOneHotDiscreteTensorSpec): + if isinstance(spec, MultiOneHot): return gym_spaces.multi_discrete.MultiDiscrete(spec.nvec) - if isinstance(spec, BinaryDiscreteTensorSpec): + if isinstance(spec, Binary): return gym_spaces.multi_binary.MultiBinary(spec.shape[-1]) - if isinstance(spec, DiscreteTensorSpec): + if isinstance(spec, Categorical): return gym_spaces.discrete.Discrete( spec.n ) # dtype=torch_to_numpy_dtype_dict[spec.dtype]) - if isinstance(spec, OneHotDiscreteTensorSpec): + if isinstance(spec, OneHot): return gym_spaces.discrete.Discrete(spec.n) - if isinstance(spec, UnboundedContinuousTensorSpec): + if isinstance(spec, Unbounded): minval, maxval = _minmax_dtype(spec.dtype) return gym_spaces.Box( low=minval, @@ -465,7 +495,7 @@ def _torchrl_to_gym_spec_transform( shape=shape, dtype=torch_to_numpy_dtype_dict[spec.dtype], ) - if isinstance(spec, UnboundedDiscreteTensorSpec): + if isinstance(spec, Unbounded): minval, maxval = _minmax_dtype(spec.dtype) return gym_spaces.Box( low=minval, @@ -473,9 +503,9 @@ def _torchrl_to_gym_spec_transform( shape=shape, dtype=torch_to_numpy_dtype_dict[spec.dtype], ) - if isinstance(spec, BoundedTensorSpec): + if isinstance(spec, Bounded): return _box_convert(spec, gym_spaces, shape) - if isinstance(spec, CompositeSpec): + if isinstance(spec, Composite): # remove batch size while spec.shape: spec = spec[0] @@ -515,12 +545,17 @@ def _get_gym_envs(): # noqa: F811 return gym.envs.registration.registry.keys() -@implement_for("gymnasium") +@implement_for("gymnasium", None, "1.0.0") def _get_gym_envs(): # noqa: F811 gym = gym_backend() return gym.envs.registration.registry.keys() +@implement_for("gymnasium", "1.0.0") +def _get_gym_envs(): # noqa: F811 + raise ImportError(GYMNASIUM_1_ERROR) + + def _is_from_pixels(env): observation_spec = env.observation_space try: @@ -624,8 +659,8 @@ class GymWrapper(GymLikeEnv, metaclass=_AsyncMeta): or :class:`gym.VectorEnv`) are supported and the environment batch-size will reflect the number of environments executed in parallel. categorical_action_encoding (bool, optional): if ``True``, categorical - specs will be converted to the TorchRL equivalent (:class:`torchrl.data.DiscreteTensorSpec`), - otherwise a one-hot encoding will be used (:class:`torchrl.data.OneHotTensorSpec`). + specs will be converted to the TorchRL equivalent (:class:`torchrl.data.Categorical`), + otherwise a one-hot encoding will be used (:class:`torchrl.data.OneHot`). Defaults to ``False``. Keyword Args: @@ -652,6 +687,11 @@ class GymWrapper(GymLikeEnv, metaclass=_AsyncMeta): allow_done_after_reset (bool, optional): if ``True``, it is tolerated for envs to be ``done`` just after :meth:`~.reset` is called. Defaults to ``False``. + convert_actions_to_numpy (bool, optional): if ``True``, actions will be + converted from tensors to numpy arrays and moved to CPU before being passed to the + env step function. Set this to ``False`` if the environment is evaluated + on GPU, such as IsaacLab. + Defaults to ``True``. Attributes: available_envs (List[str]): a list of environments to build. @@ -768,14 +808,20 @@ def __init__(self, env=None, categorical_action_encoding=False, **kwargs): self._seed_calls_reset = None self._categorical_action_encoding = categorical_action_encoding if env is not None: - if "EnvCompatibility" in str( - env - ): # a hacky way of knowing if EnvCompatibility is part of the wrappers of env - raise ValueError( - "GymWrapper does not support the gym.wrapper.compatibility.EnvCompatibility wrapper. " - "If this feature is needed, detail your use case in an issue of " - "https://github.com/pytorch/rl/issues." - ) + try: + env_str = str(env) + except TypeError: + # MiniGrid has a bug where the __str__ method fails + pass + else: + if ( + "EnvCompatibility" in env_str + ): # a hacky way of knowing if EnvCompatibility is part of the wrappers of env + raise ValueError( + "GymWrapper does not support the gym.wrapper.compatibility.EnvCompatibility wrapper. " + "If this feature is needed, detail your use case in an issue of " + "https://github.com/pytorch/rl/issues." + ) libname = self.get_library_name(env) with set_gym_backend(libname): kwargs["env"] = env @@ -820,7 +866,7 @@ def _get_batch_size(self, env): batch_size = self.batch_size return batch_size - @implement_for("gymnasium") # gymnasium wants the unwrapped env + @implement_for("gymnasium", None, "1.0.0") # gymnasium wants the unwrapped env def _get_batch_size(self, env): # noqa: F811 env_unwrapped = env.unwrapped if hasattr(env_unwrapped, "num_envs"): @@ -829,6 +875,10 @@ def _get_batch_size(self, env): # noqa: F811 batch_size = self.batch_size return batch_size + @implement_for("gymnasium", "1.0.0") + def _get_batch_size(self, env): # noqa: F811 + raise ImportError(GYMNASIUM_1_ERROR) + def _check_kwargs(self, kwargs: Dict): if "env" not in kwargs: raise TypeError("Could not find environment key 'env' in kwargs.") @@ -865,10 +915,7 @@ def _build_env( def read_action(self, action): action = super().read_action(action) - if ( - isinstance(self.action_spec, (OneHotDiscreteTensorSpec, DiscreteTensorSpec)) - and action.size == 1 - ): + if isinstance(self.action_spec, (OneHot, Categorical)) and action.size == 1: # some envs require an integer for indexing action = int(action) return action @@ -908,7 +955,11 @@ def _build_gym_env(self, env, pixels_only): # noqa: F811 return LegacyPixelObservationWrapper(env, pixels_only=pixels_only) - @implement_for("gymnasium") + @implement_for("gymnasium", "1.0.0") + def _build_gym_env(self, env, pixels_only): # noqa: F811 + raise ImportError(GYMNASIUM_1_ERROR) + + @implement_for("gymnasium", None, "1.0.0") def _build_gym_env(self, env, pixels_only): # noqa: F811 compatibility = gym_backend("wrappers.compatibility") pixel_observation = gym_backend("wrappers.pixel_observation") @@ -973,7 +1024,11 @@ def _set_seed_initial(self, seed: int) -> None: # noqa: F811 except AttributeError as err2: raise err from err2 - @implement_for("gymnasium") + @implement_for("gymnasium", "1.0.0") + def _set_seed_initial(self, seed: int) -> None: # noqa: F811 + raise ImportError(GYMNASIUM_1_ERROR) + + @implement_for("gymnasium", None, "1.0.0") def _set_seed_initial(self, seed: int) -> None: # noqa: F811 try: self.reset(seed=seed) @@ -991,7 +1046,11 @@ def _reward_space(self, env): if hasattr(env, "reward_space") and env.reward_space is not None: return env.reward_space - @implement_for("gymnasium") + @implement_for("gymnasium", "1.0.0") + def _reward_space(self, env): # noqa: F811 + raise ImportError(GYMNASIUM_1_ERROR) + + @implement_for("gymnasium", None, "1.0.0") def _reward_space(self, env): # noqa: F811 env = env.unwrapped if hasattr(env, "reward_space") and env.reward_space is not None: @@ -1012,13 +1071,13 @@ def _make_specs(self, env: "gym.Env", batch_size=None) -> None: # noqa: F821 device=self.device, categorical_action_encoding=self._categorical_action_encoding, ) - if not isinstance(observation_spec, CompositeSpec): + if not isinstance(observation_spec, Composite): if self.from_pixels: - observation_spec = CompositeSpec( + observation_spec = Composite( pixels=observation_spec, shape=cur_batch_size ) else: - observation_spec = CompositeSpec( + observation_spec = Composite( observation=observation_spec, shape=cur_batch_size ) elif observation_spec.shape[: len(cur_batch_size)] != cur_batch_size: @@ -1032,7 +1091,7 @@ def _make_specs(self, env: "gym.Env", batch_size=None) -> None: # noqa: F821 categorical_action_encoding=self._categorical_action_encoding, ) else: - reward_spec = UnboundedContinuousTensorSpec( + reward_spec = Unbounded( shape=[1], device=self.device, ) @@ -1053,15 +1112,15 @@ def _make_specs(self, env: "gym.Env", batch_size=None) -> None: # noqa: F821 @implement_for("gym", None, "0.26") def _make_done_spec(self): # noqa: F811 - return CompositeSpec( + return Composite( { - "done": DiscreteTensorSpec( + "done": Categorical( 2, dtype=torch.bool, device=self.device, shape=(*self.batch_size, 1) ), - "terminated": DiscreteTensorSpec( + "terminated": Categorical( 2, dtype=torch.bool, device=self.device, shape=(*self.batch_size, 1) ), - "truncated": DiscreteTensorSpec( + "truncated": Categorical( 2, dtype=torch.bool, device=self.device, shape=(*self.batch_size, 1) ), }, @@ -1070,15 +1129,15 @@ def _make_done_spec(self): # noqa: F811 @implement_for("gym", "0.26", None) def _make_done_spec(self): # noqa: F811 - return CompositeSpec( + return Composite( { - "done": DiscreteTensorSpec( + "done": Categorical( 2, dtype=torch.bool, device=self.device, shape=(*self.batch_size, 1) ), - "terminated": DiscreteTensorSpec( + "terminated": Categorical( 2, dtype=torch.bool, device=self.device, shape=(*self.batch_size, 1) ), - "truncated": DiscreteTensorSpec( + "truncated": Categorical( 2, dtype=torch.bool, device=self.device, shape=(*self.batch_size, 1) ), }, @@ -1087,15 +1146,15 @@ def _make_done_spec(self): # noqa: F811 @implement_for("gymnasium", "0.27", None) def _make_done_spec(self): # noqa: F811 - return CompositeSpec( + return Composite( { - "done": DiscreteTensorSpec( + "done": Categorical( 2, dtype=torch.bool, device=self.device, shape=(*self.batch_size, 1) ), - "terminated": DiscreteTensorSpec( + "terminated": Categorical( 2, dtype=torch.bool, device=self.device, shape=(*self.batch_size, 1) ), - "truncated": DiscreteTensorSpec( + "truncated": Categorical( 2, dtype=torch.bool, device=self.device, shape=(*self.batch_size, 1) ), }, @@ -1250,8 +1309,8 @@ class GymEnv(GymWrapper): Args: env_name (str): the environment id registered in `gym.registry`. categorical_action_encoding (bool, optional): if ``True``, categorical - specs will be converted to the TorchRL equivalent (:class:`torchrl.data.DiscreteTensorSpec`), - otherwise a one-hot encoding will be used (:class:`torchrl.data.OneHotTensorSpec`). + specs will be converted to the TorchRL equivalent (:class:`torchrl.data.Categorical`), + otherwise a one-hot encoding will be used (:class:`torchrl.data.OneHot`). Defaults to ``False``. Keyword Args: @@ -1269,7 +1328,7 @@ class GymEnv(GymWrapper): pixels_only (bool, optional): if ``True``, only the pixel observations will be returned (by default under the ``"pixels"`` entry in the output tensordict). If ``False``, observations (eg, states) and pixels will be returned - whenever ``from_pixels=True``. Defaults to ``True``. + whenever ``from_pixels=True``. Defaults to ``False``. frame_skip (int, optional): if provided, indicates for how many steps the same action is to be repeated. The observation returned will be the last observation of the sequence, whereas the reward will be the sum @@ -1385,7 +1444,14 @@ def _set_gym_args( # noqa: F811 ) -> None: kwargs.setdefault("disable_env_checker", True) - @implement_for("gymnasium") + @implement_for("gymnasium", "1.0.0") + def _set_gym_args( # noqa: F811 + self, + kwargs, + ) -> None: + raise ImportError(GYMNASIUM_1_ERROR) + + @implement_for("gymnasium", None, "1.0.0") def _set_gym_args( # noqa: F811 self, kwargs, @@ -1567,7 +1633,7 @@ class terminal_obs_reader(default_info_dict_reader): replaced. Args: - observation_spec (CompositeSpec): The observation spec of the gym env. + observation_spec (Composite): The observation spec of the gym env. backend (str, optional): the backend of the env. One of `"sb3"` for stable-baselines3 or `"gym"` for gym/gymnasium. @@ -1585,7 +1651,7 @@ class terminal_obs_reader(default_info_dict_reader): "gym": "final_info", } - def __init__(self, observation_spec: CompositeSpec, backend, name="final"): + def __init__(self, observation_spec: Composite, backend, name="final"): super().__init__() self.name = name self._obs_spec = observation_spec.clone() diff --git a/torchrl/envs/libs/habitat.py b/torchrl/envs/libs/habitat.py index 53752147acc..4180c42b2dc 100644 --- a/torchrl/envs/libs/habitat.py +++ b/torchrl/envs/libs/habitat.py @@ -54,8 +54,8 @@ class HabitatEnv(GymEnv): Args: env_name (str): The environment to execute. categorical_action_encoding (bool, optional): if ``True``, categorical - specs will be converted to the TorchRL equivalent (:class:`torchrl.data.DiscreteTensorSpec`), - otherwise a one-hot encoding will be used (:class:`torchrl.data.OneHotTensorSpec`). + specs will be converted to the TorchRL equivalent (:class:`torchrl.data.Categorical`), + otherwise a one-hot encoding will be used (:class:`torchrl.data.OneHot`). Defaults to ``False``. Keyword Args: diff --git a/torchrl/envs/libs/isaacgym.py b/torchrl/envs/libs/isaacgym.py index 4c56bea304a..fb37639ad37 100644 --- a/torchrl/envs/libs/isaacgym.py +++ b/torchrl/envs/libs/isaacgym.py @@ -14,7 +14,7 @@ import torch from tensordict import TensorDictBase -from torchrl.data import CompositeSpec +from torchrl.data import Composite from torchrl.envs.libs.gym import GymWrapper from torchrl.envs.utils import _classproperty, make_composite_from_td @@ -59,7 +59,7 @@ def __init__( def _make_specs(self, env: "gym.Env") -> None: # noqa: F821 super()._make_specs(env, batch_size=self.batch_size) - self.full_done_spec = CompositeSpec( + self.full_done_spec = Composite( { key: spec.squeeze(-1) for key, spec in self.full_done_spec.items(True, True) diff --git a/torchrl/envs/libs/jax_utils.py b/torchrl/envs/libs/jax_utils.py index d1d1094a264..052f538f0c4 100644 --- a/torchrl/envs/libs/jax_utils.py +++ b/torchrl/envs/libs/jax_utils.py @@ -13,12 +13,7 @@ # from jax import dlpack as jax_dlpack, numpy as jnp from tensordict import make_tensordict, TensorDictBase from torch.utils import dlpack as torch_dlpack -from torchrl.data.tensor_specs import ( - CompositeSpec, - TensorSpec, - UnboundedContinuousTensorSpec, - UnboundedDiscreteTensorSpec, -) +from torchrl.data.tensor_specs import Composite, TensorSpec, Unbounded from torchrl.data.utils import numpy_to_torch_dtype_dict _has_jax = importlib.util.find_spec("jax") is not None @@ -155,15 +150,11 @@ def _extract_spec(data: Union[torch.Tensor, TensorDictBase], key=None) -> Tensor if key in ("reward", "done"): shape = (*shape, 1) if data.dtype in (torch.float, torch.double, torch.half): - return UnboundedContinuousTensorSpec( - shape=shape, dtype=data.dtype, device=data.device - ) + return Unbounded(shape=shape, dtype=data.dtype, device=data.device) else: - return UnboundedDiscreteTensorSpec( - shape=shape, dtype=data.dtype, device=data.device - ) + return Unbounded(shape=shape, dtype=data.dtype, device=data.device) elif isinstance(data, TensorDictBase): - return CompositeSpec( + return Composite( {key: _extract_spec(value, key=key) for key, value in data.items()} ) else: diff --git a/torchrl/envs/libs/jumanji.py b/torchrl/envs/libs/jumanji.py index 071c8f7f56c..dbbc980e8cc 100644 --- a/torchrl/envs/libs/jumanji.py +++ b/torchrl/envs/libs/jumanji.py @@ -18,16 +18,15 @@ _has_jumanji = importlib.util.find_spec("jumanji") is not None from torchrl.data.tensor_specs import ( - BoundedTensorSpec, - CompositeSpec, + Bounded, + Categorical, + Composite, DEVICE_TYPING, - DiscreteTensorSpec, - MultiDiscreteTensorSpec, - MultiOneHotDiscreteTensorSpec, - OneHotDiscreteTensorSpec, + MultiCategorical, + MultiOneHot, + OneHot, TensorSpec, - UnboundedContinuousTensorSpec, - UnboundedDiscreteTensorSpec, + Unbounded, ) from torchrl.data.utils import numpy_to_torch_dtype_dict from torchrl.envs.gym_like import GymLikeEnv @@ -59,19 +58,13 @@ def _jumanji_to_torchrl_spec_transform( import jumanji if isinstance(spec, jumanji.specs.DiscreteArray): - action_space_cls = ( - DiscreteTensorSpec - if categorical_action_encoding - else OneHotDiscreteTensorSpec - ) + action_space_cls = Categorical if categorical_action_encoding else OneHot if dtype is None: dtype = numpy_to_torch_dtype_dict[spec.dtype] return action_space_cls(spec.num_values, dtype=dtype, device=device) if isinstance(spec, jumanji.specs.MultiDiscreteArray): action_space_cls = ( - MultiDiscreteTensorSpec - if categorical_action_encoding - else MultiOneHotDiscreteTensorSpec + MultiCategorical if categorical_action_encoding else MultiOneHot ) if dtype is None: dtype = numpy_to_torch_dtype_dict[spec.dtype] @@ -82,7 +75,7 @@ def _jumanji_to_torchrl_spec_transform( shape = spec.shape if dtype is None: dtype = numpy_to_torch_dtype_dict[spec.dtype] - return BoundedTensorSpec( + return Bounded( shape=shape, low=np.asarray(spec.minimum), high=np.asarray(spec.maximum), @@ -94,11 +87,9 @@ def _jumanji_to_torchrl_spec_transform( if dtype is None: dtype = numpy_to_torch_dtype_dict[spec.dtype] if dtype in (torch.float, torch.double, torch.half): - return UnboundedContinuousTensorSpec( - shape=shape, dtype=dtype, device=device - ) + return Unbounded(shape=shape, dtype=dtype, device=device) else: - return UnboundedDiscreteTensorSpec(shape=shape, dtype=dtype, device=device) + return Unbounded(shape=shape, dtype=dtype, device=device) elif isinstance(spec, jumanji.specs.Spec) and hasattr(spec, "__dict__"): new_spec = {} for key, value in spec.__dict__.items(): @@ -110,7 +101,7 @@ def _jumanji_to_torchrl_spec_transform( new_spec[key] = _jumanji_to_torchrl_spec_transform( value, dtype, device, categorical_action_encoding ) - return CompositeSpec(**new_spec) + return Composite(**new_spec) else: raise TypeError(f"Unsupported spec type {type(spec)}") @@ -140,8 +131,8 @@ class JumanjiWrapper(GymLikeEnv, metaclass=_JumanjiMakeRender): Args: env (jumanji.env.Environment): the env to wrap. categorical_action_encoding (bool, optional): if ``True``, categorical - specs will be converted to the TorchRL equivalent (:class:`torchrl.data.DiscreteTensorSpec`), - otherwise a one-hot encoding will be used (:class:`torchrl.data.OneHotTensorSpec`). + specs will be converted to the TorchRL equivalent (:class:`torchrl.data.Categorical`), + otherwise a one-hot encoding will be used (:class:`torchrl.data.OneHot`). Defaults to ``False``. Keyword Args: @@ -433,9 +424,9 @@ def _make_observation_spec(self, env) -> TensorSpec: spec = env.observation_spec new_spec = _jumanji_to_torchrl_spec_transform(spec, device=self.device) if isinstance(spec, jumanji.specs.Array): - return CompositeSpec(observation=new_spec).expand(self.batch_size) + return Composite(observation=new_spec).expand(self.batch_size) elif isinstance(spec, jumanji.specs.Spec): - return CompositeSpec(**{k: v for k, v in new_spec.items()}).expand( + return Composite(**{k: v for k, v in new_spec.items()}).expand( self.batch_size ) else: @@ -681,8 +672,8 @@ class JumanjiEnv(JumanjiWrapper): Args: env_name (str): the name of the environment to wrap. Must be part of :attr:`~.available_envs`. categorical_action_encoding (bool, optional): if ``True``, categorical - specs will be converted to the TorchRL equivalent (:class:`torchrl.data.DiscreteTensorSpec`), - otherwise a one-hot encoding will be used (:class:`torchrl.data.OneHotTensorSpec`). + specs will be converted to the TorchRL equivalent (:class:`torchrl.data.Categorical`), + otherwise a one-hot encoding will be used (:class:`torchrl.data.OneHot`). Defaults to ``False``. Keyword Args: diff --git a/torchrl/envs/libs/meltingpot.py b/torchrl/envs/libs/meltingpot.py index 446b3dac292..b8e52031a23 100644 --- a/torchrl/envs/libs/meltingpot.py +++ b/torchrl/envs/libs/meltingpot.py @@ -12,7 +12,7 @@ from tensordict import TensorDict, TensorDictBase -from torchrl.data import CompositeSpec, DiscreteTensorSpec, TensorSpec +from torchrl.data import Categorical, Composite, TensorSpec from torchrl.envs.common import _EnvWrapper from torchrl.envs.libs.dm_control import _dmcontrol_to_torchrl_spec_transform from torchrl.envs.utils import _classproperty, check_marl_grouping, MarlGroupMapType @@ -246,9 +246,9 @@ def _make_specs( } self._make_group_map() - action_spec = CompositeSpec() - observation_spec = CompositeSpec() - reward_spec = CompositeSpec() + action_spec = Composite() + observation_spec = Composite() + reward_spec = Composite() for group in self.group_map.keys(): ( @@ -266,11 +266,9 @@ def _make_specs( reward_spec[group] = group_reward_spec observation_spec.update(torchrl_state_spec) - self.done_spec = CompositeSpec( + self.done_spec = Composite( { - "done": DiscreteTensorSpec( - n=2, shape=torch.Size((1,)), dtype=torch.bool - ), + "done": Categorical(n=2, shape=torch.Size((1,)), dtype=torch.bool), }, ) self.action_spec = action_spec @@ -292,7 +290,7 @@ def _make_group_specs( for agent_name in self.group_map[group]: agent_index = self.agent_names_to_indices_map[agent_name] action_specs.append( - CompositeSpec( + Composite( { "action": torchrl_agent_act_specs[ agent_index @@ -301,7 +299,7 @@ def _make_group_specs( ) ) observation_specs.append( - CompositeSpec( + Composite( { "observation": torchrl_agent_obs_specs[ agent_index @@ -310,7 +308,7 @@ def _make_group_specs( ) ) reward_specs.append( - CompositeSpec({"reward": torchrl_rew_spec[agent_index]}) # shape = (1,) + Composite({"reward": torchrl_rew_spec[agent_index]}) # shape = (1,) ) # Create multi-agent specs diff --git a/torchrl/envs/libs/openml.py b/torchrl/envs/libs/openml.py index 7ac318e03cb..55b246bd902 100644 --- a/torchrl/envs/libs/openml.py +++ b/torchrl/envs/libs/openml.py @@ -8,12 +8,7 @@ from tensordict import TensorDict, TensorDictBase from torchrl.data.replay_buffers import SamplerWithoutReplacement -from torchrl.data.tensor_specs import ( - CompositeSpec, - DiscreteTensorSpec, - UnboundedContinuousTensorSpec, - UnboundedDiscreteTensorSpec, -) +from torchrl.data.tensor_specs import Categorical, Composite, Unbounded from torchrl.envs.common import EnvBase from torchrl.envs.transforms import Compose, DoubleToFloat, RenameTransform from torchrl.envs.utils import _classproperty @@ -24,17 +19,13 @@ def _make_composite_from_td(td): # custom funtion to convert a tensordict in a similar spec structure # of unbounded values. - composite = CompositeSpec( + composite = Composite( { key: _make_composite_from_td(tensor) if isinstance(tensor, TensorDictBase) - else UnboundedContinuousTensorSpec( - dtype=tensor.dtype, device=tensor.device, shape=tensor.shape - ) + else Unbounded(dtype=tensor.dtype, device=tensor.device, shape=tensor.shape) if tensor.dtype in (torch.float16, torch.float32, torch.float64) - else UnboundedDiscreteTensorSpec( - dtype=tensor.dtype, device=tensor.device, shape=tensor.shape - ) + else Unbounded(dtype=tensor.dtype, device=tensor.device, shape=tensor.shape) for key, tensor in td.items() }, shape=td.shape, @@ -115,10 +106,10 @@ def __init__(self, dataset_name, device="cpu", batch_size=None): .reshape(self.batch_size) .exclude("index") ) - self.action_spec = DiscreteTensorSpec( + self.action_spec = Categorical( self._data.max_outcome_val + 1, shape=self.batch_size, device=self.device ) - self.reward_spec = UnboundedContinuousTensorSpec(shape=(*self.batch_size, 1)) + self.reward_spec = Unbounded(shape=(*self.batch_size, 1)) def _reset(self, tensordict): data = self._data.sample() diff --git a/torchrl/envs/libs/openspiel.py b/torchrl/envs/libs/openspiel.py new file mode 100644 index 00000000000..8d2d76f453f --- /dev/null +++ b/torchrl/envs/libs/openspiel.py @@ -0,0 +1,655 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from __future__ import annotations + +import importlib.util +from typing import Dict, List + +import torch +from tensordict import TensorDict, TensorDictBase + +from torchrl.data.tensor_specs import ( + Categorical, + Composite, + NonTensor, + OneHot, + Unbounded, +) +from torchrl.envs.common import _EnvWrapper +from torchrl.envs.utils import _classproperty, check_marl_grouping, MarlGroupMapType + +_has_pyspiel = importlib.util.find_spec("pyspiel") is not None + + +def _get_envs(): + if not _has_pyspiel: + raise ImportError( + "open_spiel not found. Consider downloading and installing " + f"open_spiel from {OpenSpielWrapper.git_url}." + ) + + import pyspiel + + return [game.short_name for game in pyspiel.registered_games()] + + +class OpenSpielWrapper(_EnvWrapper): + """Google DeepMind OpenSpiel environment wrapper. + + GitHub: https://github.com/google-deepmind/open_spiel + + Documentation: https://openspiel.readthedocs.io/en/latest/index.html + + Args: + env (pyspiel.State): the game to wrap. + + Keyword Args: + device (torch.device, optional): if provided, the device on which the data + is to be cast. Defaults to ``None``. + batch_size (torch.Size, optional): the batch size of the environment. + Defaults to ``torch.Size([])``. + allow_done_after_reset (bool, optional): if ``True``, it is tolerated + for envs to be ``done`` just after :meth:`~.reset` is called. + Defaults to ``False``. + group_map (MarlGroupMapType or Dict[str, List[str]]], optional): how to + group agents in tensordicts for input/output. See + :class:`~torchrl.envs.utils.MarlGroupMapType` for more info. + Defaults to + :class:`~torchrl.envs.utils.MarlGroupMapType.ALL_IN_ONE_GROUP`. + categorical_actions (bool, optional): if ``True``, categorical specs + will be converted to the TorchRL equivalent + (:class:`torchrl.data.Categorical`), otherwise a one-hot encoding + will be used (:class:`torchrl.data.OneHot`). Defaults to ``False``. + return_state (bool, optional): if ``True``, "state" is included in the + output of :meth:`~.reset` and :meth:`~step`. The state can be given + to :meth:`~.reset` to reset to that state, rather than resetting to + the initial state. + Defaults to ``False``. + + Attributes: + available_envs: environments available to build + + Examples: + >>> import pyspiel + >>> from torchrl.envs import OpenSpielWrapper + >>> from tensordict import TensorDict + >>> base_env = pyspiel.load_game('chess').new_initial_state() + >>> env = OpenSpielWrapper(base_env, return_state=True) + >>> td = env.reset() + >>> td = env.step(env.full_action_spec.rand()) + >>> print(td) + TensorDict( + fields={ + agents: TensorDict( + fields={ + action: Tensor(shape=torch.Size([2, 4672]), device=cpu, dtype=torch.int64, is_shared=False)}, + batch_size=torch.Size([]), + device=None, + is_shared=False), + next: TensorDict( + fields={ + agents: TensorDict( + fields={ + observation: Tensor(shape=torch.Size([2, 20, 8, 8]), device=cpu, dtype=torch.float32, is_shared=False), + reward: Tensor(shape=torch.Size([2, 1]), device=cpu, dtype=torch.float32, is_shared=False)}, + batch_size=torch.Size([2]), + device=None, + is_shared=False), + current_player: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.int32, is_shared=False), + done: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False), + state: NonTensorData(data=FEN: rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR w KQkq - 0 1 + 3009 + , batch_size=torch.Size([]), device=None), + terminated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False)}, + batch_size=torch.Size([]), + device=None, + is_shared=False)}, + batch_size=torch.Size([]), + device=None, + is_shared=False) + >>> print(env.available_envs) + ['2048', 'add_noise', 'amazons', 'backgammon', ...] + + :meth:`~.reset` can restore a specific state, rather than the initial + state, as long as ``return_state=True``. + + >>> import pyspiel + >>> from torchrl.envs import OpenSpielWrapper + >>> from tensordict import TensorDict + >>> base_env = pyspiel.load_game('chess').new_initial_state() + >>> env = OpenSpielWrapper(base_env, return_state=True) + >>> td = env.reset() + >>> td = env.step(env.full_action_spec.rand()) + >>> td_restore = td["next"] + >>> td = env.step(env.full_action_spec.rand()) + >>> # Current state is not equal `td_restore` + >>> (td["next"] == td_restore).all() + False + >>> td = env.reset(td_restore) + >>> # After resetting, now the current state is equal to `td_restore` + >>> (td == td_restore).all() + True + """ + + git_url = "https://github.com/google-deepmind/open_spiel" + libname = "pyspiel" + _lib = None + + @_classproperty + def lib(cls): + if cls._lib is not None: + return cls._lib + + import pyspiel + + cls._lib = pyspiel + return pyspiel + + @_classproperty + def available_envs(cls): + if not _has_pyspiel: + return [] + return _get_envs() + + def __init__( + self, + env=None, + *, + group_map: MarlGroupMapType + | Dict[str, List[str]] = MarlGroupMapType.ALL_IN_ONE_GROUP, + categorical_actions: bool = False, + return_state: bool = False, + **kwargs, + ): + if env is not None: + kwargs["env"] = env + + self.group_map = group_map + self.categorical_actions = categorical_actions + self.return_state = return_state + self._cached_game = None + super().__init__(**kwargs) + + # `reset` allows resetting to any state, including a terminal state + self._allow_done_after_reset = True + + def _check_kwargs(self, kwargs: Dict): + pyspiel = self.lib + if "env" not in kwargs: + raise TypeError("Could not find environment key 'env' in kwargs.") + env = kwargs["env"] + if not isinstance(env, pyspiel.State): + raise TypeError("env is not of type 'pyspiel.State'.") + + def _build_env(self, env, requires_grad: bool = False, **kwargs): + game = env.get_game() + game_type = game.get_type() + + if game.max_chance_outcomes() != 0: + raise NotImplementedError( + f"The game '{game_type.short_name}' has chance nodes, which are not yet supported." + ) + if game_type.dynamics == self.lib.GameType.Dynamics.MEAN_FIELD: + # NOTE: It is unclear from the OpenSpiel documentation what exactly + # "mean field" means exactly, and there is no documentation on the + # several games which have it. + raise RuntimeError( + f"Mean field games like '{game_type.name}' are not yet " "supported." + ) + self.parallel = game_type.dynamics == self.lib.GameType.Dynamics.SIMULTANEOUS + self.requires_grad = requires_grad + return env + + def _init_env(self): + self._update_action_mask() + + def _get_game(self): + if self._cached_game is None: + self._cached_game = self._env.get_game() + return self._cached_game + + def _make_group_map(self, group_map, agent_names): + if group_map is None: + group_map = MarlGroupMapType.ONE_GROUP_PER_AGENT.get_group_map(agent_names) + elif isinstance(group_map, MarlGroupMapType): + group_map = group_map.get_group_map(agent_names) + check_marl_grouping(group_map, agent_names) + return group_map + + def _make_group_specs( + self, + env, + group: str, + ): + observation_specs = [] + action_specs = [] + reward_specs = [] + game = env.get_game() + + for _ in self.group_map[group]: + observation_spec = Composite() + + if self.has_observation: + observation_spec["observation"] = Unbounded( + shape=(*game.observation_tensor_shape(),), + device=self.device, + domain="continuous", + ) + + if self.has_information_state: + observation_spec["information_state"] = Unbounded( + shape=(*game.information_state_tensor_shape(),), + device=self.device, + domain="continuous", + ) + + observation_specs.append(observation_spec) + + action_spec_cls = Categorical if self.categorical_actions else OneHot + action_specs.append( + Composite( + action=action_spec_cls( + env.num_distinct_actions(), + dtype=torch.int64, + device=self.device, + ) + ) + ) + + reward_specs.append( + Composite( + reward=Unbounded( + shape=(1,), + device=self.device, + domain="continuous", + ) + ) + ) + + group_observation_spec = torch.stack( + observation_specs, dim=0 + ) # shape = (n_agents, n_obser_per_agent) + group_action_spec = torch.stack( + action_specs, dim=0 + ) # shape = (n_agents, n_actions_per_agent) + group_reward_spec = torch.stack(reward_specs, dim=0) # shape = (n_agents, 1) + + return ( + group_observation_spec, + group_action_spec, + group_reward_spec, + ) + + def _make_specs(self, env: "pyspiel.State") -> None: # noqa: F821 + self.agent_names = [f"player_{index}" for index in range(env.num_players())] + self.agent_names_to_indices_map = { + agent_name: i for i, agent_name in enumerate(self.agent_names) + } + self.group_map = self._make_group_map(self.group_map, self.agent_names) + self.done_spec = Categorical( + n=2, + shape=torch.Size((1,)), + dtype=torch.bool, + device=self.device, + ) + game = env.get_game() + game_type = game.get_type() + # In OpenSpiel, a game's state may have either an "observation" tensor, + # an "information state" tensor, or both. If the OpenSpiel game does not + # have one of these, then its corresponding accessor functions raise an + # error, so we must avoid calling them. + self.has_observation = game_type.provides_observation_tensor + self.has_information_state = game_type.provides_information_state_tensor + + observation_spec = {} + action_spec = {} + reward_spec = {} + + for group in self.group_map.keys(): + ( + group_observation_spec, + group_action_spec, + group_reward_spec, + ) = self._make_group_specs( + env, + group, + ) + observation_spec[group] = group_observation_spec + action_spec[group] = group_action_spec + reward_spec[group] = group_reward_spec + + if self.return_state: + observation_spec["state"] = NonTensor([]) + + observation_spec["current_player"] = Unbounded( + shape=(), + dtype=torch.int, + device=self.device, + domain="discrete", + ) + + self.observation_spec = Composite(observation_spec) + self.action_spec = Composite(action_spec) + self.reward_spec = Composite(reward_spec) + + def _set_seed(self, seed): + if seed is not None: + raise NotImplementedError("This environment has no seed.") + + def current_player(self): + return self._env.current_player() + + def _update_action_mask(self): + if self._env.is_terminal(): + agents_acting = [] + else: + agents_acting = [ + self.agent_names + if self.parallel + else self.agent_names[self._env.current_player()] + ] + for group, agents in self.group_map.items(): + action_masks = [] + for agent in agents: + agent_index = self.agent_names_to_indices_map[agent] + if agent in agents_acting: + action_mask = torch.zeros( + self._env.num_distinct_actions(), + device=self.device, + dtype=torch.bool, + ) + action_mask[self._env.legal_actions(agent_index)] = True + else: + action_mask = torch.zeros( + self._env.num_distinct_actions(), + device=self.device, + dtype=torch.bool, + ) + # In OpenSpiel parallel games, non-acting players are + # expected to take action 0. + # https://openspiel.readthedocs.io/en/latest/api_reference/state_apply_action.html + action_mask[0] = True + action_masks.append(action_mask) + self.full_action_spec[group, "action"].update_mask( + torch.stack(action_masks, dim=0) + ) + + def _make_td_out(self, exclude_reward=False): + done = torch.tensor( + self._env.is_terminal(), device=self.device, dtype=torch.bool + ) + current_player = torch.tensor( + self.current_player(), device=self.device, dtype=torch.int + ) + + source = { + "done": done, + "terminated": done.clone(), + "current_player": current_player, + } + + if self.return_state: + source["state"] = self._env.serialize() + + reward = self._env.returns() + + for group, agent_names in self.group_map.items(): + agent_tds = [] + + for agent in agent_names: + agent_index = self.agent_names_to_indices_map[agent] + agent_source = {} + if self.has_observation: + observation_shape = self._get_game().observation_tensor_shape() + agent_source["observation"] = self._to_tensor( + self._env.observation_tensor(agent_index) + ).reshape(observation_shape) + + if self.has_information_state: + information_state_shape = ( + self._get_game().information_state_tensor_shape() + ) + agent_source["information_state"] = self._to_tensor( + self._env.information_state_tensor(agent_index) + ).reshape(information_state_shape) + + if not exclude_reward: + agent_source["reward"] = self._to_tensor(reward[agent_index]) + + agent_td = TensorDict( + source=agent_source, + batch_size=self.batch_size, + device=self.device, + ) + agent_tds.append(agent_td) + + source[group] = torch.stack(agent_tds, dim=0) + + tensordict_out = TensorDict( + source=source, + batch_size=self.batch_size, + device=self.device, + ) + + return tensordict_out + + def _get_action_from_tensor(self, tensor): + if not self.categorical_actions: + action = torch.argmax(tensor, dim=-1) + else: + action = tensor + return action + + def _step_parallel(self, tensordict: TensorDictBase): + actions = [0] * self._env.num_players() + for group, agents in self.group_map.items(): + for index_in_group, agent in enumerate(agents): + agent_index = self.agent_names_to_indices_map[agent] + action_tensor = tensordict[group, "action"][index_in_group] + action = self._get_action_from_tensor(action_tensor) + actions[agent_index] = action + + self._env.apply_actions(actions) + + def _step_sequential(self, tensordict: TensorDictBase): + agent_index = self._env.current_player() + + # If the game has ended, do nothing + if agent_index == self.lib.PlayerId.TERMINAL: + return + + agent = self.agent_names[agent_index] + agent_group = None + agent_index_in_group = None + + for group, agents in self.group_map.items(): + if agent in agents: + agent_group = group + agent_index_in_group = agents.index(agent) + break + + assert agent_group is not None + + action_tensor = tensordict[agent_group, "action"][agent_index_in_group] + action = self._get_action_from_tensor(action_tensor) + self._env.apply_action(action) + + def _step(self, tensordict: TensorDictBase) -> TensorDictBase: + if self.parallel: + self._step_parallel(tensordict) + else: + self._step_sequential(tensordict) + + self._update_action_mask() + return self._make_td_out() + + def _to_tensor(self, value): + return torch.tensor(value, device=self.device, dtype=torch.float32) + + def _reset( + self, tensordict: TensorDictBase | None = None, **kwargs + ) -> TensorDictBase: + game = self._get_game() + + if tensordict is not None and "state" in tensordict: + new_env = game.deserialize_state(tensordict["state"]) + else: + new_env = game.new_initial_state() + + self._env = new_env + self._update_action_mask() + return self._make_td_out(exclude_reward=True) + + +class OpenSpielEnv(OpenSpielWrapper): + """Google DeepMind OpenSpiel environment wrapper built with the game string. + + GitHub: https://github.com/google-deepmind/open_spiel + + Documentation: https://openspiel.readthedocs.io/en/latest/index.html + + Args: + game_string (str): the name of the game to wrap. Must be part of + :attr:`~.available_envs`. + + Keyword Args: + device (torch.device, optional): if provided, the device on which the data + is to be cast. Defaults to ``None``. + batch_size (torch.Size, optional): the batch size of the environment. + Defaults to ``torch.Size([])``. + allow_done_after_reset (bool, optional): if ``True``, it is tolerated + for envs to be ``done`` just after :meth:`~.reset` is called. + Defaults to ``False``. + group_map (MarlGroupMapType or Dict[str, List[str]]], optional): how to + group agents in tensordicts for input/output. See + :class:`~torchrl.envs.utils.MarlGroupMapType` for more info. + Defaults to + :class:`~torchrl.envs.utils.MarlGroupMapType.ALL_IN_ONE_GROUP`. + categorical_actions (bool, optional): if ``True``, categorical specs + will be converted to the TorchRL equivalent + (:class:`torchrl.data.Categorical`), otherwise a one-hot encoding + will be used (:class:`torchrl.data.OneHot`). Defaults to ``False``. + return_state (bool, optional): if ``True``, "state" is included in the + output of :meth:`~.reset` and :meth:`~step`. The state can be given + to :meth:`~.reset` to reset to that state, rather than resetting to + the initial state. + Defaults to ``False``. + + Attributes: + available_envs: environments available to build + + Examples: + >>> from torchrl.envs import OpenSpielEnv + >>> from tensordict import TensorDict + >>> env = OpenSpielEnv("chess", return_state=True) + >>> td = env.reset() + >>> td = env.step(env.full_action_spec.rand()) + >>> print(td) + TensorDict( + fields={ + agents: TensorDict( + fields={ + action: Tensor(shape=torch.Size([2, 4672]), device=cpu, dtype=torch.int64, is_shared=False)}, + batch_size=torch.Size([]), + device=None, + is_shared=False), + next: TensorDict( + fields={ + agents: TensorDict( + fields={ + observation: Tensor(shape=torch.Size([2, 20, 8, 8]), device=cpu, dtype=torch.float32, is_shared=False), + reward: Tensor(shape=torch.Size([2, 1]), device=cpu, dtype=torch.float32, is_shared=False)}, + batch_size=torch.Size([2]), + device=None, + is_shared=False), + current_player: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.int32, is_shared=False), + done: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False), + state: NonTensorData(data=FEN: rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR w KQkq - 0 1 + 674 + , batch_size=torch.Size([]), device=None), + terminated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False)}, + batch_size=torch.Size([]), + device=None, + is_shared=False)}, + batch_size=torch.Size([]), + device=None, + is_shared=False) + >>> print(env.available_envs) + ['2048', 'add_noise', 'amazons', 'backgammon', ...] + + :meth:`~.reset` can restore a specific state, rather than the initial state, + as long as ``return_state=True``. + + >>> from torchrl.envs import OpenSpielEnv + >>> from tensordict import TensorDict + >>> env = OpenSpielEnv("chess", return_state=True) + >>> td = env.reset() + >>> td = env.step(env.full_action_spec.rand()) + >>> td_restore = td["next"] + >>> td = env.step(env.full_action_spec.rand()) + >>> # Current state is not equal `td_restore` + >>> (td["next"] == td_restore).all() + False + >>> td = env.reset(td_restore) + >>> # After resetting, now the current state is equal to `td_restore` + >>> (td == td_restore).all() + True + """ + + def __init__( + self, + game_string, + *, + group_map: MarlGroupMapType + | Dict[str, List[str]] = MarlGroupMapType.ALL_IN_ONE_GROUP, + categorical_actions=False, + return_state: bool = False, + **kwargs, + ): + kwargs["game_string"] = game_string + super().__init__( + group_map=group_map, + categorical_actions=categorical_actions, + return_state=return_state, + **kwargs, + ) + + def _build_env( + self, + game_string: str, + **kwargs, + ) -> "pyspiel.State": # noqa: F821 + if not _has_pyspiel: + raise ImportError( + f"open_spiel not found, unable to create {game_string}. Consider " + f"downloading and installing open_spiel from {self.git_url}" + ) + requires_grad = kwargs.pop("requires_grad", False) + parameters = kwargs.pop("parameters", None) + if kwargs: + raise ValueError("kwargs not supported.") + + if parameters: + game = self.lib.load_game(game_string, parameters=parameters) + else: + game = self.lib.load_game(game_string) + + env = game.new_initial_state() + return super()._build_env( + env, + requires_grad=requires_grad, + ) + + @property + def game_string(self): + return self._constructor_kwargs["game_string"] + + def _check_kwargs(self, kwargs: Dict): + if "game_string" not in kwargs: + raise TypeError("Expected 'game_string' to be part of kwargs") + + def __repr__(self) -> str: + return f"{self.__class__.__name__}(env={self.game_string}, batch_size={self.batch_size}, device={self.device})" diff --git a/torchrl/envs/libs/pettingzoo.py b/torchrl/envs/libs/pettingzoo.py index eb94a27cbba..9853e8d516d 100644 --- a/torchrl/envs/libs/pettingzoo.py +++ b/torchrl/envs/libs/pettingzoo.py @@ -13,12 +13,7 @@ import torch from tensordict import TensorDictBase -from torchrl.data.tensor_specs import ( - CompositeSpec, - DiscreteTensorSpec, - OneHotDiscreteTensorSpec, - UnboundedContinuousTensorSpec, -) +from torchrl.data.tensor_specs import Categorical, Composite, OneHot, Unbounded from torchrl.envs.common import _EnvWrapper from torchrl.envs.libs.gym import _gym_to_torchrl_spec_transform, set_gym_backend from torchrl.envs.utils import _classproperty, check_marl_grouping, MarlGroupMapType @@ -141,7 +136,7 @@ class PettingZooWrapper(_EnvWrapper): For example, you can provide ``MarlGroupMapType.ONE_GROUP_PER_AGENT``, telling that each agent should have its own tensordict (similar to the pettingzoo parallel API). - Grouping is useful for leveraging vectorisation among agents whose data goes through the same + Grouping is useful for leveraging vectorization among agents whose data goes through the same neural network. Args: @@ -308,24 +303,24 @@ def _make_specs( check_marl_grouping(self.group_map, self.possible_agents) self.has_action_mask = {group: False for group in self.group_map.keys()} - action_spec = CompositeSpec() - observation_spec = CompositeSpec() - reward_spec = CompositeSpec() - done_spec = CompositeSpec( + action_spec = Composite() + observation_spec = Composite() + reward_spec = Composite() + done_spec = Composite( { - "done": DiscreteTensorSpec( + "done": Categorical( n=2, shape=torch.Size((1,)), dtype=torch.bool, device=self.device, ), - "terminated": DiscreteTensorSpec( + "terminated": Categorical( n=2, shape=torch.Size((1,)), dtype=torch.bool, device=self.device, ), - "truncated": DiscreteTensorSpec( + "truncated": Categorical( n=2, shape=torch.Size((1,)), dtype=torch.bool, @@ -356,7 +351,7 @@ def _make_group_specs(self, group_name: str, agent_names: List[str]): observation_specs = [] for agent in agent_names: action_specs.append( - CompositeSpec( + Composite( { "action": _gym_to_torchrl_spec_transform( self.action_space(agent), @@ -368,7 +363,7 @@ def _make_group_specs(self, group_name: str, agent_names: List[str]): ) ) observation_specs.append( - CompositeSpec( + Composite( { "observation": _gym_to_torchrl_spec_transform( self.observation_space(agent), @@ -386,34 +381,31 @@ def _make_group_specs(self, group_name: str, agent_names: List[str]): # We uniform this by removing it from both places and optionally set it in a standard location. group_observation_inner_spec = group_observation_spec["observation"] if ( - isinstance(group_observation_inner_spec, CompositeSpec) + isinstance(group_observation_inner_spec, Composite) and "action_mask" in group_observation_inner_spec.keys() ): self.has_action_mask[group_name] = True del group_observation_inner_spec["action_mask"] - group_observation_spec["action_mask"] = DiscreteTensorSpec( + group_observation_spec["action_mask"] = Categorical( n=2, shape=group_action_spec["action"].shape if not self.categorical_actions - else ( - *group_action_spec["action"].shape, - group_action_spec["action"].space.n, - ), + else group_action_spec["action"].to_one_hot_spec().shape, dtype=torch.bool, device=self.device, ) if self.use_mask: - group_observation_spec["mask"] = DiscreteTensorSpec( + group_observation_spec["mask"] = Categorical( n=2, shape=torch.Size((n_agents,)), dtype=torch.bool, device=self.device, ) - group_reward_spec = CompositeSpec( + group_reward_spec = Composite( { - "reward": UnboundedContinuousTensorSpec( + "reward": Unbounded( shape=torch.Size((n_agents, 1)), device=self.device, dtype=torch.float32, @@ -421,21 +413,21 @@ def _make_group_specs(self, group_name: str, agent_names: List[str]): }, shape=torch.Size((n_agents,)), ) - group_done_spec = CompositeSpec( + group_done_spec = Composite( { - "done": DiscreteTensorSpec( + "done": Categorical( n=2, shape=torch.Size((n_agents, 1)), dtype=torch.bool, device=self.device, ), - "terminated": DiscreteTensorSpec( + "terminated": Categorical( n=2, shape=torch.Size((n_agents, 1)), dtype=torch.bool, device=self.device, ), - "truncated": DiscreteTensorSpec( + "truncated": Categorical( n=2, shape=torch.Size((n_agents, 1)), dtype=torch.bool, @@ -473,11 +465,11 @@ def _init_env(self): info_specs = [] for agent in agents: info_specs.append( - CompositeSpec( + Composite( { - "info": CompositeSpec( + "info": Composite( { - key: UnboundedContinuousTensorSpec( + key: Unbounded( shape=torch.as_tensor(value).shape, device=self.device, ) @@ -495,11 +487,11 @@ def _init_env(self): group_action_spec = self.input_spec[ "full_action_spec", group, "action" ] - self.observation_spec[group]["action_mask"] = DiscreteTensorSpec( + self.observation_spec[group]["action_mask"] = Categorical( n=2, shape=group_action_spec.shape if not self.categorical_actions - else (*group_action_spec.shape, group_action_spec.space.n), + else group_action_spec.to_one_hot_spec().shape, dtype=torch.bool, device=self.device, ) @@ -518,7 +510,7 @@ def _init_env(self): ) except AttributeError: state_example = torch.as_tensor(self.state(), device=self.device) - state_spec = UnboundedContinuousTensorSpec( + state_spec = Unbounded( shape=state_example.shape, dtype=state_example.dtype, device=self.device, @@ -809,9 +801,7 @@ def _update_action_mask(self, td, observation_dict, info_dict): del agent_info["action_mask"] group_action_spec = self.input_spec["full_action_spec", group, "action"] - if isinstance( - group_action_spec, (DiscreteTensorSpec, OneHotDiscreteTensorSpec) - ): + if isinstance(group_action_spec, (Categorical, OneHot)): # We update the mask for available actions group_action_spec.update_mask(group_mask.clone()) @@ -904,7 +894,7 @@ class PettingZooEnv(PettingZooWrapper): For example, you can provide ``MarlGroupMapType.ONE_GROUP_PER_AGENT``, telling that each agent should have its own tensordict (similar to the pettingzoo parallel API). - Grouping is useful for leveraging vectorisation among agents whose data goes through the same + Grouping is useful for leveraging vectorization among agents whose data goes through the same neural network. Args: diff --git a/torchrl/envs/libs/robohive.py b/torchrl/envs/libs/robohive.py index 5e5c8f52393..30d9c644ced 100644 --- a/torchrl/envs/libs/robohive.py +++ b/torchrl/envs/libs/robohive.py @@ -12,7 +12,7 @@ import numpy as np import torch from tensordict import TensorDict -from torchrl.data.tensor_specs import UnboundedContinuousTensorSpec +from torchrl.data.tensor_specs import Unbounded from torchrl.envs.libs.gym import ( _AsyncMeta, _gym_to_torchrl_spec_transform, @@ -80,8 +80,8 @@ class RoboHiveEnv(GymEnv, metaclass=_RoboHiveBuild): Args: env_name (str): the environment name to build. Must be one of :attr:`.available_envs` categorical_action_encoding (bool, optional): if ``True``, categorical - specs will be converted to the TorchRL equivalent (:class:`torchrl.data.DiscreteTensorSpec`), - otherwise a one-hot encoding will be used (:class:`torchrl.data.OneHotTensorSpec`). + specs will be converted to the TorchRL equivalent (:class:`torchrl.data.Categorical`), + otherwise a one-hot encoding will be used (:class:`torchrl.data.OneHot`). Defaults to ``False``. Keyword Args: @@ -305,7 +305,7 @@ def get_obs(): ) self.observation_spec = observation_spec - self.reward_spec = UnboundedContinuousTensorSpec( + self.reward_spec = Unbounded( shape=(1,), device=self.device, ) # default diff --git a/torchrl/envs/libs/smacv2.py b/torchrl/envs/libs/smacv2.py index d460eb38f1e..67e71da0d5a 100644 --- a/torchrl/envs/libs/smacv2.py +++ b/torchrl/envs/libs/smacv2.py @@ -10,13 +10,7 @@ import torch from tensordict import TensorDict, TensorDictBase -from torchrl.data.tensor_specs import ( - BoundedTensorSpec, - CompositeSpec, - DiscreteTensorSpec, - OneHotDiscreteTensorSpec, - UnboundedContinuousTensorSpec, -) +from torchrl.data.tensor_specs import Bounded, Categorical, Composite, OneHot, Unbounded from torchrl.envs.common import _EnvWrapper from torchrl.envs.utils import _classproperty, ACTION_MASK_ERROR @@ -224,11 +218,11 @@ def _build_env( def _make_specs(self, env: "smacv2.env.StarCraft2Env") -> None: # noqa: F821 self.group_map = {"agents": [str(i) for i in range(self.n_agents)]} - self.reward_spec = UnboundedContinuousTensorSpec( + self.reward_spec = Unbounded( shape=torch.Size((1,)), device=self.device, ) - self.done_spec = DiscreteTensorSpec( + self.done_spec = Categorical( n=2, shape=torch.Size((1,)), dtype=torch.bool, @@ -241,54 +235,50 @@ def _init_env(self) -> None: self._env.reset() self._update_action_mask() - def _make_action_spec(self) -> CompositeSpec: + def _make_action_spec(self) -> Composite: if self.categorical_actions: - action_spec = DiscreteTensorSpec( + action_spec = Categorical( self.n_actions, shape=torch.Size((self.n_agents,)), device=self.device, dtype=torch.long, ) else: - action_spec = OneHotDiscreteTensorSpec( + action_spec = OneHot( self.n_actions, shape=torch.Size((self.n_agents, self.n_actions)), device=self.device, dtype=torch.long, ) - spec = CompositeSpec( + spec = Composite( { - "agents": CompositeSpec( + "agents": Composite( {"action": action_spec}, shape=torch.Size((self.n_agents,)) ) } ) return spec - def _make_observation_spec(self) -> CompositeSpec: - obs_spec = BoundedTensorSpec( + def _make_observation_spec(self) -> Composite: + obs_spec = Bounded( low=-1.0, high=1.0, shape=torch.Size([self.n_agents, self.get_obs_size()]), device=self.device, dtype=torch.float32, ) - info_spec = CompositeSpec( + info_spec = Composite( { - "battle_won": DiscreteTensorSpec( - 2, dtype=torch.bool, device=self.device - ), - "episode_limit": DiscreteTensorSpec( - 2, dtype=torch.bool, device=self.device - ), - "dead_allies": BoundedTensorSpec( + "battle_won": Categorical(2, dtype=torch.bool, device=self.device), + "episode_limit": Categorical(2, dtype=torch.bool, device=self.device), + "dead_allies": Bounded( low=0, high=self.n_agents, dtype=torch.long, device=self.device, shape=(), ), - "dead_enemies": BoundedTensorSpec( + "dead_enemies": Bounded( low=0, high=self.n_enemies, dtype=torch.long, @@ -297,19 +287,19 @@ def _make_observation_spec(self) -> CompositeSpec: ), } ) - mask_spec = DiscreteTensorSpec( + mask_spec = Categorical( 2, torch.Size([self.n_agents, self.n_actions]), device=self.device, dtype=torch.bool, ) - spec = CompositeSpec( + spec = Composite( { - "agents": CompositeSpec( + "agents": Composite( {"observation": obs_spec, "action_mask": mask_spec}, shape=torch.Size((self.n_agents,)), ), - "state": BoundedTensorSpec( + "state": Bounded( low=-1.0, high=1.0, shape=torch.Size((self.get_state_size(),)), diff --git a/torchrl/envs/libs/unity_mlagents.py b/torchrl/envs/libs/unity_mlagents.py new file mode 100644 index 00000000000..95c2460bc83 --- /dev/null +++ b/torchrl/envs/libs/unity_mlagents.py @@ -0,0 +1,899 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from __future__ import annotations + +import importlib.util +from typing import Dict, List, Optional + +import torch +from tensordict import TensorDict, TensorDictBase + +from torchrl.data.tensor_specs import ( + BoundedContinuous, + Categorical, + Composite, + MultiCategorical, + MultiOneHot, + Unbounded, +) +from torchrl.envs.common import _EnvWrapper +from torchrl.envs.utils import _classproperty, check_marl_grouping, MarlGroupMapType + +_has_unity_mlagents = importlib.util.find_spec("mlagents_envs") is not None + + +def _get_registered_envs(): + if not _has_unity_mlagents: + raise ImportError( + "mlagents_envs not found. Consider downloading and installing " + f"mlagents from {UnityMLAgentsWrapper.git_url}." + ) + + from mlagents_envs.registry import default_registry + + return list(default_registry.keys()) + + +class UnityMLAgentsWrapper(_EnvWrapper): + """Unity ML-Agents environment wrapper. + + GitHub: https://github.com/Unity-Technologies/ml-agents + + Documentation: https://unity-technologies.github.io/ml-agents/Python-LLAPI/ + + Args: + env (mlagents_envs.environment.UnityEnvironment): the ML-Agents + environment to wrap. + + Keyword Args: + device (torch.device, optional): if provided, the device on which the data + is to be cast. Defaults to ``None``. + batch_size (torch.Size, optional): the batch size of the environment. + Defaults to ``torch.Size([])``. + allow_done_after_reset (bool, optional): if ``True``, it is tolerated + for envs to be ``done`` just after :meth:`~.reset` is called. + Defaults to ``False``. + group_map (MarlGroupMapType or Dict[str, List[str]]], optional): how to + group agents in tensordicts for input/output. See + :class:`~torchrl.envs.utils.MarlGroupMapType` for more info. If not + specified, agents are grouped according to the group ID given by the + Unity environment. Defaults to ``None``. + categorical_actions (bool, optional): if ``True``, categorical specs + will be converted to the TorchRL equivalent + (:class:`torchrl.data.Categorical`), otherwise a one-hot encoding + will be used (:class:`torchrl.data.OneHot`). Defaults to ``False``. + + Attributes: + available_envs: list of registered environments available to build + + Examples: + >>> from mlagents_envs.environment import UnityEnvironment + >>> base_env = UnityEnvironment() + >>> from torchrl.envs import UnityMLAgentsWrapper + >>> env = UnityMLAgentsWrapper(base_env) + >>> td = env.reset() + >>> td = env.step(td.update(env.full_action_spec.rand())) + """ + + git_url = "https://github.com/Unity-Technologies/ml-agents" + libname = "mlagents_envs" + _lib = None + + @_classproperty + def lib(cls): + if cls._lib is not None: + return cls._lib + + import mlagents_envs + import mlagents_envs.environment + + cls._lib = mlagents_envs + return mlagents_envs + + def __init__( + self, + env=None, + *, + group_map: MarlGroupMapType | Dict[str, List[str]] | None = None, + categorical_actions: bool = False, + **kwargs, + ): + if env is not None: + kwargs["env"] = env + + self.group_map = group_map + self.categorical_actions = categorical_actions + super().__init__(**kwargs) + + def _check_kwargs(self, kwargs: Dict): + mlagents_envs = self.lib + if "env" not in kwargs: + raise TypeError("Could not find environment key 'env' in kwargs.") + env = kwargs["env"] + if not isinstance(env, mlagents_envs.environment.UnityEnvironment): + raise TypeError( + "env is not of type 'mlagents_envs.environment.UnityEnvironment'" + ) + + def _build_env(self, env, requires_grad: bool = False, **kwargs): + self.requires_grad = requires_grad + return env + + def _init_env(self): + self._update_action_mask() + + # Creates a group map where agents are grouped by the group_id given by the + # Unity environment. + def _collect_agents(self, env): + agent_name_to_behavior_map = {} + agent_name_to_group_id_map = {} + + for steps_idx in [0, 1]: + for behavior in env.behavior_specs.keys(): + steps = env.get_steps(behavior)[steps_idx] + is_terminal = steps_idx == 1 + agent_ids = steps.agent_id + group_ids = steps.group_id + + for agent_id, group_id in zip(agent_ids, group_ids): + agent_name = f"agent_{agent_id}" + if agent_name in agent_name_to_behavior_map: + # Sometimes in an MLAgents environment, an agent may + # show up in both the decision steps and the terminal + # steps. When that happens, just skip the duplicate. + assert is_terminal + continue + agent_name_to_behavior_map[agent_name] = behavior + agent_name_to_group_id_map[agent_name] = group_id + + return ( + agent_name_to_behavior_map, + agent_name_to_group_id_map, + ) + + # Creates a group map where agents are grouped by their group_id. + def _make_default_group_map(self, agent_name_to_group_id_map): + group_map = {} + for agent_name, group_id in agent_name_to_group_id_map.items(): + group_name = f"group_{group_id}" + if group_name not in group_map: + group_map[group_name] = [] + group_map[group_name].append(agent_name) + return group_map + + def _make_group_map(self, group_map, agent_name_to_group_id_map): + if group_map is None: + group_map = self._make_default_group_map(agent_name_to_group_id_map) + elif isinstance(group_map, MarlGroupMapType): + group_map = group_map.get_group_map(agent_name_to_group_id_map.keys()) + check_marl_grouping(group_map, agent_name_to_group_id_map.keys()) + agent_name_to_group_name_map = {} + for group_name, agents in group_map.items(): + for agent_name in agents: + agent_name_to_group_name_map[agent_name] = group_name + return group_map, agent_name_to_group_name_map + + def _make_specs( + self, env: "mlagents_envs.environment.UnityEnvironment" # noqa: F821 + ) -> None: + # NOTE: We need to reset here because mlagents only initializes the + # agents and behaviors after reset. In order to build specs, we make the + # following assumptions about the mlagents environment: + # * all behaviors are defined on the first step + # * all agents request an action on the first step + # However, mlagents allows you to break these assumptions, so we probably + # will need to detect changes to the behaviors and agents on each step. + env.reset() + ( + self.agent_name_to_behavior_map, + self.agent_name_to_group_id_map, + ) = self._collect_agents(env) + + (self.group_map, self.agent_name_to_group_name_map) = self._make_group_map( + self.group_map, self.agent_name_to_group_id_map + ) + + action_spec = {} + observation_spec = {} + reward_spec = {} + done_spec = {} + + for group_name, agents in self.group_map.items(): + group_action_spec = {} + group_observation_spec = {} + group_reward_spec = {} + group_done_spec = {} + for agent_name in agents: + behavior = self.agent_name_to_behavior_map[agent_name] + behavior_spec = env.behavior_specs[behavior] + + # Create action spec + agent_action_spec = Composite() + env_action_spec = behavior_spec.action_spec + discrete_branches = env_action_spec.discrete_branches + continuous_size = env_action_spec.continuous_size + if len(discrete_branches) > 0: + discrete_action_spec_cls = ( + MultiCategorical if self.categorical_actions else MultiOneHot + ) + agent_action_spec["discrete_action"] = discrete_action_spec_cls( + discrete_branches, + dtype=torch.int32, + device=self.device, + ) + if continuous_size > 0: + # In mlagents, continuous actions can take values between -1 + # and 1 by default: + # https://github.com/Unity-Technologies/ml-agents/blob/22a59aad34ef46a5de05469735426feed758f8f5/ml-agents-envs/mlagents_envs/base_env.py#L395 + agent_action_spec["continuous_action"] = BoundedContinuous( + -1, 1, (continuous_size,), self.device, torch.float32 + ) + group_action_spec[agent_name] = agent_action_spec + + # Create observation spec + agent_observation_spec = Composite() + for obs_idx, env_observation_spec in enumerate( + behavior_spec.observation_specs + ): + if len(env_observation_spec.name) == 0: + obs_name = f"observation_{obs_idx}" + else: + obs_name = env_observation_spec.name + agent_observation_spec[obs_name] = Unbounded( + env_observation_spec.shape, + dtype=torch.float32, + device=self.device, + ) + group_observation_spec[agent_name] = agent_observation_spec + + # Create reward spec + agent_reward_spec = Composite() + agent_reward_spec["reward"] = Unbounded( + (1,), + dtype=torch.float32, + device=self.device, + ) + agent_reward_spec["group_reward"] = Unbounded( + (1,), + dtype=torch.float32, + device=self.device, + ) + group_reward_spec[agent_name] = agent_reward_spec + + # Create done spec + agent_done_spec = Composite() + for done_key in ["done", "terminated", "truncated"]: + agent_done_spec[done_key] = Categorical( + 2, (1,), dtype=torch.bool, device=self.device + ) + group_done_spec[agent_name] = agent_done_spec + + action_spec[group_name] = group_action_spec + observation_spec[group_name] = group_observation_spec + reward_spec[group_name] = group_reward_spec + done_spec[group_name] = group_done_spec + + self.action_spec = Composite(action_spec) + self.observation_spec = Composite(observation_spec) + self.reward_spec = Composite(reward_spec) + self.done_spec = Composite(done_spec) + + def _set_seed(self, seed): + if seed is not None: + raise NotImplementedError("This environment has no seed.") + + def _check_agent_exists(self, agent_name, group_id): + if agent_name not in self.agent_name_to_group_id_map: + raise RuntimeError( + ( + "Unity environment added a new agent. This is not yet " + "supported in torchrl." + ) + ) + if self.agent_name_to_group_id_map[agent_name] != group_id: + raise RuntimeError( + ( + "Unity environment changed the group of an agent. This " + "is not yet supported in torchrl." + ) + ) + + def _update_action_mask(self): + for behavior, behavior_spec in self._env.behavior_specs.items(): + env_action_spec = behavior_spec.action_spec + discrete_branches = env_action_spec.discrete_branches + + if len(discrete_branches) > 0: + steps = self._env.get_steps(behavior)[0] + env_action_mask = steps.action_mask + if env_action_mask is not None: + combined_action_mask = torch.cat( + [ + torch.tensor(m, device=self.device, dtype=torch.bool) + for m in env_action_mask + ], + dim=-1, + ).logical_not() + + for agent_id, group_id, agent_action_mask in zip( + steps.agent_id, steps.group_id, combined_action_mask + ): + agent_name = f"agent_{agent_id}" + self._check_agent_exists(agent_name, group_id) + group_name = self.agent_name_to_group_name_map[agent_name] + self.full_action_spec[ + group_name, agent_name, "discrete_action" + ].update_mask(agent_action_mask) + + def _make_td_out(self, tensordict_in, is_reset=False): + source = {} + for behavior, behavior_spec in self._env.behavior_specs.items(): + for idx, steps in enumerate(self._env.get_steps(behavior)): + is_terminal = idx == 1 + for steps_idx, (agent_id, group_id) in enumerate( + zip(steps.agent_id, steps.group_id) + ): + agent_name = f"agent_{agent_id}" + self._check_agent_exists(agent_name, group_id) + group_name = self.agent_name_to_group_name_map[agent_name] + if group_name not in source: + source[group_name] = {} + if agent_name not in source[group_name]: + source[group_name][agent_name] = {} + + # Add observations + for obs_idx, ( + behavior_observation, + env_observation_spec, + ) in enumerate(zip(steps.obs, behavior_spec.observation_specs)): + observation = torch.tensor( + behavior_observation[steps_idx], + device=self.device, + dtype=torch.float32, + ) + if len(env_observation_spec.name) == 0: + obs_name = f"observation_{obs_idx}" + else: + obs_name = env_observation_spec.name + source[group_name][agent_name][obs_name] = observation + + # Add rewards + if not is_reset: + source[group_name][agent_name]["reward"] = torch.tensor( + steps.reward[steps_idx], + device=self.device, + dtype=torch.float32, + ) + source[group_name][agent_name]["group_reward"] = torch.tensor( + steps.group_reward[steps_idx], + device=self.device, + dtype=torch.float32, + ) + + # Add done + done = is_terminal and not is_reset + source[group_name][agent_name]["done"] = torch.tensor( + done, device=self.device, dtype=torch.bool + ) + source[group_name][agent_name]["truncated"] = torch.tensor( + done and steps.interrupted[steps_idx], + device=self.device, + dtype=torch.bool, + ) + source[group_name][agent_name]["terminated"] = torch.tensor( + done and not steps.interrupted[steps_idx], + device=self.device, + dtype=torch.bool, + ) + + if tensordict_in is not None: + # In MLAgents, a given step will only contain information for agents + # which either terminated or requested a decision during the step. + # Some agents may have neither terminated nor requested a decision, + # so we need to fill in their information from the previous step. + for group_name, agents in self.group_map.items(): + for agent_name in agents: + if group_name not in source.keys(): + source[group_name] = {} + if agent_name not in source[group_name].keys(): + agent_dict = {} + agent_behavior = self.agent_name_to_behavior_map[agent_name] + behavior_spec = self._env.behavior_specs[agent_behavior] + td_agent_in = tensordict_in[group_name, agent_name] + + # Add observations + for env_observation_spec in behavior_spec.observation_specs: + if len(env_observation_spec.name) == 0: + obs_name = f"observation_{obs_idx}" + else: + obs_name = env_observation_spec.name + agent_dict[obs_name] = td_agent_in[obs_name] + + # Add rewards + if not is_reset: + # Since the agent didn't request an decision, the + # reward is 0 + agent_dict["reward"] = torch.zeros( + (1,), device=self.device, dtype=torch.float32 + ) + agent_dict["group_reward"] = torch.zeros( + (1,), device=self.device, dtype=torch.float32 + ) + + # Add done + agent_dict["done"] = torch.tensor( + False, device=self.device, dtype=torch.bool + ) + agent_dict["terminated"] = torch.tensor( + False, device=self.device, dtype=torch.bool + ) + agent_dict["truncated"] = torch.tensor( + False, device=self.device, dtype=torch.bool + ) + + source[group_name][agent_name] = agent_dict + + tensordict_out = TensorDict( + source=source, + batch_size=self.batch_size, + device=self.device, + ) + + return tensordict_out + + def _get_action_from_tensor(self, tensor): + if not self.categorical_actions: + action = torch.argmax(tensor, dim=-1) + else: + action = tensor + return action + + def _step(self, tensordict: TensorDictBase) -> TensorDictBase: + # Apply actions + for behavior, behavior_spec in self._env.behavior_specs.items(): + env_action_spec = behavior_spec.action_spec + steps = self._env.get_steps(behavior)[0] + + for agent_id, group_id in zip(steps.agent_id, steps.group_id): + agent_name = f"agent_{agent_id}" + self._check_agent_exists(agent_name, group_id) + group_name = self.agent_name_to_group_name_map[agent_name] + + agent_action_spec = self.full_action_spec[group_name, agent_name] + action_tuple = self.lib.base_env.ActionTuple() + discrete_branches = env_action_spec.discrete_branches + continuous_size = env_action_spec.continuous_size + + if len(discrete_branches) > 0: + discrete_spec = agent_action_spec["discrete_action"] + discrete_action = tensordict[ + group_name, agent_name, "discrete_action" + ] + if not self.categorical_actions: + discrete_action = discrete_spec.to_categorical(discrete_action) + action_tuple.add_discrete(discrete_action[None, ...].numpy()) + + if continuous_size > 0: + continuous_action = tensordict[ + group_name, agent_name, "continuous_action" + ] + action_tuple.add_continuous(continuous_action[None, ...].numpy()) + + self._env.set_action_for_agent(behavior, agent_id, action_tuple) + + self._env.step() + self._update_action_mask() + return self._make_td_out(tensordict) + + def _to_tensor(self, value): + return torch.tensor(value, device=self.device, dtype=torch.float32) + + def _reset( + self, tensordict: TensorDictBase | None = None, **kwargs + ) -> TensorDictBase: + self._env.reset() + return self._make_td_out(tensordict, is_reset=True) + + def close(self): + self._env.close() + + @_classproperty + def available_envs(cls): + if not _has_unity_mlagents: + return [] + return _get_registered_envs() + + +class UnityMLAgentsEnv(UnityMLAgentsWrapper): + """Unity ML-Agents environment wrapper. + + GitHub: https://github.com/Unity-Technologies/ml-agents + + Documentation: https://unity-technologies.github.io/ml-agents/Python-LLAPI/ + + This class can be provided any of the optional initialization arguments that + :class:`mlagents_envs.environment.UnityEnvironment` class provides. For a + list of these arguments, see: + https://unity-technologies.github.io/ml-agents/Python-LLAPI-Documentation/#__init__ + + If both ``file_name`` and ``registered_name`` are given, an error is raised. + + If neither ``file_name`` nor``registered_name`` are given, the environment + setup waits on a localhost port, and the user must execute a Unity ML-Agents + environment binary for to connect to it. + + Args: + file_name (str, optional): if provided, the path to the Unity + environment binary. Defaults to ``None``. + registered_name (str, optional): if provided, the Unity environment + binary is loaded from the default ML-Agents registry. The list of + registered environments is in :attr:`~.available_envs`. Defaults to + ``None``. + + Keyword Args: + device (torch.device, optional): if provided, the device on which the data + is to be cast. Defaults to ``None``. + batch_size (torch.Size, optional): the batch size of the environment. + Defaults to ``torch.Size([])``. + allow_done_after_reset (bool, optional): if ``True``, it is tolerated + for envs to be ``done`` just after :meth:`~.reset` is called. + Defaults to ``False``. + group_map (MarlGroupMapType or Dict[str, List[str]]], optional): how to + group agents in tensordicts for input/output. See + :class:`~torchrl.envs.utils.MarlGroupMapType` for more info. If not + specified, agents are grouped according to the group ID given by the + Unity environment. Defaults to ``None``. + categorical_actions (bool, optional): if ``True``, categorical specs + will be converted to the TorchRL equivalent + (:class:`torchrl.data.Categorical`), otherwise a one-hot encoding + will be used (:class:`torchrl.data.OneHot`). Defaults to ``False``. + + Attributes: + available_envs: list of registered environments available to build + + Examples: + >>> from torchrl.envs import UnityMLAgentsEnv + >>> env = UnityMLAgentsEnv(registered_name='3DBall') + >>> td = env.reset() + >>> td = env.step(td.update(env.full_action_spec.rand())) + >>> td + TensorDict( + fields={ + group_0: TensorDict( + fields={ + agent_0: TensorDict( + fields={ + VectorSensor_size8: Tensor(shape=torch.Size([8]), device=cpu, dtype=torch.float32, is_shared=False), + continuous_action: Tensor(shape=torch.Size([2]), device=cpu, dtype=torch.float32, is_shared=False), + done: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False), + terminated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False), + truncated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False)}, + batch_size=torch.Size([]), + device=None, + is_shared=False), + agent_10: TensorDict( + fields={ + VectorSensor_size8: Tensor(shape=torch.Size([8]), device=cpu, dtype=torch.float32, is_shared=False), + continuous_action: Tensor(shape=torch.Size([2]), device=cpu, dtype=torch.float32, is_shared=False), + done: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False), + terminated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False), + truncated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False)}, + batch_size=torch.Size([]), + device=None, + is_shared=False), + agent_11: TensorDict( + fields={ + VectorSensor_size8: Tensor(shape=torch.Size([8]), device=cpu, dtype=torch.float32, is_shared=False), + continuous_action: Tensor(shape=torch.Size([2]), device=cpu, dtype=torch.float32, is_shared=False), + done: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False), + terminated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False), + truncated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False)}, + batch_size=torch.Size([]), + device=None, + is_shared=False), + agent_1: TensorDict( + fields={ + VectorSensor_size8: Tensor(shape=torch.Size([8]), device=cpu, dtype=torch.float32, is_shared=False), + continuous_action: Tensor(shape=torch.Size([2]), device=cpu, dtype=torch.float32, is_shared=False), + done: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False), + terminated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False), + truncated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False)}, + batch_size=torch.Size([]), + device=None, + is_shared=False), + agent_2: TensorDict( + fields={ + VectorSensor_size8: Tensor(shape=torch.Size([8]), device=cpu, dtype=torch.float32, is_shared=False), + continuous_action: Tensor(shape=torch.Size([2]), device=cpu, dtype=torch.float32, is_shared=False), + done: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False), + terminated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False), + truncated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False)}, + batch_size=torch.Size([]), + device=None, + is_shared=False), + agent_3: TensorDict( + fields={ + VectorSensor_size8: Tensor(shape=torch.Size([8]), device=cpu, dtype=torch.float32, is_shared=False), + continuous_action: Tensor(shape=torch.Size([2]), device=cpu, dtype=torch.float32, is_shared=False), + done: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False), + terminated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False), + truncated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False)}, + batch_size=torch.Size([]), + device=None, + is_shared=False), + agent_4: TensorDict( + fields={ + VectorSensor_size8: Tensor(shape=torch.Size([8]), device=cpu, dtype=torch.float32, is_shared=False), + continuous_action: Tensor(shape=torch.Size([2]), device=cpu, dtype=torch.float32, is_shared=False), + done: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False), + terminated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False), + truncated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False)}, + batch_size=torch.Size([]), + device=None, + is_shared=False), + agent_5: TensorDict( + fields={ + VectorSensor_size8: Tensor(shape=torch.Size([8]), device=cpu, dtype=torch.float32, is_shared=False), + continuous_action: Tensor(shape=torch.Size([2]), device=cpu, dtype=torch.float32, is_shared=False), + done: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False), + terminated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False), + truncated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False)}, + batch_size=torch.Size([]), + device=None, + is_shared=False), + agent_6: TensorDict( + fields={ + VectorSensor_size8: Tensor(shape=torch.Size([8]), device=cpu, dtype=torch.float32, is_shared=False), + continuous_action: Tensor(shape=torch.Size([2]), device=cpu, dtype=torch.float32, is_shared=False), + done: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False), + terminated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False), + truncated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False)}, + batch_size=torch.Size([]), + device=None, + is_shared=False), + agent_7: TensorDict( + fields={ + VectorSensor_size8: Tensor(shape=torch.Size([8]), device=cpu, dtype=torch.float32, is_shared=False), + continuous_action: Tensor(shape=torch.Size([2]), device=cpu, dtype=torch.float32, is_shared=False), + done: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False), + terminated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False), + truncated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False)}, + batch_size=torch.Size([]), + device=None, + is_shared=False), + agent_8: TensorDict( + fields={ + VectorSensor_size8: Tensor(shape=torch.Size([8]), device=cpu, dtype=torch.float32, is_shared=False), + continuous_action: Tensor(shape=torch.Size([2]), device=cpu, dtype=torch.float32, is_shared=False), + done: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False), + terminated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False), + truncated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False)}, + batch_size=torch.Size([]), + device=None, + is_shared=False), + agent_9: TensorDict( + fields={ + VectorSensor_size8: Tensor(shape=torch.Size([8]), device=cpu, dtype=torch.float32, is_shared=False), + continuous_action: Tensor(shape=torch.Size([2]), device=cpu, dtype=torch.float32, is_shared=False), + done: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False), + terminated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False), + truncated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False)}, + batch_size=torch.Size([]), + device=None, + is_shared=False)}, + batch_size=torch.Size([]), + device=None, + is_shared=False), + next: TensorDict( + fields={ + group_0: TensorDict( + fields={ + agent_0: TensorDict( + fields={ + VectorSensor_size8: Tensor(shape=torch.Size([8]), device=cpu, dtype=torch.float32, is_shared=False), + done: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False), + group_reward: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, is_shared=False), + reward: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, is_shared=False), + terminated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False), + truncated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False)}, + batch_size=torch.Size([]), + device=None, + is_shared=False), + agent_10: TensorDict( + fields={ + VectorSensor_size8: Tensor(shape=torch.Size([8]), device=cpu, dtype=torch.float32, is_shared=False), + done: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False), + group_reward: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, is_shared=False), + reward: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, is_shared=False), + terminated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False), + truncated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False)}, + batch_size=torch.Size([]), + device=None, + is_shared=False), + agent_11: TensorDict( + fields={ + VectorSensor_size8: Tensor(shape=torch.Size([8]), device=cpu, dtype=torch.float32, is_shared=False), + done: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False), + group_reward: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, is_shared=False), + reward: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, is_shared=False), + terminated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False), + truncated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False)}, + batch_size=torch.Size([]), + device=None, + is_shared=False), + agent_1: TensorDict( + fields={ + VectorSensor_size8: Tensor(shape=torch.Size([8]), device=cpu, dtype=torch.float32, is_shared=False), + done: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False), + group_reward: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, is_shared=False), + reward: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, is_shared=False), + terminated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False), + truncated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False)}, + batch_size=torch.Size([]), + device=None, + is_shared=False), + agent_2: TensorDict( + fields={ + VectorSensor_size8: Tensor(shape=torch.Size([8]), device=cpu, dtype=torch.float32, is_shared=False), + done: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False), + group_reward: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, is_shared=False), + reward: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, is_shared=False), + terminated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False), + truncated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False)}, + batch_size=torch.Size([]), + device=None, + is_shared=False), + agent_3: TensorDict( + fields={ + VectorSensor_size8: Tensor(shape=torch.Size([8]), device=cpu, dtype=torch.float32, is_shared=False), + done: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False), + group_reward: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, is_shared=False), + reward: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, is_shared=False), + terminated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False), + truncated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False)}, + batch_size=torch.Size([]), + device=None, + is_shared=False), + agent_4: TensorDict( + fields={ + VectorSensor_size8: Tensor(shape=torch.Size([8]), device=cpu, dtype=torch.float32, is_shared=False), + done: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False), + group_reward: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, is_shared=False), + reward: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, is_shared=False), + terminated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False), + truncated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False)}, + batch_size=torch.Size([]), + device=None, + is_shared=False), + agent_5: TensorDict( + fields={ + VectorSensor_size8: Tensor(shape=torch.Size([8]), device=cpu, dtype=torch.float32, is_shared=False), + done: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False), + group_reward: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, is_shared=False), + reward: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, is_shared=False), + terminated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False), + truncated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False)}, + batch_size=torch.Size([]), + device=None, + is_shared=False), + agent_6: TensorDict( + fields={ + VectorSensor_size8: Tensor(shape=torch.Size([8]), device=cpu, dtype=torch.float32, is_shared=False), + done: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False), + group_reward: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, is_shared=False), + reward: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, is_shared=False), + terminated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False), + truncated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False)}, + batch_size=torch.Size([]), + device=None, + is_shared=False), + agent_7: TensorDict( + fields={ + VectorSensor_size8: Tensor(shape=torch.Size([8]), device=cpu, dtype=torch.float32, is_shared=False), + done: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False), + group_reward: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, is_shared=False), + reward: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, is_shared=False), + terminated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False), + truncated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False)}, + batch_size=torch.Size([]), + device=None, + is_shared=False), + agent_8: TensorDict( + fields={ + VectorSensor_size8: Tensor(shape=torch.Size([8]), device=cpu, dtype=torch.float32, is_shared=False), + done: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False), + group_reward: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, is_shared=False), + reward: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, is_shared=False), + terminated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False), + truncated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False)}, + batch_size=torch.Size([]), + device=None, + is_shared=False), + agent_9: TensorDict( + fields={ + VectorSensor_size8: Tensor(shape=torch.Size([8]), device=cpu, dtype=torch.float32, is_shared=False), + done: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False), + group_reward: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, is_shared=False), + reward: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, is_shared=False), + terminated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False), + truncated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False)}, + batch_size=torch.Size([]), + device=None, + is_shared=False)}, + batch_size=torch.Size([]), + device=None, + is_shared=False)}, + batch_size=torch.Size([]), + device=None, + is_shared=False)}, + batch_size=torch.Size([]), + device=None, + is_shared=False) + """ + + def __init__( + self, + file_name: Optional[str] = None, + registered_name: Optional[str] = None, + *, + group_map: MarlGroupMapType | Dict[str, List[str]] | None = None, + categorical_actions=False, + **kwargs, + ): + kwargs["file_name"] = file_name + kwargs["registered_name"] = registered_name + super().__init__( + group_map=group_map, + categorical_actions=categorical_actions, + **kwargs, + ) + + def _build_env( + self, + file_name: Optional[str], + registered_name: Optional[str], + **kwargs, + ) -> "mlagents_envs.environment.UnityEnvironment": # noqa: F821 + if not _has_unity_mlagents: + raise ImportError( + "mlagents_envs not found, unable to create environment. " + "Consider downloading and installing mlagents from " + f"{self.git_url}" + ) + if file_name is not None and registered_name is not None: + raise ValueError( + "Both `file_name` and `registered_name` were specified, which " + "is not allowed. Specify one of them or neither." + ) + elif registered_name is not None: + from mlagents_envs.registry import default_registry + + env = default_registry[registered_name].make(**kwargs) + else: + env = self.lib.environment.UnityEnvironment(file_name, **kwargs) + requires_grad = kwargs.pop("requires_grad", False) + return super()._build_env( + env, + requires_grad=requires_grad, + ) + + @property + def file_name(self): + return self._constructor_kwargs["file_name"] + + @property + def registered_name(self): + return self._constructor_kwargs["registered_name"] + + def _check_kwargs(self, kwargs: Dict): + pass + + def __repr__(self) -> str: + if self.registered_name is not None: + env_name = self.registered_name + else: + env_name = self.file_name + return f"{self.__class__.__name__}(env={env_name}, batch_size={self.batch_size}, device={self.device})" diff --git a/torchrl/envs/libs/vmas.py b/torchrl/envs/libs/vmas.py index 9751e84a3ac..22f9835303b 100644 --- a/torchrl/envs/libs/vmas.py +++ b/torchrl/envs/libs/vmas.py @@ -12,16 +12,16 @@ from tensordict import LazyStackedTensorDict, TensorDict, TensorDictBase from torchrl.data.tensor_specs import ( - BoundedTensorSpec, - CompositeSpec, + Bounded, + Categorical, + Composite, DEVICE_TYPING, - DiscreteTensorSpec, - LazyStackedCompositeSpec, - MultiDiscreteTensorSpec, - MultiOneHotDiscreteTensorSpec, - OneHotDiscreteTensorSpec, + MultiCategorical, + MultiOneHot, + OneHot, + StackedComposite, TensorSpec, - UnboundedContinuousTensorSpec, + Unbounded, ) from torchrl.data.utils import numpy_to_torch_dtype_dict from torchrl.envs.common import _EnvWrapper, EnvBase @@ -57,11 +57,7 @@ def _vmas_to_torchrl_spec_transform( ) -> TensorSpec: gym_spaces = gym_backend("spaces") if isinstance(spec, gym_spaces.discrete.Discrete): - action_space_cls = ( - DiscreteTensorSpec - if categorical_action_encoding - else OneHotDiscreteTensorSpec - ) + action_space_cls = Categorical if categorical_action_encoding else OneHot dtype = ( numpy_to_torch_dtype_dict[spec.dtype] if categorical_action_encoding @@ -75,9 +71,9 @@ def _vmas_to_torchrl_spec_transform( else torch.long ) return ( - MultiDiscreteTensorSpec(spec.nvec, device=device, dtype=dtype) + MultiCategorical(spec.nvec, device=device, dtype=dtype) if categorical_action_encoding - else MultiOneHotDiscreteTensorSpec(spec.nvec, device=device, dtype=dtype) + else MultiOneHot(spec.nvec, device=device, dtype=dtype) ) elif isinstance(spec, gym_spaces.Box): shape = spec.shape @@ -88,9 +84,9 @@ def _vmas_to_torchrl_spec_transform( high = torch.tensor(spec.high, device=device, dtype=dtype) is_unbounded = low.isinf().all() and high.isinf().all() return ( - UnboundedContinuousTensorSpec(shape, device=device, dtype=dtype) + Unbounded(shape, device=device, dtype=dtype) if is_unbounded - else BoundedTensorSpec( + else Bounded( low, high, shape, @@ -98,6 +94,16 @@ def _vmas_to_torchrl_spec_transform( device=device, ) ) + elif isinstance(spec, gym_spaces.Dict): + spec_out = {} + for key in spec.keys(): + spec_out[key] = _vmas_to_torchrl_spec_transform( + spec[key], + device=device, + categorical_action_encoding=categorical_action_encoding, + ) + # the batch-size must be set later + return Composite(spec_out, device=device) else: raise NotImplementedError( f"spec of type {type(spec).__name__} is currently unaccounted for vmas" @@ -322,9 +328,9 @@ def _make_specs( self.group_map = self.group_map.get_group_map(self.agent_names) check_marl_grouping(self.group_map, self.agent_names) - self.unbatched_action_spec = CompositeSpec(device=self.device) - self.unbatched_observation_spec = CompositeSpec(device=self.device) - self.unbatched_reward_spec = CompositeSpec(device=self.device) + self.unbatched_action_spec = Composite(device=self.device) + self.unbatched_observation_spec = Composite(device=self.device) + self.unbatched_reward_spec = Composite(device=self.device) self.het_specs = False self.het_specs_map = {} @@ -341,14 +347,14 @@ def _make_specs( if group_info_spec is not None: self.unbatched_observation_spec[(group, "info")] = group_info_spec group_het_specs = isinstance( - group_observation_spec, LazyStackedCompositeSpec - ) or isinstance(group_action_spec, LazyStackedCompositeSpec) + group_observation_spec, StackedComposite + ) or isinstance(group_action_spec, StackedComposite) self.het_specs_map[group] = group_het_specs self.het_specs = self.het_specs or group_het_specs - self.unbatched_done_spec = CompositeSpec( + self.unbatched_done_spec = Composite( { - "done": DiscreteTensorSpec( + "done": Categorical( n=2, shape=torch.Size((1,)), dtype=torch.bool, @@ -380,7 +386,7 @@ def _make_unbatched_group_specs(self, group: str): agent_index = self.agent_names_to_indices_map[agent_name] agent = self.agents[agent_index] action_specs.append( - CompositeSpec( + Composite( { "action": _vmas_to_torchrl_spec_transform( self.action_space[agent_index], @@ -391,7 +397,7 @@ def _make_unbatched_group_specs(self, group: str): ) ) observation_specs.append( - CompositeSpec( + Composite( { "observation": _vmas_to_torchrl_spec_transform( self.observation_space[agent_index], @@ -402,9 +408,9 @@ def _make_unbatched_group_specs(self, group: str): ) ) reward_specs.append( - CompositeSpec( + Composite( { - "reward": UnboundedContinuousTensorSpec( + "reward": Unbounded( shape=torch.Size((1,)), device=self.device, ) # shape = (1,) @@ -414,9 +420,9 @@ def _make_unbatched_group_specs(self, group: str): agent_info = self.scenario.info(agent) if len(agent_info): info_specs.append( - CompositeSpec( + Composite( { - key: UnboundedContinuousTensorSpec( + key: Unbounded( shape=_selective_unsqueeze( value, batch_size=self.batch_size ).shape[1:], diff --git a/torchrl/envs/model_based/common.py b/torchrl/envs/model_based/common.py index f6b3f97cd4a..2a3c0198f9c 100644 --- a/torchrl/envs/model_based/common.py +++ b/torchrl/envs/model_based/common.py @@ -27,18 +27,18 @@ class ModelBasedEnvBase(EnvBase): Example: >>> import torch >>> from tensordict import TensorDict - >>> from torchrl.data import CompositeSpec, UnboundedContinuousTensorSpec + >>> from torchrl.data import Composite, Unbounded >>> class MyMBEnv(ModelBasedEnvBase): ... def __init__(self, world_model, device="cpu", dtype=None, batch_size=None): ... super().__init__(world_model, device=device, dtype=dtype, batch_size=batch_size) - ... self.observation_spec = CompositeSpec( - ... hidden_observation=UnboundedContinuousTensorSpec((4,)) + ... self.observation_spec = Composite( + ... hidden_observation=Unbounded((4,)) ... ) - ... self.state_spec = CompositeSpec( - ... hidden_observation=UnboundedContinuousTensorSpec((4,)), + ... self.state_spec = Composite( + ... hidden_observation=Unbounded((4,)), ... ) - ... self.action_spec = UnboundedContinuousTensorSpec((1,)) - ... self.reward_spec = UnboundedContinuousTensorSpec((1,)) + ... self.action_spec = Unbounded((1,)) + ... self.reward_spec = Unbounded((1,)) ... ... def _reset(self, tensordict: TensorDict) -> TensorDict: ... tensordict = TensorDict({}, @@ -84,10 +84,10 @@ class ModelBasedEnvBase(EnvBase): Properties: - - observation_spec (CompositeSpec): sampling spec of the observations; + - observation_spec (Composite): sampling spec of the observations; - action_spec (TensorSpec): sampling spec of the actions; - reward_spec (TensorSpec): sampling spec of the rewards; - - input_spec (CompositeSpec): sampling spec of the inputs; + - input_spec (Composite): sampling spec of the inputs; - batch_size (torch.Size): batch_size to be used by the env. If not set, the env accept tensordicts of all batch sizes. - device (torch.device): device where the env input and output are expected to live diff --git a/torchrl/envs/model_based/dreamer.py b/torchrl/envs/model_based/dreamer.py index 5609861c75f..f5636f76c5a 100644 --- a/torchrl/envs/model_based/dreamer.py +++ b/torchrl/envs/model_based/dreamer.py @@ -9,7 +9,7 @@ from tensordict import TensorDict from tensordict.nn import TensorDictModule -from torchrl.data.tensor_specs import CompositeSpec +from torchrl.data.tensor_specs import Composite from torchrl.data.utils import DEVICE_TYPING from torchrl.envs.common import EnvBase from torchrl.envs.model_based import ModelBasedEnvBase @@ -39,7 +39,7 @@ def set_specs_from_env(self, env: EnvBase): """Sets the specs of the environment from the specs of the given environment.""" super().set_specs_from_env(env) self.action_spec = self.action_spec.to(self.device) - self.state_spec = CompositeSpec( + self.state_spec = Composite( state=self.observation_spec["state"], belief=self.observation_spec["belief"], shape=env.batch_size, diff --git a/torchrl/envs/transforms/gym_transforms.py b/torchrl/envs/transforms/gym_transforms.py index 35f122b770a..b3ac334a5d8 100644 --- a/torchrl/envs/transforms/gym_transforms.py +++ b/torchrl/envs/transforms/gym_transforms.py @@ -10,7 +10,7 @@ import torchrl.objectives.common from tensordict import TensorDictBase from tensordict.utils import expand_as_right, NestedKey -from torchrl.data.tensor_specs import UnboundedDiscreteTensorSpec +from torchrl.data.tensor_specs import Unbounded from torchrl.envs.transforms.transforms import FORWARD_NOT_IMPLEMENTED, Transform @@ -179,7 +179,7 @@ def _reset(self, tensordict, tensordict_reset): def transform_observation_spec(self, observation_spec): full_done_spec = self.parent.output_spec["full_done_spec"] observation_spec[self.eol_key] = full_done_spec[self.done_key].clone() - observation_spec[self.lives_key] = UnboundedDiscreteTensorSpec( + observation_spec[self.lives_key] = Unbounded( self.parent.batch_size, device=self.parent.device, dtype=torch.int64, diff --git a/torchrl/envs/transforms/r3m.py b/torchrl/envs/transforms/r3m.py index 546321d5815..bdc8af1eefa 100644 --- a/torchrl/envs/transforms/r3m.py +++ b/torchrl/envs/transforms/r3m.py @@ -11,11 +11,7 @@ from torch.hub import load_state_dict_from_url from torch.nn import Identity -from torchrl.data.tensor_specs import ( - CompositeSpec, - TensorSpec, - UnboundedContinuousTensorSpec, -) +from torchrl.data.tensor_specs import Composite, TensorSpec, Unbounded from torchrl.data.utils import DEVICE_TYPING from torchrl.envs.transforms.transforms import ( CatTensors, @@ -103,8 +99,8 @@ def _apply_transform(self, obs: torch.Tensor) -> None: return out def transform_observation_spec(self, observation_spec: TensorSpec) -> TensorSpec: - if not isinstance(observation_spec, CompositeSpec): - raise ValueError("_R3MNet can only infer CompositeSpec") + if not isinstance(observation_spec, Composite): + raise ValueError("_R3MNet can only infer Composite") keys = [key for key in observation_spec.keys(True, True) if key in self.in_keys] device = observation_spec[keys[0]].device @@ -116,7 +112,7 @@ def transform_observation_spec(self, observation_spec: TensorSpec) -> TensorSpec del observation_spec[in_key] for out_key in self.out_keys: - observation_spec[out_key] = UnboundedContinuousTensorSpec( + observation_spec[out_key] = Unbounded( shape=torch.Size([*dim, self.outdim]), device=device ) @@ -319,7 +315,7 @@ def _init(self): unsqueeze = UnsqueezeTransform( in_keys=in_keys, out_keys=in_keys, - unsqueeze_dim=-4, + dim=-4, ) transforms.append(unsqueeze) diff --git a/torchrl/envs/transforms/rlhf.py b/torchrl/envs/transforms/rlhf.py index 33874393038..6228b0f22b7 100644 --- a/torchrl/envs/transforms/rlhf.py +++ b/torchrl/envs/transforms/rlhf.py @@ -9,7 +9,7 @@ from tensordict.nn import ProbabilisticTensorDictModule, TensorDictParams from tensordict.utils import is_seq_of_nested_key from torch import nn -from torchrl.data.tensor_specs import CompositeSpec, UnboundedContinuousTensorSpec +from torchrl.data.tensor_specs import Composite, Unbounded from torchrl.envs.transforms.transforms import Transform from torchrl.envs.transforms.utils import _set_missing_tolerance, _stateless_param @@ -142,8 +142,8 @@ def _make_detached_param(x): self.sample_log_prob_key = "sample_log_prob" def find_sample_log_prob(module): - if hasattr(module, "SAMPLE_LOG_PROB_KEY"): - self.sample_log_prob_key = module.SAMPLE_LOG_PROB_KEY + if hasattr(module, "log_prob_key"): + self.sample_log_prob_key = module.log_prob_key self.functional_actor.apply(find_sample_log_prob) @@ -186,7 +186,7 @@ def _step( forward = _call - def transform_output_spec(self, output_spec: CompositeSpec) -> CompositeSpec: + def transform_output_spec(self, output_spec: Composite) -> Composite: output_spec = super().transform_output_spec(output_spec) # todo: here we'll need to use the reward_key once it's implemented # parent = self.parent @@ -195,17 +195,17 @@ def transform_output_spec(self, output_spec: CompositeSpec) -> CompositeSpec: if in_key == "reward" and out_key == "reward": parent = self.parent - reward_spec = UnboundedContinuousTensorSpec( + reward_spec = Unbounded( device=output_spec.device, shape=output_spec["full_reward_spec"][parent.reward_key].shape, ) - output_spec["full_reward_spec"] = CompositeSpec( + output_spec["full_reward_spec"] = Composite( {parent.reward_key: reward_spec}, shape=output_spec["full_reward_spec"].shape, ) elif in_key == "reward": parent = self.parent - reward_spec = UnboundedContinuousTensorSpec( + reward_spec = Unbounded( device=output_spec.device, shape=output_spec["full_reward_spec"][parent.reward_key].shape, ) @@ -214,7 +214,7 @@ def transform_output_spec(self, output_spec: CompositeSpec) -> CompositeSpec: observation_spec[out_key] = reward_spec else: observation_spec = output_spec["full_observation_spec"] - reward_spec = UnboundedContinuousTensorSpec( + reward_spec = Unbounded( device=output_spec.device, shape=observation_spec[in_key].shape ) # then we need to populate the output keys diff --git a/torchrl/envs/transforms/transforms.py b/torchrl/envs/transforms/transforms.py index 7c9dec980f5..b70e05ca431 100644 --- a/torchrl/envs/transforms/transforms.py +++ b/torchrl/envs/transforms/transforms.py @@ -39,25 +39,30 @@ unravel_key, unravel_key_list, ) -from tensordict._C import _unravel_key_to_tuple from tensordict.nn import dispatch, TensorDictModuleBase -from tensordict.utils import expand_as_right, expand_right, NestedKey +from tensordict.utils import ( + _unravel_key_to_tuple, + _zip_strict, + expand_as_right, + expand_right, + NestedKey, +) from torch import nn, Tensor from torch.utils._pytree import tree_map from torchrl._utils import _append_last, _ends_with, _make_ordinal_device, _replace_last from torchrl.data.tensor_specs import ( - BinaryDiscreteTensorSpec, - BoundedTensorSpec, - CompositeSpec, + Binary, + Bounded, + Categorical, + Composite, ContinuousBox, - DiscreteTensorSpec, - MultiDiscreteTensorSpec, - MultiOneHotDiscreteTensorSpec, - OneHotDiscreteTensorSpec, + MultiCategorical, + MultiOneHot, + OneHot, TensorSpec, - UnboundedContinuousTensorSpec, + Unbounded, ) from torchrl.envs.common import _do_nothing, _EnvPostInit, EnvBase, make_tensordict from torchrl.envs.transforms import functional as F @@ -80,14 +85,14 @@ def _apply_to_composite(function): @wraps(function) def new_fun(self, observation_spec): - if isinstance(observation_spec, CompositeSpec): + if isinstance(observation_spec, Composite): _specs = observation_spec._specs in_keys = self.in_keys out_keys = self.out_keys - for in_key, out_key in zip(in_keys, out_keys): + for in_key, out_key in _zip_strict(in_keys, out_keys): if in_key in observation_spec.keys(True, True): _specs[out_key] = function(self, observation_spec[in_key].clone()) - return CompositeSpec( + return Composite( _specs, shape=observation_spec.shape, device=observation_spec.device ) else: @@ -109,12 +114,12 @@ def new_fun(self, input_spec): action_spec = input_spec["full_action_spec"].clone() state_spec = input_spec["full_state_spec"] if state_spec is None: - state_spec = CompositeSpec(shape=input_spec.shape, device=input_spec.device) + state_spec = Composite(shape=input_spec.shape, device=input_spec.device) else: state_spec = state_spec.clone() in_keys_inv = self.in_keys_inv out_keys_inv = self.out_keys_inv - for in_key, out_key in zip(in_keys_inv, out_keys_inv): + for in_key, out_key in _zip_strict(in_keys_inv, out_keys_inv): if in_key != out_key: # we only change the input spec if the key is the same continue @@ -122,7 +127,7 @@ def new_fun(self, input_spec): action_spec[out_key] = function(self, action_spec[in_key].clone()) elif in_key in state_spec.keys(True, True): state_spec[out_key] = function(self, state_spec[in_key].clone()) - return CompositeSpec( + return Composite( full_state_spec=state_spec, full_action_spec=action_spec, shape=input_spec.shape, @@ -270,7 +275,7 @@ def _call(self, tensordict: TensorDictBase) -> TensorDictBase: :meth:`TransformedEnv.reset`. """ - for in_key, out_key in zip(self.in_keys, self.out_keys): + for in_key, out_key in _zip_strict(self.in_keys, self.out_keys): value = tensordict.get(in_key, default=None) if value is not None: observation = self._apply_transform(value) @@ -287,7 +292,7 @@ def _call(self, tensordict: TensorDictBase) -> TensorDictBase: @dispatch(source="in_keys", dest="out_keys") def forward(self, tensordict: TensorDictBase) -> TensorDictBase: """Reads the input tensordict, and for the selected keys, applies the transform.""" - for in_key, out_key in zip(self.in_keys, self.out_keys): + for in_key, out_key in _zip_strict(self.in_keys, self.out_keys): data = tensordict.get(in_key, None) if data is not None: data = self._apply_transform(data) @@ -328,7 +333,7 @@ def _inv_apply_transform(self, state: torch.Tensor) -> torch.Tensor: def _inv_call(self, tensordict: TensorDictBase) -> TensorDictBase: if not self.in_keys_inv: return tensordict - for in_key, out_key in zip(self.in_keys_inv, self.out_keys_inv): + for in_key, out_key in _zip_strict(self.in_keys_inv, self.out_keys_inv): data = tensordict.get(in_key, None) if data is not None: item = self._inv_apply_transform(data) @@ -360,7 +365,7 @@ def transform_env_batch_size(self, batch_size: torch.Size): """Transforms the batch-size of the parent env.""" return batch_size - def transform_output_spec(self, output_spec: CompositeSpec) -> CompositeSpec: + def transform_output_spec(self, output_spec: Composite) -> Composite: """Transforms the output spec such that the resulting spec matches transform mapping. This method should generally be left untouched. Changes should be implemented using @@ -599,7 +604,7 @@ def __init__( device = env.device super().__init__(device=None, allow_done_after_reset=None, **kwargs) - # Type matching must be exact here, because subtyping could introduce differences in behaviour that must + # Type matching must be exact here, because subtyping could introduce differences in behavior that must # be contained within the subclass. if type(env) is TransformedEnv and type(self) is TransformedEnv: self._set_env(env.base_env, device) @@ -831,7 +836,7 @@ def _reset_proc_data(self, tensordict, tensordict_reset): return tensordict_reset def _complete_done( - cls, done_spec: CompositeSpec, data: TensorDictBase + cls, done_spec: Composite, data: TensorDictBase ) -> TensorDictBase: # This step has already been completed. We assume the transform module do their job correctly. return data @@ -1090,7 +1095,7 @@ def transform_env_batch_size(self, batch_size: torch.batch_size): return batch_size def transform_input_spec(self, input_spec: TensorSpec) -> TensorSpec: - for t in self.transforms[::-1]: + for t in self.transforms: input_spec = t.transform_input_spec(input_spec) return input_spec @@ -1343,11 +1348,11 @@ def _apply_transform(self, observation: torch.FloatTensor) -> torch.Tensor: @_apply_to_composite def transform_observation_spec(self, observation_spec: TensorSpec) -> TensorSpec: observation_spec = self._pixel_observation(observation_spec) - unsqueeze_dim = [1] if self._should_unsqueeze(observation_spec) else [] + dim = [1] if self._should_unsqueeze(observation_spec) else [] if not self.shape_tolerant or observation_spec.shape[-1] == 3: observation_spec.shape = torch.Size( [ - *unsqueeze_dim, + *dim, *observation_spec.shape[:-3], observation_spec.shape[-1], observation_spec.shape[-3], @@ -1465,7 +1470,7 @@ def _inv_apply_transform(self, state: torch.Tensor) -> torch.Tensor: @_apply_to_composite def transform_observation_spec(self, observation_spec: TensorSpec) -> TensorSpec: - return BoundedTensorSpec( + return Bounded( shape=observation_spec.shape, device=observation_spec.device, dtype=observation_spec.dtype, @@ -1477,7 +1482,7 @@ def transform_reward_spec(self, reward_spec: TensorSpec) -> TensorSpec: for key in self.in_keys: if key in self.parent.reward_keys: spec = self.parent.output_spec["full_reward_spec"][key] - self.parent.output_spec["full_reward_spec"][key] = BoundedTensorSpec( + self.parent.output_spec["full_reward_spec"][key] = Bounded( shape=spec.shape, device=spec.device, dtype=spec.dtype, @@ -1507,7 +1512,7 @@ class TargetReturn(Transform): In goal-conditioned RL, the :class:`~.TargetReturn` is defined as the expected cumulative reward obtained from the current state to the goal state - or the end of the episode. It is used as input for the policy to guide its behaviour. + or the end of the episode. It is used as input for the policy to guide its behavior. For a trained policy typically the maximum return in the environment is chosen as the target return. However, as it is used as input to the policy module, it should be scaled @@ -1633,7 +1638,7 @@ def _reset(self, tensordict: TensorDict, tensordict_reset: TensorDictBase): return tensordict_reset def _call(self, tensordict: TensorDict) -> TensorDict: - for in_key, out_key in zip(self.in_keys, self.out_keys): + for in_key, out_key in _zip_strict(self.in_keys, self.out_keys): val_in = tensordict.get(in_key, None) val_out = tensordict.get(out_key, None) if val_in is not None: @@ -1675,7 +1680,7 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase: ) def transform_observation_spec(self, observation_spec: TensorSpec) -> TensorSpec: - for in_key, out_key in zip(self.in_keys, self.out_keys): + for in_key, out_key in _zip_strict(self.in_keys, self.out_keys): if in_key in self.parent.full_observation_spec.keys(True): target = self.parent.full_observation_spec[in_key] elif in_key in self.parent.full_reward_spec.keys(True): @@ -1685,7 +1690,7 @@ def transform_observation_spec(self, observation_spec: TensorSpec) -> TensorSpec target = self.parent.full_done_spec[in_key] else: raise RuntimeError(f"in_key {in_key} not found in output_spec.") - target_return_spec = UnboundedContinuousTensorSpec( + target_return_spec = Unbounded( shape=target.shape, dtype=target.dtype, device=target.device, @@ -1744,8 +1749,8 @@ def _apply_transform(self, reward: torch.Tensor) -> torch.Tensor: @_apply_to_composite def transform_reward_spec(self, reward_spec: TensorSpec) -> TensorSpec: - if isinstance(reward_spec, UnboundedContinuousTensorSpec): - return BoundedTensorSpec( + if isinstance(reward_spec, Unbounded): + return Bounded( self.clamp_min, self.clamp_max, shape=reward_spec.shape, @@ -1798,7 +1803,7 @@ def _apply_transform(self, reward: torch.Tensor) -> torch.Tensor: @_apply_to_composite def transform_reward_spec(self, reward_spec: TensorSpec) -> TensorSpec: - return BinaryDiscreteTensorSpec( + return Binary( n=1, device=reward_spec.device, shape=reward_spec.shape, @@ -2132,41 +2137,42 @@ class UnsqueezeTransform(Transform): """Inserts a dimension of size one at the specified position. Args: - unsqueeze_dim (int): dimension to unsqueeze. Must be negative (or allow_positive_dim + dim (int): dimension to unsqueeze. Must be negative (or allow_positive_dim must be turned on). + + Keyword Args: allow_positive_dim (bool, optional): if ``True``, positive dimensions are accepted. - :obj:`UnsqueezeTransform` will map these to the n^th feature dimension + `UnsqueezeTransform`` will map these to the n^th feature dimension (ie n^th dimension after batch size of parent env) of the input tensor, - independently from the tensordict batch size (ie positive dims may be + independently of the tensordict batch size (ie positive dims may be dangerous in contexts where tensordict of different batch dimension are passed). Defaults to False, ie. non-negative dimensions are not permitted. + in_keys (list of NestedKeys): input entries (read). + out_keys (list of NestedKeys): input entries (write). Defaults to ``in_keys`` if + not provided. + in_keys_inv (list of NestedKeys): input entries (read) during :meth:`~.inv` calls. + out_keys_inv (list of NestedKeys): input entries (write) during :meth:`~.inv` calls. + Defaults to ``in_keys_in`` if not provided. """ invertible = True @classmethod def __new__(cls, *args, **kwargs): - cls._unsqueeze_dim = None + cls._dim = None return super().__new__(cls) def __init__( self, dim: int = None, + *, allow_positive_dim: bool = False, in_keys: Sequence[NestedKey] | None = None, out_keys: Sequence[NestedKey] | None = None, in_keys_inv: Sequence[NestedKey] | None = None, out_keys_inv: Sequence[NestedKey] | None = None, - **kwargs, ): - if "unsqueeze_dim" in kwargs: - warnings.warn( - "The `unsqueeze_dim` kwarg will be removed in v0.6. Please use `dim` instead." - ) - dim = kwargs["unsqueeze_dim"] - elif dim is None: - raise TypeError("dim must be provided.") if in_keys is None: in_keys = [] # default if out_keys is None: @@ -2186,22 +2192,26 @@ def __init__( raise RuntimeError( "dim should be smaller than 0 to accommodate for " "envs of different batch_sizes. Turn allow_positive_dim to accommodate " - "for positive unsqueeze_dim." + "for positive dim." ) self._dim = dim @property def unsqueeze_dim(self): + return self.dim + + @property + def dim(self): if self._dim >= 0 and self.parent is not None: return len(self.parent.batch_size) + self._dim return self._dim def _apply_transform(self, observation: torch.Tensor) -> torch.Tensor: - observation = observation.unsqueeze(self.unsqueeze_dim) + observation = observation.unsqueeze(self.dim) return observation def _inv_apply_transform(self, observation: torch.Tensor) -> torch.Tensor: - observation = observation.squeeze(self.unsqueeze_dim) + observation = observation.squeeze(self.dim) return observation def _transform_spec(self, spec: TensorSpec): @@ -2248,7 +2258,7 @@ def _reset( def __repr__(self) -> str: s = ( - f"{self.__class__.__name__}(unsqueeze_dim={self.unsqueeze_dim}, in_keys={self.in_keys}, out_keys={self.out_keys}," + f"{self.__class__.__name__}(dim={self.dim}, in_keys={self.in_keys}, out_keys={self.out_keys}," f" in_keys_inv={self.in_keys_inv}, out_keys_inv={self.out_keys_inv})" ) return s @@ -2258,14 +2268,14 @@ class SqueezeTransform(UnsqueezeTransform): """Removes a dimension of size one at the specified position. Args: - squeeze_dim (int): dimension to squeeze. + dim (int): dimension to squeeze. """ invertible = True def __init__( self, - squeeze_dim: int, + dim: int | None = None, *args, in_keys: Optional[Sequence[str]] = None, out_keys: Optional[Sequence[str]] = None, @@ -2273,8 +2283,19 @@ def __init__( out_keys_inv: Optional[Sequence[str]] = None, **kwargs, ): + if dim is None: + if "squeeze_dim" in kwargs: + warnings.warn( + f"squeeze_dim will be deprecated in favor of dim arg in {type(self).__name__}." + ) + dim = kwargs.pop("squeeze_dim") + else: + raise TypeError( + f"dim must be passed to {type(self).__name__} constructor." + ) + super().__init__( - squeeze_dim, + dim, *args, in_keys=in_keys, out_keys=out_keys, @@ -2285,7 +2306,7 @@ def __init__( @property def squeeze_dim(self): - return super().unsqueeze_dim + return super().dim _apply_transform = UnsqueezeTransform._inv_apply_transform _inv_apply_transform = UnsqueezeTransform._apply_transform @@ -2505,7 +2526,7 @@ class ObservationNorm(ObservationTransform): loc (number or tensor): location of the affine transform scale (number or tensor): scale of the affine transform in_keys (sequence of NestedKey, optional): entries to be normalized. Defaults to ["observation", "pixels"]. - All entries will be normalized with the same values: if a different behaviour is desired + All entries will be normalized with the same values: if a different behavior is desired (e.g. a different normalization for pixels and states) different :obj:`ObservationNorm` objects should be used. out_keys (sequence of NestedKey, optional): output entries. Defaults to the value of `in_keys`. @@ -2569,7 +2590,7 @@ def __init__( ): if in_keys is None: raise RuntimeError( - "Not passing in_keys to ObservationNorm is a deprecated behaviour." + "Not passing in_keys to ObservationNorm is a deprecated behavior." ) if out_keys is None: @@ -3000,7 +3021,7 @@ def _inv_call(self, tensordict: TensorDictBase) -> torch.Tensor: def _call(self, tensordict: TensorDictBase, _reset=None) -> TensorDictBase: """Update the episode tensordict with max pooled keys.""" _just_reset = _reset is not None - for in_key, out_key in zip(self.in_keys, self.out_keys): + for in_key, out_key in _zip_strict(self.in_keys, self.out_keys): # Lazy init of buffers buffer_name = f"_cat_buffers_{in_key}" data = tensordict.get(in_key) @@ -3082,6 +3103,31 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase: else: return self.unfolding(tensordict) + def _apply_same_padding(self, dim, data, done_mask): + d = data.ndim + dim - 1 + res = data.clone() + num_repeats_per_sample = done_mask.sum(dim=-1) + + if num_repeats_per_sample.dim() > 2: + extra_dims = num_repeats_per_sample.dim() - 2 + num_repeats_per_sample = num_repeats_per_sample.flatten(0, extra_dims) + res_flat_series = res.flatten(0, extra_dims) + else: + extra_dims = 0 + res_flat_series = res + + if d - 1 > extra_dims: + res_flat_series_flat_batch = res_flat_series.flatten(1, d - 1) + else: + res_flat_series_flat_batch = res_flat_series[:, None] + + for sample_idx, num_repeats in enumerate(num_repeats_per_sample): + if num_repeats > 0: + res_slice = res_flat_series_flat_batch[sample_idx] + res_slice[:, :num_repeats] = res_slice[:, num_repeats : num_repeats + 1] + + return res + @set_lazy_legacy(False) def unfolding(self, tensordict: TensorDictBase) -> TensorDictBase: # it is assumed that the last dimension of the tensordict is the time dimension @@ -3110,12 +3156,12 @@ def unfolding(self, tensordict: TensorDictBase) -> TensorDictBase: # first sort the in_keys with strings and non-strings keys = [ (in_key, out_key) - for in_key, out_key in zip(self.in_keys, self.out_keys) + for in_key, out_key in _zip_strict(self.in_keys, self.out_keys) if isinstance(in_key, str) ] keys += [ (in_key, out_key) - for in_key, out_key in zip(self.in_keys, self.out_keys) + for in_key, out_key in _zip_strict(self.in_keys, self.out_keys) if not isinstance(in_key, str) ] @@ -3151,7 +3197,7 @@ def unfold_done(done, N): first_val = None if isinstance(in_key, tuple) and in_key[0] == "next": # let's get the out_key we have already processed - prev_out_key = dict(zip(self.in_keys, self.out_keys)).get( + prev_out_key = dict(_zip_strict(self.in_keys, self.out_keys)).get( in_key[1], None ) if prev_out_key is not None: @@ -3192,24 +3238,7 @@ def unfold_done(done, N): if self.padding != "same": data = torch.where(done_mask_expand, self.padding_value, data) else: - # TODO: This is a pretty bad implementation, could be - # made more efficient but it works! - reset_any = reset.any(-1, False) - reset_vals = list(data_orig[reset_any].unbind(0)) - j_ = float("inf") - reps = [] - d = data.ndim + self.dim - 1 - n_feat = data.shape[data.ndim + self.dim :].numel() - for j in done_mask_expand.flatten(d, -1).sum(-1).view(-1) // n_feat: - if j > j_: - reset_vals = reset_vals[1:] - reps.extend([reset_vals[0]] * int(j)) - j_ = j - if reps: - reps = torch.stack(reps) - data = torch.masked_scatter( - data, done_mask_expand, reps.reshape(-1) - ) + data = self._apply_same_padding(self.dim, data, done_mask) if first_val is not None: # Aggregate reset along last dim @@ -3321,7 +3350,7 @@ def _apply_transform(self, reward: torch.Tensor) -> torch.Tensor: @_apply_to_composite def transform_reward_spec(self, reward_spec: TensorSpec) -> TensorSpec: - if isinstance(reward_spec, UnboundedContinuousTensorSpec): + if isinstance(reward_spec, Unbounded): return reward_spec else: raise NotImplementedError( @@ -3361,7 +3390,7 @@ class DTypeCastTransform(Transform): """Casts one dtype to another for selected keys. Depending on whether the ``in_keys`` or ``in_keys_inv`` are provided - during construction, the class behaviour will change: + during construction, the class behavior will change: * If the keys are provided, those entries and those entries only will be transformed from ``dtype_in`` to ``dtype_out`` entries; @@ -3417,17 +3446,17 @@ class DTypeCastTransform(Transform): >>> print(td.get("not_transformed").dtype) torch.float32 - The same behaviour is the rule when environments are constructedw without + The same behavior is the rule when environments are constructedw without specifying the transform keys: Examples: >>> class MyEnv(EnvBase): ... def __init__(self): ... super().__init__() - ... self.observation_spec = CompositeSpec(obs=UnboundedContinuousTensorSpec((), dtype=torch.float64)) - ... self.action_spec = UnboundedContinuousTensorSpec((), dtype=torch.float64) - ... self.reward_spec = UnboundedContinuousTensorSpec((1,), dtype=torch.float64) - ... self.done_spec = UnboundedContinuousTensorSpec((1,), dtype=torch.bool) + ... self.observation_spec = Composite(obs=Unbounded((), dtype=torch.float64)) + ... self.action_spec = Unbounded((), dtype=torch.float64) + ... self.reward_spec = Unbounded((1,), dtype=torch.float64) + ... self.done_spec = Unbounded((1,), dtype=torch.bool) ... def _reset(self, data=None): ... return TensorDict({"done": torch.zeros((1,), dtype=torch.bool), **self.observation_spec.rand()}, []) ... def _step(self, data): @@ -3601,7 +3630,7 @@ def func(name, item): return tensordict else: # we made sure that if in_keys is not None, out_keys is not None either - for in_key, out_key in zip(in_keys, out_keys): + for in_key, out_key in _zip_strict(in_keys, out_keys): item = self._apply_transform(tensordict.get(in_key)) tensordict.set(out_key, item) return tensordict @@ -3640,7 +3669,7 @@ def _inv_apply_transform(self, state: torch.Tensor) -> torch.Tensor: return state.to(self.dtype_in) def _transform_spec(self, spec: TensorSpec) -> None: - if isinstance(spec, CompositeSpec): + if isinstance(spec, Composite): for key in spec: self._transform_spec(spec[key]) else: @@ -3660,7 +3689,7 @@ def transform_input_spec(self, input_spec: TensorSpec) -> TensorSpec: raise NotImplementedError( f"Calling transform_input_spec without a parent environment isn't supported yet for {type(self)}." ) - for in_key_inv, out_key_inv in zip(self.in_keys_inv, self.out_keys_inv): + for in_key_inv, out_key_inv in _zip_strict(self.in_keys_inv, self.out_keys_inv): if in_key_inv in full_action_spec.keys(True): _spec = full_action_spec[in_key_inv] target = "action" @@ -3685,7 +3714,7 @@ def transform_input_spec(self, input_spec: TensorSpec) -> TensorSpec: raise RuntimeError return input_spec - def transform_output_spec(self, output_spec: CompositeSpec) -> CompositeSpec: + def transform_output_spec(self, output_spec: Composite) -> Composite: if self.in_keys is None: raise NotImplementedError( f"Calling transform_reward_spec without a parent environment isn't supported yet for {type(self)}." @@ -3694,7 +3723,7 @@ def transform_output_spec(self, output_spec: CompositeSpec) -> CompositeSpec: full_observation_spec = output_spec["full_observation_spec"] for reward_key, reward_spec in list(full_reward_spec.items(True, True)): # find out_key that match the in_key - for in_key, out_key in zip(self.in_keys, self.out_keys): + for in_key, out_key in _zip_strict(self.in_keys, self.out_keys): if reward_key == in_key: if reward_spec.dtype != self.dtype_in: raise TypeError(f"reward_spec.dtype is not {self.dtype_in}") @@ -3710,7 +3739,7 @@ def transform_observation_spec(self, observation_spec): full_observation_spec.items(True, True) ): # find out_key that match the in_key - for in_key, out_key in zip(self.in_keys, self.out_keys): + for in_key, out_key in _zip_strict(self.in_keys, self.out_keys): if observation_key == in_key: if observation_spec.dtype != self.dtype_in: raise TypeError( @@ -3733,7 +3762,7 @@ class DoubleToFloat(DTypeCastTransform): """Casts one dtype to another for selected keys. Depending on whether the ``in_keys`` or ``in_keys_inv`` are provided - during construction, the class behaviour will change: + during construction, the class behavior will change: * If the keys are provided, those entries and those entries only will be transformed from ``float64`` to ``float32`` entries; @@ -3787,17 +3816,17 @@ class DoubleToFloat(DTypeCastTransform): >>> print(td.get("not_transformed").dtype) torch.float32 - The same behaviour is the rule when environments are constructedw without + The same behavior is the rule when environments are constructedw without specifying the transform keys: Examples: >>> class MyEnv(EnvBase): ... def __init__(self): ... super().__init__() - ... self.observation_spec = CompositeSpec(obs=UnboundedContinuousTensorSpec((), dtype=torch.float64)) - ... self.action_spec = UnboundedContinuousTensorSpec((), dtype=torch.float64) - ... self.reward_spec = UnboundedContinuousTensorSpec((1,), dtype=torch.float64) - ... self.done_spec = UnboundedContinuousTensorSpec((1,), dtype=torch.bool) + ... self.observation_spec = Composite(obs=Unbounded((), dtype=torch.float64)) + ... self.action_spec = Unbounded((), dtype=torch.float64) + ... self.reward_spec = Unbounded((1,), dtype=torch.float64) + ... self.done_spec = Unbounded((1,), dtype=torch.bool) ... def _reset(self, data=None): ... return TensorDict({"done": torch.zeros((1,), dtype=torch.bool), **self.observation_spec.rand()}, []) ... def _step(self, data): @@ -3864,6 +3893,19 @@ class DeviceCastTransform(Transform): a parent environment exists, it it retrieved from it. In all other cases, it remains unspecified. + Keyword Args: + in_keys (list of NestedKey): the list of entries to map to a different device. + Defaults to ``None``. + out_keys (list of NestedKey): the output names of the entries mapped onto a device. + Defaults to the values of ``in_keys``. + in_keys_inv (list of NestedKey): the list of entries to map to a different device. + ``in_keys_inv`` are the names expected by the base environment. + Defaults to ``None``. + out_keys_inv (list of NestedKey): the output names of the entries mapped onto a device. + ``out_keys_inv`` are the names of the keys as seen from outside the transformed env. + Defaults to the values of ``in_keys_inv``. + + Examples: >>> td = TensorDict( ... {'obs': torch.ones(1, dtype=torch.double), @@ -3891,6 +3933,10 @@ def __init__( self.orig_device = ( torch.device(orig_device) if orig_device is not None else orig_device ) + if out_keys is None: + out_keys = copy(in_keys) + if out_keys_inv is None: + out_keys_inv = copy(in_keys_inv) super().__init__( in_keys=in_keys, out_keys=out_keys, @@ -3943,7 +3989,7 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase: return result tensordict_t = tensordict.named_apply(self._to, nested_keys=True, device=None) if self._rename_keys: - for in_key, out_key in zip(self.in_keys, self.out_keys): + for in_key, out_key in _zip_strict(self.in_keys, self.out_keys): if out_key != in_key: tensordict_t.rename_key_(in_key, out_key) tensordict_t.set(in_key, tensordict.get(in_key)) @@ -3957,7 +4003,7 @@ def _call(self, tensordict: TensorDictBase) -> TensorDictBase: return result tensordict_t = tensordict.named_apply(self._to, nested_keys=True, device=None) if self._rename_keys: - for in_key, out_key in zip(self.in_keys, self.out_keys): + for in_key, out_key in _zip_strict(self.in_keys, self.out_keys): if out_key != in_key: tensordict_t.rename_key_(in_key, out_key) tensordict_t.set(in_key, tensordict.get(in_key)) @@ -3985,7 +4031,7 @@ def _inv_call(self, tensordict: TensorDictBase) -> TensorDictBase: device=None, ) if self._rename_keys_inv: - for in_key, out_key in zip(self.in_keys_inv, self.out_keys_inv): + for in_key, out_key in _zip_strict(self.in_keys_inv, self.out_keys_inv): if out_key != in_key: tensordict_t.rename_key_(in_key, out_key) tensordict_t.set(in_key, tensordict.get(in_key)) @@ -4010,58 +4056,58 @@ def _sync_orig_device(self): return self._sync_orig_device return sync_func - def transform_input_spec(self, input_spec: CompositeSpec) -> CompositeSpec: + def transform_input_spec(self, input_spec: Composite) -> Composite: if self._map_env_device: return input_spec.to(self.device) else: + input_spec.clear_device_() return super().transform_input_spec(input_spec) - def transform_action_spec(self, full_action_spec: CompositeSpec) -> CompositeSpec: + def transform_action_spec(self, full_action_spec: Composite) -> Composite: full_action_spec = full_action_spec.clear_device_() - for in_key, out_key in zip(self.in_keys_inv, self.out_keys_inv): - if in_key not in full_action_spec.keys(True, True): - continue - full_action_spec[out_key] = full_action_spec[in_key].to(self.device) + for in_key, out_key in _zip_strict(self.in_keys_inv, self.out_keys_inv): + local_action_spec = full_action_spec.get(in_key, None) + if local_action_spec is not None: + full_action_spec[out_key] = local_action_spec.to(self.device) return full_action_spec - def transform_state_spec(self, full_state_spec: CompositeSpec) -> CompositeSpec: + def transform_state_spec(self, full_state_spec: Composite) -> Composite: full_state_spec = full_state_spec.clear_device_() - for in_key, out_key in zip(self.in_keys_inv, self.out_keys_inv): - if in_key not in full_state_spec.keys(True, True): - continue - full_state_spec[out_key] = full_state_spec[in_key].to(self.device) + for in_key, out_key in _zip_strict(self.in_keys_inv, self.out_keys_inv): + local_state_spec = full_state_spec.get(in_key, None) + if local_state_spec is not None: + full_state_spec[out_key] = local_state_spec.to(self.device) return full_state_spec - def transform_output_spec(self, output_spec: CompositeSpec) -> CompositeSpec: + def transform_output_spec(self, output_spec: Composite) -> Composite: if self._map_env_device: return output_spec.to(self.device) else: + output_spec.clear_device_() return super().transform_output_spec(output_spec) - def transform_observation_spec( - self, observation_spec: CompositeSpec - ) -> CompositeSpec: + def transform_observation_spec(self, observation_spec: Composite) -> Composite: observation_spec = observation_spec.clear_device_() - for in_key, out_key in zip(self.in_keys, self.out_keys): - if in_key not in observation_spec.keys(True, True): - continue - observation_spec[out_key] = observation_spec[in_key].to(self.device) + for in_key, out_key in _zip_strict(self.in_keys, self.out_keys): + local_obs_spec = observation_spec.get(in_key, None) + if local_obs_spec is not None: + observation_spec[out_key] = local_obs_spec.to(self.device) return observation_spec - def transform_done_spec(self, full_done_spec: CompositeSpec) -> CompositeSpec: + def transform_done_spec(self, full_done_spec: Composite) -> Composite: full_done_spec = full_done_spec.clear_device_() - for in_key, out_key in zip(self.in_keys, self.out_keys): - if in_key not in full_done_spec.keys(True, True): - continue - full_done_spec[out_key] = full_done_spec[in_key].to(self.device) + for in_key, out_key in _zip_strict(self.in_keys, self.out_keys): + local_done_spec = full_done_spec.get(in_key, None) + if local_done_spec is not None: + full_done_spec[out_key] = local_done_spec.to(self.device) return full_done_spec - def transform_reward_spec(self, full_reward_spec: CompositeSpec) -> CompositeSpec: + def transform_reward_spec(self, full_reward_spec: Composite) -> Composite: full_reward_spec = full_reward_spec.clear_device_() - for in_key, out_key in zip(self.in_keys, self.out_keys): - if in_key not in full_reward_spec.keys(True, True): - continue - full_reward_spec[out_key] = full_reward_spec[in_key].to(self.device) + for in_key, out_key in _zip_strict(self.in_keys, self.out_keys): + local_reward_spec = full_reward_spec.get(in_key, None) + if local_reward_spec is not None: + full_reward_spec[out_key] = local_reward_spec.to(self.device) return full_reward_spec def transform_env_device(self, device): @@ -4092,7 +4138,7 @@ class CatTensors(Transform): Args: in_keys (sequence of NestedKey): keys to be concatenated. If `None` (or not provided) the keys will be retrieved from the parent environment the first time - the transform is used. This behaviour will only work if a parent is set. + the transform is used. This behavior will only work if a parent is set. out_key (NestedKey): key of the resulting tensor. dim (int, optional): dimension along which the concatenation will occur. Default is ``-1``. @@ -4215,13 +4261,13 @@ def transform_observation_spec(self, observation_spec: TensorSpec) -> TensorSpec self._initialized = True # check that all keys are in observation_spec - if len(self.in_keys) > 1 and not isinstance(observation_spec, CompositeSpec): + if len(self.in_keys) > 1 and not isinstance(observation_spec, Composite): raise ValueError( "CatTensor cannot infer the output observation spec as there are multiple input keys but " "only one observation_spec." ) - if isinstance(observation_spec, CompositeSpec) and len( + if isinstance(observation_spec, Composite) and len( [key for key in self.in_keys if key not in observation_spec.keys(True)] ): raise ValueError( @@ -4229,7 +4275,7 @@ def transform_observation_spec(self, observation_spec: TensorSpec) -> TensorSpec "Make sure the environment has an observation_spec attribute that includes all the specs needed for CatTensor." ) - if not isinstance(observation_spec, CompositeSpec): + if not isinstance(observation_spec, Composite): # by def, there must be only one key return observation_spec @@ -4249,7 +4295,7 @@ def transform_observation_spec(self, observation_spec: TensorSpec) -> TensorSpec device = spec0.device shape[self.dim] = sum_shape shape = torch.Size(shape) - observation_spec[out_key] = UnboundedContinuousTensorSpec( + observation_spec[out_key] = Unbounded( shape=shape, dtype=spec0.dtype, device=device, @@ -4357,14 +4403,14 @@ def _inv_apply_transform(self, action: torch.Tensor) -> torch.Tensor: action = nn.functional.one_hot(action, self.num_actions_effective) return action - def transform_input_spec(self, input_spec: CompositeSpec): + def transform_input_spec(self, input_spec: Composite): input_spec = input_spec.clone() for key in input_spec["full_action_spec"].keys(True, True): key = ("full_action_spec", key) break else: raise KeyError("key not found in action_spec.") - input_spec[key] = OneHotDiscreteTensorSpec( + input_spec[key] = OneHot( self.max_actions, shape=(*input_spec[key].shape[:-1], self.max_actions), device=input_spec.device, @@ -4456,7 +4502,7 @@ def _reset( ) # Merge the two tensordicts tensordict = parent._reset_proc_data(tensordict.clone(False), tensordict_reset) - # check that there is a single done state -- behaviour is undefined for multiple dones + # check that there is a single done state -- behavior is undefined for multiple dones done_keys = parent.done_keys reward_key = parent.reward_key if parent.batch_size.numel() > 1: @@ -4526,9 +4572,9 @@ class TensorDictPrimer(Transform): tensordict with the desired features. Args: - primers (dict or CompositeSpec, optional): a dictionary containing + primers (dict or Composite, optional): a dictionary containing key-spec pairs which will be used to populate the input tensordict. - :class:`~torchrl.data.CompositeSpec` instances are supported too. + :class:`~torchrl.data.Composite` instances are supported too. random (bool, optional): if ``True``, the values will be drawn randomly from the TensorSpec domain (or a unit Gaussian if unbounded). Otherwise a fixed value will be assumed. Defaults to `False`. @@ -4557,7 +4603,7 @@ class TensorDictPrimer(Transform): >>> base_env = SerialEnv(2, lambda: GymEnv("Pendulum-v1")) >>> env = TransformedEnv(base_env) >>> # the env is batch-locked, so the leading dims of the spec must match those of the env - >>> env.append_transform(TensorDictPrimer(mykey=UnboundedContinuousTensorSpec([2, 3]))) + >>> env.append_transform(TensorDictPrimer(mykey=Unbounded([2, 3]))) >>> td = env.reset() >>> print(td) TensorDict( @@ -4591,14 +4637,14 @@ class TensorDictPrimer(Transform): .. note:: Some TorchRL modules rely on specific keys being present in the environment TensorDicts, like :class:`~torchrl.modules.models.LSTM` or :class:`~torchrl.modules.models.GRU`. - To facilitate this process, the method :func:`~torchrl.models.utils.get_primers_from_module` + To facilitate this process, the method :func:`~torchrl.modules.utils.get_primers_from_module` automatically checks for required primer transforms in a module and its submodules and generates them. """ def __init__( self, - primers: dict | CompositeSpec = None, + primers: dict | Composite = None, random: bool | None = None, default_value: float | Callable @@ -4615,8 +4661,8 @@ def __init__( "as kwargs." ) kwargs = primers - if not isinstance(kwargs, CompositeSpec): - kwargs = CompositeSpec(kwargs) + if not isinstance(kwargs, Composite): + kwargs = Composite(kwargs) self.primers = kwargs if random and default_value: raise ValueError( @@ -4658,10 +4704,15 @@ def __init__( def reset_key(self): reset_key = self.__dict__.get("_reset_key", None) if reset_key is None: + if self.parent is None: + raise RuntimeError( + "Missing parent, cannot infer reset_key automatically." + ) reset_keys = self.parent.reset_keys if len(reset_keys) > 1: raise RuntimeError( - f"Got more than one reset key in env {self.container}, cannot infer which one to use. Consider providing the reset key in the {type(self)} constructor." + f"Got more than one reset key in env {self.container}, cannot infer which one to use. " + f"Consider providing the reset key in the {type(self)} constructor." ) reset_key = self._reset_key = reset_keys[0] return reset_key @@ -4698,12 +4749,10 @@ def to(self, *args, **kwargs): def _expand_shape(self, spec): return spec.expand((*self.parent.batch_size, *spec.shape)) - def transform_observation_spec( - self, observation_spec: CompositeSpec - ) -> CompositeSpec: - if not isinstance(observation_spec, CompositeSpec): + def transform_observation_spec(self, observation_spec: Composite) -> Composite: + if not isinstance(observation_spec, Composite): raise ValueError( - f"observation_spec was expected to be of type CompositeSpec. Got {type(observation_spec)} instead." + f"observation_spec was expected to be of type Composite. Got {type(observation_spec)} instead." ) if self.primers.shape != observation_spec.shape: @@ -4718,9 +4767,12 @@ def transform_observation_spec( return observation_spec def transform_input_spec(self, input_spec: TensorSpec) -> TensorSpec: - input_spec["full_state_spec"] = self.transform_observation_spec( - input_spec["full_state_spec"] - ) + new_state_spec = self.transform_observation_spec(input_spec["full_state_spec"]) + for action_key in list(input_spec["full_action_spec"].keys(True, True)): + if action_key in new_state_spec.keys(True, True): + input_spec["full_action_spec", action_key] = new_state_spec[action_key] + del new_state_spec[action_key] + input_spec["full_state_spec"] = new_state_spec return input_spec @property @@ -4866,7 +4918,7 @@ def __init__( ) random = state_dim is not None and action_dim is not None shape = tuple(shape) + tail_dim - primers = {"_eps_gSDE": UnboundedContinuousTensorSpec(shape=shape)} + primers = {"_eps_gSDE": Unbounded(shape=shape)} super().__init__(primers=primers, random=random, **kwargs) @@ -5010,7 +5062,7 @@ def _call(self, tensordict: TensorDictBase) -> TensorDictBase: if self.lock is not None: self.lock.acquire() - for key, key_out in zip(self.in_keys, self.out_keys): + for key, key_out in _zip_strict(self.in_keys, self.out_keys): if key not in tensordict.keys(include_nested=True): # TODO: init missing rewards with this # for key_suffix in [_append_last(key, suffix) for suffix in ("_sum", "_ssq", "_count")]: @@ -5148,7 +5200,7 @@ def to_observation_norm(self) -> Union[Compose, ObservationNorm]: out = [] loc = self.loc scale = self.scale - for key, key_out in zip(self.in_keys, self.out_keys): + for key, key_out in _zip_strict(self.in_keys, self.out_keys): _out = ObservationNorm( loc=loc.get(key), scale=scale.get(key), @@ -5325,8 +5377,8 @@ def __setstate__(self, state: Dict[str, Any]): @_apply_to_composite def transform_observation_spec(self, observation_spec: TensorSpec) -> TensorSpec: - if isinstance(observation_spec, BoundedTensorSpec): - return UnboundedContinuousTensorSpec( + if isinstance(observation_spec, Bounded): + return Unbounded( shape=observation_spec.shape, dtype=observation_spec.dtype, device=observation_spec.device, @@ -5464,10 +5516,12 @@ def reset_keys(self): # We take the filtered reset keys, which are the only keys that really # matter when calling reset, and check that they match the in_keys root. reset_keys = parent._filtered_reset_keys + if len(reset_keys) == 1: + reset_keys = list(reset_keys) * len(self.in_keys) def _check_match(reset_keys, in_keys): # if this is called, the length of reset_keys and in_keys must match - for reset_key, in_key in zip(reset_keys, in_keys): + for reset_key, in_key in _zip_strict(reset_keys, in_keys): # having _reset at the root and the reward_key ("agent", "reward") is allowed # but having ("agent", "_reset") and "reward" isn't if isinstance(reset_key, tuple) and isinstance(in_key, str): @@ -5511,7 +5565,7 @@ def _reset( self, tensordict: TensorDictBase, tensordict_reset: TensorDictBase ) -> TensorDictBase: """Resets episode rewards.""" - for in_key, reset_key, out_key in zip( + for in_key, reset_key, out_key in _zip_strict( self.in_keys, self.reset_keys, self.out_keys ): _reset = _get_reset(reset_key, tensordict) @@ -5528,7 +5582,7 @@ def _step( ) -> TensorDictBase: """Updates the episode rewards with the step rewards.""" # Update episode rewards - for in_key, out_key in zip(self.in_keys, self.out_keys): + for in_key, out_key in _zip_strict(self.in_keys, self.out_keys): if in_key in next_tensordict.keys(include_nested=True): reward = next_tensordict.get(in_key) prev_reward = tensordict.get(out_key, 0.0) @@ -5540,17 +5594,17 @@ def _step( def transform_input_spec(self, input_spec: TensorSpec) -> TensorSpec: state_spec = input_spec["full_state_spec"] if state_spec is None: - state_spec = CompositeSpec(shape=input_spec.shape, device=input_spec.device) + state_spec = Composite(shape=input_spec.shape, device=input_spec.device) state_spec.update(self._generate_episode_reward_spec()) input_spec["full_state_spec"] = state_spec return input_spec - def _generate_episode_reward_spec(self) -> CompositeSpec: - episode_reward_spec = CompositeSpec() + def _generate_episode_reward_spec(self) -> Composite: + episode_reward_spec = Composite() reward_spec = self.parent.full_reward_spec reward_spec_keys = self.parent.reward_keys # Define episode specs for all out_keys - for in_key, out_key in zip(self.in_keys, self.out_keys): + for in_key, out_key in _zip_strict(self.in_keys, self.out_keys): if ( in_key in reward_spec_keys ): # if this out_key has a corresponding key in reward_spec @@ -5559,7 +5613,7 @@ def _generate_episode_reward_spec(self) -> CompositeSpec: temp_rew_spec = reward_spec for sub_key in out_key[:-1]: if ( - not isinstance(temp_rew_spec, CompositeSpec) + not isinstance(temp_rew_spec, Composite) or sub_key not in temp_rew_spec.keys() ): break @@ -5580,8 +5634,8 @@ def transform_observation_spec(self, observation_spec: TensorSpec) -> TensorSpec """Transforms the observation spec, adding the new keys generated by RewardSum.""" if self.reward_spec: return observation_spec - if not isinstance(observation_spec, CompositeSpec): - observation_spec = CompositeSpec( + if not isinstance(observation_spec, Composite): + observation_spec = Composite( observation=observation_spec, shape=self.parent.batch_size ) observation_spec.update(self._generate_episode_reward_spec()) @@ -5600,8 +5654,8 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase: "At least one dimension of the tensordict must be named 'time' in offline mode" ) time_dim = time_dim[0] - 1 - for in_key, out_key in zip(self.in_keys, self.out_keys): - reward = tensordict.get(in_key) + for in_key, out_key in _zip_strict(self.in_keys, self.out_keys): + reward = tensordict[in_key] cumsum = reward.cumsum(time_dim) tensordict.set(out_key, cumsum) return tensordict @@ -5778,7 +5832,13 @@ def _reset( self, tensordict: TensorDictBase, tensordict_reset: TensorDictBase ) -> TensorDictBase: # get reset signal - for step_count_key, truncated_key, terminated_key, reset_key, done_key in zip( + for ( + step_count_key, + truncated_key, + terminated_key, + reset_key, + done_key, + ) in _zip_strict( self.step_count_keys, self.truncated_keys, self.terminated_keys, @@ -5819,10 +5879,8 @@ def _reset( def _step( self, tensordict: TensorDictBase, next_tensordict: TensorDictBase ) -> TensorDictBase: - for step_count_key, truncated_key, done_key in zip( - self.step_count_keys, - self.truncated_keys, - self.done_keys, + for step_count_key, truncated_key, done_key in _zip_strict( + self.step_count_keys, self.truncated_keys, self.done_keys ): step_count = tensordict.get(step_count_key) next_step_count = step_count + 1 @@ -5844,12 +5902,10 @@ def _step( next_tensordict.set(truncated_key, truncated) return next_tensordict - def transform_observation_spec( - self, observation_spec: CompositeSpec - ) -> CompositeSpec: - if not isinstance(observation_spec, CompositeSpec): + def transform_observation_spec(self, observation_spec: Composite) -> Composite: + if not isinstance(observation_spec, Composite): raise ValueError( - f"observation_spec was expected to be of type CompositeSpec. Got {type(observation_spec)} instead." + f"observation_spec was expected to be of type Composite. Got {type(observation_spec)} instead." ) full_done_spec = self.parent.output_spec["full_done_spec"] for step_count_key in self.step_count_keys: @@ -5871,7 +5927,7 @@ def transform_observation_spec( raise KeyError( f"Could not find root of step_count_key {step_count_key} in done keys {self.done_keys}." ) - observation_spec[step_count_key] = BoundedTensorSpec( + observation_spec[step_count_key] = Bounded( shape=shape, dtype=torch.int64, device=observation_spec.device, @@ -5880,7 +5936,7 @@ def transform_observation_spec( ) return super().transform_observation_spec(observation_spec) - def transform_output_spec(self, output_spec: CompositeSpec) -> CompositeSpec: + def transform_output_spec(self, output_spec: Composite) -> Composite: if self.max_steps: full_done_spec = self.parent.output_spec["full_done_spec"] for truncated_key in self.truncated_keys: @@ -5902,7 +5958,7 @@ def transform_output_spec(self, output_spec: CompositeSpec) -> CompositeSpec: raise KeyError( f"Could not find root of truncated_key {truncated_key} in done keys {self.done_keys}." ) - full_done_spec[truncated_key] = DiscreteTensorSpec( + full_done_spec[truncated_key] = Categorical( 2, dtype=torch.bool, device=output_spec.device, shape=shape ) if self.update_done: @@ -5925,19 +5981,19 @@ def transform_output_spec(self, output_spec: CompositeSpec) -> CompositeSpec: raise KeyError( f"Could not find root of stop_key {done_key} in done keys {self.done_keys}." ) - full_done_spec[done_key] = DiscreteTensorSpec( + full_done_spec[done_key] = Categorical( 2, dtype=torch.bool, device=output_spec.device, shape=shape ) output_spec["full_done_spec"] = full_done_spec return super().transform_output_spec(output_spec) - def transform_input_spec(self, input_spec: CompositeSpec) -> CompositeSpec: - if not isinstance(input_spec, CompositeSpec): + def transform_input_spec(self, input_spec: Composite) -> Composite: + if not isinstance(input_spec, Composite): raise ValueError( - f"input_spec was expected to be of type CompositeSpec. Got {type(input_spec)} instead." + f"input_spec was expected to be of type Composite. Got {type(input_spec)} instead." ) if input_spec["full_state_spec"] is None: - input_spec["full_state_spec"] = CompositeSpec( + input_spec["full_state_spec"] = Composite( shape=input_spec.shape, device=input_spec.device ) @@ -5962,9 +6018,7 @@ def transform_input_spec(self, input_spec: CompositeSpec) -> CompositeSpec: f"Could not find root of step_count_key {step_count_key} in done keys {self.done_keys}." ) - input_spec[ - unravel_key(("full_state_spec", step_count_key)) - ] = BoundedTensorSpec( + input_spec[unravel_key(("full_state_spec", step_count_key))] = Bounded( shape=shape, dtype=torch.int64, device=input_spec.device, @@ -6051,7 +6105,7 @@ def _reset( return tensordict_reset.exclude(*self.excluded_keys) return tensordict - def transform_output_spec(self, output_spec: CompositeSpec) -> CompositeSpec: + def transform_output_spec(self, output_spec: Composite) -> Composite: if not self.inverse: full_done_spec = output_spec["full_done_spec"] full_reward_spec = output_spec["full_reward_spec"] @@ -6171,7 +6225,7 @@ def _reset( *self.selected_keys, *reward_keys, *done_keys, *input_keys, strict=False ) - def transform_output_spec(self, output_spec: CompositeSpec) -> CompositeSpec: + def transform_output_spec(self, output_spec: Composite) -> Composite: full_done_spec = output_spec["full_done_spec"] full_reward_spec = output_spec["full_reward_spec"] full_observation_spec = output_spec["full_observation_spec"] @@ -6325,7 +6379,7 @@ def _make_missing_buffer(self, tensordict, in_key, buffer_name): def _call(self, tensordict: TensorDictBase, _reset=None) -> TensorDictBase: """Update the episode tensordict with max pooled keys.""" - for in_key, out_key in zip(self.in_keys, self.out_keys): + for in_key, out_key in _zip_strict(self.in_keys, self.out_keys): # Lazy init of buffers buffer_name = self._buffer_name(in_key) buffer = getattr(self, buffer_name) @@ -6381,7 +6435,7 @@ class RandomCropTensorDict(Transform): This transform is primarily designed to be used with replay buffers and modules. Currently, it cannot be used as an environment transform. - Do not hesitate to request for this behaviour through an issue if this is + Do not hesitate to request for this behavior through an issue if this is desired. Args: @@ -6409,7 +6463,7 @@ def __init__( if sample_dim > 0: warnings.warn( "A positive shape has been passed to the RandomCropTensorDict " - "constructor. This may have unexpected behaviours when the " + "constructor. This may have unexpected behaviors when the " "passed tensordicts have inconsistent batch dimensions. " "For context, by convention, TorchRL concatenates time steps " "along the last dimension of the tensordict." @@ -6566,7 +6620,7 @@ def _reset( device = tensordict.device if device is None: device = torch.device("cpu") - for reset_key, init_key in zip(self.reset_keys, self.init_keys): + for reset_key, init_key in _zip_strict(self.reset_keys, self.init_keys): _reset = tensordict.get(reset_key, None) if _reset is None: done_key = _replace_last(init_key, "done") @@ -6610,7 +6664,7 @@ def transform_observation_spec(self, observation_spec: TensorSpec) -> TensorSpec raise KeyError( f"Could not find root of init_key {init_key} within done_keys {self.parent.done_keys}." ) - observation_spec[init_key] = DiscreteTensorSpec( + observation_spec[init_key] = Categorical( 2, dtype=torch.bool, device=self.parent.device, @@ -6625,15 +6679,15 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase: class RenameTransform(Transform): - """A transform to rename entries in the output tensordict. + """A transform to rename entries in the output tensordict (or input tensordict via the inverse keys). Args: - in_keys (sequence of NestedKey): the entries to rename + in_keys (sequence of NestedKey): the entries to rename. out_keys (sequence of NestedKey): the name of the entries after renaming. - in_keys_inv (sequence of NestedKey, optional): the entries to rename before - passing the input tensordict to :meth:`EnvBase._step`. - out_keys_inv (sequence of NestedKey, optional): the names of the renamed - entries passed to :meth:`EnvBase._step`. + in_keys_inv (sequence of NestedKey, optional): the entries to rename + in the input tensordict, which will be passed to :meth:`EnvBase._step`. + out_keys_inv (sequence of NestedKey, optional): the names of the entries + in the input tensordict after renaming. create_copy (bool, optional): if ``True``, the entries will be copied with a different name rather than being renamed. This allows for renaming immutable entries such as ``"reward"`` and ``"done"``. @@ -6702,15 +6756,15 @@ def __init__( def _call(self, tensordict: TensorDictBase) -> TensorDictBase: if self.create_copy: out = tensordict.select(*self.in_keys, strict=not self._missing_tolerance) - for in_key, out_key in zip(self.in_keys, self.out_keys): + for in_key, out_key in _zip_strict(self.in_keys, self.out_keys): try: - tensordict.rename_key_(in_key, out_key) + out.rename_key_(in_key, out_key) except KeyError: if not self._missing_tolerance: raise tensordict = tensordict.update(out) else: - for in_key, out_key in zip(self.in_keys, self.out_keys): + for in_key, out_key in _zip_strict(self.in_keys, self.out_keys): try: tensordict.rename_key_(in_key, out_key) except KeyError: @@ -6732,7 +6786,7 @@ def _inv_call(self, tensordict: TensorDictBase) -> TensorDictBase: out = tensordict.select( *self.out_keys_inv, strict=not self._missing_tolerance ) - for in_key, out_key in zip(self.in_keys_inv, self.out_keys_inv): + for in_key, out_key in _zip_strict(self.in_keys_inv, self.out_keys_inv): try: out.rename_key_(out_key, in_key) except KeyError: @@ -6741,7 +6795,7 @@ def _inv_call(self, tensordict: TensorDictBase) -> TensorDictBase: tensordict = tensordict.update(out) else: - for in_key, out_key in zip(self.in_keys_inv, self.out_keys_inv): + for in_key, out_key in _zip_strict(self.in_keys_inv, self.out_keys_inv): try: tensordict.rename_key_(out_key, in_key) except KeyError: @@ -6749,7 +6803,7 @@ def _inv_call(self, tensordict: TensorDictBase) -> TensorDictBase: raise return tensordict - def transform_output_spec(self, output_spec: CompositeSpec) -> CompositeSpec: + def transform_output_spec(self, output_spec: Composite) -> Composite: for done_key in self.parent.done_keys: if done_key in self.in_keys: for i, out_key in enumerate(self.out_keys): # noqa: B007 @@ -6791,11 +6845,11 @@ def transform_output_spec(self, output_spec: CompositeSpec) -> CompositeSpec: del output_spec["full_observation_spec"][observation_key] return output_spec - def transform_input_spec(self, input_spec: CompositeSpec) -> CompositeSpec: + def transform_input_spec(self, input_spec: Composite) -> Composite: for action_key in self.parent.action_keys: - if action_key in self.in_keys: - for i, out_key in enumerate(self.out_keys): # noqa: B007 - if self.in_keys[i] == action_key: + if action_key in self.in_keys_inv: + for i, out_key in enumerate(self.out_keys_inv): # noqa: B007 + if self.in_keys_inv[i] == action_key: break else: # unreachable @@ -6806,9 +6860,9 @@ def transform_input_spec(self, input_spec: CompositeSpec) -> CompositeSpec: if not self.create_copy: del input_spec["full_action_spec"][action_key] for state_key in self.parent.full_state_spec.keys(True): - if state_key in self.in_keys: - for i, out_key in enumerate(self.out_keys): # noqa: B007 - if self.in_keys[i] == state_key: + if state_key in self.in_keys_inv: + for i, out_key in enumerate(self.out_keys_inv): # noqa: B007 + if self.in_keys_inv[i] == state_key: break else: # unreachable @@ -6962,7 +7016,7 @@ def _inv_call(self, tensordict: TensorDictBase) -> TensorDictBase: "No episode ends found to calculate the reward to go. Make sure that the number of frames_per_batch is larger than number of steps per episode." ) found = False - for in_key, out_key in zip(self.in_keys_inv, self.out_keys_inv): + for in_key, out_key in _zip_strict(self.in_keys_inv, self.out_keys_inv): if in_key in tensordict.keys(include_nested=True): found = True item = self._inv_apply_transform(tensordict.get(in_key), done) @@ -7003,16 +7057,16 @@ class ActionMask(Transform): Examples: >>> import torch - >>> from torchrl.data.tensor_specs import DiscreteTensorSpec, BinaryDiscreteTensorSpec, UnboundedContinuousTensorSpec, CompositeSpec + >>> from torchrl.data.tensor_specs import Categorical, Binary, Unbounded, Composite >>> from torchrl.envs.transforms import ActionMask, TransformedEnv >>> from torchrl.envs.common import EnvBase >>> class MaskedEnv(EnvBase): ... def __init__(self, *args, **kwargs): ... super().__init__(*args, **kwargs) - ... self.action_spec = DiscreteTensorSpec(4) - ... self.state_spec = CompositeSpec(action_mask=BinaryDiscreteTensorSpec(4, dtype=torch.bool)) - ... self.observation_spec = CompositeSpec(obs=UnboundedContinuousTensorSpec(3)) - ... self.reward_spec = UnboundedContinuousTensorSpec(1) + ... self.action_spec = Categorical(4) + ... self.state_spec = Composite(action_mask=Binary(4, dtype=torch.bool)) + ... self.observation_spec = Composite(obs=Unbounded(3)) + ... self.reward_spec = Unbounded(1) ... ... def _reset(self, tensordict=None): ... td = self.observation_spec.rand() @@ -7048,10 +7102,10 @@ class ActionMask(Transform): """ ACCEPTED_SPECS = ( - OneHotDiscreteTensorSpec, - DiscreteTensorSpec, - MultiOneHotDiscreteTensorSpec, - MultiDiscreteTensorSpec, + OneHot, + Categorical, + MultiOneHot, + MultiCategorical, ) SPEC_TYPE_ERROR = "The action spec must be one of {}. Got {} instead." @@ -7477,7 +7531,7 @@ def _inv_apply_transform(self, state: torch.Tensor) -> torch.Tensor: @_apply_to_composite def transform_observation_spec(self, observation_spec: TensorSpec) -> TensorSpec: - return BoundedTensorSpec( + return Bounded( shape=observation_spec.shape, device=observation_spec.device, dtype=observation_spec.dtype, @@ -7489,7 +7543,7 @@ def transform_reward_spec(self, reward_spec: TensorSpec) -> TensorSpec: for key in self.in_keys: if key in self.parent.reward_keys: spec = self.parent.output_spec["full_reward_spec"][key] - self.parent.output_spec["full_reward_spec"][key] = BoundedTensorSpec( + self.parent.output_spec["full_reward_spec"][key] = Bounded( shape=spec.shape, device=spec.device, dtype=spec.dtype, @@ -7512,31 +7566,31 @@ class RemoveEmptySpecs(Transform): Examples: >>> import torch >>> from tensordict import TensorDict - >>> from torchrl.data import UnboundedContinuousTensorSpec, CompositeSpec, \ - ... DiscreteTensorSpec + >>> from torchrl.data import Unbounded, Composite, \ + ... Categorical >>> from torchrl.envs import EnvBase, TransformedEnv, RemoveEmptySpecs >>> >>> >>> class DummyEnv(EnvBase): ... def __init__(self, *args, **kwargs): ... super().__init__(*args, **kwargs) - ... self.observation_spec = CompositeSpec( - ... observation=UnboundedContinuousTensorSpec((*self.batch_size, 3)), - ... other=CompositeSpec( - ... another_other=CompositeSpec(shape=self.batch_size), + ... self.observation_spec = Composite( + ... observation=UnboundedContinuous((*self.batch_size, 3)), + ... other=Composite( + ... another_other=Composite(shape=self.batch_size), ... shape=self.batch_size, ... ), ... shape=self.batch_size, ... ) - ... self.action_spec = UnboundedContinuousTensorSpec((*self.batch_size, 3)) - ... self.done_spec = DiscreteTensorSpec( + ... self.action_spec = UnboundedContinuous((*self.batch_size, 3)) + ... self.done_spec = Categorical( ... 2, (*self.batch_size, 1), dtype=torch.bool ... ) ... self.full_done_spec["truncated"] = self.full_done_spec[ ... "terminated"].clone() - ... self.reward_spec = CompositeSpec( - ... reward=UnboundedContinuousTensorSpec(*self.batch_size, 1), - ... other_reward=CompositeSpec(shape=self.batch_size), + ... self.reward_spec = Composite( + ... reward=UnboundedContinuous(*self.batch_size, 1), + ... other_reward=Composite(shape=self.batch_size), ... shape=self.batch_size ... ) ... @@ -7629,7 +7683,7 @@ def _sorter(key_val): return 0 return len(key) - def transform_output_spec(self, output_spec: CompositeSpec) -> CompositeSpec: + def transform_output_spec(self, output_spec: Composite) -> Composite: full_done_spec = output_spec["full_done_spec"] full_reward_spec = output_spec["full_reward_spec"] full_observation_spec = output_spec["full_observation_spec"] @@ -7637,19 +7691,19 @@ def transform_output_spec(self, output_spec: CompositeSpec) -> CompositeSpec: for key, spec in sorted( full_done_spec.items(True), key=self._sorter, reverse=True ): - if isinstance(spec, CompositeSpec) and spec.is_empty(): + if isinstance(spec, Composite) and spec.is_empty(): del full_done_spec[key] for key, spec in sorted( full_observation_spec.items(True), key=self._sorter, reverse=True ): - if isinstance(spec, CompositeSpec) and spec.is_empty(): + if isinstance(spec, Composite) and spec.is_empty(): del full_observation_spec[key] for key, spec in sorted( full_reward_spec.items(True), key=self._sorter, reverse=True ): - if isinstance(spec, CompositeSpec) and spec.is_empty(): + if isinstance(spec, Composite) and spec.is_empty(): del full_reward_spec[key] return output_spec @@ -7662,14 +7716,14 @@ def transform_input_spec(self, input_spec: TensorSpec) -> TensorSpec: for key, spec in sorted( full_action_spec.items(True), key=self._sorter, reverse=True ): - if isinstance(spec, CompositeSpec) and spec.is_empty(): + if isinstance(spec, Composite) and spec.is_empty(): self._has_empty_input = True del full_action_spec[key] for key, spec in sorted( full_state_spec.items(True), key=self._sorter, reverse=True ): - if isinstance(spec, CompositeSpec) and spec.is_empty(): + if isinstance(spec, Composite) and spec.is_empty(): self._has_empty_input = True del full_state_spec[key] return input_spec @@ -7688,7 +7742,7 @@ def _inv_call(self, tensordict: TensorDictBase) -> TensorDictBase: full_action_spec.items(True), key=self._sorter, reverse=True ): if ( - isinstance(spec, CompositeSpec) + isinstance(spec, Composite) and spec.is_empty() and key not in tensordict.keys(True) ): @@ -7698,7 +7752,7 @@ def _inv_call(self, tensordict: TensorDictBase) -> TensorDictBase: full_state_spec.items(True), key=self._sorter, reverse=True ): if ( - isinstance(spec, CompositeSpec) + isinstance(spec, Composite) and spec.is_empty() and key not in tensordict.keys(True) ): @@ -7866,9 +7920,9 @@ class BatchSizeTransform(Transform): ... batch_locked = False ... def __init__(self): ... super().__init__() - ... self.observation_spec = CompositeSpec(observation=UnboundedContinuousTensorSpec(3)) - ... self.reward_spec = UnboundedContinuousTensorSpec(1) - ... self.action_spec = UnboundedContinuousTensorSpec(1) + ... self.observation_spec = Composite(observation=Unbounded(3)) + ... self.reward_spec = Unbounded(1) + ... self.action_spec = Unbounded(1) ... ... def _reset(self, tensordict: TensorDictBase, **kwargs) -> TensorDictBase: ... tensordict_batch_size = tensordict.batch_size if tensordict is not None else torch.Size([]) @@ -8016,12 +8070,12 @@ def transform_env_batch_size(self, batch_size: torch.Size): return self.batch_size return self.reshape_fn(torch.zeros(batch_size, device="meta")).shape - def transform_output_spec(self, output_spec: CompositeSpec) -> CompositeSpec: + def transform_output_spec(self, output_spec: Composite) -> Composite: if self.batch_size is not None: return output_spec.expand(self.batch_size) return self.reshape_fn(output_spec) - def transform_input_spec(self, input_spec: CompositeSpec) -> CompositeSpec: + def transform_input_spec(self, input_spec: Composite) -> Composite: if self.batch_size is not None: return input_spec.expand(self.batch_size) return self.reshape_fn(input_spec) @@ -8480,7 +8534,7 @@ def _indent(s): def transform_input_spec(self, input_spec): try: action_spec = input_spec["full_action_spec", self.in_keys_inv[0]] - if not isinstance(action_spec, BoundedTensorSpec): + if not isinstance(action_spec, Bounded): raise TypeError( f"action spec type {type(action_spec)} is not supported." ) @@ -8539,9 +8593,9 @@ def custom_arange(nint): ] cls = ( - functools.partial(MultiDiscreteTensorSpec, remove_singleton=False) + functools.partial(MultiCategorical, remove_singleton=False) if self.categorical - else MultiOneHotDiscreteTensorSpec + else MultiOneHot ) if not isinstance(num_intervals, torch.Tensor): diff --git a/torchrl/envs/transforms/vc1.py b/torchrl/envs/transforms/vc1.py index d8bec1cf524..d394816372d 100644 --- a/torchrl/envs/transforms/vc1.py +++ b/torchrl/envs/transforms/vc1.py @@ -14,12 +14,7 @@ from torch import nn from torchrl._utils import logger as torchrl_logger -from torchrl.data.tensor_specs import ( - CompositeSpec, - DEVICE_TYPING, - TensorSpec, - UnboundedContinuousTensorSpec, -) +from torchrl.data.tensor_specs import Composite, DEVICE_TYPING, TensorSpec, Unbounded from torchrl.envs.transforms.transforms import ( CenterCrop, Compose, @@ -198,8 +193,8 @@ def _apply_transform(self, obs: torch.Tensor) -> None: return out def transform_observation_spec(self, observation_spec: TensorSpec) -> TensorSpec: - if not isinstance(observation_spec, CompositeSpec): - raise ValueError("VC1Transform can only infer CompositeSpec") + if not isinstance(observation_spec, Composite): + raise ValueError("VC1Transform can only infer Composite") keys = [key for key in observation_spec.keys(True, True) if key in self.in_keys] device = observation_spec[keys[0]].device @@ -211,7 +206,7 @@ def transform_observation_spec(self, observation_spec: TensorSpec) -> TensorSpec del observation_spec[in_key] for out_key in self.out_keys: - observation_spec[out_key] = UnboundedContinuousTensorSpec( + observation_spec[out_key] = Unbounded( shape=torch.Size([*dim, self.embd_size]), device=device ) diff --git a/torchrl/envs/transforms/vip.py b/torchrl/envs/transforms/vip.py index e814f5da476..a28e490c4f1 100644 --- a/torchrl/envs/transforms/vip.py +++ b/torchrl/envs/transforms/vip.py @@ -9,11 +9,7 @@ from tensordict import set_lazy_legacy, TensorDict, TensorDictBase from torch.hub import load_state_dict_from_url -from torchrl.data.tensor_specs import ( - CompositeSpec, - TensorSpec, - UnboundedContinuousTensorSpec, -) +from torchrl.data.tensor_specs import Composite, TensorSpec, Unbounded from torchrl.data.utils import DEVICE_TYPING from torchrl.envs.transforms.transforms import ( CatTensors, @@ -92,8 +88,8 @@ def _apply_transform(self, obs: torch.Tensor) -> None: return out def transform_observation_spec(self, observation_spec: TensorSpec) -> TensorSpec: - if not isinstance(observation_spec, CompositeSpec): - raise ValueError("_VIPNet can only infer CompositeSpec") + if not isinstance(observation_spec, Composite): + raise ValueError("_VIPNet can only infer Composite") keys = [key for key in observation_spec.keys(True, True) if key in self.in_keys] device = observation_spec[keys[0]].device @@ -105,7 +101,7 @@ def transform_observation_spec(self, observation_spec: TensorSpec) -> TensorSpec del observation_spec[in_key] for out_key in self.out_keys: - observation_spec[out_key] = UnboundedContinuousTensorSpec( + observation_spec[out_key] = Unbounded( shape=torch.Size([*dim, 1024]), device=device ) @@ -289,7 +285,7 @@ def _init(self): unsqueeze = UnsqueezeTransform( in_keys=in_keys, out_keys=in_keys, - unsqueeze_dim=-4, + dim=-4, ) transforms.append(unsqueeze) @@ -399,7 +395,7 @@ def transform_input_spec(self, input_spec: TensorSpec) -> TensorSpec: if "full_state_spec" in input_spec.keys(): full_state_spec = input_spec["full_state_spec"] else: - full_state_spec = CompositeSpec( + full_state_spec = Composite( shape=input_spec.shape, device=input_spec.device ) # find the obs spec diff --git a/torchrl/envs/utils.py b/torchrl/envs/utils.py index ee7649fabe4..f1724326d2a 100644 --- a/torchrl/envs/utils.py +++ b/torchrl/envs/utils.py @@ -32,13 +32,8 @@ from tensordict.base import _is_leaf_nontensor from tensordict.nn import TensorDictModule, TensorDictModuleBase from tensordict.nn.probabilistic import ( # noqa - # Note: the `set_interaction_mode` and their associated arg `default_interaction_mode` are being deprecated! - # Please use the `set_/interaction_type` ones above with the InteractionType enum instead. - # See more details: https://github.com/pytorch/rl/issues/1016 - interaction_mode as exploration_mode, interaction_type as exploration_type, InteractionType as ExplorationType, - set_interaction_mode as set_exploration_mode, set_interaction_type as set_exploration_type, ) from tensordict.utils import is_non_tensor, NestedKey @@ -47,17 +42,15 @@ from torchrl._utils import _replace_last, _rng_decorator, logger as torchrl_logger from torchrl.data.tensor_specs import ( - CompositeSpec, - NO_DEFAULT, + Composite, + NO_DEFAULT_RL as NO_DEFAULT, TensorSpec, - UnboundedContinuousTensorSpec, + Unbounded, ) -from torchrl.data.utils import check_no_exclusive_keys +from torchrl.data.utils import check_no_exclusive_keys, CloudpickleWrapper __all__ = [ - "exploration_mode", "exploration_type", - "set_exploration_mode", "set_exploration_type", "ExplorationType", "check_env_specs", @@ -69,22 +62,16 @@ ACTION_MASK_ERROR = RuntimeError( - "An out-of-bounds actions has been provided to an env with an 'action_mask' output." - " If you are using a custom policy, make sure to take the action mask into account when computing the output." - " If you are using a default policy, please add the torchrl.envs.transforms.ActionMask transform to your environment." + "An out-of-bounds actions has been provided to an env with an 'action_mask' output. " + "If you are using a custom policy, make sure to take the action mask into account when computing the output. " + "If you are using a default policy, please add the torchrl.envs.transforms.ActionMask transform to your environment. " "If you are using a ParallelEnv or another batched inventor, " - "make sure to add the transform to the ParallelEnv (and not to the sub-environments)." - " For more info on using action masks, see the docs at: " - "https://pytorch.org/rl/reference/envs.html#environments-with-masked-actions" + "make sure to add the transform to the ParallelEnv (and not to the sub-environments). " + "For more info on using action masks, see the docs at: " + "https://pytorch.org/rl/main/reference/envs.html#environments-with-masked-actions" ) -def _convert_exploration_type(*, exploration_mode, exploration_type): - if exploration_mode is not None: - return ExplorationType.from_str(exploration_mode) - return exploration_type - - class _classproperty(property): def __get__(self, cls, owner): return classmethod(self.fget).__get__(None, owner)() @@ -360,7 +347,7 @@ def step_mdp( Given a tensordict retrieved after a step, returns the :obj:`"next"` indexed-tensordict. The arguments allow for a precise control over what should be kept and what - should be copied from the ``"next"`` entry. The default behaviour is: + should be copied from the ``"next"`` entry. The default behavior is: move the observation entries, reward and done states to the root, exclude the current action and keep all extra keys (non-action, non-done, non-reward). @@ -823,7 +810,7 @@ def check_env_specs( "you will need to first pass your stack through `torchrl.data.consolidate_spec`." ) if spec is None: - spec = CompositeSpec(shape=env.batch_size, device=env.device) + spec = Composite(shape=env.batch_size, device=env.device) td = last_td.select(*spec.keys(True, True), strict=True) if not spec.contains(td): raise AssertionError( @@ -835,7 +822,7 @@ def check_env_specs( ("obs", full_observation_spec), ): if spec is None: - spec = CompositeSpec(shape=env.batch_size, device=env.device) + spec = Composite(shape=env.batch_size, device=env.device) td = last_td.get("next").select(*spec.keys(True, True), strict=True) if not spec.contains(td): raise AssertionError( @@ -870,10 +857,10 @@ def _sort_keys(element): def make_composite_from_td(data, unsqueeze_null_shapes: bool = True): - """Creates a CompositeSpec instance from a tensordict, assuming all values are unbounded. + """Creates a Composite instance from a tensordict, assuming all values are unbounded. Args: - data (tensordict.TensorDict): a tensordict to be mapped onto a CompositeSpec. + data (tensordict.TensorDict): a tensordict to be mapped onto a Composite. unsqueeze_null_shapes (bool, optional): if ``True``, every empty shape will be unsqueezed to (1,). Defaults to ``True``. @@ -886,25 +873,25 @@ def make_composite_from_td(data, unsqueeze_null_shapes: bool = True): ... }, []) >>> spec = make_composite_from_td(data) >>> print(spec) - CompositeSpec( - obs: UnboundedContinuousTensorSpec( + Composite( + obs: UnboundedContinuous( shape=torch.Size([3]), space=None, device=cpu, dtype=torch.float32, domain=continuous), - action: UnboundedContinuousTensorSpec( + action: UnboundedContinuous( shape=torch.Size([2]), space=None, device=cpu, dtype=torch.int32, domain=continuous), - next: CompositeSpec( - obs: UnboundedContinuousTensorSpec( + next: Composite( + obs: UnboundedContinuous( shape=torch.Size([3]), space=None, device=cpu, dtype=torch.float32, domain=continuous), - reward: UnboundedContinuousTensorSpec( + reward: UnboundedContinuous( shape=torch.Size([1]), space=ContinuousBox(low=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, contiguous=True), high=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, contiguous=True)), device=cpu, dtype=torch.float32, domain=continuous), device=cpu, shape=torch.Size([])), device=cpu, shape=torch.Size([])) >>> assert (spec.zero() == data.zero_()).all() """ # custom funtion to convert a tensordict in a similar spec structure # of unbounded values. - composite = CompositeSpec( + composite = Composite( { key: make_composite_from_td(tensor) if isinstance(tensor, TensorDictBase) - else UnboundedContinuousTensorSpec( + else Unbounded( dtype=tensor.dtype, device=tensor.device, shape=tensor.shape @@ -1094,14 +1081,14 @@ def _terminated_or_truncated( contained a ``True``. Examples: - >>> from torchrl.data.tensor_specs import DiscreteTensorSpec + >>> from torchrl.data.tensor_specs import Categorical >>> from tensordict import TensorDict - >>> spec = CompositeSpec( - ... done=DiscreteTensorSpec(2, dtype=torch.bool), - ... truncated=DiscreteTensorSpec(2, dtype=torch.bool), - ... nested=CompositeSpec( - ... done=DiscreteTensorSpec(2, dtype=torch.bool), - ... truncated=DiscreteTensorSpec(2, dtype=torch.bool), + >>> spec = Composite( + ... done=Categorical(2, dtype=torch.bool), + ... truncated=Categorical(2, dtype=torch.bool), + ... nested=Composite( + ... done=Categorical(2, dtype=torch.bool), + ... truncated=Categorical(2, dtype=torch.bool), ... ) ... ) >>> data = TensorDict({ @@ -1147,7 +1134,7 @@ def inner_terminated_or_truncated(data, full_done_spec, key, curr_done_key=()): composite_spec = {} found_leaf = 0 for eot_key, item in full_done_spec.items(): - if isinstance(item, CompositeSpec): + if isinstance(item, Composite): composite_spec[eot_key] = item else: found_leaf += 1 @@ -1219,14 +1206,14 @@ def terminated_or_truncated( contained a ``True``. Examples: - >>> from torchrl.data.tensor_specs import DiscreteTensorSpec + >>> from torchrl.data.tensor_specs import Categorical >>> from tensordict import TensorDict - >>> spec = CompositeSpec( - ... done=DiscreteTensorSpec(2, dtype=torch.bool), - ... truncated=DiscreteTensorSpec(2, dtype=torch.bool), - ... nested=CompositeSpec( - ... done=DiscreteTensorSpec(2, dtype=torch.bool), - ... truncated=DiscreteTensorSpec(2, dtype=torch.bool), + >>> spec = Composite( + ... done=Categorical(2, dtype=torch.bool), + ... truncated=Categorical(2, dtype=torch.bool), + ... nested=Composite( + ... done=Categorical(2, dtype=torch.bool), + ... truncated=Categorical(2, dtype=torch.bool), ... ) ... ) >>> data = TensorDict({ @@ -1274,7 +1261,7 @@ def inner_terminated_or_truncated(data, full_done_spec, key, curr_done_key=()): ) else: for eot_key, item in full_done_spec.items(): - if isinstance(item, CompositeSpec): + if isinstance(item, Composite): any_eot = any_eot | inner_terminated_or_truncated( data=data.get(eot_key), full_done_spec=item, @@ -1427,16 +1414,63 @@ def _repr_by_depth(key): return (len(key) - 1, ".".join(key)) -def _make_compatible_policy(policy, observation_spec, env=None, fast_wrap=False): +def _make_compatible_policy( + policy, + observation_spec, + env=None, + fast_wrap=False, + trust_policy=False, + env_maker=None, + env_maker_kwargs=None, +): + if trust_policy: + return policy if policy is None: - if env is None: - raise ValueError( - "env must be provided to _get_policy_and_device if policy is None" - ) - policy = RandomPolicy(env.input_spec["full_action_spec"]) - # make sure policy is an nn.Module - policy = _NonParametricPolicyWrapper(policy) + input_spec = None + if env_maker is not None: + from torchrl.envs import EnvBase, EnvCreator + + if isinstance(env_maker, EnvBase): + env = env_maker + input_spec = env.input_spec["full_action_spec"] + elif isinstance(env_maker, EnvCreator): + input_spec = env_maker._meta_data.specs[ + "input_spec", "full_action_spec" + ] + else: + env = env_maker(**env_maker_kwargs) + input_spec = env.full_action_spec + if input_spec is None: + if env is not None: + input_spec = env.input_spec["full_action_spec"] + else: + raise ValueError( + "env must be provided to _get_policy_and_device if policy is None" + ) + + policy = RandomPolicy(input_spec) + + # make sure policy is an nn.Module - this will return the same policy if conditions are met + # policy = CloudpickleWrapper(policy) + + caller = getattr(policy, "forward", policy) + if not _policy_is_tensordict_compatible(policy): + if observation_spec is None: + if env is not None: + observation_spec = env.observation_spec + elif env_maker is not None: + from torchrl.envs import EnvBase, EnvCreator + + if isinstance(env_maker, EnvBase): + observation_spec = env_maker.observation_spec + elif isinstance(env_maker, EnvCreator): + observation_spec = env_maker._meta_data.specs[ + "output_spec", "full_observation_spec" + ] + else: + observation_spec = env_maker(**env_maker_kwargs).observation_spec + # policy is a nn.Module that doesn't operate on tensordicts directly # so we attempt to auto-wrap policy with TensorDictModule if observation_spec is None: @@ -1445,13 +1479,15 @@ def _make_compatible_policy(policy, observation_spec, env=None, fast_wrap=False) "required to check compatibility of the environment and policy " "since the policy is a nn.Module that operates on tensors " "rather than a TensorDictModule or a nn.Module that accepts a " - "TensorDict as input and defines in_keys and out_keys." + "TensorDict as input and defines in_keys and out_keys. " + "If your policy is compatible with the environment, you can solve this warning by setting " + "trust_policy=True in the constructor." ) try: - sig = policy.forward.__signature__ + sig = caller.__signature__ except AttributeError: - sig = inspect.signature(policy.forward) + sig = inspect.signature(caller) # we check if all the mandatory params are there params = list(sig.parameters.keys()) if ( @@ -1480,7 +1516,7 @@ def _make_compatible_policy(policy, observation_spec, env=None, fast_wrap=False) out_keys = ["action"] else: out_keys = list(env.action_keys) - for p in policy.parameters(): + for p in getattr(policy, "parameters", list)(): policy_device = p.device break else: @@ -1503,7 +1539,7 @@ def _make_compatible_policy(policy, observation_spec, env=None, fast_wrap=False) If you want TorchRL to automatically wrap your policy with a TensorDictModule then the arguments to policy.forward must correspond one-to-one with entries in env.observation_spec. - For more complex behaviour and more control you can consider writing your + For more complex behavior and more control you can consider writing your own TensorDictModule. Check the collector documentation to know more about accepted policies. """ @@ -1512,15 +1548,20 @@ def _make_compatible_policy(policy, observation_spec, env=None, fast_wrap=False) def _policy_is_tensordict_compatible(policy: nn.Module): - if isinstance(policy, _NonParametricPolicyWrapper) and isinstance( - policy.policy, RandomPolicy - ): - return True + def is_compatible(policy): + return isinstance(policy, (RandomPolicy, TensorDictModuleBase)) - if isinstance(policy, TensorDictModuleBase): + if ( + is_compatible(policy) + or ( + isinstance(policy, _NonParametricPolicyWrapper) + and is_compatible(policy.policy) + ) + or (isinstance(policy, CloudpickleWrapper) and is_compatible(policy.fn)) + ): return True - sig = inspect.signature(policy.forward) + sig = inspect.signature(getattr(policy, "forward", policy)) if ( len(sig.parameters) == 1 @@ -1541,7 +1582,7 @@ def _policy_is_tensordict_compatible(policy: nn.Module): # if in_keys or out_keys were defined but policy is not a TensorDictModule or # accepts multiple arguments then it's likely the user is trying to do something - # that will have undetermined behaviour, we raise an error + # that will have undetermined behavior, we raise an error raise TypeError( "Received a policy that defines in_keys or out_keys and also expects multiple " "arguments to policy.forward. If the policy is compatible with TensorDict, it " @@ -1562,8 +1603,8 @@ class RandomPolicy: Examples: >>> from tensordict import TensorDict - >>> from torchrl.data.tensor_specs import BoundedTensorSpec - >>> action_spec = BoundedTensorSpec(-torch.ones(3), torch.ones(3)) + >>> from torchrl.data.tensor_specs import Bounded + >>> action_spec = Bounded(-torch.ones(3), torch.ones(3)) >>> actor = RandomPolicy(action_spec=action_spec) >>> td = actor(TensorDict({}, batch_size=[])) # selects a random action in the cube [-1; 1] """ @@ -1574,7 +1615,7 @@ def __init__(self, action_spec: TensorSpec, action_key: NestedKey = "action"): self.action_key = action_key def __call__(self, td: TensorDictBase) -> TensorDictBase: - if isinstance(self.action_spec, CompositeSpec): + if isinstance(self.action_spec, Composite): return td.update(self.action_spec.rand()) else: return td.set(self.action_key, self.action_spec.rand()) @@ -1593,19 +1634,10 @@ class _NonParametricPolicyWrapper(nn.Module, metaclass=_PolicyMetaClass): def __init__(self, policy): super().__init__() - self.policy = policy - - @property - def forward(self): - forward = self.__dict__.get("_forward", None) - if forward is None: - - @functools.wraps(self.policy) - def forward(*input, **kwargs): - return self.policy.__call__(*input, **kwargs) - - self.__dict__["_forward"] = forward - return forward + functools.update_wrapper(self, policy) + self.policy = CloudpickleWrapper(policy) + if hasattr(policy, "forward"): + self.forward = self.policy.forward def __getattr__(self, attr: str) -> Any: if attr in self.__dir__(): diff --git a/torchrl/modules/__init__.py b/torchrl/modules/__init__.py index 0a06e5844a0..f65461842bb 100644 --- a/torchrl/modules/__init__.py +++ b/torchrl/modules/__init__.py @@ -20,6 +20,8 @@ TruncatedNormal, ) from .models import ( + BatchRenorm1d, + ConsistentDropoutModule, Conv3dNet, ConvNet, DdpgCnnActor, @@ -84,4 +86,5 @@ VmapModule, WorldModelWrapper, ) +from .utils import get_primers_from_module from .planners import CEMPlanner, MPCPlannerBase, MPPIPlanner # usort:skip diff --git a/torchrl/modules/distributions/continuous.py b/torchrl/modules/distributions/continuous.py index fddc2f3415d..8b0d5654b8d 100644 --- a/torchrl/modules/distributions/continuous.py +++ b/torchrl/modules/distributions/continuous.py @@ -5,13 +5,21 @@ from __future__ import annotations import warnings +import weakref from numbers import Number from typing import Dict, Optional, Sequence, Tuple, Union import numpy as np import torch from torch import distributions as D, nn + +try: + from torch.compiler import assume_constant_result +except ImportError: + from torch._dynamo import assume_constant_result + from torch.distributions import constraints +from torch.distributions.transforms import _InverseTransform from torchrl.modules.distributions.truncated_normal import ( TruncatedNormal as _TruncatedNormal, @@ -20,14 +28,19 @@ from torchrl.modules.distributions.utils import ( _cast_device, FasterTransformedDistribution, - safeatanh, - safetanh, + safeatanh_noeps, + safetanh_noeps, ) from torchrl.modules.utils import mappings # speeds up distribution construction D.Distribution.set_default_validate_args(False) +try: + from torch.compiler import is_dynamo_compiling +except ImportError: + from torch._dynamo import is_compiling as is_dynamo_compiling + class IndependentNormal(D.Independent): """Implements a Normal distribution with location scaling. @@ -39,7 +52,7 @@ class IndependentNormal(D.Independent): .. math:: loc = tanh(loc / upscale) * upscale. - This behaviour can be disabled by switching off the tanh_loc parameter (see below). + This behavior can be disabled by switching off the tanh_loc parameter (see below). Args: @@ -92,19 +105,21 @@ class SafeTanhTransform(D.TanhTransform): """TanhTransform subclass that ensured that the transformation is numerically invertible.""" def _call(self, x: torch.Tensor) -> torch.Tensor: - if x.dtype.is_floating_point: - eps = torch.finfo(x.dtype).resolution - else: - raise NotImplementedError(f"No tanh transform for {x.dtype} inputs.") - return safetanh(x, eps) + return safetanh_noeps(x) def _inverse(self, y: torch.Tensor) -> torch.Tensor: - if y.dtype.is_floating_point: - eps = torch.finfo(y.dtype).resolution - else: - raise NotImplementedError(f"No inverse tanh for {y.dtype} inputs.") - x = safeatanh(y, eps) - return x + return safeatanh_noeps(y) + + @property + def inv(self): + inv = None + if self._inv is not None: + inv = self._inv() + if inv is None: + inv = _InverseTransform(self) + if not is_dynamo_compiling(): + self._inv = weakref.ref(inv) + return inv class NormalParamWrapper(nn.Module): @@ -173,7 +188,7 @@ class TruncatedNormal(D.Independent): .. math:: loc = tanh(loc / upscale) * upscale. - This behaviour can be disabled by switching off the tanh_loc parameter (see below). + This behavior can be disabled by switching off the tanh_loc parameter (see below). Args: @@ -202,13 +217,6 @@ class TruncatedNormal(D.Independent): "scale": constraints.greater_than(1e-6), } - def _warn_minmax(self): - warnings.warn( - f"the min / high keyword arguments are deprecated in favor of low / high in {type(self).__name__} " - f"and will be removed entirely in v0.6. ", - DeprecationWarning, - ) - def __init__( self, loc: torch.Tensor, @@ -217,14 +225,7 @@ def __init__( low: Union[torch.Tensor, float] = -1.0, high: Union[torch.Tensor, float] = 1.0, tanh_loc: bool = False, - **kwargs, ): - if "max" in kwargs: - self._warn_minmax() - high = kwargs.pop("max") - if "min" in kwargs: - self._warn_minmax() - low = kwargs.pop("min") err_msg = "TanhNormal high values must be strictly greater than low values" if isinstance(high, torch.Tensor) or isinstance(low, torch.Tensor): @@ -316,6 +317,33 @@ def log_prob(self, value, **kwargs): return lp +class _PatchedComposeTransform(D.ComposeTransform): + @property + def inv(self): + inv = None + if self._inv is not None: + inv = self._inv() + if inv is None: + inv = _PatchedComposeTransform([p.inv for p in reversed(self.parts)]) + if not is_dynamo_compiling(): + self._inv = weakref.ref(inv) + inv._inv = weakref.ref(self) + return inv + + +class _PatchedAffineTransform(D.AffineTransform): + @property + def inv(self): + inv = None + if self._inv is not None: + inv = self._inv() + if inv is None: + inv = _InverseTransform(self) + if not is_dynamo_compiling(): + self._inv = weakref.ref(inv) + return inv + + class TanhNormal(FasterTransformedDistribution): """Implements a TanhNormal distribution with location scaling. @@ -337,13 +365,15 @@ class TanhNormal(FasterTransformedDistribution): .. math:: loc = tanh(loc / upscale) * upscale. - min (torch.Tensor or number, optional): minimum value of the distribution. Default is -1.0; - max (torch.Tensor or number, optional): maximum value of the distribution. Default is 1.0; + low (torch.Tensor or number, optional): minimum value of the distribution. Default is -1.0; + high (torch.Tensor or number, optional): maximum value of the distribution. Default is 1.0; event_dims (int, optional): number of dimensions describing the action. Default is 1. Setting ``event_dims`` to ``0`` will result in a log-probability that has the same shape as the input, ``1`` will reduce (sum over) the last dimension, ``2`` the last two etc. tanh_loc (bool, optional): if ``True``, the above formula is used for the location scaling, otherwise the raw value is kept. Default is ``False``; + safe_tanh (bool, optional): if ``True``, the Tanh transform is done "safely", to avoid numerical overflows. + This will currently break with :func:`torch.compile`. """ arg_constraints = { @@ -353,13 +383,6 @@ class TanhNormal(FasterTransformedDistribution): num_params = 2 - def _warn_minmax(self): - warnings.warn( - f"the min / high keyword arguments are deprecated in favor of low / high in {type(self).__name__} " - f"and will be removed entirely in v0.6. ", - DeprecationWarning, - ) - def __init__( self, loc: torch.Tensor, @@ -369,15 +392,9 @@ def __init__( high: Union[torch.Tensor, Number] = 1.0, event_dims: int | None = None, tanh_loc: bool = False, + safe_tanh: bool = True, **kwargs, ): - if "max" in kwargs: - self._warn_minmax() - high = kwargs.pop("max") - if "min" in kwargs: - self._warn_minmax() - low = kwargs.pop("min") - if not isinstance(loc, torch.Tensor): loc = torch.as_tensor(loc, dtype=torch.get_default_dtype()) if not isinstance(scale, torch.Tensor): @@ -419,13 +436,20 @@ def __init__( self.low = low self.high = high - t = SafeTanhTransform() + if safe_tanh: + if is_dynamo_compiling(): + _err_compile_safetanh() + t = SafeTanhTransform() + else: + t = D.TanhTransform() # t = D.TanhTransform() - if self.non_trivial_max or self.non_trivial_min: - t = D.ComposeTransform( + if is_dynamo_compiling() or (self.non_trivial_max or self.non_trivial_min): + t = _PatchedComposeTransform( [ t, - D.AffineTransform(loc=(high + low) / 2, scale=(high - low) / 2), + _PatchedAffineTransform( + loc=(high + low) / 2, scale=(high - low) / 2 + ), ] ) self._t = t @@ -446,7 +470,7 @@ def update(self, loc: torch.Tensor, scale: torch.Tensor) -> None: if self.tanh_loc: loc = (loc / self.upscale).tanh() * self.upscale # loc must be rescaled if tanh_loc - if self.non_trivial_max or self.non_trivial_min: + if is_dynamo_compiling() or (self.non_trivial_max or self.non_trivial_min): loc = loc + (self.high - self.low) / 2 + self.low self.loc = loc self.scale = scale @@ -466,6 +490,10 @@ def update(self, loc: torch.Tensor, scale: torch.Tensor) -> None: base = D.Normal(self.loc, self.scale) super().__init__(base, self._t) + @property + def support(self): + return D.constraints.real() + @property def root_dist(self): bd = self @@ -475,15 +503,10 @@ def root_dist(self): @property def mode(self): - warnings.warn( - "This computation of the mode is based on an inaccurate estimation of the mode " - "given the base_dist mode. " - "To use a more stable implementation of the mode, use dist.get_mode() method instead. " - "To silence this warning, consider using the DETERMINISTIC exploration_type." - "This implementation will be removed in v0.6.", - category=DeprecationWarning, + raise RuntimeError( + f"The distribution {type(self).__name__} has not analytical mode. " + f"Use ExplorationMode.DETERMINISTIC to get a deterministic sample from it." ) - return self.deterministic_sample @property def deterministic_sample(self): @@ -647,13 +670,6 @@ class TanhDelta(FasterTransformedDistribution): "loc": constraints.real, } - def _warn_minmax(self): - warnings.warn( - f"the min / high keyword arguments are deprecated in favor of low / high in {type(self).__name__} " - f"and will be removed entirely in v0.6. ", - category=DeprecationWarning, - ) - def __init__( self, param: torch.Tensor, @@ -662,15 +678,7 @@ def __init__( event_dims: int = 1, atol: float = 1e-6, rtol: float = 1e-6, - **kwargs, ): - if "max" in kwargs: - self._warn_minmax() - high = kwargs.pop("max") - if "min" in kwargs: - self._warn_minmax() - low = kwargs.pop("min") - minmax_msg = "high value has been found to be equal or less than low value" if isinstance(high, torch.Tensor) or isinstance(low, torch.Tensor): if not (high > low).all(): @@ -696,10 +704,10 @@ def __init__( loc = self.update(param) if self.non_trivial: - t = D.ComposeTransform( + t = _PatchedComposeTransform( [ t, - D.AffineTransform( + _PatchedAffineTransform( loc=(self.high + self.low) / 2, scale=(self.high - self.low) / 2 ), ] @@ -712,7 +720,6 @@ def __init__( rtol=rtol, batch_shape=batch_shape, event_shape=event_shape, - **kwargs, ) super().__init__(base, t) @@ -761,3 +768,16 @@ def _uniform_sample_delta(dist: Delta, size=None) -> torch.Tensor: uniform_sample_delta = _uniform_sample_delta + + +def _err_compile_safetanh(): + raise RuntimeError( + "safe_tanh=True in TanhNormal is not compatible with torch.compile. To deactivate it, pass" + "safe_tanh=False. " + "If you are using a ProbabilisticTensorDictModule, this can be done via " + "`distribution_kwargs={'safe_tanh': False}`. " + "See https://github.com/pytorch/pytorch/issues/133529 for more details." + ) + + +_warn_compile_safetanh = assume_constant_result(_err_compile_safetanh) diff --git a/torchrl/modules/distributions/discrete.py b/torchrl/modules/distributions/discrete.py index c48d8168887..d2ffba30686 100644 --- a/torchrl/modules/distributions/discrete.py +++ b/torchrl/modules/distributions/discrete.py @@ -389,6 +389,17 @@ def sample( ) -> torch.Tensor: ... + @property + def deterministic_sample(self): + return self.mode + + @property + def mode(self) -> torch.Tensor: + if hasattr(self, "logits"): + return (self.logits == self.logits.max(-1, True)[0]).to(torch.long) + else: + return (self.probs == self.probs.max(-1, True)[0]).to(torch.long) + def log_prob(self, value: torch.Tensor) -> torch.Tensor: return super().log_prob(value.argmax(dim=-1)) diff --git a/torchrl/modules/distributions/utils.py b/torchrl/modules/distributions/utils.py index 267632c4fd9..546d93cb228 100644 --- a/torchrl/modules/distributions/utils.py +++ b/torchrl/modules/distributions/utils.py @@ -6,7 +6,6 @@ from typing import Union import torch -from packaging import version from torch import autograd, distributions as d from torch.distributions import Independent, Transform, TransformedDistribution @@ -92,72 +91,133 @@ def __init__(self, base_distribution, transforms, validate_args=None): ) -if version.parse(torch.__version__) >= version.parse("2.0.0"): - - class _SafeTanh(autograd.Function): - generate_vmap_rule = True - - @staticmethod - def forward(input, eps): - output = input.tanh() - lim = 1.0 - eps - output = output.clamp(-lim, lim) - # ctx.save_for_backward(output) - return output - - @staticmethod - def setup_context(ctx, inputs, output): - # input, eps = inputs - # ctx.mark_non_differentiable(ind, ind_inv) - # # Tensors must be saved via ctx.save_for_backward. Please do not - # # assign them directly onto the ctx object. - ctx.save_for_backward(output) - - @staticmethod - def backward(ctx, *grad): - grad = grad[0] - (output,) = ctx.saved_tensors - return (grad * (1 - output.pow(2)), None) - - class _SafeaTanh(autograd.Function): - generate_vmap_rule = True - - @staticmethod - def setup_context(ctx, inputs, output): - tanh_val, eps = inputs - # ctx.mark_non_differentiable(ind, ind_inv) - # # Tensors must be saved via ctx.save_for_backward. Please do not - # # assign them directly onto the ctx object. - ctx.save_for_backward(tanh_val) - ctx.eps = eps - - @staticmethod - def forward(tanh_val, eps): - lim = 1.0 - eps - output = tanh_val.clamp(-lim, lim) - # ctx.save_for_backward(output) - output = output.atanh() - return output - - @staticmethod - def backward(ctx, *grad): - grad = grad[0] - (tanh_val,) = ctx.saved_tensors - eps = ctx.eps - lim = 1.0 - eps - output = tanh_val.clamp(-lim, lim) - return (grad / (1 - output.pow(2)), None) - - safetanh = _SafeTanh.apply - safeatanh = _SafeaTanh.apply - -else: - - def safetanh(x, eps): # noqa: D103 +def _safetanh(x, eps): # noqa: D103 + lim = 1.0 - eps + y = x.tanh() + return y.clamp(-lim, lim) + + +def _safeatanh(y, eps): # noqa: D103 + lim = 1.0 - eps + return y.clamp(-lim, lim).atanh() + + +class _SafeTanh(autograd.Function): + generate_vmap_rule = True + + @staticmethod + def forward(input, eps): + output = input.tanh() lim = 1.0 - eps - y = x.tanh() - return y.clamp(-lim, lim) + output = output.clamp(-lim, lim) + # ctx.save_for_backward(output) + return output + + @staticmethod + def setup_context(ctx, inputs, output): + # input, eps = inputs + # ctx.mark_non_differentiable(ind, ind_inv) + # # Tensors must be saved via ctx.save_for_backward. Please do not + # # assign them directly onto the ctx object. + ctx.save_for_backward(output) + + @staticmethod + def backward(ctx, *grad): + grad = grad[0] + (output,) = ctx.saved_tensors + return (grad * (1 - output.pow(2)), None) + + +class _SafeTanhNoEps(autograd.Function): + generate_vmap_rule = True + + @staticmethod + def forward(input): + output = input.tanh() + eps = torch.finfo(input.dtype).resolution + lim = 1.0 - eps + output = output.clamp(-lim, lim) + return output + + @staticmethod + def setup_context(ctx, inputs, output): + ctx.save_for_backward(output) + + @staticmethod + def backward(ctx, *grad): + grad = grad[0] + (output,) = ctx.saved_tensors + return (grad * (1 - output.pow(2)),) + + +class _SafeaTanh(autograd.Function): + generate_vmap_rule = True - def safeatanh(y, eps): # noqa: D103 + @staticmethod + def forward(tanh_val, eps): + if eps is None: + eps = torch.finfo(tanh_val.dtype).resolution lim = 1.0 - eps - return y.clamp(-lim, lim).atanh() + output = tanh_val.clamp(-lim, lim) + # ctx.save_for_backward(output) + output = output.atanh() + return output + + @staticmethod + def setup_context(ctx, inputs, output): + tanh_val, eps = inputs + + # ctx.mark_non_differentiable(ind, ind_inv) + # # Tensors must be saved via ctx.save_for_backward. Please do not + # # assign them directly onto the ctx object. + ctx.save_for_backward(tanh_val) + ctx.eps = eps + + @staticmethod + def backward(ctx, *grad): + grad = grad[0] + (tanh_val,) = ctx.saved_tensors + eps = ctx.eps + lim = 1.0 - eps + output = tanh_val.clamp(-lim, lim) + return (grad / (1 - output.pow(2)), None) + + +class _SafeaTanhNoEps(autograd.Function): + generate_vmap_rule = True + + @staticmethod + def forward(tanh_val): + eps = torch.finfo(tanh_val.dtype).resolution + lim = 1.0 - eps + output = tanh_val.clamp(-lim, lim) + # ctx.save_for_backward(output) + output = output.atanh() + return output + + @staticmethod + def setup_context(ctx, inputs, output): + tanh_val = inputs[0] + eps = torch.finfo(tanh_val.dtype).resolution + + # ctx.mark_non_differentiable(ind, ind_inv) + # # Tensors must be saved via ctx.save_for_backward. Please do not + # # assign them directly onto the ctx object. + ctx.save_for_backward(tanh_val) + ctx.eps = eps + + @staticmethod + def backward(ctx, *grad): + grad = grad[0] + (tanh_val,) = ctx.saved_tensors + eps = ctx.eps + lim = 1.0 - eps + output = tanh_val.clamp(-lim, lim) + return (grad / (1 - output.pow(2)),) + + +safetanh = _SafeTanh.apply +safeatanh = _SafeaTanh.apply + +safetanh_noeps = _SafeTanhNoEps.apply +safeatanh_noeps = _SafeaTanhNoEps.apply diff --git a/torchrl/modules/models/__init__.py b/torchrl/modules/models/__init__.py index 9a814e35477..90b9fadd747 100644 --- a/torchrl/modules/models/__init__.py +++ b/torchrl/modules/models/__init__.py @@ -9,7 +9,12 @@ from .batchrenorm import BatchRenorm1d from .decision_transformer import DecisionTransformer -from .exploration import NoisyLazyLinear, NoisyLinear, reset_noise +from .exploration import ( + ConsistentDropoutModule, + NoisyLazyLinear, + NoisyLinear, + reset_noise, +) from .model_based import ( DreamerActor, ObsDecoder, diff --git a/torchrl/modules/models/batchrenorm.py b/torchrl/modules/models/batchrenorm.py index 26a2f9d50d2..41de0945f70 100644 --- a/torchrl/modules/models/batchrenorm.py +++ b/torchrl/modules/models/batchrenorm.py @@ -32,9 +32,9 @@ class BatchRenorm1d(nn.Module): Defaults to ``5.0``. warmup_steps (int, optional): Number of warm-up steps for the running mean and variance. Defaults to ``10000``. - smooth (bool, optional): if ``True``, the behaviour smoothly transitions from regular + smooth (bool, optional): if ``True``, the behavior smoothly transitions from regular batch-norm (when ``iter=0``) to batch-renorm (when ``iter=warmup_steps``). - Otherwise, the behaviour will transition from batch-norm to batch-renorm when + Otherwise, the behavior will transition from batch-norm to batch-renorm when ``iter=warmup_steps``. Defaults to ``False``. """ diff --git a/torchrl/modules/models/exploration.py b/torchrl/modules/models/exploration.py index 2ec51b46559..d69a85fd685 100644 --- a/torchrl/modules/models/exploration.py +++ b/torchrl/modules/models/exploration.py @@ -2,16 +2,24 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +from __future__ import annotations + +import functools import math import warnings -from typing import Optional, Sequence, Union +from typing import List, Optional, Sequence, Union import torch + +from tensordict.nn import TensorDictModuleBase +from tensordict.utils import NestedKey from torch import distributions as d, nn +from torch.nn import functional as F +from torch.nn.modules.dropout import _DropoutNd from torch.nn.modules.lazy import LazyModuleMixin from torch.nn.parameter import UninitializedBuffer, UninitializedParameter - from torchrl._utils import prod +from torchrl.data.tensor_specs import Unbounded from torchrl.data.utils import DEVICE_TYPING, DEVICE_TYPING_ARGS from torchrl.envs.utils import exploration_type, ExplorationType from torchrl.modules.distributions.utils import _cast_transform_device @@ -359,7 +367,7 @@ def sigma(self): def forward(self, mu, state, _eps_gSDE): sigma = self.sigma.clamp_max(self.scale_max) - _err_explo = f"gSDE behaviour for exploration mode {exploration_type()} is not defined. Choose from 'random' or 'mode'." + _err_explo = f"gSDE behavior for exploration mode {exploration_type()} is not defined. Choose from 'random' or 'mode'." if state.shape[:-1] != mu.shape[:-1]: _err_msg = f"mu and state are expected to have matching batch size, got shapes {mu.shape} and {state.shape}" @@ -520,3 +528,203 @@ def initialize_parameters( ) self._sigma.materialize((action_dim, state_dim)) self._sigma.data.copy_(self.sigma_init.expand_as(self._sigma)) + + +class ConsistentDropout(_DropoutNd): + """Implements a :class:`~torch.nn.Dropout` variant with consistent dropout. + + This method is proposed in `"Consistent Dropout for Policy Gradient Reinforcement Learning" (Hausknecht & Wagener, 2022) + `_. + + This :class:`~torch.nn.Dropout` variant attempts to increase training stability and + reduce update variance by caching the dropout masks used during rollout + and reusing them during the update phase. + + The class you are looking at is independent of the rest of TorchRL's API and does not require tensordict to be run. + :class:`~torchrl.modules.ConsistentDropoutModule` is a wrapper around ``ConsistentDropout`` that capitalizes on the extensibility + of ``TensorDict``s by storing generated dropout masks in the transition ``TensorDict`` themselves. + See this class for a detailed explanation as well as usage examples. + + There is otherwise little conceptual deviance from the PyTorch + :class:`~torch.nn.Dropout` implementation. + + ..note:: TorchRL's data collectors perform rollouts in :meth:`~torch.no_grad` mode but not in `eval` mode, + so the dropout masks will be applied unless the policy passed to the collector is in eval mode. + + .. note:: Unlike other exploration modules, :class:`~torchrl.modules.ConsistentDropoutModule` + uses the ``train``/``eval`` mode to comply with the regular `Dropout` API in PyTorch. + The :func:`~torchrl.envs.utils.set_exploration_type` context manager will have no effect on + this module. + + Args: + p (float, optional): Dropout probability. Defaults to ``0.5``. + + .. seealso:: + + - :class:`~torchrl.collectors.SyncDataCollector`: + :meth:`~torchrl.collectors.SyncDataCollector.rollout()` and :meth:`~torchrl.collectors.SyncDataCollector.iterator()` + - :class:`~torchrl.collectors.MultiSyncDataCollector`: + Uses :meth:`~torchrl.collectors.collectors._main_async_collector` (:class:`~torchrl.collectors.SyncDataCollector`) + under the hood + - :class:`~torchrl.collectors.MultiaSyncDataCollector`, :class:`~torchrl.collectors.aSyncDataCollector`: Ditto. + + """ + + def __init__(self, p: float = 0.5): + super().__init__() + self.p = p + + def forward( + self, x: torch.Tensor, mask: torch.Tensor | None = None + ) -> torch.Tensor: + """During training (rollouts & updates), this call masks a tensor full of ones before multiplying with the input tensor. + + During evaluation, this call results in a no-op and only the input is returned. + + Args: + x (torch.Tensor): the input tensor. + mask (torch.Tensor, optional): the optional mask for the dropout. + + Returns: a tensor and a corresponding mask in train mode, and only a tensor in eval mode. + """ + if self.training: + if mask is None: + mask = self.make_mask(input=x) + return x * mask, mask + + return x + + def make_mask(self, *, input=None, shape=None): + if input is not None: + return F.dropout( + torch.ones_like(input), self.p, self.training, inplace=False + ) + elif shape is not None: + return F.dropout(torch.ones(shape), self.p, self.training, inplace=False) + else: + raise RuntimeError("input or shape must be passed to make_mask.") + + +class ConsistentDropoutModule(TensorDictModuleBase): + """A TensorDictModule wrapper for :class:`~ConsistentDropout`. + + Args: + p (float, optional): Dropout probability. Default: ``0.5``. + in_keys (NestedKey or list of NestedKeys): keys to be read + from input tensordict and passed to this module. + out_keys (NestedKey or iterable of NestedKeys): keys to be written to the input tensordict. + Defaults to ``in_keys`` values. + + Keyword Args: + input_shape (tuple, optional): the shape of the input (non-batchted), used to generate the + tensordict primers with :meth:`~.make_tensordict_primer`. + input_dtype (torch.dtype, optional): the dtype of the input for the primer. If none is pased, + ``torch.get_default_dtype`` is assumed. + + .. note:: To use this class within a policy, one needs the mask to be reset at reset time. + This can be achieved through a :class:`~torchrl.envs.TensorDictPrimer` transform that can be obtained + with :meth:`~.make_tensordict_primer`. See this method for more information. + + Examples: + >>> from tensordict import TensorDict + >>> module = ConsistentDropoutModule(p = 0.1) + >>> td = TensorDict({"x": torch.randn(3, 4)}, [3]) + >>> module(td) + TensorDict( + fields={ + mask_6127171760: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.bool, is_shared=False), + x: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False)}, + batch_size=torch.Size([3]), + device=None, + is_shared=False) + """ + + def __init__( + self, + p: float, + in_keys: NestedKey | List[NestedKey], + out_keys: NestedKey | List[NestedKey] | None = None, + input_shape: torch.Size = None, + input_dtype: torch.dtype | None = None, + ): + if isinstance(in_keys, NestedKey): + in_keys = [in_keys, f"mask_{id(self)}"] + if out_keys is None: + out_keys = list(in_keys) + if isinstance(out_keys, NestedKey): + out_keys = [out_keys, f"mask_{id(self)}"] + if len(in_keys) != 2 or len(out_keys) != 2: + raise ValueError( + "in_keys and out_keys length must be 2 for consistent dropout." + ) + self.in_keys = in_keys + self.out_keys = out_keys + self.input_shape = input_shape + self.input_dtype = input_dtype + super().__init__() + + if not 0 <= p < 1: + raise ValueError(f"p must be in [0,1), got p={p: 4.4f}.") + + self.consistent_dropout = ConsistentDropout(p) + + def forward(self, tensordict): + x = tensordict.get(self.in_keys[0]) + mask = tensordict.get(self.in_keys[1], default=None) + if self.consistent_dropout.training: + x, mask = self.consistent_dropout(x, mask=mask) + tensordict.set(self.out_keys[0], x) + tensordict.set(self.out_keys[1], mask) + else: + x = self.consistent_dropout(x, mask=mask) + tensordict.set(self.out_keys[0], x) + + return tensordict + + def make_tensordict_primer(self): + """Makes a tensordict primer for the environment to generate random masks during reset calls. + + .. seealso:: :func:`torchrl.modules.utils.get_primers_from_module` for a method to generate all primers for a given + module. + + Examples: + >>> from tensordict.nn import TensorDictSequential as Seq, TensorDictModule as Mod + >>> from torchrl.envs import GymEnv, StepCounter, SerialEnv + >>> m = Seq( + ... Mod(torch.nn.Linear(7, 4), in_keys=["observation"], out_keys=["intermediate"]), + ... ConsistentDropoutModule( + ... p=0.5, + ... input_shape=(2, 4), + ... in_keys="intermediate", + ... ), + ... Mod(torch.nn.Linear(4, 7), in_keys=["intermediate"], out_keys=["action"]), + ... ) + >>> primer = get_primers_from_module(m) + >>> env0 = GymEnv("Pendulum-v1").append_transform(StepCounter(5)) + >>> env1 = GymEnv("Pendulum-v1").append_transform(StepCounter(6)) + >>> env = SerialEnv(2, [lambda env=env0: env, lambda env=env1: env]) + >>> env = env.append_transform(primer) + >>> r = env.rollout(10, m, break_when_any_done=False) + >>> mask = [k for k in r.keys() if k.startswith("mask")][0] + >>> assert (r[mask][0, :5] != r[mask][0, 5:6]).any() + >>> assert (r[mask][0, :4] == r[mask][0, 4:5]).all() + + """ + from torchrl.envs.transforms.transforms import TensorDictPrimer + + shape = self.input_shape + dtype = self.input_dtype + if dtype is None: + dtype = torch.get_default_dtype() + if shape is None: + raise RuntimeError( + "Cannot infer the shape of the input automatically. " + "Please pass the shape of the tensor to `ConstistentDropoutModule` during construction " + "with the `input_shape` kwarg." + ) + return TensorDictPrimer( + primers={self.in_keys[1]: Unbounded(dtype=dtype, shape=shape)}, + default_value=functools.partial( + self.consistent_dropout.make_mask, shape=shape + ), + ) diff --git a/torchrl/modules/models/models.py b/torchrl/modules/models/models.py index 23c229c6524..3faaa396299 100644 --- a/torchrl/modules/models/models.py +++ b/torchrl/modules/models/models.py @@ -5,6 +5,7 @@ from __future__ import annotations import dataclasses +import warnings from copy import deepcopy from numbers import Number @@ -179,8 +180,15 @@ def __init__( if out_features is None: raise ValueError("out_features must be specified for MLP.") - default_num_cells = 32 if num_cells is None: + warnings.warn( + "The current behavior of MLP when not providing `num_cells` is that the number of cells is " + "set to [default_num_cells] * depth, where `depth=3` by default and `default_num_cells=0`. " + "From v0.7, this behavior will switch and `depth=0` will be used. " + "To silence tis message, indicate what number of cells you desire.", + category=DeprecationWarning, + ) + default_num_cells = 32 if depth is None: num_cells = [default_num_cells] * 3 depth = 3 diff --git a/torchrl/modules/planners/cem.py b/torchrl/modules/planners/cem.py index 6d9e6fb3b49..abc0e3d3f95 100644 --- a/torchrl/modules/planners/cem.py +++ b/torchrl/modules/planners/cem.py @@ -45,20 +45,20 @@ class CEMPlanner(MPCPlannerBase): Examples: >>> from tensordict import TensorDict - >>> from torchrl.data import CompositeSpec, UnboundedContinuousTensorSpec + >>> from torchrl.data import Composite, Unbounded >>> from torchrl.envs.model_based import ModelBasedEnvBase >>> from torchrl.modules import SafeModule >>> class MyMBEnv(ModelBasedEnvBase): ... def __init__(self, world_model, device="cpu", dtype=None, batch_size=None): ... super().__init__(world_model, device=device, dtype=dtype, batch_size=batch_size) - ... self.state_spec = CompositeSpec( - ... hidden_observation=UnboundedContinuousTensorSpec((4,)) + ... self.state_spec = Composite( + ... hidden_observation=Unbounded((4,)) ... ) - ... self.observation_spec = CompositeSpec( - ... hidden_observation=UnboundedContinuousTensorSpec((4,)) + ... self.observation_spec = Composite( + ... hidden_observation=Unbounded((4,)) ... ) - ... self.action_spec = UnboundedContinuousTensorSpec((1,)) - ... self.reward_spec = UnboundedContinuousTensorSpec((1,)) + ... self.action_spec = Unbounded((1,)) + ... self.reward_spec = Unbounded((1,)) ... ... def _reset(self, tensordict: TensorDict) -> TensorDict: ... tensordict = TensorDict( diff --git a/torchrl/modules/planners/mppi.py b/torchrl/modules/planners/mppi.py index 9c0bbc8f147..002094fb5d2 100644 --- a/torchrl/modules/planners/mppi.py +++ b/torchrl/modules/planners/mppi.py @@ -43,7 +43,7 @@ class MPPIPlanner(MPCPlannerBase): Examples: >>> from tensordict import TensorDict - >>> from torchrl.data import CompositeSpec, UnboundedContinuousTensorSpec + >>> from torchrl.data import Composite, Unbounded >>> from torchrl.envs.model_based import ModelBasedEnvBase >>> from tensordict.nn import TensorDictModule >>> from torchrl.modules import ValueOperator @@ -51,14 +51,14 @@ class MPPIPlanner(MPCPlannerBase): >>> class MyMBEnv(ModelBasedEnvBase): ... def __init__(self, world_model, device="cpu", dtype=None, batch_size=None): ... super().__init__(world_model, device=device, dtype=dtype, batch_size=batch_size) - ... self.state_spec = CompositeSpec( - ... hidden_observation=UnboundedContinuousTensorSpec((4,)) + ... self.state_spec = Composite( + ... hidden_observation=Unbounded((4,)) ... ) - ... self.observation_spec = CompositeSpec( - ... hidden_observation=UnboundedContinuousTensorSpec((4,)) + ... self.observation_spec = Composite( + ... hidden_observation=Unbounded((4,)) ... ) - ... self.action_spec = UnboundedContinuousTensorSpec((1,)) - ... self.reward_spec = UnboundedContinuousTensorSpec((1,)) + ... self.action_spec = Unbounded((1,)) + ... self.reward_spec = Unbounded((1,)) ... ... def _reset(self, tensordict: TensorDict) -> TensorDict: ... tensordict = TensorDict( diff --git a/torchrl/modules/tensordict_module/actors.py b/torchrl/modules/tensordict_module/actors.py index 81b7ec1e605..003c35cf0eb 100644 --- a/torchrl/modules/tensordict_module/actors.py +++ b/torchrl/modules/tensordict_module/actors.py @@ -22,7 +22,7 @@ from torch.distributions import Categorical from torchrl._utils import _replace_last -from torchrl.data.tensor_specs import CompositeSpec, TensorSpec +from torchrl.data.tensor_specs import Composite, TensorSpec from torchrl.data.utils import _process_action_space_spec from torchrl.modules.tensordict_module.common import DistributionalDQNnet, SafeModule from torchrl.modules.tensordict_module.probabilistic import ( @@ -37,8 +37,8 @@ class Actor(SafeModule): The Actor class comes with default values for the out_keys (``["action"]``) and if the spec is provided but not as a - :class:`~torchrl.data.CompositeSpec` object, it will be - automatically translated into ``spec = CompositeSpec(action=spec)``. + :class:`~torchrl.data.Composite` object, it will be + automatically translated into ``spec = Composite(action=spec)``. Args: module (nn.Module): a :class:`~torch.nn.Module` used to map the input to @@ -70,11 +70,11 @@ class Actor(SafeModule): Examples: >>> import torch >>> from tensordict import TensorDict - >>> from torchrl.data import UnboundedContinuousTensorSpec + >>> from torchrl.data import Unbounded >>> from torchrl.modules import Actor >>> torch.manual_seed(0) >>> td = TensorDict({"observation": torch.randn(3, 4)}, [3,]) - >>> action_spec = UnboundedContinuousTensorSpec(4) + >>> action_spec = Unbounded(4) >>> module = torch.nn.Linear(4, 4) >>> td_module = Actor( ... module=module, @@ -111,9 +111,9 @@ def __init__( if ( "action" in out_keys and spec is not None - and not isinstance(spec, CompositeSpec) + and not isinstance(spec, Composite) ): - spec = CompositeSpec(action=spec) + spec = Composite(action=spec) super().__init__( module, @@ -128,8 +128,8 @@ class ProbabilisticActor(SafeProbabilisticTensorDictSequential): """General class for probabilistic actors in RL. The Actor class comes with default values for the out_keys (["action"]) - and if the spec is provided but not as a CompositeSpec object, it will be - automatically translated into :obj:`spec = CompositeSpec(action=spec)` + and if the spec is provided but not as a Composite object, it will be + automatically translated into :obj:`spec = Composite(action=spec)` Args: module (nn.Module): a :class:`torch.nn.Module` used to map the input to @@ -205,10 +205,10 @@ class ProbabilisticActor(SafeProbabilisticTensorDictSequential): >>> import torch >>> from tensordict import TensorDict >>> from tensordict.nn import TensorDictModule - >>> from torchrl.data import BoundedTensorSpec + >>> from torchrl.data import Bounded >>> from torchrl.modules import ProbabilisticActor, NormalParamExtractor, TanhNormal >>> td = TensorDict({"observation": torch.randn(3, 4)}, [3,]) - >>> action_spec = BoundedTensorSpec(shape=torch.Size([4]), + >>> action_spec = Bounded(shape=torch.Size([4]), ... low=-1, high=1) >>> module = nn.Sequential(torch.nn.Linear(4, 8), NormalParamExtractor()) >>> tensordict_module = TensorDictModule(module, in_keys=["observation"], out_keys=["loc", "scale"]) @@ -382,12 +382,8 @@ def __init__( out_keys = list(distribution_map.keys()) else: out_keys = ["action"] - if ( - len(out_keys) == 1 - and spec is not None - and not isinstance(spec, CompositeSpec) - ): - spec = CompositeSpec({out_keys[0]: spec}) + if len(out_keys) == 1 and spec is not None and not isinstance(spec, Composite): + spec = Composite({out_keys[0]: spec}) super().__init__( module, @@ -424,7 +420,7 @@ class ValueOperator(TensorDictModule): >>> import torch >>> from tensordict import TensorDict >>> from torch import nn - >>> from torchrl.data import UnboundedContinuousTensorSpec + >>> from torchrl.data import Unbounded >>> from torchrl.modules import ValueOperator >>> td = TensorDict({"observation": torch.randn(3, 4), "action": torch.randn(3, 2)}, [3,]) >>> class CustomModule(nn.Module): @@ -577,22 +573,22 @@ def __init__( ) self.out_keys = out_keys action_key = out_keys[0] - if not isinstance(spec, CompositeSpec): - spec = CompositeSpec({action_key: spec}) + if not isinstance(spec, Composite): + spec = Composite({action_key: spec}) super().__init__() self.register_spec(safe=safe, spec=spec) register_spec = SafeModule.register_spec @property - def spec(self) -> CompositeSpec: + def spec(self) -> Composite: return self._spec @spec.setter - def spec(self, spec: CompositeSpec) -> None: - if not isinstance(spec, CompositeSpec): + def spec(self, spec: Composite) -> None: + if not isinstance(spec, Composite): raise RuntimeError( - f"Trying to set an object of type {type(spec)} as a tensorspec but expected a CompositeSpec instance." + f"Trying to set an object of type {type(spec)} as a tensorspec but expected a Composite instance." ) self._spec = spec @@ -891,13 +887,13 @@ class QValueHook: >>> import torch >>> from tensordict import TensorDict >>> from torch import nn - >>> from torchrl.data import OneHotDiscreteTensorSpec + >>> from torchrl.data import OneHot >>> from torchrl.modules.tensordict_module.actors import QValueHook, Actor >>> td = TensorDict({'observation': torch.randn(5, 4)}, [5]) >>> module = nn.Linear(4, 4) >>> hook = QValueHook("one_hot") >>> module.register_forward_hook(hook) - >>> action_spec = OneHotDiscreteTensorSpec(4) + >>> action_spec = OneHot(4) >>> qvalue_actor = Actor(module=module, spec=action_spec, out_keys=["action", "action_value"]) >>> td = qvalue_actor(td) >>> print(td) @@ -975,7 +971,7 @@ class DistributionalQValueHook(QValueHook): >>> import torch >>> from tensordict import TensorDict >>> from torch import nn - >>> from torchrl.data import OneHotDiscreteTensorSpec + >>> from torchrl.data import OneHot >>> from torchrl.modules.tensordict_module.actors import DistributionalQValueHook, Actor >>> td = TensorDict({'observation': torch.randn(5, 4)}, [5]) >>> nbins = 3 @@ -989,7 +985,7 @@ class DistributionalQValueHook(QValueHook): ... >>> module = CustomDistributionalQval() >>> params = TensorDict.from_module(module) - >>> action_spec = OneHotDiscreteTensorSpec(4) + >>> action_spec = OneHot(4) >>> hook = DistributionalQValueHook("one_hot", support = torch.arange(nbins)) >>> module.register_forward_hook(hook) >>> qvalue_actor = Actor(module=module, spec=action_spec, out_keys=["action", "action_value"]) @@ -1085,12 +1081,12 @@ class QValueActor(SafeSequential): >>> import torch >>> from tensordict import TensorDict >>> from torch import nn - >>> from torchrl.data import OneHotDiscreteTensorSpec + >>> from torchrl.data import OneHot >>> from torchrl.modules.tensordict_module.actors import QValueActor >>> td = TensorDict({'observation': torch.randn(5, 4)}, [5]) >>> # with a regular nn.Module >>> module = nn.Linear(4, 4) - >>> action_spec = OneHotDiscreteTensorSpec(4) + >>> action_spec = OneHot(4) >>> qvalue_actor = QValueActor(module=module, spec=action_spec) >>> td = qvalue_actor(td) >>> print(td) @@ -1106,7 +1102,7 @@ class QValueActor(SafeSequential): >>> # with a TensorDictModule >>> td = TensorDict({'obs': torch.randn(5, 4)}, [5]) >>> module = TensorDictModule(lambda x: x, in_keys=["obs"], out_keys=["action_value"]) - >>> action_spec = OneHotDiscreteTensorSpec(4) + >>> action_spec = OneHot(4) >>> qvalue_actor = QValueActor(module=module, spec=action_spec) >>> td = qvalue_actor(td) >>> print(td) @@ -1161,13 +1157,13 @@ def __init__( module, in_keys=in_keys, out_keys=[action_value_key] ) if spec is None: - spec = CompositeSpec() - if isinstance(spec, CompositeSpec): + spec = Composite() + if isinstance(spec, Composite): spec = spec.clone() if "action" not in spec.keys(): spec["action"] = None else: - spec = CompositeSpec(action=spec, shape=spec.shape[:-1]) + spec = Composite(action=spec, shape=spec.shape[:-1]) spec[action_value_key] = None spec["chosen_action_value"] = None qvalue = QValueModule( @@ -1237,7 +1233,7 @@ class DistributionalQValueActor(QValueActor): >>> from tensordict import TensorDict >>> from tensordict.nn import TensorDictModule, TensorDictSequential >>> from torch import nn - >>> from torchrl.data import OneHotDiscreteTensorSpec + >>> from torchrl.data import OneHot >>> from torchrl.modules import DistributionalQValueActor, MLP >>> td = TensorDict({'observation': torch.randn(5, 4)}, [5]) >>> nbins = 3 @@ -1247,7 +1243,7 @@ class DistributionalQValueActor(QValueActor): ... TensorDictModule(module, ["observation"], ["action_value"]), ... TensorDictModule(lambda x: x.log_softmax(-2), ["action_value"], ["action_value"]), ... ) - >>> action_spec = OneHotDiscreteTensorSpec(4) + >>> action_spec = OneHot(4) >>> qvalue_actor = DistributionalQValueActor( ... module=module, ... spec=action_spec, @@ -1299,13 +1295,13 @@ def __init__( module, in_keys=in_keys, out_keys=[action_value_key] ) if spec is None: - spec = CompositeSpec() - if isinstance(spec, CompositeSpec): + spec = Composite() + if isinstance(spec, Composite): spec = spec.clone() if "action" not in spec.keys(): spec["action"] = None else: - spec = CompositeSpec(action=spec, shape=spec.shape[:-1]) + spec = Composite(action=spec, shape=spec.shape[:-1]) spec[action_value_key] = None qvalue = DistributionalQValueModule( @@ -1848,8 +1844,8 @@ def __init__( self.return_to_go_key = "return_to_go" self.inference_context = inference_context if spec is not None: - if not isinstance(spec, CompositeSpec) and len(self.out_keys) >= 1: - spec = CompositeSpec({self.action_key: spec}, shape=spec.shape[:-1]) + if not isinstance(spec, Composite) and len(self.out_keys) >= 1: + spec = Composite({self.action_key: spec}, shape=spec.shape[:-1]) self._spec = spec elif hasattr(self.td_module, "_spec"): self._spec = self.td_module._spec.clone() @@ -1860,7 +1856,7 @@ def __init__( if self.action_key not in self._spec.keys(): self._spec[self.action_key] = None else: - self._spec = CompositeSpec({key: None for key in policy.out_keys}) + self._spec = Composite({key: None for key in policy.out_keys}) self.checked = False @property @@ -1989,7 +1985,7 @@ class TanhModule(TensorDictModuleBase): Keyword Args: spec (TensorSpec, optional): if provided, the spec of the output. - If a CompositeSpec is provided, its key(s) must match the key(s) + If a Composite is provided, its key(s) must match the key(s) in out_keys. Otherwise, the key(s) of out_keys are assumed and the same spec is used for all outputs. low (float, np.ndarray or torch.Tensor): the lower bound of the space. @@ -2027,8 +2023,8 @@ class TanhModule(TensorDictModuleBase): >>> data['action'] tensor([-2.0000, 0.9991, 1.0000, -2.0000, -1.9991]) >>> # A spec can be provided - >>> from torchrl.data import BoundedTensorSpec - >>> spec = BoundedTensorSpec(low, high, shape=()) + >>> from torchrl.data import Bounded + >>> spec = Bounded(low, high, shape=()) >>> mod = TanhModule( ... in_keys=in_keys, ... low=low, @@ -2038,9 +2034,9 @@ class TanhModule(TensorDictModuleBase): ... ) >>> # One can also work with multiple keys >>> in_keys = ['a', 'b'] - >>> spec = CompositeSpec( - ... a=BoundedTensorSpec(-3, 0, shape=()), - ... b=BoundedTensorSpec(0, 3, shape=())) + >>> spec = Composite( + ... a=Bounded(-3, 0, shape=()), + ... b=Bounded(0, 3, shape=())) >>> mod = TanhModule( ... in_keys=in_keys, ... spec=spec, @@ -2077,13 +2073,13 @@ def __init__( ) self.out_keys = out_keys # action_spec can be a composite spec or not - if isinstance(spec, CompositeSpec): + if isinstance(spec, Composite): for out_key in self.out_keys: if out_key not in spec.keys(True, True): spec[out_key] = None else: # if one spec is present, we assume it is the same for all keys - spec = CompositeSpec( + spec = Composite( {out_key: spec for out_key in out_keys}, ) diff --git a/torchrl/modules/tensordict_module/common.py b/torchrl/modules/tensordict_module/common.py index 11cc363b461..4018589bfa1 100644 --- a/torchrl/modules/tensordict_module/common.py +++ b/torchrl/modules/tensordict_module/common.py @@ -21,7 +21,7 @@ from torch import nn from torch.nn import functional as F -from torchrl.data.tensor_specs import CompositeSpec, TensorSpec +from torchrl.data.tensor_specs import Composite, TensorSpec from torchrl.data.utils import DEVICE_TYPING @@ -59,12 +59,12 @@ def _check_all_str(list_of_str, first_level=True): def _forward_hook_safe_action(module, tensordict_in, tensordict_out): try: spec = module.spec - if len(module.out_keys) > 1 and not isinstance(spec, CompositeSpec): + if len(module.out_keys) > 1 and not isinstance(spec, Composite): raise RuntimeError( - "safe TensorDictModules with multiple out_keys require a CompositeSpec with matching keys. Got " + "safe TensorDictModules with multiple out_keys require a Composite with matching keys. Got " f"keys {module.out_keys}." ) - elif not isinstance(spec, CompositeSpec): + elif not isinstance(spec, Composite): out_key = module.out_keys[0] keys = [out_key] values = [spec] @@ -138,10 +138,10 @@ class SafeModule(TensorDictModule): Examples: >>> import torch >>> from tensordict import TensorDict - >>> from torchrl.data import UnboundedContinuousTensorSpec + >>> from torchrl.data import Unbounded >>> from torchrl.modules import TensorDictModule >>> td = TensorDict({"input": torch.randn(3, 4), "hidden": torch.randn(3, 8)}, [3,]) - >>> spec = UnboundedContinuousTensorSpec(8) + >>> spec = Unbounded(8) >>> module = torch.nn.GRUCell(4, 8) >>> td_fmodule = TensorDictModule( ... module=module, @@ -216,18 +216,18 @@ def register_spec(self, safe, spec): spec = spec.clone() if spec is not None and not isinstance(spec, TensorSpec): raise TypeError("spec must be a TensorSpec subclass") - elif spec is not None and not isinstance(spec, CompositeSpec): + elif spec is not None and not isinstance(spec, Composite): if len(self.out_keys) > 1: raise RuntimeError( f"got more than one out_key for the TensorDictModule: {self.out_keys},\nbut only one spec. " - "Consider using a CompositeSpec object or no spec at all." + "Consider using a Composite object or no spec at all." ) - spec = CompositeSpec({self.out_keys[0]: spec}) - elif spec is not None and isinstance(spec, CompositeSpec): + spec = Composite({self.out_keys[0]: spec}) + elif spec is not None and isinstance(spec, Composite): if "_" in spec.keys() and spec["_"] is not None: warnings.warn('got a spec with key "_": it will be ignored') elif spec is None: - spec = CompositeSpec() + spec = Composite() # unravel_key_list(self.out_keys) can be removed once 473 is merged in tensordict spec_keys = set(unravel_key_list(list(spec.keys(True, True)))) @@ -247,7 +247,7 @@ def register_spec(self, safe, spec): self.safe = safe if safe: if spec is None or ( - isinstance(spec, CompositeSpec) + isinstance(spec, Composite) and all(_spec is None for _spec in spec.values()) ): raise RuntimeError( @@ -257,14 +257,14 @@ def register_spec(self, safe, spec): self.register_forward_hook(_forward_hook_safe_action) @property - def spec(self) -> CompositeSpec: + def spec(self) -> Composite: return self._spec @spec.setter - def spec(self, spec: CompositeSpec) -> None: - if not isinstance(spec, CompositeSpec): + def spec(self, spec: Composite) -> None: + if not isinstance(spec, Composite): raise RuntimeError( - f"Trying to set an object of type {type(spec)} as a tensorspec but expected a CompositeSpec instance." + f"Trying to set an object of type {type(spec)} as a tensorspec but expected a Composite instance." ) self._spec = spec @@ -350,7 +350,7 @@ def is_tensordict_compatible(module: Union[TensorDictModule, nn.Module]): # if in_keys or out_keys were defined but module is not a TensorDictModule or # accepts multiple arguments then it's likely the user is trying to do something - # that will have undetermined behaviour, we raise an error + # that will have undetermined behavior, we raise an error raise TypeError( "Received a module that defines in_keys or out_keys and also expects multiple " "arguments to module.forward. If the module is compatible with TensorDict, it " @@ -403,7 +403,7 @@ def ensure_tensordict_compatible( "env.observation_spec. If you want TorchRL to automatically " "wrap your module with a TensorDictModule then the arguments " "to module must correspond one-to-one with entries in " - "in_keys. For more complex behaviour and more control you can " + "in_keys. For more complex behavior and more control you can " "consider writing your own TensorDictModule." ) diff --git a/torchrl/modules/tensordict_module/exploration.py b/torchrl/modules/tensordict_module/exploration.py index 5a41f11bf76..7337d1c94dd 100644 --- a/torchrl/modules/tensordict_module/exploration.py +++ b/torchrl/modules/tensordict_module/exploration.py @@ -16,7 +16,7 @@ ) from tensordict.utils import expand_as_right, expand_right, NestedKey -from torchrl.data.tensor_specs import CompositeSpec, TensorSpec +from torchrl.data.tensor_specs import Composite, TensorSpec from torchrl.envs.utils import exploration_type, ExplorationType from torchrl.modules.tensordict_module.common import _forward_hook_safe_action @@ -64,9 +64,9 @@ class EGreedyModule(TensorDictModuleBase): >>> from tensordict import TensorDict >>> from tensordict.nn import TensorDictSequential >>> from torchrl.modules import EGreedyModule, Actor - >>> from torchrl.data import BoundedTensorSpec + >>> from torchrl.data import Bounded >>> torch.manual_seed(0) - >>> spec = BoundedTensorSpec(-1, 1, torch.Size([4])) + >>> spec = Bounded(-1, 1, torch.Size([4])) >>> module = torch.nn.Linear(4, 4, bias=False) >>> policy = Actor(spec=spec, module=module) >>> explorative_policy = TensorDictSequential(policy, EGreedyModule(eps_init=0.2)) @@ -115,8 +115,8 @@ def __init__( self.register_buffer("eps", torch.as_tensor([eps_init], dtype=torch.float32)) if spec is not None: - if not isinstance(spec, CompositeSpec) and len(self.out_keys) >= 1: - spec = CompositeSpec({action_key: spec}, shape=spec.shape[:-1]) + if not isinstance(spec, Composite) and len(self.out_keys) >= 1: + spec = Composite({action_key: spec}, shape=spec.shape[:-1]) self._spec = spec @property @@ -155,7 +155,7 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase: cond = expand_as_right(cond, out) spec = self.spec if spec is not None: - if isinstance(spec, CompositeSpec): + if isinstance(spec, Composite): spec = spec[self.action_key] if spec.shape != out.shape: # In batched envs if the spec is passed unbatched, the rand() will not @@ -214,9 +214,9 @@ class EGreedyWrapper(TensorDictModuleWrapper): >>> import torch >>> from tensordict import TensorDict >>> from torchrl.modules import EGreedyWrapper, Actor - >>> from torchrl.data import BoundedTensorSpec + >>> from torchrl.data import Bounded >>> torch.manual_seed(0) - >>> spec = BoundedTensorSpec(-1, 1, torch.Size([4])) + >>> spec = Bounded(-1, 1, torch.Size([4])) >>> module = torch.nn.Linear(4, 4, bias=False) >>> policy = Actor(spec=spec, module=module) >>> explorative_policy = EGreedyWrapper(policy, eps_init=0.2) @@ -267,7 +267,7 @@ class AdditiveGaussianWrapper(TensorDictModuleWrapper): mean (float, optional): mean of each output element’s normal distribution. std (float, optional): standard deviation of each output element’s normal distribution. action_key (NestedKey, optional): if the policy module has more than one output key, - its output spec will be of type CompositeSpec. One needs to know where to + its output spec will be of type Composite. One needs to know where to find the action spec. Default is "action". spec (TensorSpec, optional): if provided, the sampled action will be @@ -323,8 +323,8 @@ def __init__( f"The action key {action_key} was not found in the td_module out_keys {self.td_module.out_keys}." ) if spec is not None: - if not isinstance(spec, CompositeSpec) and len(self.out_keys) >= 1: - spec = CompositeSpec({action_key: spec}, shape=spec.shape[:-1]) + if not isinstance(spec, Composite) and len(self.out_keys) >= 1: + spec = Composite({action_key: spec}, shape=spec.shape[:-1]) self._spec = spec elif hasattr(self.td_module, "_spec"): self._spec = self.td_module._spec.clone() @@ -335,7 +335,7 @@ def __init__( if action_key not in self._spec.keys(True, True): self._spec[action_key] = None else: - self._spec = CompositeSpec({key: None for key in policy.out_keys}) + self._spec = Composite({key: None for key in policy.out_keys}) self.safe = safe if self.safe: @@ -410,7 +410,7 @@ class AdditiveGaussianModule(TensorDictModuleBase): Keyword Args: action_key (NestedKey, optional): if the policy module has more than one output key, - its output spec will be of type CompositeSpec. One needs to know where to + its output spec will be of type Composite. One needs to know where to find the action spec. default: "action" @@ -453,8 +453,8 @@ def __init__( self.register_buffer("sigma", torch.tensor([sigma_init], dtype=torch.float32)) if spec is not None: - if not isinstance(spec, CompositeSpec) and len(self.out_keys) >= 1: - spec = CompositeSpec({action_key: spec}, shape=spec.shape[:-1]) + if not isinstance(spec, Composite) and len(self.out_keys) >= 1: + spec = Composite({action_key: spec}, shape=spec.shape[:-1]) else: raise RuntimeError("spec cannot be None.") self._spec = spec @@ -570,10 +570,10 @@ class OrnsteinUhlenbeckProcessWrapper(TensorDictModuleWrapper): Examples: >>> import torch >>> from tensordict import TensorDict - >>> from torchrl.data import BoundedTensorSpec + >>> from torchrl.data import Bounded >>> from torchrl.modules import OrnsteinUhlenbeckProcessWrapper, Actor >>> torch.manual_seed(0) - >>> spec = BoundedTensorSpec(-1, 1, torch.Size([4])) + >>> spec = Bounded(-1, 1, torch.Size([4])) >>> module = torch.nn.Linear(4, 4, bias=False) >>> policy = Actor(module=module, spec=spec) >>> explorative_policy = OrnsteinUhlenbeckProcessWrapper(policy) @@ -647,8 +647,8 @@ def __init__( steps_key = self.ou.steps_key if spec is not None: - if not isinstance(spec, CompositeSpec) and len(self.out_keys) >= 1: - spec = CompositeSpec({action_key: spec}, shape=spec.shape[:-1]) + if not isinstance(spec, Composite) and len(self.out_keys) >= 1: + spec = Composite({action_key: spec}, shape=spec.shape[:-1]) self._spec = spec elif hasattr(self.td_module, "_spec"): self._spec = self.td_module._spec.clone() @@ -659,7 +659,7 @@ def __init__( if action_key not in self._spec.keys(True, True): self._spec[action_key] = None else: - self._spec = CompositeSpec({key: None for key in policy.out_keys}) + self._spec = Composite({key: None for key in policy.out_keys}) ou_specs = { noise_key: None, steps_key: None, @@ -707,7 +707,7 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase: f"The tensordict passed to {self.__class__.__name__} appears to be " f"missing the '{self.is_init_key}' entry. This entry is used to " f"reset the noise at the beginning of a trajectory, without it " - f"the behaviour of this exploration method is undefined. " + f"the behavior of this exploration method is undefined. " f"This is allowed for BC compatibility purposes but it will be deprecated soon! " f"To create a '{self.is_init_key}' entry, simply append an torchrl.envs.InitTracker " f"transform to your environment with `env = TransformedEnv(env, InitTracker())`." @@ -783,10 +783,10 @@ class OrnsteinUhlenbeckProcessModule(TensorDictModuleBase): >>> import torch >>> from tensordict import TensorDict >>> from tensordict.nn import TensorDictSequential - >>> from torchrl.data import BoundedTensorSpec + >>> from torchrl.data import Bounded >>> from torchrl.modules import OrnsteinUhlenbeckProcessModule, Actor >>> torch.manual_seed(0) - >>> spec = BoundedTensorSpec(-1, 1, torch.Size([4])) + >>> spec = Bounded(-1, 1, torch.Size([4])) >>> module = torch.nn.Linear(4, 4, bias=False) >>> policy = Actor(module=module, spec=spec) >>> ou = OrnsteinUhlenbeckProcessModule(spec=spec) @@ -851,8 +851,8 @@ def __init__( steps_key = self.ou.steps_key if spec is not None: - if not isinstance(spec, CompositeSpec) and len(self.out_keys) >= 1: - spec = CompositeSpec({action_key: spec}, shape=spec.shape[:-1]) + if not isinstance(spec, Composite) and len(self.out_keys) >= 1: + spec = Composite({action_key: spec}, shape=spec.shape[:-1]) self._spec = spec else: raise RuntimeError("spec cannot be None.") @@ -900,7 +900,7 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase: f"The tensordict passed to {self.__class__.__name__} appears to be " f"missing the '{self.is_init_key}' entry. This entry is used to " f"reset the noise at the beginning of a trajectory, without it " - f"the behaviour of this exploration method is undefined. " + f"the behavior of this exploration method is undefined. " f"This is allowed for BC compatibility purposes but it will be deprecated soon! " f"To create a '{self.is_init_key}' entry, simply append an torchrl.envs.InitTracker " f"transform to your environment with `env = TransformedEnv(env, InitTracker())`." diff --git a/torchrl/modules/tensordict_module/probabilistic.py b/torchrl/modules/tensordict_module/probabilistic.py index 725323e1a28..483d9b90eea 100644 --- a/torchrl/modules/tensordict_module/probabilistic.py +++ b/torchrl/modules/tensordict_module/probabilistic.py @@ -15,7 +15,7 @@ TensorDictModule, ) from tensordict.utils import NestedKey -from torchrl.data.tensor_specs import CompositeSpec, TensorSpec +from torchrl.data.tensor_specs import Composite, TensorSpec from torchrl.modules.distributions import Delta from torchrl.modules.tensordict_module.common import _forward_hook_safe_action from torchrl.modules.tensordict_module.sequence import SafeSequential @@ -104,7 +104,6 @@ def __init__( out_keys: Optional[Union[NestedKey, List[NestedKey]]] = None, spec: Optional[TensorSpec] = None, safe: bool = False, - default_interaction_mode: str = None, default_interaction_type: str = InteractionType.DETERMINISTIC, distribution_class: Type = Delta, distribution_kwargs: Optional[dict] = None, @@ -117,7 +116,6 @@ def __init__( in_keys=in_keys, out_keys=out_keys, default_interaction_type=default_interaction_type, - default_interaction_mode=default_interaction_mode, distribution_class=distribution_class, distribution_kwargs=distribution_kwargs, return_log_prob=return_log_prob, @@ -129,18 +127,18 @@ def __init__( spec = spec.clone() if spec is not None and not isinstance(spec, TensorSpec): raise TypeError("spec must be a TensorSpec subclass") - elif spec is not None and not isinstance(spec, CompositeSpec): + elif spec is not None and not isinstance(spec, Composite): if len(self.out_keys) > 1: raise RuntimeError( f"got more than one out_key for the SafeModule: {self.out_keys},\nbut only one spec. " - "Consider using a CompositeSpec object or no spec at all." + "Consider using a Composite object or no spec at all." ) - spec = CompositeSpec({self.out_keys[0]: spec}) - elif spec is not None and isinstance(spec, CompositeSpec): + spec = Composite({self.out_keys[0]: spec}) + elif spec is not None and isinstance(spec, Composite): if "_" in spec.keys(): warnings.warn('got a spec with key "_": it will be ignored') elif spec is None: - spec = CompositeSpec() + spec = Composite() spec_keys = set(unravel_key_list(list(spec.keys(True, True)))) out_keys = set(unravel_key_list(self.out_keys)) if spec_keys != out_keys: @@ -159,7 +157,7 @@ def __init__( self.safe = safe if safe: if spec is None or ( - isinstance(spec, CompositeSpec) + isinstance(spec, Composite) and all(_spec is None for _spec in spec.values()) ): raise RuntimeError( @@ -169,14 +167,14 @@ def __init__( self.register_forward_hook(_forward_hook_safe_action) @property - def spec(self) -> CompositeSpec: + def spec(self) -> Composite: return self._spec @spec.setter - def spec(self, spec: CompositeSpec) -> None: - if not isinstance(spec, CompositeSpec): + def spec(self, spec: Composite) -> None: + if not isinstance(spec, Composite): raise RuntimeError( - f"Trying to set an object of type {type(spec)} as a tensorspec but expected a CompositeSpec instance." + f"Trying to set an object of type {type(spec)} as a tensorspec but expected a Composite instance." ) self._spec = spec diff --git a/torchrl/modules/tensordict_module/rnn.py b/torchrl/modules/tensordict_module/rnn.py index 878fb13ebb8..f538f8e95c5 100644 --- a/torchrl/modules/tensordict_module/rnn.py +++ b/torchrl/modules/tensordict_module/rnn.py @@ -2,6 +2,8 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +from __future__ import annotations + from typing import Optional, Tuple import torch @@ -16,7 +18,7 @@ from torch import nn, Tensor from torch.nn.modules.rnn import RNNCellBase -from torchrl.data.tensor_specs import UnboundedContinuousTensorSpec +from torchrl.data.tensor_specs import Unbounded from torchrl.objectives.value.functional import ( _inv_pad_sequence, _split_and_pad_sequence, @@ -387,7 +389,7 @@ class LSTMModule(ModuleBase): .. note:: This module relies on specific ``recurrent_state`` keys being present in the input TensorDicts. To generate a :class:`~torchrl.envs.transforms.TensorDictPrimer` transform that will automatically add hidden states to the environment TensorDicts, use the method :func:`~torchrl.modules.rnn.LSTMModule.make_tensordict_primer`. - If this class is a submodule in a larger module, the method :func:`~torchrl.models.utils.get_primers_from_module` can be called + If this class is a submodule in a larger module, the method :func:`~torchrl.modules.utils.get_primers_from_module` can be called on the parent module to automatically generate the primer transforms required for all submodules, including this one. @@ -529,11 +531,14 @@ def make_tensordict_primer(self): inputs and outputs (recurrent states) during rollout execution. That way, the data can be shared across processes and dealt with properly. - Not including a ``TensorDictPrimer`` in the environment may result in poorly defined behaviours, for instance + Not including a ``TensorDictPrimer`` in the environment may result in poorly defined behaviors, for instance in parallel settings where a step involves copying the new recurrent state from ``"next"`` to the root tensordict, which the meth:`~torchrl.EnvBase.step_mdp` method will not be able to do as the recurrent states are not registered within the environment specs. + See :func:`torchrl.modules.utils.get_primers_from_module` for a method to generate all primers for a given + module. + Examples: >>> from torchrl.collectors import SyncDataCollector >>> from torchrl.envs import TransformedEnv, InitTracker @@ -581,12 +586,8 @@ def make_tuple(key): ) return TensorDictPrimer( { - in_key1: UnboundedContinuousTensorSpec( - shape=(self.lstm.num_layers, self.lstm.hidden_size) - ), - in_key2: UnboundedContinuousTensorSpec( - shape=(self.lstm.num_layers, self.lstm.hidden_size) - ), + in_key1: Unbounded(shape=(self.lstm.num_layers, self.lstm.hidden_size)), + in_key2: Unbounded(shape=(self.lstm.num_layers, self.lstm.hidden_size)), } ) @@ -609,7 +610,7 @@ def temporal_mode(self): def set_recurrent_mode(self, mode: bool = True): """Returns a new copy of the module that shares the same lstm model but with a different ``recurrent_mode`` attribute (if it differs). - A copy is created such that the module can be used with divergent behaviour + A copy is created such that the module can be used with divergent behavior in various parts of the code (inference vs training): Examples: @@ -623,7 +624,7 @@ def set_recurrent_mode(self, mode: bool = True): >>> lstm = nn.LSTM(input_size=env.observation_spec["observation"].shape[-1], hidden_size=64, batch_first=True) >>> lstm_module = LSTMModule(lstm=lstm, in_keys=["observation", "hidden0", "hidden1"], out_keys=["intermediate", ("next", "hidden0"), ("next", "hidden1")]) >>> mlp = MLP(num_cells=[64], out_features=1) - >>> # building two policies with different behaviours: + >>> # building two policies with different behaviors: >>> policy_inference = Seq(lstm_module, Mod(mlp, in_keys=["intermediate"], out_keys=["action"])) >>> policy_training = Seq(lstm_module.set_recurrent_mode(True), Mod(mlp, in_keys=["intermediate"], out_keys=["action"])) >>> traj_td = env.rollout(3) # some random temporal data @@ -665,7 +666,7 @@ def forward(self, tensordict: TensorDictBase): else: tensordict_shaped = tensordict.reshape(-1).unsqueeze(-1) - is_init = tensordict_shaped.get("is_init").squeeze(-1) + is_init = tensordict_shaped["is_init"].squeeze(-1) splits = None if self.recurrent_mode and is_init[..., 1:].any(): # if we have consecutive trajectories, things get a little more complicated @@ -679,7 +680,7 @@ def forward(self, tensordict: TensorDictBase): tensordict_shaped = _split_and_pad_sequence( tensordict_shaped.select(*self.in_keys, strict=False), splits ) - is_init = tensordict_shaped.get("is_init").squeeze(-1) + is_init = tensordict_shaped["is_init"].squeeze(-1) value, hidden0, hidden1 = ( tensordict_shaped.get(key, default) @@ -691,7 +692,7 @@ def forward(self, tensordict: TensorDictBase): # packed sequences do not help to get the accurate last hidden values # if splits is not None: # value = torch.nn.utils.rnn.pack_padded_sequence(value, splits, batch_first=True) - if is_init.any() and hidden0 is not None: + if hidden0 is not None: is_init_expand = expand_as_right(is_init, hidden0) hidden0 = torch.where(is_init_expand, 0, hidden0) hidden1 = torch.where(is_init_expand, 0, hidden1) @@ -1112,7 +1113,7 @@ class GRUModule(ModuleBase): .. note:: This module relies on specific ``recurrent_state`` keys being present in the input TensorDicts. To generate a :class:`~torchrl.envs.transforms.TensorDictPrimer` transform that will automatically add hidden states to the environment TensorDicts, use the method :func:`~torchrl.modules.rnn.GRUModule.make_tensordict_primer`. - If this class is a submodule in a larger module, the method :func:`~torchrl.models.utils.get_primers_from_module` can be called + If this class is a submodule in a larger module, the method :func:`~torchrl.modules.utils.get_primers_from_module` can be called on the parent module to automatically generate the primer transforms required for all submodules, including this one. Examples: @@ -1279,11 +1280,14 @@ def make_tensordict_primer(self): inputs and outputs (recurrent states) during rollout execution. That way, the data can be shared across processes and dealt with properly. - Not including a ``TensorDictPrimer`` in the environment may result in poorly defined behaviours, for instance + Not including a ``TensorDictPrimer`` in the environment may result in poorly defined behaviors, for instance in parallel settings where a step involves copying the new recurrent state from ``"next"`` to the root tensordict, which the meth:`~torchrl.EnvBase.step_mdp` method will not be able to do as the recurrent states are not registered within the environment specs. + See :func:`torchrl.modules.utils.get_primers_from_module` for a method to generate all primers for a given + module. + Examples: >>> from torchrl.collectors import SyncDataCollector >>> from torchrl.envs import TransformedEnv, InitTracker @@ -1329,9 +1333,7 @@ def make_tuple(key): ) return TensorDictPrimer( { - in_key1: UnboundedContinuousTensorSpec( - shape=(self.gru.num_layers, self.gru.hidden_size) - ), + in_key1: Unbounded(shape=(self.gru.num_layers, self.gru.hidden_size)), } ) @@ -1354,7 +1356,7 @@ def temporal_mode(self): def set_recurrent_mode(self, mode: bool = True): """Returns a new copy of the module that shares the same gru model but with a different ``recurrent_mode`` attribute (if it differs). - A copy is created such that the module can be used with divergent behaviour + A copy is created such that the module can be used with divergent behavior in various parts of the code (inference vs training): Examples: @@ -1367,7 +1369,7 @@ def set_recurrent_mode(self, mode: bool = True): >>> gru = nn.GRU(input_size=env.observation_spec["observation"].shape[-1], hidden_size=64, batch_first=True) >>> gru_module = GRUModule(gru=gru, in_keys=["observation", "hidden"], out_keys=["intermediate", ("next", "hidden")]) >>> mlp = MLP(num_cells=[64], out_features=1) - >>> # building two policies with different behaviours: + >>> # building two policies with different behaviors: >>> policy_inference = Seq(gru_module, Mod(mlp, in_keys=["intermediate"], out_keys=["action"])) >>> policy_training = Seq(gru_module.set_recurrent_mode(True), Mod(mlp, in_keys=["intermediate"], out_keys=["action"])) >>> traj_td = env.rollout(3) # some random temporal data @@ -1410,7 +1412,7 @@ def forward(self, tensordict: TensorDictBase): else: tensordict_shaped = tensordict.reshape(-1).unsqueeze(-1) - is_init = tensordict_shaped.get("is_init").squeeze(-1) + is_init = tensordict_shaped["is_init"].squeeze(-1) splits = None if self.recurrent_mode and is_init[..., 1:].any(): # if we have consecutive trajectories, things get a little more complicated @@ -1424,7 +1426,7 @@ def forward(self, tensordict: TensorDictBase): tensordict_shaped = _split_and_pad_sequence( tensordict_shaped.select(*self.in_keys, strict=False), splits ) - is_init = tensordict_shaped.get("is_init").squeeze(-1) + is_init = tensordict_shaped["is_init"].squeeze(-1) value, hidden = ( tensordict_shaped.get(key, default) diff --git a/torchrl/modules/tensordict_module/sequence.py b/torchrl/modules/tensordict_module/sequence.py index 41ddb55fb35..938843e624f 100644 --- a/torchrl/modules/tensordict_module/sequence.py +++ b/torchrl/modules/tensordict_module/sequence.py @@ -8,7 +8,7 @@ from tensordict.nn import TensorDictModule, TensorDictSequential from torch import nn -from torchrl.data.tensor_specs import CompositeSpec +from torchrl.data.tensor_specs import Composite from torchrl.modules.tensordict_module.common import SafeModule @@ -33,11 +33,11 @@ class SafeSequential(TensorDictSequential, SafeModule): Examples: >>> import torch >>> from tensordict import TensorDict - >>> from torchrl.data import CompositeSpec, UnboundedContinuousTensorSpec + >>> from torchrl.data import Composite, Unbounded >>> from torchrl.modules import TanhNormal, SafeSequential, TensorDictModule, NormalParamExtractor >>> from torchrl.modules.tensordict_module import SafeProbabilisticModule >>> td = TensorDict({"input": torch.randn(3, 4)}, [3,]) - >>> spec1 = CompositeSpec(hidden=UnboundedContinuousTensorSpec(4), loc=None, scale=None) + >>> spec1 = Composite(hidden=Unbounded(4), loc=None, scale=None) >>> net1 = nn.Sequential(torch.nn.Linear(4, 8), NormalParamExtractor()) >>> module1 = TensorDictModule(net1, in_keys=["input"], out_keys=["loc", "scale"]) >>> td_module1 = SafeProbabilisticModule( @@ -48,7 +48,7 @@ class SafeSequential(TensorDictSequential, SafeModule): ... distribution_class=TanhNormal, ... return_log_prob=True, ... ) - >>> spec2 = UnboundedContinuousTensorSpec(8) + >>> spec2 = Unbounded(8) >>> module2 = torch.nn.Linear(4, 8) >>> td_module2 = TensorDictModule( ... module=module2, @@ -74,12 +74,12 @@ class SafeSequential(TensorDictSequential, SafeModule): is_shared=False) >>> # The module spec aggregates all the input specs: >>> print(td_module.spec) - CompositeSpec( - hidden: UnboundedContinuousTensorSpec( + Composite( + hidden: UnboundedContinuous( shape=torch.Size([4]), space=None, device=cpu, dtype=torch.float32, domain=continuous), loc: None, scale: None, - output: UnboundedContinuousTensorSpec( + output: UnboundedContinuous( shape=torch.Size([8]), space=None, device=cpu, dtype=torch.float32, domain=continuous)) In the vmap case: @@ -112,12 +112,12 @@ def __init__( in_keys, out_keys = self._compute_in_and_out_keys(modules) - spec = CompositeSpec() + spec = Composite() for module in modules: try: spec.update(module.spec) except AttributeError: - spec.update(CompositeSpec({key: None for key in module.out_keys})) + spec.update(Composite({key: None for key in module.out_keys})) super(TensorDictSequential, self).__init__( spec=spec, diff --git a/torchrl/modules/utils/utils.py b/torchrl/modules/utils/utils.py index 0f3088a8943..9a8914aab89 100644 --- a/torchrl/modules/utils/utils.py +++ b/torchrl/modules/utils/utils.py @@ -46,8 +46,8 @@ def get_primers_from_module(module): >>> primers = get_primers_from_module(model) >>> print(primers) - TensorDictPrimer(primers=CompositeSpec( - recurrent_state: UnboundedContinuousTensorSpec( + TensorDictPrimer(primers=Composite( + recurrent_state: UnboundedContinuous( shape=torch.Size([1, 10]), space=None, device=cpu, diff --git a/torchrl/objectives/__init__.py b/torchrl/objectives/__init__.py index aa13a88c7e9..1ea9ebb5998 100644 --- a/torchrl/objectives/__init__.py +++ b/torchrl/objectives/__init__.py @@ -11,6 +11,7 @@ from .decision_transformer import DTLoss, OnlineDTLoss from .dqn import DistributionalDQNLoss, DQNLoss from .dreamer import DreamerActorLoss, DreamerModelLoss, DreamerValueLoss +from .gail import GAILLoss from .iql import DiscreteIQLLoss, IQLLoss from .multiagent import QMixerLoss from .ppo import ClipPPOLoss, KLPENPPOLoss, PPOLoss @@ -29,5 +30,3 @@ SoftUpdate, ValueEstimators, ) - -# from .value import bellman_max, c_val, dv_val, vtrace, GAE, TDLambdaEstimate, TDEstimate diff --git a/torchrl/objectives/a2c.py b/torchrl/objectives/a2c.py index a236b80d56c..c823788b4c2 100644 --- a/torchrl/objectives/a2c.py +++ b/torchrl/objectives/a2c.py @@ -10,7 +10,12 @@ from typing import Tuple import torch -from tensordict import TensorDict, TensorDictBase, TensorDictParams +from tensordict import ( + is_tensor_collection, + TensorDict, + TensorDictBase, + TensorDictParams, +) from tensordict.nn import dispatch, ProbabilisticTensorDictSequential, TensorDictModule from tensordict.utils import NestedKey from torch import distributions as d @@ -56,8 +61,9 @@ class A2CLoss(LossModule): ``samples_mc_entropy`` will control how many samples will be used to compute this estimate. Defaults to ``1``. - entropy_coef (float): the weight of the entropy loss. - critic_coef (float): the weight of the critic loss. + entropy_coef (float): the weight of the entropy loss. Defaults to `0.01``. + critic_coef (float): the weight of the critic loss. Defaults to ``1.0``. If ``None``, the critic + loss won't be included and the in-keys will miss the critic inputs. loss_critic_type (str): loss function for the value discrepancy. Can be one of "l1", "l2" or "smooth_l1". Defaults to ``"smooth_l1"``. separate_losses (bool, optional): if ``True``, shared parameters between @@ -96,14 +102,14 @@ class A2CLoss(LossModule): Examples: >>> import torch >>> from torch import nn - >>> from torchrl.data import BoundedTensorSpec + >>> from torchrl.data import Bounded >>> from torchrl.modules.distributions import NormalParamExtractor, TanhNormal >>> from torchrl.modules.tensordict_module.actors import ProbabilisticActor, ValueOperator >>> from torchrl.modules.tensordict_module.common import SafeModule >>> from torchrl.objectives.a2c import A2CLoss >>> from tensordict import TensorDict >>> n_act, n_obs = 4, 3 - >>> spec = BoundedTensorSpec(-torch.ones(n_act), torch.ones(n_act), (n_act,)) + >>> spec = Bounded(-torch.ones(n_act), torch.ones(n_act), (n_act,)) >>> net = nn.Sequential(nn.Linear(n_obs, 2 * n_act), NormalParamExtractor()) >>> module = SafeModule(net, in_keys=["observation"], out_keys=["loc", "scale"]) >>> actor = ProbabilisticActor( @@ -147,14 +153,14 @@ class A2CLoss(LossModule): Examples: >>> import torch >>> from torch import nn - >>> from torchrl.data import BoundedTensorSpec + >>> from torchrl.data import Bounded >>> from torchrl.modules.distributions import NormalParamExtractor, TanhNormal >>> from torchrl.modules.tensordict_module.actors import ProbabilisticActor, ValueOperator >>> from torchrl.modules.tensordict_module.common import SafeModule >>> from torchrl.objectives.a2c import A2CLoss >>> _ = torch.manual_seed(42) >>> n_act, n_obs = 4, 3 - >>> spec = BoundedTensorSpec(-torch.ones(n_act), torch.ones(n_act), (n_act,)) + >>> spec = Bounded(-torch.ones(n_act), torch.ones(n_act), (n_act,)) >>> net = nn.Sequential(nn.Linear(n_obs, 2 * n_act), NormalParamExtractor()) >>> module = SafeModule(net, in_keys=["observation"], out_keys=["loc", "scale"]) >>> actor = ProbabilisticActor( @@ -190,6 +196,13 @@ class A2CLoss(LossModule): ... next_reward = torch.randn(*batch, 1), ... next_observation = torch.randn(*batch, n_obs)) >>> loss_obj.backward() + + .. note:: + There is an exception regarding compatibility with non-tensordict-based modules. + If the actor network is probabilistic and uses a :class:`~tensordict.nn.distributions.CompositeDistribution`, + this class must be used with tensordicts and cannot function as a tensordict-independent module. + This is because composite action spaces inherently rely on the structured representation of data provided by + tensordicts to handle their actions. """ @dataclass @@ -311,7 +324,13 @@ def __init__( self.register_buffer( "entropy_coef", torch.as_tensor(entropy_coef, device=device) ) - self.register_buffer("critic_coef", torch.as_tensor(critic_coef, device=device)) + if critic_coef is not None: + self.register_buffer( + "critic_coef", torch.as_tensor(critic_coef, device=device) + ) + else: + self.critic_coef = None + if gamma is not None: raise TypeError(_GAMMA_LMBDA_DEPREC_ERROR) self.loss_critic_type = loss_critic_type @@ -344,7 +363,7 @@ def in_keys(self): *self.actor_network.in_keys, *[("next", key) for key in self.actor_network.in_keys], ] - if self.critic_coef: + if self.critic_coef is not None: keys.extend(self.critic_network.in_keys) return list(set(keys)) @@ -352,7 +371,7 @@ def in_keys(self): def out_keys(self): if self._out_keys is None: outs = ["loss_objective"] - if self.critic_coef: + if self.critic_coef is not None: outs.append("loss_critic") if self.entropy_bonus: outs.append("entropy") @@ -383,7 +402,10 @@ def get_entropy_bonus(self, dist: d.Distribution) -> torch.Tensor: entropy = dist.entropy() except NotImplementedError: x = dist.rsample((self.samples_mc_entropy,)) - entropy = -dist.log_prob(x).mean(0) + log_prob = dist.log_prob(x) + if is_tensor_collection(log_prob): + log_prob = log_prob.get(self.tensor_keys.sample_log_prob) + entropy = -log_prob.mean(0) return entropy.unsqueeze(-1) def _log_probs( @@ -391,10 +413,6 @@ def _log_probs( ) -> Tuple[torch.Tensor, d.Distribution]: # current log_prob of actions action = tensordict.get(self.tensor_keys.action) - if action.requires_grad: - raise RuntimeError( - f"tensordict stored {self.tensor_keys.action} require grad." - ) tensordict_clone = tensordict.select( *self.actor_network.in_keys, strict=False ).clone() @@ -402,11 +420,29 @@ def _log_probs( self.actor_network ) if self.functional else contextlib.nullcontext(): dist = self.actor_network.get_dist(tensordict_clone) - log_prob = dist.log_prob(action) + if action.requires_grad: + raise RuntimeError( + f"tensordict stored {self.tensor_keys.action} requires grad." + ) + if isinstance(action, torch.Tensor): + log_prob = dist.log_prob(action) + else: + maybe_log_prob = dist.log_prob(tensordict) + if not isinstance(maybe_log_prob, torch.Tensor): + # In some cases (Composite distribution with aggregate_probabilities toggled off) the returned type may not + # be a tensor + log_prob = maybe_log_prob.get(self.tensor_keys.sample_log_prob) + else: + log_prob = maybe_log_prob log_prob = log_prob.unsqueeze(-1) return log_prob, dist - def loss_critic(self, tensordict: TensorDictBase) -> torch.Tensor: + def loss_critic(self, tensordict: TensorDictBase) -> Tuple[torch.Tensor, float]: + """Returns the loss value of the critic, multiplied by ``critic_coef`` if it is not ``None``. + + Returns the loss and the clip-fraction. + + """ if self.clip_value: old_state_value = tensordict.get( self.tensor_keys.value, None @@ -456,7 +492,9 @@ def loss_critic(self, tensordict: TensorDictBase) -> torch.Tensor: loss_value, self.loss_critic_type, ) - return self.critic_coef * loss_value, clip_fraction + if self.critic_coef is not None: + return self.critic_coef * loss_value, clip_fraction + return loss_value, clip_fraction @property @_cache_values @@ -483,7 +521,7 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase: entropy = self.get_entropy_bonus(dist) td_out.set("entropy", entropy.detach().mean()) # for logging td_out.set("loss_entropy", -self.entropy_coef * entropy) - if self.critic_coef: + if self.critic_coef is not None: loss_critic, value_clip_fraction = self.loss_critic(tensordict) td_out.set("loss_critic", loss_critic) if value_clip_fraction is not None: diff --git a/torchrl/objectives/common.py b/torchrl/objectives/common.py index 5ceec84e36a..f6935ceae82 100644 --- a/torchrl/objectives/common.py +++ b/torchrl/objectives/common.py @@ -15,6 +15,7 @@ from tensordict import is_tensor_collection, TensorDict, TensorDictBase from tensordict.nn import TensorDictModule, TensorDictModuleBase, TensorDictParams +from tensordict.utils import Buffer from torch import nn from torch.nn import Parameter from torchrl._utils import RL_WARNINGS @@ -23,9 +24,18 @@ from torchrl.objectives.utils import RANDOM_MODULE_LIST, ValueEstimators from torchrl.objectives.value import ValueEstimatorBase +try: + from torch.compiler import is_dynamo_compiling +except ImportError: + from torch._dynamo import is_compiling as is_dynamo_compiling + def _updater_check_forward_prehook(module, *args, **kwargs): - if not all(module._has_update_associated.values()) and RL_WARNINGS: + if ( + not all(module._has_update_associated.values()) + and RL_WARNINGS + and not is_dynamo_compiling() + ): warnings.warn( module.TARGET_NET_WARNING, category=UserWarning, @@ -87,8 +97,8 @@ class LossModule(TensorDictModuleBase, metaclass=_LossMeta): >>> loss.set_keys(action="action2") .. note:: When a policy that is wrapped or augmented with an exploration module is passed - to the loss, we want to deactivate the exploration through ``set_exploration_mode()`` where - ```` is either ``ExplorationType.MEAN``, ``ExplorationType.MODE`` or + to the loss, we want to deactivate the exploration through ``set_exploration_type()`` where + ```` is either ``ExplorationType.MEAN``, ``ExplorationType.MODE`` or ``ExplorationType.DETERMINISTIC``. The default value is ``DETERMINISTIC`` and it is set through the ``deterministic_sampling_mode`` loss attribute. If another exploration mode is required (or if ``DETERMINISTIC`` is not available), one can @@ -217,8 +227,10 @@ def set_keys(self, **kwargs) -> None: >>> dqn_loss.set_keys(priority_key="td_error", action_value_key="action_value") """ for key, value in kwargs.items(): - if key not in self._AcceptedKeys.__dict__: - raise ValueError(f"{key} is not an accepted tensordict key") + if key not in self._AcceptedKeys.__dataclass_fields__: + raise ValueError( + f"{key} is not an accepted tensordict key. Accepted keys are: {self._AcceptedKeys.__dataclass_fields__}." + ) if value is not None: setattr(self.tensor_keys, key, value) else: @@ -415,7 +427,11 @@ def __getattr__(self, item): # no target param, take detached data params = getattr(self, item[7:]) params = params.data - elif not self._has_update_associated[item[7:-7]] and RL_WARNINGS: + elif ( + not self._has_update_associated[item[7:-7]] + and RL_WARNINGS + and not is_dynamo_compiling() + ): # no updater associated warnings.warn( self.TARGET_NET_WARNING, @@ -433,7 +449,7 @@ def _apply(self, fn): def _erase_cache(self): for key in list(self.__dict__): if key.startswith("_cache"): - del self.__dict__[key] + delattr(self, key) def _networks(self) -> Iterator[nn.Module]: for item in self.__dir__(): @@ -603,11 +619,10 @@ def __init__(self, clone): self.clone = clone def __call__(self, x): + x = x.data.clone() if self.clone else x.data if isinstance(x, nn.Parameter): - return nn.Parameter( - x.data.clone() if self.clone else x.data, requires_grad=False - ) - return x.data.clone() if self.clone else x.data + return Buffer(x) + return x def add_ramdom_module(module): diff --git a/torchrl/objectives/cql.py b/torchrl/objectives/cql.py index f1e2aa9c532..fb8fbff2ccf 100644 --- a/torchrl/objectives/cql.py +++ b/torchrl/objectives/cql.py @@ -19,7 +19,7 @@ from tensordict.utils import NestedKey, unravel_key from torch import Tensor -from torchrl.data.tensor_specs import CompositeSpec +from torchrl.data.tensor_specs import Composite from torchrl.data.utils import _find_action_space from torchrl.envs.utils import ExplorationType, set_exploration_type @@ -100,14 +100,14 @@ class CQLLoss(LossModule): Examples: >>> import torch >>> from torch import nn - >>> from torchrl.data import BoundedTensorSpec + >>> from torchrl.data import Bounded >>> from torchrl.modules.distributions import NormalParamExtractor, TanhNormal >>> from torchrl.modules.tensordict_module.actors import ProbabilisticActor, ValueOperator >>> from torchrl.modules.tensordict_module.common import SafeModule >>> from torchrl.objectives.cql import CQLLoss >>> from tensordict import TensorDict >>> n_act, n_obs = 4, 3 - >>> spec = BoundedTensorSpec(-torch.ones(n_act), torch.ones(n_act), (n_act,)) + >>> spec = Bounded(-torch.ones(n_act), torch.ones(n_act), (n_act,)) >>> net = nn.Sequential(nn.Linear(n_obs, 2 * n_act), NormalParamExtractor()) >>> module = SafeModule(net, in_keys=["observation"], out_keys=["loc", "scale"]) >>> actor = ProbabilisticActor( @@ -160,14 +160,14 @@ class CQLLoss(LossModule): Examples: >>> import torch >>> from torch import nn - >>> from torchrl.data import BoundedTensorSpec + >>> from torchrl.data import Bounded >>> from torchrl.modules.distributions import NormalParamExtractor, TanhNormal >>> from torchrl.modules.tensordict_module.actors import ProbabilisticActor, ValueOperator >>> from torchrl.modules.tensordict_module.common import SafeModule >>> from torchrl.objectives.cql import CQLLoss >>> _ = torch.manual_seed(42) >>> n_act, n_obs = 4, 3 - >>> spec = BoundedTensorSpec(-torch.ones(n_act), torch.ones(n_act), (n_act,)) + >>> spec = Bounded(-torch.ones(n_act), torch.ones(n_act), (n_act,)) >>> net = nn.Sequential(nn.Linear(n_obs, 2 * n_act), NormalParamExtractor()) >>> module = SafeModule(net, in_keys=["observation"], out_keys=["loc", "scale"]) >>> actor = ProbabilisticActor( @@ -405,8 +405,8 @@ def target_entropy(self): "the target entropy explicitely or provide the spec of the " "action tensor in the actor network." ) - if not isinstance(action_spec, CompositeSpec): - action_spec = CompositeSpec({self.tensor_keys.action: action_spec}) + if not isinstance(action_spec, Composite): + action_spec = Composite({self.tensor_keys.action: action_spec}) if ( isinstance(self.tensor_keys.action, tuple) and len(self.tensor_keys.action) > 1 @@ -521,16 +521,16 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase: else: tensordict_reshape = tensordict - td_device = tensordict_reshape.to(tensordict.device) - - q_loss, metadata = self.q_loss(td_device) - cql_loss, cql_metadata = self.cql_loss(td_device) + q_loss, metadata = self.q_loss(tensordict_reshape) + cql_loss, cql_metadata = self.cql_loss(tensordict_reshape) if self.with_lagrange: - alpha_prime_loss, alpha_prime_metadata = self.alpha_prime_loss(td_device) + alpha_prime_loss, alpha_prime_metadata = self.alpha_prime_loss( + tensordict_reshape + ) metadata.update(alpha_prime_metadata) - loss_actor_bc, bc_metadata = self.actor_bc_loss(td_device) - loss_actor, actor_metadata = self.actor_loss(td_device) - loss_alpha, alpha_metadata = self.alpha_loss(td_device) + loss_actor_bc, bc_metadata = self.actor_bc_loss(tensordict_reshape) + loss_actor, actor_metadata = self.actor_loss(tensordict_reshape) + loss_alpha, alpha_metadata = self.alpha_loss(actor_metadata) metadata.update(bc_metadata) metadata.update(cql_metadata) metadata.update(actor_metadata) @@ -547,7 +547,7 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase: "loss_cql": cql_loss, "loss_alpha": loss_alpha, "alpha": self._alpha, - "entropy": -td_device.get(self.tensor_keys.log_prob).mean().detach(), + "entropy": -actor_metadata.get(self.tensor_keys.log_prob).mean().detach(), } if self.with_lagrange: out["loss_alpha_prime"] = alpha_prime_loss.mean() @@ -574,7 +574,7 @@ def actor_bc_loss(self, tensordict: TensorDictBase) -> Tensor: metadata = {"bc_log_prob": bc_log_prob.mean().detach()} return bc_actor_loss, metadata - def actor_loss(self, tensordict: TensorDictBase) -> Tensor: + def actor_loss(self, tensordict: TensorDictBase) -> Tuple[Tensor, dict]: with set_exploration_type( ExplorationType.RANDOM ), self.actor_network_params.to_module(self.actor_network): @@ -585,6 +585,8 @@ def actor_loss(self, tensordict: TensorDictBase) -> Tensor: log_prob = dist.log_prob(a_reparm) td_q = tensordict.select(*self.qvalue_network.in_keys, strict=False) + if td_q is tensordict: + raise RuntimeError td_q.set(self.tensor_keys.action, a_reparm) td_q = self._vmap_qvalue_networkN0( td_q, @@ -599,12 +601,12 @@ def actor_loss(self, tensordict: TensorDictBase) -> Tensor: f"Losses shape mismatch: {log_prob.shape} and {min_q_logprob.shape}" ) - # write log_prob in tensordict for alpha loss - tensordict.set(self.tensor_keys.log_prob, log_prob.detach()) + metadata = {} + metadata[self.tensor_keys.log_prob] = log_prob.detach() actor_loss = self._alpha * log_prob - min_q_logprob actor_loss = _reduce(actor_loss, reduction=self.reduction) - return actor_loss, {} + return actor_loss, metadata def _get_policy_actions(self, data, actor_params, num_actions=10): batch_size = data.batch_size @@ -667,7 +669,7 @@ def _get_value_v(self, tensordict, _alpha, actor_params, qval_params): if self.max_q_backup: next_tensordict, _ = self._get_policy_actions( - tensordict.get("next"), + tensordict.get("next").copy(), actor_params, num_actions=self.num_random, ) @@ -691,10 +693,10 @@ def _get_value_v(self, tensordict, _alpha, actor_params, qval_params): target_value = self.value_estimator.value_estimate(tensordict).squeeze(-1) return target_value - def q_loss(self, tensordict: TensorDictBase) -> Tensor: + def q_loss(self, tensordict: TensorDictBase) -> Tuple[Tensor, dict]: # we pass the alpha value to the tensordict. Since it's a scalar, we must erase the batch-size first. target_value = self._get_value_v( - tensordict, + tensordict.copy(), self._alpha, self.actor_network_params, self.target_qvalue_network_params, @@ -722,7 +724,7 @@ def q_loss(self, tensordict: TensorDictBase) -> Tensor: metadata = {"td_error": td_error.detach()} return loss_qval, metadata - def cql_loss(self, tensordict: TensorDictBase) -> Tensor: + def cql_loss(self, tensordict: TensorDictBase) -> Tuple[Tensor, dict]: pred_q1 = tensordict.get(self.tensor_keys.pred_q1) pred_q2 = tensordict.get(self.tensor_keys.pred_q2) @@ -746,12 +748,12 @@ def cql_loss(self, tensordict: TensorDictBase) -> Tensor: .to(tensordict.device) ) curr_actions_td, curr_log_pis = self._get_policy_actions( - tensordict, + tensordict.copy(), self.actor_network_params, num_actions=self.num_random, ) new_curr_actions_td, new_log_pis = self._get_policy_actions( - tensordict.get("next"), + tensordict.get("next").copy(), self.actor_network_params, num_actions=self.num_random, ) @@ -933,11 +935,11 @@ class DiscreteCQLLoss(LossModule): Examples: >>> from torchrl.modules import MLP, QValueActor - >>> from torchrl.data import OneHotDiscreteTensorSpec + >>> from torchrl.data import OneHot >>> from torchrl.objectives import DiscreteCQLLoss >>> n_obs, n_act = 4, 3 >>> value_net = MLP(in_features=n_obs, out_features=n_act) - >>> spec = OneHotDiscreteTensorSpec(n_act) + >>> spec = OneHot(n_act) >>> actor = QValueActor(value_net, in_keys=["observation"], action_space=spec) >>> loss = DiscreteCQLLoss(actor, action_space=spec) >>> batch = [10,] @@ -969,12 +971,12 @@ class DiscreteCQLLoss(LossModule): Examples: >>> from torchrl.objectives import DiscreteCQLLoss - >>> from torchrl.data import OneHotDiscreteTensorSpec + >>> from torchrl.data import OneHot >>> from torch import nn >>> import torch >>> n_obs = 3 >>> n_action = 4 - >>> action_spec = OneHotDiscreteTensorSpec(n_action) + >>> action_spec = OneHot(n_action) >>> value_network = nn.Linear(n_obs, n_action) # a simple value model >>> dcql_loss = DiscreteCQLLoss(value_network, action_space=action_spec) >>> # define data @@ -1089,7 +1091,7 @@ def __init__( if action_space is None: warnings.warn( "action_space was not specified. DiscreteCQLLoss will default to 'one-hot'. " - "This behaviour will be deprecated soon and a space will have to be passed. " + "This behavior will be deprecated soon and a space will have to be passed. " "Check the DiscreteCQLLoss documentation to see how to pass the action space." ) action_space = "one-hot" diff --git a/torchrl/objectives/crossq.py b/torchrl/objectives/crossq.py index e76e3438c09..d86442fca12 100644 --- a/torchrl/objectives/crossq.py +++ b/torchrl/objectives/crossq.py @@ -15,7 +15,7 @@ from tensordict.nn import dispatch, TensorDictModule from tensordict.utils import NestedKey from torch import Tensor -from torchrl.data.tensor_specs import CompositeSpec +from torchrl.data.tensor_specs import Composite from torchrl.envs.utils import ExplorationType, set_exploration_type from torchrl.modules import ProbabilisticActor from torchrl.objectives.common import LossModule @@ -98,14 +98,14 @@ class CrossQLoss(LossModule): Examples: >>> import torch >>> from torch import nn - >>> from torchrl.data import BoundedTensorSpec + >>> from torchrl.data import Bounded >>> from torchrl.modules.distributions import NormalParamExtractor, TanhNormal >>> from torchrl.modules.tensordict_module.actors import ProbabilisticActor, ValueOperator >>> from torchrl.modules.tensordict_module.common import SafeModule >>> from torchrl.objectives.crossq import CrossQLoss >>> from tensordict import TensorDict >>> n_act, n_obs = 4, 3 - >>> spec = BoundedTensorSpec(-torch.ones(n_act), torch.ones(n_act), (n_act,)) + >>> spec = Bounded(-torch.ones(n_act), torch.ones(n_act), (n_act,)) >>> net = nn.Sequential(nn.Linear(n_obs, 2 * n_act), NormalParamExtractor()) >>> module = SafeModule(net, in_keys=["observation"], out_keys=["loc", "scale"]) >>> actor = ProbabilisticActor( @@ -156,14 +156,14 @@ class CrossQLoss(LossModule): Examples: >>> import torch >>> from torch import nn - >>> from torchrl.data import BoundedTensorSpec + >>> from torchrl.data import Bounded >>> from torchrl.modules.distributions import NormalParamExtractor, TanhNormal >>> from torchrl.modules.tensordict_module.actors import ProbabilisticActor, ValueOperator >>> from torchrl.modules.tensordict_module.common import SafeModule >>> from torchrl.objectives import CrossQLoss >>> _ = torch.manual_seed(42) >>> n_act, n_obs = 4, 3 - >>> spec = BoundedTensorSpec(-torch.ones(n_act), torch.ones(n_act), (n_act,)) + >>> spec = Bounded(-torch.ones(n_act), torch.ones(n_act), (n_act,)) >>> net = nn.Sequential(nn.Linear(n_obs, 2 * n_act), NormalParamExtractor()) >>> module = SafeModule(net, in_keys=["observation"], out_keys=["loc", "scale"]) >>> actor = ProbabilisticActor( @@ -375,8 +375,8 @@ def target_entropy(self): "the target entropy explicitely or provide the spec of the " "action tensor in the actor network." ) - if not isinstance(action_spec, CompositeSpec): - action_spec = CompositeSpec({self.tensor_keys.action: action_spec}) + if not isinstance(action_spec, Composite): + action_spec = Composite({self.tensor_keys.action: action_spec}) if ( isinstance(self.tensor_keys.action, tuple) and len(self.tensor_keys.action) > 1 diff --git a/torchrl/objectives/ddpg.py b/torchrl/objectives/ddpg.py index 6e1cf0f5eb3..7dc6b23212a 100644 --- a/torchrl/objectives/ddpg.py +++ b/torchrl/objectives/ddpg.py @@ -50,12 +50,12 @@ class DDPGLoss(LossModule): Examples: >>> import torch >>> from torch import nn - >>> from torchrl.data import BoundedTensorSpec + >>> from torchrl.data import Bounded >>> from torchrl.modules.tensordict_module.actors import Actor, ValueOperator >>> from torchrl.objectives.ddpg import DDPGLoss >>> from tensordict import TensorDict >>> n_act, n_obs = 4, 3 - >>> spec = BoundedTensorSpec(-torch.ones(n_act), torch.ones(n_act), (n_act,)) + >>> spec = Bounded(-torch.ones(n_act), torch.ones(n_act), (n_act,)) >>> actor = Actor(spec=spec, module=nn.Linear(n_obs, n_act)) >>> class ValueClass(nn.Module): ... def __init__(self): @@ -100,12 +100,12 @@ class DDPGLoss(LossModule): Examples: >>> import torch >>> from torch import nn - >>> from torchrl.data import BoundedTensorSpec + >>> from torchrl.data import Bounded >>> from torchrl.modules.tensordict_module.actors import Actor, ValueOperator >>> from torchrl.objectives.ddpg import DDPGLoss >>> _ = torch.manual_seed(42) >>> n_act, n_obs = 4, 3 - >>> spec = BoundedTensorSpec(-torch.ones(n_act), torch.ones(n_act), (n_act,)) + >>> spec = Bounded(-torch.ones(n_act), torch.ones(n_act), (n_act,)) >>> actor = Actor(spec=spec, module=nn.Linear(n_obs, n_act)) >>> class ValueClass(nn.Module): ... def __init__(self): diff --git a/torchrl/objectives/deprecated.py b/torchrl/objectives/deprecated.py index c1ed8b2cffe..32394942600 100644 --- a/torchrl/objectives/deprecated.py +++ b/torchrl/objectives/deprecated.py @@ -17,7 +17,7 @@ from tensordict.utils import NestedKey from torch import Tensor -from torchrl.data.tensor_specs import CompositeSpec +from torchrl.data.tensor_specs import Composite from torchrl.envs.utils import ExplorationType, set_exploration_type, step_mdp from torchrl.objectives import default_value_kwargs, distance_loss, ValueEstimators from torchrl.objectives.common import LossModule @@ -251,8 +251,8 @@ def target_entropy(self): "the target entropy explicitely or provide the spec of the " "action tensor in the actor network." ) - if not isinstance(action_spec, CompositeSpec): - action_spec = CompositeSpec({self.tensor_keys.action: action_spec}) + if not isinstance(action_spec, Composite): + action_spec = Composite({self.tensor_keys.action: action_spec}) target_entropy = -float( np.prod(action_spec[self.tensor_keys.action].shape) ) @@ -465,7 +465,7 @@ def make_value_estimator(self, value_type: ValueEstimators = None, **hyperparams class DoubleREDQLoss_deprecated(REDQLoss_deprecated): - """[Deprecated] Class for delayed target-REDQ (which should be the default behaviour).""" + """[Deprecated] Class for delayed target-REDQ (which should be the default behavior).""" delay_qvalue: bool = True diff --git a/torchrl/objectives/dqn.py b/torchrl/objectives/dqn.py index 7b35598c474..6cbb8b02426 100644 --- a/torchrl/objectives/dqn.py +++ b/torchrl/objectives/dqn.py @@ -47,14 +47,14 @@ class DQNLoss(LossModule): Defaults to "l2". delay_value (bool, optional): whether to duplicate the value network into a new target value network to - create a DQN with a target network. Default is ``False``. + create a DQN with a target network. Default is ``True``. double_dqn (bool, optional): whether to use Double DQN, as described in https://arxiv.org/abs/1509.06461. Defaults to ``False``. action_space (str or TensorSpec, optional): Action space. Must be one of ``"one-hot"``, ``"mult_one_hot"``, ``"binary"`` or ``"categorical"``, - or an instance of the corresponding specs (:class:`torchrl.data.OneHotDiscreteTensorSpec`, - :class:`torchrl.data.MultiOneHotDiscreteTensorSpec`, - :class:`torchrl.data.BinaryDiscreteTensorSpec` or :class:`torchrl.data.DiscreteTensorSpec`). + or an instance of the corresponding specs (:class:`torchrl.data.OneHot`, + :class:`torchrl.data.MultiOneHot`, + :class:`torchrl.data.Binary` or :class:`torchrl.data.Categorical`). If not provided, an attempt to retrieve it from the value network will be made. priority_key (NestedKey, optional): [Deprecated, use .set_keys(priority_key=priority_key) instead] @@ -68,10 +68,10 @@ class DQNLoss(LossModule): Examples: >>> from torchrl.modules import MLP - >>> from torchrl.data import OneHotDiscreteTensorSpec + >>> from torchrl.data import OneHot >>> n_obs, n_act = 4, 3 >>> value_net = MLP(in_features=n_obs, out_features=n_act) - >>> spec = OneHotDiscreteTensorSpec(n_act) + >>> spec = OneHot(n_act) >>> actor = QValueActor(value_net, in_keys=["observation"], action_space=spec) >>> loss = DQNLoss(actor, action_space=spec) >>> batch = [10,] @@ -99,12 +99,12 @@ class DQNLoss(LossModule): Examples: >>> from torchrl.objectives import DQNLoss - >>> from torchrl.data import OneHotDiscreteTensorSpec + >>> from torchrl.data import OneHot >>> from torch import nn >>> import torch >>> n_obs = 3 >>> n_action = 4 - >>> action_spec = OneHotDiscreteTensorSpec(n_action) + >>> action_spec = OneHot(n_action) >>> value_network = nn.Linear(n_obs, n_action) # a simple value model >>> dqn_loss = DQNLoss(value_network, action_space=action_spec) >>> # define data @@ -224,7 +224,7 @@ def __init__( if action_space is None: warnings.warn( "action_space was not specified. DQNLoss will default to 'one-hot'." - "This behaviour will be deprecated soon and a space will have to be passed." + "This behavior will be deprecated soon and a space will have to be passed." "Check the DQNLoss documentation to see how to pass the action space. " ) action_space = "one-hot" diff --git a/torchrl/objectives/functional.py b/torchrl/objectives/functional.py index 7c598676794..fd96b2e92a3 100644 --- a/torchrl/objectives/functional.py +++ b/torchrl/objectives/functional.py @@ -20,7 +20,7 @@ def cross_entropy_loss( (integer representation) or log_policy.shape (one-hot). inplace: fills log_policy in-place with 0.0 at non-selected actions before summing along the last dimensions. This is usually faster but it will change the value of log-policy in place, which may lead to unwanted - behaviours. + behaviors. """ if action.shape == log_policy.shape: diff --git a/torchrl/objectives/gail.py b/torchrl/objectives/gail.py new file mode 100644 index 00000000000..3c0050fca84 --- /dev/null +++ b/torchrl/objectives/gail.py @@ -0,0 +1,251 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +from __future__ import annotations + +from dataclasses import dataclass + +import torch + +import torch.autograd as autograd +from tensordict import TensorDict, TensorDictBase, TensorDictParams +from tensordict.nn import dispatch, TensorDictModule +from tensordict.utils import NestedKey + +from torchrl.objectives.common import LossModule +from torchrl.objectives.utils import _reduce + + +class GAILLoss(LossModule): + r"""TorchRL implementation of the Generative Adversarial Imitation Learning (GAIL) loss. + + Presented in `"Generative Adversarial Imitation Learning" ` + + Args: + discriminator_network (TensorDictModule): stochastic actor + + Keyword Args: + use_grad_penalty (bool, optional): Whether to use gradient penalty. Default: ``False``. + gp_lambda (float, optional): Gradient penalty lambda. Default: ``10``. + reduction (str, optional): Specifies the reduction to apply to the output: + ``"none"`` | ``"mean"`` | ``"sum"``. ``"none"``: no reduction will be applied, + ``"mean"``: the sum of the output will be divided by the number of + elements in the output, ``"sum"``: the output will be summed. Default: ``"mean"``. + """ + + @dataclass + class _AcceptedKeys: + """Maintains default values for all configurable tensordict keys. + + This class defines which tensordict keys can be set using '.set_keys(key_name=key_value)' and their + default values. + + Attributes: + expert_action (NestedKey): The input tensordict key where the action is expected. + Defaults to ``"action"``. + expert_observation (NestedKey): The tensordict key where the observation is expected. + Defaults to ``"observation"``. + collector_action (NestedKey): The tensordict key where the collector action is expected. + Defaults to ``"collector_action"``. + collector_observation (NestedKey): The tensordict key where the collector observation is expected. + Defaults to ``"collector_observation"``. + discriminator_pred (NestedKey): The tensordict key where the discriminator prediction is expected. + """ + + expert_action: NestedKey = "action" + expert_observation: NestedKey = "observation" + collector_action: NestedKey = "collector_action" + collector_observation: NestedKey = "collector_observation" + discriminator_pred: NestedKey = "d_logits" + + default_keys = _AcceptedKeys() + + discriminator_network: TensorDictModule + discriminator_network_params: TensorDictParams + target_discriminator_network: TensorDictModule + target_discriminator_network_params: TensorDictParams + + out_keys = [ + "loss", + "gp_loss", + ] + + def __init__( + self, + discriminator_network: TensorDictModule, + *, + use_grad_penalty: bool = False, + gp_lambda: float = 10, + reduction: str = None, + ) -> None: + self._in_keys = None + self._out_keys = None + if reduction is None: + reduction = "mean" + super().__init__() + + # Discriminator Network + self.convert_to_functional( + discriminator_network, + "discriminator_network", + create_target_params=False, + ) + self.loss_function = torch.nn.BCELoss(reduction="none") + self.use_grad_penalty = use_grad_penalty + self.gp_lambda = gp_lambda + + self.reduction = reduction + + def _set_in_keys(self): + keys = self.discriminator_network.in_keys + keys = set(keys) + keys.add(self.tensor_keys.expert_observation) + keys.add(self.tensor_keys.expert_action) + keys.add(self.tensor_keys.collector_observation) + keys.add(self.tensor_keys.collector_action) + self._in_keys = sorted(keys, key=str) + + def _forward_value_estimator_keys(self, **kwargs) -> None: + pass + + @property + def in_keys(self): + if self._in_keys is None: + self._set_in_keys() + return self._in_keys + + @in_keys.setter + def in_keys(self, values): + self._in_keys = values + + @property + def out_keys(self): + if self._out_keys is None: + keys = ["loss"] + if self.use_grad_penalty: + keys.append("gp_loss") + self._out_keys = keys + return self._out_keys + + @out_keys.setter + def out_keys(self, values): + self._out_keys = values + + @dispatch + def forward( + self, + tensordict: TensorDictBase, + ) -> TensorDictBase: + """The forward method. + + Computes the discriminator loss and gradient penalty if `use_grad_penalty` is set to True. If `use_grad_penalty` is set to True, the detached gradient penalty loss is also returned for logging purposes. + To see what keys are expected in the input tensordict and what keys are expected as output, check the + class's `"in_keys"` and `"out_keys"` attributes. + """ + device = self.discriminator_network.device + tensordict = tensordict.clone(False) + shape = tensordict.shape + if len(shape) > 1: + batch_size, seq_len = shape + else: + batch_size = shape[0] + collector_obs = tensordict.get(self.tensor_keys.collector_observation) + collector_act = tensordict.get(self.tensor_keys.collector_action) + + expert_obs = tensordict.get(self.tensor_keys.expert_observation) + expert_act = tensordict.get(self.tensor_keys.expert_action) + + combined_obs_inputs = torch.cat([expert_obs, collector_obs], dim=0) + combined_act_inputs = torch.cat([expert_act, collector_act], dim=0) + + combined_inputs = TensorDict( + { + self.tensor_keys.expert_observation: combined_obs_inputs, + self.tensor_keys.expert_action: combined_act_inputs, + }, + batch_size=[2 * batch_size], + device=device, + ) + + # create + if len(shape) > 1: + fake_labels = torch.zeros((batch_size, seq_len, 1), dtype=torch.float32).to( + device + ) + real_labels = torch.ones((batch_size, seq_len, 1), dtype=torch.float32).to( + device + ) + else: + fake_labels = torch.zeros((batch_size, 1), dtype=torch.float32).to(device) + real_labels = torch.ones((batch_size, 1), dtype=torch.float32).to(device) + + with self.discriminator_network_params.to_module(self.discriminator_network): + d_logits = self.discriminator_network(combined_inputs).get( + self.tensor_keys.discriminator_pred + ) + + expert_preds, collection_preds = torch.split( + d_logits, [batch_size, batch_size], dim=0 + ) + + expert_loss = self.loss_function(expert_preds, real_labels) + collection_loss = self.loss_function(collection_preds, fake_labels) + + loss = expert_loss + collection_loss + out = {} + if self.use_grad_penalty: + obs = tensordict.get(self.tensor_keys.collector_observation) + acts = tensordict.get(self.tensor_keys.collector_action) + obs_e = tensordict.get(self.tensor_keys.expert_observation) + acts_e = tensordict.get(self.tensor_keys.expert_action) + + obss_noise = ( + torch.distributions.Uniform(0.0, 1.0).sample(obs_e.shape).to(device) + ) + acts_noise = ( + torch.distributions.Uniform(0.0, 1.0).sample(acts_e.shape).to(device) + ) + obss_mixture = obss_noise * obs + (1 - obss_noise) * obs_e + acts_mixture = acts_noise * acts + (1 - acts_noise) * acts_e + obss_mixture.requires_grad_(True) + acts_mixture.requires_grad_(True) + + pg_input_td = TensorDict( + { + self.tensor_keys.expert_observation: obss_mixture, + self.tensor_keys.expert_action: acts_mixture, + }, + [], + device=device, + ) + + with self.discriminator_network_params.to_module( + self.discriminator_network + ): + d_logits_mixture = self.discriminator_network(pg_input_td).get( + self.tensor_keys.discriminator_pred + ) + + gradients = torch.cat( + autograd.grad( + outputs=d_logits_mixture, + inputs=(obss_mixture, acts_mixture), + grad_outputs=torch.ones(d_logits_mixture.size(), device=device), + create_graph=True, + retain_graph=True, + only_inputs=True, + ), + dim=-1, + ) + + gp_loss = self.gp_lambda * torch.mean( + (torch.linalg.norm(gradients, dim=-1) - 1) ** 2 + ) + + loss += gp_loss + out["gp_loss"] = gp_loss.detach() + loss = _reduce(loss, reduction=self.reduction) + out["loss"] = loss + td_out = TensorDict(out, []) + return td_out diff --git a/torchrl/objectives/iql.py b/torchrl/objectives/iql.py index 74cfe504e78..c4639b70bdd 100644 --- a/torchrl/objectives/iql.py +++ b/torchrl/objectives/iql.py @@ -73,14 +73,14 @@ class IQLLoss(LossModule): Examples: >>> import torch >>> from torch import nn - >>> from torchrl.data import BoundedTensorSpec + >>> from torchrl.data import Bounded >>> from torchrl.modules.distributions import NormalParamExtractor, TanhNormal >>> from torchrl.modules.tensordict_module.actors import ProbabilisticActor, ValueOperator >>> from torchrl.modules.tensordict_module.common import SafeModule >>> from torchrl.objectives.iql import IQLLoss >>> from tensordict import TensorDict >>> n_act, n_obs = 4, 3 - >>> spec = BoundedTensorSpec(-torch.ones(n_act), torch.ones(n_act), (n_act,)) + >>> spec = Bounded(-torch.ones(n_act), torch.ones(n_act), (n_act,)) >>> net = nn.Sequential(nn.Linear(n_obs, 2 * n_act), NormalParamExtractor()) >>> module = SafeModule(net, in_keys=["observation"], out_keys=["loc", "scale"]) >>> actor = ProbabilisticActor( @@ -136,14 +136,14 @@ class IQLLoss(LossModule): Examples: >>> import torch >>> from torch import nn - >>> from torchrl.data import BoundedTensorSpec + >>> from torchrl.data import Bounded >>> from torchrl.modules.distributions import NormalParamExtractor, TanhNormal >>> from torchrl.modules.tensordict_module.actors import ProbabilisticActor, ValueOperator >>> from torchrl.modules.tensordict_module.common import SafeModule >>> from torchrl.objectives.iql import IQLLoss >>> _ = torch.manual_seed(42) >>> n_act, n_obs = 4, 3 - >>> spec = BoundedTensorSpec(-torch.ones(n_act), torch.ones(n_act), (n_act,)) + >>> spec = Bounded(-torch.ones(n_act), torch.ones(n_act), (n_act,)) >>> net = nn.Sequential(nn.Linear(n_obs, 2 * n_act), NormalParamExtractor()) >>> module = SafeModule(net, in_keys=["observation"], out_keys=["loc", "scale"]) >>> actor = ProbabilisticActor( @@ -383,7 +383,8 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase: loss_actor, metadata = self.actor_loss(tensordict_reshape) loss_qvalue, metadata_qvalue = self.qvalue_loss(tensordict_reshape) loss_value, metadata_value = self.value_loss(tensordict_reshape) - metadata.update(**metadata_qvalue, **metadata_value) + metadata.update(metadata_qvalue) + metadata.update(metadata_value) if (loss_actor.shape != loss_qvalue.shape) or ( loss_value is not None and loss_actor.shape != loss_value.shape @@ -410,7 +411,7 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase: [], ) - def actor_loss(self, tensordict: TensorDictBase) -> Tensor: + def actor_loss(self, tensordict: TensorDictBase) -> Tuple[Tensor, dict]: # KL loss with self.actor_network_params.to_module(self.actor_network): dist = self.actor_network.get_dist(tensordict) @@ -446,7 +447,7 @@ def actor_loss(self, tensordict: TensorDictBase) -> Tensor: loss_actor = _reduce(loss_actor, reduction=self.reduction) return loss_actor, {} - def value_loss(self, tensordict: TensorDictBase) -> Tuple[Tensor, Tensor]: + def value_loss(self, tensordict: TensorDictBase) -> Tuple[Tensor, dict]: # Min Q value td_q = tensordict.select(*self.qvalue_network.in_keys, strict=False) td_q = self._vmap_qvalue_networkN0(td_q, self.target_qvalue_network_params) @@ -460,7 +461,7 @@ def value_loss(self, tensordict: TensorDictBase) -> Tuple[Tensor, Tensor]: value_loss = _reduce(value_loss, reduction=self.reduction) return value_loss, {} - def qvalue_loss(self, tensordict: TensorDictBase) -> Tuple[Tensor, Tensor]: + def qvalue_loss(self, tensordict: TensorDictBase) -> Tuple[Tensor, dict]: obs_keys = self.actor_network.in_keys tensordict = tensordict.select( "next", *obs_keys, self.tensor_keys.action, strict=False @@ -541,9 +542,9 @@ class DiscreteIQLLoss(IQLLoss): Keyword Args: action_space (str or TensorSpec): Action space. Must be one of ``"one-hot"``, ``"mult_one_hot"``, ``"binary"`` or ``"categorical"``, - or an instance of the corresponding specs (:class:`torchrl.data.OneHotDiscreteTensorSpec`, - :class:`torchrl.data.MultiOneHotDiscreteTensorSpec`, - :class:`torchrl.data.BinaryDiscreteTensorSpec` or :class:`torchrl.data.DiscreteTensorSpec`). + or an instance of the corresponding specs (:class:`torchrl.data.OneHot`, + :class:`torchrl.data.MultiOneHot`, + :class:`torchrl.data.Binary` or :class:`torchrl.data.Categorical`). num_qvalue_nets (integer, optional): number of Q-Value networks used. Defaults to ``2``. loss_function (str, optional): loss function to be used with @@ -569,14 +570,14 @@ class DiscreteIQLLoss(IQLLoss): Examples: >>> import torch >>> from torch import nn - >>> from torchrl.data.tensor_specs import OneHotDiscreteTensorSpec + >>> from torchrl.data.tensor_specs import OneHot >>> from torchrl.modules.distributions.discrete import OneHotCategorical >>> from torchrl.modules.tensordict_module.actors import ProbabilisticActor >>> from torchrl.modules.tensordict_module.common import SafeModule >>> from torchrl.objectives.iql import DiscreteIQLLoss >>> from tensordict import TensorDict >>> n_act, n_obs = 4, 3 - >>> spec = OneHotDiscreteTensorSpec(n_act) + >>> spec = OneHot(n_act) >>> module = SafeModule(nn.Linear(n_obs, n_act), in_keys=["observation"], out_keys=["logits"]) >>> actor = ProbabilisticActor( ... module=module, @@ -627,14 +628,14 @@ class DiscreteIQLLoss(IQLLoss): >>> import torch >>> import torch >>> from torch import nn - >>> from torchrl.data.tensor_specs import OneHotDiscreteTensorSpec + >>> from torchrl.data.tensor_specs import OneHot >>> from torchrl.modules.distributions.discrete import OneHotCategorical >>> from torchrl.modules.tensordict_module.actors import ProbabilisticActor >>> from torchrl.modules.tensordict_module.common import SafeModule >>> from torchrl.objectives.iql import DiscreteIQLLoss >>> _ = torch.manual_seed(42) >>> n_act, n_obs = 4, 3 - >>> spec = OneHotDiscreteTensorSpec(n_act) + >>> spec = OneHot(n_act) >>> module = SafeModule(nn.Linear(n_obs, n_act), in_keys=["observation"], out_keys=["logits"]) >>> actor = ProbabilisticActor( ... module=module, @@ -774,14 +775,14 @@ def __init__( if action_space is None: warnings.warn( "action_space was not specified. DiscreteIQLLoss will default to 'one-hot'." - "This behaviour will be deprecated soon and a space will have to be passed." + "This behavior will be deprecated soon and a space will have to be passed." "Check the DiscreteIQLLoss documentation to see how to pass the action space. " ) action_space = "one-hot" self.action_space = _find_action_space(action_space) self.reduction = reduction - def actor_loss(self, tensordict: TensorDictBase) -> Tensor: + def actor_loss(self, tensordict: TensorDictBase) -> Tuple[Tensor, dict]: # KL loss with self.actor_network_params.to_module(self.actor_network): dist = self.actor_network.get_dist(tensordict) @@ -828,7 +829,7 @@ def actor_loss(self, tensordict: TensorDictBase) -> Tensor: loss_actor = _reduce(loss_actor, reduction=self.reduction) return loss_actor, {} - def value_loss(self, tensordict: TensorDictBase) -> Tuple[Tensor, Tensor]: + def value_loss(self, tensordict: TensorDictBase) -> Tuple[Tensor, dict]: # Min Q value with torch.no_grad(): # Min Q value @@ -856,7 +857,7 @@ def value_loss(self, tensordict: TensorDictBase) -> Tuple[Tensor, Tensor]: value_loss = _reduce(value_loss, reduction=self.reduction) return value_loss, {} - def qvalue_loss(self, tensordict: TensorDictBase) -> Tuple[Tensor, Tensor]: + def qvalue_loss(self, tensordict: TensorDictBase) -> Tuple[Tensor, dict]: obs_keys = self.actor_network.in_keys next_td = tensordict.select( "next", *obs_keys, self.tensor_keys.action, strict=False diff --git a/torchrl/objectives/multiagent/qmixer.py b/torchrl/objectives/multiagent/qmixer.py index c9dc281ef41..39777c59e26 100644 --- a/torchrl/objectives/multiagent/qmixer.py +++ b/torchrl/objectives/multiagent/qmixer.py @@ -63,9 +63,9 @@ class QMixerLoss(LossModule): create a double DQN. Default is ``False``. action_space (str or TensorSpec, optional): Action space. Must be one of ``"one-hot"``, ``"mult_one_hot"``, ``"binary"`` or ``"categorical"``, - or an instance of the corresponding specs (:class:`torchrl.data.OneHotDiscreteTensorSpec`, - :class:`torchrl.data.MultiOneHotDiscreteTensorSpec`, - :class:`torchrl.data.BinaryDiscreteTensorSpec` or :class:`torchrl.data.DiscreteTensorSpec`). + or an instance of the corresponding specs (:class:`torchrl.data.OneHot`, + :class:`torchrl.data.MultiOneHot`, + :class:`torchrl.data.Binary` or :class:`torchrl.data.Categorical`). If not provided, an attempt to retrieve it from the value network will be made. priority_key (NestedKey, optional): [Deprecated, use .set_keys(priority_key=priority_key) instead] @@ -254,7 +254,7 @@ def __init__( if action_space is None: warnings.warn( "action_space was not specified. QMixerLoss will default to 'one-hot'." - "This behaviour will be deprecated soon and a space will have to be passed." + "This behavior will be deprecated soon and a space will have to be passed." "Check the QMixerLoss documentation to see how to pass the action space. " ) action_space = "one-hot" diff --git a/torchrl/objectives/ppo.py b/torchrl/objectives/ppo.py index c29bc73dfa8..efc951b3999 100644 --- a/torchrl/objectives/ppo.py +++ b/torchrl/objectives/ppo.py @@ -6,13 +6,17 @@ import contextlib -import math from copy import deepcopy from dataclasses import dataclass from typing import Tuple import torch -from tensordict import TensorDict, TensorDictBase, TensorDictParams +from tensordict import ( + is_tensor_collection, + TensorDict, + TensorDictBase, + TensorDictParams, +) from tensordict.nn import ( dispatch, ProbabilisticTensorDictModule, @@ -45,15 +49,15 @@ class PPOLoss(LossModule): """A parent PPO loss class. - PPO (Proximal Policy Optimisation) is a model-free, online RL algorithm + PPO (Proximal Policy Optimization) is a model-free, online RL algorithm that makes use of a recorded (batch of) trajectories to perform several optimization steps, while actively preventing the updated policy to deviate too much from its original parameter configuration. - PPO loss can be found in different flavours, depending on the way the - constrained optimisation is implemented: ClipPPOLoss and KLPENPPOLoss. - Unlike its subclasses, this class does not implement any regularisation + PPO loss can be found in different flavors, depending on the way the + constrained optimization is implemented: ClipPPOLoss and KLPENPPOLoss. + Unlike its subclasses, this class does not implement any regularization and should therefore be used cautiously. For more details regarding PPO, refer to: "Proximal Policy Optimization Algorithms", @@ -75,7 +79,8 @@ class PPOLoss(LossModule): entropy_coef (scalar, optional): entropy multiplier when computing the total loss. Defaults to ``0.01``. critic_coef (scalar, optional): critic loss multiplier when computing the total - loss. Defaults to ``1.0``. + loss. Defaults to ``1.0``. Set ``critic_coef`` to ``None`` to exclude the value + loss from the forward outputs. loss_critic_type (str, optional): loss function for the value discrepancy. Can be one of "l1", "l2" or "smooth_l1". Defaults to ``"smooth_l1"``. normalize_advantage (bool, optional): if ``True``, the advantage will be normalized @@ -151,14 +156,14 @@ class PPOLoss(LossModule): Examples: >>> import torch >>> from torch import nn - >>> from torchrl.data.tensor_specs import BoundedTensorSpec + >>> from torchrl.data.tensor_specs import Bounded >>> from torchrl.modules.distributions import NormalParamExtractor, TanhNormal >>> from torchrl.modules.tensordict_module.actors import ProbabilisticActor, ValueOperator >>> from torchrl.modules.tensordict_module.common import SafeModule >>> from torchrl.objectives.ppo import PPOLoss >>> from tensordict import TensorDict >>> n_act, n_obs = 4, 3 - >>> spec = BoundedTensorSpec(-torch.ones(n_act), torch.ones(n_act), (n_act,)) + >>> spec = Bounded(-torch.ones(n_act), torch.ones(n_act), (n_act,)) >>> base_layer = nn.Linear(n_obs, 5) >>> net = nn.Sequential(base_layer, nn.Linear(5, 2 * n_act), NormalParamExtractor()) >>> module = SafeModule(net, in_keys=["observation"], out_keys=["loc", "scale"]) @@ -204,13 +209,13 @@ class PPOLoss(LossModule): Examples: >>> import torch >>> from torch import nn - >>> from torchrl.data.tensor_specs import BoundedTensorSpec + >>> from torchrl.data.tensor_specs import Bounded >>> from torchrl.modules.distributions import NormalParamExtractor, TanhNormal >>> from torchrl.modules.tensordict_module.actors import ProbabilisticActor, ValueOperator >>> from torchrl.modules.tensordict_module.common import SafeModule >>> from torchrl.objectives.ppo import PPOLoss >>> n_act, n_obs = 4, 3 - >>> spec = BoundedTensorSpec(-torch.ones(n_act), torch.ones(n_act), (n_act,)) + >>> spec = Bounded(-torch.ones(n_act), torch.ones(n_act), (n_act,)) >>> base_layer = nn.Linear(n_obs, 5) >>> net = nn.Sequential(base_layer, nn.Linear(5, 2 * n_act), NormalParamExtractor()) >>> module = SafeModule(net, in_keys=["observation"], out_keys=["loc", "scale"]) @@ -238,6 +243,12 @@ class PPOLoss(LossModule): ... next_observation=torch.randn(*batch, n_obs)) >>> loss_objective.backward() + .. note:: + There is an exception regarding compatibility with non-tensordict-based modules. + If the actor network is probabilistic and uses a :class:`~tensordict.nn.distributions.CompositeDistribution`, + this class must be used with tensordicts and cannot function as a tensordict-independent module. + This is because composite action spaces inherently rely on the structured representation of data provided by + tensordicts to handle their actions. """ @dataclass @@ -360,7 +371,12 @@ def __init__( device = torch.device("cpu") self.register_buffer("entropy_coef", torch.tensor(entropy_coef, device=device)) - self.register_buffer("critic_coef", torch.tensor(critic_coef, device=device)) + if critic_coef is not None: + self.register_buffer( + "critic_coef", torch.tensor(critic_coef, device=device) + ) + else: + self.critic_coef = None self.loss_critic_type = loss_critic_type self.normalize_advantage = normalize_advantage if gamma is not None: @@ -449,7 +465,10 @@ def get_entropy_bonus(self, dist: d.Distribution) -> torch.Tensor: entropy = dist.entropy() except NotImplementedError: x = dist.rsample((self.samples_mc_entropy,)) - entropy = -dist.log_prob(x).mean(0) + log_prob = dist.log_prob(x) + if is_tensor_collection(log_prob): + log_prob = log_prob.get(self.tensor_keys.sample_log_prob) + entropy = -log_prob.mean(0) return entropy.unsqueeze(-1) def _log_weight( @@ -457,20 +476,32 @@ def _log_weight( ) -> Tuple[torch.Tensor, d.Distribution]: # current log_prob of actions action = tensordict.get(self.tensor_keys.action) - if action.requires_grad: - raise RuntimeError( - f"tensordict stored {self.tensor_keys.action} requires grad." - ) with self.actor_network_params.to_module( self.actor_network ) if self.functional else contextlib.nullcontext(): dist = self.actor_network.get_dist(tensordict) - log_prob = dist.log_prob(action) prev_log_prob = tensordict.get(self.tensor_keys.sample_log_prob) if prev_log_prob.requires_grad: - raise RuntimeError("tensordict prev_log_prob requires grad.") + raise RuntimeError( + f"tensordict stored {self.tensor_keys.sample_log_prob} requires grad." + ) + + if action.requires_grad: + raise RuntimeError( + f"tensordict stored {self.tensor_keys.action} requires grad." + ) + if isinstance(action, torch.Tensor): + log_prob = dist.log_prob(action) + else: + maybe_log_prob = dist.log_prob(tensordict) + if not isinstance(maybe_log_prob, torch.Tensor): + # In some cases (Composite distribution with aggregate_probabilities toggled off) the returned type may not + # be a tensor + log_prob = maybe_log_prob.get(self.tensor_keys.sample_log_prob) + else: + log_prob = maybe_log_prob log_weight = (log_prob - prev_log_prob).unsqueeze(-1) kl_approx = (prev_log_prob - log_prob).unsqueeze(-1) @@ -478,6 +509,7 @@ def _log_weight( return log_weight, dist, kl_approx def loss_critic(self, tensordict: TensorDictBase) -> torch.Tensor: + """Returns the critic loss multiplied by ``critic_coef``, if it is not ``None``.""" # TODO: if the advantage is gathered by forward, this introduces an # overhead that we could easily reduce. if self.separate_losses: @@ -536,7 +568,9 @@ def loss_critic(self, tensordict: TensorDictBase) -> torch.Tensor: self.loss_critic_type, ) - return self.critic_coef * loss_value, clip_fraction + if self.critic_coef is not None: + return self.critic_coef * loss_value, clip_fraction + return loss_value, clip_fraction @property @_cache_values @@ -569,7 +603,7 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase: td_out.set("entropy", entropy.detach().mean()) # for logging td_out.set("kl_approx", kl_approx.detach().mean()) # for logging td_out.set("loss_entropy", -self.entropy_coef * entropy) - if self.critic_coef: + if self.critic_coef is not None: loss_critic, value_clip_fraction = self.loss_critic(tensordict) td_out.set("loss_critic", loss_critic) if value_clip_fraction is not None: @@ -653,7 +687,8 @@ class ClipPPOLoss(PPOLoss): entropy_coef (scalar, optional): entropy multiplier when computing the total loss. Defaults to ``0.01``. critic_coef (scalar, optional): critic loss multiplier when computing the total - loss. Defaults to ``1.0``. + loss. Defaults to ``1.0``. Set ``critic_coef`` to ``None`` to exclude the value + loss from the forward outputs. loss_critic_type (str, optional): loss function for the value discrepancy. Can be one of "l1", "l2" or "smooth_l1". Defaults to ``"smooth_l1"``. normalize_advantage (bool, optional): if ``True``, the advantage will be normalized @@ -774,13 +809,18 @@ def __init__( clip_value=clip_value, **kwargs, ) - self.register_buffer("clip_epsilon", torch.tensor(clip_epsilon)) + for p in self.parameters(): + device = p.device + break + else: + device = None + self.register_buffer("clip_epsilon", torch.tensor(clip_epsilon, device=device)) @property def _clip_bounds(self): return ( - math.log1p(-self.clip_epsilon), - math.log1p(self.clip_epsilon), + (-self.clip_epsilon).log1p(), + self.clip_epsilon.log1p(), ) @property @@ -843,7 +883,7 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase: td_out.set("entropy", entropy.detach().mean()) # for logging td_out.set("kl_approx", kl_approx.detach().mean()) # for logging td_out.set("loss_entropy", -self.entropy_coef * entropy) - if self.critic_coef: + if self.critic_coef is not None: loss_critic, value_clip_fraction = self.loss_critic(tensordict) td_out.set("loss_critic", loss_critic) if value_clip_fraction is not None: @@ -1107,7 +1147,17 @@ def forward(self, tensordict: TensorDictBase) -> TensorDict: kl = torch.distributions.kl.kl_divergence(previous_dist, current_dist) except NotImplementedError: x = previous_dist.sample((self.samples_mc_kl,)) - kl = (previous_dist.log_prob(x) - current_dist.log_prob(x)).mean(0) + previous_log_prob = previous_dist.log_prob(x) + current_log_prob = current_dist.log_prob(x) + if is_tensor_collection(current_log_prob): + previous_log_prob = previous_log_prob.get( + self.tensor_keys.sample_log_prob + ) + current_log_prob = current_log_prob.get( + self.tensor_keys.sample_log_prob + ) + + kl = (previous_log_prob - current_log_prob).mean(0) kl = kl.unsqueeze(-1) neg_loss = neg_loss - self.beta * kl if kl.mean() > self.dtarg * 1.5: @@ -1127,7 +1177,7 @@ def forward(self, tensordict: TensorDictBase) -> TensorDict: td_out.set("entropy", entropy.detach().mean()) # for logging td_out.set("kl_approx", kl_approx.detach().mean()) # for logging td_out.set("loss_entropy", -self.entropy_coef * entropy) - if self.critic_coef: + if self.critic_coef is not None: loss_critic, value_clip_fraction = self.loss_critic(tensordict_copy) td_out.set("loss_critic", loss_critic) if value_clip_fraction is not None: diff --git a/torchrl/objectives/redq.py b/torchrl/objectives/redq.py index 1522fd7749e..271f233bae8 100644 --- a/torchrl/objectives/redq.py +++ b/torchrl/objectives/redq.py @@ -12,11 +12,11 @@ import torch from tensordict import TensorDict, TensorDictBase, TensorDictParams -from tensordict.nn import dispatch, TensorDictModule, TensorDictSequential +from tensordict.nn import dispatch, TensorDictModule from tensordict.utils import NestedKey from torch import Tensor -from torchrl.data.tensor_specs import CompositeSpec +from torchrl.data.tensor_specs import Composite from torchrl.envs.utils import ExplorationType, set_exploration_type, step_mdp from torchrl.objectives.common import LossModule @@ -93,14 +93,14 @@ class REDQLoss(LossModule): Examples: >>> import torch >>> from torch import nn - >>> from torchrl.data import BoundedTensorSpec + >>> from torchrl.data import Bounded >>> from torchrl.modules.distributions import NormalParamExtractor, TanhNormal >>> from torchrl.modules.tensordict_module.actors import ProbabilisticActor, ValueOperator >>> from torchrl.modules.tensordict_module.common import SafeModule >>> from torchrl.objectives.redq import REDQLoss >>> from tensordict import TensorDict >>> n_act, n_obs = 4, 3 - >>> spec = BoundedTensorSpec(-torch.ones(n_act), torch.ones(n_act), (n_act,)) + >>> spec = Bounded(-torch.ones(n_act), torch.ones(n_act), (n_act,)) >>> net = nn.Sequential(nn.Linear(n_obs, 2 * n_act), NormalParamExtractor()) >>> module = SafeModule(net, in_keys=["observation"], out_keys=["loc", "scale"]) >>> actor = ProbabilisticActor( @@ -155,13 +155,13 @@ class REDQLoss(LossModule): Examples: >>> import torch >>> from torch import nn - >>> from torchrl.data import BoundedTensorSpec + >>> from torchrl.data import Bounded >>> from torchrl.modules.distributions import NormalParamExtractor, TanhNormal >>> from torchrl.modules.tensordict_module.actors import ProbabilisticActor, ValueOperator >>> from torchrl.modules.tensordict_module.common import SafeModule >>> from torchrl.objectives.redq import REDQLoss >>> n_act, n_obs = 4, 3 - >>> spec = BoundedTensorSpec(-torch.ones(n_act), torch.ones(n_act), (n_act,)) + >>> spec = Bounded(-torch.ones(n_act), torch.ones(n_act), (n_act,)) >>> net = nn.Sequential(nn.Linear(n_obs, 2 * n_act), NormalParamExtractor()) >>> module = SafeModule(net, in_keys=["observation"], out_keys=["loc", "scale"]) >>> actor = ProbabilisticActor( @@ -326,7 +326,11 @@ def __init__( else: self.register_parameter( "log_alpha", - torch.nn.Parameter(torch.tensor(math.log(alpha_init), device=device)), + torch.nn.Parameter( + torch.tensor( + math.log(alpha_init), device=device, requires_grad=True + ) + ), ) self._target_entropy = target_entropy @@ -367,8 +371,8 @@ def target_entropy(self): "the target entropy explicitely or provide the spec of the " "action tensor in the actor network." ) - if not isinstance(action_spec, CompositeSpec): - action_spec = CompositeSpec({self.tensor_keys.action: action_spec}) + if not isinstance(action_spec, Composite): + action_spec = Composite({self.tensor_keys.action: action_spec}) if ( isinstance(self.tensor_keys.action, tuple) and len(self.tensor_keys.action) > 1 @@ -401,10 +405,8 @@ def _forward_value_estimator_keys(self, **kwargs) -> None: @property def alpha(self): - self.log_alpha.data.clamp_(self.min_log_alpha, self.max_log_alpha) with torch.no_grad(): - alpha = self.log_alpha.exp() - return alpha + return self.log_alpha.clamp(self.min_log_alpha, self.max_log_alpha).exp() def _set_in_keys(self): keys = [ @@ -448,9 +450,12 @@ def _qvalue_params_cat(self, selected_q_params): @dispatch def forward(self, tensordict: TensorDictBase) -> TensorDictBase: obs_keys = self.actor_network.in_keys - tensordict_select = tensordict.clone(False).select( + tensordict_select = tensordict.select( "next", *obs_keys, self.tensor_keys.action, strict=False ) + # We need to copy bc select does not copy sub-tds + tensordict_select = tensordict_select.copy() + selected_models_idx = torch.randperm(self.num_qvalue_nets)[ : self.sub_sample_len ].sort()[0] @@ -467,7 +472,6 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase: *self.actor_network.in_keys, strict=False ) # next_observation -> tensordict_actor = torch.stack([tensordict_actor_grad, next_td_actor], 0) - # tensordict_actor = tensordict_actor.contiguous() with set_exploration_type(ExplorationType.RANDOM): if self.gSDE: @@ -480,19 +484,12 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase: tensordict_actor, actor_params, ) - if isinstance(self.actor_network, TensorDictSequential): - sample_key = self.tensor_keys.action - tensordict_actor_dist = self.actor_network.build_dist_from_params( - td_params - ) - else: - sample_key = self.tensor_keys.action - tensordict_actor_dist = self.actor_network.build_dist_from_params( - td_params - ) + sample_key = self.tensor_keys.action + sample_key_lp = self.tensor_keys.sample_log_prob + tensordict_actor_dist = self.actor_network.build_dist_from_params(td_params) tensordict_actor.set(sample_key, tensordict_actor_dist.rsample()) tensordict_actor.set( - self.tensor_keys.sample_log_prob, + sample_key_lp, tensordict_actor_dist.log_prob(tensordict_actor.get(sample_key)), ) @@ -603,12 +600,22 @@ def _loss_alpha(self, log_pi: Tensor) -> Tensor: ) if self.target_entropy is not None: # we can compute this loss even if log_alpha is not a parameter - alpha_loss = -self.log_alpha.exp() * (log_pi.detach() + self.target_entropy) + alpha_loss = -self._safe_log_alpha.exp() * ( + log_pi.detach() + self.target_entropy + ) else: # placeholder alpha_loss = torch.zeros_like(log_pi) return alpha_loss + @property + def _safe_log_alpha(self): + log_alpha = self.log_alpha + with torch.no_grad(): + log_alpha_clamp = log_alpha.clamp(self.min_log_alpha, self.max_log_alpha) + log_alpha_det = log_alpha.detach() + return log_alpha - log_alpha_det + log_alpha_clamp + def make_value_estimator(self, value_type: ValueEstimators = None, **hyperparams): if value_type is None: value_type = self.default_value_estimator diff --git a/torchrl/objectives/reinforce.py b/torchrl/objectives/reinforce.py index af9f7d99b46..08ff896610c 100644 --- a/torchrl/objectives/reinforce.py +++ b/torchrl/objectives/reinforce.py @@ -100,7 +100,7 @@ class ReinforceLoss(LossModule): Examples: >>> import torch >>> from torch import nn - >>> from torchrl.data.tensor_specs import UnboundedContinuousTensorSpec + >>> from torchrl.data.tensor_specs import Unbounded >>> from torchrl.modules.distributions import NormalParamExtractor, TanhNormal >>> from torchrl.modules.tensordict_module.actors import ProbabilisticActor, ValueOperator >>> from torchrl.modules.tensordict_module.common import SafeModule @@ -115,7 +115,7 @@ class ReinforceLoss(LossModule): ... distribution_class=TanhNormal, ... return_log_prob=True, ... in_keys=["loc", "scale"], - ... spec=UnboundedContinuousTensorSpec(n_act),) + ... spec=Unbounded(n_act),) >>> loss = ReinforceLoss(actor_net, value_net) >>> batch = 2 >>> data = TensorDict({ @@ -146,7 +146,7 @@ class ReinforceLoss(LossModule): Examples: >>> import torch >>> from torch import nn - >>> from torchrl.data.tensor_specs import UnboundedContinuousTensorSpec + >>> from torchrl.data.tensor_specs import Unbounded >>> from torchrl.modules.distributions import NormalParamExtractor, TanhNormal >>> from torchrl.modules.tensordict_module.actors import ProbabilisticActor, ValueOperator >>> from torchrl.modules.tensordict_module.common import SafeModule @@ -160,7 +160,7 @@ class ReinforceLoss(LossModule): ... distribution_class=TanhNormal, ... return_log_prob=True, ... in_keys=["loc", "scale"], - ... spec=UnboundedContinuousTensorSpec(n_act),) + ... spec=Unbounded(n_act),) >>> loss = ReinforceLoss(actor_net, value_net) >>> batch = 2 >>> loss_actor, loss_value = loss( diff --git a/torchrl/objectives/sac.py b/torchrl/objectives/sac.py index df444eac053..6350538db16 100644 --- a/torchrl/objectives/sac.py +++ b/torchrl/objectives/sac.py @@ -18,7 +18,7 @@ from tensordict.nn import dispatch, TensorDictModule from tensordict.utils import NestedKey from torch import Tensor -from torchrl.data.tensor_specs import CompositeSpec, TensorSpec +from torchrl.data.tensor_specs import Composite, TensorSpec from torchrl.data.utils import _find_action_space from torchrl.envs.utils import ExplorationType, set_exploration_type from torchrl.modules import ProbabilisticActor @@ -46,6 +46,19 @@ def new_func(self, *args, **kwargs): return new_func +def compute_log_prob(action_dist, action_or_tensordict, tensor_key): + """Compute the log probability of an action given a distribution.""" + if isinstance(action_or_tensordict, torch.Tensor): + log_p = action_dist.log_prob(action_or_tensordict) + else: + maybe_log_prob = action_dist.log_prob(action_or_tensordict) + if not isinstance(maybe_log_prob, torch.Tensor): + log_p = maybe_log_prob.get(tensor_key) + else: + log_p = maybe_log_prob + return log_p + + class SACLoss(LossModule): """TorchRL implementation of the SAC loss. @@ -117,14 +130,14 @@ class SACLoss(LossModule): Examples: >>> import torch >>> from torch import nn - >>> from torchrl.data import BoundedTensorSpec + >>> from torchrl.data import Bounded >>> from torchrl.modules.distributions import NormalParamExtractor, TanhNormal >>> from torchrl.modules.tensordict_module.actors import ProbabilisticActor, ValueOperator >>> from torchrl.modules.tensordict_module.common import SafeModule >>> from torchrl.objectives.sac import SACLoss >>> from tensordict import TensorDict >>> n_act, n_obs = 4, 3 - >>> spec = BoundedTensorSpec(-torch.ones(n_act), torch.ones(n_act), (n_act,)) + >>> spec = Bounded(-torch.ones(n_act), torch.ones(n_act), (n_act,)) >>> net = nn.Sequential(nn.Linear(n_obs, 2 * n_act), NormalParamExtractor()) >>> module = SafeModule(net, in_keys=["observation"], out_keys=["loc", "scale"]) >>> actor = ProbabilisticActor( @@ -180,14 +193,14 @@ class SACLoss(LossModule): Examples: >>> import torch >>> from torch import nn - >>> from torchrl.data import BoundedTensorSpec + >>> from torchrl.data import Bounded >>> from torchrl.modules.distributions import NormalParamExtractor, TanhNormal >>> from torchrl.modules.tensordict_module.actors import ProbabilisticActor, ValueOperator >>> from torchrl.modules.tensordict_module.common import SafeModule >>> from torchrl.objectives.sac import SACLoss >>> _ = torch.manual_seed(42) >>> n_act, n_obs = 4, 3 - >>> spec = BoundedTensorSpec(-torch.ones(n_act), torch.ones(n_act), (n_act,)) + >>> spec = Bounded(-torch.ones(n_act), torch.ones(n_act), (n_act,)) >>> net = nn.Sequential(nn.Linear(n_obs, 2 * n_act), NormalParamExtractor()) >>> module = SafeModule(net, in_keys=["observation"], out_keys=["loc", "scale"]) >>> actor = ProbabilisticActor( @@ -251,7 +264,7 @@ class _AcceptedKeys: state_action_value (NestedKey): The input tensordict key where the state action value is expected. Defaults to ``"state_action_value"``. log_prob (NestedKey): The input tensordict key where the log probability is expected. - Defaults to ``"_log_prob"``. + Defaults to ``"sample_log_prob"``. priority (NestedKey): The input tensordict key where the target priority is written to. Defaults to ``"td_error"``. reward (NestedKey): The input tensordict key where the reward is expected. @@ -267,7 +280,7 @@ class _AcceptedKeys: action: NestedKey = "action" value: NestedKey = "state_value" state_action_value: NestedKey = "state_action_value" - log_prob: NestedKey = "_log_prob" + log_prob: NestedKey = "sample_log_prob" priority: NestedKey = "td_error" reward: NestedKey = "reward" done: NestedKey = "done" @@ -440,8 +453,8 @@ def target_entropy(self): "the target entropy explicitely or provide the spec of the " "action tensor in the actor network." ) - if not isinstance(action_spec, CompositeSpec): - action_spec = CompositeSpec({self.tensor_keys.action: action_spec}) + if not isinstance(action_spec, Composite): + action_spec = Composite({self.tensor_keys.action: action_spec}) if ( isinstance(self.tensor_keys.action, tuple) and len(self.tensor_keys.action) > 1 @@ -450,9 +463,7 @@ def target_entropy(self): else: action_container_shape = action_spec.shape target_entropy = -float( - action_spec[self.tensor_keys.action] - .shape[len(action_container_shape) :] - .numel() + action_spec.shape[len(action_container_shape) :].numel() ) delattr(self, "_target_entropy") self.register_buffer( @@ -622,7 +633,7 @@ def _actor_loss( ), self.actor_network_params.to_module(self.actor_network): dist = self.actor_network.get_dist(tensordict) a_reparm = dist.rsample() - log_prob = dist.log_prob(a_reparm) + log_prob = compute_log_prob(dist, a_reparm, self.tensor_keys.log_prob) td_q = tensordict.select(*self.qvalue_network.in_keys, strict=False) td_q.set(self.tensor_keys.action, a_reparm) @@ -713,7 +724,9 @@ def _compute_target_v2(self, tensordict) -> Tensor: next_dist = self.actor_network.get_dist(next_tensordict) next_action = next_dist.rsample() next_tensordict.set(self.tensor_keys.action, next_action) - next_sample_log_prob = next_dist.log_prob(next_action) + next_sample_log_prob = compute_log_prob( + next_dist, next_action, self.tensor_keys.log_prob + ) # get q-values next_tensordict_expand = self._vmap_qnetworkN0( @@ -780,7 +793,8 @@ def _value_loss( td_copy.get(self.tensor_keys.state_action_value).squeeze(-1).min(0)[0] ) - log_p = action_dist.log_prob(action) + log_p = compute_log_prob(action_dist, action, self.tensor_keys.log_prob) + if log_p.shape != min_qval.shape: raise RuntimeError( f"Losses shape mismatch: {min_qval.shape} and {log_p.shape}" @@ -818,9 +832,9 @@ class DiscreteSACLoss(LossModule): qvalue_network (TensorDictModule): a single Q-value network that will be multiplicated as many times as needed. action_space (str or TensorSpec): Action space. Must be one of ``"one-hot"``, ``"mult_one_hot"``, ``"binary"`` or ``"categorical"``, - or an instance of the corresponding specs (:class:`torchrl.data.OneHotDiscreteTensorSpec`, - :class:`torchrl.data.MultiOneHotDiscreteTensorSpec`, - :class:`torchrl.data.BinaryDiscreteTensorSpec` or :class:`torchrl.data.DiscreteTensorSpec`). + or an instance of the corresponding specs (:class:`torchrl.data.OneHot`, + :class:`torchrl.data.MultiOneHot`, + :class:`torchrl.data.Binary` or :class:`torchrl.data.Categorical`). num_actions (int, optional): number of actions in the action space. To be provided if target_entropy is set to "auto". num_qvalue_nets (int, optional): Number of Q-value networks to be trained. Default is 2. @@ -852,7 +866,7 @@ class DiscreteSACLoss(LossModule): Examples: >>> import torch >>> from torch import nn - >>> from torchrl.data.tensor_specs import OneHotDiscreteTensorSpec + >>> from torchrl.data.tensor_specs import OneHot >>> from torchrl.modules.distributions import NormalParamExtractor, OneHotCategorical >>> from torchrl.modules.tensordict_module.actors import ProbabilisticActor, ValueOperator >>> from torchrl.modules.tensordict_module.common import SafeModule @@ -860,7 +874,7 @@ class DiscreteSACLoss(LossModule): >>> from tensordict import TensorDict >>> from tensordict.nn import TensorDictModule >>> n_act, n_obs = 4, 3 - >>> spec = OneHotDiscreteTensorSpec(n_act) + >>> spec = OneHot(n_act) >>> module = TensorDictModule(nn.Linear(n_obs, n_act), in_keys=["observation"], out_keys=["logits"]) >>> actor = ProbabilisticActor( ... module=module, @@ -909,13 +923,13 @@ class DiscreteSACLoss(LossModule): Examples: >>> import torch >>> from torch import nn - >>> from torchrl.data.tensor_specs import OneHotDiscreteTensorSpec + >>> from torchrl.data.tensor_specs import OneHot >>> from torchrl.modules.distributions import NormalParamExtractor, OneHotCategorical >>> from torchrl.modules.tensordict_module.actors import ProbabilisticActor, ValueOperator >>> from torchrl.modules.tensordict_module.common import SafeModule >>> from torchrl.objectives.sac import DiscreteSACLoss >>> n_act, n_obs = 4, 3 - >>> spec = OneHotDiscreteTensorSpec(n_act) + >>> spec = OneHot(n_act) >>> net = nn.Sequential(nn.Linear(n_obs, 2 * n_act), NormalParamExtractor()) >>> module = SafeModule(net, in_keys=["observation"], out_keys=["logits"]) >>> actor = ProbabilisticActor( @@ -1088,7 +1102,7 @@ def __init__( if action_space is None: warnings.warn( "action_space was not specified. DiscreteSACLoss will default to 'one-hot'." - "This behaviour will be deprecated soon and a space will have to be passed. " + "This behavior will be deprecated soon and a space will have to be passed. " "Check the DiscreteSACLoss documentation to see how to pass the action space. " ) action_space = "one-hot" diff --git a/torchrl/objectives/td3.py b/torchrl/objectives/td3.py index eb1027ad936..89ff581991f 100644 --- a/torchrl/objectives/td3.py +++ b/torchrl/objectives/td3.py @@ -12,7 +12,7 @@ from tensordict import TensorDict, TensorDictBase, TensorDictParams from tensordict.nn import dispatch, TensorDictModule from tensordict.utils import NestedKey -from torchrl.data.tensor_specs import BoundedTensorSpec, CompositeSpec, TensorSpec +from torchrl.data.tensor_specs import Bounded, Composite, TensorSpec from torchrl.envs.utils import step_mdp from torchrl.objectives.common import LossModule @@ -83,14 +83,14 @@ class TD3Loss(LossModule): Examples: >>> import torch >>> from torch import nn - >>> from torchrl.data import BoundedTensorSpec + >>> from torchrl.data import Bounded >>> from torchrl.modules.distributions import NormalParamExtractor, TanhNormal >>> from torchrl.modules.tensordict_module.actors import Actor, ProbabilisticActor, ValueOperator >>> from torchrl.modules.tensordict_module.common import SafeModule >>> from torchrl.objectives.td3 import TD3Loss >>> from tensordict import TensorDict >>> n_act, n_obs = 4, 3 - >>> spec = BoundedTensorSpec(-torch.ones(n_act), torch.ones(n_act), (n_act,)) + >>> spec = Bounded(-torch.ones(n_act), torch.ones(n_act), (n_act,)) >>> module = nn.Linear(n_obs, n_act) >>> actor = Actor( ... module=module, @@ -139,11 +139,11 @@ class TD3Loss(LossModule): Examples: >>> import torch >>> from torch import nn - >>> from torchrl.data import BoundedTensorSpec + >>> from torchrl.data import Bounded >>> from torchrl.modules.tensordict_module.actors import Actor, ValueOperator >>> from torchrl.objectives.td3 import TD3Loss >>> n_act, n_obs = 4, 3 - >>> spec = BoundedTensorSpec(-torch.ones(n_act), torch.ones(n_act), (n_act,)) + >>> spec = Bounded(-torch.ones(n_act), torch.ones(n_act), (n_act,)) >>> module = nn.Linear(n_obs, n_act) >>> actor = Actor( ... module=module, @@ -283,7 +283,7 @@ def __init__( f"but not both or none. Got bounds={bounds} and action_spec={action_spec}." ) elif action_spec is not None: - if isinstance(action_spec, CompositeSpec): + if isinstance(action_spec, Composite): if ( isinstance(self.tensor_keys.action, tuple) and len(self.tensor_keys.action) > 1 @@ -296,9 +296,9 @@ def __init__( action_spec = action_spec[self.tensor_keys.action][ (0,) * len(action_container_shape) ] - if not isinstance(action_spec, BoundedTensorSpec): + if not isinstance(action_spec, Bounded): raise ValueError( - f"action_spec is not of type BoundedTensorSpec but {type(action_spec)}." + f"action_spec is not of type Bounded but {type(action_spec)}." ) low = action_spec.space.low high = action_spec.space.high @@ -372,7 +372,7 @@ def _cached_stack_actor_params(self): [self.actor_network_params, self.target_actor_network_params], 0 ) - def actor_loss(self, tensordict): + def actor_loss(self, tensordict) -> Tuple[torch.Tensor, dict]: tensordict_actor_grad = tensordict.select( *self.actor_network.in_keys, strict=False ) @@ -398,7 +398,7 @@ def actor_loss(self, tensordict): loss_actor = _reduce(loss_actor, reduction=self.reduction) return loss_actor, metadata - def value_loss(self, tensordict): + def value_loss(self, tensordict) -> Tuple[torch.Tensor, dict]: tensordict = tensordict.clone(False) act = tensordict.get(self.tensor_keys.action) diff --git a/torchrl/objectives/td3_bc.py b/torchrl/objectives/td3_bc.py index aa87ea9aa1a..8b394137480 100644 --- a/torchrl/objectives/td3_bc.py +++ b/torchrl/objectives/td3_bc.py @@ -12,7 +12,7 @@ from tensordict import TensorDict, TensorDictBase, TensorDictParams from tensordict.nn import dispatch, TensorDictModule from tensordict.utils import NestedKey -from torchrl.data.tensor_specs import BoundedTensorSpec, CompositeSpec, TensorSpec +from torchrl.data.tensor_specs import Bounded, Composite, TensorSpec from torchrl.envs.utils import step_mdp from torchrl.objectives.common import LossModule @@ -94,14 +94,14 @@ class TD3BCLoss(LossModule): Examples: >>> import torch >>> from torch import nn - >>> from torchrl.data import BoundedTensorSpec + >>> from torchrl.data import Bounded >>> from torchrl.modules.distributions import NormalParamExtractor, TanhNormal >>> from torchrl.modules.tensordict_module.actors import Actor, ProbabilisticActor, ValueOperator >>> from torchrl.modules.tensordict_module.common import SafeModule >>> from torchrl.objectives.td3_bc import TD3BCLoss >>> from tensordict import TensorDict >>> n_act, n_obs = 4, 3 - >>> spec = BoundedTensorSpec(-torch.ones(n_act), torch.ones(n_act), (n_act,)) + >>> spec = Bounded(-torch.ones(n_act), torch.ones(n_act), (n_act,)) >>> module = nn.Linear(n_obs, n_act) >>> actor = Actor( ... module=module, @@ -152,11 +152,11 @@ class TD3BCLoss(LossModule): Examples: >>> import torch >>> from torch import nn - >>> from torchrl.data import BoundedTensorSpec + >>> from torchrl.data import Bounded >>> from torchrl.modules.tensordict_module.actors import Actor, ValueOperator >>> from torchrl.objectives.td3_bc import TD3BCLoss >>> n_act, n_obs = 4, 3 - >>> spec = BoundedTensorSpec(-torch.ones(n_act), torch.ones(n_act), (n_act,)) + >>> spec = Bounded(-torch.ones(n_act), torch.ones(n_act), (n_act,)) >>> module = nn.Linear(n_obs, n_act) >>> actor = Actor( ... module=module, @@ -299,7 +299,7 @@ def __init__( f"but not both or none. Got bounds={bounds} and action_spec={action_spec}." ) elif action_spec is not None: - if isinstance(action_spec, CompositeSpec): + if isinstance(action_spec, Composite): if ( isinstance(self.tensor_keys.action, tuple) and len(self.tensor_keys.action) > 1 @@ -312,9 +312,9 @@ def __init__( action_spec = action_spec[self.tensor_keys.action][ (0,) * len(action_container_shape) ] - if not isinstance(action_spec, BoundedTensorSpec): + if not isinstance(action_spec, Bounded): raise ValueError( - f"action_spec is not of type BoundedTensorSpec but {type(action_spec)}." + f"action_spec is not of type Bounded but {type(action_spec)}." ) low = action_spec.space.low high = action_spec.space.high @@ -386,7 +386,7 @@ def _cached_stack_actor_params(self): [self.actor_network_params, self.target_actor_network_params], 0 ) - def actor_loss(self, tensordict): + def actor_loss(self, tensordict) -> Tuple[torch.Tensor, dict]: """Compute the actor loss. The actor loss should be computed after the :meth:`~.qvalue_loss` and is usually delayed 1-3 critic updates. @@ -433,7 +433,7 @@ def actor_loss(self, tensordict): loss_actor = _reduce(loss_actor, reduction=self.reduction) return loss_actor, metadata - def qvalue_loss(self, tensordict): + def qvalue_loss(self, tensordict) -> Tuple[torch.Tensor, dict]: """Compute the q-value loss. The q-value loss should be computed before the :meth:`~.actor_loss`. diff --git a/torchrl/objectives/utils.py b/torchrl/objectives/utils.py index b1077198784..31954005195 100644 --- a/torchrl/objectives/utils.py +++ b/torchrl/objectives/utils.py @@ -26,6 +26,11 @@ raise err_ft from err from torchrl.envs.utils import step_mdp +try: + from torch.compiler import is_dynamo_compiling +except ImportError: + from torch._dynamo import is_compiling as is_dynamo_compiling + _GAMMA_LMBDA_DEPREC_ERROR = ( "Passing gamma / lambda parameters through the loss constructor " "is a deprecated feature. To customize your value function, " @@ -198,23 +203,37 @@ def __init__( @property def _targets(self): - return TensorDict( - {name: getattr(self.loss_module, name) for name in self._target_names}, - [], - ) + targets = self.__dict__.get("_targets_val", None) + if targets is None: + targets = self.__dict__["_targets_val"] = TensorDict( + {name: getattr(self.loss_module, name) for name in self._target_names}, + [], + ) + return targets + + @_targets.setter + def _targets(self, targets): + self.__dict__["_targets_val"] = targets @property def _sources(self): - return TensorDict( - {name: getattr(self.loss_module, name) for name in self._source_names}, - [], - ) + sources = self.__dict__.get("_sources_val", None) + if sources is None: + sources = self.__dict__["_sources_val"] = TensorDict( + {name: getattr(self.loss_module, name) for name in self._source_names}, + [], + ) + return sources + + @_sources.setter + def _sources(self, sources): + self.__dict__["_sources_val"] = sources def init_(self) -> None: if self.initialized: warnings.warn("Updated already initialized.") found_distinct = False - self._distinct = {} + self._distinct_and_params = {} for key, source in self._sources.items(True, True): if not isinstance(key, tuple): key = (key,) @@ -223,8 +242,12 @@ def init_(self) -> None: # for p_source, p_target in zip(source, target): if target.requires_grad: raise RuntimeError("the target parameter is part of a graph.") - self._distinct[key] = target.data_ptr() != source.data.data_ptr() - found_distinct = found_distinct or self._distinct[key] + self._distinct_and_params[key] = ( + target.is_leaf + and source.requires_grad + and target.data_ptr() != source.data.data_ptr() + ) + found_distinct = found_distinct or self._distinct_and_params[key] target.data.copy_(source.data) if not found_distinct: raise RuntimeError( @@ -235,6 +258,23 @@ def init_(self) -> None: f"If no target parameter is needed, do not use a target updater such as {type(self)}." ) + # filter the target_ out + def filter_target(key): + if isinstance(key, tuple): + return (filter_target(key[0]), *key[1:]) + return key[7:] + + self._sources = self._sources.select( + *[ + filter_target(key) + for (key, val) in self._distinct_and_params.items() + if val + ] + ).lock_() + self._targets = self._targets.select( + *(key for (key, val) in self._distinct_and_params.items() if val) + ).lock_() + self.initialized = True def step(self) -> None: @@ -243,19 +283,11 @@ def step(self) -> None: f"{self.__class__.__name__} must be " f"initialized (`{self.__class__.__name__}.init_()`) before calling step()" ) - for key, source in self._sources.items(True, True): - if not isinstance(key, tuple): - key = (key,) - key = ("target_" + key[0], *key[1:]) - if not self._distinct[key]: - continue - target = self._targets[key] + for key, param in self._sources.items(): + target = self._targets.get("target_{}".format(key)) if target.requires_grad: raise RuntimeError("the target parameter is part of a graph.") - if target.is_leaf: - self._step(source, target) - else: - target.copy_(source) + self._step(param, target) def _step(self, p_source: Tensor, p_target: Tensor) -> None: raise NotImplementedError @@ -301,7 +333,7 @@ def __init__( ): if eps is None and tau is None: raise RuntimeError( - "Neither eps nor tau was provided. This behaviour is deprecated.", + "Neither eps nor tau was provided. This behavior is deprecated.", ) eps = 0.999 if (eps is None) ^ (tau is None): @@ -321,8 +353,10 @@ def __init__( super(SoftUpdate, self).__init__(loss_module) self.eps = eps - def _step(self, p_source: Tensor, p_target: Tensor) -> None: - p_target.data.copy_(p_target.data * self.eps + p_source.data * (1 - self.eps)) + def _step( + self, p_source: Tensor | TensorDictBase, p_target: Tensor | TensorDictBase + ) -> None: + p_target.data.lerp_(p_source.data, 1 - self.eps) class HardUpdate(TargetNetUpdater): @@ -454,11 +488,17 @@ def next_state_value( return target_value -def _cache_values(fun): +def _cache_values(func): """Caches the tensordict returned by a property.""" - name = fun.__name__ + name = func.__name__ - def new_fun(self, netname=None): + @functools.wraps(func) + def new_func(self, netname=None): + if is_dynamo_compiling(): + if netname is not None: + return func(self, netname) + else: + return func(self) __dict__ = self.__dict__ _cache = __dict__.setdefault("_cache", {}) attr_name = name @@ -468,16 +508,16 @@ def new_fun(self, netname=None): out = _cache[attr_name] return out if netname is not None: - out = fun(self, netname) + out = func(self, netname) else: - out = fun(self) + out = func(self) # TODO: decide what to do with locked tds in functional calls # if is_tensor_collection(out): # out.lock_() _cache[attr_name] = out return out - return new_fun + return new_func def _vmap_func(module, *args, func=None, **kwargs): diff --git a/torchrl/objectives/value/advantages.py b/torchrl/objectives/value/advantages.py index b7db2e8242e..e396b7e1fcc 100644 --- a/torchrl/objectives/value/advantages.py +++ b/torchrl/objectives/value/advantages.py @@ -16,13 +16,12 @@ from tensordict import TensorDictBase from tensordict.nn import ( dispatch, - is_functional, set_skip_existing, TensorDictModule, TensorDictModuleBase, ) from tensordict.utils import NestedKey -from torch import nn, Tensor +from torch import Tensor from torchrl._utils import RL_WARNINGS from torchrl.envs.utils import step_mdp @@ -412,18 +411,13 @@ def value_estimate( @property def is_functional(self): - if isinstance(self.value_network, nn.Module): - return is_functional(self.value_network) - elif self.value_network is None: - return None - else: - raise RuntimeError("Cannot determine if value network is functional.") + # legacy + return False @property def is_stateless(self): - if not self.is_functional: - return False - return self.value_network._is_stateless + # legacy + return False def _next_value(self, tensordict, target_params, kwargs): step_td = step_mdp(tensordict, keep_other=False) @@ -1183,17 +1177,17 @@ class GAE(ValueEstimatorBase): device (torch.device, optional): device of the module. time_dim (int, optional): the dimension corresponding to the time in the input tensordict. If not provided, defaults to the dimension - markes with the ``"time"`` name if any, and to the last dimension + marked with the ``"time"`` name if any, and to the last dimension otherwise. Can be overridden during a call to :meth:`~.value_estimate`. Negative dimensions are considered with respect to the input tensordict. - GAE will return an :obj:`"advantage"` entry containing the advange value. It will also + GAE will return an :obj:`"advantage"` entry containing the advantage value. It will also return a :obj:`"value_target"` entry with the return value that is to be used to train the value network. Finally, if :obj:`gradient_mode` is ``True``, an additional and differentiable :obj:`"value_error"` entry will be returned, - which simple represents the difference between the return and the value network + which simply represents the difference between the return and the value network output (i.e. an additional distance loss should be applied to that signed value). .. note:: @@ -1268,7 +1262,7 @@ def forward( target params to be passed to the functional value network module. time_dim (int, optional): the dimension corresponding to the time in the input tensordict. If not provided, defaults to the dimension - markes with the ``"time"`` name if any, and to the last dimension + marked with the ``"time"`` name if any, and to the last dimension otherwise. Negative dimensions are considered with respect to the input tensordict. @@ -1316,7 +1310,7 @@ def forward( """ if tensordict.batch_dims < 1: raise RuntimeError( - "Expected input tensordict to have at least one dimensions, got " + "Expected input tensordict to have at least one dimension, got " f"tensordict.batch_size = {tensordict.batch_size}" ) reward = tensordict.get(("next", self.tensor_keys.reward)) diff --git a/torchrl/objectives/value/functional.py b/torchrl/objectives/value/functional.py index d3ad8d93ca4..ddd688610c2 100644 --- a/torchrl/objectives/value/functional.py +++ b/torchrl/objectives/value/functional.py @@ -230,7 +230,7 @@ def _fast_vec_gae( ``[*Batch x TimeSteps x F]``, with ``F`` feature dimensions. """ - # _gen_num_per_traj and _split_and_pad_sequence need + # _get_num_per_traj and _split_and_pad_sequence need # time dimension at last position done = done.transpose(-2, -1) terminated = terminated.transpose(-2, -1) diff --git a/torchrl/record/recorder.py b/torchrl/record/recorder.py index b7fb8ab4ed2..e533f9e9df9 100644 --- a/torchrl/record/recorder.py +++ b/torchrl/record/recorder.py @@ -18,7 +18,7 @@ from torchrl._utils import _can_be_pickled from torchrl.data import TensorSpec -from torchrl.data.tensor_specs import NonTensorSpec, UnboundedContinuousTensorSpec +from torchrl.data.tensor_specs import NonTensor, Unbounded from torchrl.data.utils import CloudpickleWrapper from torchrl.envs import EnvBase from torchrl.envs.transforms import ObservationTransform, Transform @@ -409,9 +409,9 @@ class PixelRenderTransform(Transform): >>> env.transform[-1].dump() The transform can be disabled using the :meth:`~torchrl.record.PixelRenderTransform.switch` method, which will - turn the rendering on if it's off or off if it's on (an argument can also be passed to control this behaviour). + turn the rendering on if it's off or off if it's on (an argument can also be passed to control this behavior). Since transforms are :class:`~torch.nn.Module` instances, :meth:`~torch.nn.Module.apply` can be used to control - this behaviour: + this behavior: >>> def switch(module): ... if isinstance(module, PixelRenderTransform): @@ -506,11 +506,9 @@ def transform_observation_spec(self, observation_spec: TensorSpec) -> TensorSpec self._call(td_in) obs = td_in.get(self.out_keys[0]) if isinstance(obs, NonTensorData): - spec = NonTensorSpec(device=obs.device, dtype=obs.dtype, shape=obs.shape) + spec = NonTensor(device=obs.device, dtype=obs.dtype, shape=obs.shape) else: - spec = UnboundedContinuousTensorSpec( - device=obs.device, dtype=obs.dtype, shape=obs.shape - ) + spec = Unbounded(device=obs.device, dtype=obs.dtype, shape=obs.shape) observation_spec[self.out_keys[0]] = spec if switch: self.switch() diff --git a/torchrl/trainers/helpers/collectors.py b/torchrl/trainers/helpers/collectors.py index b192d115a54..efdde1a1c63 100644 --- a/torchrl/trainers/helpers/collectors.py +++ b/torchrl/trainers/helpers/collectors.py @@ -19,7 +19,6 @@ from torchrl.data.postprocs import MultiStep from torchrl.envs.batched_envs import ParallelEnv from torchrl.envs.common import EnvBase -from torchrl.envs.utils import ExplorationType def sync_async_collector( @@ -304,7 +303,7 @@ def make_collector_offpolicy( "init_random_frames": cfg.init_random_frames, "split_trajs": True, # trajectories must be separated if multi-step is used - "exploration_type": ExplorationType.from_str(cfg.exploration_mode), + "exploration_type": cfg.exploration_type, } collector = collector_helper(**collector_helper_kwargs) @@ -358,7 +357,7 @@ def make_collector_onpolicy( "storing_device": cfg.collector_device, "split_trajs": True, # trajectories must be separated in online settings - "exploration_mode": cfg.exploration_mode, + "exploration_type": cfg.exploration_type, } collector = collector_helper(**collector_helper_kwargs) @@ -398,7 +397,7 @@ class OnPolicyCollectorConfig: # for each of these parallel wrappers. If env_per_collector=num_workers, no parallel wrapper is created seed: int = 42 # seed used for the environment, pytorch and numpy. - exploration_mode: str = "random" + exploration_type: str = "random" # exploration mode of the data collector. async_collection: bool = False # whether data collection should be done asynchrously. Asynchrounous data collection means diff --git a/torchrl/trainers/helpers/models.py b/torchrl/trainers/helpers/models.py index 0c9ec92cff4..a3776f78e5a 100644 --- a/torchrl/trainers/helpers/models.py +++ b/torchrl/trainers/helpers/models.py @@ -9,11 +9,7 @@ from tensordict import set_lazy_legacy from tensordict.nn import InteractionType from torch import nn -from torchrl.data.tensor_specs import ( - CompositeSpec, - DiscreteTensorSpec, - UnboundedContinuousTensorSpec, -) +from torchrl.data.tensor_specs import Categorical, Composite, Unbounded from torchrl.data.utils import DEVICE_TYPING from torchrl.envs.common import EnvBase from torchrl.envs.model_based.dreamer import DreamerEnv @@ -153,9 +149,9 @@ def make_dqn_actor( actor_class = QValueActor actor_kwargs = {} - if isinstance(action_spec, DiscreteTensorSpec): + if isinstance(action_spec, Categorical): # if action spec is modeled as categorical variable, we still need to have features equal - # to the number of possible choices and also set categorical behavioural for actors. + # to the number of possible choices and also set categorical behavioral for actors. actor_kwargs.update({"action_space": "categorical"}) out_features = env_specs["input_spec", "full_action_spec", "action"].space.n else: @@ -182,7 +178,7 @@ def make_dqn_actor( model = actor_class( module=net, - spec=CompositeSpec(action=action_spec), + spec=Composite(action=action_spec), in_keys=[in_key], safe=True, **actor_kwargs, @@ -385,13 +381,13 @@ def _dreamer_make_actor_sim(action_key, proof_environment, actor_module): actor_module, in_keys=["state", "belief"], out_keys=["loc", "scale"], - spec=CompositeSpec( + spec=Composite( **{ - "loc": UnboundedContinuousTensorSpec( + "loc": Unbounded( proof_environment.action_spec.shape, device=proof_environment.action_spec.device, ), - "scale": UnboundedContinuousTensorSpec( + "scale": Unbounded( proof_environment.action_spec.shape, device=proof_environment.action_spec.device, ), @@ -404,7 +400,7 @@ def _dreamer_make_actor_sim(action_key, proof_environment, actor_module): default_interaction_type=InteractionType.RANDOM, distribution_class=TanhNormal, distribution_kwargs={"tanh_loc": True}, - spec=CompositeSpec(**{action_key: proof_environment.action_spec}), + spec=Composite(**{action_key: proof_environment.action_spec}), ), ) return actor_simulator @@ -436,12 +432,12 @@ def _dreamer_make_actor_real( actor_module, in_keys=["state", "belief"], out_keys=["loc", "scale"], - spec=CompositeSpec( + spec=Composite( **{ - "loc": UnboundedContinuousTensorSpec( + "loc": Unbounded( proof_environment.action_spec.shape, ), - "scale": UnboundedContinuousTensorSpec( + "scale": Unbounded( proof_environment.action_spec.shape, ), } @@ -453,9 +449,7 @@ def _dreamer_make_actor_real( default_interaction_type=InteractionType.DETERMINISTIC, distribution_class=TanhNormal, distribution_kwargs={"tanh_loc": True}, - spec=CompositeSpec( - **{action_key: proof_environment.action_spec.to("cpu")} - ), + spec=Composite(**{action_key: proof_environment.action_spec.to("cpu")}), ), ), SafeModule( @@ -536,8 +530,8 @@ def _dreamer_make_mbenv( model_based_env.set_specs_from_env(proof_environment) model_based_env = TransformedEnv(model_based_env) default_dict = { - "state": UnboundedContinuousTensorSpec(state_dim), - "belief": UnboundedContinuousTensorSpec(rssm_hidden_dim), + "state": Unbounded(state_dim), + "belief": Unbounded(rssm_hidden_dim), # "action": proof_environment.action_spec, } model_based_env.append_transform( diff --git a/torchrl/trainers/trainers.py b/torchrl/trainers/trainers.py index 247d039eb1e..62ea4a4a109 100644 --- a/torchrl/trainers/trainers.py +++ b/torchrl/trainers/trainers.py @@ -1126,7 +1126,7 @@ class Recorder(TrainerHookBase): """Recorder hook for :class:`~torchrl.trainers.Trainer`. Args: - record_interval (int): total number of optimisation steps + record_interval (int): total number of optimization steps between two calls to the recorder for testing. record_frames (int): number of frames to be recorded during testing. @@ -1145,7 +1145,7 @@ class Recorder(TrainerHookBase): Given that this instance is supposed to both explore and render the performance of the policy, it should be possible to turn off - the explorative behaviour by calling the + the explorative behavior by calling the `set_exploration_type(ExplorationType.DETERMINISTIC)` context manager. environment (EnvBase): An environment instance to be used for testing. diff --git a/tutorials/README.md b/tutorials/README.md index d774f6b7566..562c5d427a9 100644 --- a/tutorials/README.md +++ b/tutorials/README.md @@ -1,21 +1,7 @@ # Tutorials -Get a sense of TorchRL functionalities through our tutorials. +Get a sense of TorchRL functionalities through our [tutorials](https://pytorch.org/rl/stable/tutorials). -For an overview of TorchRL, try the [TorchRL demo](https://pytorch.org/rl/tutorials/torchrl_demo.html). +The ["Getting Started"](https://pytorch.org/rl/stable/index.html#getting-started) section will help you model your first training loop with the library! -Make sure you test the [TensorDict tutorial](https://pytorch.org/rl/tutorials/tensordict_tutorial.html) to see what TensorDict -is about and what it can do. - -To understand how to use `TensorDict` with pytorch modules, make sure to check out the [TensorDictModule tutorial](https://pytorch.org/rl/tutorials/tensordict_module.html). - -Check out the [environment tutorial](https://pytorch.org/rl/tutorials/torch_envs.html) for a deep dive in the envs -functionalities. - -Read through our short tutorial on [multi-tasking](https://pytorch.org/rl/tutorials/multi_task.html) to see how you can execute diverse -tasks in batch mode and build task-specific policies. -This tutorial is also a good example of the advanced features of TensorDict stacking and -indexing. - -Finally, the [DDPG tutorial](https://pytorch.org/rl/tutorials/coding_ddpg.html) and [DQN tutorial](https://pytorch.org/rl/tutorials/coding_dqn.html) will guide you through the steps to code -your first RL algorithms with TorchRL. +The rest of the tutorials is split in [Basic](https://pytorch.org/rl/stable/index.html#basics), [Intermediate](https://pytorch.org/rl/stable/index.html#intermediate) and [Advanced](https://pytorch.org/rl/stable/index.html#advanced) sections. diff --git a/tutorials/sphinx-tutorials/coding_ddpg.py b/tutorials/sphinx-tutorials/coding_ddpg.py index 1bf7fd57e83..13721b715e3 100644 --- a/tutorials/sphinx-tutorials/coding_ddpg.py +++ b/tutorials/sphinx-tutorials/coding_ddpg.py @@ -683,7 +683,7 @@ def get_env_stats(): ) -from torchrl.data import CompositeSpec +from torchrl.data import Composite ############################################################################### # Building the model @@ -756,7 +756,7 @@ def make_ddpg_actor( actor, distribution_class=TanhDelta, in_keys=["param"], - spec=CompositeSpec(action=proof_environment.action_spec), + spec=Composite(action=proof_environment.action_spec), ).to(device) q_net = DdpgMlpQNet() @@ -899,7 +899,7 @@ def make_recorder(actor_model_explore, transform_state_dict, record_interval): record_frames=1000, policy_exploration=actor_model_explore, environment=environment, - exploration_type=ExplorationType.MEAN, + exploration_type=ExplorationType.DETERMINISTIC, record_interval=record_interval, ) return recorder_obj diff --git a/tutorials/sphinx-tutorials/coding_dqn.py b/tutorials/sphinx-tutorials/coding_dqn.py index e9f2085d3df..59188ad21f6 100644 --- a/tutorials/sphinx-tutorials/coding_dqn.py +++ b/tutorials/sphinx-tutorials/coding_dqn.py @@ -449,7 +449,7 @@ def get_collector( policy=actor_explore, frames_per_batch=frames_per_batch, total_frames=total_frames, - # this is the default behaviour: the collector runs in ``"random"`` (or explorative) mode + # this is the default behavior: the collector runs in ``"random"`` (or explorative) mode exploration_type=ExplorationType.RANDOM, # We set the all the devices to be identical. Below is an example of # heterogeneous devices @@ -672,7 +672,7 @@ def get_loss_module(actor, gamma): frame_skip=1, policy_exploration=actor_explore, environment=test_env, - exploration_type=ExplorationType.MODE, + exploration_type=ExplorationType.DETERMINISTIC, log_keys=[("next", "reward")], out_keys={("next", "reward"): "rewards"}, log_pbar=True, diff --git a/tutorials/sphinx-tutorials/coding_ppo.py b/tutorials/sphinx-tutorials/coding_ppo.py index 51229e1880d..25e72dc40f4 100644 --- a/tutorials/sphinx-tutorials/coding_ppo.py +++ b/tutorials/sphinx-tutorials/coding_ppo.py @@ -195,7 +195,7 @@ # ~~~~~~~~~~~~~~ # # At each data collection (or batch collection) we will run the optimization -# over a certain number of *epochs*, each time consuming the entire data we just +# over a certain number of *epochs*, each time-consuming the entire data we just # acquired in a nested training loop. Here, the ``sub_batch_size`` is different from the # ``frames_per_batch`` here above: recall that we are working with a "batch of data" # coming from our collector, which size is defined by ``frames_per_batch``, and that @@ -203,7 +203,7 @@ # The size of these sub-batches is controlled by ``sub_batch_size``. # sub_batch_size = 64 # cardinality of the sub-samples gathered from the current data in the inner loop -num_epochs = 10 # optimisation steps per batch of data collected +num_epochs = 10 # optimization steps per batch of data collected clip_epsilon = ( 0.2 # clip value for PPO loss: see the equation in the intro for more context. ) @@ -651,7 +651,7 @@ # number of steps (1000, which is our ``env`` horizon). # The ``rollout`` method of the ``env`` can take a policy as argument: # it will then execute this policy at each step. - with set_exploration_type(ExplorationType.MEAN), torch.no_grad(): + with set_exploration_type(ExplorationType.DETERMINISTIC), torch.no_grad(): # execute a rollout with the trained policy eval_rollout = env.rollout(1000, policy_module) logs["eval reward"].append(eval_rollout["next", "reward"].mean().item()) diff --git a/tutorials/sphinx-tutorials/dqn_with_rnn.py b/tutorials/sphinx-tutorials/dqn_with_rnn.py index 28a9638c6f6..8931f483384 100644 --- a/tutorials/sphinx-tutorials/dqn_with_rnn.py +++ b/tutorials/sphinx-tutorials/dqn_with_rnn.py @@ -440,7 +440,7 @@ exploration_module.step(data.numel()) updater.step() - with set_exploration_type(ExplorationType.MODE), torch.no_grad(): + with set_exploration_type(ExplorationType.DETERMINISTIC), torch.no_grad(): rollout = env.rollout(10000, stoch_policy) traj_lens.append(rollout.get(("next", "step_count")).max().item()) diff --git a/tutorials/sphinx-tutorials/getting-started-1.py b/tutorials/sphinx-tutorials/getting-started-1.py index 437cae26c42..4e8a1b30930 100644 --- a/tutorials/sphinx-tutorials/getting-started-1.py +++ b/tutorials/sphinx-tutorials/getting-started-1.py @@ -172,7 +172,7 @@ from torchrl.envs.utils import ExplorationType, set_exploration_type -with set_exploration_type(ExplorationType.MEAN): +with set_exploration_type(ExplorationType.DETERMINISTIC): # takes the mean as action rollout = env.rollout(max_steps=10, policy=policy) with set_exploration_type(ExplorationType.RANDOM): @@ -221,7 +221,7 @@ exploration_policy = TensorDictSequential(policy, exploration_module) -with set_exploration_type(ExplorationType.MEAN): +with set_exploration_type(ExplorationType.DETERMINISTIC): # Turns off exploration rollout = env.rollout(max_steps=10, policy=exploration_policy) with set_exploration_type(ExplorationType.RANDOM): diff --git a/tutorials/sphinx-tutorials/getting-started-5.py b/tutorials/sphinx-tutorials/getting-started-5.py index 5f95fe1e534..d355d1888c5 100644 --- a/tutorials/sphinx-tutorials/getting-started-5.py +++ b/tutorials/sphinx-tutorials/getting-started-5.py @@ -89,7 +89,7 @@ optim_steps = 10 collector = SyncDataCollector( env, - policy, + policy_explore, frames_per_batch=frames_per_batch, total_frames=-1, init_random_frames=init_rand_steps, diff --git a/tutorials/sphinx-tutorials/multiagent_competitive_ddpg.py b/tutorials/sphinx-tutorials/multiagent_competitive_ddpg.py index 77574b765e7..08b6d83bf5c 100644 --- a/tutorials/sphinx-tutorials/multiagent_competitive_ddpg.py +++ b/tutorials/sphinx-tutorials/multiagent_competitive_ddpg.py @@ -89,7 +89,7 @@ # wrapper for either PettingZoo or VMAS. # # 3. Following that, we will formulate the policy and critic networks, discussing the effects of various choices on -# parameter sharing and critic centralisation. +# parameter sharing and critic centralization. # # 4. Afterwards, we will create the sampling collector and the replay buffer. # @@ -179,7 +179,7 @@ memory_size = 1_000_000 # The replay buffer of each group can store this many frames # Training -n_optimiser_steps = 100 # Number of optimisation steps per training iteration +n_optimiser_steps = 100 # Number of optimization steps per training iteration train_batch_size = 128 # Number of frames trained in each optimiser step lr = 3e-4 # Learning rate max_grad_norm = 1.0 # Maximum norm for the gradients @@ -193,7 +193,7 @@ # ----------- # # Multi-agent environments simulate multiple agents interacting with the world. -# TorchRL API allows integrating various types of multi-agent environment flavours. +# TorchRL API allows integrating various types of multi-agent environment flavors. # In this tutorial we will focus on environments where multiple agent groups interact in parallel. # That is: at every step all agents will get an observation and take an action synchronously. # @@ -310,7 +310,7 @@ # Looking at the ``done_spec``, we can see that there are some keys that are outside of agent groups # (``"done", "terminated", "truncated"``), which do not have a leading multi-agent dimension. # These keys are shared by all agents and represent the environment global done state used for resetting. -# By default, like in this case, parallel PettingZoo environments are done when any agent is done, but this behaviour +# By default, like in this case, parallel PettingZoo environments are done when any agent is done, but this behavior # can be overridden by setting ``done_on_any`` at PettingZoo environment construction. # # To quickly access the keys for each of these values in tensordicts, we can simply ask the environment for the @@ -415,7 +415,7 @@ # Another important decision we need to make is whether we want the agents within a team to **share the policy parameters**. # On the one hand, sharing parameters means that they will all share the same policy, which will allow them to benefit from # each other's experiences. This will also result in faster training. -# On the other hand, it will make them behaviourally *homogenous*, as they will in fact share the same model. +# On the other hand, it will make them behaviorally *homogenous*, as they will in fact share the same model. # For this example, we will enable sharing as we do not mind the homogeneity and can benefit from the computational # speed, but it is important to always think about this decision in your own problems! # @@ -424,7 +424,7 @@ # **First**: define a neural network ``n_obs_per_agent`` -> ``n_actions_per_agents`` # # For this we use the ``MultiAgentMLP``, a TorchRL module made exactly for -# multiple agents, with much customisation available. +# multiple agents, with much customization available. # # We will define a different policy for each group and store them in a dictionary. # @@ -817,7 +817,7 @@ def process_batch(batch: TensorDictBase) -> TensorDictBase: target_updaters[group].step() # Exploration sigma anneal update - exploration_policies[group].step(current_frames) + exploration_policies[group][-1].step(current_frames) # Stop training a certain group when a condition is met (e.g., number of training iterations) if iteration == iteration_when_stop_training_evaders: @@ -903,7 +903,7 @@ def process_batch(batch: TensorDictBase) -> TensorDictBase: env_with_render = env_with_render.append_transform( VideoRecorder(logger=video_logger, tag="vmas_rendered") ) - with set_exploration_type(ExplorationType.MODE): + with set_exploration_type(ExplorationType.DETERMINISTIC): print("Rendering rollout...") env_with_render.rollout(100, policy=agents_exploration_policy) print("Saving the video...") diff --git a/tutorials/sphinx-tutorials/multiagent_ppo.py b/tutorials/sphinx-tutorials/multiagent_ppo.py index d7d906a4fb0..ec24de6cddd 100644 --- a/tutorials/sphinx-tutorials/multiagent_ppo.py +++ b/tutorials/sphinx-tutorials/multiagent_ppo.py @@ -99,7 +99,7 @@ # wrapper for the VMAS simulator. # # 3. Next, we will design the policy and the critic networks, discussing the impact of the various choices on -# parameter sharing and critic centralisation. +# parameter sharing and critic centralization. # # 4. Next, we will create the sampling collector and the replay buffer. # @@ -184,7 +184,7 @@ # ----------- # # Multi-agent environments simulate multiple agents interacting with the world. -# TorchRL API allows integrating various types of multi-agent environment flavours. +# TorchRL API allows integrating various types of multi-agent environment flavors. # Some examples include environments with shared or individual agent rewards, done flags, and observations. # For more information on how the multi-agent environments API works in TorchRL, you can check out the dedicated # :ref:`doc section `. @@ -195,7 +195,7 @@ # This means that all its state and physics # are PyTorch tensors with a first dimension representing the number of parallel environments in a batch. # This allows leveraging the Single Instruction Multiple Data (SIMD) paradigm of GPUs and significantly -# speed up parallel computation by leveraging parallelisation in GPU warps. It also means +# speed up parallel computation by leveraging parallelization in GPU warps. It also means # that, when using it in TorchRL, both simulation and training can be run on-device, without ever passing # data to the CPU. # @@ -207,7 +207,7 @@ # avoid colliding into each other. # Agents act in a 2D continuous world with drag and elastic collisions. # Their actions are 2D continuous forces which determine their acceleration. -# The reward is composed of three terms: a collision penalisation, a reward based on the distance to the goal, and a +# The reward is composed of three terms: a collision penalization, a reward based on the distance to the goal, and a # final shared reward given when all agents reach their goal. # The distance-based term is computed as the difference in the relative distance # between an agent and its goal over two consecutive timesteps. @@ -391,7 +391,7 @@ # **First**: define a neural network ``n_obs_per_agent`` -> ``2 * n_actions_per_agents`` # # For this we use the ``MultiAgentMLP``, a TorchRL module made exactly for -# multiple agents, with much customisation available. +# multiple agents, with much customization available. # share_parameters_policy = True diff --git a/tutorials/sphinx-tutorials/pendulum.py b/tutorials/sphinx-tutorials/pendulum.py index d25bc2cdd8a..1593d42a0ec 100644 --- a/tutorials/sphinx-tutorials/pendulum.py +++ b/tutorials/sphinx-tutorials/pendulum.py @@ -107,7 +107,7 @@ from tensordict.nn import TensorDictModule from torch import nn -from torchrl.data import BoundedTensorSpec, CompositeSpec, UnboundedContinuousTensorSpec +from torchrl.data import Bounded, Composite, Unbounded from torchrl.envs import ( CatTensors, EnvBase, @@ -128,7 +128,7 @@ # * :meth:`EnvBase._reset`, which codes for the resetting of the simulator # at a (potentially random) initial state; # * :meth:`EnvBase._step` which codes for the state transition dynamic; -# * :meth:`EnvBase._set_seed`` which implements the seeding mechanism; +# * :meth:`EnvBase._set_seed` which implements the seeding mechanism; # * the environment specs. # # Let us first describe the problem at hand: we would like to model a simple @@ -410,14 +410,14 @@ def _reset(self, tensordict): def _make_spec(self, td_params): # Under the hood, this will populate self.output_spec["observation"] - self.observation_spec = CompositeSpec( - th=BoundedTensorSpec( + self.observation_spec = Composite( + th=Bounded( low=-torch.pi, high=torch.pi, shape=(), dtype=torch.float32, ), - thdot=BoundedTensorSpec( + thdot=Bounded( low=-td_params["params", "max_speed"], high=td_params["params", "max_speed"], shape=(), @@ -433,25 +433,23 @@ def _make_spec(self, td_params): self.state_spec = self.observation_spec.clone() # action-spec will be automatically wrapped in input_spec when # `self.action_spec = spec` will be called supported - self.action_spec = BoundedTensorSpec( + self.action_spec = Bounded( low=-td_params["params", "max_torque"], high=td_params["params", "max_torque"], shape=(1,), dtype=torch.float32, ) - self.reward_spec = UnboundedContinuousTensorSpec(shape=(*td_params.shape, 1)) + self.reward_spec = Unbounded(shape=(*td_params.shape, 1)) def make_composite_from_td(td): # custom function to convert a ``tensordict`` in a similar spec structure # of unbounded values. - composite = CompositeSpec( + composite = Composite( { key: make_composite_from_td(tensor) if isinstance(tensor, TensorDictBase) - else UnboundedContinuousTensorSpec( - dtype=tensor.dtype, device=tensor.device, shape=tensor.shape - ) + else Unbounded(dtype=tensor.dtype, device=tensor.device, shape=tensor.shape) for key, tensor in td.items() }, shape=td.shape, @@ -611,7 +609,7 @@ def __init__(self, td_params=None, seed=None, device="cpu"): env, # ``Unsqueeze`` the observations that we will concatenate UnsqueezeTransform( - unsqueeze_dim=-1, + dim=-1, in_keys=["th", "thdot"], in_keys_inv=["th", "thdot"], ), @@ -694,7 +692,7 @@ def _reset( # is of type ``Composite`` @_apply_to_composite def transform_observation_spec(self, observation_spec): - return BoundedTensorSpec( + return Bounded( low=-1, high=1, shape=observation_spec.shape, @@ -718,7 +716,7 @@ def _reset( # is of type ``Composite`` @_apply_to_composite def transform_observation_spec(self, observation_spec): - return BoundedTensorSpec( + return Bounded( low=-1, high=1, shape=observation_spec.shape, diff --git a/tutorials/sphinx-tutorials/rb_tutorial.py b/tutorials/sphinx-tutorials/rb_tutorial.py index fc3a3ae954c..f189888b804 100644 --- a/tutorials/sphinx-tutorials/rb_tutorial.py +++ b/tutorials/sphinx-tutorials/rb_tutorial.py @@ -133,7 +133,7 @@ # basic properties (such as shape and dtype) as the first batch of data that # was used to instantiate the buffer. # Passing data that does not match this requirement will either raise an -# exception or lead to some undefined behaviours. +# exception or lead to some undefined behaviors. # - The :class:`~torchrl.data.LazyMemmapStorage` works as the # :class:`~torchrl.data.LazyTensorStorage` in that it is lazy (i.e., it # expects the first batch of data to be instantiated), and it requires data diff --git a/tutorials/sphinx-tutorials/torchrl_demo.py b/tutorials/sphinx-tutorials/torchrl_demo.py index 9d25da0a4cd..84f7be715ad 100644 --- a/tutorials/sphinx-tutorials/torchrl_demo.py +++ b/tutorials/sphinx-tutorials/torchrl_demo.py @@ -170,7 +170,7 @@ # * a collection of algorithms: we do not intend to provide SOTA implementations of RL algorithms, # but we provide these algorithms only as examples of how to use the library. # -# * a research framework: modularity in TorchRL comes in two flavours. First, we try +# * a research framework: modularity in TorchRL comes in two flavors. First, we try # to build re-usable components, such that they can be easily swapped with each other. # Second, we make our best such that components can be used independently of the rest # of the library. @@ -365,7 +365,7 @@ # Envs # ^^^^^^^^^^^^^^^^^^^^^^^^^^^^ -from torchrl.envs.libs.gym import GymEnv, GymWrapper +from torchrl.envs.libs.gym import GymEnv, GymWrapper, set_gym_backend gym_env = gym.make("Pendulum-v1") env = GymWrapper(gym_env) @@ -434,9 +434,16 @@ from torchrl.envs import ParallelEnv + +def make_env(): + # You can control whether to use gym or gymnasium for your env + with set_gym_backend("gym"): + return GymEnv("Pendulum-v1", frame_skip=3, from_pixels=True, pixels_only=False) + + base_env = ParallelEnv( 4, - lambda: GymEnv("Pendulum-v1", frame_skip=3, from_pixels=True, pixels_only=False), + make_env, mp_start_method="fork", # This will break on Windows machines! Remove and decorate with if __name__ == "__main__" ) env = TransformedEnv( @@ -572,10 +579,10 @@ def exec_sequence(params, data): # ------------------------------ torch.manual_seed(0) -from torchrl.data import BoundedTensorSpec +from torchrl.data import Bounded from torchrl.modules import SafeModule -spec = BoundedTensorSpec(-torch.ones(3), torch.ones(3)) +spec = Bounded(-torch.ones(3), torch.ones(3)) base_module = nn.Linear(5, 3) module = SafeModule( module=base_module, spec=spec, in_keys=["obs"], out_keys=["action"], safe=True @@ -652,14 +659,10 @@ def exec_sequence(params, data): td_module(td) print("random:", td["action"]) -with set_exploration_type(ExplorationType.MODE): +with set_exploration_type(ExplorationType.DETERMINISTIC): td_module(td) print("mode:", td["action"]) -with set_exploration_type(ExplorationType.MODE): - td_module(td) - print("mean:", td["action"]) - ############################################################################### # Using Environments and Modules # ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ diff --git a/tutorials/sphinx-tutorials/torchrl_envs.py b/tutorials/sphinx-tutorials/torchrl_envs.py index f2ae0372db2..34189396ee9 100644 --- a/tutorials/sphinx-tutorials/torchrl_envs.py +++ b/tutorials/sphinx-tutorials/torchrl_envs.py @@ -608,7 +608,7 @@ def env_make(env_name): ############################################################################### # Transforming parallel environments # ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ -# There are two equivalent ways of transforming parallen environments: in each +# There are two equivalent ways of transforming parallel environments: in each # process separately, or on the main process. It is even possible to do both. # One can therefore think carefully about the transform design to leverage the # device capabilities (e.g. transforms on cuda devices) and vectorizing