From dda0f9e14c890204021159f944ae76eb0fd415a8 Mon Sep 17 00:00:00 2001 From: Jason Ansel Date: Thu, 24 Aug 2023 11:55:47 -0700 Subject: [PATCH] Update README.md and remove some files --- .isort.cfg | 5 - Makefile | 180 - README.md | 4 - pytest.ini | 5 - test/__init__.py | 0 test/dynamo/__init__.py | 0 test/dynamo/mock_modules/__init__.py | 0 test/dynamo/mock_modules/mock_module1.py | 2 - test/dynamo/mock_modules/mock_module2.py | 19 - test/dynamo/mock_modules/mock_module3.py | 7 - test/dynamo/test_aot_autograd.py | 139 - test/dynamo/test_aot_cudagraphs.py | 209 -- test/dynamo/test_distributed.py | 287 -- test/dynamo/test_dynamic_shapes.py | 34 - test/dynamo/test_export.py | 1429 -------- test/dynamo/test_functions.py | 676 ---- test/dynamo/test_global.py | 233 -- test/dynamo/test_global_declaration.py | 4 - test/dynamo/test_misc.py | 2740 --------------- test/dynamo/test_model_output.py | 166 - test/dynamo/test_modules.py | 891 ----- test/dynamo/test_no_fake_tensors.py | 33 - test/dynamo/test_nops.py | 72 - test/dynamo/test_optimizations.py | 209 -- test/dynamo/test_optimizers.py | 103 - test/dynamo/test_python_autograd.py | 293 -- test/dynamo/test_repros.py | 1718 --------- test/dynamo/test_skip_non_tensor.py | 113 - test/dynamo/test_subgraphs.py | 534 --- test/dynamo/test_unspec.py | 228 -- test/dynamo/test_verify_correctness.py | 175 - test/inductor/__init__.py | 0 test/inductor/cpp/.gitignore | 13 - test/inductor/cpp/CMakeLists.txt | 47 - test/inductor/cpp/test.sh | 7 - test/inductor/cpp/test_cpp_prefix.cpp | 21 - test/inductor/test_torchinductor.py | 4061 ---------------------- 37 files changed, 14657 deletions(-) delete mode 100644 .isort.cfg delete mode 100644 Makefile delete mode 100644 pytest.ini delete mode 100644 test/__init__.py delete mode 100644 test/dynamo/__init__.py delete mode 100644 test/dynamo/mock_modules/__init__.py delete mode 100644 test/dynamo/mock_modules/mock_module1.py delete mode 100644 test/dynamo/mock_modules/mock_module2.py delete mode 100644 test/dynamo/mock_modules/mock_module3.py delete mode 100644 test/dynamo/test_aot_autograd.py delete mode 100644 test/dynamo/test_aot_cudagraphs.py delete mode 100644 test/dynamo/test_distributed.py delete mode 100644 test/dynamo/test_dynamic_shapes.py delete mode 100644 test/dynamo/test_export.py delete mode 100644 test/dynamo/test_functions.py delete mode 100644 test/dynamo/test_global.py delete mode 100644 test/dynamo/test_global_declaration.py delete mode 100644 test/dynamo/test_misc.py delete mode 100644 test/dynamo/test_model_output.py delete mode 100644 test/dynamo/test_modules.py delete mode 100644 test/dynamo/test_no_fake_tensors.py delete mode 100644 test/dynamo/test_nops.py delete mode 100644 test/dynamo/test_optimizations.py delete mode 100644 test/dynamo/test_optimizers.py delete mode 100644 test/dynamo/test_python_autograd.py delete mode 100644 test/dynamo/test_repros.py delete mode 100644 test/dynamo/test_skip_non_tensor.py delete mode 100644 test/dynamo/test_subgraphs.py delete mode 100644 test/dynamo/test_unspec.py delete mode 100644 test/dynamo/test_verify_correctness.py delete mode 100644 test/inductor/__init__.py delete mode 100644 test/inductor/cpp/.gitignore delete mode 100644 test/inductor/cpp/CMakeLists.txt delete mode 100755 test/inductor/cpp/test.sh delete mode 100644 test/inductor/cpp/test_cpp_prefix.cpp delete mode 100644 test/inductor/test_torchinductor.py diff --git a/.isort.cfg b/.isort.cfg deleted file mode 100644 index 1b26dca583..0000000000 --- a/.isort.cfg +++ /dev/null @@ -1,5 +0,0 @@ -[settings] -profile=black -src_paths=test,torchdynamo -force_single_line=True -known_first_party=torchdynamo,torchinductor diff --git a/Makefile b/Makefile deleted file mode 100644 index b70e5270fd..0000000000 --- a/Makefile +++ /dev/null @@ -1,180 +0,0 @@ -.PHONY: default develop test torchbench format lint setup clean - -PY_FILES := $(wildcard *.py) $(wildcard torchdynamo/*.py) $(wildcard torchdynamo/*/*.py) \ - $(wildcard torchinductor/*.py) $(wildcard torchinductor/*/*.py) \ - $(wildcard benchmarks/*.py) $(wildcard benchmarks/*/*.py) \ - $(wildcard test/*.py) $(wildcard test/*/*.py) \ - $(wildcard .circleci/*.py) $(wildcard tools/*.py) -C_FILES := $(wildcard torchdynamo/*.c torchdynamo/*.cpp) -CLANG_TIDY ?= clang-tidy-10 -CLANG_FORMAT ?= clang-format-10 -PIP ?= python -m pip - -# versions used in CI -# Also update the "Install nightly binaries" section of the README when updating these -PYTORCH_VERSION ?= dev20221017 -TRITON_VERSION ?= db3aa1d1fb2bb536752a71d9e0f03cf6a86ddf65 - - -default: develop - -develop: - python setup.py develop - -test: develop - pytest test -o log_cli=False - -torchbench: develop - python benchmarks/torchbench.py --fast - -overhead: develop - python benchmarks/torchbench.py --overhead - -format: - isort $(PY_FILES) - black $(PY_FILES) - -lint: - black --check --diff $(PY_FILES) - isort --check --diff $(PY_FILES) - flake8 $(PY_FILES) - -lint-deps: - grep -E '(black|flake8|isort|click|torch|mypy)' requirements.txt | xargs $(PIP) install - -setup_lint: lint-deps - -setup: - $(PIP) install -r requirements.txt - -setup_nightly: - $(PIP) install ninja - $(PIP) install --pre torch==1.14.0.$(PYTORCH_VERSION) --extra-index-url https://download.pytorch.org/whl/nightly/cpu - $(PIP) install -r requirements.txt - -setup_nightly_gpu: - conda install -y -c pytorch magma-cuda116 cudatoolkit=11.6 -c conda-forge - $(PIP) install --pre torch==1.14.0.$(PYTORCH_VERSION) \ - torchvision==0.15.0.$(PYTORCH_VERSION) \ - torchtext==0.14.0.$(PYTORCH_VERSION) \ - --extra-index-url https://download.pytorch.org/whl/nightly/cu116 - $(PIP) install ninja - $(PIP) install -U "git+https://github.com/openai/triton@$(TRITON_VERSION)#subdirectory=python" - $(PIP) install -r requirements.txt - -clean: - python setup.py clean - rm -rf build torchdynamo.egg-info torchdynamo/*.so __pycache__ .pytest_cache .benchmarks *.csv dist - -clone-deps: - (cd .. \ - && (test -e pytorch || git clone --recursive https://github.com/pytorch/pytorch pytorch) \ - && (test -e torchvision || git clone --recursive https://github.com/pytorch/vision torchvision) \ - && (test -e torchtext || git clone --recursive https://github.com/pytorch/text torchtext) \ - && (test -e detectron2 || git clone --recursive https://github.com/facebookresearch/detectron2) \ - && (test -e torchbenchmark || git clone --recursive https://github.com/pytorch/benchmark torchbenchmark) \ - && (test -e triton || git clone --recursive https://github.com/openai/triton.git) \ - ) - -pull-deps: - (cd ../pytorch && git pull && git submodule update --init --recursive) - (cd ../torchvision && git pull && git submodule update --init --recursive) - (cd ../torchtext && git pull && git submodule update --init --recursive) - (cd ../detectron2 && git pull && git submodule update --init --recursive) - (cd ../torchbenchmark && git pull && git submodule update --init --recursive) - (cd ../triton && git checkout master && git pull && git checkout $(TRITON_VERSION) && git submodule update --init --recursive) - -build-deps: clone-deps - # conda env remove --name torchdynamo - # conda create --name torchdynamo -y python=3.8 - # conda activate torchdynamo - conda install -y astunparse numpy scipy ninja pyyaml mkl mkl-include setuptools cmake \ - cffi typing_extensions future six requests dataclasses protobuf numba cython scikit-learn - conda install -y -c pytorch magma-cuda116 - conda install -y -c conda-forge librosa - - make setup && $(PIP) uninstall -y torch - (cd ../pytorch && python setup.py clean && python setup.py develop) - (cd ../torchvision && python setup.py clean && python setup.py develop) - (cd ../torchtext && python setup.py clean && python setup.py develop) - (cd ../detectron2 && python setup.py clean && python setup.py develop) - (cd ../torchbenchmark && python install.py --continue_on_fail) - (cd ../triton/python && python setup.py clean && python setup.py develop) - make setup_lint - python setup.py develop - -baseline-cpu: develop - rm -f baseline_*.csv - python benchmarks/torchbench.py -n50 --overhead - python benchmarks/torchbench.py -n50 --speedup-ts - python benchmarks/torchbench.py -n50 --speedup-sr - python benchmarks/torchbench.py -n50 --speedup-onnx - paste -d, baseline_ts.csv baseline_sr.csv baseline_onnx.csv > baseline_all.csv - -baseline-gpu: develop - rm -f baseline_*.csv - python benchmarks/torchbench.py -dcuda -n100 --overhead - python benchmarks/torchbench.py -dcuda -n100 --speedup-ts && mv baseline_ts.csv baseline_nnc.csv - python benchmarks/torchbench.py -dcuda -n100 --speedup-ts --nvfuser && mv baseline_ts.csv baseline_nvfuser.csv - python benchmarks/torchbench.py -dcuda -n100 --speedup-trt - python benchmarks/torchbench.py -dcuda -n100 --speedup-onnx - paste -d, baseline_nnc.csv baseline_nvfuser.csv baseline_trt.csv baseline_onnx.csv > baseline_all.csv - -gpu-inductor-cudagraphs-fp32: develop - rm -f inductor.csv baseline_cudagraphs.csv baseline_cg_nvfuser.csv baseline_cg_nnc.csv inductor_gpu_cudagraphs_fp32.csv - python benchmarks/torchbench.py -dcuda --inductor-settings --float32 -n50 --inductor - python benchmarks/torchbench.py -dcuda --inductor-settings --float32 -n50 --backend=cudagraphs - mv speedup_cudagraphs.csv baseline_cudagraphs.csv - python benchmarks/torchbench.py -dcuda --inductor-settings --float32 -n50 --backend=cudagraphs_ts --nvfuser - mv speedup_cudagraphs_ts.csv baseline_cg_nvfuser.csv - python benchmarks/torchbench.py -dcuda --inductor-settings --float32 -n50 --backend=cudagraphs_ts - mv speedup_cudagraphs_ts.csv baseline_cg_nnc.csv - paste -d, inductor.csv baseline_cudagraphs.csv baseline_cg_nvfuser.csv baseline_cg_nnc.csv > inductor_gpu_cudagraphs_fp32.csv - -gpu-inductor-cudagraphs-fp16: develop - rm -f inductor.csv baseline_cudagraphs.csv baseline_cg_nvfuser.csv baseline_cg_nnc.csv inductor_gpu_cudagraphs_fp16.csv - python benchmarks/torchbench.py -dcuda --inductor-settings --float16 -n50 --inductor - python benchmarks/torchbench.py -dcuda --inductor-settings --float16 -n50 --backend=cudagraphs - mv speedup_cudagraphs.csv baseline_cudagraphs.csv - python benchmarks/torchbench.py -dcuda --inductor-settings --float16 -n50 --backend=cudagraphs_ts --nvfuser - mv speedup_cudagraphs_ts.csv baseline_cg_nvfuser.csv - python benchmarks/torchbench.py -dcuda --inductor-settings --float16 -n50 --backend=cudagraphs_ts - mv speedup_cudagraphs_ts.csv baseline_cg_nnc.csv - paste -d, inductor.csv baseline_cudagraphs.csv baseline_cg_nvfuser.csv baseline_cg_nnc.csv > inductor_gpu_cudagraphs_fp16.csv - -gpu-inductor-dynamic: develop - rm -f inductor.csv baseline_nvfuser.csv baseline_nnc.csv inductor_gpu_dynamic.csv - python benchmarks/torchbench.py -dcuda --inductor-settings --float32 -n50 --inductor-dynamic - python benchmarks/torchbench.py -dcuda --inductor-settings --float32 -n50 --backend=ts --nvfuser - mv speedup_ts.csv baseline_nvfuser.csv - python benchmarks/torchbench.py -dcuda --inductor-settings --float32 -n50 --backend=ts - mv speedup_ts.csv baseline_nnc.csv - paste -d, inductor.csv baseline_nvfuser.csv baseline_nnc.csv > inductor_gpu_dynamic.csv - -cpu-inductor: develop - rm -f inductor.csv speedup_ts.csv cpu_mt_inductor.csv - python torchbench.py --inductor-settings --fast --inductor - python torchbench.py --inductor-settings --fast --backend=ts - paste -d, inductor.csv speedup_ts.csv > cpu_mt_inductor.csv - -cpu-inductor-seq: develop - rm -f inductor.csv speedup_ts.csv cpu_1t_inductor.csv - taskset 1 python benchmarks/torchbench.py --inductor-settings --fast --inductor --threads=1 - taskset 1 python benchmarks/torchbench.py --inductor-settings --fast --backend=ts --threads=1 - paste -d, inductor.csv speedup_ts.csv > cpu_1t_inductor.csv - -gpu-inductor-bw-fp16: develop - rm -f inductor.csv speedup_aot_nvfuser.csv speedup_aot_cudagraphs.csv - python benchmarks/torchbench.py --training -dcuda --inductor-settings --float16 -n100 --backend=aot_nvfuser --nvfuser - python benchmarks/torchbench.py --training -dcuda --inductor-settings --float16 -n100 --backend=aot_cudagraphs - python benchmarks/torchbench.py --training -dcuda --inductor-settings --float16 -n100 --inductor - paste -d, inductor.csv speedup_aot_nvfuser.csv speedup_aot_cudagraphs.csv > inductor_bw_fp16.csv - -gpu-inductor-bw-fp32: develop - rm -f inductor.csv speedup_aot_nvfuser.csv speedup_aot_cudagraphs.csv - python benchmarks/torchbench.py --training -dcuda --inductor-settings --float32 -n100 --backend=aot_nvfuser --nvfuser - python benchmarks/torchbench.py --training -dcuda --inductor-settings --float32 -n100 --backend=aot_cudagraphs - python benchmarks/torchbench.py --training -dcuda --inductor-settings --float32 -n100 --inductor - paste -d, inductor.csv speedup_aot_nvfuser.csv speedup_aot_cudagraphs.csv > inductor_bw_fp32.csv - - diff --git a/README.md b/README.md index 58945044b3..418f622b39 100644 --- a/README.md +++ b/README.md @@ -6,10 +6,6 @@ We have moved TorchDynamo to - `import torchdynamo` is now `import torch._dynamo` - `import torchinductor` is now `import torch._inductor` -This repository still contains: -- An alias to the new location -- Issues: we will continue using this project for issue tracking - For Documentation: https://pytorch.org/docs/stable/dynamo/index.html ## License diff --git a/pytest.ini b/pytest.ini deleted file mode 100644 index b5d952ae69..0000000000 --- a/pytest.ini +++ /dev/null @@ -1,5 +0,0 @@ -[pytest] -testpaths = - test -log_cli = False -log_cli_level = INFO diff --git a/test/__init__.py b/test/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/test/dynamo/__init__.py b/test/dynamo/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/test/dynamo/mock_modules/__init__.py b/test/dynamo/mock_modules/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/test/dynamo/mock_modules/mock_module1.py b/test/dynamo/mock_modules/mock_module1.py deleted file mode 100644 index c4bd2bf4f9..0000000000 --- a/test/dynamo/mock_modules/mock_module1.py +++ /dev/null @@ -1,2 +0,0 @@ -def method1(a, b): - return a + b diff --git a/test/dynamo/mock_modules/mock_module2.py b/test/dynamo/mock_modules/mock_module2.py deleted file mode 100644 index 7fe8979709..0000000000 --- a/test/dynamo/mock_modules/mock_module2.py +++ /dev/null @@ -1,19 +0,0 @@ -# from . import mock_module3 -import torch - -from . import mock_module3 - - -class Class1: - def __init__(self, x, y): - self.x = x - self.y = y - - def method2(self, x): - return mock_module3.method1([], x) - - -def method1(x, y): - torch.ones(1, 1) - x.append(y) - return x diff --git a/test/dynamo/mock_modules/mock_module3.py b/test/dynamo/mock_modules/mock_module3.py deleted file mode 100644 index 8af77a237a..0000000000 --- a/test/dynamo/mock_modules/mock_module3.py +++ /dev/null @@ -1,7 +0,0 @@ -import torch - - -def method1(x, y): - torch.ones(1, 1) - x.append(y) - return x diff --git a/test/dynamo/test_aot_autograd.py b/test/dynamo/test_aot_autograd.py deleted file mode 100644 index 9e02099862..0000000000 --- a/test/dynamo/test_aot_autograd.py +++ /dev/null @@ -1,139 +0,0 @@ -# Owner(s): ["module: dynamo"] -import functools - -import torch - -import torchdynamo -import torchdynamo.test_case -from torchdynamo.optimizations.training import is_aot_autograd_safe_to_run -from torchdynamo.testing import rand_strided - - -def compiler_safe_fn(gm, example_inputs, is_safe): - is_safe[0] = is_aot_autograd_safe_to_run(gm, example_inputs) - return gm.forward - - -class AotAutogradFallbackTests(torchdynamo.test_case.TestCase): - def test_LSTM(self): - # https://github.com/pytorch/torchdynamo/issues/1147 - class Repro(torch.nn.Module): - def __init__(self): - super().__init__() - self.self_mod_model_lstm_lstm = torch.nn.LSTM( - 64, 64, num_layers=2, bidirectional=True - ) - - def forward(self, permute: torch.Tensor): - self_mod_model_lstm_lstm = self.self_mod_model_lstm_lstm(permute) - return (self_mod_model_lstm_lstm,) - - is_safe = [True] - mod = Repro() - compiler_fn = functools.partial(compiler_safe_fn, is_safe=is_safe) - aot_mod = torchdynamo.optimize(compiler_fn)(mod) - - args = [((92, 4, 64), (1, 5888, 92), torch.float32, "cpu", False)] - args = [ - rand_strided(sh, st, dt, dev).requires_grad_(rg) - for (sh, st, dt, dev, rg) in args - ] - - aot_mod(*args) - self.assertTrue(not is_safe[0]) - - def test_mutation(self): - # https://github.com/pytorch/torchdynamo/issues/1301 - def fn(param, y): - prev_grad = torch.is_grad_enabled() - try: - torch.set_grad_enabled(False) - param.add_(y) - finally: - torch.set_grad_enabled(prev_grad) - return y - - y = torch.randn(4) - x = torch.nn.Parameter(torch.randn(4)) - is_safe = [True] - compiler_fn = functools.partial(compiler_safe_fn, is_safe=is_safe) - aot_fn = torchdynamo.optimize(compiler_fn)(fn) - aot_fn(x, y) - self.assertTrue(not is_safe[0]) - - def test_mutation1(self): - def fn(_stack0: torch.Tensor, diagonal_chunked_attention_scores: torch.Tensor): - getitem = diagonal_chunked_attention_scores[ - ( - slice(None, None, None), - slice(None, None, None), - slice(None, 256, None), - slice(None, 257, None), - ) - ] - _stack0[ - ( - slice(None, None, None), - slice(None, -1, None), - slice(None, None, None), - slice(256, None, None), - ) - ] = getitem - view = _stack0.view(1, 12, 1024, 513) - return (view,) - - x = torch.randn(torch.Size([12, 4, 256, 513])) - y = torch.randn(torch.Size([12, 3, 512, 513])) - is_safe = [True] - compiler_fn = functools.partial(compiler_safe_fn, is_safe=is_safe) - aot_fn = torchdynamo.optimize(compiler_fn)(fn) - aot_fn(x, y) - self.assertTrue(not is_safe[0]) - - def test_negative_testing_mutation(self): - def fn(_stack0: torch.Tensor, diagonal_chunked_attention_scores: torch.Tensor): - getitem = diagonal_chunked_attention_scores[ - ( - slice(None, None, None), - slice(None, None, None), - slice(None, 256, None), - slice(None, 257, None), - ) - ] - _stack0 = torch.sin(_stack0) - _stack0[ - ( - slice(None, None, None), - slice(None, -1, None), - slice(None, None, None), - slice(256, None, None), - ) - ] = getitem - view = _stack0.view(1, 12, 1024, 513) - return (view,) - - x = torch.randn(torch.Size([12, 4, 256, 513])) - y = torch.randn(torch.Size([12, 3, 512, 513])) - is_safe = [True] - compiler_fn = functools.partial(compiler_safe_fn, is_safe=is_safe) - aot_fn = torchdynamo.optimize(compiler_fn)(fn) - aot_fn(x, y) - self.assertTrue(is_safe[0]) - - def test_negative_testing(self): - def fn(x, y): - return torch.sin(x).add_(y) - - y = torch.randn(4) - x = torch.randn(4) - is_safe = [True] - compiler_fn = functools.partial(compiler_safe_fn, is_safe=is_safe) - aot_fn = torchdynamo.optimize(compiler_fn)(fn) - aot_fn(x, y) - self.assertTrue(is_safe[0]) - - -if __name__ == "__main__": - from torchdynamo.test_case import run_tests - - run_tests() diff --git a/test/dynamo/test_aot_cudagraphs.py b/test/dynamo/test_aot_cudagraphs.py deleted file mode 100644 index 2392a08478..0000000000 --- a/test/dynamo/test_aot_cudagraphs.py +++ /dev/null @@ -1,209 +0,0 @@ -# Owner(s): ["module: cuda graphs"] - -import functools -import unittest -from unittest.mock import patch - -import torch -from torch.testing._internal.common_utils import TEST_WITH_ROCM - -import torchdynamo -import torchdynamo.test_case -import torchdynamo.testing -from torchdynamo.testing import same - - -def composed(*decs): - def deco(f): - for dec in reversed(decs): - f = dec(f) - return f - - return deco - - -def assert_aot_autograd_counter(ok=True): - def deco(f): - @functools.wraps(f) - def wrap(self, *args, **kwargs): - torchdynamo.utils.counters.clear() - r = f(self, *args, **kwargs) - c_ok = torchdynamo.utils.counters["aot_autograd"]["ok"] - c_not_ok = torchdynamo.utils.counters["aot_autograd"]["not_ok"] - if ok: - self.assertGreater(c_ok, 0) - self.assertEqual(c_not_ok, 0) - else: - self.assertEqual(c_ok, 0) - self.assertGreater(c_not_ok, 0) - return r - - return wrap - - return deco - - -def patch_all(ok=True): - return composed( - unittest.skipIf(TEST_WITH_ROCM, "ROCm not supported"), - patch("torchdynamo.config.verify_correctness", True), - assert_aot_autograd_counter(ok), - ) - - -N_ITERS = 5 - - -@unittest.skipIf(not torch.cuda.is_available(), "these tests require cuda") -class TestAotCudagraphs(torchdynamo.test_case.TestCase): - @patch_all() - def test_basic(self): - def model(x, y): - return (x + y) * y - - @torchdynamo.optimize("aot_cudagraphs") - def fn(x, y): - for i in range(N_ITERS): - loss = model(x, y).sum() - loss.backward() - - x = torch.randn(3, device="cuda", requires_grad=True) - y = torch.randn(3, device="cuda") - fn(x, y) - - @patch_all() - def test_dtoh(self): - def model(x, y): - a = x + y - b = a.cpu() * 3 - return b - - @torchdynamo.optimize("aot_cudagraphs") - def fn(x, y): - for i in range(N_ITERS): - loss = model(x, y).sum() - loss.backward() - - x = torch.randn(3, device="cuda", requires_grad=True) - y = torch.randn(3, device="cuda") - fn(x, y) - - @patch_all() - def test_htod(self): - def model(x, y): - a = x + y - return a * 3 - - @torchdynamo.optimize("aot_cudagraphs") - def fn(x, y): - for i in range(N_ITERS): - loss = model(x, y).sum() - loss.backward() - - x = torch.randn(3, device="cuda", requires_grad=True) - y = torch.randn((), device="cpu") - fn(x, y) - - @patch("functorch._src.config.use_functionalize", True) - @patch_all(ok=False) # input mutation not supported yet - def test_mutate_input(self): - def model(x, y): - y.add_(3) - return x * y - - @torchdynamo.optimize("aot_cudagraphs") - def fn(x, y): - for i in range(N_ITERS): - with self.subTest(i): - y_orig = y.clone() - loss = model(x, y).sum() - self.assertTrue(same(y, y_orig + 3)) - loss.backward() - - x = torch.randn(3, device="cuda", requires_grad=True) - y = torch.randn(3, device="cuda") - fn(x, y) - - @patch_all() - def test_mutate_constant(self): - def model(x, y): - c = torch.tensor(1) - c.add_(2) - return x * y * 0 + c - - @torchdynamo.optimize("aot_cudagraphs") - def fn(x, y): - for i in range(N_ITERS): - with self.subTest(i): - loss = model(x, y).sum() - self.assertTrue(same(loss, torch.tensor(3.0, device="cuda"))) - loss.backward() - - x = torch.randn(1, device="cuda", requires_grad=True) - y = torch.randn(1, device="cuda") - fn(x, y) - - @patch_all() - def test_factory(self): - def model(y): - x = torch.zeros(3, device="cuda:0") - x.add_(3) - return x * y - - @torchdynamo.optimize("aot_cudagraphs") - def fn(y): - for i in range(N_ITERS): - with self.subTest(i): - loss = model(y).sum() - loss.backward() - - y = torch.randn(3, device="cuda:0", requires_grad=True) - fn(y) - - @patch("functorch._src.config.use_functionalize", True) - @patch_all() - def test_mutated_metadata(self): - # more tortured example at - # https://github.com/pytorch/pytorch/issues/81385 - def model(x): - x = x.clone() - x.resize_(20) - x.fill_(2) - return x - - @torchdynamo.optimize("aot_cudagraphs") - def fn(x): - for i in range(N_ITERS): - with self.subTest(i): - rx = model(x) - self.assertTrue(same(rx, torch.full((20,), 2.0, device="cuda:0"))) - - x = torch.empty(0, device="cuda:0") - fn(x) - - @patch("functorch._src.config.use_functionalize", True) - @patch_all() - def test_dead_fill(self): - def model(x): - x = x.clone() - y = x[0:0] - x.fill_(2) - y.fill_(3) - return x, y - - @torchdynamo.optimize("aot_cudagraphs") - def fn(x): - for i in range(N_ITERS): - with self.subTest(i): - rx, ry = model(x) - self.assertTrue(same(rx, torch.full((20,), 2.0, device="cuda:0"))) - self.assertTrue(same(ry, torch.empty(0, device="cuda:0"))) - - x = torch.empty(20, device="cuda:0") - fn(x) - - -if __name__ == "__main__": - from torchdynamo.test_case import run_tests - - run_tests() diff --git a/test/dynamo/test_distributed.py b/test/dynamo/test_distributed.py deleted file mode 100644 index 754396950e..0000000000 --- a/test/dynamo/test_distributed.py +++ /dev/null @@ -1,287 +0,0 @@ -# Owner(s): ["module: dynamo"] -import os -import unittest -from unittest.mock import patch - -import pytest -import torch -import torch.distributed as dist -from torch import nn - -import torchdynamo -import torchdynamo.test_case -from torchdynamo import config -from torchdynamo.testing import same - - -class ToyModel(nn.Module): - def __init__(self, in_feat=10, hidden_feat=5000, num_hidden=2, out_feat=5): - super().__init__() - self.net = nn.Sequential( - *[nn.Linear(in_feat, hidden_feat), nn.ReLU()] - + [nn.Linear(5000, 5000), nn.ReLU()] * num_hidden - + [nn.Linear(5000, 5), nn.ReLU()] - ) - - def forward(self, inputs): - return self.net(inputs) - - -class CheckSplitsCompiler: - def __init__(self): - self.compiler_called = 0 - - def compile_fn(self, gm, example_inputs): - self.compiler_called += 1 - return gm - - -def skip_if_no_active_ddp(): - from torch.nn.parallel import DistributedDataParallel as DDP - - if not hasattr(DDP, "_get_active_ddp_module"): - raise unittest.SkipTest("requires pytorch landing in parallel") - - -@pytest.mark.skip("Module hangs in PyTorch CI") -class TestDistributed(torchdynamo.test_case.TestCase): - """ - Test harness initializes dist process group - """ - - @classmethod - def setUpClass(cls): - super().setUpClass() - # _exit_stack is set up in TestCase - cls._exit_stack.enter_context( - patch.dict( - os.environ, - { - "MASTER_ADDR": "localhost", - "MASTER_PORT": "12355", - }, - ) - ) - cls.rank = 0 - cls.device = f"cpu:{cls.rank}" - cls.device_ids = None if "cpu" in cls.device else [cls.rank] - dist.init_process_group("gloo", rank=cls.rank, world_size=1) - - @classmethod - def tearDownClass(cls): - dist.destroy_process_group() - super().tearDownClass() - - def get_model(self): - m = ToyModel().to(self.device) - inputs = torch.randn(20, 10).to(self.device) - outputs = m(inputs) - return m, inputs, outputs - - @patch.object(config, "optimize_ddp", False) - def test_ddp_baseline_aot_eager(self): - from torch.nn.parallel import DistributedDataParallel as DDP - - m, inputs, correct_outputs = self.get_model() - ddp_m = DDP(m, device_ids=self.device_ids) - ddp_m = torchdynamo.optimize("aot_eager")(ddp_m) - outputs = ddp_m(inputs) - self.assertTrue(same(correct_outputs, outputs)) - - @patch.object(config, "optimize_ddp", False) - def test_ddp_baseline_inductor(self): - from torch.nn.parallel import DistributedDataParallel as DDP - - m, inputs, correct_outputs = self.get_model() - ddp_m = DDP(m, device_ids=self.device_ids) - ddp_m = torchdynamo.optimize("inductor")(ddp_m) - outputs = ddp_m(inputs) - self.assertTrue(same(correct_outputs, outputs)) - - # can't run with gloo (no support for _allgather_base) and nccl not available in CI - @pytest.mark.xfail - @patch.object(config, "optimize_ddp", False) - def test_fsdp_baseline_aot_eager(self): - from torch.distributed.fsdp import FullyShardedDataParallel as FSDP - - m, inputs, correct_outputs = self.get_model() - fsdp_m = FSDP(m, device_id=self.device_ids[0] if self.device_ids else None) - fsdp_m = torchdynamo.optimize("aot_eager")(fsdp_m) - outputs = fsdp_m(inputs) - self.assertTrue(same(correct_outputs, outputs)) - - # hangs/crashes with inductor currently - @pytest.mark.skip - @patch.object(config, "optimize_ddp", False) - def test_fsdp_baseline_inductor(self): - from torch.distributed.fsdp import FullyShardedDataParallel as FSDP - - m, inputs, correct_outputs = self.get_model() - fsdp_m = FSDP(m, device_id=self.device_ids[0] if self.device_ids else None) - fsdp_m = torchdynamo.optimize("inductor")(fsdp_m) - outputs = fsdp_m(inputs) - self.assertTrue(same(correct_outputs, outputs)) - - @patch.object(config, "optimize_ddp", True) - def test_graph_split(self): - """ - Just ensures that the appropriate number of splits happen (based on - bucket size and model parameters) - verifies the number of times - the user-provided compiler is called by the DDPOptimizer which is - doing the graph splitting - """ - from torch.nn.parallel import DistributedDataParallel as DDP - - skip_if_no_active_ddp() - - m, inputs, correct_outputs = self.get_model() - ddp_m = DDP(m, device_ids=self.device_ids, bucket_cap_mb=25) - - check_splits_compiler = CheckSplitsCompiler() - - @torchdynamo.optimize(check_splits_compiler.compile_fn) - def opt_fn(inputs): - return ddp_m(inputs) - - opt_outputs = opt_fn(inputs) - self.assertTrue(same(correct_outputs, opt_outputs)) - self.assertEqual(check_splits_compiler.compiler_called, 3) - - # hangs/crashes with inductor currently - @pytest.mark.skip - @patch.object(config, "optimize_ddp", True) - def test_graph_split_inductor(self): - """ - Same as above, but using inductor backend. - We observed issues with inductor/fx interface in the past. - """ - from torch.nn.parallel import DistributedDataParallel as DDP - - skip_if_no_active_ddp() - m, inputs, correct_outputs = self.get_model() - ddp_m = DDP(m, device_ids=self.device_ids, bucket_cap_mb=25) - - @torchdynamo.optimize("inductor") - def opt_fn(inputs): - return ddp_m(inputs) - - opt_outputs = opt_fn(inputs) - self.assertTrue(same(correct_outputs, opt_outputs)) - - @patch.object(config, "optimize_ddp", True) - def test_no_split(self): - """ - Ensures the DDPOptimizer returns a correct, compiled module without - introducing graph splits. (Based on model parmeters fitting in the bucket) - """ - from torch.nn.parallel import DistributedDataParallel as DDP - - skip_if_no_active_ddp() - m, inputs, correct_outputs = self.get_model() - ddp_m = DDP(m, device_ids=self.device_ids, bucket_cap_mb=250) - - check_splits_compiler = CheckSplitsCompiler() - - @torchdynamo.optimize(check_splits_compiler.compile_fn) - def opt_fn(inputs): - return ddp_m(inputs) - - opt_outputs = opt_fn(inputs) - self.assertTrue(same(correct_outputs, opt_outputs)) - self.assertEqual(check_splits_compiler.compiler_called, 1) - - @patch.object(config, "optimize_ddp", True) - def test_aot_autograd(self): - """ - Explicitly check AotAutograd family of compilers work, - since they require example inputs propagated between graph splits. - """ - from torch.nn.parallel import DistributedDataParallel as DDP - - skip_if_no_active_ddp() - m, inputs, correct_outputs = self.get_model() - ddp_m = DDP(m, device_ids=self.device_ids, bucket_cap_mb=25) - - @torchdynamo.optimize("aot_eager") - def opt_fn(inputs): - return ddp_m(inputs) - - opt_outputs = opt_fn(inputs) - opt_outputs.sum().backward() - self.assertTrue(same(correct_outputs, opt_outputs)) - - @patch.object(config, "optimize_ddp", True) - def test_custom_layer(self): - """ - Just ensures that the appropriate number of splits happen (based on - bucket size and model parameters) - verifies the number of times - the user-provided compiler is called by the DDPOptimizer which is - doing the graph splitting - """ - from torch.nn.parallel import DistributedDataParallel as DDP - - skip_if_no_active_ddp() - - class MyCustomLinear(torch.nn.Module): - def __init__(self): - super(MyCustomLinear, self).__init__() - self.weight = nn.Parameter(torch.randn(512, 512)) - - def forward(self, x): - return torch.mm(x, self.weight.t()) - - class MyLinear(torch.nn.Module): - def __init__(self): - super(MyLinear, self).__init__() - self.linear = torch.nn.Linear(512, 512) - - def forward(self, x): - return self.linear(x) - - class MyModule(torch.nn.Module): - def __init__(self): - super(MyModule, self).__init__() - mods = [ - (MyLinear(), torch.nn.ReLU()), - # sandwitch the custom in the middle so it comes before and after - (MyCustomLinear(), torch.nn.ReLU()), - (MyLinear(), torch.nn.ReLU()), - ] - self.seq = torch.nn.Sequential(*[x for items in mods for x in items]) - - def forward(self, x): - return self.seq(x) - - m = MyModule().to(self.device) - inputs = torch.randn((512, 512)).to(self.device) - correct_outputs = m(inputs) - ddp_m = DDP(m, device_ids=self.device_ids, bucket_cap_mb=1) - - check_splits_compiler = CheckSplitsCompiler() - - @torchdynamo.optimize(check_splits_compiler.compile_fn) - def opt_fn(inputs): - return ddp_m(inputs) - - opt_outputs = opt_fn(inputs) - self.assertTrue(same(correct_outputs, opt_outputs)) - self.assertEqual(check_splits_compiler.compiler_called, 3) - - def test_empty_graph(self): - def fn(): - get_world_size = torch.distributed.distributed_c10d.get_world_size() - return (get_world_size,) - - opt_fn = torchdynamo.optimize("inductor")(fn) - res = None - try: - res = opt_fn()[0] - except Exception: - pass - self.assertEqual(res, 1) - - -# TODO(jansel): debug issues running this in CI -# if __name__ == "__main__": -# from torchdynamo.testing import run_tests -# run_tests() diff --git a/test/dynamo/test_dynamic_shapes.py b/test/dynamo/test_dynamic_shapes.py deleted file mode 100644 index b9159c466b..0000000000 --- a/test/dynamo/test_dynamic_shapes.py +++ /dev/null @@ -1,34 +0,0 @@ -# Owner(s): ["module: dynamo"] - -from torchdynamo.testing import make_test_cls_with_patches - -try: - from . import test_functions - from . import test_misc - from . import test_modules - from . import test_repros - from . import test_unspec -except ImportError: - import test_functions - import test_misc - import test_modules - import test_repros - import test_unspec - - -def make_dynamic_cls(cls): - return make_test_cls_with_patches( - cls, "DynamicShapes", "_dynamic_shapes", ("dynamic_shapes", True) - ) - - -DynamicShapesFunctionTests = make_dynamic_cls(test_functions.FunctionTests) -DynamicShapesMiscTests = make_dynamic_cls(test_misc.MiscTests) -DynamicShapesReproTests = make_dynamic_cls(test_repros.ReproTests) -DynamicShapesNNModuleTests = make_dynamic_cls(test_modules.NNModuleTests) -DynamicShapesUnspecTests = make_dynamic_cls(test_unspec.UnspecTests) - -if __name__ == "__main__": - from torchdynamo.test_case import run_tests - - run_tests() diff --git a/test/dynamo/test_export.py b/test/dynamo/test_export.py deleted file mode 100644 index 6044f7dac7..0000000000 --- a/test/dynamo/test_export.py +++ /dev/null @@ -1,1429 +0,0 @@ -# Owner(s): ["module: dynamo"] -from unittest.mock import patch - -import torch -import torch.utils._pytree as pytree -from torch.fx.experimental.proxy_tensor import make_fx - -import torchdynamo.test_case -import torchdynamo.testing - - -class ExportTests(torchdynamo.test_case.TestCase): - # TODO(voz): Refactor to a shared test function. - # The tests in this file are a little redundant, - # They all take a func, run it with eager, then export it, then compare - def test_export(self): - def pre_attention_state_ops(input, mems, state): - lc_key = state[0] - lc_val = state[1] - bar = [] - for i in range(0, 4): - bar2 = [] - for j in range(0, 3): - bar2.append( - lc_key + lc_val + torch.tensor([0.1, 0.25, 0.4, 0.5, 0.1]) - ) - bar.append(bar2) - - return bar - - def func(): - mems = torch.tensor([[[1.8364, 0.2724, -1.4917, -0.4367, 0.8640]]]) - state = [ - torch.tensor([[[1.0517, 0.3848, -0.6472, 0.0823, 0.9116]]]), - torch.tensor([[[1.0517, 0.3848, -0.6472, 0.0823, 0.9116]]]), - ] - i = torch.tensor( - [ - [0.0313, -0.1487, -0.3846, -0.5321], - [-1.7073, 1.3331, -0.0890, -1.4935], - [-0.8314, -0.1862, -0.5935, 1.5232], - ] - ) - return pre_attention_state_ops(i, mems, state) - - opt_func = torchdynamo.optimize("eager", nopython=True)(func) - real_result = opt_func() - - torchdynamo.reset() - - exported = torchdynamo.export(func) - out_graph = exported[0] - - dynamo_result = out_graph() - self.assertTrue(torchdynamo.utils.same(real_result, dynamo_result)) - - def test_export_mismatched_out(self): - def func(x): - y = x + 1 - return ([x, x], (y, y)) - - opt_func = torchdynamo.optimize("eager", nopython=True)(func) - real_result = opt_func(torch.tensor([[[1.3737, 0.1]]])) - - torchdynamo.reset() - - exported = torchdynamo.export(func, torch.tensor([[[1.3737, 0.1]]])) - out_graph = exported[0] - - dynamo_result = out_graph(torch.tensor([[[1.3737, 0.1]]])) - - self.assertTrue(torchdynamo.utils.same(real_result, dynamo_result)) - - def test_export_graph_bypass(self): - inp = [ - torch.tensor([0.1, 0.1]), - torch.tensor([0.2, 0.2]), - torch.tensor([0.3, 0.3]), - ] - - def func(x): - first = x[2] - second = x[2] - return first * second - - opt_func = torchdynamo.optimize("eager", nopython=True)(func) - real_result = opt_func(inp) - - torchdynamo.reset() - - exported = torchdynamo.export(func, inp) - out_graph = exported[0] - flat_input, _ = pytree.tree_flatten(inp) - - dynamo_result = out_graph(*flat_input) - - self.assertTrue(torchdynamo.utils.same(real_result, dynamo_result)) - - def test_list_unpack(self): - inp = [ - torch.tensor([0.1, 0.1]), - torch.tensor([0.2, 0.2]), - torch.tensor([0.3, 0.3]), - ] - - def func(x): - first = x[2] - second = x[2] - return x[0], first * second, x[1], x[2] - - opt_func = torchdynamo.optimize("eager", nopython=True)(func) - real_result = opt_func(inp) - - torchdynamo.reset() - - exported = torchdynamo.export(func, inp) - out_graph = exported[0] - flat_input, _ = pytree.tree_flatten(inp) - - dynamo_result = out_graph(*flat_input) - - self.assertTrue(torchdynamo.utils.same(real_result, dynamo_result)) - - def test_export_mismatched_out_2(self): - def func(x): - y = x + 1 - return ([x, x], (y, y)) - - opt_func = torchdynamo.optimize("eager", nopython=True)(func) - real_result = opt_func(torch.tensor([[[1.3737, 0.1]]])) - - torchdynamo.reset() - - exported = torchdynamo.export(func, torch.tensor([[[1.3737, 0.1]]])) - out_graph = exported[0] - - dynamo_result = out_graph(torch.tensor([[[1.3737, 0.1]]])) - - self.assertTrue(torchdynamo.utils.same(real_result, dynamo_result)) - - def test_export_graph_with_list(self): - inp = [ - torch.tensor([0.1, 0.1]), - torch.tensor([0.2, 0.2]), - torch.tensor([0.3, 0.3]), - torch.tensor([0.4, 0.4]), - ] - - def func(x): - first = x[2] - second = x[2] - return first * second, x - - opt_func = torchdynamo.optimize("eager", nopython=True)(func) - real_result = opt_func(inp) - - torchdynamo.reset() - - exported = torchdynamo.export(func, inp) - out_graph = exported[0] - flat_input, _ = pytree.tree_flatten(inp) - - dynamo_result = out_graph(*flat_input) - - self.assertTrue(torchdynamo.utils.same(real_result, dynamo_result)) - - def test_export_graph_with_complex_reorder(self): - inp = [ - torch.tensor([0.1, 0.1]), - torch.tensor([0.2, 0.2]), - torch.tensor([0.3, 0.3]), - torch.tensor([0.4, 0.4]), - ] - - def func(x): - first = x[0] - second = x[1] - third = x[2] - return third, first, second, first * second, first * third - - opt_func = torchdynamo.optimize("eager", nopython=True)(func) - real_result = opt_func(inp) - - torchdynamo.reset() - - exported = torchdynamo.export(func, inp) - out_graph = exported[0] - flat_input, _ = pytree.tree_flatten(inp) - - dynamo_result = out_graph(*flat_input) - - self.assertTrue(torchdynamo.utils.same(real_result, dynamo_result)) - - def test_dupes(self): - inp = torch.tensor([0.1, 0.1]) - - def func(x): - y = x + 1 - return y, y - - opt_func = torchdynamo.optimize("eager", nopython=True)(func) - real_result = opt_func(inp) - - torchdynamo.reset() - - exported = torchdynamo.export(func, inp) - out_graph = exported[0] - flat_input, _ = pytree.tree_flatten(inp) - - dynamo_result = out_graph(*flat_input) - - self.assertTrue(torchdynamo.utils.same(real_result, dynamo_result)) - - def test_dupes_2(self): - inp = torch.tensor([0.1, 0.1]) - - def func(x): - y = x + 1 - return y, y - - opt_func = torchdynamo.optimize("eager", nopython=True)(func) - real_result = opt_func(inp) - - torchdynamo.reset() - - exported = torchdynamo.export(func, inp) - out_graph = exported[0] - flat_input, _ = pytree.tree_flatten(inp) - - dynamo_result = out_graph(*flat_input) - - self.assertTrue(torchdynamo.utils.same(real_result, dynamo_result)) - - def test_dupes_and_bypass(self): - inp = torch.tensor([0.1, 0.1]) - inp2 = torch.tensor([0.4, 0.4]) - inps = [inp, inp2] - - def func(x, z): - y = x + 1 - return y, y, z - - opt_func = torchdynamo.optimize("eager", nopython=True)(func) - real_result = opt_func(*inps) - - torchdynamo.reset() - - exported = torchdynamo.export(func, *inps) - out_graph = exported[0] - flat_input, _ = pytree.tree_flatten(inps) - - dynamo_result = out_graph(*flat_input) - - self.assertTrue(torchdynamo.utils.same(real_result, dynamo_result)) - - def test_dupes_and_bypass_with_non_tensor_arg(self): - inp = torch.tensor([0.1, 0.1]) - inp2 = torch.tensor([0.1, 0.1]) - inp3 = 4 - inps = [inp, inp2, inp3] - - def func(x, z, k): - y = x + k - return y, y, z - - opt_func = torchdynamo.optimize("eager", nopython=True)(func) - real_result = opt_func(*inps) - - torchdynamo.reset() - - exported = torchdynamo.export(func, *inps) - out_graph = exported[0] - flat_input, _ = pytree.tree_flatten(inps) - - dynamo_result = out_graph(*flat_input) - - self.assertTrue(torchdynamo.utils.same(real_result, dynamo_result)) - - def test_dupes_and_bypass_reorder_with_non_tensor_arg(self): - inp = torch.tensor([0.1, 0.1]) - inp2 = torch.tensor([0.1, 0.1]) - inp3 = 4 - inps = [inp, inp2, inp3] - - def func(x, z, k): - y = x + k - return z, y, y - - opt_func = torchdynamo.optimize("eager", nopython=True)(func) - real_result = opt_func(*inps) - - torchdynamo.reset() - - exported = torchdynamo.export(func, *inps) - out_graph = exported[0] - flat_input, _ = pytree.tree_flatten(inps) - - dynamo_result = out_graph(*flat_input) - - self.assertTrue(torchdynamo.utils.same(real_result, dynamo_result)) - - @patch.object(torchdynamo.config, "capture_scalar_outputs", True) - def test_dupes_and_bypass_with_non_tensor_output(self): - inp = torch.tensor([0.1, 0.1]) - inp2 = torch.tensor([0.1, 0.1]) - inp3 = 4 - inps = [inp, inp2, inp3] - - def func(x, z, k): - y = x + k - return y[0].item(), y, z - - opt_func = torchdynamo.optimize("eager", nopython=True)(func) - real_result = opt_func(*inps) - - torchdynamo.reset() - - exported = torchdynamo.export(func, *inps) - out_graph = exported[0] - flat_input, _ = pytree.tree_flatten(inps) - - dynamo_result = out_graph(*flat_input) - - self.assertTrue(torchdynamo.utils.same(real_result, dynamo_result)) - - def test_zeroes_in_and_out_different_shape_on_test(self): - inp = torch.zeros(10) - inp2 = torch.zeros(10) - inp3 = torch.zeros(10) - inps = [inp, inp2, inp3] - - inps_rand = [torch.randn(10), torch.randn(10), torch.randn(10)] - - def func(a, b, c): - return [[a], [b, c], [a + b], [[c + c]]] - - opt_func = torchdynamo.optimize("eager", nopython=True)(func) - real_result = opt_func(*inps_rand) - - torchdynamo.reset() - - exported = torchdynamo.export(func, *inps) - out_graph = exported[0] - flat_input, _ = pytree.tree_flatten(inps_rand) - - dynamo_result = out_graph(*flat_input) - - self.assertTrue(torchdynamo.utils.same(real_result, dynamo_result)) - - @patch.object(torchdynamo.config, "capture_scalar_outputs", True) - def test_zeroes_in_new_shape_scalar_out(self): - inp = torch.zeros(10) - inp2 = torch.zeros(10) - inp3 = torch.zeros(10) - inps = [inp, inp2, inp3] - - inps_rand = [torch.randn(10), torch.randn(10), torch.randn(10)] - - def func(a, b, c): - return a[0].item() + b[0].item() + c[0].item() - - opt_func = torchdynamo.optimize("eager", nopython=True)(func) - real_result = opt_func(*inps_rand) - - torchdynamo.reset() - - exported = torchdynamo.export(func, *inps) - out_graph = exported[0] - flat_input, _ = pytree.tree_flatten(inps_rand) - - dynamo_result = out_graph(*flat_input) - - self.assertTrue(torchdynamo.utils.same(real_result, dynamo_result)) - - @patch.object(torchdynamo.config, "capture_scalar_outputs", True) - def test_zeroes_in_new_shape_scalar_out_permute(self): - inp = torch.zeros(10) - inp2 = torch.zeros(10) - inp3 = torch.zeros(10) - inps = [inp, inp2, inp3] - - inps_rand = [torch.randn(10), torch.randn(10), torch.randn(10)] - - def func(a, b, c): - return b[0].item() + c[0].item() + a[0].item() + a[0].item() - - opt_func = torchdynamo.optimize("eager", nopython=True)(func) - real_result = opt_func(*inps_rand) - - torchdynamo.reset() - - exported = torchdynamo.export(func, *inps) - out_graph = exported[0] - flat_input, _ = pytree.tree_flatten(inps_rand) - - dynamo_result = out_graph(*flat_input) - - self.assertTrue(torchdynamo.utils.same(real_result, dynamo_result)) - - @patch.object(torchdynamo.config, "capture_scalar_outputs", True) - def test_zeroes_in_new_shape_scalar_out_permute_dupe_and_bypass(self): - inp = torch.zeros(10) - inp2 = torch.zeros(10) - inp3 = torch.zeros(10) - inps = [inp, inp2, inp3] - - inps_rand = [torch.randn(10), torch.randn(10), torch.randn(10)] - - def func(a, b, c): - return a, b[0].item() + c[0].item() + a[0].item() + a[0].item(), a - - opt_func = torchdynamo.optimize("eager", nopython=True)(func) - real_result = opt_func(*inps_rand) - - torchdynamo.reset() - - exported = torchdynamo.export(func, *inps) - out_graph = exported[0] - flat_input, _ = pytree.tree_flatten(inps_rand) - - dynamo_result = out_graph(*flat_input) - - self.assertTrue(torchdynamo.utils.same(real_result, dynamo_result)) - - def test_func_return(self): - inp = torch.zeros(10) - inp2 = torch.zeros(10) - inp3 = torch.zeros(10) - inps = [inp, inp2, inp3] - - inps_rand = [torch.randn(10), torch.randn(10), torch.randn(10)] - - def func(a, b, c): - x = a + b + c - - def func2(y): - return x * y - - return func2(x) - - opt_func = torchdynamo.optimize("eager", nopython=True)(func) - real_result = opt_func(*inps_rand) - - torchdynamo.reset() - - exported = torchdynamo.export(func, *inps) - out_graph = exported[0] - flat_input, _ = pytree.tree_flatten(inps_rand) - - dynamo_result = out_graph(*flat_input) - - self.assertTrue(torchdynamo.utils.same(real_result, dynamo_result)) - - def test_dict_return(self): - inp = torch.zeros(10) - inp2 = torch.zeros(10) - inp3 = torch.zeros(10) - inps = [inp, inp2, inp3] - - inps_rand = [torch.randn(10), torch.randn(10), torch.randn(10)] - - def func(a, b, c): - x = a + b + c - return {"a": x} - - opt_func = torchdynamo.optimize("eager", nopython=True)(func) - real_result = opt_func(*inps_rand) - - torchdynamo.reset() - - exported = torchdynamo.export(func, *inps) - out_graph = exported[0] - flat_input, _ = pytree.tree_flatten(inps_rand) - - dynamo_result = out_graph(*flat_input) - - self.assertTrue(torchdynamo.utils.same(real_result, dynamo_result)) - - def test_export_with_aten_graph(self): - def pre_attention_state_ops(input, mems, state): - lc_key = state[0] - lc_val = state[1] - bar = [] - for i in range(0, 4): - bar2 = [] - for j in range(0, 3): - bar2.append( - lc_key + lc_val + torch.tensor([0.1, 0.25, 0.4, 0.5, 0.1]) - ) - bar.append(bar2) - - return bar - - def func(): - mems = torch.tensor([[[1.8364, 0.2724, -1.4917, -0.4367, 0.8640]]]) - state = [ - torch.tensor([[[1.0517, 0.3848, -0.6472, 0.0823, 0.9116]]]), - torch.tensor([[[1.0517, 0.3848, -0.6472, 0.0823, 0.9116]]]), - ] - i = torch.tensor( - [ - [0.0313, -0.1487, -0.3846, -0.5321], - [-1.7073, 1.3331, -0.0890, -1.4935], - [-0.8314, -0.1862, -0.5935, 1.5232], - ] - ) - return pre_attention_state_ops(i, mems, state) - - opt_func = torchdynamo.optimize("eager", nopython=True)(func) - real_result = opt_func() - - torchdynamo.reset() - - exported = torchdynamo.export(func, aten_graph=True) - out_graph = exported[0] - - dynamo_result = out_graph() - self.assertTrue(torchdynamo.utils.same(real_result, dynamo_result)) - - def test_export_mismatched_out_with_aten_graph(self): - def func(x): - y = x + 1 - return ([x, x], (y, y)) - - opt_func = torchdynamo.optimize("eager", nopython=True)(func) - real_result = opt_func(torch.tensor([[[1.3737, 0.1]]])) - - torchdynamo.reset() - - exported = torchdynamo.export( - func, torch.tensor([[[1.3737, 0.1]]]), aten_graph=True - ) - out_graph = exported[0] - - dynamo_result = out_graph(torch.tensor([[[1.3737, 0.1]]])) - - self.assertTrue(torchdynamo.utils.same(real_result, dynamo_result)) - - def test_export_graph_bypass_with_aten_graph(self): - inp = [ - torch.tensor([0.1, 0.1]), - torch.tensor([0.2, 0.2]), - torch.tensor([0.3, 0.3]), - ] - - def func(x): - first = x[2] - second = x[2] - return first * second - - opt_func = torchdynamo.optimize("eager", nopython=True)(func) - real_result = opt_func(inp) - - torchdynamo.reset() - - exported = torchdynamo.export(func, inp, aten_graph=True) - out_graph = exported[0] - flat_input, _ = pytree.tree_flatten(inp) - - dynamo_result = out_graph(*flat_input) - - self.assertTrue(torchdynamo.utils.same(real_result, dynamo_result)) - - def test_list_unpack_with_aten_graph(self): - inp = [ - torch.tensor([0.1, 0.1]), - torch.tensor([0.2, 0.2]), - torch.tensor([0.3, 0.3]), - ] - - def func(x): - first = x[2] - second = x[2] - return x[0], first * second, x[1], x[2] - - opt_func = torchdynamo.optimize("eager", nopython=True)(func) - real_result = opt_func(inp) - - torchdynamo.reset() - - exported = torchdynamo.export(func, inp, aten_graph=True) - out_graph = exported[0] - flat_input, _ = pytree.tree_flatten(inp) - - dynamo_result = out_graph(*flat_input) - - self.assertTrue(torchdynamo.utils.same(real_result, dynamo_result)) - - def test_export_mismatched_out_2_with_aten_graph(self): - def func(x): - y = x + 1 - return ([x, x], (y, y)) - - opt_func = torchdynamo.optimize("eager", nopython=True)(func) - real_result = opt_func(torch.tensor([[[1.3737, 0.1]]])) - - torchdynamo.reset() - - exported = torchdynamo.export( - func, torch.tensor([[[1.3737, 0.1]]]), aten_graph=True - ) - out_graph = exported[0] - - dynamo_result = out_graph(torch.tensor([[[1.3737, 0.1]]])) - - self.assertTrue(torchdynamo.utils.same(real_result, dynamo_result)) - - def test_export_graph_with_list_with_aten_graph(self): - inp = [ - torch.tensor([0.1, 0.1]), - torch.tensor([0.2, 0.2]), - torch.tensor([0.3, 0.3]), - torch.tensor([0.4, 0.4]), - ] - - def func(x): - first = x[2] - second = x[2] - return first * second, x - - opt_func = torchdynamo.optimize("eager", nopython=True)(func) - real_result = opt_func(inp) - - torchdynamo.reset() - - exported = torchdynamo.export(func, inp, aten_graph=True) - out_graph = exported[0] - flat_input, _ = pytree.tree_flatten(inp) - - dynamo_result = out_graph(*flat_input) - - self.assertTrue(torchdynamo.utils.same(real_result, dynamo_result)) - - def test_export_graph_with_complex_reorder_with_aten_graph(self): - inp = [ - torch.tensor([0.1, 0.1]), - torch.tensor([0.2, 0.2]), - torch.tensor([0.3, 0.3]), - torch.tensor([0.4, 0.4]), - ] - - def func(x): - first = x[0] - second = x[1] - third = x[2] - return third, first, second, first * second, first * third - - opt_func = torchdynamo.optimize("eager", nopython=True)(func) - real_result = opt_func(inp) - - torchdynamo.reset() - - exported = torchdynamo.export(func, inp, aten_graph=True) - out_graph = exported[0] - flat_input, _ = pytree.tree_flatten(inp) - - dynamo_result = out_graph(*flat_input) - - self.assertTrue(torchdynamo.utils.same(real_result, dynamo_result)) - - def test_dupes_with_aten_graph(self): - inp = torch.tensor([0.1, 0.1]) - - def func(x): - y = x + 1 - return y, y - - opt_func = torchdynamo.optimize("eager", nopython=True)(func) - real_result = opt_func(inp) - - torchdynamo.reset() - - exported = torchdynamo.export(func, inp, aten_graph=True) - out_graph = exported[0] - flat_input, _ = pytree.tree_flatten(inp) - - dynamo_result = out_graph(*flat_input) - - self.assertTrue(torchdynamo.utils.same(real_result, dynamo_result)) - - def test_dupes_2_with_aten_graph(self): - inp = torch.tensor([0.1, 0.1]) - - def func(x): - y = x + 1 - return y, y - - opt_func = torchdynamo.optimize("eager", nopython=True)(func) - real_result = opt_func(inp) - - torchdynamo.reset() - - exported = torchdynamo.export(func, inp, aten_graph=True) - out_graph = exported[0] - flat_input, _ = pytree.tree_flatten(inp) - - dynamo_result = out_graph(*flat_input) - - self.assertTrue(torchdynamo.utils.same(real_result, dynamo_result)) - - def test_dupes_and_bypass_with_aten_graph(self): - inp = torch.tensor([0.1, 0.1]) - inp2 = torch.tensor([0.4, 0.4]) - inps = [inp, inp2] - - def func(x, z): - y = x + 1 - return y, y, z - - opt_func = torchdynamo.optimize("eager", nopython=True)(func) - real_result = opt_func(*inps) - - torchdynamo.reset() - - exported = torchdynamo.export(func, *inps, aten_graph=True) - out_graph = exported[0] - flat_input, _ = pytree.tree_flatten(inps) - - dynamo_result = out_graph(*flat_input) - - self.assertTrue(torchdynamo.utils.same(real_result, dynamo_result)) - - def test_dupes_and_bypass_with_non_tensor_arg_with_aten_graph(self): - inp = torch.tensor([0.1, 0.1]) - inp2 = torch.tensor([0.1, 0.1]) - inp3 = 4 - inps = [inp, inp2, inp3] - - def func(x, z, k): - y = x + k - return y, y, z - - opt_func = torchdynamo.optimize("eager", nopython=True)(func) - real_result = opt_func(*inps) - - torchdynamo.reset() - - exported = torchdynamo.export(func, *inps, aten_graph=True) - out_graph = exported[0] - flat_input, _ = pytree.tree_flatten(inps) - - dynamo_result = out_graph(*flat_input) - - self.assertTrue(torchdynamo.utils.same(real_result, dynamo_result)) - - def test_dupes_and_bypass_reorder_with_non_tensor_arg_with_aten_graph(self): - inp = torch.tensor([0.1, 0.1]) - inp2 = torch.tensor([0.1, 0.1]) - inp3 = 4 - inps = [inp, inp2, inp3] - - def func(x, z, k): - y = x + k - return z, y, y - - opt_func = torchdynamo.optimize("eager", nopython=True)(func) - real_result = opt_func(*inps) - - torchdynamo.reset() - - exported = torchdynamo.export(func, *inps, aten_graph=True) - out_graph = exported[0] - flat_input, _ = pytree.tree_flatten(inps) - - dynamo_result = out_graph(*flat_input) - - self.assertTrue(torchdynamo.utils.same(real_result, dynamo_result)) - - @patch.object(torchdynamo.config, "capture_scalar_outputs", True) - def test_dupes_and_bypass_with_non_tensor_output_with_aten_graph(self): - inp = torch.tensor([0.1, 0.1]) - inp2 = torch.tensor([0.1, 0.1]) - inp3 = 4 - inps = [inp, inp2, inp3] - - def func(x, z, k): - y = x + k - return y[0].item(), y, z - - opt_func = torchdynamo.optimize("eager", nopython=True)(func) - real_result = opt_func(*inps) - - torchdynamo.reset() - - exported = torchdynamo.export(func, *inps) - out_graph = exported[0] - flat_input, _ = pytree.tree_flatten(inps) - - dynamo_result = out_graph(*flat_input) - - self.assertTrue(torchdynamo.utils.same(real_result, dynamo_result)) - - def test_zeroes_in_and_out_different_shape_on_test_with_aten_graph(self): - inp = torch.zeros(10) - inp2 = torch.zeros(10) - inp3 = torch.zeros(10) - inps = [inp, inp2, inp3] - - inps_rand = [torch.randn(10), torch.randn(10), torch.randn(10)] - - def func(a, b, c): - return [[a], [b, c], [a + b], [[c + c]]] - - opt_func = torchdynamo.optimize("eager", nopython=True)(func) - real_result = opt_func(*inps_rand) - - torchdynamo.reset() - - exported = torchdynamo.export(func, *inps, aten_graph=True) - out_graph = exported[0] - flat_input, _ = pytree.tree_flatten(inps_rand) - - dynamo_result = out_graph(*flat_input) - - self.assertTrue(torchdynamo.utils.same(real_result, dynamo_result)) - - def test_func_return_with_aten_graph(self): - inp = torch.zeros(10) - inp2 = torch.zeros(10) - inp3 = torch.zeros(10) - inps = [inp, inp2, inp3] - - inps_rand = [torch.randn(10), torch.randn(10), torch.randn(10)] - - def func(a, b, c): - x = a + b + c - - def func2(y): - return x * y - - return func2(x) - - opt_func = torchdynamo.optimize("eager", nopython=True)(func) - real_result = opt_func(*inps_rand) - - torchdynamo.reset() - - exported = torchdynamo.export(func, *inps, aten_graph=True) - out_graph = exported[0] - flat_input, _ = pytree.tree_flatten(inps_rand) - - dynamo_result = out_graph(*flat_input) - - self.assertTrue(torchdynamo.utils.same(real_result, dynamo_result)) - - def test_dict_return_with_aten_graph(self): - inp = torch.zeros(10) - inp2 = torch.zeros(10) - inp3 = torch.zeros(10) - inps = [inp, inp2, inp3] - - inps_rand = [torch.randn(10), torch.randn(10), torch.randn(10)] - - def func(a, b, c): - x = a + b + c - return {"a": x} - - opt_func = torchdynamo.optimize("eager", nopython=True)(func) - real_result = opt_func(*inps_rand) - - torchdynamo.reset() - - exported = torchdynamo.export(func, *inps, aten_graph=True) - out_graph = exported[0] - flat_input, _ = pytree.tree_flatten(inps_rand) - - dynamo_result = out_graph(*flat_input) - - self.assertTrue(torchdynamo.utils.same(real_result, dynamo_result)) - - def test_export_with_stack_trace(self): - inp = torch.tensor([0.1, 0.1]) - linear = torch.nn.Linear(2, 2) - - def func(x): - x = x + 1 - y = x.t() - y = y.relu() - y = linear(y) - return y - - exported = torchdynamo.export(func, inp, aten_graph=False) - out_graph = exported[0] - - for node in out_graph.graph.nodes: - if node.op not in {"placeholder", "output"}: - self.assertTrue(node.stack_trace is not None) - - torchdynamo.reset() - - exported = torchdynamo.export(func, inp, aten_graph=True) - out_graph = exported[0] - for node in out_graph.graph.nodes: - if node.op == "call_function": - self.assertTrue(node.stack_trace is not None) - - def test_export_compare_optimize_with_make_fx(self): - inp = torch.tensor([0.1, 0.1]) - linear = torch.nn.Linear(2, 2) - - def func(x): - x = x + 1 - y = x.t() - y = y.relu() - y = linear(y) - return y - - exported = torchdynamo.export(func, inp, aten_graph=True) - out_graph = exported[0] - export_result = out_graph(inp) - - torchdynamo.reset() - - def compiler(gm, sample_inputs): - aten_gm = make_fx(gm)(*sample_inputs) - - self.assertEqual(len(aten_gm.graph.nodes), len(out_graph.graph.nodes)) - for node1, node2 in zip(aten_gm.graph.nodes, out_graph.graph.nodes): - self.assertEqual(node1.op, node2.op) - if node1.op == "call_function": - self.assertEqual(node1.target, node2.target) - self.assertEqual(len(node1.args), len(node2.args)) - for arg1, arg2 in zip(node1.args, node2.args): - self.assertEqual(type(arg1), type(arg2)) - - return aten_gm.forward - - opt_func = torchdynamo.optimize(compiler, nopython=True)(func) - make_fx_result = opt_func(inp) - - self.assertTrue(torchdynamo.utils.same(make_fx_result, export_result)) - - @patch.object(torchdynamo.config, "fake_tensor_propagation", False) - def test_export_with_constant_method_on_module(self): - class MyModule(torch.nn.Module): - def __init__(self): - super().__init__() - self.param = torch.nn.Parameter(torch.rand(4, 2)) - self.linear = torch.nn.Linear(2, 2) - - @torchdynamo.assume_constant_result - def helper_fn(self, x): - return torch.nonzero(x) - - def forward(self, x): - y = torch.sin(x) - x = self.linear(x) - y = self.helper_fn(x) - return y - - module = MyModule() - real_result = module(torch.tensor([[1.0, 0], [0, 0]])) - module = MyModule() - graph, _ = torchdynamo.export(module, torch.tensor([[0.0, 0], [0, 0]])) - result = graph(torch.tensor([[1.0, 0.0], [0, 0]])) - self.assertTrue(torchdynamo.utils.same(result, real_result)) - result = graph(torch.tensor([[1, 0], [0.25, 0.25]])) - self.assertTrue(torchdynamo.utils.same(result, real_result)) - - @patch.object(torchdynamo.config, "fake_tensor_propagation", False) - def test_export_with_constant_method_on_module_invoke_twice(self): - class MyModule(torch.nn.Module): - def __init__(self): - super().__init__() - self.param = torch.nn.Parameter(torch.rand(4, 2)) - self.linear = torch.nn.Linear(2, 2) - - @torchdynamo.assume_constant_result - def helper_fn(self, x): - return torch.nonzero(x) - - def forward(self, x): - y = torch.sin(x) - x = self.linear(x) - y = self.helper_fn(x) + self.helper_fn(x) - return y - - module = MyModule() - real_result = module(torch.tensor([[1.0, 0], [0, 0]])) - module = MyModule() - graph, _ = torchdynamo.export(module, torch.tensor([[0.0, 0], [0, 0]])) - result = graph(torch.tensor([[1.0, 0.0], [0, 0]])) - self.assertTrue(torchdynamo.utils.same(result, real_result)) - result = graph(torch.tensor([[1, 0], [0.25, 0.25]])) - self.assertTrue(torchdynamo.utils.same(result, real_result)) - - @patch.object(torchdynamo.config, "fake_tensor_propagation", False) - def test_export_with_constant_free_function(self): - @torchdynamo.assume_constant_result - def helper_fn(x): - return torch.nonzero(x) - - class MyModule(torch.nn.Module): - def __init__(self): - super().__init__() - self.param = torch.nn.Parameter(torch.rand(4, 2)) - self.linear = torch.nn.Linear(2, 2) - - @torchdynamo.assume_constant_result - def helper_fn(self, x): - return torch.nonzero(x) - - def forward(self, x): - y = torch.sin(x) - x = self.linear(x) - y = helper_fn(x) + self.helper_fn(x) - return y - - module = MyModule() - real_result = module(torch.tensor([[1.0, 0], [0, 0]])) - module = MyModule() - graph, _ = torchdynamo.export(module, torch.tensor([[0.0, 0], [0, 0]])) - result = graph(torch.tensor([[1.0, 0.0], [0, 0]])) - self.assertTrue(torchdynamo.utils.same(result, real_result)) - result = graph(torch.tensor([[1, 0], [0.25, 0.25]])) - self.assertTrue(torchdynamo.utils.same(result, real_result)) - - @patch.object(torchdynamo.config, "fake_tensor_propagation", False) - def test_export_with_constant_free_function_and_class_method(self): - @torchdynamo.assume_constant_result - def helper_fn(x): - return torch.nonzero(x) - - class MyModule(torch.nn.Module): - def __init__(self): - super().__init__() - self.param = torch.nn.Parameter(torch.rand(4, 2)) - self.linear = torch.nn.Linear(2, 2) - - def forward(self, x): - y = torch.sin(x) - x = self.linear(x) - y = helper_fn(x) - return y - - module = MyModule() - real_result = module(torch.tensor([[1.0, 0], [0, 0]])) - module = MyModule() - graph, _ = torchdynamo.export(module, torch.tensor([[0.0, 0], [0, 0]])) - result = graph(torch.tensor([[1.0, 0.0], [0, 0]])) - self.assertTrue(torchdynamo.utils.same(result, real_result)) - result = graph(torch.tensor([[1, 0], [0.25, 0.25]])) - self.assertTrue(torchdynamo.utils.same(result, real_result)) - - @patch.object(torchdynamo.config, "fake_tensor_propagation", False) - def test_export_with_constant_free_function_and_class_method_multiarg(self): - @torchdynamo.assume_constant_result - def helper_fn(x): - return torch.nonzero(x) - - class MyModule(torch.nn.Module): - def __init__(self): - super().__init__() - self.param = torch.nn.Parameter(torch.rand(4, 2)) - self.linear = torch.nn.Linear(2, 2) - - def forward(self, x, z): - y = torch.sin(x) - x = self.linear(x) - y = helper_fn(x) + helper_fn(z) - return y - - module = MyModule() - real_result = module( - torch.tensor([[1.0, 0], [0, 0]]), torch.tensor([[1.0, 0], [0, 0]]) - ) - module = MyModule() - graph, _ = torchdynamo.export( - module, torch.tensor([[0.0, 0], [0, 0]]), torch.tensor([[1.0, 0], [0, 0]]) - ) - result = graph( - torch.tensor([[1.0, 0.0], [0, 0]]), torch.tensor([[1.0, 0.0], [0, 0]]) - ) - self.assertTrue(torchdynamo.utils.same(result, real_result)) - result = graph( - torch.tensor([[1, 0], [0.25, 0.25]]), torch.tensor([[1, 0], [0.25, 0.25]]) - ) - self.assertTrue(torchdynamo.utils.same(result, real_result)) - - @patch.object(torchdynamo.config, "fake_tensor_propagation", False) - def test_export_with_constant_free_function_and_class_method_multiarg_diff(self): - @torchdynamo.assume_constant_result - def helper_fn(x): - return torch.nonzero(x) - - class MyModule(torch.nn.Module): - def __init__(self): - super().__init__() - - def forward(self, x, z): - y = helper_fn(x) + helper_fn(z) - return y - - module = MyModule() - real_result = module( - torch.tensor([[1.0, 0], [0, 0]]), torch.tensor([[1.0, 0], [0, 0]]) - ) - module = MyModule() - graph, _ = torchdynamo.export( - module, torch.tensor([[0.0, 0], [0, 0]]), torch.tensor([[0.0, 0], [0.5, 0]]) - ) - result = graph( - torch.tensor([[1.0, 0.0], [0, 0]]), torch.tensor([[0.0, 1.0], [0, 0]]) - ) - self.assertTrue(torchdynamo.utils.same(result, real_result)) - result = graph( - torch.tensor([[1, 0], [0.25, 0.25]]), - torch.tensor([[0.33, 0.33], [0.25, 0.25]]), - ) - self.assertTrue(torchdynamo.utils.same(result, real_result)) - - @patch.object(torchdynamo.config, "fake_tensor_propagation", False) - def test_export_with_constant_tuple_nonzero(self): - class MyModule(torch.nn.Module): - @torchdynamo.assume_constant_result - def helper_fn(self, x): - return (torch.nonzero(x), torch.nonzero(x)) - - def forward(self, x): - y = torch.tensor([0.5]) - elements = self.helper_fn(x) - all_y = [] - for element in elements: - for item in element: - all_y.append(y * item) - return all_y - - module = MyModule() - real_result = module(torch.tensor([1.0, 1.0])) - graph, guards = torchdynamo.export(module, torch.tensor([1.0, 1.0])) - - # Tensor input can be almost anything here, and the result will capture what we - # made constant at compile time. - result = graph(torch.tensor([[[1.0, 0], [0, 0]], [[1.0, 0], [0, 0]]])) - self.assertTrue(torchdynamo.utils.same(result, real_result)) - - @patch.object(torchdynamo.config, "fake_tensor_propagation", False) - def test_export_with_constant_list_nonzero(self): - class MyModule(torch.nn.Module): - @torchdynamo.assume_constant_result - def helper_fn(self, x): - return [torch.nonzero(x), torch.nonzero(x)] - - def forward(self, x): - y = torch.tensor([0.5]) - elements = self.helper_fn(x) - all_y = [] - for element in elements: - for item in element: - all_y.append(y * item) - return all_y - - module = MyModule() - real_result = module(torch.tensor([1.0, 1.0])) - graph, guards = torchdynamo.export(module, torch.tensor([1.0, 1.0])) - - # Tensor input can be almost anything here, and the result will capture what we - # made constant at compile time. - result = graph(torch.tensor([[[1.0, 0], [0, 0]], [[1.0, 0], [0, 0]]])) - self.assertTrue(torchdynamo.utils.same(result, real_result)) - - @patch.object(torchdynamo.config, "fake_tensor_propagation", False) - def test_export_with_constant_list_nonzero_free_function(self): - @torchdynamo.assume_constant_result - def helper_fn(x): - return [torch.nonzero(x), torch.nonzero(x)] - - class MyModule(torch.nn.Module): - def forward(self, x): - y = torch.tensor([0.5]) - elements = helper_fn(x) - all_y = [] - for element in elements: - for item in element: - all_y.append(y * item) - return all_y - - module = MyModule() - real_result = module(torch.tensor([1.0, 1.0])) - graph, guards = torchdynamo.export(module, torch.tensor([1.0, 1.0])) - - # Tensor input can be almost anything here, and the result will capture what we - # made constant at compile time. - result = graph(torch.tensor([[[1.0, 0], [0, 0]], [[1.0, 0], [0, 0]]])) - self.assertTrue(torchdynamo.utils.same(result, real_result)) - - @patch.object(torchdynamo.config, "fake_tensor_propagation", False) - def test_export_with_constant_dict_values(self): - class MyModule(torch.nn.Module): - @torchdynamo.assume_constant_result - def helper_fn(self, x): - return {"x": x, "x^2": x * x} - - def forward(self, x): - y = torch.tensor([0.5]) - elements = self.helper_fn(x) - y = y * elements["x"] - y = y * elements["x^2"] - return y - - module = MyModule() - real_result = module(torch.tensor([2.0, 2.0])) - graph, guards = torchdynamo.export(module, torch.tensor([2.0, 2.0])) - - # Tensor input can be almost anything here, and the result will capture what we - # made constant at compile time. - result = graph(torch.tensor([[[1.0, 0], [0, 0]], [[1.0, 0], [0, 0]]])) - self.assertTrue(torchdynamo.utils.same(result, real_result)) - - @patch.object(torchdynamo.config, "fake_tensor_propagation", False) - def test_export_with_constant_none_control_flow(self): - class MyModule(torch.nn.Module): - @torchdynamo.assume_constant_result - def helper_fn(self, x): - if x.item() < 0: - return None - else: - return x - - def forward(self, x): - y = torch.tensor([0.5]) - x = self.helper_fn(x) - if x is None: - return y - return y * x - - module = MyModule() - real_result = module(torch.tensor([-1])) - - # X is negative, so .item() < 0, which means we return y - self.assertEqual(real_result, torch.tensor([0.5])) - - graph, guards = torchdynamo.export(module, torch.tensor([-1])) - result = graph(torch.tensor([2])) - # X is positive, but we compiled helper_fn to return None, so it will still return y - self.assertTrue(torchdynamo.utils.same(result, real_result)) - - @patch.object(torchdynamo.config, "fake_tensor_propagation", False) - def test_export_with_constant_not_none_control_flow(self): - class MyModule(torch.nn.Module): - @torchdynamo.assume_constant_result - def helper_fn(self, x): - if x.item() < 0: - return None - else: - return x - - def forward(self, x): - y = torch.tensor([0.5]) - x = self.helper_fn(x) - if x is None: - return y - return y * x - - module = MyModule() - real_result = module(torch.tensor([2])) - - # X is positive, so .item() > 0, which means we return y * x - self.assertEqual(real_result, torch.tensor([1.0])) - - graph, guards = torchdynamo.export(module, torch.tensor([2])) - result = graph(torch.tensor([-0.5])) - # X is negative, but we compiled helper_fn to return x, so it will still return y * x - self.assertTrue(torchdynamo.utils.same(result, real_result)) - - @patch.object(torchdynamo.config, "fake_tensor_propagation", False) - def test_export_with_constant_none_control_flow_free_func(self): - @torchdynamo.assume_constant_result - def helper_fn(x): - if x.item() < 0: - return None - else: - return x - - class MyModule(torch.nn.Module): - def forward(self, x): - y = torch.tensor([0.5]) - x = helper_fn(x) - if x is None: - return y - return y * x - - module = MyModule() - real_result = module(torch.tensor([-1])) - - # X is negative, so .item() < 0, which means we return y - self.assertEqual(real_result, torch.tensor([0.5])) - - graph, guards = torchdynamo.export(module, torch.tensor([-1])) - result = graph(torch.tensor([2])) - # X is positive, but we compiled helper_fn to return None, so it will still return y - self.assertTrue(torchdynamo.utils.same(result, real_result)) - - @patch.object(torchdynamo.config, "fake_tensor_propagation", False) - def test_export_with_constant_not_none_control_flow_pos(self): - class MyModule(torch.nn.Module): - @torchdynamo.assume_constant_result - def helper_fn(self, x): - if x.item() < 0: - return None - else: - return x - - def forward(self, x): - y = torch.tensor([0.5]) - x = self.helper_fn(x) - if x is None: - return y - return y * x - - module = MyModule() - real_result = module(torch.tensor([2])) - - # X is positive, so .item() > 0, which means we return y * x - self.assertEqual(real_result, torch.tensor([1.0])) - - graph, guards = torchdynamo.export(module, torch.tensor([2])) - result = graph(torch.tensor([-0.5])) - # X is negative, but we compiled helper_fn to return x, so it will still return y * x - self.assertTrue(torchdynamo.utils.same(result, real_result)) - - @patch.object(torchdynamo.config, "fake_tensor_propagation", False) - def test_export_with_constant_not_none_control_flow_free_func(self): - @torchdynamo.assume_constant_result - def helper_fn(x): - if x.item() < 0: - return None - else: - return x - - class MyModule(torch.nn.Module): - def forward(self, x): - y = torch.tensor([0.5]) - x = helper_fn(x) - if x is None: - return y - return y * x - - module = MyModule() - real_result = module(torch.tensor([2])) - - # X is positive, so .item() > 0, which means we return y * x - self.assertEqual(real_result, torch.tensor([1.0])) - - graph, guards = torchdynamo.export(module, torch.tensor([2])) - result = graph(torch.tensor([-0.5])) - # X is negative, but we compiled helper_fn to return x, so it will still return y * x - self.assertTrue(torchdynamo.utils.same(result, real_result)) - - @patch.object(torchdynamo.config, "fake_tensor_propagation", False) - def test_export_with_constant_not_return_const(self): - class MyModule(torch.nn.Module): - @torchdynamo.assume_constant_result - def helper_fn(self, x): - return self.val - - def forward(self, x): - y = torch.tensor([0.5]) - x = self.helper_fn(x) - if x == "A": - return y - return -1 - - module = MyModule() - module.val = "A" - resA = module(torch.tensor([2])) - graph, guards = torchdynamo.export(module, torch.tensor([2])) - module.val = "B" - resB = graph(torch.tensor([2])) - self.assertTrue(torchdynamo.utils.same(resA, resB)) - - def test_export_decomp(self): - def f(x): - return x.t() + x.t() - - def nop(x): - return x.cos() - - graph, _ = torchdynamo.export( - f, - (torch.randn(5)), - aten_graph=True, - decomposition_table={torch.ops.aten.t.default: nop}, - ) - self.assertEqual( - len([n for n in graph.graph.nodes if n.target == torch.ops.aten.t.default]), - 0, - ) - - graph, _ = torchdynamo.export( - f, (torch.randn(5)), aten_graph=True, decomposition_table=None - ) - self.assertEqual( - len([n for n in graph.graph.nodes if n.target == torch.ops.aten.t.default]), - 2, - ) - - def test_export_decomp_asserts_bad_args(self): - def f(x): - return x.t() + x.t() - - def nop(x): - return x.cos() - - with self.assertRaises(AssertionError): - graph, _ = torchdynamo.export( - f, - (torch.randn(5)), - aten_graph=False, - decomposition_table={torch.ops.aten.t.default: nop}, - ) - - def test_export_decomp_asserts_bad_args_mode(self): - def f(x): - return x.t() + x.t() - - def nop(x): - return x.cos() - - with self.assertRaises(AssertionError): - graph, _ = torchdynamo.export( - f, (torch.randn(5)), aten_graph=False, tracing_mode="symbolic" - ) - - -if __name__ == "__main__": - from torchdynamo.test_case import run_tests - - run_tests() diff --git a/test/dynamo/test_functions.py b/test/dynamo/test_functions.py deleted file mode 100644 index 9d80eb52d9..0000000000 --- a/test/dynamo/test_functions.py +++ /dev/null @@ -1,676 +0,0 @@ -# Owner(s): ["module: dynamo"] -# flake8: noqa -import collections -import functools -import inspect -import itertools -import operator -from typing import Any - -import torch -from torch import sub -from torch.nn import functional as F - -import torchdynamo.test_case -import torchdynamo.testing -from torchdynamo.testing import requires_static_shapes - -tensor_for_import_testing = torch.ones(10, 10) -d = torch.ones(10, 10) -e = torch.nn.Linear(10, 10) -flag = True - - -def constant3(a, b): - return a - b + (1.0 + 2) - - -def func_with_default(a, b, some_default_arg=True): - if some_default_arg: - return a - b - - -def make_test(fn): - nargs = len(inspect.signature(fn).parameters) - - def test_fn(self): - return torchdynamo.testing.standard_test(self, fn=fn, nargs=nargs) - - return test_fn - - -@torch.jit.script_if_tracing -def inline_script_if_tracing(x): - return x + 1.2 - - -@torch.jit.ignore -def inline_ignore(x): - return x + 3.4 - - -@torch.jit.unused -def inline_unused(x): - return x + 5.6 - - -class FunctionTests(torchdynamo.test_case.TestCase): - @make_test - def test_inline_jit_annotations(x): - x = inline_script_if_tracing(x) - x = inline_ignore(x) - x = inline_unused(x) - return - - @make_test - def test_add(a, b): - return a + b - - @make_test - def test_is_not_null(a, b): - if a is not None and b is not None: - return a + b - - @make_test - def test_constant1(a, b, c): - return a - b * c + 1.0 - - @make_test - def test_constant2(a, b, c): - return a - b * c + 1 - - @make_test - def test_constant3(a): - b = 1 - c = 2 - d = 3 - return b + c - d + a - - @make_test - def test_constant4(a, b): - c = 2 - d = 3 - if c > d: - return a - b - return b - a - - @make_test - def test_finfo(a, b): - if torch.iinfo(torch.int32).bits == 32: - return torch.finfo(a.dtype).min * b - - @make_test - def test_globalfn(a, b): - return sub(a, b) - - @make_test - def test_viatorch(a, b): - return torch.sub(a, b) - - @make_test - def test_viamethod(a, b): - return a.sub(b) - - @make_test - def test_indirect1(a, b): - t = a.sub - return t(b) - - @make_test - def test_indirect2(a, b): - t = a.sub - args = (b,) - return t(*args) - - @make_test - def test_indirect3(a, b): - t = a.sub - args = (b,) - kwargs = {} - return t(*args, **kwargs) - - @make_test - def test_methodcall1(a, b, c): - return constant3(a, b) * c - - @make_test - def test_methodcall2(a, b): - return constant3(a=b, b=a) + 1 - - @make_test - def test_methodcall3(a, b): - return constant3(a, b=1.0) + b - - @make_test - def test_device_constant(a): - return a + torch.ones(1, device=torch.device("cpu")) - - @make_test - def test_tuple1(a, b): - args = (a, b) - return sub(*args) - - @make_test - def test_tuple2(a, b): - args = [a, b] - return sub(*args) - - @make_test - def test_is_in_onnx_export(x, y): - if torch.onnx.is_in_onnx_export(): - return x - 1 - else: - return y + 1 - - @make_test - def test_is_fx_tracing(x, y): - if torch.fx._symbolic_trace.is_fx_tracing(): - return x - 1 - else: - return y + 1 - - @make_test - def test_listarg1(a, b): - return torch.cat([a, b]) - - @make_test - def test_listarg2(a, b): - return torch.cat((a, b), dim=0) - - @make_test - def test_listarg3(a, b): - kwargs = {"tensors": (a, b), "dim": 0} - return torch.cat(**kwargs) - - @make_test - def test_listarg4(a, b): - return torch.cat(tensors=[a, b], dim=0) - - @make_test - def test_listarg5(a, b): - args = [(a, b)] - kwargs = {"dim": 0} - return torch.cat(*args, **kwargs) - - @make_test - def test_slice1(a): - return a[5] - - @make_test - def test_slice2(a): - return a[:5] - - @make_test - def test_slice3(a): - return a[5:] - - @make_test - def test_slice4(a): - return a[2:5] - - @make_test - def test_slice5(a): - return a[::2] - - @make_test - def test_slice6(a): - return torch.unsqueeze(a, 0)[:, 2:] - - @make_test - def test_unpack1(a): - a, b = a[:5], a[5:] - return a - b - - @make_test - def test_unpack2(a): - packed = [a[:5], a[5:]] - a, b = packed - return a - b - - @make_test - def test_unpack3(a): - packed = (a[:5], a[5:]) - a, b = packed - return a - b - - @make_test - def test_fn_with_self_set(a, b): - # avg_pool2d is an odd one with __self__ set - return F.avg_pool2d( - torch.unsqueeze(a, 0) * torch.unsqueeze(b, 1), kernel_size=2, padding=1 - ) - - @make_test - def test_return_tuple1(a, b): - return (a - b, b - a, a, b) - - @make_test - def test_globalvar(a, b): - return a - b + d - - @make_test - def test_globalmodule(x): - return e(x) - - @make_test - def test_inline_with_default(a, b, c): - return func_with_default(a, b) * c - - @make_test - def test_inner_function(x): - def fn(x): - return torch.add(x, x) - - return fn(x) - - @make_test - def test_transpose_for_scores(x): - new_x_shape = x.size()[:-1] + (2, 5) - x = x.view(*new_x_shape) - return x.permute(0, 2, 1) - - @make_test - def test_return_tuple2(x): - return (torch.add(x, x), x) - - @make_test - def test_load_global_bool(x): - if flag: - return torch.add(x, x) - else: - return x - - @make_test - def test_len_tensor(x): - z = len(x) - return torch.add(x, z) - - @make_test - def test_len_constant_list(x): - z = len([1, 2, 3]) - return torch.add(x, z) - - @make_test - def test_len_constant_dict(x): - z = len({"foo": "bar"}) - return torch.add(x, z) - - @make_test - def test_dict_copy(x): - z = dict({"foo": x + 1}) - return z - - @make_test - def test_len_constant_misc_iterables(x): - a = len((1, 2, 3)) - b = len("test str") - c = a + b - return torch.add(x, c) - - @make_test - def test_float(x): - y = float(1.2) - y += float("1.2") - return torch.add(x, y) - - @make_test - def test_dtype(x): - if x.dtype == torch.float32: - return x + 1 - - @make_test - def test_device(x): - if not x.is_cuda: - return x + 1 - - @make_test - def test_ndim(x): - if x.ndim == 2 and x.ndimension() == 2 and x.dim() == 2: - return x + 1 - - @make_test - def test_is_sparse(x): - if not x.is_sparse: - return x + 1 - - @requires_static_shapes - @make_test - def test_shape1(x): - if x.shape[0] == 10: - return x + 1 - - @requires_static_shapes - @make_test - def test_shape2(x): - if x.size(1) == 10: - return x + 1 - - @make_test - def test_del(a, b): - c = a + 1 - d = c + 2 - del c, a - return b + d - - @requires_static_shapes - @make_test - def test_chunks1(x): - chunk_size = 5 - assert x.shape[0] % chunk_size == 0 - assert x.shape[0] // chunk_size == 2 - return x[:chunk_size] - x[chunk_size:] - - @make_test - def test_import1(x, y): - import torch - from torch import sub - - return sub(torch.add(x, y), y) - - @make_test - def test_return_dict(x, y): - z = [x + y, y, False] - return {"x": x, "z": z, "a": x, "b": z, "c": x} - - @make_test - def test_return_dict2(x, y): - tmp = {"x": x} - tmp["z"] = [x + y, y] - tmp["y"] = y - tmp["z"].append(False) - return tmp - - @make_test - def test_funcdef_closure(x, y): - x = x + y + 1.0 - - def inner(z): - nonlocal x, y - y = x + z + 20.0 - x = y + z + 10.0 - - inner(2.0) - inner(3.0) - - return x, y - - @make_test - def test_module_constant(x, y): - r = x + y - for i in range(torchdynamo.testing.three): - r = r / y - return r - - @make_test - def test_inline_softmax(x, y): - # This is common in sme huggingface models - return torch.nn.Softmax(dim=-1)(x + y * 2) - - @make_test - def test_dtype_compare(a, b): - if a.dtype == torch.float16: - return a + 10 - if a.dtype == torch.float32: - return a - b * 32 - - @make_test - def test_build_list_unpack(a, b): - it1 = (x + 1 for x in (a, b)) - it2 = (x - 1 for x in (a, b)) - return torch.cat([*it1, *it2], dim=-1) - - @make_test - def test_tensor_len(a, b): - return a + b + len(a) + b.__len__() - - @make_test - def test_pop(a, b): - ll = [a, b] - ll.append(a + 1) - ll.extend( - [ - b + 2, - a + b, - ] - ) - ll.pop(-1) - ll.pop(0) - ll.pop() - v1, v2 = ll - return v1 - v2 - - @make_test - def test_list_convert(a, b): - ll = [a + 2, b] - ll = tuple(ll) - tmp = b + 3 - ll = list(ll) - v1, v2 = ll - return v1 - v2 + tmp - - @make_test - def test_list_add(a, b): - l1 = (a, b) - l2 = () # being a LOAD_CONST in the bytecode - l3 = l1 + l2 - return l3[0] + l3[1] - - @make_test - def test_startswith(a, b): - x = a + b - if "foobar".startswith("foo") and "test" in constant3.__module__: - x = x + 1 - return x - - @make_test - def test_dict_ops(a, b): - tmp = {"a": a + 1, "b": b + 2} - v = tmp.pop("b") + tmp.get("a") + tmp.get("missing", 3) + tmp.pop("missing", 4) - tmp.update({"d": 3}) - tmp["c"] = v + tmp["d"] - if "c" in tmp and "missing" not in tmp: - return tmp["c"] - tmp["a"] + len(tmp) - - def test_dict_param_keys(self): - a_param = torch.nn.Parameter(torch.ones([4, 4])) - - def fn(a): - tmp = {"a": a, a_param: 3} - return tmp["a"] + tmp[a_param] - - test = make_test(fn) - test(self) - - def test_default_dict(self): - dd = collections.defaultdict(dict) - param = torch.nn.Parameter(torch.ones([2, 2])) - - def fn(x): - dd["a"] = x + 1 - dd[param] = 123 - dd["c"] = x * 2 - return dd["b"], dd - - test = make_test(fn) - test(self) - - @make_test - def test_min_max(a, b): - c = a + b - a = a.sum() - b = b.sum() - a = min(max(a, 0), 1) - b = max(0, min(1, b)) - return max(a, b) - min(a, b) + c - - @make_test - def test_map_sum(a, b, c, d): - return sum(map(lambda x: x + 1, [a, b, c, d])) - - @make_test - def test_reduce(a, b, c, d): - return functools.reduce(operator.add, [a, b, c, d]) - - @make_test - def test_tuple_contains(a, b): - v1 = "a" - v2 = "b" - v3 = "c" - vals1 = (v1, v2, v3) - vals2 = ("d", "e", "f") - if "a" in vals1 and "b" not in vals2: - return a + b - return a - b - - @make_test - def test_tuple_iadd(a, b): - output = (a, b) - output += (a + b, a - b) - return output - - @make_test - def test_unpack_ex1(x): - output = (x, x + 1, x + 2, x + 3) - a, b, *cd = output - return a - b / cd[0] - - @make_test - def test_unpack_ex2(x): - output = (x, x + 1, x + 2, x + 3) - *ab, c, d = output - return c - d / ab[0] - - @make_test - def test_unpack_ex3(x): - output = (x, x + 1, x + 2, x + 3) - a, *bc, d = output - return a - d / bc[0] - - @make_test - def test_const_tuple_add1(x): - output = (x, x + 1, x + 2, x + 3) - output = () + output + () - return output[2] + output[3] - - @make_test - def test_const_tuple_add2(x): - output = (x, x + 1, x + 2, x + 3) - output = (None,) + output + (None,) - return output[2] + output[3] - - @make_test - def test_list_truth(a, b): - tmp = [1, 2, 3] - if tmp: - return a + b - else: - return a - b - - @make_test - def test_list_reversed(a, b): - tmp = [a + 1, a + 2, a + 3] - return a + b + next(iter(reversed(tmp))) - - @make_test - def test_list_clear(a, b): - tmp = [a + 1, a + 2] - tmp.clear() - tmp.append(a + b) - return tmp - - @make_test - def test_islice_chain(a, b): - tmp1 = [a + 1, a + 2] - tmp2 = [a + 3, a + 4] - a, b = list(itertools.islice(itertools.chain(tmp1, tmp2), 1, 3)) - c = next(itertools.islice(tmp1, 1, None)) - return a - b / c - - @make_test - def test_is_quantized(a, b): - if not a.is_quantized: - return a + b - - @make_test - def test_fstrings1(a, b): - x = 1.229 - tmp = f"{x:.2f} bar" - if tmp.startswith("1.23"): - return a + b - - @requires_static_shapes - @make_test - def test_fstrings2(x): - tmp = f"{x.shape[0]} bar" - if tmp.startswith("10"): - return x + 1 - - @make_test - def test_fstrings3(x): - tmp = f"{x.__class__.__name__} foo" - if tmp.startswith("Tensor"): - return x + 1 - - @requires_static_shapes - @make_test - def test_tensor_new_with_size(x): - y = torch.rand(5, 8) - z = x.new(y.size()) - assert z.size() == y.size() - - @requires_static_shapes - @make_test - def test_tensor_new_with_shape(x): - y = torch.rand(5, 8) - z = x.new(y.shape) - assert z.size() == y.size() - - @make_test - def test_jit_annotate(x): - y = torch.jit.annotate(Any, x + 1) - return y + 2 - - @requires_static_shapes - @make_test - def test_is_contiguous_memory_format(tensor): - if torch.jit.is_scripting(): - return None - elif tensor.is_contiguous(memory_format=torch.contiguous_format): - return tensor + 1 - - @make_test - def test_list_slice_assignment(x): - m = [1, 2, 3, 4] - m[1:] = [6] * (len(m) - 1) - return x + 1 - - # # This is to test the new syntax for pattern matching - # # ("match ... case ...") added on python 3.10. - # # Uncomment these test cases if you run on 3.10+ - # @make_test - # def test_match_sequence(a): - # point = (5, 8) - # match point: - # case (0, 0): - # return a - # case (0, y): - # return a - y - # case (x, 0): - # return a + x - # case (x, y): - # return a + x - y - - # @make_test - # def test_match_mapping_and_match_keys(x): - # param = {"a": 0.5} - # match param: - # case {"a": param}: - # return x * param - # case {"b": param}: - # return x / param - - -if __name__ == "__main__": - from torchdynamo.test_case import run_tests - - run_tests() diff --git a/test/dynamo/test_global.py b/test/dynamo/test_global.py deleted file mode 100644 index e9e571c90b..0000000000 --- a/test/dynamo/test_global.py +++ /dev/null @@ -1,233 +0,0 @@ -# Owner(s): ["module: dynamo"] -import torch - -import torchdynamo.test_case -import torchdynamo.testing -from torchdynamo.testing import same - -try: - from . import test_global_declaration -except ImportError: - import test_global_declaration - - -class Pair(object): # noqa: B903 - def __init__(self, x, y): - self.x = x - self.y = y - - -def Foo(): - return Pair(1, 1) - - -g_counter = 1 -g_list = [0, 1, 2] -g_dict = {"a": 0, "b": 1} -g_object = Foo() -g_tensor = torch.zeros(10) - - -_name: int = 0 - - -def fresh_name() -> str: - """create a new unique name for a variable: v0, v1, v2""" - global _name - r = f"v{_name}" - _name += 1 - return r - - -def reset_name(): - global _name - _name = 0 - - -class TestGlobals(torchdynamo.test_case.TestCase): - def test_store_global_1(self): - def fn(x): - global g_counter - val = x + g_counter - g_counter += 1 - return val - - x = torch.randn(10) - cnts = torchdynamo.testing.CompileCounter() - opt_fn = torchdynamo.optimize(cnts)(fn) - res1 = opt_fn(x) - res2 = fn(x) - self.assertTrue(same(res2 - res1, torch.ones(10))) - - def test_store_global_2(self): - def fn(x): - global g_counter - val = x + g_counter - g_counter += 1 - g_counter += 1 - return val - - x = torch.randn(10) - cnts = torchdynamo.testing.CompileCounter() - opt_fn = torchdynamo.optimize(cnts)(fn) - res1 = opt_fn(x) - """Wrap the second call with torchdynamo as well""" - opt_fn = torchdynamo.optimize(cnts)(fn) - res2 = opt_fn(x) - self.assertTrue(same(res2 - res1, 2 * torch.ones(10))) - - def test_store_global_new(self): - def fn(x): - # Test create a new global - global g_counter_new - g_counter_new = x + 1 - return x + g_counter_new - - x = torch.randn(10) - cnts = torchdynamo.testing.CompileCounter() - opt_fn = torchdynamo.optimize(cnts)(fn) - res1 = opt_fn(x) - self.assertTrue(same(res1, x + x + 1)) - - def test_store_global_list(self): - def fn(x): - global g_list - val = x + g_list[1] - """ - Strictly speaking, we are not testing STORE_GLOBAL - here, since STORE_SUBSCR is actually used to store. - """ - g_list[1] += 1 - return val - - x = torch.randn(10) - cnts = torchdynamo.testing.CompileCounter() - opt_fn = torchdynamo.optimize(cnts)(fn) - res1 = opt_fn(x) - res2 = fn(x) - self.assertTrue(same(res2 - res1, torch.ones(10))) - - def test_store_global_list_2(self): - def fn(x): - global g_list - val = x + g_list[1] - g_list = [x + 1 for x in g_list] - return val - - x = torch.randn(10) - cnts = torchdynamo.testing.CompileCounter() - opt_fn = torchdynamo.optimize(cnts)(fn) - res1 = opt_fn(x) - res2 = fn(x) - self.assertTrue(same(res2 - res1, torch.ones(10))) - - def test_store_global_dict(self): - def fn(x): - global g_dict - val = x + g_dict["b"] - """ - Strictly speaking, we are not testing STORE_GLOBAL - here, since STORE_SUBSCR is actually used to store. - """ - g_dict["b"] += 1 - return val - - x = torch.randn(10) - cnts = torchdynamo.testing.CompileCounter() - opt_fn = torchdynamo.optimize(cnts)(fn) - res1 = opt_fn(x) - res2 = fn(x) - self.assertTrue(same(res2 - res1, torch.ones(10))) - - def test_store_global_dict_2(self): - def fn(x): - global g_dict - g_dict = {key: value + 1 for key, value in g_dict.items()} - val = x + g_dict["b"] - return val - - x = torch.randn(10) - cnts = torchdynamo.testing.CompileCounter() - opt_fn = torchdynamo.optimize(cnts)(fn) - res1 = opt_fn(x) - res2 = fn(x) - self.assertTrue(same(res2 - res1, torch.ones(10))) - - def test_store_global_object(self): - def fn(x): - global g_object - val = x + g_object.y - g_object.y += 1 - return val - - x = torch.randn(10) - cnts = torchdynamo.testing.CompileCounter() - opt_fn = torchdynamo.optimize(cnts)(fn) - res1 = opt_fn(x) - res2 = fn(x) - self.assertTrue(same(res2 - res1, torch.ones(10))) - - def test_store_global_cross_file(self): - def fn(x): - val = x + test_global_declaration.g_tensor_export - test_global_declaration.g_tensor_export = ( - test_global_declaration.g_tensor_export + 1 - ) - return val - - x = torch.randn(10) - cnts = torchdynamo.testing.CompileCounter() - opt_fn = torchdynamo.optimize(cnts)(fn) - res1 = opt_fn(x) - res2 = fn(x) - self.assertTrue(same(res2 - res1, torch.ones(10))) - - def test_store_global_inline_1(self): - # Borrowed from test_python_autograd.py - class Variable: - def __init__(self, value: torch.Tensor, name: str = None): - self.value = value - self.name = name or fresh_name() - - def fn(a, b): - a = Variable(a) - b = Variable(b) - return a.value + b.value, a.name + b.name - - a = torch.randn(10) - b = torch.randn(10) - cnts = torchdynamo.testing.CompileCounter() - opt_fn = torchdynamo.optimize(cnts)(fn) - v0, s0 = opt_fn(a, b) - self.assertEqual(s0, "v0v1") - reset_name() - - def test_store_global_inline_2(self): - # Borrowed from test_python_autograd.py - class Variable: - def __init__(self, value: torch.Tensor, name: str = None): - self.value = value - self.name = name or fresh_name() - - @staticmethod - def constant(value: torch.Tensor, name: str = None): - return Variable(value, name) - - def fn(a, b): - a = Variable.constant(a) - b = Variable.constant(b) - return a.value + b.value, a.name + b.name - - a = torch.randn(10) - b = torch.randn(10) - cnts = torchdynamo.testing.CompileCounter() - opt_fn = torchdynamo.optimize(cnts)(fn) - v0, s0 = opt_fn(a, b) - self.assertEqual(s0, "v0v1") - reset_name() - - -if __name__ == "__main__": - from torchdynamo.test_case import run_tests - - run_tests() diff --git a/test/dynamo/test_global_declaration.py b/test/dynamo/test_global_declaration.py deleted file mode 100644 index 95995ca80a..0000000000 --- a/test/dynamo/test_global_declaration.py +++ /dev/null @@ -1,4 +0,0 @@ -# Owner(s): ["module: dynamo"] -import torch - -g_tensor_export = torch.ones(10) diff --git a/test/dynamo/test_misc.py b/test/dynamo/test_misc.py deleted file mode 100644 index e0ad25efe5..0000000000 --- a/test/dynamo/test_misc.py +++ /dev/null @@ -1,2740 +0,0 @@ -# Owner(s): ["module: dynamo"] -import collections -import copy -import dataclasses -import dis -import enum -import logging -import math -import os -import sys -import typing -import unittest -import weakref -from unittest.mock import patch - -import numpy as np -import torch -import torch.onnx.operators -from torch.testing._internal.jit_utils import JitTestCase - -import torchdynamo.test_case -import torchdynamo.testing -from torchdynamo import bytecode_transformation -from torchdynamo import graph_break -from torchdynamo.testing import CompileCounter -from torchdynamo.testing import requires_static_shapes -from torchdynamo.testing import same -from torchdynamo.testing import unsupported - -mytuple = collections.namedtuple("mytuple", ["a", "b", "ab"]) - - -def my_custom_function(x): - return x + 1 - - -class MiscTests(torchdynamo.test_case.TestCase): - def test_boolarg(self): - def boolarg(aa, bb, flag): - if flag: - return aa - bb - else: - return bb - aa - - a = torch.randn(10, 10) - b = torch.randn(10, 10) - correct1 = boolarg(a, b, True) - correct2 = boolarg(a, b, False) - correct3 = boolarg(a, b, None) - counter = CompileCounter() - opt_boolarg = torchdynamo.optimize_assert(counter)(boolarg) - val1 = opt_boolarg(a, b, True) - val2 = opt_boolarg(a, b, False) - val3 = opt_boolarg(a, b, None) - val4 = opt_boolarg(a, b, True) - self.assertTrue(same(val1, correct1)) - self.assertTrue(same(val2, correct2)) - self.assertTrue(same(val3, correct3)) - self.assertTrue(same(val4, correct1)) - self.assertEqual(counter.frame_count, 3) - - def test_callpacked(self): - def call_packed(args): - a, b, c = args - return a - b * c - - counter = CompileCounter() - a = torch.randn(10, 10) - b = torch.randn(10, 10) - c = torch.randn(10, 10) - correct = call_packed([a, b, c]) - opt_call_packed = torchdynamo.optimize_assert(counter)(call_packed) - val1 = opt_call_packed([a, b, c]) - val2 = opt_call_packed((a, b, c)) - val3 = opt_call_packed([a, b, c]) - val4 = opt_call_packed((a, b, c)) - self.assertTrue(same(val1, correct)) - self.assertTrue(same(val2, correct)) - self.assertTrue(same(val3, correct)) - self.assertTrue(same(val4, correct)) - self.assertEqual(counter.frame_count, 2) - - def test_raises(self): - def fn(a, b, c, cls): - x = a + b - c * 10 - raise cls(str(x)) - - counter = CompileCounter() - a = torch.randn(10, 10) - b = torch.randn(10, 10) - c = torch.randn(10, 10) - opt_fn = torchdynamo.optimize(counter)(fn) - self.assertRaises(AssertionError, lambda: opt_fn(a, b, c, AssertionError)) - self.assertEqual(counter.frame_count, 1) - self.assertEqual(counter.op_count, 3) - - def test_inplace(self): - def inplace1(a, b): - o = torch.empty((10, 10)) - o.copy_(a) - o -= b - return o - - torchdynamo.testing.standard_test(self, inplace1, 2, expected_ops=3) - - def test_unpack4(self): - def unpack4(a, b): - a = a[:5, :] - b = b[:5, :] - x, y = a.size() - o = torch.empty((x, y)) - o.copy_(a / b) - return o - - torchdynamo.testing.standard_test( - self, unpack4, 2, expected_ops=5, expected_ops_dynamic=8 - ) - - def test_unpack5(self): - def unpack5(a, b): - a = a[:5, :] - b = b[:5, :] - x, y = a.shape - o = torch.empty((x, y)) - o.copy_(a / b) - return o - - torchdynamo.testing.standard_test( - self, unpack5, 2, expected_ops=5, expected_ops_dynamic=8 - ) - - def test_matmul1(self): - def matmul_op1(a, b): - return a @ b - - # TODO(jansel): FX doesn't support this, should add upstream support - torchdynamo.testing.standard_test(self, matmul_op1, 2, expected_ops=1) - - def test_builtin_isinstance(self): - def fn(x): - t = torch.arange(1, 3) - a = isinstance(x, torch.Tensor) - b = isinstance(t, torch.Tensor) - c = isinstance(x, int) - d = isinstance(3, int) - e = isinstance([1, 2, 3], list) - f = isinstance({"foo": 1, "bar": 2}, dict) - res = [a, b, c, d, e, f] - # Can't run yet due to other unimplemented instructions - # res += [isinstance(torch.nn.LazyLinear(2, 3), torch.nn.Linear)] - return res - - torchdynamo.testing.standard_test(self, fn, 1, expected_ops=1) - - def test_fold(self): - def fn(a): - return a + math.sqrt(63) - - torchdynamo.testing.standard_test(self, fn, 1, expected_ops=1) - - def test_shape_unpack(self): - def fn(x): - a, b = x.size() - return x * b - - i = torch.randn(5, 10) - r1 = fn(i) - opt_fn = torchdynamo.optimize("eager")(fn) - r2 = opt_fn(i) - self.assertTrue(same(r1, r2)) - - def test_empty_list(self): - def fn(x, ll): - if len(ll) == 0 and not ll and ll is not None: - return x + 1 - - i = torch.randn(5, 10) - r1 = fn(i, []) - opt_fn = torchdynamo.optimize("eager")(fn) - r2 = opt_fn(i, []) - r3 = opt_fn(i, tuple()) - self.assertTrue(same(r1, r2)) - self.assertTrue(same(r1, r3)) - - def test_config_obj(self): - class Cfg: - def __init__(self): - self.val = 0.5 - self.count = 3 - - def fn(x, cfg): - for i in range(cfg.count): - x = x + cfg.val - return x - - cfg1 = Cfg() - cfg1.val = 1.0 - cfg2 = Cfg() - v = torch.zeros(1) - cnts = torchdynamo.testing.CompileCounter() - opt_fn = torchdynamo.optimize(cnts)(fn) - v = opt_fn(v, cfg1) # 3 - v = opt_fn(v, cfg2) # 4.5 - cfg2.count = 1 - v = opt_fn(v, cfg2) # 5 - cfg2.val = 2.0 - v = opt_fn(v, cfg2) # 7 - self.assertEqual(v[0], 7) - self.assertEqual(cnts.op_count, 8) - - def test_config_getattr_default(self): - class Cfg: - def __init__(self): - self.val = 0.5 - self.count = 10 - - def fn(x, cfg): - if getattr(cfg, "just_add_7", False): - return x + 7 - for i in range(cfg.count): - x = x + cfg.val - return x - - cfg1 = Cfg() - v = torch.zeros(1) - cnts = torchdynamo.testing.CompileCounter() - opt_fn = torchdynamo.optimize(cnts)(fn) - self.assertEqual(opt_fn(v, cfg1)[0], 5) - self.assertEqual(opt_fn(v, cfg1)[0], 5) - cfg1.just_add_7 = True - self.assertEqual(opt_fn(v, cfg1)[0], 7) - self.assertEqual(opt_fn(v, cfg1)[0], 7) - cfg1.just_add_7 = False - self.assertEqual(opt_fn(v, cfg1)[0], 5) - self.assertEqual(opt_fn(v, cfg1)[0], 5) - self.assertEqual(cnts.frame_count, 3) - - def test_size_input(self): - def fn(x, s): - a, b = s - return x + (a - b) - - v = torch.zeros(10, 20) - cnts = torchdynamo.testing.CompileCounter() - opt_fn = torchdynamo.optimize(cnts)(fn) - self.assertEqual(opt_fn(v, v.size())[0, 0], -10) - self.assertEqual(opt_fn(v, (10, 20))[0, 0], -10) - self.assertEqual(opt_fn(v, [10, 20])[0, 0], -10) - self.assertEqual(cnts.op_count, 2) - - def test_cell_output1(self): - out = None - - def fn(a, b): - nonlocal out - out = a + b * 10 - - v = torch.Tensor([100]) - cnts = torchdynamo.testing.CompileCounter() - opt_fn = torchdynamo.optimize(cnts)(fn) - self.assertIsNone(opt_fn(v, v)) - self.assertEqual(out[0], 1100) - self.assertEqual(cnts.op_count, 2) - - def test_cell_output2(self): - out = None - - def fn(a, b): - nonlocal out - c = unsupported(a, b) - out = a + b * 10 + c - - v = torch.Tensor([100]) - cnts = torchdynamo.testing.CompileCounter() - opt_fn = torchdynamo.optimize(cnts)(fn) - self.assertIsNone(opt_fn(v, v)) - self.assertEqual(out[0], 1200) - self.assertEqual(cnts.op_count, 3) - - def test_return_nested_function(self): - out = None - - def fn(a, b): - nonlocal out - c = a + b - d = a + 1.0 - - def fn2(f: int = 7, g: float = 9.0): - nonlocal out - out = a + b * 10 - return c * f - d * g - - return fn2 - - v1 = torch.Tensor([100]) - v2 = torch.Tensor([200]) - cnts = torchdynamo.testing.CompileCounter() - opt_fn = torchdynamo.optimize(cnts)(fn) - opt_fn_ret = torchdynamo.optimize(cnts)(opt_fn(v1, v2)) - self.assertEqual(opt_fn_ret(1.5)[0], -459) - self.assertEqual(out[0], 2100) - self.assertEqual(cnts.frame_count, 2) - self.assertEqual(cnts.op_count, 7) - - def test_tensor_dict1(self): - def fn(inputs): - return inputs["a"] - inputs["b"] * 1.5 - - v1 = torch.Tensor([100]) - v2 = torch.Tensor([200]) - cnts = torchdynamo.testing.CompileCounter() - opt_fn = torchdynamo.optimize(cnts)(fn) - self.assertEqual(opt_fn({"a": v1, "b": v2})[0], -200) - self.assertEqual(cnts.frame_count, 1) - self.assertEqual(cnts.op_count, 2) - - def test_tensor_dict2(self): - def fn1(inputs): - total = torch.zeros(1) - for k, v in inputs.items(): - total += v - return total - - def fn2(inputs): - total = torch.zeros(1) - for v in inputs.values(): - total += v - return total - - def fn3(inputs): - total = torch.zeros(1) - for k in inputs.keys(): - total += inputs[k] - return total - - v1 = torch.Tensor([100]) - v2 = torch.Tensor([200]) - cnts = torchdynamo.testing.CompileCounter() - opt_fn1 = torchdynamo.optimize(cnts)(fn1) - opt_fn2 = torchdynamo.optimize(cnts)(fn2) - opt_fn3 = torchdynamo.optimize(cnts)(fn3) - self.assertEqual(opt_fn1({"a": v1, "b": v2})[0], 300) - self.assertEqual(opt_fn2({"a": v1, "b": v2})[0], 300) - self.assertEqual(opt_fn3({"a": v1, "b": v2})[0], 300) - self.assertEqual(cnts.frame_count, 3) - self.assertEqual(cnts.op_count, 9) - - def test_dictcomp(self): - def fn1(inputs): - return {k: v + 1 for k, v in inputs.items()} - - v1 = torch.Tensor([100]) - v2 = torch.Tensor([200]) - cnts = torchdynamo.testing.CompileCounter() - opt_fn1 = torchdynamo.optimize(cnts)(fn1) - self.assertEqual(opt_fn1({"a": v1, "b": v2})["a"], 101) - self.assertEqual(opt_fn1({"a": v1, "b": v2})["b"], 201) - self.assertEqual(cnts.frame_count, 1) - self.assertEqual(cnts.op_count, 2) - - def test_listcomp(self): - def fn2(inputs): - return torch.sum(torch.cat([v + 1 for k, v in inputs.items()], 0)) - - v1 = torch.Tensor([100]) - v2 = torch.Tensor([200]) - cnts = torchdynamo.testing.CompileCounter() - opt_fn2 = torchdynamo.optimize(cnts)(fn2) - self.assertEqual(opt_fn2({"a": v1, "b": v2}), 302) - self.assertEqual(cnts.frame_count, 1) - self.assertEqual(cnts.op_count, 4) - - def test_is_floating_point(self): - def fn(a, b): - x = a + 1.0 - if torch.is_floating_point(b): - x = x + b - return x + 2.0 - - return torchdynamo.testing.standard_test(self, fn=fn, nargs=2, expected_ops=3) - - def test_is_floating_point2(self): - def fn(a, b): - x = a + 1.0 - if b.is_floating_point(): - x = x + b - return x + 2.0 - - return torchdynamo.testing.standard_test(self, fn=fn, nargs=2, expected_ops=3) - - def test_is_tensor(self): - def fn(a, b): - x = a + 1.0 - if torch.is_tensor(b): - x = x + b - return x + 2.0 - - return torchdynamo.testing.standard_test(self, fn=fn, nargs=2, expected_ops=3) - - def test_numel(self): - def fn(a): - return a + a.numel() + torch.numel(a) - - return torchdynamo.testing.standard_test( - self, fn=fn, nargs=1, expected_ops=2, expected_ops_dynamic=4 - ) - - def test_pair(self): - def fn(a): - return ( - torch.zeros(torch.nn.modules.utils._pair(a.size())) - + a - + torch.ones(torch.nn.modules.utils._ntuple(3)(3)).sum() - ) - - return torchdynamo.testing.standard_test( - self, fn=fn, nargs=1, expected_ops=5, expected_ops_dynamic=8 - ) - - @patch.object(torchdynamo.config, "capture_scalar_outputs", True) - def test_tensor_item_capture(self): - def fn(a, b): - return (a + b).sum().item() - - v1 = torch.randn((10, 10)) - v2 = torch.randn((10, 10)) - correct = fn(v1, v2) - cnts = torchdynamo.testing.CompileCounter() - opt_fn = torchdynamo.optimize((cnts))(fn) - self.assertEqual(opt_fn(v1, v2), correct) - self.assertEqual(cnts.frame_count, 1) - self.assertEqual(cnts.op_count, 3) - - @patch.object(torchdynamo.config, "capture_scalar_outputs", False) - def test_tensor_item_no_capture(self): - def fn(a, b): - return (a + b).sum().item() - - v1 = torch.randn((10, 10)) - v2 = torch.randn((10, 10)) - correct = fn(v1, v2) - cnts = torchdynamo.testing.CompileCounter() - opt_fn = torchdynamo.optimize((cnts))(fn) - self.assertEqual(opt_fn(v1, v2), correct) - self.assertEqual(cnts.frame_count, 1) - self.assertEqual(cnts.op_count, 2) - - def test_namedtuple1(self): - def fn(a, b): - tmp = mytuple(a, b, a + b) - return mytuple(tmp.a, tmp[1], tmp.ab + b) - - v1 = torch.Tensor([10]) - v2 = torch.Tensor([20]) - cnts = torchdynamo.testing.CompileCounter() - opt_fn = torchdynamo.optimize(cnts)(fn) - self.assertEqual(opt_fn(v1, v2).ab, 50) - self.assertEqual(cnts.frame_count, 1) - self.assertEqual(cnts.op_count, 2) - - def test_namedtuple2(self): - def fn(packed): - a, b, c = packed - if hasattr(packed, "b"): - b = packed.b + 1 - c = packed[2] - return a + b + c - - v1 = torch.Tensor([1]) - v2 = torch.Tensor([2]) - v3 = torch.Tensor([3]) - cnts = torchdynamo.testing.CompileCounter() - opt_fn = torchdynamo.optimize(cnts)(fn) - self.assertEqual(opt_fn(mytuple(v1, v2, v3))[0], 7) - self.assertEqual(cnts.frame_count, 1) - self.assertEqual(cnts.op_count, 3) - - def test_range_input(self): - def fn(a, rng): - x = a - for i in rng: - x = x + i - return x - - def fn1(a): - return fn(a, rng=range(3)) - - return torchdynamo.testing.standard_test(self, fn=fn1, nargs=1, expected_ops=3) - - def test_no_grad(self): - def fn1(a, b): - x = a + 1 - # redundant no_grad should get ignored - with torch.no_grad(): - x = x + b - x = x + 2 - return x - - def fn2(a, b): - x = a + 1 - with torch.set_grad_enabled(False): - x = x + b - x = x + 2 - return x - - def fn3(a, b): - x = a + 1 - with torch.enable_grad(): - x = x + b - x = x + 2 - return x - - def fn4(a, b): - x = a + 1 - with torch.set_grad_enabled(True): - if torch.is_grad_enabled(): - x = x + b - x = x + 2 - return x - - with torch.no_grad(): - torchdynamo.testing.standard_test(self, fn=fn1, nargs=2, expected_ops=5) - torchdynamo.testing.standard_test(self, fn=fn2, nargs=2, expected_ops=5) - torchdynamo.testing.standard_test(self, fn=fn3, nargs=2, expected_ops=5) - torchdynamo.testing.standard_test(self, fn=fn4, nargs=2, expected_ops=5) - with torch.enable_grad(): - torchdynamo.testing.standard_test(self, fn=fn1, nargs=2, expected_ops=5) - torchdynamo.testing.standard_test(self, fn=fn2, nargs=2, expected_ops=5) - torchdynamo.testing.standard_test(self, fn=fn3, nargs=2, expected_ops=5) - torchdynamo.testing.standard_test(self, fn=fn4, nargs=2, expected_ops=5) - - def test_grad_mode_guard(self): - def fn(a, b): - prev_grad = torch.is_grad_enabled() - torch.set_grad_enabled(False) - a = a + 1 - a.tolist() # graph break - ret = a + b - torch.set_grad_enabled(prev_grad) - return ret - - a = torch.randn([3, 4]) - b = torch.randn([3, 4]) - cnts = torchdynamo.testing.CompileCounter() - opt_fn = torchdynamo.optimize(cnts)(fn) - for _ in range(10): - opt_fn(a, b) - self.assertEqual(cnts.frame_count, 2) - - def test_build_tuple_unpack(self): - def fn1(a, b, c): - return a - b / c - - def fn2(a, b, c): - tmp1 = (a,) - tmp2 = (b, c) - args = (*tmp1, *tmp2) - return fn1(*args) - - def fn3(a, *args): - return fn1(a, *args) - - torchdynamo.testing.standard_test(self, fn=fn2, nargs=3, expected_ops=2) - torchdynamo.testing.standard_test(self, fn=fn3, nargs=3, expected_ops=2) - - def test_list_mul(self): - def fn(count): - head_mask = count * [None] * count - return head_mask - - cnts = torchdynamo.testing.CompileCounter() - opt_fn = torchdynamo.optimize(cnts)(fn) - self.assertEqual(opt_fn(2), [None] * 4) - self.assertEqual(cnts.frame_count, 0) - self.assertEqual(cnts.op_count, 0) - - def test_user_getattr1(self): - class MyConfig(dict): - def __getattr__(self, name): - return self[name] - - def fn(cfg, x, y): - return x + y + cfg.offset - - x = torch.randn(10) - cfg = MyConfig(offset=5) - cnts = torchdynamo.testing.CompileCounter() - opt_fn = torchdynamo.optimize(cnts)(fn) - self.assertTrue(same(opt_fn(cfg, x, x), 2 * x + 5)) - self.assertEqual(cnts.frame_count, 1) - self.assertEqual(cnts.op_count, 2) - - def test_user_getattr2(self): - class MyConfig: - defined_on_class = 1 - - def __init__(self): - self.defined_on_object = 2 - - def __getattr__(self, name): - return 3 - - def fn(cfg, x): - return x + cfg.defined_on_class - cfg.defined_on_object + cfg.not_defined - - x = torch.randn(10) - cfg = MyConfig() - cnts = torchdynamo.testing.CompileCounter() - opt_fn = torchdynamo.optimize(cnts)(fn) - self.assertTrue(same(opt_fn(cfg, x), x + 1 - 2 + 3)) - self.assertEqual(cnts.frame_count, 1) - self.assertEqual(cnts.op_count, 3) - - def test_user_property(self): - class MyConfig: - @property - def prop5(self): - return 5 - - def fn(cfg, x, y): - return x + y + cfg.prop5 - - x = torch.randn(10) - cfg = MyConfig() - cnts = torchdynamo.testing.CompileCounter() - opt_fn = torchdynamo.optimize(cnts)(fn) - self.assertTrue(same(opt_fn(cfg, x, x), 2 * x + 5)) - self.assertEqual(cnts.frame_count, 1) - self.assertEqual(cnts.op_count, 2) - - def test_dataclass_fields(self): - @dataclasses.dataclass - class MyDataClass: - a: torch.Tensor - b: torch.Tensor = None - c: torch.Tensor = None - d: torch.Tensor = None - e: torch.Tensor = None - - def fn(obj): - class_fields = dataclasses.fields(obj) - assert len(class_fields) - assert all(field.default is None for field in class_fields[1:]) - other_fields_are_none = all( - getattr(obj, field.name) is None for field in class_fields[1:] - ) - assert not other_fields_are_none - - total = getattr(obj, class_fields[0].name) - for field in class_fields[1:]: - v = getattr(obj, field.name) - if v is not None: - total += v - - return total - - obj1 = MyDataClass(torch.randn(10), torch.randn(10), torch.randn(10)) - obj2 = MyDataClass(torch.randn(10), e=torch.randn(10)) - correct1 = fn(obj1) - correct2 = fn(obj2) - - cnts = torchdynamo.testing.CompileCounter() - opt_fn = torchdynamo.optimize(cnts)(fn) - self.assertTrue(same(opt_fn(obj1), correct1)) - self.assertEqual(cnts.frame_count, 1) - self.assertEqual(cnts.op_count, 2) - - torchdynamo.reset() - cnts = torchdynamo.testing.CompileCounter() - opt_fn = torchdynamo.optimize(cnts)(fn) - self.assertTrue(same(opt_fn(obj2), correct2)) - self.assertEqual(cnts.frame_count, 1) - self.assertEqual(cnts.op_count, 1) - - @requires_static_shapes - def test_tensor_build_list_unpack(self): - def fn(x): - # seen in fastNLP_Bert - return torch.cat([*x], dim=-1) - - val = torch.randn([1, 1, 473, 768]) - correct = fn(val) - cnts = torchdynamo.testing.CompileCounter() - opt_fn = torchdynamo.optimize(cnts)(fn) - self.assertTrue(same(opt_fn(val), correct)) - self.assertEqual(cnts.frame_count, 1) - self.assertEqual(cnts.op_count, 2) - - def test_numpy_int_constant(self): - def fn(x, a, b): - return x + (a % b) - - args = [torch.randn(10), 4096, np.int64(8)] - correct = fn(*args) - cnts = torchdynamo.testing.CompileCounter() - opt_fn = torchdynamo.optimize(cnts)(fn) - self.assertTrue(same(opt_fn(*args), correct)) - self.assertTrue(same(opt_fn(*args), correct)) - self.assertEqual(cnts.frame_count, 1) - self.assertEqual(cnts.op_count, 2) - - def test_dict_mutation_side_effect(self): - def fn(d): - d["c"] = d["a"] + d.pop("b") - return d - - args1 = {"a": torch.randn(10), "b": torch.randn(10)} - args2 = dict(args1) - assert fn(args1) is args1 - cnts = torchdynamo.testing.CompileCounter() - opt_fn = torchdynamo.optimize(cnts)(fn) - self.assertIs(opt_fn(args2), args2) - self.assertTrue(same(args1, args2)) - self.assertEqual(cnts.frame_count, 1) - self.assertEqual(cnts.op_count, 1) - - def test_module_deepcopy(self): - m1 = torch.nn.Sequential( - torch.nn.Linear(10, 10), - torch.nn.ReLU(), - torch.nn.Linear(10, 10), - torch.nn.ReLU(), - ) - m2 = torch.nn.Sequential( - torch.nn.Linear(10, 10), - torch.nn.ReLU(), - torch.nn.Linear(10, 10), - torch.nn.ReLU(), - ) - - def fn(m, x): - m_copy = copy.deepcopy(m) - return m_copy(x) - - v = torch.randn(10) - correct1 = fn(m1, v) - correct2 = fn(m2, v) - cnts = torchdynamo.testing.CompileCounter() - opt_fn = torchdynamo.optimize(cnts)(fn) - for _ in range(10): - self.assertTrue(same(opt_fn(m1, v), correct1)) - for _ in range(10): - self.assertTrue(same(opt_fn(m2, v), correct2)) - self.assertEqual(cnts.frame_count, 1) - self.assertEqual(cnts.op_count, 4) - - def test_type_copy(self): - def fn(seq): - a, b = seq - return type(seq)([a + 1, b + 2, a + b]) - - args1 = [torch.randn(10), torch.randn(10)] - args2 = (torch.randn(10), torch.randn(10)) - correct1 = fn(args1) - correct2 = fn(args2) - cnts = torchdynamo.testing.CompileCounter() - opt_fn = torchdynamo.optimize(cnts)(fn) - self.assertTrue(same(opt_fn(args1), correct1)) - self.assertTrue(same(opt_fn(args2), correct2)) - self.assertIsInstance(opt_fn(args1), list) - self.assertIsInstance(opt_fn(args2), tuple) - self.assertEqual(cnts.frame_count, 2) - self.assertEqual(cnts.op_count, 6) - - def test_setattr_mutation1(self): - class MyObj: # noqa: B903 - def __init__(self, a, b): - self.a = a - self.b = b - - def fn(obj): - obj.c = obj.a * obj.b + 1 - obj.b = obj.a * obj.c + 2 - obj.a = obj.b * obj.c + 3 - obj.c = obj.a * obj.b + 4 - obj.b = obj.a * obj.c + 5 - obj.a = obj.b * obj.c + 6 - return obj - - x1 = torch.randn(10) - x2 = torch.randn(10) - obj1 = MyObj(x1, x2) - obj2 = MyObj(x1, x2) - fn(obj2) - cnts = torchdynamo.testing.CompileCounter() - opt_fn = torchdynamo.optimize(cnts)(fn) - self.assertIs(opt_fn(obj1), obj1) - self.assertTrue(same(obj1.a, obj2.a)) - self.assertTrue(same(obj1.b, obj2.b)) - self.assertTrue(same(obj1.c, obj2.c)) - self.assertEqual(cnts.frame_count, 1) - self.assertEqual(cnts.op_count, 12) - - def test_setattr_mutation2(self): - class MyObj: - def __init__(self, x): - self.a = x + 1 - self.b = x + 2 - - def fn(x): - x = x / 3.0 - obj = MyObj(x) - obj.c = obj.a * obj.b + 1 - obj.b = obj.a * obj.c + 2 - obj.a = obj.b * obj.c + 3 - return obj - - x1 = torch.randn(10) - obj2 = fn(x1) - - cnts = torchdynamo.testing.CompileCounter() - opt_fn = torchdynamo.optimize(cnts)(fn) - obj1 = opt_fn(x1) - self.assertTrue(same(obj1.a, obj2.a)) - self.assertTrue(same(obj1.b, obj2.b)) - self.assertTrue(same(obj1.c, obj2.c)) - self.assertEqual(cnts.frame_count, 1) - self.assertEqual(cnts.op_count, 9) - - def test_setattr_mutation3(self): - # TODO(jansel): dead code eliminate the object creation - class MyObj: - def __init__(self, x): - super().__init__() - self.a = x + 1 - self.b = x + 2 - - def fn(x): - x = x / 3.0 - obj = MyObj(x) - obj.c = obj.a * obj.b + 1 - obj.b = obj.a * obj.c + 2 - obj.a = obj.b * obj.c + 3 - return obj.a, obj.b, obj.c - - x1 = torch.randn(10) - obj2 = fn(x1) - - cnts = torchdynamo.testing.CompileCounter() - opt_fn = torchdynamo.optimize(cnts)(fn) - obj1 = opt_fn(x1) - self.assertTrue(same(obj1, obj2)) - self.assertEqual(cnts.frame_count, 1) - self.assertEqual(cnts.op_count, 9) - - def test_user_defined_class_name(self): - class MyClassFoo: - pass - - def fn1(a, b, c): - tmp = MyClassFoo() - if tmp.__class__.__name__ == "MyClassFoo": - return a - b / c - - torchdynamo.testing.standard_test(self, fn=fn1, nargs=3) - - def test_manual_seed(self): - def fn(a, b): - x = a + b - torch.manual_seed(9000) - return x + 1 - - torchdynamo.testing.standard_test(self, fn=fn, nargs=2, expected_ops=3) - - def test_usr_cls_staticmethod(self): - class Foo: - @staticmethod - def bar(a, b): - return a + b - - def fn(a, b): - return Foo.bar(a, b) - 1 - - torchdynamo.testing.standard_test(self, fn=fn, nargs=2) - - def test_usr_cls_classmethod(self): - class Foo: - @classmethod - def bar(cls, a, b): - return a + b - - def fn(a, b): - return Foo.bar(a, b) - 1 - - torchdynamo.testing.standard_test(self, fn=fn, nargs=2) - - def test_dunder_methods(self): - class Foo: - def __init__(self, val): - super().__init__() - self.val = val - - def __add__(self, other): - return Foo(self.val + other.val) - - def __mul__(self, other): - return Foo(self.val * other.val) - - def __truediv__(self, other): - return Foo(self.val / other.val) - - def __sub__(self, other): - return Foo(self.val - other.val) - - def fn(a, b, c): - return Foo(a) + Foo(b) * Foo(c) / Foo(a) - Foo(b) - - torchdynamo.testing.standard_test(self, fn=fn, nargs=3, expected_ops=4) - - def test_function_annotation(self): - class Variable: - pass - - def fn(x): - x = x / 3.0 - - def inner(y: typing.List[Variable]): - return x + 1 - - return inner - - x1 = torch.randn(10) - obj2 = fn(x1)([]) - - cnts = torchdynamo.testing.CompileCounter() - opt_fn = torchdynamo.optimize_assert(cnts)(fn) - opt_fn_inner = torchdynamo.optimize_assert(cnts)(opt_fn(x1)) - obj1 = opt_fn_inner([]) - self.assertTrue(same(obj1, obj2)) - self.assertEqual(cnts.frame_count, 2) - self.assertEqual(cnts.op_count, 2) - - def test_nested_closure(self): - v0 = torch.randn(10) - - def fn1(): - v1 = torch.randn(10) - - def fn2(*args, **kwargs): - assert len(args) == 1 - assert len(kwargs) == 1 - v2 = torch.randn(10) + args[0] + kwargs["b"] - - def fn3(v3=torch.randn(10)): - def fn4(): - return v0 + v1 + v2 + v3 + 1 - - return fn4 - - return fn3 - - return fn2(1, b=2)() - - cnts = torchdynamo.testing.CompileCounter() - opt_fn1 = torchdynamo.optimize_assert(cnts)(fn1) - tmp1 = torchdynamo.optimize_assert(cnts)(opt_fn1()) - tmp2 = torchdynamo.optimize_assert(cnts)(opt_fn1()) - self.assertTrue(tmp1().shape, (10,)) - self.assertTrue(same(tmp1(), tmp1())) - self.assertFalse(same(tmp1(), tmp2())) - self.assertEqual(cnts.frame_count, 2) - self.assertEqual(cnts.op_count, 9) - - def test_nested_closure_mutation(self): - def fn1(): - v1 = torch.randn(10) - - def fn2(): - v2 = torch.randn(10) - - def fn3(): - nonlocal v1, v2 - v1 += 1 - v2 += 2 - return v1 + v2 - - return fn3 - - rv = fn2() - rv() - rv() - return rv - - torch.manual_seed(9000) - counter1 = fn1() - result1 = [counter1(), counter1(), counter1()] - - torch.manual_seed(9000) - cnts = torchdynamo.testing.CompileCounter() - opt_fn1 = torchdynamo.optimize_assert(cnts)(fn1) - counter2 = torchdynamo.optimize_assert(cnts)(opt_fn1()) - result2 = [counter2(), counter2(), counter2()] - result1.append(counter1()) - result2.append(counter2()) - - self.assertTrue(same(result1, result2)) - self.assertEqual(cnts.frame_count, 2) - self.assertEqual(cnts.op_count, 11) - - def test_write_to_closures_in_inlining(self): - out = [] - for use_dynamo in [False, True]: - - def make_counter(): - x = torch.randn(10) - - def counter(): - nonlocal x - x = x + 1 - return x - - return counter - - torch.manual_seed(0) - counter = make_counter() - if not use_dynamo: - out.append(counter() + counter()) - else: - cnts = torchdynamo.testing.CompileCounter() - - @torchdynamo.optimize(cnts, nopython=True) - def fn(counter): - return counter() + counter() - - out.append(fn(counter)) - self.assertEqual(cnts.frame_count, 1) - self.assertEqual(cnts.op_count, 3) - self.assertFalse(same(counter() + counter(), out[-1])) - - self.assertTrue(same(out[0], out[1])) - - def test_top_package_import(self): - def fn(x): - import torch.fx - - assert not isinstance(x, torch.fx.Proxy) - return torch.sin(x) - - x = torch.randn(4, 5) - ref = fn(x) - cnts = torchdynamo.testing.CompileCounter() - opt_fn = torchdynamo.optimize_assert(cnts)(fn) - res = opt_fn(x) - self.assertTrue(same(ref, res)) - - def test_optimize_on_module(self): - class MockModule(torch.nn.Module): - def __init__(self): - super().__init__() - self.relu = torch.nn.ReLU() - - def custom_member(self): - # Just for checking that Dynamo returned mod object can redirect - # to this method - pass - - def forward(self, x): - return self.relu(x) - - cnts1 = torchdynamo.testing.CompileCounter() - mod = MockModule() - optimized_mod = torchdynamo.optimize(cnts1, nopython=True)(mod) - - a = torch.randn(10) - ref = mod(a) - res = optimized_mod(a) - - optimized_mod.custom_member() - - self.assertTrue(same(ref, res)) - - def test_nested_optimize_decorator(self): - cnts2 = torchdynamo.testing.CompileCounter() - cnts3 = torchdynamo.testing.CompileCounter() - - @torchdynamo.run() - def fn1(x): - return torch.sin(x) * 10 - - @torchdynamo.optimize(cnts2, nopython=True) - def fn2(x): - return fn1(x) + 1 - - @torchdynamo.optimize(cnts3, nopython=True) - def fn3(x): - return torch.relu(fn2(x)) - - fn3(torch.randn(4, 5)) - self.assertEqual(cnts2.frame_count, 0) - self.assertEqual(cnts3.frame_count, 1) - self.assertEqual(cnts3.op_count, 4) - - def test_nested_optimize_run(self): - cnts = torchdynamo.testing.CompileCounter() - - @torchdynamo.optimize(cnts, nopython=True) - def fn(x): - return torch.relu(torch.cos(x) + torch.sin(x)) - - fn(torch.randn(4)) - self.assertEqual(cnts.frame_count, 1) - - fn(torch.randn(4, 4)) - self.assertEqual(cnts.frame_count, 2) - - # Test that run works on a decorated fn - fn = torchdynamo.run(fn) - fn(torch.randn(4, 4, 4)) - self.assertEqual(cnts.frame_count, 2) - - def test_nested_optimize(self): - cnts1 = torchdynamo.testing.CompileCounter() - cnts2 = torchdynamo.testing.CompileCounter() - - def fn(x): - return torch.relu(torch.cos(x) + torch.sin(x)) - - fn1 = torchdynamo.optimize(cnts1, nopython=True)(fn) - fn2 = torchdynamo.optimize(cnts2, nopython=True)(fn1) - - # The first optimize in the nesting should be ignored - fn2(torch.randn(4)) - self.assertEqual(cnts2.frame_count, 1) - self.assertEqual(cnts1.frame_count, 0) - - # Since the fn code object is already compiled, calling fn1 should - # directly call the compiled_fn callable. - torchdynamo.run()(fn1)(torch.randn(4)) - self.assertEqual(cnts1.frame_count, 0) - - # Test same behavior by reversing the calls - torchdynamo.reset() - cnts1 = torchdynamo.testing.CompileCounter() - cnts2 = torchdynamo.testing.CompileCounter() - fn1 = torchdynamo.optimize(cnts1, nopython=True)(fn) - fn2 = torchdynamo.optimize(cnts2, nopython=True)(fn1) - fn1(torch.randn(4)) - self.assertEqual(cnts1.frame_count, 1) - torchdynamo.run()(fn2)(torch.randn(4)) - self.assertEqual(cnts2.frame_count, 0) - - def test_nested_disable_decorator(self): - cnts = torchdynamo.testing.CompileCounter() - - @torchdynamo.disable() - def fn1(x): - return torch.sin(x) * 10 - - @torchdynamo.optimize(cnts) - def fn2(x): - x = x + 1 - x = x + 1 - x = fn1(x) # graph break - x = x + 1 - x = x + 1 - return x - - @torchdynamo.optimize(cnts, nopython=True) - def fn3(x): - return fn2(x) - - fn2(torch.randn(4, 5)) - self.assertEqual(cnts.frame_count, 2) - self.assertEqual(cnts.op_count, 4) - - try: - fn3(torch.randn(4, 5)) - self.assertFalse(True) - except torchdynamo.exc.Unsupported as e: - self.assertIn("call torch._dynamo.disable() wrapped function", str(e)) - - def test_graph_break(self): - cnts = torchdynamo.testing.CompileCounter() - - @torchdynamo.optimize(cnts) - def fn(x): - x = torch.cos(x) - x = torch.cos(x) - torchdynamo.graph_break() - x = torch.cos(x) - x = torch.cos(x) - graph_break() - x = torch.cos(x) - x = torch.cos(x) - return x - - fn(torch.randn(4, 5)) - self.assertEqual(cnts.frame_count, 3) - self.assertEqual(cnts.op_count, 6) - - def test_torch_size(self): - cnts = torchdynamo.testing.CompileCounter() - - def fn(x): - output_size = torch.Size([10, 10]) - x = x.view(*output_size) - return (x,) - - x = torch.randn(100, requires_grad=True) - x_clone = x.clone() - ref = fn(x) - - opt_fn = torchdynamo.optimize(cnts, nopython=True)(fn) - res = opt_fn(x_clone) - - self.assertTrue(same(ref, res)) - - def test_torch_seed(self): - cnts = torchdynamo.testing.CompileCounter() - - def fn(x): - attention_seed = int(torch.seed() % sys.maxsize) - torch.manual_seed(attention_seed) - return (x,) - - x = torch.randn(100, requires_grad=True) - ref = fn(x) - - opt_fn = torchdynamo.optimize(cnts, nopython=True)(fn) - res = opt_fn(x) - - self.assertTrue(same(ref, res)) - - def test_is_tensor_like(self): - cnts = torchdynamo.testing.CompileCounter() - - def f(x): - if torch.overrides.is_tensor_like(x): - return (x * 2,) - return (torch.ones(10) + x,) - - x = torch.randn(10) - ref0 = f(x) - ref1 = f(4) - opt_f = torchdynamo.optimize(cnts, nopython=True)(f) - res0 = opt_f(x) - res1 = opt_f(4) - self.assertTrue(same(ref0, res0)) - self.assertTrue(same(ref1, res1)) - - def test_version_ci(self): - # temporary test to check that the ci torch version is set correctly - self.assertTrue(hasattr(torch, "_subclasses")) - - @unittest.skipIf(not torch.cuda.is_available(), "requires cuda") - def test_rand(self): - cnts = torchdynamo.testing.CompileCounter() - device = "cuda" - - def fn(): - return torch.randn(10, device=device) - - torch.manual_seed(10) - ref_run1 = fn() - - torch.manual_seed(10) - ref_run2 = fn() - self.assertTrue(same(ref_run1, ref_run2)) - - torch.manual_seed(10) - opt_fn = torchdynamo.optimize(cnts, nopython=True)(fn) - res = opt_fn() - - self.assertTrue(same(res, ref_run1)) - - def test_slice_input(self): - cnts = torchdynamo.testing.CompileCounter() - - def getitem(a, idx): - if isinstance(idx, slice): - return ( - torch.zeros(1), - a[idx] - + [ - 100, - ], - ) - else: - return (torch.zeros(1), a[idx]) - - layers = list(range(10)) - ref0 = getitem(layers, slice(0, 2, 1)) - ref1 = getitem(layers, 2) - ref2 = getitem(layers, slice(3, 8, 2)) - opt_getitem = torchdynamo.optimize(cnts, nopython=True)(getitem) - res0 = opt_getitem(layers, slice(0, 2, 1)) - res1 = opt_getitem(layers, 2) - res2 = opt_getitem(layers, slice(3, 8, 2)) - - self.assertTrue(ref0 == res0) - self.assertTrue(ref1 == res1) - self.assertTrue(ref2 == res2) - - def test_grad(self): - cnts = torchdynamo.testing.CompileCounter() - - def fn(a, b): - out = a * b - out.sum().backward() - real_out = torch.sigmoid(a.grad + b) - return real_out - - inps = [torch.randn(4, requires_grad=True) for _ in range(2)] - for inp in inps: - inp.grad = None - ref = fn(*inps) - - for inp in inps: - inp.grad = None - opt_fn = torchdynamo.optimize(cnts)(fn) - res = opt_fn(*inps) - - self.assertTrue(same(ref, res)) - - @unittest.skipIf(sys.version_info < (3, 10), "use linetable when python >= 3.10") - def test_linetable_writer(self): - def fn(): - a = 10 - b = 20 - c = a + b - f = "linetable_writer" - return f"Test if {f} generates correct co_linetable: {c}" - - inst = dis.get_instructions(fn) - result = bytecode_transformation.assemble(inst, fn.__code__.co_firstlineno) - self.assertTrue(result[1] == fn.__code__.co_linetable) - - @unittest.skipIf(sys.version_info >= (3, 10), "use lnotab when python < 3.10") - def test_lnotab_writer(self): - def fn(): - a = 10 - b = 20 - c = a + b - f = "lnotab_writer" - return f"Test if {f} generates correct co_lnotab: {c}" - - inst = dis.get_instructions(fn) - result = bytecode_transformation.assemble(inst, fn.__code__.co_firstlineno) - self.assertTrue(result[1] == fn.__code__.co_lnotab) - - def test_torch_profiler(self): - # wrap torch.profiler.* as ProfilerContextWrapperVariable and do nothing - def fn(x): - y = x**2 - with torch.profiler.profile(): - y = y + 2 - with torch.profiler.record_function("my_function"): - z = y**3 - z.tolist() # graph break - z = z + 1 - return z - - x = torch.randn((2, 2), requires_grad=True) - ref = fn(x) - cnts = torchdynamo.testing.CompileCounter() - opt_fn = torchdynamo.optimize(cnts)(fn) - res = opt_fn(x) - self.assertTrue(same(ref, res)) - self.assertEqual(cnts.frame_count, 2) - - def test_autograd_profiler(self): - # wrap torch.autograd.profiler.* as ProfilerContextWrapperVariable and do nothing - def fn(x): - y = x**2 - with torch.autograd.profiler.profile(): - y = y + 2 - with torch.autograd.profiler.record_function("my_function"): - z = y**3 - z.tolist() # graph break - z = z + 1 - return z - - x = torch.randn((2, 2), requires_grad=True) - ref = fn(x) - cnts = torchdynamo.testing.CompileCounter() - opt_fn = torchdynamo.optimize(cnts)(fn) - res = opt_fn(x) - self.assertTrue(same(ref, res)) - self.assertEqual(cnts.frame_count, 2) - - def test_python_slice(self): - def f1(input): - y = 0 - for i, x in enumerate(input[2:], 1): - y = y + x - return y - - def f2(input): - y = 0 - for i, x in enumerate(input.shape[2:], 1): - y = y + x - return y - - cnts = torchdynamo.testing.CompileCounter() - opt_f1 = torchdynamo.optimize(cnts)(f1) - opt_f2 = torchdynamo.optimize(cnts)(f2) - res1 = opt_f1([1, 2, 3, 5]) - res2 = opt_f2(torch.rand([2, 3, 4, 5])) - - self.assertEqual(res1, 8) - self.assertEqual(res2, 9) - - def test_const_dict_variable_python_type(self): - from torchdynamo.variables import ConstDictVariable - - d1 = {"a": 10, "b": 20} - d2 = collections.OrderedDict([("x", 12), ("y", 22)]) - self.assertEqual(ConstDictVariable(d1, dict).python_type(), dict) - self.assertEqual( - ConstDictVariable(d2, collections.OrderedDict).python_type(), - collections.OrderedDict, - ) - - def test_builtin_subclasses_as_method_on_class_type(self): - class Foo: - def __init__(self, name): - self.ame_ = name - - def get_name(self): - return "Foo " + self.name_ - - class Bar(Foo): - def __init__(self, name): - self.name_ = name - - def get_name(self): - return "Bar " + self.name_ - - class Baz(Foo): - def __init__(self, name): # noqa: B903 - self.name_ = name - - def get_name(self): - return "Baz " + self.name_ - - subs_of_foo_reg = Foo.__subclasses__() - - counter = CompileCounter() - - @torchdynamo.optimize_assert(counter) - def fn(): - return Foo.__subclasses__() - - subs_of_foo_optim = fn() - - self.assertEqual(len(subs_of_foo_reg), 2) - self.assertEqual(subs_of_foo_reg, subs_of_foo_optim) - - def test_builtin_subclasses_as_method_on_var(self): - class Foo: - def __init__(self, name): - self.name_ = name - - def get_name(self): - return "Foo " + self.name_ - - class Bar(Foo): - def __init__(self, name): - self.name_ = name - - def get_name(self): - return "Bar " + self.name_ - - class Baz(Bar): - def __init__(self, name): - self.name_ = name - - def get_name(self): - return "Baz " + self.name_ - - subs_of_foo_reg = Foo.__subclasses__() - sub_of_foo_subclass_var_reg = subs_of_foo_reg[0].__subclasses__() - - sub_of_foo_subclass_var_optim = list() - counter = CompileCounter() - - @torchdynamo.optimize_assert(counter) - def fn(): - return Foo.__subclasses__() - - @torchdynamo.optimize_assert(counter) - def fn_single(subs_of_foo_optim): - return subs_of_foo_optim[0].__subclasses__() - - subs_of_foo_optim = fn() - sub_of_foo_subclass_var_optim = fn_single(subs_of_foo_optim) - - self.assertEqual(len(sub_of_foo_subclass_var_optim), 1) - self.assertEqual(sub_of_foo_subclass_var_optim, sub_of_foo_subclass_var_reg) - - def test_enum_no_graphbreaks(self): - class Foo(enum.Enum): - FOO = 0 - BAR = 1 - - def fn(x, foo): - if foo is Foo.FOO: - x = torch.add(x, 1.0) - x = torch.mul(x, 1.0) - return x - - x = torch.randn(1) - cnts = torchdynamo.testing.CompileCounter() - opt_fn = torchdynamo.optimize(cnts, nopython=True)(fn) - opt_fn(x, Foo.FOO) - self.assertEqual(cnts.op_count, 2) - - torchdynamo.reset() - cnts = torchdynamo.testing.CompileCounter() - opt_fn = torchdynamo.optimize(cnts, nopython=True)(fn) - opt_fn(x, Foo.BAR) - self.assertEqual(cnts.op_count, 1) - - def test_id_of_nn_module(self): - class M(torch.nn.Module): - def forward(self, x, ref_id): - self_id = id(self) - if self_id == ref_id: - x = torch.mul(x, 1.0) - x = torch.add(x, 1.0) - return x - - m = M().eval() - data = torch.randn(1) - cnts = torchdynamo.testing.CompileCounter() - correct_ref_id = id(m) - opt_m = torchdynamo.optimize(cnts, nopython=True)(m) - opt_m(data, correct_ref_id) - self.assertEqual(cnts.op_count, 2) - - torchdynamo.reset() - cnts = torchdynamo.testing.CompileCounter() - incorrect_ref_id = id(m) + 1 - opt_m = torchdynamo.optimize(cnts, nopython=True)(m) - opt_m(data, incorrect_ref_id) - self.assertEqual(cnts.op_count, 1) - - def test_inline_func_jump_on_tensor_condition(self): - def f1(input): - if input == 0: - return input + 1 - else: - return input + 2 - - def f2(input): - return f1(input) - - cnts = torchdynamo.testing.CompileCounter() - opt_f2 = torchdynamo.optimize(cnts)(f2) - res1 = opt_f2(torch.tensor([1.0])) - res2 = opt_f2(torch.tensor([0.0])) - - self.assertEqual(res1, 3) - self.assertEqual(res2, 1) - - def test_frozenset_torch_func_contains(self): - funcs = frozenset([torch.add]) - - def fn(x, func): - if func in funcs: - x = torch.add(x, 1.0) - x = torch.mul(x, 1.0) - return x - - x = torch.randn(1) - cnts = torchdynamo.testing.CompileCounter() - opt_fn = torchdynamo.optimize(cnts, nopython=True)(fn) - opt_fn(x, torch.add) - self.assertEqual(cnts.op_count, 2) - - torchdynamo.reset() - cnts = torchdynamo.testing.CompileCounter() - opt_fn = torchdynamo.optimize(cnts, nopython=True)(fn) - opt_fn(x, torch.mul) - self.assertEqual(cnts.op_count, 1) - - @patch.object(torchdynamo.config, "fake_tensor_propagation", True) - def test_unsupported_fake_tensor(self): - def f(x): - return torch.quantize_per_tensor(x, 0.1, 10, torch.quint8) - - x = torch.randn(2, 2) - cnts = torchdynamo.testing.CompileCounter() - opt_f = torchdynamo.optimize(cnts)(f) - opt_f(x) - self.assertEqual(cnts.op_count, 0) - - torchdynamo.reset() - with patch.object(torchdynamo.config, "fake_tensor_propagation", False): - opt_f = torchdynamo.optimize_assert(torchdynamo.testing.CompileCounter())(f) - opt_f(x) - - def test_inline_list_mutation(self): - def f1(x): - x.append(torch.ones(8)) - return x - - def f2(): - x = [torch.ones(6)] - f1(x) - return x - - res1 = f2() - cnts = torchdynamo.testing.CompileCounter() - opt_f2 = torchdynamo.optimize(cnts)(f2) - res2 = opt_f2() - self.assertTrue(same(res1, res2)) - - def test_inline_dict_mutation(self): - def f1(d): - d["c"] = d["a"] + d.pop("b") - return d - - def f2(): - d = {"a": torch.ones(5), "b": torch.ones(5)} - f1(d) - return d - - res1 = f2() - cnts = torchdynamo.testing.CompileCounter() - opt_f2 = torchdynamo.optimize(cnts)(f2) - res2 = opt_f2() - self.assertTrue(same(res1, res2)) - - def test_recursive_inline_list_mutation(self): - def f1(x, y): - x.append(torch.tensor([1.1])) - y.append(torch.tensor([1.2])) - return x, y - - def f2(x, y): - x.append(torch.tensor([2.1])) - y.append(torch.tensor([2.2])) - f1(x, y) - return x, y - - def f3(x): - x.append(torch.tensor([3.1])) - y = [torch.tensor([3.2])] - f2(x, y) - return x, y - - def f4(): - x = [torch.tensor([4.1])] - return f3(x) - - res1 = f4() - cnts = torchdynamo.testing.CompileCounter() - opt_f4 = torchdynamo.optimize(cnts)(f4) - res2 = opt_f4() - self.assertTrue(same(res1, res2)) - - def test_disallow_in_graph(self): - cnts = torchdynamo.testing.CompileCounter() - - @torchdynamo.optimize(cnts) - def fn(a): - x = torch.add(a, 1) - x = torch.add(x, 1) - x = torch.sub(x, 1) - x = torch.add(x, 1) - x = torch.add(x, 1) - return x - - torchdynamo.disallow_in_graph(torch.sub) - fn(torch.randn(10)) - torchdynamo.allow_in_graph(torch.sub) - - # check for graph break on sub - self.assertEqual(cnts.frame_count, 2) - self.assertEqual(cnts.op_count, 4) - - def test_allow_in_graph(self): - cnts = torchdynamo.testing.CompileCounter() - - @torchdynamo.optimize(cnts) - def fn(a): - x = torch.add(a, 1) - x = torch.add(x, 1) - x = my_custom_function(x) - x = torch.add(x, 1) - x = torch.add(x, 1) - return x - - torchdynamo.allow_in_graph(my_custom_function) - fn(torch.randn(10)) - torchdynamo.disallow_in_graph(my_custom_function) - - # check for no graph break - self.assertEqual(cnts.frame_count, 1) - self.assertEqual(cnts.op_count, 5) - - def test_sample_input(self): - from torch.testing._internal.common_methods_invocations import SampleInput - - def fn(sample): - if isinstance(sample.input, torch.Tensor): - return sample.input * 2 - return torch.zeros(()) - - sample = SampleInput(torch.ones(2)) - ref = fn(sample) - - opt_fn = torchdynamo.optimize("eager")(fn) - res = opt_fn(sample) - - self.assertTrue(same(ref, res)) - - def test_release_input_memory(self): - x = torch.rand([4]) - x_ref = weakref.ref(x) - - cnts = torchdynamo.testing.CompileCounter() - - @torchdynamo.optimize(cnts) - def foo(x): - return x + x - - out = foo(x) - self.assertTrue(same(out, x + x)) - del x - self.assertIs(x_ref(), None) - - def test_release_module_memory(self): - - mod = torch.nn.Linear(10, 10) - x = torch.rand([10, 10]) - mod_weight_ref = weakref.ref(mod.weight) - mod_ref = weakref.ref(mod) - - # Modules that are passed into torchdynamo optimized functions - # will normally be held onto through the generated GraphModule, - # which contains the modules. remove the reference in this backend - # and test that no additional references are being held. - class NoLeakBackend: - def __call__(self, gm: torch.fx.GraphModule, example_inputs): - gm.mod = None - - def foo(*args, **kwargs): - return (1,) - - return foo - - no_leak_backend = NoLeakBackend() - - @torchdynamo.optimize(no_leak_backend) - def foo(mod, x): - return mod(x) - - foo(mod, x) - del mod - del x - self.assertIsNone(mod_ref(), None) - self.assertIsNone(mod_weight_ref(), None) - - def test_update_locals_and_stack_uses_shared_cache(self): - def fn(x): - perm = [0, 3, 5] - perm = list(range(min(perm))) + perm - perm.extend(i for i in range(x.dim()) if i not in perm) - return perm - - x = torch.rand([2, 2, 2, 2, 2, 2]) - res1 = fn(x) - cnts = torchdynamo.testing.CompileCounter() - opt_fn = torchdynamo.optimize(cnts)(fn) - res2 = opt_fn(x) - self.assertTrue(same(res1, res2)) - - def test_dict_reconstruct_keeps_original_order(self): - def fn(): - modules = collections.OrderedDict([("act", torch.nn.ReLU())]) - module_dict = torch.nn.ModuleDict(modules) - - next_modules = {"fc4": torch.nn.Linear(5, 6), "act3": torch.nn.Sigmoid()} - modules.update(next_modules.items()) - module_dict.update(next_modules) - return modules, module_dict - - cnts = torchdynamo.testing.CompileCounter() - opt_fn = torchdynamo.optimize(cnts)(fn) - modules, module_dict = opt_fn() - - self.assertEqual(len(module_dict), len(modules)) - for k1, m2 in zip(modules, module_dict.children()): - self.assertTrue(modules[k1] is m2) - - def test_side_effects_codegen_update_mutated(self): - # codegen to update mutated variables with side effect - # should after stack value's codegen - def f1(x): - alist = [x] - alist.append(x + 1) - alist[0].sum().item() # graph break - res = alist.pop() - res.sum().item() # graph break - return res - - def f2(a, b): - d = {"a": a + 1, "b": b + 2} - x = d.pop("b") - x.sum().item() # graph break - y = d["a"] + x - y.sum().item() # graph break - d["c"] = y - return d - - x = torch.rand([2, 3]) - a = torch.rand([5, 6]) - b = torch.rand([5, 6]) - res11 = f1(x) - res21 = f2(a, b) - cnts = torchdynamo.testing.CompileCounter() - opt_f1 = torchdynamo.optimize(cnts)(f1) - opt_f2 = torchdynamo.optimize(cnts)(f2) - res12 = opt_f1(x) - res22 = opt_f2(a, b) - self.assertTrue(same(res11, res12)) - self.assertTrue(same(res21, res22)) - - def test_list_append_return_none(self): - def fn(x): - alist = [] - blist = alist.append(x + 1) - return alist, blist - - x = torch.tensor([2.3]) - res = fn(x) - cnts = torchdynamo.testing.CompileCounter() - opt_fn = torchdynamo.optimize(cnts)(fn) - res2 = opt_fn(x) - self.assertEqual(res, res2) - - def test_tensor_types(self): - def fn(dtype, tensor_type): - x = torch.empty(4, dtype=dtype) - assert isinstance(x, tensor_type) - - opt_fn = torchdynamo.optimize("eager")(fn) - opt_fn(torch.float32, torch.FloatTensor) - opt_fn(torch.float64, torch.DoubleTensor) - opt_fn(torch.float16, torch.HalfTensor) - opt_fn(torch.bfloat16, torch.BFloat16Tensor) - opt_fn(torch.uint8, torch.ByteTensor) - opt_fn(torch.int8, torch.CharTensor) - opt_fn(torch.int64, torch.LongTensor) - opt_fn(torch.int, torch.IntTensor) - opt_fn(torch.int16, torch.ShortTensor) - opt_fn(torch.bool, torch.BoolTensor) - - def test_nan(self): - def f(x, n): - return x * 2 + n - - x = torch.randn(4) - n = float("nan") - - cnts = torchdynamo.testing.CompileCounter() - opt_f = torchdynamo.optimize(cnts)(f) - opt_f(x, n) - opt_f(x, n) - self.assertEqual(cnts.frame_count, 1) - - @patch.object(torchdynamo.config, "capture_scalar_outputs", True) - def test_item(self): - class MyMod(torch.nn.Module): - def forward(self, x): - z = torch.max(x) - return z.int().item() - - x = torch.tensor([[10.6763, 11.7445, -2.2369]]) - model = MyMod() - y = torchdynamo.optimize("eager", nopython=True)(model)(x) - - self.assertEqual(y, 11) - - @patch.object(torchdynamo.config, "capture_scalar_outputs", True) - def test_item_changes(self): - class MyMod(torch.nn.Module): - def forward(self, x): - z = torch.max(x) - return z.int().item() - - x = torch.tensor([[10.6763, 11.7445, -2.2369]]) - model = MyMod() - opt_model = torchdynamo.optimize("eager", nopython=True)(model) - y = opt_model(x) - z = opt_model(torch.tensor([[y - 5, y + 10, y + 50]])) - - self.assertEqual(y, 11) - self.assertEqual(z, 61) - - @patch.object(torchdynamo.config, "capture_scalar_outputs", True) - def test_item_changes_new_shape(self): - class MyMod(torch.nn.Module): - def forward(self, x): - z = torch.max(x) - return z.int().item() - - x = torch.tensor([[10.6763, 11.7445, -2.2369]]) - model = MyMod() - opt_model = torchdynamo.optimize("eager", nopython=True)(model) - y = opt_model(x) - z = opt_model(torch.tensor([[y - 5, y + 50], [y + 5, y - 50]])) - - self.assertEqual(y, 11) - self.assertEqual(z, 61) - - def test_cross_entropy_loss_fancy_ctor(self): - output = None - rand_5 = torch.randn(5) - rand_3_5 = torch.randn(3, 5) - target = torch.empty(3, dtype=torch.long).random_(5) - - loss = torch.nn.CrossEntropyLoss( - weight=rand_5, reduce=False, label_smoothing=0.5 - ) - opt_loss = torchdynamo.optimize("eager", nopython=True)(loss) - input = rand_3_5 - dynamo_output = opt_loss(input, target) - - loss = torch.nn.CrossEntropyLoss( - weight=rand_5, reduce=False, label_smoothing=0.5 - ) - input = rand_3_5 - output = loss(input, target) - - self.assertTrue(torch.allclose(dynamo_output, output)) - - def test_cross_entropy_loss_simple_ctor(self): - output = None - rand_3_5 = torch.randn(3, 5) - target = torch.empty(3, dtype=torch.long).random_(5) - - loss = torch.nn.CrossEntropyLoss() - opt_loss = torchdynamo.optimize("eager", nopython=True)(loss) - input = rand_3_5 - dynamo_output = opt_loss(input, target) - - loss = torch.nn.CrossEntropyLoss() - input = rand_3_5 - output = loss(input, target) - - self.assertTrue(torch.allclose(dynamo_output, output)) - - def test_large_reduction_list(self): - dtype = torch.float32 - device = "cpu" - - def check_sum_all(tensor: torch.Tensor) -> None: - pylist = tensor.reshape(-1).tolist() - self.assertTrue(same(tensor.sum(), torch.tensor(sum(pylist)))) - - check_sum_all(torch.randn(200000, dtype=dtype, device=device)) - - @patch.object(torchdynamo.config, "raise_on_backend_error", True) - def test_raise_on_backend_error(self): - def my_compiler(gm, _): - raise RuntimeError("duck!") - - @torchdynamo.optimize(my_compiler) - def fn(a, b): - return a + b / (a - b) - - self.assertRaises( - torchdynamo.exc.BackendCompilerFailed, - lambda: fn(torch.randn(10), torch.randn(10)), - ) - - def test_named_parameters(self): - n_embd = 768 - block_size = 128 - vocab_size = 65 - embd_pdrop = 0.1 - - class MyModel2(torch.nn.Module): - def __init__(self): - super().__init__() - self.tok_emb = torch.nn.Embedding(vocab_size, n_embd) - self.pos_emb = torch.nn.Parameter(torch.zeros(1, block_size, n_embd)) - self.drop = torch.nn.Dropout(embd_pdrop) - - def forward(self, x): - return x - - class MyModel(torch.nn.Module): - def __init__(self): - super().__init__() - self.tok_emb = torch.nn.Embedding(vocab_size, n_embd) - self.pos_emb = torch.nn.Parameter(torch.zeros(1, block_size, n_embd)) - self.drop = torch.nn.Dropout(embd_pdrop) - self.submod2 = MyModel2() - - def forward(self, x): - return x - - # Regular - params = [] - mod = MyModel() - actual_params = list(mod.named_parameters()) - - @torchdynamo.optimize("eager", nopython=True) - def fn(): - return list(mod.named_parameters()) - - params = fn() - - self.assertEqual(len(actual_params), len(params)) - for idx in range(len(params)): - k_a, v_a = actual_params[idx] - k, v = params[idx] - self.assertEqual(k_a, k) - self.assertTrue(torch.allclose(v_a, v)) - - # Prefix - params = [] - mod = MyModel() - actual_params = list(mod.named_parameters(prefix="foo")) - - @torchdynamo.optimize("eager", nopython=True) - def fn1(): - return list(mod.named_parameters(prefix="foo")) - - params = fn1() - - self.assertEqual(len(actual_params), len(params)) - for idx in range(len(params)): - k_a, v_a = actual_params[idx] - k, v = params[idx] - self.assertEqual(k_a, k) - self.assertTrue(torch.allclose(v_a, v)) - - def test_module_complex_iter(self): - n_embd = 768 - block_size = 128 - vocab_size = 65 - embd_pdrop = 0.1 - - class FakeGPT(torch.nn.Module): - def __init__(self): - super().__init__() - self.tok_emb = torch.nn.Embedding(vocab_size, n_embd) - self.pos_emb = torch.nn.Parameter(torch.zeros(1, block_size, n_embd)) - self.drop = torch.nn.Dropout(embd_pdrop) - self.ln_f = torch.nn.LayerNorm(n_embd) - self.head = torch.nn.Linear(n_embd, vocab_size, bias=False) - - self.block_size = block_size - self.names = [] - - def forward(self, idx, targets=None): - from torch.nn import functional as F - - b, t = idx.size() - assert ( - t <= self.block_size - ), "Cannot forward, model block size is exhausted." - - # forward the GPT model - token_embeddings = self.tok_emb( - idx - ) # each index maps to a (learnable) vector - position_embeddings = self.pos_emb[ - :, :t, : - ] # each position maps to a (learnable) vector - x = self.drop(token_embeddings + position_embeddings) - x = self.blocks(x) - x = self.ln_f(x) - logits = self.head(x) - - # if we are given some desired targets also calculate the loss - loss = None - if targets is not None: - loss = F.cross_entropy( - logits.view(-1, logits.size(-1)), targets.view(-1) - ) - - return logits, loss - - def foo(self, memo=None, prefix="", remove_duplicate=False): - for mn, m in self.named_modules( - memo=memo, prefix=prefix, remove_duplicate=remove_duplicate - ): - for pn, p in self.named_parameters(): - fpn = "%s.%s" % (mn, pn) if mn else pn - self.names.append(fpn) - - # Test plain recurse - model_a = FakeGPT() - model_a.foo() - a_names = model_a.names - - model_b = FakeGPT() - opt_model_b = torchdynamo.optimize("eager", nopython=True)(model_b) - opt_model_b.foo() - - self.assertEqual(a_names, model_b.names) - - # Test with prefix - model_a = FakeGPT() - model_a.foo(prefix="abc") - a_names = model_a.names - - model_b = FakeGPT() - opt_model_b = torchdynamo.optimize("eager", nopython=True)(model_b) - opt_model_b.foo(prefix="abc") - - self.assertEqual(a_names, model_b.names) - - def test_numpy_variable_isinstance(self): - def fn(x, m): - if isinstance(m, np.ndarray): - return x + 1 - else: - return x - 1 - - x = torch.tensor([2.3]) - m = np.array([1, 2, 3]) - ref = fn(x, m) - cnts = torchdynamo.testing.CompileCounter() - opt_fn = torchdynamo.optimize(cnts)(fn) - res = opt_fn(x, m) - self.assertEqual(ref, res) - - def test_tensor_dot_grad_no_graph_break(self): - def fn(a, b): - y = 3 * a**3 - b**2 - y.backward(gradient=torch.tensor([1.0, 1.0])) - b.grad.zero_() - return a.grad, b.grad - - a = torch.tensor([2.0, 3.0], requires_grad=True) - b = torch.tensor([6.0, 4.0], requires_grad=True) - cnts = torchdynamo.testing.CompileCounter() - opt_fn = torchdynamo.optimize(cnts)(fn) - _, b_grad = opt_fn(a, b) - self.assertTrue(same(b_grad, torch.tensor([0.0, 0.0]))) - self.assertEqual(cnts.frame_count, 2) - - def test_torch_nn_parameter_isinstance(self): - def fn(x): - a = torch.nn.Parameter(torch.rand(2, 3)) - if isinstance(a, torch.Tensor): - return x + 1 - else: - return x - 1 - - x = torch.tensor([2.5]) - ref = fn(x) - opt_fn = torchdynamo.optimize("eager")(fn) - res = opt_fn(x) - self.assertEqual(ref, res) - - def test_change_backends(self): - @torchdynamo.optimize("eager", nopython=True) - def fn1(): - return x + 1 - - @torchdynamo.optimize("ts") - def fn2(): - return x + 2 - - @torchdynamo.optimize("eager", nopython=False) - def fn3(): - return x + 1 - - x = torch.tensor([3, 5]) - - fn1() - fn1() - fn3() - self.assertRaises(torchdynamo.exc.ResetRequired, fn2) - fn1() - torchdynamo.reset() - fn2() - fn2() - self.assertRaises(torchdynamo.exc.ResetRequired, fn1) - self.assertRaises(torchdynamo.exc.ResetRequired, fn3) - fn2() - - def test_dynamo_min_operator_with_shape(self): - @torchdynamo.optimize("eager", nopython=True) - def f(x, a): - return min(x.shape[0], a) - - result = f(torch.ones(6), 3) - self.assertEqual(result, 3) - - @patch.object(torchdynamo.config, "dynamic_shapes", True) - def test_onnx_shape_as_tensor(self): - @torchdynamo.optimize("eager", nopython=True) - def f(x): - return 1 + torch._shape_as_tensor(x)[0] - - gm, _ = torchdynamo.export(f, torch.ones(6)) - - input_one_dim = torch.ones(6) - input_two_dims = torch.ones(7, 4) - self.assertEqual(f(input_one_dim), 7) - self.assertEqual(f(input_two_dims), 8) - self.assertEqual(f(input_two_dims), 8) - - @torchdynamo.optimize("eager", nopython=True) - def f_onnx(x): - return 1 + torch.onnx.operators.shape_as_tensor(x)[0] - - self.assertEqual(f_onnx(input_one_dim), 7) - self.assertEqual(f_onnx(input_two_dims), 8) - self.assertEqual(f_onnx(input_two_dims), 8) - - def test_cond(self): - from functorch.experimental.cond import cond - - def true_fn(x): - return x.sin() - - def false_fn(x): - return x.cos() - - def f(pred, x): - return cond(pred, true_fn, false_fn, [x]) - - opt_fn = torchdynamo.optimize("eager")(f) - a = opt_fn(torch.tensor(False), torch.tensor([0.25, 0.25])) - self.assertTrue(same(torch.cos(torch.tensor([0.25, 0.25])), a)) - b = opt_fn(torch.tensor(True), torch.tensor([0.25, 0.25])) - self.assertTrue(same(torch.sin(torch.tensor([0.25, 0.25])), b)) - - def test_cond_nested(self): - from functorch.experimental.cond import cond - - def true_fn_nested(x): - return x * 10 - - def false_fn_nested(x): - return x * -1 - - def true_fn(pred2, x): - return x.sin() - - def false_fn(pred2, x): - return x + cond(pred2, true_fn_nested, false_fn_nested, [x]) - - def f(pred, pred2, x): - return cond(pred, true_fn, false_fn, [pred2, x]) - - cc = torchdynamo.testing.CompileCounter() - opt_fn = torchdynamo.optimize(cc)(f) - true_true_sin = opt_fn( - torch.tensor(True), torch.tensor(True), torch.tensor([0.25, 0.25]) - ) - self.assertTrue(same(torch.sin(torch.tensor([0.25, 0.25])), true_true_sin)) - - true_false_sin = opt_fn( - torch.tensor(True), torch.tensor(False), torch.tensor([0.25, 0.25]) - ) - self.assertTrue(same(torch.sin(torch.tensor([0.25, 0.25])), true_false_sin)) - - false_true_sum_mult = opt_fn( - torch.tensor(False), torch.tensor(True), torch.tensor([0.25, 0.25]) - ) - self.assertTrue( - same(torch.tensor([2.75, 2.75]), false_true_sum_mult) - ) # * 10 then add x - - false_false_sum_neg = opt_fn( - torch.tensor(False), torch.tensor(False), torch.tensor([0.25, 0.25]) - ) - self.assertTrue( - same(torch.tensor([0.0, 0.0]), false_false_sum_neg) - ) # * -1 then add x - self.assertTrue(cc.frame_count, 2) - - @patch.object(torchdynamo.config, "fake_tensor_propagation", False) - def test_cond_nested_fake_tensor_off(self): - from functorch.experimental.cond import cond - - def true_fn_nested(x): - return x * 10 - - def false_fn_nested(x): - return x * -1 - - def true_fn(pred2, x): - return x.sin() - - def false_fn(pred2, x): - return x + cond(pred2, true_fn_nested, false_fn_nested, [x]) - - def f(pred, pred2, x): - return cond(pred, true_fn, false_fn, [pred2, x]) - - cc = torchdynamo.testing.CompileCounter() - opt_fn = torchdynamo.optimize(cc)(f) - true_true_sin = opt_fn( - torch.tensor(True), torch.tensor(True), torch.tensor([0.25, 0.25]) - ) - self.assertTrue(same(torch.sin(torch.tensor([0.25, 0.25])), true_true_sin)) - - true_false_sin = opt_fn( - torch.tensor(True), torch.tensor(False), torch.tensor([0.25, 0.25]) - ) - self.assertTrue(same(torch.sin(torch.tensor([0.25, 0.25])), true_false_sin)) - - false_true_sum_mult = opt_fn( - torch.tensor(False), torch.tensor(True), torch.tensor([0.25, 0.25]) - ) - self.assertTrue( - same(torch.tensor([2.75, 2.75]), false_true_sum_mult) - ) # * 10 then add x - - false_false_sum_neg = opt_fn( - torch.tensor(False), torch.tensor(False), torch.tensor([0.25, 0.25]) - ) - self.assertTrue( - same(torch.tensor([0.0, 0.0]), false_false_sum_neg) - ) # * -1 then add x - self.assertTrue(cc.frame_count, 1) - - @patch.object(torchdynamo.config, "fake_tensor_propagation", False) - def test_cond_export(self): - from functorch.experimental.cond import cond - - def true_fn_nested(x): - return x * 10 - - def false_fn_nested(x): - return x * -1 - - def true_fn(pred2, x): - return x.sin() - - def false_fn(pred2, x): - return x + cond(pred2, true_fn_nested, false_fn_nested, [x]) - - def f(pred, pred2, x): - return cond(pred, true_fn, false_fn, [pred2, x]) - - graph, guard = torchdynamo.export( - f, torch.tensor(False), torch.tensor(True), torch.tensor([0.25, 0.25]) - ) - true_true_sin = graph( - torch.tensor(True), torch.tensor(True), torch.tensor([0.25, 0.25]) - ) - self.assertTrue(same(torch.sin(torch.tensor([0.25, 0.25])), true_true_sin)) - - true_false_sin = graph( - torch.tensor(True), torch.tensor(False), torch.tensor([0.25, 0.25]) - ) - self.assertTrue(same(torch.sin(torch.tensor([0.25, 0.25])), true_false_sin)) - - false_true_sum_mult = graph( - torch.tensor(False), torch.tensor(True), torch.tensor([0.25, 0.25]) - ) - self.assertTrue( - same(torch.tensor([2.75, 2.75]), false_true_sum_mult) - ) # * 10 then add x - - false_false_sum_neg = graph( - torch.tensor(False), torch.tensor(False), torch.tensor([0.25, 0.25]) - ) - self.assertTrue( - same(torch.tensor([0.0, 0.0]), false_false_sum_neg) - ) # * -1 then add x - - @patch.object(torchdynamo.config, "fake_tensor_propagation", False) - def test_cond_export_single_arg(self): - from functorch.experimental.cond import cond - - def true_fn(x): - return x - - def false_fn(x): - return x.sin() - - def f(pred, x): - return cond(pred, true_fn, false_fn, [x]) - - graph, guard = torchdynamo.export( - f, torch.tensor(False), torch.tensor([0.25, 0.25]) - ) - true_mirror = graph(torch.tensor(True), torch.tensor([0.25, 0.25])) - self.assertTrue(same(torch.tensor([0.25, 0.25]), true_mirror)) - true_mirror_2 = graph(torch.tensor(True), torch.tensor([0.33, 0.33, 0.33])) - self.assertTrue(same(torch.tensor([0.33, 0.33, 0.33]), true_mirror_2)) - - false_sin = graph(torch.tensor(False), torch.tensor([0.5, 0.5])) - self.assertTrue(same(torch.sin(torch.tensor([0.5, 0.5])), false_sin)) - - def test_disable_optimize(self): - cnt = torchdynamo.testing.CompileCounter() - - @torchdynamo.optimize(cnt, disable=True) - def f1(x): - return x + 1 - - f1(torch.ones(6)) - self.assertEqual(cnt.frame_count, 0) - - @torchdynamo.optimize(cnt, disable=True) - def f2(x): - return x + 1 - - f2(torch.ones(6)) - self.assertEqual(cnt.frame_count, 0) - - with patch.dict(os.environ, {"TORCHDYNAMO_DISABLE": "1"}): - - @torchdynamo.optimize(cnt) - def f3(x): - return x + 1 - - f3(torch.ones(6)) - self.assertEqual(cnt.frame_count, 0) - - def test_config_log_level(self): - @torchdynamo.optimize("eager") - def fn(a, b): - return a + b - - with self.assertLogs(logger="torch._dynamo", level=logging.DEBUG) as log: - torchdynamo.config.log_level = logging.DEBUG - fn(torch.randn(10), torch.randn(10)) - cur_len = len(log) - self.assertGreater(cur_len, 0) - - torchdynamo.config.log_level = logging.WARNING - fn(torch.randn(10), torch.randn(10)) - self.assertEqual(cur_len, len(log)) - - @unittest.skip("disabled") - def test_duplicate_graph_break_warning(self): - @torchdynamo.optimize("eager") - def f1(a, b): - f2(a, b) - - def f2(a, b): - c = a + b - print("break") - return a + b + c - - @torchdynamo.optimize("eager") - def g1(a, b): - g2(a, b) - - def g2(a, b): - c = a + b - print("break") - return a + b + c - - def count_graph_break_msgs(msgs): - return sum(msg.find("Graph break") != -1 for msg in msgs) - - with self.assertLogs(logger="torch._dynamo", level=logging.WARNING) as log: - torchdynamo.config.verbose = True - f1(torch.randn(10), torch.randn(10)) - self.assertGreater(count_graph_break_msgs(log.output), 1) - - with self.assertLogs(logger="torch._dynamo", level=logging.WARNING) as log: - torchdynamo.config.verbose = False - g1(torch.randn(10), torch.randn(10)) - self.assertEqual(count_graph_break_msgs(log.output), 1) - - def test_inplace_param_update(self): - def fn(param, y): - prev_grad = torch.is_grad_enabled() - try: - torch.set_grad_enabled(False) - torch.set_grad_enabled(True) - torch.set_grad_enabled(False) - param.add_(y) - finally: - torch.set_grad_enabled(prev_grad) - - y = torch.randn(4) - x = torch.nn.Parameter(torch.randn(4)) - fn(x, y) - - cnts = torchdynamo.testing.CompileCounter() - opt_fn = torchdynamo.optimize(cnts, nopython=True)(fn) - opt_fn(x, y) - self.assertEqual(cnts.frame_count, 1) - self.assertEqual(cnts.op_count, 5) - - @unittest.skipIf(not torch.cuda.is_available(), "requires cuda") - def test_autocast(self): - if not torch.cuda.is_bf16_supported(): - raise unittest.SkipTest("requires bf16") - - class MyModule(torch.nn.Module): - def forward(self, x): - a_float32 = torch.rand((8, 8), device="cuda") - b_float32 = torch.rand((8, 8), device="cuda") - d_float32 = torch.rand((8, 8), device="cuda") - - with torch.autocast(device_type="cuda", dtype=torch.bfloat16): - e_float16 = torch.mm(a_float32, b_float32) - f_float16 = torch.mm(d_float32, e_float16) - return f_float16 - - module = MyModule() - real = module(torch.tensor([0.5])) - real_device = real.device - real_dtype = real.dtype - - graph, guards = torchdynamo.export(module, torch.tensor([[0.0, 0], [0, 0]])) - exported = graph(torch.tensor([0.5])) - self.assertEqual(exported.device, real_device) - self.assertEqual(exported.dtype, real_dtype) - - self.assertEqual(exported.device.type, "cuda") - self.assertEqual(exported.device.index, 0) - self.assertEqual(exported.dtype, torch.bfloat16) - - def test_autocast_cpu(self): - class MyModule(torch.nn.Module): - def forward(self, x): - a_float32 = torch.rand((8, 8), device="cpu") - b_float32 = torch.rand((8, 8), device="cpu") - d_float32 = torch.rand((8, 8), device="cpu") - - with torch.autocast(device_type="cpu", dtype=torch.bfloat16): - e_float16 = torch.mm(a_float32, b_float32) - f_float16 = torch.mm(d_float32, e_float16) - return f_float16 - - module = MyModule() - real = module(torch.tensor([0.5])) - real_device = real.device - real_dtype = real.dtype - - graph, guards = torchdynamo.export(module, torch.tensor([[0.0, 0], [0, 0]])) - exported = graph(torch.tensor([0.5])) - self.assertEqual(exported.device, real_device) - self.assertEqual(exported.dtype, real_dtype) - - self.assertEqual(exported.device.type, "cpu") - self.assertEqual(exported.dtype, torch.bfloat16) - - @unittest.skipIf(not torch.cuda.is_available(), "requires cuda") - def test_autocast_float64(self): - class MyModule(torch.nn.Module): - def forward(self, x): - a_float32 = torch.rand((8, 8), device="cuda") - b_float32 = torch.rand((8, 8), device="cuda") - d_float32 = torch.rand((8, 8), device="cuda") - - with torch.autocast(device_type="cuda", dtype=torch.float64): - e_float64 = torch.mm(a_float32, b_float32) - f_float64 = torch.mm(d_float32, e_float64) - return f_float64 - - module = MyModule() - real = module(torch.tensor([0.5])) - real_device = real.device - real_dtype = real.dtype - - graph, guards = torchdynamo.export(module, torch.tensor([[0.0, 0], [0, 0]])) - exported = graph(torch.tensor([0.5])) - self.assertEqual(exported.device, real_device) - self.assertEqual(exported.dtype, real_dtype) - - self.assertEqual(exported.device.index, 0) - self.assertEqual(exported.dtype, torch.float64) - - @unittest.skipIf(not torch.cuda.is_available(), "requires cuda") - def test_autocast_device(self): - class MyModule(torch.nn.Module): - def forward(self, x): - a_float32 = torch.rand((8, 8), device="cuda") - b_float32 = torch.rand((8, 8), device="cuda") - d_float32 = torch.rand((8, 8), device="cuda") - - with torch.autocast(device_type="cuda"): - e_float64 = torch.mm(a_float32, b_float32) - f_float64 = torch.mm(d_float32, e_float64) - return f_float64 - - module = MyModule() - real = module(torch.tensor([0.5])) - real_device = real.device - real_dtype = real.dtype - - graph, guards = torchdynamo.export(module, torch.tensor([[0.0, 0], [0, 0]])) - exported = graph(torch.tensor([0.5])) - self.assertEqual(exported.device, real_device) - self.assertEqual(exported.dtype, real_dtype) - - self.assertEqual(exported.device.index, 0) - self.assertEqual(exported.dtype, torch.torch.float16) - - def test_generate_tensor_from_list_of_numpy_primitive_type(self): - # Test sth like torch.LongTensor(list(np.int64, np.int64, ...)) - def fn(): - x = np.array([1, 2, 3, 4, 5, 6], dtype=np.int64) - y = [x[0], x[2], x[4]] - z = torch.LongTensor(y) - return z - - ref = fn() - opt_fn = torchdynamo.optimize("eager")(fn) - res = opt_fn() - self.assertTrue(same(ref, res)) - - def test_autograd_function_equivalence(self): - m1 = Module1() - - @torchdynamo.optimize("eager", nopython=True) - def f1(): - return m1(torch.ones(2, 3)) - - self.assertTrue(torch.allclose(f1(), torch.tensor([2.0]))) - - m2 = Module2() - - @torchdynamo.optimize("eager", nopython=True) - def f2(): - return m2(torch.ones(2, 3)) - - self.assertTrue(torch.allclose(f2(), torch.tensor([2.0]))) - - def test_object_classmethod(self): - class C: - @classmethod - def fn(cls, x): - return x + x - - @torchdynamo.optimize("eager", nopython=True) - def f(): - return C().fn(torch.ones(2, 3)) - - self.assertTrue(torch.allclose(f(), torch.tensor([2.0]))) - - def test_object_staticmethod(self): - class C: - @staticmethod - def fn(x): - return x + x - - @torchdynamo.optimize("eager", nopython=True) - def f(): - return C().fn(torch.ones(2, 3)) - - self.assertTrue(torch.allclose(f(), torch.tensor([2.0]))) - - def test_user_function_variable_supports_enum_argument(self): - class Foo(enum.Enum): - FOO = 0 - BAR = 1 - - def gn(x, y=Foo.FOO): - if y is Foo.FOO: - return x - else: - return x + 1 - - def fn(x): - return gn(x) - - x = torch.randn(2, 3) - ref = fn(x) - opt_fn = torchdynamo.optimize("eager", nopython=True)(fn) - res = opt_fn(x) - self.assertTrue(torch.allclose(ref, res)) - - def test_repro_graph_breaks_in__get_item_by_idx(self): - class Mod(torch.nn.Module): - def __init__(self): - super().__init__() - self.mod = torch.nn.Sequential( - torch.nn.Linear(3, 3), torch.nn.Linear(3, 3) - ) - - def forward(self, x): - return self.mod[0](x) - - m = Mod() - graph, _ = torchdynamo.export(m, torch.randn(3, 3)) - - -class CustomFunc(torch.autograd.Function): - @staticmethod - def forward(ctx, foo): - return foo + foo - - @staticmethod - def backward(ctx, grad_output): - return grad_output - - -class Module1(torch.nn.Module): - def __init__(self): - super().__init__() - - def forward(self, foo): - return CustomFunc().apply(foo) - - -class Module2(torch.nn.Module): - def __init__(self): - super().__init__() - self.fn = CustomFunc.apply - - def forward(self, foo): - return self.fn(foo) - - -class TestTracer(JitTestCase): - def test_jit_save(self): - def fn(): - class Foo(torch.nn.Module): - def __init__(self): - super(Foo, self).__init__() - self.a = 3 - - @torch.jit.export - def __getstate__(self): - return (3, self.training) - - @torch.jit.export - def __setstate__(self, state): - self.a = state[0] - self.training = state[1] - - def forward(self, x): - return x + self.a - - f = Foo() - - return torch.jit.trace(f, (torch.rand(3, 4),)) - - fn() - opt_fn = torchdynamo.optimize("eager")(fn) - opt_fn() - - -if __name__ == "__main__": - from torchdynamo.test_case import run_tests - - run_tests() diff --git a/test/dynamo/test_model_output.py b/test/dynamo/test_model_output.py deleted file mode 100644 index 2653c09c72..0000000000 --- a/test/dynamo/test_model_output.py +++ /dev/null @@ -1,166 +0,0 @@ -# Owner(s): ["module: dynamo"] -import dataclasses -import unittest.mock - -import torch - -import torchdynamo.test_case -import torchdynamo.testing -from torchdynamo.testing import same - -try: - from transformers import modeling_outputs - from transformers.configuration_utils import PretrainedConfig - from transformers.file_utils import ModelOutput - from transformers.modeling_outputs import BaseModelOutput -except ImportError: - modeling_outputs = None - - -def maybe_skip(fn): - if modeling_outputs is None: - return unittest.skip("requires HuggingFace")(fn) - return fn - - -class TestHFPretrained(torchdynamo.test_case.TestCase): - @maybe_skip - def test_pretrained(self): - def fn(a, tmp): - if tmp.return_dict: - return a + torch.ones(2) * tmp.max_length - return a - - x = torch.randn(2) - tmp = PretrainedConfig(return_dict=True, max_length=20) - ref = fn(x, tmp) - opt_fn = torchdynamo.optimize("eager", nopython=True)(fn) - res = opt_fn(x, tmp) - self.assertTrue(same(ref, res)) - - -class TestModelOutput(torchdynamo.test_case.TestCase): - @maybe_skip - def test_mo_create(self): - def fn(a, b): - tmp = BaseModelOutput(a + 1, attentions=b + 3) - return tmp - - torchdynamo.testing.standard_test(self, fn=fn, nargs=2, expected_ops=2) - - @maybe_skip - def test_mo_assign(self): - def fn(a, b): - tmp = BaseModelOutput(last_hidden_state=b + 3) - tmp.hidden_states = a + 7 - tmp["attentions"] = a + b + 6 - return tmp - - args = [torch.randn(10), torch.randn(10)] - obj1 = fn(*args) - - cnts = torchdynamo.testing.CompileCounter() - opt_fn = torchdynamo.optimize_assert(cnts)(fn) - obj2 = opt_fn(*args) - self.assertTrue(same(obj1.last_hidden_state, obj2.last_hidden_state)) - self.assertTrue(same(obj1.hidden_states, obj2.hidden_states)) - self.assertTrue(same(obj1.attentions, obj2.attentions)) - self.assertEqual(cnts.frame_count, 1) - self.assertEqual(cnts.op_count, 4) - - def _common(self, fn, op_count): - args = [ - BaseModelOutput( - last_hidden_state=torch.randn(10), attentions=torch.randn(10) - ) - ] - obj1 = fn(*args) - cnts = torchdynamo.testing.CompileCounter() - opt_fn = torchdynamo.optimize_assert(cnts)(fn) - obj2 = opt_fn(*args) - self.assertTrue(same(obj1, obj2)) - self.assertEqual(cnts.frame_count, 1) - self.assertEqual(cnts.op_count, op_count) - - @maybe_skip - def test_mo_getattr(self): - def fn(obj: BaseModelOutput): - x = obj.last_hidden_state * 10 - if obj.hidden_states is not None: - x += obj.hidden_states - if obj.attentions is not None: - x += obj.attentions - return x - - self._common(fn, 2) - - @maybe_skip - def test_mo_getitem(self): - def fn(obj: BaseModelOutput): - x = obj["last_hidden_state"] * 10 - if "hidden_stats" in obj: - x += obj["hidden_states"] - if "attentions" in obj: - x += obj["attentions"] - return x - - self._common(fn, 2) - - @maybe_skip - def test_mo_tuple(self): - def fn(obj: BaseModelOutput): - a, b = obj.to_tuple() - return a + b * 10 - - self._common(fn, 2) - - @maybe_skip - def test_mo_index(self): - def fn(obj: BaseModelOutput): - return obj[0] * 10 + obj[1] - - self._common(fn, 2) - - @maybe_skip - def test_mo_init(self): - @dataclasses.dataclass - class MyDataClass(ModelOutput): - a: torch.Tensor - b: torch.Tensor = None - c: torch.Tensor = None - d: torch.Tensor = None - e: torch.Tensor = None - - def fn(obj): - class_fields = dataclasses.fields(obj) - assert len(class_fields) - assert all(field.default is None for field in class_fields[1:]) - other_fields_are_none = all( - getattr(obj, field.name) is None for field in class_fields[1:] - ) - assert not other_fields_are_none - - total = getattr(obj, class_fields[0].name) - for field in class_fields[1:]: - v = getattr(obj, field.name) - if v is not None: - total += v - - return total - - tensors = [torch.randn(10), torch.randn(10), torch.randn(10)] - obj1 = MyDataClass(*tensors) - correct1 = fn(obj1) - - obj2 = MyDataClass(*tensors) - cnts = torchdynamo.testing.CompileCounter() - opt_fn = torchdynamo.optimize(cnts)(fn) - self.assertTrue(same(opt_fn(obj2), correct1)) - self.assertEqual(cnts.frame_count, 1) - self.assertEqual(cnts.op_count, 2) - - -if __name__ == "__main__": - from torchdynamo.test_case import run_tests - - run_tests() diff --git a/test/dynamo/test_modules.py b/test/dynamo/test_modules.py deleted file mode 100644 index 4055337b40..0000000000 --- a/test/dynamo/test_modules.py +++ /dev/null @@ -1,891 +0,0 @@ -# Owner(s): ["module: dynamo"] - -from copy import deepcopy -from unittest.mock import patch - -import torch -from torch.nn import functional as F -from torch.nn.modules.lazy import LazyModuleMixin -from torch.nn.parameter import Parameter -from torch.nn.parameter import UninitializedParameter - -import torchdynamo.test_case -import torchdynamo.testing -from torchdynamo.eval_frame import unsupported -from torchdynamo.mutation_guard import GenerationTracker -from torchdynamo.testing import same - -try: - from . import test_functions -except ImportError: - import test_functions - - -class BasicModule(torch.nn.Module): - def __init__(self): - super().__init__() - self.linear1 = torch.nn.Linear(10, 10) - self.scale = torch.randn(1, 10) - - def forward(self, x): - return F.relu(self.linear1(x)) * self.scale - - -class FnMember(torch.nn.Module): - def __init__(self): - super().__init__() - self.linear1 = torch.nn.Linear(10, 10) - self.activation = F.relu - - def forward(self, x): - x = self.linear1(x) - if self.activation: - x = self.activation(x) - return x - - -class FnMemberCmp(torch.nn.Module): - def __init__(self, activation): - super().__init__() - self.linear1 = torch.nn.Linear(10, 10) - self.activation = activation - - def forward(self, x): - x = self.linear1(x) - if self.activation is not None: - x = self.activation(x) - if self.activation is None: - x = torch.sigmoid(x) - return x - - -class SubmoduleExample(torch.nn.Module): - def __init__(self): - super().__init__() - self.layer1 = BasicModule() - self.layer2 = BasicModule() - self.scale = torch.randn(1, 10) - - def forward(self, x): - x = self.layer1(x) - x = self.layer2(x) - return x * self.scale - - -class IsTrainingCheck(torch.nn.Module): - def __init__(self): - super().__init__() - self.linear1 = torch.nn.Linear(10, 10) - self.linear2 = torch.nn.Linear(10, 10) - self.train(True) - - def forward(self, x): - if self.training: - mod = self.linear1 - else: - mod = self.linear2 - return F.relu(mod(x)) - - -class IsEvalCheck(IsTrainingCheck): - def __init__(self): - super().__init__() - self.train(False) - - -class ModuleMethodCall(torch.nn.Module): - def __init__(self): - super().__init__() - self.layer1 = BasicModule() - self.layer2 = BasicModule() - self.scale = torch.randn(1, 10) - - def call_and_scale(self, mod, x): - x = mod(x) - return x * self.scale - - def forward(self, x): - x1 = self.call_and_scale(self.layer1, x) - x2 = self.call_and_scale(self.layer2, x) - return x1 + x2 - - -class UnsupportedMethodCall(torch.nn.Module): - def __init__(self): - super().__init__() - self.layer1 = BasicModule() - self.scale = torch.randn(1, 10) - - def call_and_scale(self, mod, x): - x = mod(x) - x = x * self.scale - return unsupported(x, x) - - def forward(self, x): - x1 = self.call_and_scale(self.layer1, x) - return x + x1 - - -class UnsupportedModule(torch.nn.Module): - def __init__(self): - super().__init__() - self.layer1 = BasicModule() - self.scale = torch.randn(1, 10) - - def forward(self, x): - x = self.layer1(x) * self.scale - return unsupported(x, x) - - -class UnsupportedModuleCall(torch.nn.Module): - def __init__(self): - super().__init__() - self.mod = UnsupportedModule() - - def forward(self, x): - return 1 + self.mod(x * 1.5) - - -class ModuleStaticMethodCall(torch.nn.Module): - def __init__(self): - super().__init__() - self.layer1 = BasicModule() - self.layer2 = BasicModule() - self.scale = torch.randn(1, 10) - - @staticmethod - def call_and_scale(scale, mod, x): - x = mod(x) - return x * scale - - def forward(self, x): - x1 = self.call_and_scale(self.scale, self.layer1, x) - x2 = self.call_and_scale(self.scale, self.layer2, x) - return x1 + x2 - - -class ModuleClassMethodCall(torch.nn.Module): - def __init__(self): - super().__init__() - self.layer1 = BasicModule() - self.layer2 = BasicModule() - self.scale = torch.randn(1, 10) - - @classmethod - def call_and_scale(cls, scale, mod, x): - x = mod(x) - return x * scale - - def forward(self, x): - x1 = self.call_and_scale(self.scale, self.layer1, x) - x2 = self.call_and_scale(self.scale, self.layer2, x) - return x1 + x2 - - -class ModuleProperty(torch.nn.Module): - def __init__(self): - super().__init__() - self.scale = torch.randn(1, 10) - - @property - def scale_alias(self): - return self.scale - - def forward(self, x): - return x * self.scale_alias - - -class ConstLoop(torch.nn.Module): - def __init__(self): - super().__init__() - self.linear1 = torch.nn.Linear(10, 10) - self.count = 3 - - def forward(self, x): - for i in range(self.count): - x = torch.sigmoid(self.linear1(x)) - return x - - -class ViaModuleCall(torch.nn.Module): - def __init__(self): - super().__init__() - self.linear1 = torch.nn.Linear(10, 10) - - def forward(self, x): - return test_functions.constant3(torch.sigmoid(self.linear1(x)), x) - - -class IsNoneLayer(torch.nn.Module): - def __init__(self): - super().__init__() - self.layer1 = torch.nn.Linear(10, 10) - self.layer2 = None - self.train(True) - - def forward(self, x): - if self.layer1 is not None: - x = self.layer1(x) - if self.layer2 is not None: - x = self.layer2(x) - return x - - -class LayerList(torch.nn.Module): - def __init__(self): - super().__init__() - self.layers = [ - torch.nn.Linear(10, 10), - torch.nn.ReLU(), - torch.nn.Linear(10, 10), - ] - - def forward(self, x): - for layer in self.layers: - x = layer(x) - return x - - -class ModuleList(torch.nn.Module): - def __init__(self): - super().__init__() - self.layers = torch.nn.ModuleList( - [ - torch.nn.Linear(10, 10), - torch.nn.ReLU(), - torch.nn.Linear(10, 10), - torch.nn.ReLU(), - ] - ) - - def forward(self, x): - for i in range(len(self.layers)): - x = self.layers[i](x) - - for layer in self.layers: - x = layer(x) - - for layer, val in zip(self.layers, (x, x, x, x)): - x = layer(x) + val - - for layer, val in zip(self.layers, (1, 2, 3, 4)): - x = layer(x) + val - - for idx, layer in enumerate(self.layers): - x = layer(x) * idx - - for idx, layer in enumerate(self.layers[::-1]): - x = layer(x) * idx - - return x - - -class ModuleDict(torch.nn.Module): - def __init__(self): - super().__init__() - self.layers = torch.nn.ModuleDict( - { - "0": torch.nn.Linear(10, 10), - } - ) - - def forward(self, x): - # TODO(future PR): handle more logic - x = self.layers["0"](x) - return x - - -class TensorList(torch.nn.Module): - def __init__(self): - super().__init__() - self.layers = ( - torch.randn((1, 10)), - torch.randn((10, 1)), - torch.randn((1, 10)), - torch.randn((10, 1)), - ) - - def forward(self, x): - for layer in self.layers: - x = x * layer - return x - - -class Children(torch.nn.Module): - def __init__(self): - super().__init__() - self.l1 = torch.nn.Linear(10, 10) - self.l2 = torch.nn.ReLU() - self.l3 = torch.nn.Linear(10, 10) - self.l4 = torch.nn.ReLU() - - def forward(self, x): - for block in self.children(): - x = block(x) - return x - - -class IntArg(torch.nn.Module): - def __init__(self): - super().__init__() - self.layer1 = torch.nn.Linear(10, 10) - - def forward(self, x, offset=1): - x = F.relu(self.layer1(x)) + offset - return x - - -class Seq(torch.nn.Module): - def __init__(self): - super().__init__() - self.layers = torch.nn.Sequential( - torch.nn.Linear(10, 10), - torch.nn.ReLU(), - torch.nn.Linear(10, 10), - torch.nn.ReLU(), - ) - - def forward(self, x): - return self.layers(x) - - -class Cfg: - def __init__(self): - self.val = 0.5 - self.count = 3 - - -class CfgModule(torch.nn.Module): - def __init__(self): - super().__init__() - self.cfg = Cfg() - self.layer = torch.nn.Linear(10, 10) - - def forward(self, x): - for i in range(self.cfg.count): - x = self.layer(x + self.cfg.val) - return x - - -class StringMember(torch.nn.Module): - def __init__(self): - super().__init__() - self.linear1 = torch.nn.Linear(10, 10) - self.mode = "some_string" - - def forward(self, x): - if self.mode == "some_string": - return F.relu(self.linear1(x)) - - -class _Block(torch.nn.Module): - def forward(self, x): - return 1.5 * torch.cat(x, 1) - - -class _DenseBlock(torch.nn.ModuleDict): - _version = 2 - - def __init__( - self, - num_layers: int = 3, - ) -> None: - super().__init__() - for i in range(num_layers): - self.add_module("denselayer%d" % (i + 1), _Block()) - - def forward(self, init_features): - features = [init_features] - for name, layer in self.items(): - new_features = layer(features) - features.append(new_features) - return torch.cat(features, 1) - - -class DenseNetBlocks(torch.nn.Module): - def __init__(self): - super().__init__() - self.layers = _DenseBlock() - - def forward(self, x): - return self.layers(x) - - -class MaterializedModule(torch.nn.Module): - """Once the below lazy module is initialized with its first input, - it is transformed into this module.""" - - param: Parameter - - def __init__(self): - super().__init__() - self.register_parameter("param", None) - - def forward(self, x): - return x - - -class LazyModule(LazyModuleMixin, MaterializedModule): - param: UninitializedParameter - cls_to_become = MaterializedModule - - def __init__(self): - super().__init__() - self.param = UninitializedParameter() - - def initialize_parameters(self, x): - self.param.materialize(x.shape) - - -def requires_grad1(module: torch.nn.Module, recurse: bool = False) -> bool: - requires_grad = any([p.requires_grad for p in module.parameters(recurse)]) - return requires_grad - - -def requires_grad2(module: torch.nn.Module, recurse: bool = False) -> bool: - requires_grad = any(p.requires_grad for p in module.parameters(recurse)) - return requires_grad - - -class ParametersModule1(torch.nn.Module): - def __init__(self): - super().__init__() - self.linear1 = torch.nn.Linear(10, 10) - self.scale = torch.nn.Parameter(torch.randn(1, 10)) - - def forward(self, x): - if not requires_grad1(self): - return F.relu(self.linear1(x)) * self.scale - else: - return x + 1 - - -class ParametersModule2(ParametersModule1): - def forward(self, x): - if not requires_grad2(self): - return F.relu(self.linear1(x)) * self.scale - else: - return x + 1 - - -class ParametersModule3(ParametersModule1): - def forward(self, x): - ones = torch.ones(10, dtype=next(self.parameters()).dtype) - return F.relu(self.linear1(x)) * self.scale + ones - - -class SuperModule(BasicModule): - def forward(self, x): - x = super().forward(x) - return x + 10.0 - - -class ComplicatedSuperParent(torch.nn.Module): - @classmethod - def custom_add(cls, x): - x = x + x - return x - - -class SuperChildCallsClassMethod(ComplicatedSuperParent): - @classmethod - def child_func(cls, x): - x = super().custom_add(x) - return x - - def forward(self, x): - x = self.child_func(x) - return x - - -class HasAttrModule(torch.nn.Module): - def __init__(self): - super().__init__() - self.scale = torch.nn.Parameter(torch.randn(1, 10)) - - def forward(self, x): - x = F.relu(x) - if hasattr(self, "scale"): - x *= self.scale - if hasattr(self, "scale2"): - x *= self.scale2 - return x - - -class EnumValues(torch.nn.ModuleDict): - def __init__( - self, - num_layers: int = 3, - ) -> None: - super().__init__() - for i in range(num_layers): - self.add_module("denselayer%d" % (i + 1), _Block()) - - def forward(self, init_features): - features = [init_features] - for idx, layer in enumerate(self.values()): - new_features = layer(features) - features.append(new_features) - return torch.cat(features, 1) - - -class CallForwardDirectly(torch.nn.Module): - def __init__(self): - super().__init__() - self.layer1 = BasicModule() - self.layer2 = torch.nn.Linear(10, 10) - - def forward(self, x): - x = self.layer1.forward(x) - x = self.layer2.forward(x) - return x - - -class ModuleNameString(torch.nn.Module): - def __init__(self): - super().__init__() - self.linear1 = torch.nn.Linear(10, 10) - - def forward(self, x): - if self.__class__.__name__ == "ABC": - return 10 - if self.linear1.__class__.__name__ == "Linear": - return F.relu(self.linear1(x) + 10) - return 11 - - -class SelfMutatingModule(torch.nn.Module): - def __init__(self, layer): - super().__init__() - self.layer = layer - self.counter = 0 - - def forward(self, x): - result = self.layer(x) + self.counter - self.counter += 1 - return F.relu(result) - - -class ModuleAttributePrecedenceBase(torch.nn.Module): - def __init__(self): - super().__init__() - - def linear(self, x): - return x * 2.0 - - -class ModuleAttributePrecedence(ModuleAttributePrecedenceBase): - def __init__(self): - super().__init__() - self.activation = torch.nn.ReLU() - self.linear = torch.nn.Linear(10, 10) - self.initializer = torch.ones([10, 10]) - self.scale = 0.5 - - def activation(self, x): - return x * 1.2 - - def initializer(self): - return torch.zeros([10, 10]) - - def scale(self): - return 2.0 - - def forward(self, x): - # object attribute takes precedence unless it's a nn.Module - return self.activation(self.linear(self.initializer + x)) * self.scale - - -def make_test(fn, expected_ops=None): - def test_fn(self): - return torchdynamo.testing.standard_test( - self, fn=fn, nargs=1, expected_ops=expected_ops - ) - - fn.eval() - return test_fn - - -class NNModuleTests(torchdynamo.test_case.TestCase): - test_seq = make_test(Seq()) - test_basicmodule1 = make_test(BasicModule()) - test_basicmodule2 = make_test(BasicModule()) - test_submodules1 = make_test(SubmoduleExample()) - test_submodules2 = make_test(SubmoduleExample()) - test_modulemethod1 = make_test(ModuleMethodCall()) - test_modulemethod2 = make_test(ModuleMethodCall()) - test_module_static_method = make_test(ModuleStaticMethodCall()) - test_fnmember = make_test(FnMember()) - test_fnmembercmp1 = make_test(FnMemberCmp(F.relu)) - test_fnmembercmp2 = make_test(FnMemberCmp(None)) - test_constloop = make_test(ConstLoop()) - test_istraining1 = make_test(IsTrainingCheck()) - test_istraining2 = make_test(IsTrainingCheck()) - test_iseval1 = make_test(IsEvalCheck()) - test_iseval2 = make_test(IsEvalCheck()) - test_viamodulecall = make_test(ViaModuleCall()) - test_isnonelayer = make_test(IsNoneLayer()) - test_layerlist = make_test(LayerList()) - test_tensorlist = make_test(TensorList()) - test_intarg = make_test(IntArg()) - test_cfgmod = make_test(CfgModule()) - test_stringmember = make_test(StringMember()) - test_modulelist = make_test(ModuleList()) - test_moduledict = make_test(ModuleDict()) - test_super1 = make_test(SuperModule()) - test_super_class_method = make_test(SuperChildCallsClassMethod()) - test_children = make_test(Children()) - test_densenet = make_test(DenseNetBlocks()) - test_parameters1 = make_test(ParametersModule1()) - test_parameters2 = make_test(ParametersModule2()) - test_parameters3 = make_test(ParametersModule3(), expected_ops=5) - test_hasattr = make_test(HasAttrModule()) - test_enumvalues = make_test(EnumValues()) - test_module_class_method = make_test(ModuleClassMethodCall()) - test_module_property = make_test(ModuleProperty()) - test_forward_directly = make_test(CallForwardDirectly()) - test_module_name_string = make_test(ModuleNameString()) - test_module_attribute_precedence = make_test(ModuleAttributePrecedence()) - - def test_unsupportedmethod(self): - m = UnsupportedMethodCall() - i = torch.randn(10) - cnt = torchdynamo.testing.CompileCounter() - opt_m = torchdynamo.optimize(cnt)(m) - r = opt_m(i) - self.assertTrue(torchdynamo.testing.same(r, m(i))) - self.assertEqual(cnt.op_count, 5) - - def test_unsupportedmodule(self): - m = UnsupportedModuleCall() - i = torch.randn(10) - cnt = torchdynamo.testing.CompileCounter() - opt_m = torchdynamo.optimize(cnt)(m) - r = opt_m(i) - self.assertTrue(torchdynamo.testing.same(r, m(i))) - self.assertEqual(cnt.op_count, 6) - - def test_self_mutating1(self): - m1 = torch.nn.Linear(10, 10) - m2 = SelfMutatingModule(m1) - m3 = SelfMutatingModule(m1) - m4 = SelfMutatingModule(m1) - i = torch.randn(10) - out2 = [m2(i), m2(i), m2(i)] - cnt = torchdynamo.testing.CompileCounter() - opt_m3 = torchdynamo.optimize_assert(cnt)(m3) - opt_m4 = torchdynamo.optimize_assert(cnt)(m4) - out3 = [opt_m3(i), opt_m3(i), opt_m3(i)] - out4 = [opt_m4(i), opt_m4(i), opt_m4(i)] - self.assertTrue(torchdynamo.testing.same(out2, out3)) - self.assertTrue(torchdynamo.testing.same(out2, out4)) - self.assertEqual(cnt.frame_count, 3) - - @patch.object(torchdynamo.config, "raise_on_ctx_manager_usage", False) - def test_generation_tag(self): - cnt = torchdynamo.testing.CompileCounter() - - # guarantee that we have installed - # the generation tagging function - with torchdynamo.optimize_assert(cnt): - pass - - m1 = torch.nn.Linear(10, 10) - prev_generation = GenerationTracker.get_generation_value(m1) - cur_generation = prev_generation + 1 - - with torchdynamo.optimize_assert(cnt): - m2 = torch.nn.Linear(10, 10) - - self.assertEqual(GenerationTracker.get_generation_value(m1), prev_generation) - self.assertEqual(GenerationTracker.get_generation_value(m2), cur_generation) - # check that newly constructed instances - # also have the same generation (even if copied from an old instance) - m3 = deepcopy(m1) - self.assertEqual(GenerationTracker.get_generation_value(m3), cur_generation) - - def test_simple_torch_function(self): - def foo(x): - # function call, twice to test wrapping - x = F.sigmoid(x) - x = F.sigmoid(x) - # method call, twice to test wrapping - x = x.sigmoid() - x = x.sigmoid() - return x - - class TensorProxy(torch.Tensor): - @classmethod - def __torch_function__(cls, func, types, args=(), kwargs=None): - return super().__torch_function__(func, types, args, kwargs) - - torchdynamo.config.traceable_tensor_subclasses.add(TensorProxy) - - x = torch.randn(1).as_subclass(TensorProxy) - cnt = torchdynamo.testing.CompileCounter() - out1 = foo(x) - opt_foo = torchdynamo.optimize(cnt, nopython=True)(foo) - out2 = opt_foo(x) - - self.assertEqual(cnt.op_count, 4) - self.assertTrue(torchdynamo.testing.same(out1, out2)) - - torchdynamo.config.traceable_tensor_subclasses.remove(TensorProxy) - - def test_torch_function_with_closure(self): - def run(): - - counter = 0 - - def foo(x): - # function call, twice to test wrapping - x = F.sigmoid(x) - x = F.sigmoid(x) - # method call, twice to test wrapping - x = x.sigmoid() - x = x.sigmoid() - return x - - class TensorProxy(torch.Tensor): - @classmethod - def __torch_function__(cls, func, types, args=(), kwargs=None): - nonlocal counter - # for now, only support reads from closure cells - # TODO(future PR): support writes as well - counter + 1 - return super().__torch_function__(func, types, args, kwargs) - - torchdynamo.config.traceable_tensor_subclasses.add(TensorProxy) - - x = torch.randn(1).as_subclass(TensorProxy) - x = torch.randn(1) - cnt = torchdynamo.testing.CompileCounter() - out1 = foo(x) - opt_foo = torchdynamo.optimize(cnt, nopython=True)(foo) - out2 = opt_foo(x) - - self.assertEqual(cnt.op_count, 4) - self.assertTrue(torchdynamo.testing.same(out1, out2)) - - torchdynamo.config.traceable_tensor_subclasses.remove(TensorProxy) - - run() - - @patch.object(torchdynamo.config, "raise_on_ctx_manager_usage", False) - def test_nn_moduledict_contains(self): - class M(torch.nn.Module): - def __init__(self, module_dict): - super().__init__() - self.module_dict = module_dict - - def forward(self, x): - if "foo" in self.module_dict: - x = torch.mul(x, 1.0) - x = torch.add(x, 1.0) - return x - - module_dict = torch.nn.ModuleDict({"foo": torch.nn.Conv2d(1, 1, 1)}) - m = M(module_dict) - data = torch.randn(1) - out1 = m(data) - cnt = torchdynamo.testing.CompileCounter() - opt_m = torchdynamo.optimize(cnt, nopython=True)(m) - out2 = opt_m(data) - self.assertEqual(cnt.op_count, 2) - self.assertTrue(torchdynamo.testing.same(out1, out2)) - - module_dict = torch.nn.ModuleDict({"bar": torch.nn.Conv2d(1, 1, 1)}) - m = M(module_dict) - data = torch.randn(1) - out1 = m(data) - cnt = torchdynamo.testing.CompileCounter() - torchdynamo.reset() - opt_m = torchdynamo.optimize(cnt, nopython=True)(m) - out2 = opt_m(data) - - self.assertEqual(cnt.op_count, 1) - self.assertTrue(torchdynamo.testing.same(out1, out2)) - - module_dict = torch.nn.ModuleDict({"cat": torch.nn.Conv2d(1, 1, 1)}) - pre = m(data) - cnt.clear() - - with torchdynamo.optimize(cnt, nopython=False): - opt_pre = m(data) - m = M(module_dict) - data = torch.randn(1) - out1 = m(data) - - out_post = m(data) - self.assertEqual(cnt.frame_count, 1) - self.assertEqual(cnt.op_count, 1) - self.assertTrue(torchdynamo.testing.same(pre, opt_pre)) - self.assertTrue(torchdynamo.testing.same(out1, out_post)) - - def test_lazy_module(self): - input_shape = (16, 3, 6, 7, 8) - - cnt = torchdynamo.testing.CompileCounter() - module = LazyModule() - - def test_static_module(): - input = torch.ones(*input_shape) - module(input) - - opt_test_static_module = torchdynamo.optimize(cnt)(test_static_module) - opt_test_static_module() - - self.assertTrue( - isinstance(module, MaterializedModule), - "Module should be transformed to an instance of MaterializedModule.", - ) - self.assertEqual(module.param.shape, input_shape) - - # test when mapped to UnspecializedNNModule - module = LazyModule() - - def test_unspecialized(): - nonlocal module - module = LazyModule() - input = torch.ones(*input_shape) - module(input) - - opt_test_unspecialized = torchdynamo.optimize(cnt)(test_unspecialized) - opt_test_unspecialized() - - self.assertTrue( - isinstance(module, MaterializedModule), - "Module should be transformed to an instance of MaterializedModule.", - ) - self.assertEqual(module.param.shape, input_shape) - - # test with a static module in torch.* - module = torch.nn.modules.LazyBatchNorm3d( - affine=False, track_running_stats=False - ) - - cnt = torchdynamo.testing.CompileCounter() - - torchdynamo.reset() - - def test_torch_static(): - input = torch.ones(*input_shape) - return module(input) # fully materialized - - opt_test_torch_static = torchdynamo.optimize(cnt)(test_torch_static) - opt_test_torch_static() - out = opt_test_torch_static() - - self.assertTrue(same(out, module(torch.ones(*input_shape)))) - - self.assertTrue( - isinstance(module, torch.nn.modules.batchnorm.BatchNorm3d), - "Module should be transformed to an instance of BatchNorm3d.", - ) - self.assertEqual(cnt.frame_count, 1, "No guards should have triggered.") - - -if __name__ == "__main__": - from torchdynamo.test_case import run_tests - - run_tests() diff --git a/test/dynamo/test_no_fake_tensors.py b/test/dynamo/test_no_fake_tensors.py deleted file mode 100644 index 6cc3573464..0000000000 --- a/test/dynamo/test_no_fake_tensors.py +++ /dev/null @@ -1,33 +0,0 @@ -# Owner(s): ["module: dynamo"] -from torchdynamo.testing import make_test_cls_with_patches - -try: - from . import test_functions - from . import test_misc - from . import test_modules - from . import test_repros - from . import test_unspec -except ImportError: - import test_functions - import test_misc - import test_modules - import test_repros - import test_unspec - - -def make_no_fake_cls(cls): - return make_test_cls_with_patches( - cls, "NoFakeTensors", "_no_fake_tensors", ("fake_tensor_propagation", False) - ) - - -NoFakeTensorsFunctionTests = make_no_fake_cls(test_functions.FunctionTests) -NoFakeTensorsMiscTests = make_no_fake_cls(test_misc.MiscTests) -NoFakeTensorsReproTests = make_no_fake_cls(test_repros.ReproTests) -NoFakeTensorsNNModuleTests = make_no_fake_cls(test_modules.NNModuleTests) -NoFakeTensorsUnspecTests = make_no_fake_cls(test_unspec.UnspecTests) - -if __name__ == "__main__": - from torchdynamo.test_case import run_tests - - run_tests() diff --git a/test/dynamo/test_nops.py b/test/dynamo/test_nops.py deleted file mode 100644 index 797b1cfd5b..0000000000 --- a/test/dynamo/test_nops.py +++ /dev/null @@ -1,72 +0,0 @@ -# Owner(s): ["module: dynamo"] -import torch - -import torchdynamo.test_case -import torchdynamo.testing -from torchdynamo import eval_frame - -c = 10 - - -def fn1(a, b): - return a + b - c - - -def fn2(a, b): - x = 0 - y = 1 - - def modify(): - nonlocal x - x += a + b + c - - for _ in range(2): - modify() - - return x + y - - -def fn3(): - yield 1 - yield 2 - - -with_debug_nops = eval_frame._optimize_catch_errors( - torchdynamo.testing.debug_insert_nops -) - - -class NopTests(torchdynamo.test_case.TestCase): - @with_debug_nops - def test1(self): - self.assertEqual(fn1(1, 2), -7) - self.assertEqual(fn1(1, 2), -7) - - @with_debug_nops - def test2(self): - self.assertEqual(fn2(1, 2), 27) - self.assertEqual(fn2(1, 2), 27) - - @with_debug_nops - def test3(self): - t = fn3() - self.assertEqual(next(t), 1) - self.assertEqual(next(t), 2) - self.assertRaises(StopIteration, lambda: next(t)) - - def test_extended_args(self): - too_many_adds = "+".join(["a", "b"] * 256) - source = ( - f"lambda a, b: ({too_many_adds}+a if a.sum() > 0 else {too_many_adds} - b)" - ) - fn = eval(source) - a = torch.ones(1) - b = torch.ones(1) - fn = with_debug_nops(fn) - self.assertEqual(fn(a, b).sum(), 513) - - -if __name__ == "__main__": - from torchdynamo.test_case import run_tests - - run_tests() diff --git a/test/dynamo/test_optimizations.py b/test/dynamo/test_optimizations.py deleted file mode 100644 index da0978a12f..0000000000 --- a/test/dynamo/test_optimizations.py +++ /dev/null @@ -1,209 +0,0 @@ -# Owner(s): ["module: dynamo"] -import importlib -import json -import os -import unittest -from unittest.mock import patch - -import torch - -import torchdynamo -import torchdynamo.test_case -from torchdynamo.optimizations import backends -from torchdynamo.optimizations.analysis import has_mutation -from torchdynamo.optimizations.log_args import conv_args_analysis -from torchdynamo.optimizations.normalize import Inplacifier -from torchdynamo.optimizations.normalize import normalize -from torchdynamo.testing import same - - -def has_onnxruntime(): - try: - importlib.import_module("onnxruntime") - return True - except ImportError: - return False - - -def has_ipex(): - try: - importlib.import_module("intel_extension_for_pytorch") - return True - except ImportError: - return False - - -def has_functorch(): - try: - importlib.import_module("functorch") - return True - except ImportError: - return False - - -class Seq(torch.nn.Module): - def __init__(self): - super().__init__() - self.layers = torch.nn.Sequential( - torch.nn.Linear(10, 10), - torch.nn.ReLU(), - torch.nn.Linear(10, 10), - torch.nn.Sigmoid(), - ) - - def forward(self, x): - return self.layers(x) - - -class Conv_Bn_Relu(torch.nn.Module): - def __init__(self, in_channels, out_channels, **kwargs): - super(Conv_Bn_Relu, self).__init__() - self.conv = torch.nn.Conv2d(in_channels, out_channels, bias=False, **kwargs) - self.bn = torch.nn.BatchNorm2d(out_channels, eps=0.001) - self.relu = torch.nn.ReLU() - - def forward(self, x): - return self.relu(self.bn(self.conv(x))) - - -class TestOptimizations(torchdynamo.test_case.TestCase): - def test_inplacifier(self): - gm = torch.fx.symbolic_trace(Seq()) - normalize(gm) - Inplacifier(gm).inplacify() - gm.recompile() - code = gm.code.replace(" ", "") - self.assertIn("inplace=True", code) - self.assertIn("out=linear_1", code) - - def test_has_mutation(self): - gm = torch.fx.symbolic_trace(Seq()) - self.assertFalse(has_mutation(gm, torch.rand([10, 10]))) - - class Mutating(torch.nn.Module): - def __init__(self): - super(Mutating, self).__init__() - - def forward(self, arg): - return arg.add_(1) - - gm = torch.fx.symbolic_trace(Mutating()) - self.assertTrue(has_mutation(gm, torch.rand([10, 1, 1, 1]))) - - def test_has_mutation_factory(self): - def fn(): - x = torch.empty(2) - x.fill_(2) - return x - - def compiler_fn(graph, example_inputs): - self.assertTrue(has_mutation(graph, example_inputs)) - return graph - - opt_fn = torchdynamo.optimize(compiler_fn)(fn) - opt_fn() - - def test_example_inputs(self): - def fn(a, bc, d): - b, c = bc - return a / d - b / c - - def compiler_fn(graph, example_inputs): - nonlocal r1 - r1 = graph(*example_inputs)[0] - return graph.forward - - a = torch.empty(2).fill_(1) - b = torch.empty(2).fill_(2) - c = torch.empty(2).fill_(3) - d = 4 - r1 = None - r2 = fn(a, (b, c), d) - opt_fn = torchdynamo.optimize_assert(compiler_fn)(fn) - r3 = opt_fn(a, (b, c), d) - - self.assertIsNotNone(r1) - self.assertTrue(same(r1, r2)) - self.assertTrue(same(r1, r3)) - - @patch.object(torchdynamo.config, "fake_tensor_propagation", False) - @unittest.skipIf(not has_functorch(), "requires functorch") - def test_log_conv_args(self): - model = Conv_Bn_Relu(3, 32, kernel_size=3, stride=1) - model = model.to(memory_format=torch.channels_last) - model = model.eval() - input = torch.randn(8, 3, 64, 64).contiguous(memory_format=torch.channels_last) - r1 = model(input) - # check tmp/conv_args.json exists and has keys as arg names - filename = "tmp/conv_args.json" - if os.path.exists(filename): - os.remove(filename) - opt_model = torchdynamo.optimize(conv_args_analysis)(model) - with torch.no_grad(): - r2 = opt_model(input) - self.assertTrue(same(r1, r2.float(), tol=0.1)) - self.assertTrue(os.path.exists(filename)) - with open(filename) as f: - args_dict = json.load(f) - self.assertIn("convolution", args_dict.keys()) - conv_args_dict = args_dict["convolution"] - self.assertIn("input", conv_args_dict.keys()) - self.assertIn("weight", conv_args_dict.keys()) - self.assertIn("bias", conv_args_dict.keys()) - self.assertIn("stride", conv_args_dict.keys()) - self.assertIn("padding", conv_args_dict.keys()) - self.assertIn("dilation", conv_args_dict.keys()) - self.assertIn("transposed", conv_args_dict.keys()) - self.assertIn("output_padding", conv_args_dict.keys()) - self.assertIn("groups", conv_args_dict.keys()) - os.remove(filename) - - @unittest.skipIf(not has_ipex(), "requires ipex") - def test_ipex_fp32(self): - model = Conv_Bn_Relu(3, 32, kernel_size=3, stride=1) - model = model.to(memory_format=torch.channels_last) - model = model.eval() - input = torch.randn(8, 3, 64, 64).contiguous(memory_format=torch.channels_last) - r1 = model(input) - opt_model = torchdynamo.optimize(backends.ipex_fp32)(model) - with torch.no_grad(): - r2 = opt_model(input) - self.assertTrue(same(r1, r2)) - self.assertEqual(r2.dtype, torch.float32) - - @unittest.skipIf(not has_ipex(), "requires ipex") - def test_ipex_bf16(self): - model = Conv_Bn_Relu(3, 32, kernel_size=3, stride=1) - model = model.to(memory_format=torch.channels_last) - model = model.eval() - input = torch.randn(8, 3, 64, 64).contiguous(memory_format=torch.channels_last) - r1 = model(input) - opt_model = torchdynamo.optimize(backends.ipex_bf16)(model) - with torch.no_grad(), torch.cpu.amp.autocast(): - r2 = opt_model(input) - self.assertTrue(same(r1, r2.float(), tol=0.1)) - self.assertEqual(r2.dtype, torch.bfloat16) - - -class NormalizeIRTests(torchdynamo.test_case.TestCase): - @unittest.skipIf(not has_functorch(), "requires functorch") - def test_inplace_normalize(self): - def fn(a, b): - x = torch.cos(a) - x += b - return torch.sin(x) - - a = torch.randn(10) - b = torch.randn(10).to(torch.float64) - - ref = fn(a, b) - - optimized_fn = torchdynamo.optimize("aot_eager")(fn) - res = optimized_fn(a, b) - self.assertTrue(same(ref, res)) - - -if __name__ == "__main__": - from torchdynamo.test_case import run_tests - - run_tests() diff --git a/test/dynamo/test_optimizers.py b/test/dynamo/test_optimizers.py deleted file mode 100644 index 19a9655d74..0000000000 --- a/test/dynamo/test_optimizers.py +++ /dev/null @@ -1,103 +0,0 @@ -# Owner(s): ["module: dynamo"] - -import inspect -import unittest - -import torch - -import torchdynamo -import torchdynamo.test_case -import torchdynamo.testing - -input = torch.ones([10, 10]) -model = torch.nn.Sequential(*[torch.nn.Linear(10, 10) for _ in range(2)]) -model(input).sum().backward() - - -def make_test(optim_cls, exp_frame_cnt=1, closure=None, **kwargs): - opt = optim_cls(model.parameters(), **kwargs) - - def test_fn(self): - nonlocal opt - - counter = torchdynamo.testing.CompileCounter() - - if closure is not None: - - def fn(): - opt.step(closure) - - else: - fn = opt.step - - opt_fn = torchdynamo.optimize(counter)(fn) - opt_fn() - - self.assertEqual(counter.frame_count, exp_frame_cnt) - - return test_fn - - -class OptimizerTests(torchdynamo.test_case.TestCase): - @classmethod - def setUpClass(cls): - super().setUpClass() - # needed until pytorch assertion is changed to enable Adam - # to be called with capturable=True - cls._exit_stack.enter_context( - unittest.mock.patch.object( - torchdynamo.config, "capture_scalar_outputs", True - ) - ) - cls._exit_stack.enter_context( - unittest.mock.patch.object( - torchdynamo.config, "fake_tensor_propagation", False - ) - ) - cls._exit_stack.enter_context( - unittest.mock.patch.object( - torchdynamo.config, "raise_on_assertion_error", True - ) - ) - - test_sgd = make_test(torch.optim.SGD, lr=0.01) - # lgbfs has data-dependent control and internally iterates - # calling the closure - # TODO mlazos: re-enable once we have latest pytorch with FakeTensor fix #497 - # test_lbfgs = make_test( - # torch.optim.LBFGS, exp_frame_cnt=3, closure=lambda: model(input).sum() - # ) - # RAdam has data-dependent control which breaks the graph - test_radam = make_test(torch.optim.RAdam, exp_frame_cnt=1) - - # ASGD has a small optimization that avoids averaging - # This will fully capture the graph once that optimization is removed - # NB: in python versions < 3.8, we don't capture graphs when breaks - # occur in a loop - - # Fails without fake tensor: - # TypeError: clamp() received an invalid combination of arguments - got (float, min=int) - # test_asgd = make_test( - # torch.optim.ASGD, exp_frame_cnt=(0 if sys.version_info < (3, 8) else 6) - # ) - - -# exclude SparseAdam because other areas of the stack don't support it yet -# the others are handled specially above -exclude = set(["SGD", "Optimizer", "SparseAdam", "LBFGS", "RAdam", "ASGD"]) -optimizers = [ - opt - for opt in torch.optim.__dict__.values() - if inspect.isclass(opt) - and issubclass(opt, torch.optim.Optimizer) - and opt.__name__ not in exclude -] - - -for opt in optimizers: - setattr(OptimizerTests, "test_" + opt.__name__.lower(), make_test(opt)) - -if __name__ == "__main__": - from torchdynamo.test_case import run_tests - - run_tests() diff --git a/test/dynamo/test_python_autograd.py b/test/dynamo/test_python_autograd.py deleted file mode 100644 index a99a77a714..0000000000 --- a/test/dynamo/test_python_autograd.py +++ /dev/null @@ -1,293 +0,0 @@ -# Owner(s): ["module: dynamo"] -from typing import Callable -from typing import Dict -from typing import List -from typing import NamedTuple -from typing import Optional - -import torch - -import torchdynamo -from torchdynamo.test_case import TestCase -from torchdynamo.test_case import run_tests -from torchdynamo.testing import CompileCounter -from torchdynamo.testing import same - -""" -This is an example of a pure-python version of autograd implemented by -@zdevito. It represents a rather challenging test case for TorchDynamo -to push the limits of what it can do. -""" - - -_name: int = 0 - - -def fresh_name() -> str: - """create a new unique name for a variable: v0, v1, v2""" - global _name - r = f"v{_name}" - _name += 1 - return r - - -class Variable: - def __init__(self, value: torch.Tensor, name: str = None): - self.value = value - self.name = name or fresh_name() - - # We need to start with some tensors whose values were not computed - # inside the autograd. This function constructs leaf nodes. - @staticmethod - def constant(value: torch.Tensor, name: str = None): - return Variable(value, name) - - def __repr__(self): - return repr(self.value) - - # This performs a pointwise multiplication of a Variable, tracking gradients - def __mul__(self, rhs: "Variable") -> "Variable": - # defined later in the notebook - return operator_mul(self, rhs) - - def __add__(self, rhs: "Variable") -> "Variable": - return operator_add(self, rhs) - - def sum(self, name: Optional[str] = None) -> "Variable": - return operator_sum(self, name) - - def expand(self, sizes: List[int]) -> "Variable": - return operator_expand(self, sizes) - - -class TapeEntry(NamedTuple): - # names of the inputs to the original computation - inputs: List[str] - # names of the outputs of the original computation - outputs: List[str] - # apply chain rule - propagate: "Callable[List[Variable], List[Variable]]" - - -gradient_tape: List[TapeEntry] = [] - - -def reset_tape(): - gradient_tape.clear() - global _name - _name = 0 - - -def grad(L, desired_results: List[Variable]) -> List[Variable]: - # this map holds dL/dX for all values X - dL_d: Dict[str, Variable] = {} - # It starts by initializing the 'seed' dL/dL, which is 1 - dL_d[L.name] = Variable(torch.ones(())) - # print(f'd{L.name} ------------------------') - - # look up dL_dentries. If a variable is never used to compute the loss, - # we consider its gradient None, see the note below about zeros for more information. - def gather_grad(entries: List[str]): - return [dL_d[entry] if entry in dL_d else None for entry in entries] - - # propagate the gradient information backward - for entry in reversed(gradient_tape): - dL_doutputs = gather_grad(entry.outputs) - if all(dL_doutput is None for dL_doutput in dL_doutputs): - # optimize for the case where some gradient pathways are zero. See - # The note below for more details. - continue - - # perform chain rule propagation specific to each compute - dL_dinputs = entry.propagate(dL_doutputs) - - # Accululate the gradient produced for each input. - # Each use of a variable produces some gradient dL_dinput for that - # use. The multivariate chain rule tells us it is safe to sum - # all the contributions together. - for input, dL_dinput in zip(entry.inputs, dL_dinputs): - if input not in dL_d: - dL_d[input] = dL_dinput - else: - dL_d[input].value += dL_dinput.value - - # print some information to understand the values of each intermediate - # for name, value in dL_d.items(): - # print(f'd{L.name}_d{name} = {value.name}') - # print(f'------------------------') - - return gather_grad(desired.name for desired in desired_results) - - -def operator_mul(self: Variable, rhs: Variable) -> Variable: - if isinstance(rhs, float) and rhs == 1.0: - # peephole optimization - return self - - # define forward - r = Variable(self.value * rhs.value) - # print(f'{r.name} = {self.name} * {rhs.name}') - - # record what the inputs and outputs of the op were - inputs = [self.name, rhs.name] - outputs = [r.name] - - # define backprop - def propagate(dL_doutputs: List[Variable]): - (dL_dr,) = dL_doutputs - - dr_dself = rhs # partial derivative of r = self*rhs - dr_drhs = self # partial derivative of r = self*rhs - - # chain rule propagation from outputs to inputs of multiply - dL_dself = dL_dr * dr_dself - dL_drhs = dL_dr * dr_drhs - dL_dinputs = [dL_dself, dL_drhs] - return dL_dinputs - - # finally, we record the compute we did on the tape - gradient_tape.append(TapeEntry(inputs=inputs, outputs=outputs, propagate=propagate)) - return r - - -def operator_add(self: Variable, rhs: Variable) -> Variable: - # Add follows a similar pattern to Mul, but it doesn't end up - # capturing any variables. - r = Variable(self.value + rhs.value) - # print(f'{r.name} = {self.name} + {rhs.name}') - - def propagate(dL_doutputs: List[Variable]): - (dL_dr,) = dL_doutputs - dr_dself = 1.0 - dr_drhs = 1.0 - dL_dself = dL_dr * dr_dself - dL_drhs = dL_dr * dr_drhs - return [dL_dself, dL_drhs] - - gradient_tape.append( - TapeEntry(inputs=[self.name, rhs.name], outputs=[r.name], propagate=propagate) - ) - return r - - -def operator_sum(self: Variable, name: Optional[str]) -> "Variable": - r = Variable(torch.sum(self.value), name=name) - # print(f'{r.name} = {self.name}.sum()') - - def propagate(dL_doutputs: List[Variable]): - (dL_dr,) = dL_doutputs - size = self.value.size() - return [dL_dr.expand(*size)] - - gradient_tape.append( - TapeEntry(inputs=[self.name], outputs=[r.name], propagate=propagate) - ) - return r - - -def operator_expand(self: Variable, sizes: List[int]) -> "Variable": - assert self.value.dim() == 0 # only works for scalars - r = Variable(self.value.expand(sizes)) - # print(f'{r.name} = {self.name}.expand({sizes})') - - def propagate(dL_doutputs: List[Variable]): - (dL_dr,) = dL_doutputs - return [dL_dr.sum()] - - gradient_tape.append( - TapeEntry(inputs=[self.name], outputs=[r.name], propagate=propagate) - ) - return r - - -def simple(a, b): - t = a + b - return t * b - - -class TestPythonAutograd(TestCase): - def _common(self, fn, expected_ops): - args1 = [torch.randn(10), torch.randn(10)] - args2 = [torch.randn(10), torch.randn(10)] - cnt = CompileCounter() - fn_dynamo = torchdynamo.optimize_assert(cnt)(fn) - reset_tape() - res1 = fn_dynamo(*args1) - reset_tape() - res2 = fn_dynamo(*args2) - reset_tape() - self.assertTrue(same(res1, fn(*args1))) - reset_tape() - self.assertTrue(same(res2, fn(*args2))) - reset_tape() - self.assertEqual(cnt.frame_count, 1) - self.assertEqual(cnt.op_count, expected_ops) - - def test_forwards1(self): - def fn(a, b): - a = Variable.constant(a, name="a") - b = Variable.constant(b, name="b") - loss = simple(a, b).sum() - return loss - - self._common(fn, 3) - - def test_forwards2(self): - def fn(a, b): - reset_tape() - a = Variable.constant(a, name="a") - b = Variable.constant(b, name="b") - loss = simple(a, b).sum() - reset_tape() - return loss - - self._common(fn, 3) - - def test_backwards1(self): - def fn(a, b): - a = Variable.constant(a, name="a") - b = Variable.constant(b, name="b") - loss = simple(a, b).sum() - return grad(loss, [a, b]) - - self._common(fn, 8) - - def test_backwards2(self): - def fn(a, b): - reset_tape() - a = Variable.constant(a, name="a") - b = Variable.constant(b, name="b") - loss = simple(a, b).sum() - res = grad(loss, [a, b]) - reset_tape() - return res - - self._common(fn, 8) - - def test_split(self): - v1 = Variable.constant(torch.randn(10), name="a") - v2 = Variable.constant(torch.randn(10), name="b") - cnt = CompileCounter() - - def forward(a, b): - return simple(a, b).sum() - - reset_tape() - loss1 = forward(v1, v2) - grad1 = grad(loss1, [v1, v2]) - - reset_tape() - opt_forward = torchdynamo.optimize_assert(cnt)(forward) - opt_grad = torchdynamo.optimize_assert(cnt)(grad) - loss2 = opt_forward(v1, v2) - # force two frames - grad2 = opt_grad(loss2, [v1, v2]) - - self.assertTrue(same(loss1, loss2)) - self.assertTrue(same(grad1, grad2)) - self.assertEqual(cnt.frame_count, 2) - self.assertEqual(cnt.op_count, 8) - - -if __name__ == "__main__": - run_tests() diff --git a/test/dynamo/test_repros.py b/test/dynamo/test_repros.py deleted file mode 100644 index 0ab457abb9..0000000000 --- a/test/dynamo/test_repros.py +++ /dev/null @@ -1,1718 +0,0 @@ -# Owner(s): ["module: dynamo"] -import collections -import copy -import inspect -import itertools -import random -import unittest -from abc import ABC -from collections import namedtuple -from copy import deepcopy -from typing import List -from unittest.mock import patch - -import numpy as np -import torch -from torch import nn -from torch.nn import functional as F - -import torchdynamo.test_case -import torchdynamo.testing -import torchdynamo.utils -from torchdynamo.debug_utils import same_two_models -from torchdynamo.testing import rand_strided -from torchdynamo.testing import requires_static_shapes -from torchdynamo.testing import same - -try: - import torch._refs - - HAS_REFS = True -except ImportError: - HAS_REFS = False - - -def ifdyn(count1, count2): - if torchdynamo.config.dynamic_shapes: - return count1 - else: - return count2 - - -def has_detectron2(): - try: - from detectron2.layers.mask_ops import _paste_masks_tensor_shape - - return _paste_masks_tensor_shape is not None - except ImportError: - return False - - -def _do_paste_mask(masks, boxes, img_h: int, img_w: int, skip_empty: bool = True): - # from detectron2 mask_ops.py - - device = masks.device - - if skip_empty and not torch.jit.is_scripting(): - x0_int, y0_int = torch.clamp(boxes.min(dim=0).values.floor()[:2] - 1, min=0).to( - dtype=torch.int32 - ) - x1_int = torch.clamp(boxes[:, 2].max().ceil() + 1, max=img_w).to( - dtype=torch.int32 - ) - y1_int = torch.clamp(boxes[:, 3].max().ceil() + 1, max=img_h).to( - dtype=torch.int32 - ) - else: - x0_int, y0_int = 0, 0 - x1_int, y1_int = img_w, img_h - x0, y0, x1, y1 = torch.split(boxes, 1, dim=1) # each is Nx1 - - N = masks.shape[0] - - img_y = torch.arange(y0_int, y1_int, device=device, dtype=torch.float32) + 0.5 - img_x = torch.arange(x0_int, x1_int, device=device, dtype=torch.float32) + 0.5 - img_y = (img_y - y0) / (y1 - y0) * 2 - 1 - img_x = (img_x - x0) / (x1 - x0) * 2 - 1 - # img_x, img_y have shapes (N, w), (N, h) - - gx = img_x[:, None, :].expand(N, img_y.size(1), img_x.size(1)) - gy = img_y[:, :, None].expand(N, img_y.size(1), img_x.size(1)) - grid = torch.stack([gx, gy], dim=3) - - if not torch.jit.is_scripting(): - if not masks.dtype.is_floating_point: - masks = masks.float() - img_masks = F.grid_sample(masks, grid.to(masks.dtype), align_corners=False) - - if skip_empty and not torch.jit.is_scripting(): - return img_masks[:, 0], (slice(y0_int, y1_int), slice(x0_int, x1_int)) - else: - return img_masks[:, 0], () - - -def cat(tensors, dim=0): - # from detectron2 wrappers.py - assert isinstance(tensors, (list, tuple)) - if len(tensors) == 1: - return tensors[0] - return torch.cat(tensors, dim) - - -def shapes_to_tensor(x, device=None): - # from detectron2 wrappers.py - if torch.jit.is_scripting(): - return torch.as_tensor(x, device=device) - if torch.jit.is_tracing(): - assert all( - [isinstance(t, torch.Tensor) for t in x] - ), "Shape should be tensor during tracing!" - # as_tensor should not be used in tracing because it records a constant - ret = torch.stack(x) - if ret.device != device: # avoid recording a hard-coded device if not necessary - ret = ret.to(device=device) - return ret - return torch.as_tensor(x, device=device) - - -class Boxes: - # from detectron2 poolers.py - def __init__(self, tensor: torch.Tensor): - """ - Args: - tensor (Tensor[float]): a Nx4 matrix. Each row is (x1, y1, x2, y2). - """ - device = ( - tensor.device if isinstance(tensor, torch.Tensor) else torch.device("cpu") - ) - tensor = torch.as_tensor(tensor, dtype=torch.float32, device=device) - if tensor.numel() == 0: - # Use reshape, so we don't end up creating a new tensor that does not depend on - # the inputs (and consequently confuses jit) - tensor = tensor.reshape((-1, 4)).to(dtype=torch.float32, device=device) - assert tensor.dim() == 2 and tensor.size(-1) == 4, tensor.size() - self.tensor = tensor - - def __len__(self) -> int: - return self.tensor.shape[0] - - @property - def device(self): - return self.tensor.device - - -def convert_boxes_to_pooler_format(box_lists): - # from detectron2 structures.py - boxes = torch.cat([x.tensor for x in box_lists], dim=0) - # __len__ returns Tensor in tracing. - sizes = shapes_to_tensor([x.__len__() for x in box_lists], device=boxes.device) - indices = torch.repeat_interleave( - torch.arange(len(box_lists), dtype=boxes.dtype, device=boxes.device), sizes - ) - return cat([indices[:, None], boxes], dim=1) - - -ReformerBackwardOutput = namedtuple( - "ReformerBackwardOutput", - ["attn_output", "hidden_states", "grad_attn_output", "grad_hidden_states"], -) -ReformerEncoderOutput = namedtuple( - "ReformerEncoderOutput", - ["hidden_states", "all_hidden_states", "all_attentions", "past_buckets_states"], -) - - -class _ReversibleFunction(torch.autograd.Function): - # taken from modeling_reformer.py in huggingface - @staticmethod - def forward( - ctx, - hidden_states, - layers, - attention_mask, - head_mask, - num_hashes, - all_hidden_states, - all_attentions, - past_buckets_states, - use_cache, - orig_sequence_length, - output_hidden_states, - output_attentions, - ): - all_buckets = () - - # split duplicated tensor - hidden_states, attn_output = torch.chunk(hidden_states, 2, dim=-1) - - for layer_id, (layer, layer_head_mask) in enumerate(zip(layers, head_mask)): - if output_hidden_states is True: - all_hidden_states.append(hidden_states) - - attn_output = layer(attn_output) - - # Add last layer - if output_hidden_states is True: - all_hidden_states.append(hidden_states) - - # attach params to ctx for backward - ctx.save_for_backward(attn_output.detach(), hidden_states.detach()) - ctx.layers = layers - ctx.all_buckets = all_buckets - ctx.head_mask = head_mask - ctx.attention_mask = attention_mask - - # Concatenate 2 RevNet outputs - return torch.cat([attn_output, hidden_states], dim=-1) - - @staticmethod - def backward(ctx, grad_hidden_states): - grad_attn_output, grad_hidden_states = torch.chunk( - grad_hidden_states, 2, dim=-1 - ) - - # retrieve params from ctx for backward - attn_output, hidden_states = ctx.saved_tensors - - # create tuple - output = ReformerBackwardOutput( - attn_output=attn_output, - hidden_states=hidden_states, - grad_attn_output=grad_attn_output, - grad_hidden_states=grad_hidden_states, - ) - - # free memory - del grad_attn_output, grad_hidden_states, attn_output, hidden_states - - layers = ctx.layers - all_buckets = ctx.all_buckets - head_mask = ctx.head_mask - attention_mask = ctx.attention_mask - - for idx, layer in enumerate(layers[::-1]): - # pop last buckets from stack - buckets = all_buckets[-1] - all_buckets = all_buckets[:-1] - - # backprop - output = layer.backward_pass( - next_attn_output=output.attn_output, - hidden_states=output.hidden_states, - grad_attn_output=output.grad_attn_output, - grad_hidden_states=output.grad_hidden_states, - head_mask=head_mask[len(layers) - idx - 1], - attention_mask=attention_mask, - buckets=buckets, - ) - - assert all_buckets == (), "buckets have to be empty after backpropagation" - grad_hidden_states = torch.cat( - [output.grad_attn_output, output.grad_hidden_states], dim=-1 - ) - - # num of return vars has to match num of forward() args - # return gradient for hidden_states arg and None for other args - return ( - grad_hidden_states, - None, - None, - None, - None, - None, - None, - None, - None, - None, - None, - None, - ) - - -class ReformerEncoder(torch.nn.Module): - def __init__(self): - super().__init__() - self.dropout = 0.5 - self.layer_norm = torch.nn.LayerNorm(512, eps=1.0e-12) - self.layers = [torch.nn.Linear(256, 256)] - - def forward( - self, - hidden_states, - attention_mask=None, - head_mask=[None] * 6, - num_hashes=None, - use_cache=False, - orig_sequence_length=64, - output_hidden_states=False, - output_attentions=False, - ): - # hidden_states and attention lists to be filled if wished - all_hidden_states = [] - all_attentions = [] - past_buckets_states = [((None), (None)) for i in range(len(self.layers))] - - # concat same tensor for reversible ResNet - hidden_states = torch.cat([hidden_states, hidden_states], dim=-1) - hidden_states = _ReversibleFunction.apply( - hidden_states, - self.layers, - attention_mask, - head_mask, - num_hashes, - all_hidden_states, - all_attentions, - past_buckets_states, - use_cache, - orig_sequence_length, - output_hidden_states, - output_attentions, - ) - - # Apply layer norm to concatenated hidden states - hidden_states = self.layer_norm(hidden_states) - - # Apply dropout - hidden_states = torch.nn.functional.dropout( - hidden_states, p=self.dropout, training=self.training - ) - - return ReformerEncoderOutput( - hidden_states=hidden_states, - all_hidden_states=all_hidden_states, - all_attentions=all_attentions, - past_buckets_states=past_buckets_states, - ) - - -def longformer_chunk(hidden_states, window_overlap=256): - """convert into overlapping chunks. Chunk size = 2w, overlap size = w""" - - # non-overlapping chunks of size = 2w - hidden_states = hidden_states.view( - hidden_states.size(0), - hidden_states.size(1) // (window_overlap * 2), - window_overlap * 2, - hidden_states.size(2), - ) - - # use `as_strided` to make the chunks overlap with an overlap size = window_overlap - chunk_size = list(hidden_states.size()) - chunk_size[1] = chunk_size[1] * 2 - 1 - - chunk_stride = list(hidden_states.stride()) - chunk_stride[1] = chunk_stride[1] // 2 - return hidden_states.as_strided(size=chunk_size, stride=chunk_stride) - - -class PartialT5(torch.nn.Module): - # Highly simplified T5Attention prefix - def __init__(self): - super(PartialT5, self).__init__() - self.q = torch.nn.Linear(512, 512) - self.k = torch.nn.Linear(512, 512) - self.v = torch.nn.Linear(512, 512) - - def forward( - self, - hidden_states, - key_value_states=None, - past_key_value=None, - query_length=None, - ): - batch_size, seq_length = hidden_states.shape[:2] - - real_seq_length = seq_length - - if past_key_value is not None: - assert ( - len(past_key_value) == 2 - ), f"past_key_value should have 2 past states: keys and values. Got { len(past_key_value)} past states" - real_seq_length += ( - past_key_value[0].shape[2] if query_length is None else query_length - ) - - def shape(states): - """projection""" - return states.view(batch_size, -1, 8, 64).transpose(1, 2) - - def project(hidden_states, proj_layer, key_value_states, past_key_value): - """projects hidden states correctly to key/query states""" - if key_value_states is None: - # self-attn - # (batch_size, n_heads, seq_length, dim_per_head) - hidden_states = shape(proj_layer(hidden_states)) - elif past_key_value is None: - # cross-attn - # (batch_size, n_heads, seq_length, dim_per_head) - hidden_states = shape(proj_layer(key_value_states)) - - if past_key_value is not None: - if key_value_states is None: - # self-attn - # (batch_size, n_heads, key_length, dim_per_head) - hidden_states = torch.cat([past_key_value, hidden_states], dim=2) - else: - # cross-attn - hidden_states = past_key_value - return hidden_states - - # get query states - query_states = shape( - self.q(hidden_states) - ) # (batch_size, n_heads, seq_length, dim_per_head) - - # get key/value states - key_states = project( - hidden_states, - self.k, - key_value_states, - past_key_value[0] if past_key_value is not None else None, - ) - value_states = project( - hidden_states, - self.v, - key_value_states, - past_key_value[1] if past_key_value is not None else None, - ) - - # compute scores - scores = torch.matmul(query_states, key_states.transpose(3, 2)) - - # (truncated here ) - return scores, value_states - - -class ChunkReformerFeedForward(torch.nn.Module): - # simplified from HF modeling_reformer.py - def __init__(self): - super().__init__() - self.layer_norm = torch.nn.LayerNorm(256, eps=1e-12) - self.dense = torch.nn.Linear(256, 256) - self.output = torch.nn.Linear(256, 256) - - def forward(self, attention_output): - return apply_chunking_to_forward( - self.forward_chunk, - attention_output + 1, - ) - - def forward_chunk(self, hidden_states): - hidden_states = self.layer_norm(hidden_states) - hidden_states = self.dense(hidden_states) - return self.output(hidden_states) - - -def apply_chunking_to_forward(forward_fn, *input_tensors): - # simplified from HF model_utils.py - assert len(input_tensors) > 0 - tensor_shape = input_tensors[0].shape[1] - assert all(input_tensor.shape[1] == tensor_shape for input_tensor in input_tensors) - num_args_in_forward_chunk_fn = len(inspect.signature(forward_fn).parameters) - if num_args_in_forward_chunk_fn != len(input_tensors): - raise ValueError() - - return forward_fn(*input_tensors) - - -class FakeMamlInner(torch.nn.Module): - def __init__(self): - super(FakeMamlInner, self).__init__() - self.linear = torch.nn.Linear(784, 5) - - def forward(self, x, ignored=None, bn_training=False): - return self.linear(x.view(x.shape[0], -1)) - - -class PartialMaml(torch.nn.Module): - # Highly simplified version of maml.meta.Meta.finetuning - def __init__(self): - super(PartialMaml, self).__init__() - self.net = FakeMamlInner() - self.update_step_test = 10 - self.update_lr = 0.4 - - def forward(self, x_spt, y_spt, x_qry, y_qry): - querysz = x_qry.size(0) - - corrects = [0 for _ in range(self.update_step_test + 1)] - - # in order to not ruin the state of running_mean/variance and bn_weight/bias - # we finetunning on the copied model instead of self.net - net = deepcopy(self.net) - - # 1. run the i-th task and compute loss for k=0 - logits = net(x_spt) - loss = F.cross_entropy(logits, y_spt) - grad = torch.autograd.grad(loss, net.parameters()) - fast_weights = list( - map(lambda p: p[1] - self.update_lr * p[0], zip(grad, net.parameters())) - ) - - # this is the loss and accuracy before first update - with torch.no_grad(): - # [setsz, nway] - logits_q = net(x_qry, net.parameters(), bn_training=True) - # [setsz] - pred_q = F.softmax(logits_q, dim=1).argmax(dim=1) - # scalar - correct = torch.eq(pred_q, y_qry).sum().item() - corrects[0] = corrects[0] + correct - - # this is the loss and accuracy after the first update - with torch.no_grad(): - # [setsz, nway] - logits_q = net(x_qry, fast_weights, bn_training=True) - # [setsz] - pred_q = F.softmax(logits_q, dim=1).argmax(dim=1) - # scalar - correct = torch.eq(pred_q, y_qry).sum().item() - corrects[1] = corrects[1] + correct - - del net - - accs = torch.tensor(corrects) / querysz - - return accs - - -class ModelOutput(collections.OrderedDict): - """based on file_utils.py in HuggingFace""" - - def __getitem__(self, k): - if isinstance(k, str): - inner_dict = {k: v for (k, v) in self.items()} - return inner_dict[k] - else: - return self.to_tuple()[k] - - def __setattr__(self, name, value): - if name in self.keys() and value is not None: - # Don't call self.__setitem__ to avoid recursion errors - super().__setitem__(name, value) - super().__setattr__(name, value) - - def __setitem__(self, key, value): - # Will raise a KeyException if needed - super().__setitem__(key, value) - # Don't call self.__setattr__ to avoid recursion errors - super().__setattr__(key, value) - - def to_tuple(self): - return tuple(self[k] for k in self.keys()) - - -def create_rand_mask_from_inputs( - from_blocked_mask, - to_blocked_mask, - rand_attn, - num_attention_heads, - num_rand_blocks, - batch_size, - from_seq_length, - from_block_size, -): - """taken from HF modeling_big_bird.py""" - num_windows = from_seq_length // from_block_size - 2 - rand_mask = torch.stack( - [p1[i1.flatten()] for p1, i1 in zip(to_blocked_mask, rand_attn)] - ) - rand_mask = rand_mask.view( - batch_size, num_attention_heads, num_windows, num_rand_blocks * from_block_size - ) - rand_mask = torch.einsum("blq,bhlk->bhlqk", from_blocked_mask[:, 1:-1], rand_mask) - return rand_mask - - -class SequentialAppendList(torch.nn.Sequential): - """from timm/models/vovnet.py""" - - def __init__(self, *args): - super(SequentialAppendList, self).__init__(*args) - - def forward(self, x: torch.Tensor, concat_list: List[torch.Tensor]) -> torch.Tensor: - for i, module in enumerate(self): - if i == 0: - concat_list.append(module(x)) - else: - concat_list.append(module(concat_list[-1])) - x = torch.cat(concat_list, dim=1) - return x, concat_list - - -class BatchNormAct2d(torch.nn.BatchNorm2d): - """Taken from timm""" - - def __init__( - self, - num_features, - eps=1e-5, - momentum=0.1, - affine=True, - track_running_stats=True, - act_layer=torch.nn.ReLU, - inplace=True, - ): - super(BatchNormAct2d, self).__init__( - num_features, - eps=eps, - momentum=momentum, - affine=affine, - track_running_stats=track_running_stats, - ) - self.act = act_layer(inplace=inplace) - - @torch.jit.ignore - def _forward_python(self, x): - return super().forward(x) - - def forward(self, x): - if torch.jit.is_scripting(): - x = self._forward_jit(x) - else: - x = self._forward_python(x) - x = self.act(x) - return x - - -def get_parameter_dtype(parameter): - """from huggingface model_utils.py""" - try: - return next(parameter.parameters()).dtype - except StopIteration: - # For nn.DataParallel compatibility in PyTorch 1.5 - - def find_tensor_attributes(module): - tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)] - return tuples - - gen = parameter._named_members(get_members_fn=find_tensor_attributes) - first_tuple = next(gen) - return first_tuple[1].dtype - - -class DummyConfig: - attn_layers = ["local", "lsh", "local", "lsh", "local", "lsh"] - lsh_attn_chunk_length = 64 - local_attn_chunk_length = 64 - - -def _get_min_chunk_len(config): - """from hf_Reformer""" - attn_types = config.attn_layers - attn_types_set = set(attn_types) - if len(attn_types_set) == 1 and attn_types[0] == "lsh": - return config.lsh_attn_chunk_length - elif len(attn_types_set) == 1 and attn_types[0] == "local": - return config.local_attn_chunk_length - elif len(attn_types_set) == 2 and attn_types_set == set(["lsh", "local"]): - return min(config.lsh_attn_chunk_length, config.local_attn_chunk_length) - else: - raise NotImplementedError( - f"Only attn layer types 'lsh' and 'local' exist, but `config.attn_layers`: {config.attn_layers}. Select " - "attn layer types from ['lsh', 'local'] only." - ) - - -def _stable_argsort(vector, dim): - """from hf_Reformer""" - # this function scales the vector so that torch.argsort is stable. - # torch.argsort is not stable on its own - scale_offset = torch.arange(vector.shape[dim], device=vector.device).view(1, 1, -1) - scale_offset = scale_offset.expand(vector.shape) - scaled_vector = vector.shape[dim] * vector + (scale_offset % vector.shape[dim]) - return torch.argsort(scaled_vector, dim=dim) - - -def _get_sorted_bucket_idx_and_undo_sorted_bucket_idx(buckets): - """from hf_Reformer""" - # no gradients are needed - with torch.no_grad(): - # hash-based sort - sorted_bucket_idx = _stable_argsort(buckets, dim=-1) - - # create simple indices to scatter to, to have undo sort - indices = ( - torch.arange(sorted_bucket_idx.shape[-1], device=buckets.device) - .view(1, 1, -1) - .expand(sorted_bucket_idx.shape) - ) - - # get undo sort - undo_sorted_bucket_idx = sorted_bucket_idx.new(*sorted_bucket_idx.size()) - undo_sorted_bucket_idx.scatter_(-1, sorted_bucket_idx, indices) - - return sorted_bucket_idx, undo_sorted_bucket_idx - - -class FeedForwardLayer(nn.Module): - def __init__(self, d_model, dim_feedforward, activation, dropout) -> None: - super(FeedForwardLayer, self).__init__() - self.linear1 = nn.Linear(d_model, dim_feedforward) - self.activation = activation - self.dropout1 = nn.Dropout(dropout) - self.linear2 = nn.Linear(dim_feedforward, d_model) - self.dropout2 = nn.Dropout(dropout) - - def forward(self, x): - return self.dropout2( - self.linear2(self.dropout1(self.activation(self.linear1(x)))) - ) - - -class TransformerEncoderLayer(nn.Module): - def __init__( - self, - d_model, - nhead, - dim_feedforward=2048, - dropout=0.1, - activation=nn.ReLU(), - layer_norm_eps=1e-5, - ): - super(TransformerEncoderLayer, self).__init__() - self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) - self.norm1 = nn.LayerNorm(d_model, eps=layer_norm_eps) - self.norm2 = nn.LayerNorm(d_model, eps=layer_norm_eps) - self.dropout = nn.Dropout(dropout) - self.ff_block = FeedForwardLayer(d_model, dim_feedforward, activation, dropout) - - def forward(self, src, src_mask=None, src_key_padding_mask=None): - x = src - x = self.norm1(x + self._sa_block(x, src_mask, src_key_padding_mask)) - x = self.norm2(x + self._ff_block(x)) - return x - - # self-attention block - def _sa_block(self, x, attn_mask, key_padding_mask): - x = self.self_attn( - x, - x, - x, - attn_mask=attn_mask, - key_padding_mask=key_padding_mask, - need_weights=False, - )[0] - return self.dropout(x) - - # feed forward block - def _ff_block(self, x): - return self.ff_block(x) - - -class TestModule(torch.nn.Module): - def inner_fn(self, left, right): - return tuple(left) == tuple(right) - - def fn(self, tensor): - if type(tensor) is int: - return False - - torch.add(tensor, tensor) - return self.inner_fn(tensor.shape, (1, 2, 3)) - - -class ReproTests(torchdynamo.test_case.TestCase): - def test_do_paste_mask(self): - torchdynamo.utils.counters.clear() - opt__do_paste_mask = torchdynamo.optimize(torchdynamo.testing.CompileCounter())( - _do_paste_mask - ) - opt__do_paste_mask( - torch.randn(1, 1, 28, 28), - torch.tensor([[0.0, 1, 2, 4]]) * 1, - 427, - 640, - True, - ) - opt__do_paste_mask( - torch.randn(1, 1, 28, 28), - torch.tensor([[0.0, 1, 2, 4]]) * 2, - 427, - 640, - True, - ) - opt__do_paste_mask( - torch.randn(1, 1, 28, 28), - torch.tensor([[0.0, 1, 2, 4]]) * 3, - 612, - 612, - True, - ) - opt__do_paste_mask( - torch.randn(1, 1, 28, 28), - torch.tensor([[0.0, 1, 2, 4]]) * 4, - 612, - 612, - True, - ) - opt__do_paste_mask( - torch.randn(1, 1, 28, 28), - torch.tensor([[0.0, 1, 2, 4]]) * 2, - 427, - 640, - False, - ) - - self.assertGreaterEqual(torchdynamo.utils.counters["frames"]["ok"], 3) - # Graph break because of dynamic slicing - self.assertEqual( - torchdynamo.utils.counters["frames"]["total"], - torchdynamo.utils.counters["frames"]["ok"] + 1, - ) - - @patch.object(torchdynamo.config, "fake_tensor_propagation", True) - def test_convert_boxes_to_pooler_format(self): - boxes1 = [ - Boxes(torch.arange(0, 8).reshape((2, 4))), - Boxes(torch.arange(8, 16).reshape((2, 4))), - ] - boxes2 = [ - Boxes(torch.arange(16, 20).reshape((1, 4))), - Boxes(torch.arange(20, 24).reshape((1, 4))), - ] - correct1 = convert_boxes_to_pooler_format(boxes1) - correct2 = convert_boxes_to_pooler_format(boxes2) - fn = convert_boxes_to_pooler_format - cnt = torchdynamo.testing.CompileCounter() - opt_fn = torchdynamo.optimize(cnt)(fn) - self.assertTrue(same(opt_fn(boxes1), correct1)) - self.assertTrue(same(opt_fn(boxes2), correct2)) - - # repeat_interleave is a dynamic shape operator we do not execute/ - # In the future, we could reduce the frame_count down to 1 - # by guarding on the exact values of `Tensor repeats` arg - self.assertEqual(cnt.frame_count, ifdyn(2, 4)) - self.assertEqual(cnt.op_count, ifdyn(9, 10)) - - def test_boxes_len(self): - def fn(boxes): - return len(boxes) + boxes.__len__() + boxes.tensor - - boxes1 = Boxes(torch.arange(0, 8).reshape((2, 4))) - cnt = torchdynamo.testing.CompileCounter() - opt_fn = torchdynamo.optimize_assert(cnt)(fn) - self.assertTrue(same(opt_fn(boxes1), boxes1.tensor + 4.0)) - - self.assertEqual(cnt.frame_count, 1) - self.assertEqual(cnt.op_count, ifdyn(6, 1)) - - def _reformer(self, nopython): - input = torch.randn([1, 64, 256]) - model = ReformerEncoder() - torch.manual_seed(1337) - correct = copy.deepcopy(model)(input) - cnt = torchdynamo.testing.CompileCounter() - torch.manual_seed(1337) - opt_model = torchdynamo.optimize(cnt, nopython=nopython)(model) - self.assertTrue(same(opt_model(input), correct)) - return cnt - - def test_reformer_eval(self): - with torch.no_grad(): - cnt = self._reformer(nopython=True) - self.assertEqual(cnt.frame_count, 1) - self.assertEqual(cnt.op_count, 10) - - def test_reformer_train(self): - with torch.enable_grad(): - cnt = self._reformer(nopython=False) - # cant inline torch.autograd.Function means graph break - self.assertEqual(cnt.frame_count, 4) - self.assertEqual(cnt.op_count, 10) - - def test_longformer_chunk(self): - input1 = torch.randn([1, 4096, 1]) - input2 = torch.randn([12, 4096, 64]) - correct1 = longformer_chunk(input1) - correct2 = longformer_chunk(input2) - fn = longformer_chunk - cnt = torchdynamo.testing.CompileCounter() - opt_fn = torchdynamo.optimize_assert(cnt)(fn) - self.assertTrue(same(opt_fn(input1), correct1)) - self.assertTrue(same(opt_fn(input2), correct2)) - self.assertTrue(same(opt_fn(input1), correct1)) - self.assertTrue(same(opt_fn(input2), correct2)) - - self.assertEqual(cnt.frame_count, ifdyn(1, 2)) - self.assertEqual(cnt.op_count, ifdyn(19, 4)) - - def test_hf_t5_forward(self): - input = torch.randn([1, 2048, 512]) - model = PartialT5() - correct = model(input) - cnt = torchdynamo.testing.CompileCounter() - opt_model = torchdynamo.optimize_assert(cnt)(model) - self.assertTrue(same(opt_model(input), correct)) - - self.assertEqual(cnt.frame_count, 1) - self.assertEqual(cnt.op_count, ifdyn(13, 11)) - - def test_slicing_dynamic_shape(self): - def fn(y): - x = torch.ones(8) - idx = y[0] - out = x[idx:] - return (out + 3) * 5 - - counter = torchdynamo.testing.CompileCounter() - opt_fn = torchdynamo.optimize(counter)(fn) - out = opt_fn(torch.ones(10, dtype=torch.long)) - # idx should be 1 -> slicing off [1:] of 8 elem tensor - self.assertEqual(list(out.shape), [7]) - - expected_ops = ifdyn(5, 4) - expected_frame = ifdyn(1, 2) - - self.assertEqual(expected_ops, expected_ops) - self.assertEqual(expected_frame, expected_frame) - - self.assertEqual(list(opt_fn(torch.tensor([4])).shape), [4]) - - def test_slicing_dynamic_shape_setitem(self): - def fn(input_lengths: torch.Tensor, new_ones_1): - getitem_13 = input_lengths[3] - new_ones_1[(3, slice(getitem_13, None, None))] = 0 - setitem_13 = new_ones_1 - return (setitem_13,) - - x = torch.randn(10).to(dtype=torch.int64) - y = torch.randn(10, 204) - ref = fn(x, y) - opt_fn = torchdynamo.optimize("aot_eager")(fn) - res = opt_fn(x, y) - self.assertTrue(same(ref, res)) - - @requires_static_shapes - def test_chunk_reformer_ff(self): - input = torch.randn([1, 4096, 256]) - model = ChunkReformerFeedForward() - correct = model(input) - cnt = torchdynamo.testing.CompileCounter() - opt_model = torchdynamo.optimize_assert(cnt)(model) - self.assertTrue(same(opt_model(input), correct)) - - self.assertEqual(cnt.frame_count, 1) - self.assertEqual(cnt.op_count, 4) - - # see: https://github.com/pytorch/pytorch/issues/80067 - @patch.object(torchdynamo.config, "fake_tensor_propagation", False) - @patch.object(torchdynamo.config, "capture_scalar_outputs", True) - def test_maml_item_capture(self): - a = torch.randn(5, 1, 28, 28) - b = torch.zeros(5, dtype=torch.int64) - c = torch.randn(75, 1, 28, 28) - d = torch.zeros(75, dtype=torch.int64) - model = PartialMaml() - correct = model(a, b, c, d) - cnt = torchdynamo.testing.CompileCounter() - opt_model = torchdynamo.optimize(cnt)(model) - for _ in range(10): - self.assertTrue(same(opt_model(a, b, c, d), correct)) - - self.assertEqual(cnt.frame_count, ifdyn(3, 2)) - # TODO(jansel): figure out why op count depends on imports - self.assertIn(cnt.op_count, (36, 35, 29, 28)) - - # see: https://github.com/pytorch/pytorch/issues/80067 - @patch.object(torchdynamo.config, "fake_tensor_propagation", False) - @patch.object(torchdynamo.config, "capture_scalar_outputs", False) - def test_maml_no_item_capture(self): - a = torch.randn(5, 1, 28, 28) - b = torch.zeros(5, dtype=torch.int64) - c = torch.randn(75, 1, 28, 28) - d = torch.zeros(75, dtype=torch.int64) - model = PartialMaml() - correct = model(a, b, c, d) - cnt = torchdynamo.testing.CompileCounter() - opt_model = torchdynamo.optimize(cnt)(model) - for _ in range(10): - self.assertTrue(same(opt_model(a, b, c, d), correct)) - - self.assertEqual(cnt.frame_count, ifdyn(5, 4)) - # TODO(jansel): figure out why op count depends on imports - self.assertIn(cnt.op_count, (31, 36, 35, 29, 28)) - - def test_hf_model_output(self): - ex = ModelOutput(a=torch.randn(10), b=torch.randn(10), c=torch.randn(10)) - - def fn1(x): - return x["a"] + 1 - - def fn2(x): - return x.a + 1 - - def fn3(x): - return x.to_tuple()[0] + 1 - - def fn4(x): - return x[0] + 1 - - cnt = torchdynamo.testing.CompileCounter() - for fn in (fn1, fn2, fn3, fn4): - cnt.clear() - opt_fn = torchdynamo.optimize_assert(cnt)(fn) - self.assertTrue(same(opt_fn(ex), ex.a + 1)) - self.assertEqual(cnt.frame_count, 1) - self.assertEqual(cnt.op_count, 1) - - @requires_static_shapes - def test_create_rand_mask_from_inputs(self): - args = [ - torch.randn([1, 64, 64]), - torch.randn([1, 64, 64]), - torch.zeros([1, 12, 62, 3], dtype=torch.int64), - 12, - 3, - 1, - 4096, - 64, - ] - correct = create_rand_mask_from_inputs(*args) - fn = create_rand_mask_from_inputs - - cnt = torchdynamo.testing.CompileCounter() - opt_fn = torchdynamo.optimize_assert(cnt)(fn) - self.assertTrue(same(opt_fn(*args), correct)) - self.assertEqual(cnt.frame_count, 1) - self.assertEqual(cnt.op_count, 8) - - def test_rng_state(self): - def fn(): - state = torch.get_rng_state() - before = torch.rand(1000) - torch.set_rng_state(state) - after = torch.rand(1000) - return before, after - - cnt = torchdynamo.testing.CompileCounter() - opt_fn = torchdynamo.optimize(cnt)(fn) - - before, after = opt_fn() - self.assertTrue(same(before, after)) - self.assertEqual(cnt.frame_count, 1) - self.assertEqual(cnt.op_count, 4) # rand, rand - graph, _ = torchdynamo.export(fn) - - def test_seq_append_list(self): - x = torch.randn(4, 10) - model = SequentialAppendList( - torch.nn.Linear(10, 10), - torch.nn.ReLU(), - torch.nn.Linear(10, 10), - torch.nn.ReLU(), - ) - # this one is tricky because it mutates the list provided as an input - l1 = [x] - l2 = [x] - correct, _ = model(x, l1) - cnt = torchdynamo.testing.CompileCounter() - opt_model = torchdynamo.optimize_assert(cnt)(model) - result, l3 = opt_model(x, l2) - self.assertTrue(same(result, correct)) - self.assertTrue(same(l1, l2)) - self.assertIs(l2, l3) - self.assertEqual(cnt.frame_count, 1) - self.assertEqual(cnt.op_count, 5) - - def test_batch_norm_act(self): - a = torch.randn(5, 1, 28, 28) - model = BatchNormAct2d(1).eval() - correct = model(a) - cnt = torchdynamo.testing.CompileCounter() - if not torchdynamo.config.specialize_int_float: - # _local_scalar_dense causes graph break w 0-dim tensor - opt_model = torchdynamo.optimize(cnt)(model) - self.assertTrue(same(opt_model(a), correct)) - return - - opt_model = torchdynamo.optimize_assert(cnt)(model) - self.assertTrue(same(opt_model(a), correct)) - self.assertEqual(cnt.frame_count, 1) - self.assertEqual(cnt.op_count, 2) - - def test_get_parameter_dtype(self): - model = SequentialAppendList( - torch.nn.Linear(10, 10), - torch.nn.ReLU(), - ) - - def fn(model, x): - return x + torch.randn(10, dtype=get_parameter_dtype(model)) - - cnt = torchdynamo.testing.CompileCounter() - opt_fn = torchdynamo.optimize_assert(cnt)(fn) - self.assertEqual(opt_fn(model, torch.randn(10)).dtype, torch.float32) - self.assertEqual(cnt.frame_count, 1) - self.assertEqual(cnt.op_count, 2) - - @patch.object(torchdynamo.config, "fake_tensor_propagation", True) - def test_nn_parameter(self): - def test_fn(): - a = torch.nn.Parameter(torch.randn(5, 5)) - # Checks that TensorVariable stores the type information correctly - self.assertTrue(isinstance(a, torch.nn.Parameter)) - return a - - cnt = torchdynamo.testing.CompileCounter() - opt_test_fn = torchdynamo.optimize(cnt)(test_fn) - out = opt_test_fn() - self.assertTrue(isinstance(out, torch.nn.Parameter)) - - def test_Size(self): - def test_fn(): - a = torch.randn(4) - x = torch.Size([1, 2, 3]) - # Checks that SizeVariable return torch.Size object - assert isinstance(x, torch.Size) - # Causes graph breaks and checks reconstruction of SizeVariable - # object - self.assertIsInstance(x, torch.Size) - return a - - cnt = torchdynamo.testing.CompileCounter() - opt_test_fn = torchdynamo.optimize(cnt)(test_fn) - opt_test_fn() - - def test_indexing_with_list(self): - def test_fn(): - def run_test(tensor, *idx): - npt = tensor.numpy() - assert npt[idx].shape == tensor[idx].shape - - x = torch.arange(0, 10) - cases = [ - [None, None], - [1, None], - ] - - for case in cases: - run_test(x, *case) - - return torch.randn(4) - - cnt = torchdynamo.testing.CompileCounter() - opt_test_fn = torchdynamo.optimize(cnt)(test_fn) - opt_test_fn() - - def test_reformer_min_chunk_len(self): - def fn(cfg): - t = torch.empty(10) - t.fill_(_get_min_chunk_len(cfg)) - return t[0] - - cfg = DummyConfig() - cnt = torchdynamo.testing.CompileCounter() - opt_fn = torchdynamo.optimize_assert(cnt)(fn) - self.assertEqual(opt_fn(cfg), 64) - self.assertEqual(cnt.frame_count, 1) - self.assertEqual(cnt.op_count, 3) - - def test_reformer_sorting(self): - x = torch.zeros([1, 12, 4096], dtype=torch.int64) - correct = _get_sorted_bucket_idx_and_undo_sorted_bucket_idx(x) - fn = _get_sorted_bucket_idx_and_undo_sorted_bucket_idx - - cnt = torchdynamo.testing.CompileCounter() - opt_fn = torchdynamo.optimize_assert(cnt)(fn) - self.assertTrue(same(opt_fn(x), correct)) - self.assertEqual(cnt.frame_count, 1) - self.assertEqual(cnt.op_count, ifdyn(28, 14)) - - def test_recursive_map(self): - # https://github.com/pytorch/torchdynamo/issues/132 - def _recursive_map(struct, batch_dim=0): - for k, v in struct.items(): - if v is not None: - if isinstance(v, dict): - _recursive_map(v) - else: - struct[k] = v - - def toy_example(a, b, v): - x = a / (torch.abs(a) + 1) - if v is not None: - _recursive_map(v) - return x * b - - cnt = torchdynamo.testing.CompileCounter() - opt_toy_example = torchdynamo.optimize(cnt)(toy_example) - opt_toy_example( - torch.randn(10), - torch.randn(10), - {"layer0": {"memory_keys": torch.randn(10)}}, - ) - self.assertEqual(cnt.frame_count, 1) - self.assertEqual(cnt.op_count, 4) - - def test_issue175(self): - n_heads = 2 - d_model = 64 - model = TransformerEncoderLayer(d_model, n_heads) - inp = torch.randn(1, d_model) - cnt = torchdynamo.testing.CompileCounter() - opt_model = torchdynamo.optimize(cnt, nopython=True)(model) - opt_model(inp) - opt_model(inp) - self.assertEqual(cnt.frame_count, 1) - self.assertEqual(cnt.op_count, 12) - - def test_exec_import(self): - def fn1(): - exec("import math") - - def fn2(): - try: - math.sqrt(4) - return False - except NameError: - return True - - def fn3(): - fn1() - return fn2() - - self.assertTrue(fn3()) - opt_fn3 = torchdynamo.optimize("eager")(fn3) - self.assertTrue(opt_fn3()) - - def test_exec_wildcard_import(self): - # Test that globals are not carried over from frame to frame - def fn1(): - exec("from torch import *") - - def fn2(): - x = torch.zeros(4) - for i in range(5): - x = x + i - return x - - def fn3(): - fn1() - return fn2() - - ref = fn3() - opt_fn3 = torchdynamo.optimize("eager")(fn3) - res = opt_fn3() - self.assertTrue(same(ref, res)) - - def test_with_on_graph_break_inst(self): - def reversible(x): - print("Hello world") # Cause graph break so inline fails - return torch.sin(torch.cos(x)) - - def fn(x): - with torch.enable_grad(): - a = torch.sin(x) - b = reversible(a) - c = torch.sigmoid(b) - c.sum().backward() - return x.grad - - x = torch.randn(3, requires_grad=True) - x.grad = None - with torch.no_grad(): - ref = fn(x) - - x.grad = None - opt_fn = torchdynamo.optimize("eager")(fn) - with torch.no_grad(): - res = opt_fn(x) - self.assertTrue(same(ref, res)) - - def test_abc_setattr(self): - # tests that we correctly bail out of __setattr__ calls - - # TODO: does not ensure ABC classes are correctly inferred as ClassVariables - # (doesn't test the fix for 'super()') - - class BaseModule(torch.nn.Module, ABC): - def blah(self, x): - return x + 1 - - class Derived(BaseModule): - def __setattr__(self, name, value) -> None: - super().__setattr__(name, value) - - def forward(self, x): - # expect a graph break on __setattr__ - self.foo = 0 - return self.blah(x) - - def blah(self, x): - return super().blah(x) - - x = torch.randn(3, requires_grad=True) - mod = Derived() - opt_mod = torchdynamo.optimize("eager")(mod) - opt_mod(x) - - self.assertGreaterEqual(torchdynamo.utils.counters["frames"]["ok"], 3) - self.assertGreaterEqual(torchdynamo.utils.counters["frames"]["total"], 3) - - def test_guard_fail_tensor_bool(self): - @torchdynamo.skip - def fn(): - condition_shape = (5, 5) - dtypes = (torch.bool,) - shapes = ( - (), - (5,), - (1, 5), - ) - - tensors = list( - [ - torch.empty(shape, dtype=dtype).fill_(17) - for shape, dtype in itertools.product(shapes, dtypes) - ] - ) - - x_vals = (5.0, *tensors) - y_vals = (6.0, *tensors) - - @torchdynamo.disable - def get_expected(condition, x, y): - x_np = x.cpu().numpy() if isinstance(x, torch.Tensor) else x - y_np = y.cpu().numpy() if isinstance(y, torch.Tensor) else y - return torch.from_numpy( - np.where(condition.cpu().numpy(), x_np, y_np) - ).to(common_dtype) - - for x, y in zip(x_vals, y_vals): - condition = torch.empty(*condition_shape, dtype=torch.bool).bernoulli_() - common_dtype = torch.result_type(x, y) - - def check_equal(condition, x, y): - # NumPy aggressively promotes to double, hence cast to output to correct dtype - expected = get_expected(condition, x, y) - result = torch.where(condition, x, y) - assert torch.allclose(expected, result) - - check_equal(condition, x, y) - check_equal(condition, y, x) - - fn() - opt_fn = torchdynamo.optimize("eager")(fn) - opt_fn() - - def test_guard_fail_nested_tuple(self): - def fn(args): - return torch.ones(()), args[0] * 2 - - # This adds a tensor check on args[1][0] and args[1][1] - args1 = (torch.ones(1), (torch.ones(1), torch.ones(1))) - args2 = (torch.ones(1), torch.ones(1)) - opt_fn = torchdynamo.optimize("eager")(fn) - ref = opt_fn(args1) - res = opt_fn(args2) - - self.assertTrue(same(ref, res)) - - def test_numpy_list(self): - @torchdynamo.disable - def rand_gen(): - return list(np.array([random.randint(5, 10) for _ in range(10)])) - - def fn(x): - random_list = rand_gen() - z = torch.LongTensor(random_list) - return x * z - - x = torch.ones(10) * 2 - - random.seed(0) - ref0 = fn(x) - ref1 = fn(x) - - random.seed(0) - opt_fn = torchdynamo.optimize("eager")(fn) - res0 = opt_fn(x) - res1 = opt_fn(x) - - self.assertTrue(same(ref0, res0)) - self.assertTrue(same(ref1, res1)) - - @unittest.skipIf(not HAS_REFS, "requires recent PT version") - @unittest.expectedFailure - def test_primtorch(self): - @torchdynamo.optimize("eager", nopython=True) - def fn(x): - torch._refs.abs(x) - - fn(torch.randn(3)) - - @unittest.skipIf( - not isinstance(torch.ops.aten.abs, torch._ops.OpOverloadPacket), - "old pt doesn't work", - ) - def test_torch_ops_aten(self): - # Picked an op that doesn't show up in the default list - @torchdynamo.optimize("eager", nopython=True) - def fn(x): - return torch.ops.aten.absolute(x) - - fn(torch.randn(3)) - - def test_guard_ordering_shape_fail(self): - # If a function which takes a tensor has an inner function which - # is compiled and generates a guard on its shape, - # they are evaluated in the wrong order. So if on a subsequent call - # an int is passed instead of a tensor, guard evaluation will crash - # with a "no attribute: shape" error - m = TestModule() - opt_m = torchdynamo.optimize("eager")(m) - opt_m.fn(torch.ones((5, 5))) - opt_m.fn(-3) - - def test_tensor_isinstance_tuple(self): - @torchdynamo.optimize("eager") - def fn(): - t = torch.ones(5, 5) - if not isinstance(t, (int, torch.Tensor)): - msg = str.format( - "{0} is not an instance of {1}", - type(t), - (int, torch.Tensor), - ) - raise ValueError(msg) - return True - - fn() - - def test_isinstance_dtype(self): - @torchdynamo.optimize("eager", nopython=True) - def fn(x): - isinstance(torch.bfloat16, torch.dtype) - return x - - fn(torch.randn(3)) - - def test_isinstance_storage(self): - @torchdynamo.optimize("eager") - def fn(x): - f = bytearray([0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x10, 0x40]) - bools = torch.BoolStorage.from_buffer(f, "big") - self.assertTrue(isinstance(bools, torch.BoolStorage)) - return x - - fn(torch.randn(3)) - - def test_dict_list_values(self): - def inner_fn(args): - return [x[1].shape for x in args] - - @torchdynamo.optimize("eager") - def fn(tensors): - return inner_fn(zip(itertools.count(), tensors["args"])) - - fn({"args": [torch.ones(5, 5), torch.ones(5, 6), torch.ones(5, 7)]}) - fn({"args": [torch.ones(5, 5)]}) - - def test_dict_iter(self): - class MyMod(torch.nn.Module): - def forward(self, x): - z = {"my": 1, "const": 2, "dict": 3, "variable": 4} - tot = 0 - for key in z: - tot += z[key] - - return tot - - x = torch.tensor([0]) - model = MyMod() - opt_model = torchdynamo.optimize("eager", nopython=True)(model) - y = opt_model(x) - - self.assertEqual(y, 10) - - def test_sort_out(self): - - dtype = torch.float32 - device = "cpu" - - def fn(): - tensor = torch.randn((3, 5), dtype=dtype, device=device)[:, 0] - values1 = torch.tensor(0, dtype=dtype, device=device) - indices1 = torch.tensor(0, dtype=torch.long, device=device) - torch.sort(tensor, out=(values1, indices1)) - self.assertEqual(values1.stride(), (1,)) - self.assertEqual(indices1.stride(), (1,)) - - fn() - opt_fn = torchdynamo.optimize("eager")(fn) - opt_fn() - - def test_sigmoid_out(self): - - dtype = torch.float32 - device = "cpu" - - def fn(): - inp = torch.randn((3, 5), dtype=dtype, device=device) - out1 = torch.tensor(0, dtype=dtype, device=device) - torch.sigmoid(inp, out=out1) - self.assertEqual(out1.numel(), 15) - - fn() - opt_fn = torchdynamo.optimize("eager")(fn) - opt_fn() - - def test_slice_into_list_mutable(self): - class Mod(torch.nn.Module): - def forward(self, listy): - x = listy[3:5] - for i in range(10): - z = torch.abs(torch.randn(10)) + 1 - x[0] = z - return x - - m = Mod() - listy = [torch.randn(10)] * 10 - - cnt = torchdynamo.testing.CompileCounter() - opt_m = torchdynamo.optimize(cnt, nopython=True)(m) - opt_m.forward(listy) - - self.assertEqual(cnt.frame_count, 1) - - def test_vdd_duplicate_error(self): - def fn(a, dt): - keys = list(dt._jt_dict.keys()) - p = torch.cos(dt._jt_dict[keys[0]]._value) - q = torch.sin(a) - r = torch.sigmoid(dt._jt_dict[keys[0]]._value) - return p + q + r - - class Value: - def __init__(self): - self._value = torch.randn(4) - - class Sample: - def __init__(self): - self._jt_dict = {} - self._jt_dict["POSITION_ID"] = Value() - - a = torch.randn(4) - sample = Sample() - - ref = fn(a, sample) - - optimized_fn = torchdynamo.optimize("eager", nopython=True)(fn) - res = optimized_fn(a, sample) - - self.assertTrue(same(ref, res)) - - @patch.object(torchdynamo.config, "fake_tensor_propagation", False) - def test_specialized_stride(self): - def f(): - e = torch.empty(4) - x = e[::2] - return x.stride() - - self.assertEqual(f(), torchdynamo.optimize("eager")(f)()) - - @unittest.skipIf(not has_detectron2(), "requires detectron2") - def test_multi_import(self): - @torchdynamo.optimize("eager", nopython=True) - def to_bitmasks(boxes): - from detectron2.layers.mask_ops import _paste_masks_tensor_shape - from detectron2.layers.mask_ops import paste_masks_in_image - - if ( - paste_masks_in_image is not None - and _paste_masks_tensor_shape is not None - ): - return boxes + 1 - - self.assertTrue((to_bitmasks(torch.zeros(10)) == torch.ones(10)).all()) - - def test_multi_dot_import(self): - def fn1(x): - return torch.sin(x) - - def fn(x): - import torch.fx - - _ = torch.fx.symbolic_trace(fn1) - return x * 2 - - x = torch.randn(10) - fn(x) - cnt = torchdynamo.testing.CompileCounter() - opt_fn = torchdynamo.optimize(cnt)(fn) - opt_fn(x) - self.assertEqual(cnt.frame_count, 1) - - def test_relative_import(self): - try: - from . import test_functions as _ # noqa: F401 - - def fn(x): - from .test_functions import tensor_for_import_testing - - return x * 2 * tensor_for_import_testing - - except ImportError: - - def fn(x): - from test_functions import tensor_for_import_testing - - return x * 2 * tensor_for_import_testing - - x = torch.randn(10) - fn(x) - cnt = torchdynamo.testing.CompileCounter() - opt_fn = torchdynamo.optimize(cnt, nopython=True)(fn) - opt_fn(x) - self.assertEqual(cnt.frame_count, 1) - - def test_relative_import_no_modulename(self): - try: - from . import test_functions as _ # noqa: F401 - - def fn(x): - from . import test_functions - - return x * 2 * test_functions.tensor_for_import_testing - - except ImportError: - - def fn(x): - import test_functions - - return x * 2 * test_functions.tensor_for_import_testing - - x = torch.randn(10) - fn(x) - cnt = torchdynamo.testing.CompileCounter() - opt_fn = torchdynamo.optimize(cnt, nopython=True)(fn) - opt_fn(x) - self.assertEqual(cnt.frame_count, 1) - - # This doesn't work without fake tensors but I don't care - @patch.object(torchdynamo.config, "fake_tensor_propagation", True) - def test_issue1466_size_aot_autograd(self): - def fn(x): - # do a tensor op and a size compute - y = x * 2 - x_size = x.size() - # trigger a graph break - print("arf") - # use the tensor op and size compute - z = y.view(x_size) + 1 - return z - - x = torch.randn(2, 3, requires_grad=True) - ref = fn(x) - opt_fn = torchdynamo.optimize("aot_eager")(fn) - res = opt_fn(x) - self.assertTrue(same(ref, res)) - - def test_ellipsis(self): - class Repro(torch.nn.Module): - def __init__(self): - super().__init__() - self.lnorm = torch.nn.LayerNorm( - (256,), eps=1e-06, elementwise_affine=True - ) - self.linear = torch.nn.Linear( - in_features=256, out_features=256, bias=True - ) - - def forward(self, cat_10): - lnorm = self.lnorm(cat_10) - getitem_64 = lnorm[ - (slice(None, None, None), slice(0, 1, None), Ellipsis) - ] - linear = self.linear(getitem_64) - return (linear,) - - args = [torch.randn(2, 197, 256)] - - mod = Repro() - opt_mod = torchdynamo.optimize("eager", nopython=True)(mod) - - self.assertTrue(same(mod(*args), opt_mod(*args))) - - def test_reinplacing(self): - class MockModule(torch.nn.Module): - def __init__(self): - super().__init__() - self.self_layoutlm_embeddings_x_position_embeddings = ( - torch.nn.Embedding(1024, 768) - ) - self.self_layoutlm_embeddings_y_position_embeddings = ( - torch.nn.Embedding(1024, 768) - ) - - def forward(self, getitem_1, getitem_2, add): - self_layoutlm_embeddings_x_position_embeddings = ( - self.self_layoutlm_embeddings_x_position_embeddings(getitem_1) - ) - self_layoutlm_embeddings_y_position_embeddings = ( - self.self_layoutlm_embeddings_y_position_embeddings(getitem_2) - ) - add_1 = add + self_layoutlm_embeddings_x_position_embeddings - add_2 = add_1 + self_layoutlm_embeddings_y_position_embeddings - return (add_2,) - - mod = MockModule() - opt_mod = torchdynamo.optimize("aot_inductor_debug")(mod) - - args = [ - ((2, 512), (2048, 4), torch.int64, "cpu", False), - ((2, 512), (2048, 4), torch.int64, "cpu", False), - ((2, 512, 768), (393216, 768, 1), torch.float32, "cpu", True), - ] - args = [ - rand_strided(sh, st, dt, dev).requires_grad_(rg) - for (sh, st, dt, dev, rg) in args - ] - self.assertTrue(same_two_models(mod, opt_mod, args)) - - -if __name__ == "__main__": - from torchdynamo.test_case import run_tests - - run_tests() diff --git a/test/dynamo/test_skip_non_tensor.py b/test/dynamo/test_skip_non_tensor.py deleted file mode 100644 index 23c109481b..0000000000 --- a/test/dynamo/test_skip_non_tensor.py +++ /dev/null @@ -1,113 +0,0 @@ -# Owner(s): ["module: dynamo"] -from unittest.mock import patch - -import torch - -import torchdynamo -import torchdynamo.test_case -from torchdynamo.testing import CompileCounter - - -class SkipNonTensorTests(torchdynamo.test_case.TestCase): - def test_add_tensor1(self): - def fn(a, b): - return a + b - - counter = CompileCounter() - x = torch.randn(4) - y = 5 - opt_fn = torchdynamo.optimize_assert(counter)(fn) - opt_fn(x, y) - - assert counter.op_count == 1 - - def test_add_tensor2(self): - def fn(a, b): - return torch.add(a, b) - - counter = CompileCounter() - - x = torch.randn(4) - y = 5 - opt_fn = torchdynamo.optimize_assert(counter)(fn) - opt_fn(x, y) - - assert counter.op_count == 1 - - def test_add_tensor_list(self): - def fn(lst): - return lst[0] + lst[1] - - counter = CompileCounter() - x = torch.randn(4) - y = 5 - opt_fn = torchdynamo.optimize_assert(counter)(fn) - opt_fn([x, y]) - - assert counter.op_count == 1 - - def test_add_tensor_dict(self): - def fn(dt): - return dt["a"] + dt["b"] - - counter = CompileCounter() - x = torch.randn(4) - y = 5 - opt_fn = torchdynamo.optimize_assert(counter)(fn) - opt_fn({"a": x, "b": y}) - - assert counter.op_count == 1 - - def test_add_skip(self): - def fn(a, b): - return a + b - - counter = CompileCounter() - opt_fn = torchdynamo.optimize_assert(counter)(fn) - x = 4 - y = 5 - opt_fn(x, y) - - assert counter.op_count == 0 - - @patch.object(torchdynamo.config, "raise_on_ctx_manager_usage", False) - def test_recursive_list(self): - def fn(x): - return x - - counter = CompileCounter() - - x = [] - x.append(x) - with torchdynamo.optimize_assert(counter): - fn(x) - - assert counter.op_count == 0 - - @patch.object(torchdynamo.config, "raise_on_ctx_manager_usage", False) - def test_custom_list(self): - def fn(x): - return x[0] + x[1] - - counter = CompileCounter() - - class Foo(list): - def __iter__(self): - raise Exception() - - def __len__(self): - raise Exception() - - x = Foo() - x.append(torch.randn(4)) - x.append(torch.randn(4)) - with torchdynamo.optimize_assert(counter): - fn(x) - - assert counter.op_count == 0 - - -if __name__ == "__main__": - from torchdynamo.test_case import run_tests - - run_tests() diff --git a/test/dynamo/test_subgraphs.py b/test/dynamo/test_subgraphs.py deleted file mode 100644 index 8f96a90033..0000000000 --- a/test/dynamo/test_subgraphs.py +++ /dev/null @@ -1,534 +0,0 @@ -# Owner(s): ["module: dynamo"] -import unittest -from unittest.mock import patch - -import torch - -import torchdynamo.test_case -import torchdynamo.testing -from torchdynamo import config -from torchdynamo.testing import unsupported -from torchdynamo.utils import disable_cache_limit - -globalmod = torch.nn.ReLU() - - -def indirectly_unsupported(a, b): - c = a + b - return unsupported(a, c) - - -class SubGraphTests(torchdynamo.test_case.TestCase): - def _common(self, fn, frame_count, op_count): - torchdynamo.reset() - v1 = torch.ones(10) - v2 = torch.ones(10) * -2.0 - correct1 = fn(v1, v2) - correct2 = fn(v2, v1) - cnt = torchdynamo.testing.CompileCounter() - opt_fn = torchdynamo.optimize(cnt)(fn) - r1 = opt_fn(v1, v2) - r2 = opt_fn(v2, v1) - self.assertTrue(torchdynamo.testing.same(r1, correct1)) - self.assertTrue(torchdynamo.testing.same(r2, correct2)) - self.assertEqual(cnt.frame_count, frame_count) - self.assertEqual(cnt.op_count, op_count) - - def test_control_flow1(self): - def fn(a, b): - c1 = a - b - c2 = b - a - if c1.sum() > c2.sum(): - return c1 - else: - return c2 - - self._common(fn, 1, 5) - - def test_control_flow2(self): - def fn(a, b): - if a.sum() > b.sum(): - return 1 - else: - return 2 - - self._common(fn, 1, 3) - - def test_control_flow3(self): - def fn(a, b): - c1 = a - b - c2 = b - a - m = globalmod - if c1.sum() > c2.sum(): - return m(c1) - else: - return m(c2) - - self._common(fn, 3, 7) - - def test_control_flow4(self): - def fn(a, b): - tmp1 = a.sum() > b.sum() and a.sum() > 0 - if tmp1: - return 1 - else: - return 2 - - self._common(fn, 3, 5) - - def test_control_flow5(self): - def fn(a, b): - tmp1 = a.sum() > b.sum() and a.sum() > 0 - tmp2 = a.sum() < b.sum() or b.sum() > 0 - if tmp1 and tmp2: - return 1, tmp1, tmp2 - else: - return 2, tmp1, tmp2 - - self._common(fn, 6, 13) - - def test_capi_call1(self): - def fn(a, b): - c1 = a - b - c2 = b - a - return unsupported(c1, c2) - - self._common(fn, 1, 2) - - def test_capi_call2(self): - def fn(a, b): - c1 = a - b - c2 = b - a - return a - (b - unsupported(c1, c2)) - - self._common(fn, 2, 4) - - def test_capi_call3(self): - def fn(a, b): - c1 = a - b - c2 = b - a - return torchdynamo.testing.unsupported(c1, c2) - - self._common(fn, 1, 2) - - def test_indirect_unsupported1(self): - def fn(a, b): - c1 = a - b - c2 = b - a - return indirectly_unsupported(c1, c2) - - self._common(fn, 2, 3) - - def test_indirect_unsupported2(self): - def fn(a, b): - local_const1 = 7 - local_const2 = 22 - c1 = a - b - c2 = b - a - return local_const1 / (local_const2 - indirectly_unsupported(c1, c2)) - - self._common(fn, 3, 5) - - def test_indirect_unsupported3(self): - def fn(a, b): - args = [a - b, b - a] - return indirectly_unsupported(*args) - - self._common(fn, 2, 3) - - def test_stack_state1(self): - def fn(a, b): - t1 = 1.23 * a - t2 = 4.56 * a - c1 = a - b - c2 = b - a - return t1 / (t2 - unsupported(c1, c2)) - - self._common(fn, 2, 6) - - def test_stack_state2(self): - def fn(a, b): - t1 = 1.23 * a - t2 = 4.56 * a - c1 = a - b - c2 = b - a - return t1 / (t2 - indirectly_unsupported(c1, c2)) - - self._common(fn, 3, 7) - - def test_multigraph(self): - def fn(a, b): - x = a + b - x = x / 2.0 - if x.sum() < 0: - return x * -1.0 - return x - - self._common(fn, 2, 5) - - def test_extended_args(self): - too_many_adds = "+".join(["a", "b"] * 256) - source = ( - f"lambda a, b: ({too_many_adds}+a if a.sum() > 0 else {too_many_adds} - b)" - ) - self._common(eval(source), 3, 1026) - - def test_resume1(self): - def fn(a, b): - x = a + b - x = x / 2.0 - x = x + 2.0 - x = unsupported(x, a) - x = x + 2.0 - x = x + 2.0 - x = x + 2.0 - return x - - self._common(fn, 2, 6) - - def test_resume2(self): - def fn(a, b): - x = a + b - x = x / 2.0 - x = x + 2.0 - x = indirectly_unsupported(x, a) - x = x + 2.0 - x = x + 2.0 - x = x + 2.0 - return x - - self._common(fn, 3, 7) - - def test_resume3(self): - def fn(a, b): - x = a + b - x = x / 2.0 - x = x + 2.0 - x = indirectly_unsupported(x, b=a) - x = x + 2.0 - x = x + 2.0 - x = x + 2.0 - return x - - self._common(fn, 3, 7) - - def test_resume4(self): - def fn(a, b): - x = a + b - x = x / 2.0 - x = x + 2.0 - x = indirectly_unsupported(a=x, b=a) - x = x + 2.0 - x = x + 2.0 - x = x + 2.0 - return x - - self._common(fn, 3, 7) - - def test_resume5(self): - def fn(a, b): - x = a + b - x = x / 2.0 - x = x + 2.0 - print(x) - x = x + 2.0 - x = x + 2.0 - x = x + 2.0 - return x - - self._common(fn, 2, 6) - - def test_start1(self): - def fn(a, b): - print(a) - x = a + b - x = x + 2.0 - x = x + 2.0 - return x - - self._common(fn, 1, 3) - - def test_start2(self): - def fn(a, b): - x = indirectly_unsupported(a, b) - x = x + 2.0 - x = x + 2.0 - x = x + 2.0 - return x - - self._common(fn, 2, 4) - - def test_start3(self): - def fn(a, b): - x = unsupported(a, b) - x = x + 2.0 - x = x + 2.0 - x = x + 2.0 - return x - - self._common(fn, 1, 3) - - def test_start4(self): - def fn(a, b, check): - if check: - return a + b + 10 - else: - return a + b - 10 - - v1 = torch.randn(10) - v2 = torch.randn(10) - f = torch.zeros(1, dtype=torch.int32) - t = torch.ones(1, dtype=torch.int32) - correct1 = fn(v1, v2, t) - correct2 = fn(v1, v2, f) - cnt = torchdynamo.testing.CompileCounter() - opt_fn = torchdynamo.optimize(cnt)(fn) - r1 = opt_fn(v1, v2, t) - r2 = opt_fn(v1, v2, f) - self.assertTrue(torchdynamo.testing.same(r1, correct1)) - self.assertTrue(torchdynamo.testing.same(r2, correct2)) - self.assertEqual(cnt.frame_count, 3) - self.assertEqual(cnt.op_count, 4) - - def test_resume_freevars(self): - c1 = torch.randn(10) - c2 = torch.randn(10) - - def fn(a, b): - x = a + b + (c1 - c2) - x = unsupported(x, x) - return x + (c1 - c2) - - self._common(fn, 2, 5) - - def test_restore_state(self): - def fn(a, b): - len_ = len - x = a + b - x = torch.add(unsupported(x, x), 1) - return a * x + len_(b) - - if config.dynamic_shapes: - self._common(fn, 2, 5) - else: - self._common(fn, 2, 4) - - def test_restore_range(self): - def fn(a, b): - x = a + b - rng = range(3, 8, 2) - x = unsupported(x, x) - for i in rng: - x = x + i - return x - - self._common(fn, 2, 4) - - def test_restore_range_iter(self): - def fn(a, b): - x = a + b - rng = iter(range(3, 8, 2)) - x = unsupported(x, x) - x += next(rng) - return x, list(rng) - - self._common(fn, 2, 2) - - def test_pop_after_resume(self): - def fn(a, b): - tmp = [a + 1, b + 2, a + b] - x = a - x = unsupported(x, x) - for i in range(3): - x += tmp.pop(-1) - return x - - self._common(fn, 2, 6) - - @disable_cache_limit() - def test_dynamic_shapes(self): - def fn(a, b): - return a - b * 10 - - torchdynamo.reset() - cnt_static = torchdynamo.testing.CompileCounter() - with patch("torchdynamo.config.dynamic_shapes", False): - opt_fn = torchdynamo.optimize(cnt_static)(fn) - for i in range(10): - opt_fn(torch.randn(i), torch.randn(i)) - self.assertEqual(cnt_static.frame_count, 10) - - torchdynamo.reset() - cnt_dynamic = torchdynamo.testing.CompileCounter() - with patch("torchdynamo.config.dynamic_shapes", True): - opt_fn = torchdynamo.optimize(cnt_dynamic)(fn) - for i in range(10): - opt_fn(torch.randn(i), torch.randn(i)) - # just one graph now rather than 10 - self.assertEqual(cnt_dynamic.frame_count, 1) - - @patch.object(torchdynamo.config, "capture_scalar_outputs", True) - def test_no_graph_break_on_item(self): - def fn(a, b): - x = a + b - 1.5 - x = x.sum() - x.item() - x = x / (a + b) - return x - - self._common(fn, 1, 6) - - @patch.object(torchdynamo.config, "capture_scalar_outputs", False) - def test_graph_break_on_item(self): - def fn(a, b): - x = a + b - 1.5 - x = x.sum() - x.item() - x = x / (a + b) - return x - - self._common(fn, 2, 5) - - def test_resume_paths_join(self): - def fn(x, c1, c2, c3): - x = x + 1 - if c1: - x = x + 2 - x = x + 3 - if c2: - x = x + 4 - x = x + 5 - if c3: - x = x + 6 - return x + 7 - - v1 = torch.randn(10) - t = torch.Tensor([True]) - f = torch.Tensor([False]) - cnt = torchdynamo.testing.CompileCounter() - opt_fn = torchdynamo.optimize(cnt)(fn) - for a in (t, f): - for b in (t, f): - for c in (t, f): - opt_fn(v1, a, b, c) - - # checking here we don't create 2^n graphs - self.assertEqual(cnt.frame_count, 7) - self.assertEqual(cnt.op_count, 10) - - def test_resume_with_no_grad1(self): - def fn(a, b): - x = a + b - with torch.no_grad(): - x = x + 1 - x.sum().tolist() # graph break - x = x + 2 - x = x + 3 - return x - - self._common(fn, 2, 9) - torchdynamo.reset() - with torch.no_grad(): - self._common(fn, 2, 9) - - def test_resume_with_no_grad2(self): - def fn(a, b): - x = a + b - with torch.no_grad(): - x = x + 1 - x.sum().tolist() # graph break - x = x + 2 - x.sum().tolist() # graph break - x = x + 3 - x = x + 4 - return x - - self._common(fn, 3, 13) - - def test_resume_with_no_grad3(self): - def fn(a, b): - x = a + b - with torch.no_grad(): - with torch.no_grad(): - x = x + 1 - with torch.enable_grad(): - x.sum().tolist() # graph break - x = x[0] + 2 - x = x + 3 - x = x + 4 - return x - - self._common(fn, 2, 19) - - def test_resume_tuple_iterator(self): - def fn(a, b): - x = a + b - it = iter(tuple(range(10))) - x = x + next(it) - x = x + next(it) - x = x + next(it) - x = unsupported(x, x) - x = x + next(it) - x = x + next(it) - x = x + next(it) - x = x + next(it) - return x - - self._common(fn, 2, 8) - - def test_tuple_iterator_return(self): - def fn(x): - it = iter(tuple(range(10))) - x = x + next(it) - x = x + next(it) - x = unsupported(x, x) - x = x + next(it) - x = x + next(it) - x = unsupported(x, x) - x = x + next(it) - x = x + next(it) - return x, it - - v1 = torch.randn(10) - v2, it2 = fn(v1) - cnt = torchdynamo.testing.CompileCounter() - opt_fn = torchdynamo.optimize(cnt)(fn) - v3, it3 = opt_fn(v1) - v4, it4 = opt_fn(v1) - self.assertEqual(v2.tolist(), v3.tolist()) - self.assertEqual(v2.tolist(), v4.tolist()) - self.assertEqual(list(it2), list(it3)) - self.assertEqual(cnt.frame_count, 3) - self.assertEqual(cnt.op_count, 6) - - @unittest.skip("not working yet") - def test_tuple_iterator_mutate(self): - def fn(x, it): - x = x + next(it) - x = x + next(it) - x = x + next(it) - x = x + next(it) - return x - - v1 = torch.randn(10) - it1 = iter(tuple(range(10))) - cnt = torchdynamo.testing.CompileCounter() - opt_fn = torchdynamo.optimize(cnt)(fn) - self.assertEqual(opt_fn(v1, it1).tolist(), (v1 + 1 + 2 + 3).tolist()) - self.assertEqual(list(it1), [4, 5, 6, 7, 8, 9]) - - def test_enumerate_not_break_graph(self): - def fn(a, b): - for i, x in enumerate(a.shape): - b = b + x - for i, x in enumerate(b.shape, 8): - b = b + x * i - return b - - self._common(fn, 1, 2) - - -if __name__ == "__main__": - from torchdynamo.test_case import run_tests - - run_tests() diff --git a/test/dynamo/test_unspec.py b/test/dynamo/test_unspec.py deleted file mode 100644 index 429cb53180..0000000000 --- a/test/dynamo/test_unspec.py +++ /dev/null @@ -1,228 +0,0 @@ -# Owner(s): ["module: dynamo"] -import functools -import random -import unittest -from unittest.mock import patch - -import numpy as np -import torch - -import torchdynamo.test_case -import torchdynamo.testing -from torchdynamo.testing import same - -try: - from . import test_modules - from . import test_repros -except ImportError: - import test_modules - import test_repros - - -def make_unspec_fn(fn): - @functools.wraps(fn) - def _fn(*args, **kwargs): - with patch.object(torchdynamo.config, "specialize_int_float", False): - return fn(*args, **kwargs) - - return _fn - - -def make_unspec_cls(cls): - class UnspecTest(cls): - pass - - UnspecTest.__name__ = f"Unspec{cls.__name__}" - - for name in dir(cls): - if name.startswith("test_"): - fn = getattr(cls, name) - if not callable(fn): - continue - new_name = f"{name}_unspec" - fn = make_unspec_fn(fn) - fn.__name__ = new_name - setattr(UnspecTest, name, None) - setattr(UnspecTest, new_name, fn) - - return UnspecTest - - -UnspecReproTests = make_unspec_cls(test_repros.ReproTests) -UnspecNNModuleTests = make_unspec_cls(test_modules.NNModuleTests) - - -@patch.object(torchdynamo.config, "specialize_int_float", False) -class UnspecTests(torchdynamo.test_case.TestCase): - def test_numpy_correctness(self): - def fn(x, y, z): - xy = [x + y, y, False] - np_x = x.numpy() - np_y = y.numpy() - return { - "x": x, - "z": z, - "a": np_y.sum(), - "b": xy, - "c": np_y[0][0] / 68, - "d": np_x.sum(), - }, x + np_y.sum() + z - - x = torch.tensor([[1.0, 2.0], [3.0, 4.0]], dtype=torch.float64) - y = torch.ones([2, 2], dtype=torch.int64) - z = np.int64(12) - res1 = fn(x, y, z) - cnts = torchdynamo.testing.CompileCounter() - opt_fn = torchdynamo.optimize(cnts)(fn) - res2 = opt_fn(x, y, z) - self.assertTrue(same(res1, res2)) - - def test_no_recompilations(self): - # no recompilations if passing on different numpy int values - def fn(x, y): - return {"a": x + 1, "b": y / 2} - - x = torch.tensor([[1.0, 2.0], [3.0, 4.0]], dtype=torch.float64) - cnts = torchdynamo.testing.CompileCounter() - opt_fn = torchdynamo.optimize(cnts)(fn) - for i in range(10): - opt_fn(x, np.int64(i)) - self.assertEqual(cnts.frame_count, 1) - self.assertEqual(cnts.op_count, 2) - - def test_builtin_max_min(self): - # test unspecialized primitive max/min - def fn(x, y, z): - return z + 1, max(x, y), min(x - 4, y) - - x = np.int64(12) - y = 10 - z = torch.tensor([[1.0, 2.0], [3.0, 4.0]], dtype=torch.float64) - res1 = fn(x, y, z) - cnts = torchdynamo.testing.CompileCounter() - opt_fn = torchdynamo.optimize(cnts)(fn) - res2 = opt_fn(x, y, z) - self.assertTrue(same(res1, res2)) - - def test_feed_random_values_into_graph_only(self): - def fn(shape): - torch.manual_seed(123) - x = torch.randn(shape, device="cpu") * random.randint(30, 100) - return x - - shape = [2, 3] - random.seed(1) - res1 = fn(shape) - cnts = torchdynamo.testing.CompileCounter() - opt_fn = torchdynamo.optimize(cnts)(fn) - random.seed(1) - res2 = opt_fn(shape) - - self.assertTrue(same(res1, res2)) - - def test_random_values_with_graph_break(self): - def fn(x): - r1 = random.random() - y = x + random.uniform(10, 20) - y.sum().item() - r2 = random.randint(2, 18) # no graph output in this frame - y.sum().item() - return y + r1, r2 - - x = torch.tensor([[1.0, 2.0], [3.0, 4.0]]) - random.seed(1) - res1 = fn(x) - cnts = torchdynamo.testing.CompileCounter() - opt_fn = torchdynamo.optimize(cnts)(fn) - random.seed(1) - res2 = opt_fn(x) - self.assertTrue(same(res1, res2)) - - @patch.object(torchdynamo.config, "fake_tensor_propagation", False) - def test_multiple_consecutive_random_calls_before_graph(self): - def fn(x): - dim1 = random.randrange(start=0, stop=5) - dim2 = random.randrange(start=0, stop=5) - dim3 = random.randrange(start=0, stop=5) - y = torch.rand(dim1, dim2, dim3) - return x + 2, y - - x = torch.tensor([[1.0, 2.0], [3.0, 4.0]]) - random.seed(1) - res1 = fn(x) - cnts = torchdynamo.testing.CompileCounter() - opt_fn = torchdynamo.optimize(cnts)(fn) - random.seed(1) - res2 = opt_fn(x) - self.assertTrue(same(res1, res2)) - - def test_random_call_with_while_loop(self): - def fn(x): - dim1 = random.randrange(start=0, stop=3) - dim2 = dim1 - while dim1 == dim2: - dim2 = random.randrange(start=0, stop=3) - return x * 2 - - x = torch.randn(4) - random.seed(1) - res1 = fn(x) - opt_fn = torchdynamo.optimize("eager")(fn) - random.seed(1) - res2 = opt_fn(x) - self.assertTrue(same(res1, res2)) - - def test_builtin_getitem(self): - # builtin getitem args[0] is python list and args[1] is unspec - def fn(x, idx): - return (torch.zeros(idx), x[idx], x[idx:]) - - x = list(range(50)) - ref = fn(x, 48) # 48 is unspecialized - cnts = torchdynamo.testing.CompileCounter() - opt_fn = torchdynamo.optimize(cnts)(fn) - res = opt_fn(x, 48) - self.assertTrue(same(ref, res)) - - @unittest.skipIf(not torch.cuda.is_available(), "requires cuda") - def test_builtin_functions_on_cuda(self): - def fn(x, scaler): - m = torch.nn.ReLU() - y = m(x) * scaler - return y - - x = torch.randn([3, 6], device="cuda") - scaler = 0.23 # 0.23 is unspecialized - ref = fn(x, scaler) - cnts = torchdynamo.testing.CompileCounter() - opt_fn = torchdynamo.optimize(cnts)(fn) - res = opt_fn(x, scaler) - self.assertTrue(same(ref, res)) - self.assertEqual(ref.device, res.device) - - def test_unspec_float_precision(self): - def fn(image, scale_factor): - image = torch.nn.functional.interpolate( - image[None], - size=None, - scale_factor=scale_factor, - mode="bilinear", - recompute_scale_factor=True, - align_corners=False, - )[0] - - return image.shape - - x = torch.rand([3, 427, 640]) - scale_factor = 1.873536229133606 - ref = fn(x, scale_factor) - cnts = torchdynamo.testing.CompileCounter() - opt_fn = torchdynamo.optimize(cnts)(fn) - res = opt_fn(x, scale_factor) - self.assertTrue(same(ref, res)) - - -if __name__ == "__main__": - from torchdynamo.test_case import run_tests - - run_tests() diff --git a/test/dynamo/test_verify_correctness.py b/test/dynamo/test_verify_correctness.py deleted file mode 100644 index cdf41f9680..0000000000 --- a/test/dynamo/test_verify_correctness.py +++ /dev/null @@ -1,175 +0,0 @@ -# Owner(s): ["module: dynamo"] -import importlib -import operator -import unittest -from unittest.mock import patch - -import torch - -import torchdynamo -import torchdynamo.config as config -import torchdynamo.test_case -from torchdynamo.optimizations import backends -from torchdynamo.testing import same - - -def has_onnxruntime(): - try: - importlib.import_module("onnxruntime") - return True - except ImportError: - return False - - -def has_ipex(): - try: - importlib.import_module("intel_extension_for_pytorch") - return True - except ImportError: - return False - - -class Seq(torch.nn.Module): - def __init__(self): - super().__init__() - self.layers = torch.nn.Sequential( - torch.nn.Linear(10, 10), - torch.nn.ReLU(), - torch.nn.Linear(10, 10), - torch.nn.Sigmoid(), - ) - - def forward(self, x): - return self.layers(x) - - -class Conv_Bn_Relu(torch.nn.Module): - def __init__(self, in_channels, out_channels, **kwargs): - super(Conv_Bn_Relu, self).__init__() - self.conv = torch.nn.Conv2d(in_channels, out_channels, bias=False, **kwargs) - self.bn = torch.nn.BatchNorm2d(out_channels, eps=0.001) - self.relu = torch.nn.ReLU() - - def forward(self, x): - return self.relu(self.bn(self.conv(x))) - - -def toy_example(a, b): - x = a / (torch.abs(a) + 1) - if b.sum() < 0: - b = b * -1 - return x * b - - -def transform(gm: torch.fx.GraphModule) -> torch.fx.GraphModule: - for node in gm.graph.nodes: - # Checks if we're calling a function (i.e: - # operator.add) - if node.op == "call_function": - # The target attribute is the function - # that call_function calls. - if node.target == operator.mul: - node.target = operator.add - - gm.graph.lint() # Does some checks to make sure the - # Graph is well-formed. - - gm.recompile() - return gm - - -class TestVerifyCorrectness(torchdynamo.test_case.TestCase): - @patch.object(config, "verify_correctness", True) - def test_example_inputs(self): - def fn(a, bc, d): - b, c = bc - return a / d - b / c - - def compiler_fn(graph, example_inputs): - nonlocal r1 - r1 = graph(*example_inputs)[0] - return graph.forward - - a = torch.empty(2).fill_(1) - b = torch.empty(2).fill_(2) - c = torch.empty(2).fill_(3) - d = 4 - r1 = None - r2 = fn(a, (b, c), d) - opt_fn = torchdynamo.optimize_assert(compiler_fn)(fn) - r3 = opt_fn(a, (b, c), d) - - self.assertIsNotNone(r1) - self.assertTrue(same(r1, r2)) - self.assertTrue(same(r1, r3)) - - @patch.object(config, "verify_correctness", True) - def test_nnc(self): - s = Seq() - i = torch.randn(10) - r1 = s(i) - opt_s = torchdynamo.optimize("nnc")(s) - r2 = opt_s(i) - self.assertTrue(same(r1, r2)) - - @patch.object(config, "verify_correctness", True) - def test_incorrect_verify_true(self): - """ - If a bad optimization return a graph that - is not functionally equal to the original graph; - When config.verify_correctness=True, it will - check the correctness of outputs and raise an error - """ - i1 = torch.randn(10) - i2 = torch.randn(10) - - def incorrect_compile_fn(gm, example_inputs): - return transform(gm).forward - - toy_example(i1, i2) - try: - opt_toy_example = torchdynamo.optimize(incorrect_compile_fn)(toy_example) - opt_toy_example(i1, i2) - except RuntimeError: - pass - else: - self.fail("expected failure") - - @patch.object(config, "verify_correctness", False) - def test_incorrect_verify_false(self): - """ - The bad optimization return a graph that - is not functionally equal to the original graph; - When config.verify_correctness=False, wrong outputs - will return - """ - i1 = torch.randn(10) - i2 = torch.randn(10) - - def incorrect_compile_fn(gm, example_inputs): - return transform(gm).forward - - r1 = toy_example(i1, i2) - opt_toy_example = torchdynamo.optimize(incorrect_compile_fn)(toy_example) - r2 = opt_toy_example(i1, i2) - self.assertTrue(not same(r1, r2)) - - @unittest.skipIf(not has_ipex(), "requires ipex") - @patch.object(config, "verify_correctness", True) - def test_ipex_fp32(self): - model = Conv_Bn_Relu(3, 32, kernel_size=3, stride=1) - model = model.to(memory_format=torch.channels_last) - model = model.eval() - input = torch.randn(8, 3, 64, 64).contiguous(memory_format=torch.channels_last) - r1 = model(input) - opt_model = torchdynamo.optimize(backends.ipex_fp32)(model) - with torch.no_grad(): - r2 = opt_model(input) - self.assertTrue(same(r1, r2)) - self.assertEqual(r2.dtype, torch.float32) - - -if __name__ == "__main__": - from torchdynamo.test_case import run_tests - - run_tests() diff --git a/test/inductor/__init__.py b/test/inductor/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/test/inductor/cpp/.gitignore b/test/inductor/cpp/.gitignore deleted file mode 100644 index 37b0b62a96..0000000000 --- a/test/inductor/cpp/.gitignore +++ /dev/null @@ -1,13 +0,0 @@ -CMakeLists.txt.user -CMakeCache.txt -CMakeFiles -CMakeScripts -Testing -Makefile -cmake_install.cmake -install_manifest.txt -compile_commands.json -CTestTestfile.cmake -_deps -lib -bin diff --git a/test/inductor/cpp/CMakeLists.txt b/test/inductor/cpp/CMakeLists.txt deleted file mode 100644 index cc4954fc89..0000000000 --- a/test/inductor/cpp/CMakeLists.txt +++ /dev/null @@ -1,47 +0,0 @@ -project(my-project LANGUAGES C CXX) - -# Build output setup -set(CMAKE_ARCHIVE_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/test/lib) -set(CMAKE_LIBRARY_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/test/lib) -set(CMAKE_RUNTIME_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/test/bin) - -# TODO(voz): Fix hack below -# Start hack -list(APPEND policies_new CMP0079) - -foreach(policy ${policies_new}) - if(POLICY ${policy}) - cmake_policy(SET ${policy} NEW) - endif() -endforeach() -# End hack - -################################ -# GTest -################################ -project(googletest-git NONE) - -include(FetchContent) -FetchContent_Declare( - googletest - GIT_REPOSITORY https://github.com/google/googletest.git - GIT_TAG release-1.12.1 -) - -set(gtest_force_shared_crt ON CACHE BOOL "" FORCE) -set(BUILD_GMOCK OFF CACHE BOOL "" FORCE) -set(BUILD_GTEST ON CACHE BOOL "" FORCE) - -FetchContent_MakeAvailable(googletest) - - - -################################ -# Tests -################################ - -# TODO(voz): This is a little assumptive of just this one test, rewrite with real dir includes -include_directories(${ATEN_INCLUDE}) -add_executable(test_cpp_prefix test_cpp_prefix.cpp ../../torchinductor/codegen/cpp_prefix.h) -target_link_libraries(test_cpp_prefix gtest gtest_main) -add_test(NAME test_cpp_prefix COMMAND test_cpp_prefix) diff --git a/test/inductor/cpp/test.sh b/test/inductor/cpp/test.sh deleted file mode 100755 index 3fd42414c8..0000000000 --- a/test/inductor/cpp/test.sh +++ /dev/null @@ -1,7 +0,0 @@ -#!/bin/bash -set -euo pipefail -IFS=$'\n\t' - -cmake . -DATEN_INCLUDE:PATH=$(python -c "import torch; from torch.utils import cpp_extension; print(cpp_extension.include_paths()[0])") -make -./test/bin/test_cpp_prefix diff --git a/test/inductor/cpp/test_cpp_prefix.cpp b/test/inductor/cpp/test_cpp_prefix.cpp deleted file mode 100644 index 08d379fe3a..0000000000 --- a/test/inductor/cpp/test_cpp_prefix.cpp +++ /dev/null @@ -1,21 +0,0 @@ -#include "../../torchinductor/codegen/cpp_prefix.h" -#include - -TEST(testCppPrefix, testAtomicAddInt) { - int x = 0; - atomic_add(&x, 100); - EXPECT_EQ(x, 100); -} - -TEST(testCppPrefix, testAtomicAddFloat) { - float x = 0.0f; - atomic_add(&x, 100.0f); - EXPECT_EQ(x, 100.0f); -} - -TEST(testCppPrefix, testAtomicAddI64) { - int64_t x = 0.0; - int64_t y = 100.0; - atomic_add(&x, y); - EXPECT_EQ(x, 100); -} diff --git a/test/inductor/test_torchinductor.py b/test/inductor/test_torchinductor.py deleted file mode 100644 index d3e720b418..0000000000 --- a/test/inductor/test_torchinductor.py +++ /dev/null @@ -1,4061 +0,0 @@ -# Owner(s): ["module: inductor"] -import contextlib -import dataclasses -import functools -import importlib -import os -import random -import sys -import unittest -import weakref -from unittest.mock import patch - -import torch -from torch.fx.experimental.proxy_tensor import make_fx -from torch.nn import functional as F -from torch.testing._internal.common_utils import TEST_WITH_ASAN -from torch.testing._internal.common_utils import TEST_WITH_ROCM -from torch.testing._internal.common_utils import TestCase as TorchTestCase -from torch.utils._python_dispatch import TorchDispatchMode -from torch.utils._pytree import tree_flatten -from torch.utils._pytree import tree_unflatten - -import torchdynamo -from torchdynamo.debug_utils import same_two_models -from torchdynamo.testing import rand_strided -from torchdynamo.testing import same - -try: - import sympy - - importlib.import_module("functorch") - importlib.import_module("filelock") - - from functorch.compile import config as functorch_config - from torch._decomp import get_decompositions - - import torchinductor.config - from torchinductor import config - from torchinductor.compile_fx import compile_fx - from torchinductor.ir import IndexingDiv - from torchinductor.ir import ModularIndexing - from torchinductor.sizevars import SizeVarAllocator - from torchinductor.utils import has_torchvision_roi_align - from torchinductor.utils import has_triton - from torchinductor.utils import timed - - # This will only pass on pytorch builds newer than roughly 5/15/2022 - assert get_decompositions([torch.ops.aten.trace]) - # Requires functorch - from torchinductor.compile_fx import compile_fx_inner -except (ImportError, AssertionError) as e: - sys.stderr.write(f"{type(e)}: {e}\n") - if __name__ == "__main__": - sys.exit(0) - raise unittest.SkipTest("requires sympy/functorch/filelock") - - -HAS_CPU = False -try: - from subprocess import CalledProcessError - - from torchinductor.codecache import CppCodeCache - - CppCodeCache.load("") - HAS_CPU = True -except ( - CalledProcessError, - OSError, - torchinductor.exc.InvalidCxxCompiler, - torchinductor.exc.CppCompileError, -): - pass - -aten = torch.ops.aten - -HAS_CUDA = has_triton() -requires_cuda = functools.partial(unittest.skipIf, not HAS_CUDA, "requires cuda") - -torchinductor.config.triton.autotune = False # too slow - - -def requires_decomp(fn): - """Decorator to disable test if a decomp is missing""" - - def wrap_test(test): - @functools.wraps(test) - def maybe_test(*args, **kwargs): - if len(get_decompositions([fn])) == 0: - raise unittest.SkipTest(f"requires decomp for {fn.__name__}") - return test(*args, **kwargs) - - return maybe_test - - return wrap_test - - -class TestCase(TorchTestCase): - @classmethod - def setUpClass(cls): - super().setUpClass() - cls._stack = contextlib.ExitStack() - cls._stack.enter_context(patch.object(config, "debug", True)) - cls._stack.enter_context(patch.object(config.cpp, "min_chunk_size", 1)) - - @classmethod - def tearDownClass(cls): - cls._stack.close() - super().tearDownClass() - - -class ToTuple(torch.nn.Module): - def forward(self, x): - return (x,) - - -@dataclasses.dataclass -class InputGen: - n: int - device: str - - def dense(self): - return torch.randn((self.n, self.n), device=self.device) - - def transposed(self): - return self.dense().transpose(0, 1) - - def strided(self): - return torch.randn((self.n * 2, self.n * 3), device=self.device)[ - self.n :, self.n :: 2 - ] - - def broadcast1(self): - return torch.randn((self.n,), device=self.device) - - def broadcast2(self): - return torch.randn((1, self.n, 1), device=self.device) - - def broadcast3(self): - return torch.randn((1,), device=self.device) - - def double(self): - return torch.randn((self.n, self.n), device=self.device, dtype=torch.double) - - def int(self): - return torch.arange(self.n, device=self.device, dtype=torch.int32) - - -def compute_grads(args, kwrags, results, grads): - def gather_leaf_tensors(args, kwargs): - args, _ = tree_flatten(args) - kwargs, _ = tree_flatten(kwargs) - args = args + kwargs - leaf_tensors = [ - arg for arg in args if isinstance(arg, torch.Tensor) and arg.requires_grad - ] - return leaf_tensors - - flat_results, _ = tree_flatten(results) - flat_diff_results = [r for r in flat_results if r.requires_grad] - assert len(flat_diff_results) > 0 - - leaf_tensors = gather_leaf_tensors(args, kwrags) - assert len(leaf_tensors) > 0 - return torch.autograd.grad( - flat_diff_results, - leaf_tensors, - grads, - allow_unused=True, - retain_graph=True, - ) - - -@patch.object(torchinductor.config.triton, "cudagraphs", False) -@patch("torchdynamo.config.raise_on_backend_error", True) -def check_model( - self: TestCase, - model, - example_inputs, - kwargs=None, - *, - atol=None, - rtol=None, - check_lowp=True, - exact_dtype=True, - nopython=True, - copy_to_cuda=True, - reference_in_float=True, - assert_equal=True, - check_gradient=False, -): - kwargs = kwargs or {} - torchdynamo.reset() - - ref_inputs = example_inputs - ref_kwargs = kwargs - has_lowp_args = False - - if reference_in_float: - # check_lowp is ignored here, it's kept just to be able to call `common` with extra arg - def upcast_fn(x): - nonlocal has_lowp_args - if isinstance(x, torch.Tensor) and ( - x.dtype == torch.float16 or x.dtype == torch.bfloat16 - ): - has_lowp_args = True - return x.float() - else: - return x - - ref_inputs = list(map(upcast_fn, example_inputs)) - ref_kwargs = {k: upcast_fn(v) for k, v in kwargs.items()} - if has_lowp_args: - if hasattr(model, "to"): - model = model.to(torch.float) - - torch.manual_seed(0) - - correct = model(*ref_inputs, **ref_kwargs) - # downcast the model back if needed - if reference_in_float and has_lowp_args: - if hasattr(model, "to"): - model = model.to(torch.half) - - torchinductor.metrics.reset() - - called = False - - def compile_fx_wrapper(model_, example_inputs_): - nonlocal called - called = True - return compile_fx(model_, example_inputs_) - - def run(*ex, **kwargs): - return model(*ex, **kwargs) - - run = torchdynamo.optimize(compile_fx_wrapper, nopython=nopython)(run) - - torch.manual_seed(0) - actual = run(*example_inputs, **kwargs) - # if not called: - # exp = torchdynamo.explain(run, *example_inputs) - # print("Explain:", exp[0]) - # for graph in exp[2]: - # print("Graph", graph) - assert called, "Ran graph without calling compile_fx" - assert type(actual) == type(correct) - - correct_flat, correct_spec = tree_flatten(correct) - actual_flat, _ = tree_flatten(actual) - if reference_in_float: - correct_flat = tuple( - y.to(x.dtype) - if isinstance(y, torch.Tensor) and y.dtype.is_floating_point - else y - for x, y in zip(actual_flat, correct_flat) - ) - correct = tree_unflatten(correct_flat, correct_spec) - - if assert_equal: - self.assertEqual( - actual, - correct, - atol=atol, - rtol=rtol, - equal_nan=True, - exact_dtype=exact_dtype, - ) - else: - for correct_val, actual_val in zip(correct_flat, actual_flat): - if isinstance(correct_val, torch.Tensor): - assert correct_val.device == actual_val.device - assert correct_val.size() == actual_val.size() - assert correct_val.stride() == actual_val.stride() - assert correct_val.layout == actual_val.layout - if exact_dtype: - assert correct_val.dtype == actual_val.dtype - - if check_gradient: - - # generate random unit norm gradients - grads = [ - torch.rand(r.shape, device=r.device, dtype=r.dtype) - for r in correct_flat - if r.requires_grad - ] - for g in grads: - g /= g.norm() - - correct_grad = compute_grads(ref_inputs, ref_kwargs, correct, grads) - actual_grad = compute_grads(example_inputs, kwargs, actual, grads) - - self.assertEqual( - actual_grad, - correct_grad, - atol=atol, - rtol=rtol, - equal_nan=True, - exact_dtype=exact_dtype, - ) - - torchdynamo.reset() - - -@patch.object(torchinductor.config.triton, "cudagraphs", False) -def check_model_cuda( - self: TestCase, - model, - example_inputs, - kwargs=None, - *, - atol=None, - rtol=None, - check_lowp=True, - exact_dtype=True, - nopython=True, - copy_to_cuda=True, - reference_in_float=True, - assert_equal=True, - check_gradient=False, -): - kwargs = kwargs or {} - if hasattr(model, "to"): - model = model.to("cuda") - - def copy_fn(x): - # preserve strides of the input on the device - if not isinstance(x, torch.Tensor): - return x - return torch.empty_strided( - x.size(), x.stride(), device="cuda", dtype=x.dtype - ).copy_(x) - - if copy_to_cuda: - example_inputs = tuple(copy_fn(x) for x in example_inputs) - - check_model( - self, - model, - example_inputs, - kwargs, - atol=atol, - rtol=rtol, - exact_dtype=exact_dtype, - nopython=nopython, - reference_in_float=reference_in_float, - assert_equal=assert_equal, - check_gradient=check_gradient, - ) - - if check_lowp: - - def downcast_fn(x): - if not isinstance(x, torch.Tensor) or not x.dtype == torch.float: - return x - return torch.empty_strided( - x.size(), x.stride(), device="cuda", dtype=torch.half - ).copy_(x) - - example_inputs = list(map(downcast_fn, example_inputs)) - if hasattr(model, "to"): - model = model.to(torch.half) - check_model( - self, - model, - example_inputs, - kwargs, - atol=atol, - rtol=rtol, - exact_dtype=exact_dtype, - nopython=nopython, - reference_in_float=reference_in_float, - assert_equal=assert_equal, - check_gradient=check_gradient, - ) - - -class SweepInputs2: - input_gen_types1 = [ - "dense", - "transposed", - "strided", - "broadcast1", - "broadcast2", - "broadcast3", - "double", - "int", - ] - input_gen_types2 = input_gen_types1 - gen = None - - @staticmethod - def kernel(a, b): - return (a + b,) - - @classmethod - def gen_template(cls, name1, name2): - def test(self): - check_model( - self, - cls.kernel, - ( - getattr(cls.gen, name1)(), - getattr(cls.gen, name2)(), - ), - ) - - test.__name__ = f"test_{cls.gen.device}_{name1}_{name2}" - setattr(cls, test.__name__, test) - - @classmethod - def populate(cls): - for name1 in cls.input_gen_types1: - for name2 in cls.input_gen_types2: - cls.gen_template(name1, name2) - - -class SweepInputsCpuTest(SweepInputs2, TestCase): - gen = InputGen(10, "cpu") - - -SweepInputsCpuTest.populate() - - -class TestIndexingSimplification(TorchTestCase): - def test_indexing_simplification(self): - sizevars = SizeVarAllocator() - i0 = sympy.Symbol("i0") - i1 = sympy.Symbol("i1") - i2 = sympy.Symbol("i2") - r3 = sympy.Symbol("r3") - - var_ranges = {i0: 3136, i1: 64, i2: 32, r3: 3} - expr = ( - 128 * i2 - + ModularIndexing(i1, 1, 64) - + 64 * ModularIndexing(i1 + 64 * r3, 64, 2) - ) - # check that `i1//64` is removed when i1 is always less than 64, - # and the next simplificaton doesn't happen - self.assertEqual( - sizevars.simplify_with_ranges(expr, var_ranges), - i1 + 128 * i2 + 64 * ModularIndexing(r3, 1, 2), - ) - # all the modular indexing should be removed when the body cant be larger than the modulus - var_ranges[r3] = 2 - self.assertEqual( - sizevars.simplify_with_ranges(expr, var_ranges), i1 + 128 * i2 + 64 * r3 - ) - - # small terms should be kept if the rest is not guaranteed to be divisible - self.assertEqual( - sizevars.simplify_with_ranges(IndexingDiv(r3 + i2 + i1, 32), var_ranges), - IndexingDiv(r3 + i2 + i1, 32), - ) - - expr = ModularIndexing(2 * i2 + r3, 1, 64) - # modular indexing is removed if base is smaller than modulo - self.assertEqual(sizevars.simplify_with_ranges(expr, var_ranges), 2 * i2 + r3) - - # check the same thing but with symbolic divisor - self.assertEqual(IndexingDiv(r3 * i0, r3), i0) - self.assertEqual(ModularIndexing(r3 * i0, r3, 10), ModularIndexing(i0, 1, 10)) - - # (10*i) % 10 is always zero and should get optimized away - self.assertEqual( - ModularIndexing(i0 + i1 * 10, 1, 10), ModularIndexing(i0, 1, 10) - ) - - # ((20*i)//2) % 10 is always zero and should get optimized away - self.assertEqual( - ModularIndexing(i0 + i1 * 20, 2, 10), ModularIndexing(i0, 2, 10) - ) - - # the same things happens with symbolic divisor - self.assertEqual( - ModularIndexing(i0 + i1 * i2 * r3, i2, r3), ModularIndexing(i0, i2, r3) - ) - - # Constant fold from divisor into base - self.assertEqual(ModularIndexing(i0 * 4, 2, 10), ModularIndexing(i0 * 2, 1, 10)) - self.assertEqual(IndexingDiv(i0 * 4, 2), i0 * 2) - - # Nested modular indexing is correctly simplified - var_ranges = {"i1": 13, "i2": 121} - expr = ModularIndexing(ModularIndexing(121 * i1 + i2, 1, 784), 1, 28) - self.assertEqual(sizevars.simplify_with_ranges(expr, var_ranges), expr) - expr = ModularIndexing(ModularIndexing(121 * i1 + i2, 1, 784) + 1, 1, 28) - self.assertEqual(sizevars.simplify_with_ranges(expr, var_ranges), expr) - var_ranges = {"i2": 784} - expr = ModularIndexing(ModularIndexing(i2, 1, 28), 7, 4) - expected = IndexingDiv(ModularIndexing(i2, 1, 28), 7) - self.assertEqual(sizevars.simplify_with_ranges(expr, var_ranges), expected) - expr = ModularIndexing(ModularIndexing(i2, 1, 28) + 1, 7, 4) - self.assertEqual(sizevars.simplify_with_ranges(expr, var_ranges), expr) - - def test_indexing_join(self): - sizevars = SizeVarAllocator() - i0 = sympy.Symbol("i0") - i1 = sympy.Symbol("i1") - i2 = sympy.Symbol("i2") - - # join two ModularIndexing calls into one larger one when possible - expr1 = ModularIndexing(i0, 1, 32) + 32 * ModularIndexing(i0, 32, 4) - self.assertEqual( - sizevars.simplify_with_ranges(expr1, {}), ModularIndexing(i0, 1, 128) - ) - - # it should also work with a scale - self.assertEqual( - sizevars.simplify_with_ranges(2 * expr1, {}), - 2 * ModularIndexing(i0, 1, 128), - ) - - # it should work when divisor is not 1 - expr2 = ModularIndexing(i0, 3, 32) + 32 * ModularIndexing(i0, 32 * 3, 4) - simplified = sizevars.simplify_with_ranges(expr2, {}) - self.assertEqual(simplified, ModularIndexing(i0, 3, 128)) - self.assertEqual(expr2.subs({i0: 39485}), simplified.subs({i0: 39485})) - - # it should not happen in this case as the modulus is wrong - expr3 = ModularIndexing(i0, 1, 30) + 32 * ModularIndexing(i0, 32, 4) - self.assertEqual(sizevars.simplify_with_ranges(expr3, {}), expr3) - - # check that it also works with a modulus>1 - expr4 = ModularIndexing(i0, 10, i1) + i1 * ModularIndexing(i0, i1 * 10, i2) - res0 = expr4.subs({i0: 24056, i1: 13, i2: 19}) - simplified = sizevars.simplify_with_ranges(expr4, {}) - res1 = simplified.subs({i0: 24056, i1: 13, i2: 19}) - self.assertEqual(res0, res1) - self.assertEqual(simplified, ModularIndexing(i0, 10, i1 * i2)) - - # and also works with an offset - self.assertEqual( - sizevars.simplify_with_ranges(expr4 + 10, {}), - ModularIndexing(i0, 10, i1 * i2) + 10, - ) - - # works for ModularIndexing + IndexingDiv - expr5 = 197 * IndexingDiv(i0, 197) + ModularIndexing(i0, 1, 197) - simplified = sizevars.simplify_with_ranges(expr5, {}) - self.assertEqual(simplified, i0) - self.assertEqual(expr5.subs({i0: 39485}), simplified.subs({i0: 39485})) - - # works with a scale - self.assertEqual( - sizevars.simplify_with_ranges(2 * expr5, {}), - 2 * i0, - ) - - # divisor != 1 - expr6 = 197 * IndexingDiv(i0, 197 * 3) + ModularIndexing(i0, 3, 197) - simplified = sizevars.simplify_with_ranges(expr6, {}) - self.assertEqual(simplified, IndexingDiv(i0, 3)) - self.assertEqual(expr6.subs({i0: 39485}), simplified.subs({i0: 39485})) - - -class CommonTemplate: - @classmethod - def install(my_cls, other_cls, suffix): # noqa: B902 - for name, value in my_cls.__dict__.items(): - if name.startswith("test_"): - setattr(other_cls, f"{name}_{suffix}", value) - - def test_bool(self): - def fn(a, b): - return ( - a + b, - a * b, - a & b, - a | b, - a ^ b, - torch.logical_and(a, b), - torch.logical_or(a, b), - torch.logical_not(a), - torch.sign(b), - ) - - self.common( - fn, - ( - torch.tensor([True, False, True, False]), - torch.tensor([False, False, True, True]), - ), - ) - - def test_add_const_int(self): - def fn(a): - return (a + 1,) - - self.common(fn, (torch.randn(32),)) - - def test_add_const_float(self): - def fn(a): - return (a + 1.5,) - - self.common(fn, (torch.randn(32),)) - - def test_add_inplace_permuted(self): - def fn(x, y): - return x.add_(y) - - x = torch.ones([2, 12, 13, 17]).transpose(1, 2) - y = torch.randn([2, 13, 1, 17]) - - self.common(fn, (x, y)) - - def test_abs(self): - def fn(a): - return (a / (torch.abs(a) + 1),) - - self.common(fn, (torch.randn(17),)) - - def test_sgn(self): - def fn(a): - return torch.sgn(a), torch.sgn(a + 1) - 1 - - self.common(fn, [torch.linspace(-10, 10, 41)]) - - def test_max_min(self): - def fn(a, b): - return (torch.maximum(a, b), torch.minimum(a, b)) - - self.common(fn, (torch.randn(8), torch.randn(8))) - - def test_horizonal_fusion1(self): - def fn(a, b, c): - return (a + b, a - c, b * c) - - self.common( - fn, (torch.randn(8, 16, 16), torch.randn(8, 16, 16), torch.randn(1, 16, 1)) - ) - - def test_horizonal_fusion2(self): - def fn(a, b, c): - return a + 1, b + 2, c + 3 - - self.common(fn, (torch.randn(8, 16, 8), torch.randn(8, 16), torch.randn(16, 8))) - - def test_vertical_fusion1(self): - def fn(sa, ct, p): - # From torchbench.pyhpc_equation_of_state - v17 = -3.087032500374211e-7 - v18 = -1.988366587925593e-8 - v19 = -1.061519070296458e-11 - v20 = 1.550932729220080e-10 - t15 = v19 * ct - t19 = v17 + ct * (v18 + t15) + v20 * sa - t20 = 1.0 / t19 - t128 = t19 * p - return t20 + t128 - - self.common( - fn, - ( - torch.randn(204, 204, 26), - torch.randn(204, 204, 26), - torch.randn(26), - ), - ) - self.assertEqual(torchinductor.metrics.generated_kernel_count, 1) - - def test_sum1(self): - def fn(a, b): - return ((a + b).sum(-1),) - - self.common(fn, (torch.randn(8, 8), torch.randn(8, 8))) - - def test_sum2(self): - def fn(a, b): - return ((a + b).sum([1, 2]), (a + b).sum(-1)) - - self.common(fn, (torch.randn(8, 9, 3, 21), torch.randn(8, 9, 3, 21))) - - def test_sum3(self): - def fn(a, b): - r1 = a + b - r2 = r1.sum(-1) - r3 = torch.squeeze(b) + 10 - return (r1, r2, r3) - - # Mismatched elements: 2 / 10 (20.0%) - # Greatest absolute difference: 0.0029296875 at index (8,) (up to 1e-05 allowed) - # Greatest relative difference: 0.0017482517482517483 at index (6,) (up to 0.001 allowed) - self.common(fn, (torch.randn(10, 10), torch.randn(1, 10)), atol=1e-5, rtol=2e-3) - - def test_sum4(self): - def fn(a): - b = a + 1 - c = b.sum(-1) - d = c + 3 - e = d.sum(-1) - f = e + 5 - return (f, e, d, c, b) - - self.common(fn, (torch.randn(1, 16, 8, 8),)) - - def test_sum5(self): - def fn(a): - b = a + 1 - c = b.sum(-1) - d = c + 3 - e = d.sum(-1) - f = e + 5 - return (f,) - - self.common(fn, (torch.randn(1, 17, 8, 9),)) - - def test_reduction1(self): - def fn(a): - return (a.sum(), a.max(), a.min(), a.argmax(), a.argmin()) - - self.common(fn, (torch.tensor([float("-inf"), 0.0, float("inf")]),)) - - def test_reduction2(self): - def fn(a): - # FIXME: a.argmax - return (a.sum(), a.max(), a.min(), a.argmin()) - - self.common(fn, (torch.full((4,), float("inf")),)) - - def test_reduction3(self): - def fn(a): - # FIXME: a.argmin - return (a.sum(), a.max(), a.min(), a.argmax()) - - self.common(fn, (torch.full((4,), float("-inf")),)) - - @patch.object(config, "dynamic_shapes", False) - def test_unroll_small_reduction(self): - def fn(x): - val1, index1 = x.min(-1) - val2, index2 = x.max(-1) - return ( - val1, - index1, - val2, - index2, - x.sum(-1), - (x > 1).any(-1), - (x > 0).all(-1), - x.argmin(-1), - x.argmax(-1), - x.amin(-1), - x.amax(-1), - ) - - with patch.object(config, "unroll_reductions_threshold", 8): - # small sized reductions will get unrolled - self.common(fn, (torch.randn(8, 3),)) - torchdynamo.reset() - with patch.object(config, "unroll_reductions_threshold", 1): - # make sure things also work if they aren't unrolled - self.common(fn, (torch.randn(8, 3),)) - - def test_multilayer_low_prec(self): - # fp16 nyi for cpu - if self.device == "cpu": - raise unittest.SkipTest("requires CUDA") - - def fn(a): - return torch.mean(a) - - self.common(fn, ((torch.rand((10, 3, 352, 352), dtype=torch.float16),))) - - def test_expanded_reduction(self): - def fn(x, y): - z = x * y - return z.sum((0, 1)) - - self.common(fn, (torch.randn(2, 197, 256), torch.randn(2, 1, 256))) - - def test_min_max_reduction(self): - def fn(a, b): - return ((a + b).max(), (a + b).min(), torch.amax(a + 1, keepdim=True)) - - self.common(fn, (torch.randn(8, 8), torch.randn(8, 8))) - - def test_sum_int(self): - def fn(x): - return 2 * x.sum(-1) + x.sum() - - dtypes = torch.bool, torch.uint8, torch.int - inps = [torch.randint(2, (64,), dtype=dtype) for dtype in dtypes] - for i in inps: - self.common(fn, (i,), check_lowp=False) - - def test_sum_dtype(self): - def fn(x): - return x * x.sum(-1, dtype=torch.double) + x.sum(dtype=torch.double) - - self.common(fn, (torch.ones(32, 32) * 70,)) - - def test_clamp(self): - def fn(a, b): - return (a.clamp(-0.1, 0.1), b.clamp(0), torch.clamp(a + b, max=0)) - - self.common(fn, (torch.randn(8, 8), torch.randn(8, 8))) - - def test_arange1(self): - def fn(x): - rng1 = torch.arange(8 * 8, dtype=torch.float32, device=x.device).view(8, 8) - rng2 = torch.arange(10, 18, device=x.device) - tmp = x * rng1 - return tmp, tmp + rng2 - - self.common(fn, (torch.randn(8, 8),)) - - def test_arange2(self): - def fn(x): - rng1 = torch.arange(8, device=x.device) - return (x + rng1,) - - self.common(fn, (torch.randint(4, (8, 8)),), check_lowp=False) - - def test_arange3(self): - def fn(x): - return x + torch.ops.aten.arange.start_step( - 0, 53, 4, dtype=torch.int64, device=x.device - ) - - self.common(fn, (torch.randn(14),)) - - def test_arange4(self): - def fn(x): - return x - torch.arange(512, -512, -1.0, device=x.device) - - self.common(fn, (torch.randn(1024),)) - - def test_linspace(self): - def fn(x): - return torch.linspace(0.125, 0.875, 7, device=x.device) + x - - self.common(fn, (torch.randn(1, 7),)) - - def test_tensor1(self): - def fn(x): - return torch.tensor([1], device=x.device) + x, torch.tensor( - 5, device=x.device - ) - - self.common(fn, (torch.randn(10),)) - - def test_tensor2(self): - def fn(x): - return torch.tensor(list(range(2, 40, 2)), device=x.device) + x - - self.common(fn, (torch.randn(1),)) - - def test_tensor3(self): - def fn(x): - return ( - torch.tensor([], device=x.device), - torch.tensor([1, 2], device=x.device) + 1, - torch.tensor([1, 2, 3], device=x.device) + 2, - torch.tensor([1, 2, 3, 4], device=x.device) + x, - ) - - self.common(fn, [torch.randn(4)]) - - def test_views1(self): - def fn1(x, y): - return (x.view(size2) + y,) - - def fn2(x, y): - return ((x + 1).view(size2) + y,) - - views = [ - ([5 * 7], [5, 7]), - ([2 * 3 * 4 * 5 * 6 * 7], [2, 3, 4, 5, 6, 7]), - ([2 * 3, 4, 5, 6 * 7], [2, 3, 4, 5, 6, 7]), - ([10 * 5, 20], [10, 5, 20]), - ([1, 10, 1], [10]), - ([10, 1, 10, 1, 10], [10, 100]), - ([2, 2, 2, 2], [4, 4]), - ] - for size1, size2 in views: - self.common(fn1, (torch.randn(size1), torch.randn(size2))) - self.common(fn2, (torch.randn(size1), torch.randn(size2))) - - for size2, size1 in views: - self.common(fn1, (torch.randn(size1), torch.randn(size2))) - self.common(fn2, (torch.randn(size1), torch.randn(size2))) - - def test_views2(self): - def fn1(x): - return (x.view(size2) + 1,) - - def fn2(x): - return ((x * 2).view(size2) + 1,) - - for size1, size2 in [ - ([2, 2, 2, 2], [4, -1]), - ([10, 1, 10, 1, 10], [-1, 100]), - ([10 * 5, 20], [10, -1, 20]), - ]: - self.common(fn1, (torch.randn(size1),)) - self.common(fn2, (torch.randn(size1),)) - - def test_views3(self): - # example taken from hf_BigBird - def forward(arg1, arg2): - index = torch.ops.aten.index(arg1, [arg2]) - view_1 = torch.ops.aten.view(index, [1, 2232, 64]) - view_2 = torch.ops.aten.view(view_1, [1, 12, 62, 192]) - return view_2 - - self.common( - forward, - ( - rand_strided((64, 64), (64, 1), torch.float32), - rand_strided((2232,), (1,), torch.int64), - ), - ) - - def test_relu(self): - def fn(a, b): - return (torch.relu(a), torch.relu(a + b) / 10) - - self.common(fn, (torch.randn(8, 8), torch.randn(8, 8))) - - def test_exp(self): - def fn(a, b): - return (torch.exp(a), torch.exp(a + b)) - - self.common(fn, (torch.randn(8, 8), torch.randn(8, 8))) - - def test_sigmoid(self): - def fn(a, b): - return (torch.sigmoid(a), torch.sigmoid(a + b)) - - self.common(fn, (torch.randn(8, 8), torch.randn(8, 8))) - - def test_round(self): - def fn(a, b): - return torch.round(a), torch.round(b + 1), torch.round(a, decimals=2) - - # without manual_seed, there is some chance this test fails due to: - # https://github.com/openai/triton/issues/530 - torch.manual_seed(0) - - # with *100 we are always getting a number exactly at .5 which we don't do right in half - self.common(fn, (torch.randn(8, 8) * 100, torch.randn(8, 8) * 10)) - - def test_round_correctness(self): - if self.device == "cuda": - raise unittest.SkipTest("need to debug tl.libdevice on A100/V100") - - def fn(a): - return torch.round(a) - - self.common( - fn, - [torch.arange(-10, 10, 0.1, dtype=torch.float64)], - check_lowp=False, - ) - - def test_silu(self): - def fn(a): - return (torch.nn.functional.silu(a),) - - self.common(fn, (torch.randn(8, 8),)) - - # TODO(voz): Re-enable this test ASAP https://github.com/pytorch/pytorch/issues/82763 - @unittest.skip("Skipping due to op bugs") - def test_nan_to_num(self): - def fn(a): - return ( - torch.nan_to_num(a), - torch.nan_to_num(a, nan=3.0), - torch.nan_to_num(a, nan=None), - torch.nan_to_num(a, posinf=4.0), - torch.nan_to_num(a, neginf=5.0), - torch.nan_to_num(a, nan=3.0, posinf=4.0, neginf=5.0), - ) - - self.common( - fn, - (torch.tensor((float("nan"), float("inf"), float("-inf"), 1.0)),), - check_lowp=False, # a much more elaborate test is required to match finfo max's for float and half - ) - - def test_div1(self): - def fn(a, b): - return ( - aten.div(a, b, rounding_mode=None), - aten.div(a, b, rounding_mode="floor"), - aten.div(a, b, rounding_mode="trunc"), - a / b, - a // b, - ) - - self.common(fn, (torch.randn(8, 8) * 100, torch.randn(8, 8) * 100)) - - def test_div2(self): - def fn(a, b): - return ( - aten.div(a, b, rounding_mode=None), - aten.div(a, b, rounding_mode="floor"), - aten.div(a, b, rounding_mode="trunc"), - a / b, - a // b, - ) - - self.common(fn, (torch.randint(-100, 100, [8, 8]), 100 * torch.randn(8, 8))) - - def test_div3(self): - def fn(a, b): - return ( - aten.div(a, b, rounding_mode=None), - aten.div(a, b, rounding_mode="floor"), - aten.div(a, b, rounding_mode="trunc"), - a / b, - a // b, - ) - - a = torch.randint(1, 100, [8, 8]) - self.common(fn, (a * 2, a)) - - def test_div4(self): - def fn(a, b): - return ( - aten.div(a, b, rounding_mode=None), - aten.div(a, b, rounding_mode="floor"), - aten.div(a, b, rounding_mode="trunc"), - a / b, - a // b, - ) - - self.common( - fn, - (torch.randint(-100, 0, [8, 8]), torch.randint(1, 10, [8, 8])), - ) - - def test_div5(self): - def fn(a, b): - return ( - aten.div(a, b, rounding_mode=None), - aten.div(a, b, rounding_mode="floor"), - aten.div(a, b, rounding_mode="trunc"), - a / b, - a // b, - ) - - # divide a scalar - self.common(fn, (torch.randint(-100, 0, [8, 8]), 16)) - - def test_div6(self): - def fn(a, b): - return ( - aten.div(a, b, rounding_mode=None), - aten.div(a, b, rounding_mode="floor"), - aten.div(a, b, rounding_mode="trunc"), - a / b, - a // b, - ) - - # treat boolean as integer - self.common( - fn, - (torch.ones([8, 8], dtype=torch.bool), torch.randint(-100, -1, [8, 8])), - ) - - def test_div7(self): - def fn(a, b): - return ( - aten.div(a, b, rounding_mode=None), - aten.div(a, b, rounding_mode="floor"), - aten.div(a, b, rounding_mode="trunc"), - a / b, - a // b, - ) - - self.common( - fn, - ( - torch.randint(2**32, 2**40, [100, 100]), - torch.randint(-10, -1, [100, 100]), - ), - ) - - def test_div8(self): - def fn(a, b): - return ( - aten.div(a, b, rounding_mode=None), - aten.div(a, b, rounding_mode="floor"), - aten.div(a, b, rounding_mode="trunc"), - a / b, - a // b, - ) - - self.common(fn, (1024, 100)) - - def test_both_scalars(self): - def fn(a, b): - return ( - aten.add(a, b), - aten.add(b, a), - aten.sub(a, b), - aten.sub(b, a), - aten.mul(a, b), - aten.mul(b, a), - ) - - self.common(fn, (4, 3.3), reference_in_float=False) - - def test_sum_keepdims(self): - def fn(a, b): - return (torch.sum(a + b, -1, keepdim=True),) - - self.common(fn, (torch.randn(8, 8), torch.randn(8, 8))) - - def test_softmax(self): - def fn(a, b): - return (torch.softmax(a + b, -1), torch.softmax(a, 0), torch.softmax(b, 1)) - - self.common(fn, (torch.randn(8, 8), torch.randn(8, 8))) - - def test_log_softmax(self): - def fn(a, b): - return (F.log_softmax(a + b, -1), F.log_softmax(a, 0), F.log_softmax(b, 1)) - - self.common(fn, (torch.randn(8, 8), torch.randn(8, 8))) - - def test_transpose(self): - def fn(a, b): - return ( - torch.t(a) + b, - torch.transpose(b * 2, 0, 1) + 10, - ) - - self.common(fn, (torch.randn(8, 8), torch.randn(8, 8))) - - def test_permute(self): - def fn(a): - return ( - torch.permute(a + 1, [2, 1, 4, 0, 3]) + 2, - torch.permute(a, [2, 1, 4, 0, 3]) + 2, - ) - - self.common(fn, (torch.randn(2, 2, 2, 2, 2),)) - - def test_expand(self): - def fn(a): - return ( - (a + 1).expand(3, 4, 2, 3, 2) + 2, - a.expand(2, 1, 2, 3, 2) + 2, - ), a.expand(2, -1, 5, -1) - - self.common(fn, (torch.randn(2, 1, 2),)) - - def test_squeeze1(self): - def fn(a): - return ((a + 1).squeeze() + 2, a.squeeze() + 2) - - self.common(fn, (torch.randn(1, 2, 1, 2, 2, 1, 1),)) - - def test_squeeze2(self): - def fn(a): - return ((a + 1).squeeze(-1).squeeze(2) + 2, a.squeeze(0) + 2) - - self.common(fn, (torch.randn(1, 2, 1, 2, 2, 2, 1),)) - - def test_simplify_loops(self): - def fn(a, b): - return a + b - - self.common( - fn, - ( - torch.randn(2, 3, 4, 5, 6), - torch.randn(4, 2, 3, 5, 6).permute(1, 2, 0, 3, 4), - ), - ) - - def test_unsqueeze(self): - def fn(a): - return ( - torch.unsqueeze(a + 1, -1) + 2, - torch.unsqueeze(a, 2) + 2, - torch.unsqueeze(a + 1, 0) + 2, - torch.unsqueeze(a, -2) + 2, - ) - - self.common( - fn, - ( - torch.randn( - 2, - 2, - 2, - 2, - ), - ), - ) - - def test_unsqueeze_inplace(self): - def fn(a): - tmp1 = a + 1 - aten.unsqueeze_(tmp1, 2) - tmp2 = aten.unsqueeze_(a + 1, 0) + 2 - return (tmp1, tmp2) - - self.common( - fn, - ( - torch.randn( - 2, - 2, - 2, - 2, - ), - ), - ) - - def test_addmm(self): - def fn(a, b, c): - return (torch.addmm(a + 1, b + 2, c + 3) + 4,) - - self.common( - fn, - ( - torch.randn(8, 8), - torch.randn(8, 8), - torch.randn(8, 8), - ), - ) - - def test_linear1(self): - mod = torch.nn.Sequential( - torch.nn.Linear(8, 16), - torch.nn.Sigmoid(), - ToTuple(), - ) - self.common(mod, (torch.randn(2, 8),)) - - def test_linear2(self): - mod = torch.nn.Sequential( - torch.nn.Linear(8, 8), - torch.nn.ReLU(), - torch.nn.Linear(8, 8), - torch.nn.ReLU(), - torch.nn.Linear(8, 8), - torch.nn.ReLU(), - torch.nn.Linear(8, 8), - torch.nn.ReLU(), - ) - self.common(mod, (torch.randn(2, 8),)) - - def test_bmm1(self): - def fn(a, b): - return ( - torch.bmm(a, b), - torch.bmm(a + 1, b + 2) + 3, - ) - - self.common( - fn, - ( - torch.randn(2, 8, 8), - torch.randn(2, 8, 8), - ), - check_lowp=False, - ) - self.common( - fn, - ( - torch.randn(1, 16, 8), - torch.randn(1, 8, 10), - ), - check_lowp=False, - ) - - def test_bmm2(self): - def fn(a, b): - return torch.bmm(a.permute(0, 2, 1), b) - - self.common( - fn, - ( - torch.randn(1, 8, 8), - torch.randn(1, 8, 8), - ), - check_lowp=False, - ) - - def test_gather1(self): - def fn(a, b): - return ( - torch.gather(a.expand([4, 5, 10, 6]), 3, b + 1), - torch.gather(a.expand([4, 5, 10, 6]), -1, b + 1), - ) - - self.common( - fn, - ( - torch.randn([1, 1, 10, 6]), - torch.randint(5, [4, 5, 10, 1], dtype=torch.int64), - ), - ) - - def test_gather2(self): - # 0d tensor - def fn(a, b): - return torch.gather(a, 0, b) + torch.gather(a, -1, b) - - x = torch.tensor(123) - y = torch.tensor(0) - self.assertEqual(fn(x, y), x + x) - - def test_slice1(self): - def fn(a): - return ( - a[:, :10, 0] + a[:, 10:, 0], - (a + 1)[:, :10, 0] + (a + 1)[:, 10:, 0], - ) - - self.common( - fn, - (torch.randn([2, 20, 2]),), - ) - - def test_slice2(self): - def fn(a): - return ( - a[:-1, ::2, -1] + a[-1:, 1::2, -2], - (a + 1)[:-1, ::2, -1] + (a + 2)[-1:, 1::2, -2], - ) - - self.common( - fn, - (torch.randn([2, 20, 2]),), - ) - - def test_split_with_sizes(self): - def fn(a, sizes): - return [t + 1.0 for t in torch.split(a * 2.0, sizes, -1)] - - self.common(fn, (torch.randn(2, 2, 10), [3, 3, 4])) - self.common(fn, (torch.randn(2, 2, 10), [4, 3, 3])) - self.common(fn, (torch.randn(2, 2, 10), [1, 2, 3, 4])) - - def test_split(self): - def fn(a): - t = torch.split(a, 3, -1) - return (t[0], t[1], t[2], t[3]) - - def fn2(a): - return fn(a + 1) - - self.common( - fn, - (torch.randn([2, 2, 10]),), - ) - - self.common( - fn2, - (torch.randn([2, 2, 10]),), - ) - - def test_to_dtype(self): - def fn(a, b): - return ( - aten._to_copy(a, dtype=6), - aten._to_copy(b + 1, dtype=6), - aten.to(b, torch.float64), - aten.to(b, torch.bool), - ) - - self.common( - fn, - ( - torch.randn([2, 2, 10]), - torch.randn([2, 2, 10], dtype=torch.float64), - ), - ) - - @requires_cuda() - def test_to_device(self): - def fn(a): - if a.device.type == "cpu": - return aten._to_copy(a, device=torch.device("cuda"), dtype=6, layout=0) - else: - return aten._to_copy(a, device=torch.device("cpu"), dtype=6, layout=0) - - self.common( - fn, - (torch.randn([2, 2, 10]),), - ) - - @requires_cuda() - def test_to_device_constant(self): - def fn(a): - d1 = a.device.type - if d1 == "cpu": - d2 = "cuda" - else: - d2 = "cpu" - - const1 = torch.as_tensor(list(range(64)), device=d2) - return ( - torch.arange(10, device=d2).to(d1) + a, - const1.to(d1), - (const1 + 1).to(d1), - ) - - self.common( - fn, - (torch.randn([10]),), - ) - - @requires_cuda() - def test_multi_device(self): - def fn(x): - x = x + 1 - x = x + 2 - x = x.cuda() - x = x + 3 - x = x + 4 - x = x.cpu() - x = x + 5 - x = x + 6 - x = x.cuda() - x = x + 7 - x = x + 8 - x = x.cpu() - x = x + 9 - x = x + 10 - return x - - self.common( - fn, - (torch.randn([2, 2, 10]),), - check_lowp=False, # cpu doesn't understand fp16, and there are explicit .cpu() calls - ) - - def test_unbind(self): - def fn(a): - return torch.unbind(a), torch.unbind(a, -1) - - self.common( - fn, - (torch.randn([4, 4, 4]),), - ) - - def test_convolution1(self): - m = torch.nn.Sequential( - torch.nn.Conv2d(5, 6, [3, 3]), - torch.nn.ReLU(), - ToTuple(), - ) - - self.common( - m, - (torch.randn([2, 5, 16, 16]),), - # Mismatched elements: 10 / 2352 (0.4%) - # Greatest absolute difference: 5.7220458984375e-05 at index (0, 3, 12, 12) (up to 1e-05 allowed) - # Greatest relative difference: 0.06512477175897748 at index (0, 4, 11, 9) (up to 0.001 allowed) - atol=6e-5, - rtol=0.001, - ) - - def test_convolution2(self): - def fn(x, w, b): - # transposed conv - return (aten.convolution(x, w, b, [4], [0], [1], True, [0], 1),) - - self.common( - fn, - ( - torch.randn([2, 32, 90]), - torch.randn([32, 16, 8]), - torch.randn([16]), - ), - check_lowp=False, - ) - - @unittest.skipIf(HAS_CUDA, "only support cpu channels_last") - def test_conv2d_channels_last(self): - m = torch.nn.Sequential( - torch.nn.Conv2d(3, 3, 1, 1), - ToTuple(), - ) - # only weight is channels_last - self.common( - m.to(memory_format=torch.channels_last), - (torch.randn([2, 3, 16, 16]),), - ) - # only activation is channels_last - self.common( - m, - (torch.randn([2, 3, 16, 16]).to(memory_format=torch.channels_last),), - ) - # activation and weight are all channels_last - self.common( - m.to(memory_format=torch.channels_last), - (torch.randn([2, 3, 16, 16]).to(memory_format=torch.channels_last),), - ) - - @unittest.skipIf(HAS_CUDA, "only support cpu channels_last") - def test_conv3d_channels_last(self): - m = torch.nn.Sequential( - torch.nn.Conv3d(3, 3, 1, 1), - ToTuple(), - ) - # only weight is channels_last - self.common( - m.to(memory_format=torch.channels_last_3d), - (torch.randn([2, 3, 16, 16, 16]),), - ) - # only activation is channels_last - self.common( - m, - (torch.randn([2, 3, 16, 16, 16]).to(memory_format=torch.channels_last_3d),), - ) - # activation and weight are all channels_last - self.common( - m.to(memory_format=torch.channels_last_3d), - (torch.randn([2, 3, 16, 16, 16]).to(memory_format=torch.channels_last_3d),), - ) - - def test_adaptive_avg_pool2d1(self): - def fn(x): - return aten._adaptive_avg_pool2d(x, (6, 6)), aten._adaptive_avg_pool2d( - x + 1, (2, 5) - ) - - self.common( - fn, - (torch.randn(2, 4, 16, 16),), - ) - - # lowering to avg_pool2d case - self.common( - fn, - (torch.randn(2, 4, 3, 3),), - ) - - # no-op case - self.common( - fn, - (torch.randn(2, 4, 6, 6),), - ) - - def test_max_pool2d1(self): - def fn(x): - return aten.max_pool2d_with_indices(x, [3, 3], [2, 2]) - - self.common( - fn, - (torch.randn(2, 4, 16, 16),), - ) - - def test_max_pool2d2(self): - def fn(x): - return aten.max_pool2d_with_indices(x, [3, 3], [2, 2]) - - self.common( - fn, - (torch.randn([16, 64, 55, 55]),), - ) - - def test_max_pool2d3(self): - def fn(x): - # with padding - return aten.max_pool2d_with_indices(x, [3, 3], [2, 2], [1, 1]) - - self.common( - fn, - (-torch.arange(1 * 8 * 8, dtype=torch.float32).view(1, 1, 8, 8),), - ) - - def test_max_pool2d4(self): - def fn(x): - # with padding - return aten.max_pool2d_with_indices(x, [3, 3], [2, 2], [0, 0], [1, 1], True) - - self.common( - fn, - (torch.randn([2, 8, 111, 111]),), - ) - - def test_max_pool2d5(self): - def fn(x): - return aten.max_pool2d_with_indices(x, [3, 3], []) - - self.common( - fn, - (torch.randn([16, 64, 55, 55]),), - ) - - def test_avg_pool2d1(self): - def fn(x): - return aten.avg_pool2d(x, [3, 3], [2, 2]) - - self.common( - fn, - (torch.randn(2, 4, 16, 16),), - ) - - def test_avg_pool2d2(self): - def fn(x): - return aten.avg_pool2d(x, [3, 3], [2, 2]) - - self.common( - fn, - (torch.randn([16, 64, 55, 55]),), - ) - - def test_avg_pool2d3(self): - def fn(x): - return aten.avg_pool2d(x, [3, 3], [2, 2], [1, 1]) - - self.common( - fn, - (-torch.arange(1 * 8 * 8, dtype=torch.float32).view(1, 1, 8, 8),), - ) - - def test_avg_pool2d4(self): - def fn(x): - return aten.avg_pool2d(x, [3, 3], [2, 2], [0, 0], True) - - self.common( - fn, - (torch.randn([2, 8, 111, 111]),), - ) - - def test_avg_pool2d5(self): - def fn(x): - return aten.avg_pool2d(x, [3, 3], [2, 2], [1, 1], count_include_pad=False) - - self.common( - fn, - (-torch.arange(1 * 8 * 8, dtype=torch.float32).view(1, 1, 8, 8),), - ) - - def test_avg_pool2d6(self): - def fn(x): - return aten.avg_pool2d(x, [3, 3], [2, 2], [1, 1], divisor_override=3) - - self.common( - fn, - (-torch.arange(1 * 8 * 8, dtype=torch.float32).view(1, 1, 8, 8),), - ) - - def test_alexnet_prefix(self): - def forward(arg6, arg7, arg16): - convolution = torch.ops.aten.convolution( - arg16, arg7, arg6, [4, 4], [2, 2], [1, 1], False, [0, 0], 1 - ) - relu = torch.ops.aten.relu(convolution) - max_pool2d_with_indices = torch.ops.aten.max_pool2d_with_indices( - relu, [3, 3], [2, 2] - ) - getitem = max_pool2d_with_indices[0] - return (getitem,) - - self.common( - forward, - ( - rand_strided((64,), (1,), torch.float32, "cpu"), - rand_strided((64, 3, 11, 11), (363, 121, 11, 1), torch.float32, "cpu"), - rand_strided( - (16, 3, 224, 224), (150528, 50176, 224, 1), torch.float32, "cpu" - ), - ), - # Mismatched elements: 127 / 746496 (0.0%) - # Greatest absolute difference: 0.0009765625 at index (1, 62, 7, 16) (up to 1e-05 allowed) - # Greatest relative difference: 0.05187467899332306 at index (14, 18, 11, 0) (up to 0.001 allowed) - atol=1e-3, - rtol=0.001, - ) - - def test_elu(self): - def fn(x): - return aten.elu(x, 1.6732632423543772, 1.0507009873554805) + 2, aten.elu( - x + 1, 2, 3, 4 - ) - - self.common( - fn, - (torch.randn([16, 16]),), - ) - - def test_tanh(self): - def fn(x): - return aten.tanh(x) + 2, aten.tanh(x + 1) - - self.common( - fn, - (torch.randn([16, 16]),), - ) - - def test_lgamma(self): - def fn(x): - return aten.lgamma(x) + 2, aten.cos(x + 1) - - self.common( - fn, - (torch.randn([16, 16]),), - ) - - def test_cos(self): - def fn(x): - return aten.cos(x) + 2, aten.cos(x + 1) - - self.common( - fn, - (torch.randn([16, 16]),), - ) - - def test_sin(self): - def fn(x): - return aten.sin(x) + 2, aten.sin(x + 1) - - self.common( - fn, - (torch.randn([16, 16]),), - ) - - def test_repeat(self): - def fn(x): - return ( - x.repeat(2, 2, 3, 1), - x.repeat(8, 1, 1, 1), - x.repeat(2, 1, 1, 1, 1, 1), - ) - - self.common( - fn, - (torch.randn([1, 2, 4, 8]),), - ) - - def test_embedding(self): - m = torch.nn.Sequential( - torch.nn.Embedding(10, 4, padding_idx=0), - torch.nn.ReLU(), - ToTuple(), - ) - - self.common( - m, - (torch.randint(10, [2, 8]),), - ) - - def test_mean(self): - def fn(x): - return ( - x.mean(), - x.mean(-1), - torch.mean(x, -2, keepdim=True), - x.mean([0, 1]), - ) - - self.common( - fn, - (torch.randn([1, 2, 4, 8]),), - ) - - def test_var_mean(self): - def fn(x): - return ( - *torch.var_mean(x, -1), - *torch.var_mean(x, [1, 3]), - ) - - self.common( - fn, - (torch.randn([1, 2, 4, 8]),), - ) - - @patch.object(config, "pick_loop_orders", True) - def test_transposed_propagates(self): - @torchdynamo.optimize("inductor", nopython=True) - def fn(x, y): - return x + y - - a = torch.randn(1, 4, 4, 4, device=self.device).permute(0, 2, 3, 1) - b = torch.randn(4, 4, 4, device=self.device).permute(1, 2, 0) - c = fn(a, b) - self.assertEqual(a.stride(), c.stride()) - self.assertEqual(c.stride()[2], 1) - - @requires_cuda() - @patch.object(config.triton, "convolution", "triton") - @patch.object(config.triton, "dense_indexing", "True") - def test_triton_conv(self): - @torchdynamo.optimize("inductor", nopython=True) - def triton_conv( - x, - w, - bias, - stride, - padding, - dilation, - groups, - ): - y = torch.conv2d(x, w, bias, stride, padding, dilation, groups) - return y - - stride, padding, dilation, groups = (1, 1), (0, 0), (1, 1), 1 - dtype = torch.float32 - x = torch.randn((32, 128, 32, 32), dtype=dtype, device=self.device) - w = torch.randn((32, 128, 1, 1), dtype=dtype, device=self.device) - bias = torch.randn((32), dtype=dtype, device=self.device) - - y = triton_conv(x, w, bias, stride, padding, dilation, groups) - y_correct = torch.conv2d(x, w, bias, stride, padding, dilation, groups) - self.assertTrue(same(y, y_correct, cos_similarity=True, tol=0.1)) - - @requires_cuda() - @patch.object(config.triton, "convolution", "autotune") - @patch.object(config.triton, "dense_indexing", "True") - def test_conv_autotune(self): - @torchdynamo.optimize("inductor", nopython=True) - def triton_conv( - x, - w, - bias, - stride, - padding, - dilation, - groups, - ): - y = torch.conv2d(x, w, bias, stride, padding, dilation, groups) - return y - - stride, padding, dilation, groups = (1, 1), (0, 0), (1, 1), 1 - dtype = torch.float32 - x = torch.randn((32, 128, 32, 32), dtype=dtype, device=self.device) - w = torch.randn((32, 128, 1, 1), dtype=dtype, device=self.device) - bias = torch.randn((32), dtype=dtype, device=self.device) - - y = triton_conv(x, w, bias, stride, padding, dilation, groups) - y_correct = torch.conv2d(x, w, bias, stride, padding, dilation, groups) - self.assertTrue(same(y, y_correct, cos_similarity=True, tol=0.1)) - - @patch.object(config.triton, "mm", "triton") - def test_triton_mm2(self): - @torchdynamo.optimize("inductor", nopython=True) - def fn(x, y): - return torch.relu(torch.mm(x, y)) - - N = 1024 - a = torch.randn([N, N], device=self.device, dtype=torch.float32) - b = torch.randn([N, N], device=self.device, dtype=torch.float32) - c1 = torch.relu(torch.mm(a, b)) - torchinductor.metrics.reset() - c = fn(a, b) - assert torch.allclose(c1, c, atol=1e-3, rtol=1e-3) - if self.device == "cuda": - assert torchinductor.metrics.generated_kernel_count == 1 - - def test_std(self): - def fn(x): - return ( - torch.var(x, True), - torch.var(x, False), - torch.var(x, -1, True), - torch.var(x, -1, False), - torch.std(x, False), - torch.std(x, [0, 1], True), - torch.std(x, [0, 1], False), - torch.std(x, -2, True, keepdim=True), - ) - - self.common( - fn, - (torch.randn([2, 4, 4, 8]),), - ) - - def test_embedding_bag(self): - def fn(w, i, o): - return aten._embedding_bag(w, i, o, False, 0, False, None) - - self.common( - fn, - (torch.randn([10, 4]), torch.randint(10, [8]), torch.tensor([0, 2, 6])), - ) - - def test_batch_norm_2d(self): - m = torch.nn.Sequential( - torch.nn.BatchNorm2d(10), - torch.nn.ReLU(), - ) - m.eval() - self.common(m, (torch.randn([2, 10, 8, 8]),), check_lowp=False) - self.common( - m, - (torch.randn([3, 10, 16, 16]),), - check_lowp=False, # too painful to match types of bn model - ) - - def test_layer_norm(self): - m = torch.nn.Sequential( - torch.nn.LayerNorm(32), - torch.nn.ReLU(), - ) - m.eval() - self.common(m, (torch.randn([16, 32]),), check_lowp=False) - if self.device != "cpu": - self.assertEqual(torchinductor.metrics.generated_kernel_count, 1) - - def test_move_arange(self): - def fn(x): - return torch.arange(len(x), device="cpu").to(x.device) + x - - self.common(fn, (torch.randn([32]),), check_lowp=False) - # if we have a copy there will be more than 1 kernel - self.assertEqual(torchinductor.metrics.generated_kernel_count, 1) - - def test_leaky_relu(self): - def fn(x): - return aten.leaky_relu(x, 0.2) + 2, aten.leaky_relu(x + 1) - - self.common( - fn, - (torch.randn([16, 16]),), - ) - - def test_gelu(self): - def fn(x): - return aten.gelu(x) + 2, aten.gelu(x + 1) - - self.common( - fn, - (torch.randn([16, 16]),), - ) - - def test_clone(self): - def fn(x): - return aten.clone(x) + 2, aten.clone(x + 1) - - self.common( - fn, - (torch.randn([16, 16]),), - ) - - def test_masked_fill(self): - def fn(mask, value): - return aten.masked_fill(value, mask, -10000.0) + 2, aten.masked_fill( - value / 2.0, torch.logical_not(mask), 667 - ) - - self.common( - fn, - ( - torch.randint(0, 1, [1, 16], dtype=torch.bool), - torch.randn([16, 16]), - ), - ) - - def test_masked_fill_promotion(self): - def fn(mask, value): - return aten.masked_fill(value, mask, torch.tensor(3.5)) - - opt_fn = torchdynamo.optimize("inductor")(fn) - for inp in ( - torch.randn( - [16, 16], - dtype=torch.float16 if self.device == "cuda" else torch.float32, - device=self.device, - ), - torch.randint(16, (16, 16), device=self.device), - ): - - inputs = ( - torch.randint(0, 1, [1, 16], dtype=torch.bool, device=self.device), - inp, - ) - self.assertEqual(fn(*inputs), opt_fn(*inputs)) - - def test_fill1(self): - def fn(x): - tmp = torch.ones_like(x) - return tmp, aten.fill.Scalar(tmp, 2) - - self.common( - fn, - (torch.randn([16, 16]),), - ) - - def test_fill2(self): - def fn(x): - tmp = torch.ones_like(x) - return tmp, aten.fill.Tensor(tmp, torch.tensor(3.0)) - - self.common( - fn, - (torch.randn([16, 16]),), - ) - - def test_pow1(self): - def fn(x): - return [aten.pow(x, e) for e in range(-8, 9)] - - self.common( - fn, - (torch.randn([16, 16]),), - ) - - def test_pow2(self): - def fn(x): - return aten.pow(1000, x), aten.pow(x, 1000) - - self.common( - fn, - (torch.randn([16, 16]),), - # Mismatched elements: 9 / 256 (3.5%) - # Greatest absolute difference: 2.491354329061828e+28 at index (6, 6) (up to 1e-05 allowed) - # Greatest relative difference: 2.9793410720160818e-05 at index (4, 5) (up to 1.3e-06 allowed) - atol=1e-5, - rtol=3e-05, - ) - - def test_glu(self): - def fn(x): - return aten.glu(x, -1), aten.glu(x, 1), aten.glu(x, 2) - - self.common( - fn, - (torch.randn([8, 16, 8, 8]),), - ) - - def test_cat(self): - def fn(a): - tmp = a * 2 - return ( - torch.cat((a, a[:, :4] + 1, a + 2), -1), - torch.cat((tmp, tmp), 0), - torch.cat((tmp, tmp.double()), 0), - ) - - self.common( - fn, - (torch.randn([8, 16]),), - ) - - def test_cat_upcasting(self): - def fn(arg4_1, slice_7): - cat_1 = aten.cat.default([arg4_1, slice_7], 1) - return (cat_1,) - - self.common( - fn, - ( - torch.randn([8, 16], dtype=torch.float32), - torch.randn([8, 20], dtype=torch.float16), - ), - ) - - def test_cat_extern_kernel(self): - def fn(x1, x2, x3, x4): - x = torch.mm(x2, x3) - s = torch.narrow(x, 1, 0, 100) - x = torch.mm(s, x4) - c = torch.cat((x, x1), 1) - return (c,) - - self.common( - fn, - ( - torch.randn(256, 256), - torch.randn(256, 1024), - torch.randn(1024, 1600), - torch.randn(100, 256), - ), - check_lowp=False, # accuracy issues with relatively large matmuls - ) - - def test_stack(self): - def fn(a, b): - return torch.stack( - [ - a.expand(12, 16), - b.expand(12, 16), - ], - 2, - ) - - self.common(fn, (torch.randn([1, 16]), torch.randn([12, 1]))) - - def test_hardtanh(self): - def fn(x): - return F.hardtanh(x), F.hardtanh(x + 1), F.hardtanh(x - 1) - - self.common( - fn, - (torch.randn([64]),), - ) - - def test_hardsigmoid(self): - def fn(x): - return F.hardsigmoid(x), F.hardsigmoid(x + 3), F.hardsigmoid(x - 3) - - self.common( - fn, - (torch.randn([64]),), - ) - - def test_hardswish(self): - def fn(x): - return F.hardswish(x), F.hardswish(x + 3), F.hardswish(x - 3) - - self.common( - fn, - (torch.randn([64]),), - ) - - def test_rsqrt(self): - def fn(x): - return torch.rsqrt(x), torch.rsqrt(x + 1) - 2 - - self.common( - fn, - (torch.randn([64]),), - ) - - def test_flip(self): - def fn(x): - return torch.flip(x, (-1,)), torch.flip(x, (0, 2)) - 2 - - self.common( - fn, - (torch.randn([1, 2, 6, 6]),), - ) - - def test_signbit(self): - def fn(x): - return torch.signbit(x), ~torch.signbit(-x) & 1 - - self.common( - fn, - (torch.randn([1, 2, 6, 6]),), - ) - - def test_fmod(self): - def fn(a, b): - return torch.fmod(a, b), torch.fmod(3.0 * a, b) - 2.0 - - shape = [1, 2, 6, 6] - self.common(fn, (torch.randn(shape), torch.randn(shape))) - - def test_log2(self): - def fn(x): - return torch.log2(x), torch.log2(x + 1) - 2 - - self.common( - fn, - (torch.randn([64]) + 10,), - ) - - def test_logsumexp(self): - def fn(x): - return torch.logsumexp(x, -1), torch.logsumexp(x, 0) - 2 - - self.common( - fn, - (torch.randn([8, 8]) + 10,), - ) - - def test_log_fp64(self): - def fn(x): - return torch.log(x), torch.log2(x) - - self.common( - fn, - (torch.randn([1024], dtype=torch.float64) + 10,), - ) - - def test_bitwise(self): - def fn(x, y): - return ( - torch.bitwise_not(x), - torch.bitwise_or(x, y), - torch.bitwise_xor(x, y), - torch.bitwise_and(x, y), - ) - - self.common( - fn, - ( - torch.randint(0, 2**30, [64], dtype=torch.int32), - torch.randint(0, 2**30, [64], dtype=torch.int32), - ), - ) - - def test_bitwise2(self): - # again with bool types - def fn(x, y): - return ( - torch.bitwise_not(x), - torch.bitwise_or(x, y), - torch.bitwise_xor(x, y), - torch.bitwise_and(x, y), - ) - - self.common( - fn, - ( - torch.randint(0, 2, (2, 20), dtype=torch.bool), - torch.randint(0, 2, (2, 20), dtype=torch.bool), - ), - ) - - def test_inf(self): - def fn(a): - return a + float("inf"), a + float("-inf"), a * -float("inf") - - self.common(fn, (torch.randn(8),)) - - def test_remainder(self): - def fn(a, b): - return ( - torch.remainder(a, b), - torch.remainder(a + 1, b - 1), - torch.remainder(a - 1, b + 1), - ) - - self.common(fn, (torch.randn(64), torch.randn(64))) - - def test_zeros(self): - def fn(a): - return ( - a + 1, - torch.zeros( - (1, 8, 64, 64), - dtype=torch.float32, - device=a.device, - ), - torch.zeros( - 1, - 8, - 64, - 64, - dtype=torch.float32, - device=a.device, - ), - torch.zeros(2, 3, names=None), - a + torch.ones(8, device=a.device), - torch.full((2, 3), 3.1416, device=a.device), - ) - - self.common(fn, (torch.randn(8),)) - - def test_new_ones(self): - def fn(a): - return ( - aten.new_ones( - a, [], device=a.device, dtype=6, layout=0, pin_memory=False - ), - aten.new_zeros( - a, [], device=a.device, dtype=6, layout=0, pin_memory=False - ), - ) - - self.common(fn, (torch.randn(8),)) - - def test_full_like(self): - def fn(a): - return torch.full_like(a, 7.777) - 1 - - self.common(fn, (torch.randn(8),)) - - def test_index1(self): - def fn(a, b, c): - return aten.index(a, [b, c]) - - self.common( - fn, - ( - torch.randn(8, 8, 12), - torch.tensor([0, 0, 2, 2], dtype=torch.int64), - torch.tensor([3, 4, 4, 3], dtype=torch.int64), - ), - ) - self.common( - fn, - ( - torch.randn(8, 8, 12), - torch.tensor([[0, 0, 2, 2]], dtype=torch.int64), - torch.tensor([[3], [4], [4], [3]], dtype=torch.int64), - ), - ) - - def test_index2(self): - def fn(a, b): - return ( - aten.index(a, [b]), - aten.index(a, [None, b]), - ) - - self.common( - fn, - ( - torch.randn(8, 8, 8), - torch.tensor([[0, 0, 2, 2]], dtype=torch.int64), - ), - ) - - def test_index_select(self): - def fn(a, b): - return ( - torch.index_select(a, 0, b), - torch.index_select(a, 1, b), - torch.index_select(torch.index_select(a, 2, b), 1, b), - ) - - for ind_dtype in (torch.int32, torch.int64): - self.common( - fn, - ( - torch.randn(8, 8, 8), - torch.tensor([0, 0, 2, 1], dtype=ind_dtype), - ), - ) - - # https://github.com/pytorch/torchdynamo/issues/467 - @patch.object(torchdynamo.config, "fake_tensor_propagation", False) - def test_cudnn_rnn(self): - if self.device == "cpu": - raise unittest.SkipTest("requires CUDA") - - def fn( - a0, - b0, - b1, - b2, - b3, - b4, - b5, - b6, - b7, - b8, - b9, - b10, - b11, - b12, - b13, - b14, - b15, - a3, - a4, - a5, - ): - a1 = [ - b0, - b1, - b2, - b3, - b4, - b5, - b6, - b7, - b8, - b9, - b10, - b11, - b12, - b13, - b14, - b15, - ] - return aten._cudnn_rnn( - a0, - a1, - 4, - a3, - a4, - a5, - 2, - 2048, - 0, - 2, - False, - 0.0, - False, - True, - [], - None, - ) - - self.common( - fn, - ( - torch.randn([92, 8, 2048]), - torch.randn([8192, 2048]), - torch.randn([8192, 2048]), - torch.randn([8192]), - torch.randn([8192]), - torch.randn([8192, 2048]), - torch.randn([8192, 2048]), - torch.randn([8192]), - torch.randn([8192]), - torch.randn([8192, 4096]), - torch.randn([8192, 2048]), - torch.randn([8192]), - torch.randn([8192]), - torch.randn([8192, 4096]), - torch.randn([8192, 2048]), - torch.randn([8192]), - torch.randn([8192]), - torch.randn([167837696]), - torch.randn([4, 8, 2048]), - torch.randn([4, 8, 2048]), - ), - check_lowp=False, # difference in rnn is too large between half and float inputs - ) - - def test_upsample_nearest2d(self): - def fn(a): - return ( - aten.upsample_nearest2d(a, [74, 76], None), - aten.upsample_nearest2d(a, [70, 75], None), - aten.upsample_nearest2d(a, [45, 74], None), - aten.upsample_nearest2d(a, [36, 39], None), - aten.upsample_nearest2d(a, None, [2.0, 2.0]), - ) - - self.common(fn, (torch.randn([2, 4, 37, 38]),)) - - def test_upsample_nearest2d_backward(self): - func = torch.ops.aten.upsample_nearest2d_backward.vec - - def fn(a): - return ( - func( - a, output_size=[6, 12], input_size=[3, 3, 3, 6], scale_factors=None - ), - func( - a, output_size=[6, 12], input_size=[3, 3, 4, 5], scale_factors=None - ), - func( - a, output_size=[6, 12], input_size=[3, 3, 2, 8], scale_factors=None - ), - func( - a, output_size=[6, 12], input_size=[3, 3, 2, 8], scale_factors=None - ), - func( - a, output_size=[6, 12], input_size=[3, 3, 4, 7], scale_factors=None - ), - ) - - self.common(fn, (torch.randn([3, 3, 6, 12]),)) - - def test_upsample_bilinear2d_a(self): - def fn(a): - return ( - aten.upsample_bilinear2d(a, [45, 45], False, None), - aten.upsample_bilinear2d(a, None, True, [2.0, 2.0]), - ) - - self.common(fn, (torch.randn([2, 4, 37, 38]),)) - - def test_upsample_bilinear2d_b(self): - def fn(a): - return aten.upsample_bilinear2d(a, None, True, [2.0, 2.0]) - - self.common( - fn, - [ - torch.randn([1, 2, 40, 59]), - ], - ) - - def test_reflection_pad2d(self): - def fn(a): - return ( - aten.reflection_pad2d(a, [1, 1, 1, 1]), - aten.reflection_pad2d(a, [1, 2, 3, 4]), - ) - - self.common( - fn, (torch.randint(0, 999, size=[1, 1, 8, 8], dtype=torch.float32),) - ) - - def test_reflection_pad2d_backward(self): - def template(size, padding): - def fn(grad_output, x): - return aten.reflection_pad2d_backward(grad_output, x, padding) - - x = torch.randint(0, 999, size=size, dtype=torch.float32) - result = aten.reflection_pad2d(x, padding) - grad_output = torch.randn_like(result) - - self.common(fn, (grad_output, x)) - - template([1, 1, 8, 8], [0, 0, 0, 0]) - template([1, 1, 8, 8], [1, 1, 1, 1]) - template([1, 1, 8, 8], [1, 2, 3, 4]) - - def test_grid_sampler_2d(self): - def fn(a, b): - return ( - aten.grid_sampler_2d(a, b, 0, 0, True), - aten.grid_sampler_2d(a, b, 0, 1, False), - ) - - self.common( - fn, - ( - torch.randn([4, 3, 352, 352], dtype=torch.float32), - torch.rand([4, 352, 352, 2], dtype=torch.float32) * 2 - 1, - ), - check_lowp=False, - # Mismatched elements: 154697 / 1486848 (10.4%) - # Greatest absolute difference: 0.0001976490020751953 at index (0, 0, 101, 243) (up to 1e-05 allowed) - # Greatest relative difference: 7.332530120481928 at index (1, 1, 258, 301) (up to 1.3e-06 allowed) - atol=0.0002, - rtol=1.3e-06, - ) - - def test_upsample_bicubic2d(self): - def fn(a): - return ( - aten.upsample_bicubic2d(a, (128, 128), True), - aten.upsample_bicubic2d(a, (128, 256), False), - ) - - # Mismatched elements: 10 / 196608 (0.0%) - # Greatest absolute difference: 1.3869255781173706e-05 at index (2, 1, 88, 65) (up to 1e-05 allowed) - # Greatest relative difference: 0.0033082996811011046 at index (3, 1, 88, 91) (up to 1.3e-06 allowed) - self.common( - fn, - (torch.randn([4, 3, 64, 32], dtype=torch.float32),), - atol=2e-5, - rtol=1e-3, - ) - - def test_sort(self): - def fn(a): - return torch.sort(a) - - self.common( - fn, (torch.randint(0, 999, size=[1, 1, 8, 8], dtype=torch.float32),) - ) - - def test_topk(self): - def fn(a): - return torch.topk(a, 2, -1) - - self.common( - fn, (torch.randint(0, 999, size=[1, 1, 8, 8], dtype=torch.float32),) - ) - - def test_long_tensor(self): - def fn(a): - return ( - torch.LongTensor([294]).to(a.device) - a, - torch.as_tensor([295]).to(a.device) + a, - ) - - self.common(fn, (torch.randint(0, 999, size=[8, 8]),)) - - def test_constant_pad_1d(self): - def fn(a): - return ( - aten.constant_pad_nd(a, [0, 1], 6.0), - aten.constant_pad_nd(a, [2, 3], 99.0), - ) - - self.common(fn, (torch.randint(0, 999, size=[2, 16, 31], dtype=torch.float32),)) - - def test_constant_pad_2d(self): - def fn(a): - return ( - aten.constant_pad_nd(a, [1, 1, 1, 1], 6.0), - aten.constant_pad_nd(a, [1, 2, 3, 4], 99.0), - ) - - self.common( - fn, (torch.randint(0, 999, size=[1, 1, 8, 8], dtype=torch.float32),) - ) - - def test_constant_pad_3d(self): - def fn(a): - return ( - aten.constant_pad_nd(a, [1, 2, 3, 4, 5, 6], 6.0), - aten.constant_pad_nd(a, [0, 0, 3, 4, 0, 0], 6.0), - ) - - self.common( - fn, (torch.randint(0, 999, size=[2, 4, 4, 4], dtype=torch.float32),) - ) - - def test_l1_loss(self): - def fn(a, b): - return torch.nn.functional.l1_loss(a, b), torch.nn.functional.mse_loss(a, b) - - self.common( - fn, - ( - torch.randn([2, 3, 16, 16]), - torch.randn([2, 3, 16, 16]), - ), - check_lowp=False, - ) - - def test_triu(self): - def fn(a): - return aten.triu(a, 1), aten.triu(a, 0), aten.triu(a, 2) - - self.common(fn, (torch.randn([2, 10, 10]),)) - - def test_no_op_reduction(self): - def fn(a): - return a.sum(-1), torch.amax(a + 1, 1, keepdim=True) - - self.common(fn, (torch.randn([8, 1, 1]),)) - - def test_inplace_add(self): - @torchdynamo.optimize("inductor") - def fn(x, y): - return x.add_(y) - - inputs = ( - rand_strided((4, 4), (4, 1), device=self.device), - rand_strided((4, 4), (4, 1), device=self.device), - ) - inp_clone = inputs[0].clone() - out = fn(*inputs) - self.assertTrue(same(out, inp_clone + inputs[1])) - self.assertTrue(out is inputs[0]) - - def test_inplace_mixed_dtype_ops(self): - @torchdynamo.optimize("inductor") - def fn(x, y): - z = x + y.float() - w = z.add_(y) - return w.mul_(y) - - inputs = ( - rand_strided((4, 4), (4, 1), device=self.device, dtype=torch.float), - rand_strided((4, 4), (4, 1), device=self.device, dtype=torch.double), - ) - out = fn(*inputs) - out_eager = (inputs[0] + inputs[1].float()).add_(inputs[1]).mul_(inputs[1]) - self.assertTrue(same(out, out_eager)) - - @patch.object(config.triton, "cudagraphs", True) - def test_strided_inputs(self): - @torchdynamo.optimize("inductor") - def fn(x, y): - return x + y - - inputs = ( - rand_strided((8, 16), (32, 2), device=self.device), - rand_strided((8, 16), (16, 1), device=self.device), - ) - self.assertTrue(same(fn(*inputs), inputs[0] + inputs[1])) - - @patch.object(config.triton, "cudagraphs", True) - @patch.object(functorch_config, "use_fake_tensor", True) - def test_input_mutation1(self): - def fn(a): - b = a + 1 - a.copy_(b) - c = a + 2 - return a * b / c - - arg1 = torch.randn(64, device=self.device) - arg2 = arg1.clone() - arg3 = torch.randn(64, device=self.device) - arg4 = arg3.clone() - correct1 = fn(arg1) - correct2 = fn(arg3) - opt_fn = torchdynamo.optimize_assert(compile_fx)(fn) - actual1 = opt_fn(arg2) - actual2 = opt_fn(arg4) - - self.assertTrue(same(actual1, correct1)) - self.assertTrue(same(actual2, correct2)) - self.assertTrue(same(arg1, arg2)) - self.assertTrue(same(arg3, arg4)) - - @patch.object(functorch_config, "use_fake_tensor", True) - def test_input_mutation2(self): - def fn(a): - b = a + 1 - a.view(64).copy_(torch.tensor([66.0], device=a.device)) - c = a + 2 - return b, c - - arg1 = torch.randn([1, 64], device=self.device) - arg2 = arg1.clone() - correct1 = fn(arg1) - opt_fn = torchdynamo.optimize_assert(compile_fx)(fn) - actual1 = opt_fn(arg2) - - self.assertTrue(same(actual1, correct1)) - self.assertTrue(same(arg1, arg2)) - - @patch.object(functorch_config, "use_fake_tensor", True) - def test_input_mutation3(self): - def fn(a): - a += 1 - a *= 2 - aten.sigmoid_(a) - a = a.view(64) - a += 3 - a *= 4 - aten.relu_(a) - return a - - arg1 = torch.randn([1, 64], device=self.device) - arg2 = arg1.clone() - correct1 = fn(arg1) - opt_fn = torchdynamo.optimize_assert(compile_fx)(fn) - actual1 = opt_fn(arg2) - - self.assertTrue(same(actual1, correct1)) - self.assertTrue(same(arg1, arg2)) - - def test_input_mutation4(self): - def fn(a): - torch.relu_(a) - return a - - arg1 = torch.randn([1, 64], device=self.device) - arg2 = arg1.clone() - correct1 = fn(arg1) - opt_fn = torchdynamo.optimize_assert(compile_fx)(fn) - actual1 = opt_fn(arg2) - - self.assertTrue(same(actual1, correct1)) - self.assertTrue(same(arg1, arg2)) - - @patch.object(functorch_config, "use_fake_tensor", True) - def test_slice_mutation1(self): - def fn(a): - x = torch.zeros_like(a) - b = x + 1 - x[:, 3] = 3.0 - c = torch.clone(x) - x[4, :] = 4.0 - d = x + 1 - return x, b, c, d - - self.common(fn, (torch.randn([8, 8]),)) - - @patch.object(functorch_config, "use_fake_tensor", True) - def test_slice_mutation2(self): - def fn(a): - a[:, 20:40] = a[:, 20:40] + 1 - a[:, 2:11] = a[:, 1:10] + 2 - - arg1 = torch.randn([1, 64], device=self.device) - arg2 = arg1.clone() - fn(arg1) - opt_fn = torchdynamo.optimize_assert(compile_fx)(fn) - opt_fn(arg2) - - self.assertTrue(same(arg1, arg2)) - - def test_indirect_load_broadcast(self): - def fn(in_ptr0, in_ptr1, in_ptr2): - return torch.gather(in_ptr1, 0, in_ptr2) + in_ptr0 - - arg190 = rand_strided((32, 21), (1, 32), device=self.device, dtype=torch.int64) - arg190.fill_(0) - arg111 = rand_strided( - (9521, 512), (512, 1), device=self.device, dtype=torch.float32 - ) - self.common( - fn, - ( - torch.randn(32, 1), - arg111, - arg190, - ), - ) - - @unittest.skipIf(not has_torchvision_roi_align(), "requirs torchvision") - def test_roi_align(self): - def fn(a, b): - return torch.ops.torchvision.roi_align(a, b, 0.25, 7, 7, 2, False) - - self.common(fn, (torch.zeros([4, 256, 296, 304]), torch.zeros([2292, 5]))) - - @requires_decomp(aten.nll_loss_forward) - def test_nll_loss_forward(self): - def fn(a, b): - return aten.nll_loss_forward(a, b, None, 1, -100) - - self.common( - fn, - ( - torch.randn([5, 5]), - torch.zeros([5], dtype=torch.int64), - ), - ) - - def test_isinf(self): - def fn(x): - return x.isinf(), x.isnan() - - self.common( - fn, [torch.tensor([1, float("inf"), 2, float("-inf"), float("nan")])] - ) - self.common( - fn, - [ - torch.tensor( - [1, float("inf"), 2, float("-inf"), float("nan")], - dtype=torch.float64, - ) - ], - ) - - def test_any(self): - def fn(x): - return ( - x.any(-1), - x.isinf().any(), - torch.all(x.isinf(), dim=0), - torch.all(torch.logical_not(x.isinf())), - ) - - self.common(fn, [-torch.rand(64)]) - tmp = torch.randn(16, 8) - tmp[1, 1] = float("inf") - self.common(fn, [tmp]) - - def test_inplace_activations(self): - def fn(x): - a = aten.hardswish_(x + 1) - b = aten.hardtanh_(x + 1) - c = aten.leaky_relu_(x + 1) - d = aten.silu_(x + 1) - e = aten.log1p(x + 1) - f = aten.masked_fill_(x + 1, torch.zeros_like(x, dtype=torch.bool), 99.0) - h = aten.masked_fill_(x + 1, torch.ones_like(x, dtype=torch.bool), 99.0) - return (a, b, c, d, e, f, h) - - self.common(fn, [torch.randn(64) * 10]) - - def test_baddbmm(self): - def fn(a, b, c): - return aten.baddbmm(a, b, c) - - self.common( - fn, - [ - torch.randn(6, 1, 100), - torch.randn(6, 128, 64), - torch.randn(6, 64, 100), - ], - # Mismatched elements: 1212 / 76800 (1.6%) - # Greatest absolute difference: 0.001953125 at index (0, 0, 93) (up to 1e-05 allowed) - # Greatest relative difference: 1.0 at index (3, 19, 4) (up to 0.001 allowed) - atol=0.002, - rtol=0.001, - ) - - @patch.object(config.triton, "max_tiles", 2) - def test_fuse_tiled(self): - def fn(a, b, c): - return a + b, c + 1 - - self.common( - fn, [torch.randn(128, 1), torch.randn(1, 128), torch.randn(128, 128)] - ) - - def test_expand_as(self): - def fn(a, b): - return aten.expand_as(a, b), aten.expand_as(a + 1, b + 1) + 1 - - self.common( - fn, - [ - torch.randn(6, 1, 100), - torch.randn(6, 128, 100), - ], - ) - - def test_index_put1(self): - def fn(a, b, c): - return ( - torch.index_put(a, [b], c), - torch.index_put_(a + 1, [b + 1], c + 1) + 1, - ) - - self.common( - fn, - [ - torch.randn([800, 256, 7, 7]), - torch.randperm(601), - torch.randn([601, 256, 7, 7]), - ], - ) - self.common( - fn, [torch.randn(1024, 4, 2), torch.arange(4), torch.randn(4, 1, 1)] - ) - - def test_index_put2(self): - def fn(a, b, c): - return torch.index_put(a, [b], c, True) - - self.common( - fn, - [ - torch.randn([100, 256, 7, 7]), - torch.randint(0, 100, size=[600], dtype=torch.int64), - torch.randn([600, 256, 7, 7]), - ], - # workaround for https://github.com/openai/triton/issues/558 - check_lowp=False, - ) - - def test_index_put3(self): - def fn(a, b, c): - torch.ops.aten.index_put_(a, (None, b, None), c) - a1 = a + 1 - torch.ops.aten.index_put_(a1, (None, b + 1, None), c + 1) - return (a, a1) - - self.common( - fn, - [ - torch.randn([1024, 4, 2]), - torch.arange(3), - torch.randn([1024, 1, 2]), - ], - ) - - def test_index_put_as_masked_fill(self): - def fn(a, b, c, d): - a = a.clone() - torch.ops.aten.index_put_(a, [b], c, d) - return a - - self.common( - fn, - ( - torch.randn([1024, 4, 2]), - torch.randn([1024, 4, 2]) > 0, - torch.randn([]), - False, - ), - ) - - self.common( - fn, - ( - torch.randn([1024, 4, 2]), - torch.randn([1024, 4, 2]) > 0, - torch.randn([]), - True, - ), - ) - - def test_index_put_fallback1(self): - def fn(a, b, c, d): - a = a.clone() - torch.ops.aten.index_put_(a, [b], c, d) - return a - - self.common( - fn, - ( - torch.randn([3]), - torch.as_tensor([True, True, False]), - torch.randn([2]), - False, - ), - ) - - self.common( - fn, - ( - torch.randn([3]), - torch.as_tensor([True, True, False]), - torch.randn([2]), - True, - ), - ) - - def test_index_put_fallback2(self): - def fn(a, b, c, d, e): - a = a.clone() - torch.ops.aten.index_put_(a, [None, b, c], d, e) - return a - - self.common( - fn, - ( - torch.randn([1, 2, 3]), - torch.as_tensor([0, 1]), - torch.as_tensor([True, True, False]), - torch.randn([]), - False, - ), - ) - self.common( - fn, - ( - torch.randn([1, 2, 3]), - torch.as_tensor([0, 1]), - torch.as_tensor([True, True, False]), - torch.randn([]), - True, - ), - ) - - @patch.object(config, "fallback_random", True) - def test_bernoulli1(self): - def fn(a): - b = torch.empty_like(a) - return aten.bernoulli_(b), b - - self.common( - fn, - [ - torch.randn([100]), - ], - ) - - def test_bernoulli2(self): - def fn(a): - return aten.bernoulli(a) - - self.common( - fn, - [torch.tensor([1.0, 1.0, 0.0, 0.0, 1.0, 0.0, 1.0, 1.0])], - ) - - def test_narrow(self): - def fn(x): - return aten.narrow(x, 1, 10, 16), aten.narrow(x + 2, 0, 10, 16) + 1 - - self.common(fn, [torch.randn(64, 64)]) - - def test_as_strided(self): - def fn(x): - return ( - aten.as_strided(x, (8, 8, 64), (8 * 64, 64, 1), 0), - aten.as_strided(x + 1, (8, 8, 64), (8 * 64, 64, 1), 0) + 2, - ) - - self.common(fn, [torch.randn(64, 64)]) - - def test_select_scatter(self): - def fn(x, a, b): - return ( - aten.select_scatter(x, a, 1, 0), - aten.select_scatter(x, b, 0, 1), - ) - - self.common( - fn, - [ - torch.randn(8, 197, 38), - torch.randn(8, 38), - torch.randn(197, 38), - ], - ) - - def test_slice_scatter(self): - def fn(x, a): - return ( - aten.slice_scatter(x, a, 2, 10, -10), - aten.slice_scatter(x, a[:, :, :40], 2, 10, -10, 2), - ) - - self.common( - fn, - [ - torch.randn(4, 8, 100), - torch.randn(4, 8, 80), - ], - ) - - def test_slice_scatter2(self): - def fn(a, b): - return aten.slice_scatter(a, b, 0, 0, 9223372036854775807) - - self.common( - fn, - [ - torch.randn([8, 197, 384]), - torch.randn([8, 197, 384]), - ], - ) - - def test_scatter1(self): - def fn(a, dim, index, b): - return aten.scatter(a, dim, index, b) - - self.common( - fn, - [ - torch.zeros(2, 3), - -1, - torch.tensor([[0]]), - torch.ones(2, 3), - ], - ) - - def test_scatter2(self): - def fn(a, dim, index, b): - return aten.scatter.reduce(a, dim, index, b, reduce="add") - - self.common( - fn, - [ - torch.zeros(64, 512), - 0, - torch.zeros((64, 512), dtype=torch.int64), - torch.ones(64, 512), - ], - ) - - def test_scatter3(self): - def fn(a, dim, index, b): - return aten.scatter(a, dim, index, b, reduce="add") - - self.common( - fn, - [ - torch.randn(5, 29, 13), - 2, - torch.tensor([[[3, 5, 7, 9]]]), - 0.8, # src can be a scalar - ], - # Mismatched elements: 1 / 1885 (0.1%) - # Greatest absolute difference: 0.00018310546875 at index (0, 0, 3) (up to 1e-05 allowed) - # Greatest relative difference: 0.0022371364653243847 at index (0, 0, 3) (up to 0.001 allowed) - atol=2e-4, - rtol=1e-3, - ) - - def test_scatter4(self): - def fn(x, ind, src): - return torch.scatter(x, 0, ind, src) - - self.common( - fn, - (torch.randn(196, 992), torch.randint(196, (1, 992)), torch.randn(1, 992)), - ) - - @unittest.skip("Flaky test, needs debugging") - def test_scatter_add1(self): - def fn(a, dim, index, b): - return aten.scatter_add(a, dim, index, b) - - self.common( - fn, - [ - torch.randn(2, 3), - 0, - torch.tensor([[0]]), - torch.randn(2, 3), - ], - ) - - def test_scatter_add2(self): - def fn(a, dim, index, b): - return aten.scatter_add(a, dim, index, b) - - self.common( - fn, - [ - torch.randn(2, 3), - 0, - torch.tensor([[0, 0, 0], [1, 1, 1]]), - torch.randn(2, 3), - ], - ) - - def test_scatter_add3(self): - def fn(a, dim, index, b): - return aten.scatter_add(a, dim, index, b) - - self.common( - fn, - [ - torch.randn(5, 29, 13), - 2, - torch.tensor([[[3, 5, 7, 9]]]), - torch.randn(1, 1, 10), - ], - ) - - def test_scatter_reduce1(self): - def fn(a, dim, index, b): - return aten.scatter_reduce(a, dim, index, b, "sum") - - self.common( - fn, - [ - torch.randn(5, 29, 13), - 2, - torch.tensor([[[3, 5, 7, 9]]]), - torch.randn(1, 1, 10), - ], - ) - - def test_scatter_reduce2(self): - def fn(a, dim, index, b): - return aten.scatter_reduce(a, dim, index, b, "sum", include_self=False) - - self.common( - fn, - [ - torch.randn(2, 3), - 0, - torch.zeros((2, 3), dtype=torch.int64), - torch.randn(2, 3), - ], - ) - - # issue #1150 - def test_dense_mask_index(self): - def fn(x, y): - y = torch.ops.aten.select.int(y, 0, 2) - z = x * y - return z.sum() - - self.common(fn, [torch.randn(102400), torch.randn(3)]) - - def test_new_empty_strided(self): - def fn(a): - return aten.new_empty_strided(a, [1, 128, 128], [16384, 128, 1]).fill_(123) - - self.common(fn, [torch.randn(55)]) - - @patch.object(torchinductor.config.triton, "cudagraphs", True) - def test_dropout(self): - random.seed(1234) - torch.manual_seed(1234) - - @torchdynamo.optimize("inductor") - def fn(a): - return torch.nn.functional.dropout(a, 0.5, True) - - x = torch.ones(1000, device=self.device, dtype=torch.float32) - result = fn(x) - self.assertTrue(400 < result.nonzero().shape[0] < 600) - self.assertTrue(0.9 < result.mean().item() < 1.1) - - def test_dropout_deterministic(self): - @torchdynamo.optimize("inductor") - def fn(a): - return torch.nn.functional.dropout(a, 0.55, True) - - for cg in (False, True): - with patch.object(torchinductor.config.triton, "cudagraphs", cg): - torchdynamo.reset() - - x = torch.ones(1024, device=self.device, dtype=torch.float32) - - torch.manual_seed(1234) - a0 = fn(x).clone() - a1 = fn(x).clone() - a2 = fn(x).clone() - - torch.manual_seed(1234) - b0 = fn(x).clone() - b1 = fn(x).clone() - b2 = fn(x).clone() - - # same seed, same values - self.assertTrue(torch.allclose(a0, b0)) - self.assertTrue(torch.allclose(a1, b1)) - self.assertTrue(torch.allclose(a2, b2)) - - # different calls, different values - self.assertFalse(torch.allclose(a0, a1)) - self.assertFalse(torch.allclose(a1, a2)) - - def test_rand_like_deterministic(self): - @torchdynamo.optimize("inductor") - def fn(a): - return torch.rand_like(a), torch.rand_like(a) - - x = torch.ones(1024, device=self.device, dtype=torch.float32) - - torch.manual_seed(1234) - a0 = fn(x)[0].clone() - a1 = fn(x)[0].clone() - a2 = fn(x)[0].clone() - - torch.manual_seed(1234) - b0 = fn(x)[0].clone() - b1 = fn(x)[0].clone() - b2 = fn(x)[0].clone() - - # same seed, same values - self.assertTrue(torch.allclose(a0, b0)) - self.assertTrue(torch.allclose(a1, b1)) - self.assertTrue(torch.allclose(a2, b2)) - - # different calls, different values - self.assertFalse(torch.allclose(a0, a1)) - self.assertFalse(torch.allclose(a1, a2)) - - c, d = fn(x) - self.assertFalse(torch.allclose(c, d)) - self.assertTrue((c >= 0).all()) - self.assertTrue((c < 1).all()) - self.assertTrue((d >= 0).all()) - self.assertTrue((d < 1).all()) - - def test_max_pool2d_with_indices_backward(self): - def fn(a, b, c): - return aten.max_pool2d_with_indices_backward( - a, b, [2, 2], [2, 2], [0, 0], [1, 1], False, c - ) - - x = torch.randn([2, 4, 18, 14]) - result, indices = aten.max_pool2d_with_indices( - x, - [2, 2], - [2, 2], - [0, 0], - [1, 1], - False, - ) - - self.common( - fn, - [ - torch.randn_like(result), - x, - indices, - ], - ) - - def test_max_pool2d_with_indices_backward2(self): - def fn(a, b, c): - return aten.max_pool2d_with_indices_backward( - a, b, [3, 3], [2, 2], [1, 1], [1, 1], True, c - ) - - x = torch.randn([2, 4, 40, 56]) - result, indices = aten.max_pool2d_with_indices( - x, - [3, 3], - [2, 2], - [1, 1], - [1, 1], - True, - ) - - self.common( - fn, - [ - torch.randn_like(result), - x, - indices, - ], - ) - - # From https://github.com/pytorch/torchdynamo/issues/1200 - def test_max_pool2d_with_indices_backward3(self): - def fn(a, b, c): - return aten.max_pool2d_with_indices_backward( - a, b, [1, 1], [2, 2], [0, 0], [1, 1], False, c - ) - - x = torch.randn([32, 256, 37, 38]) - result, indices = aten.max_pool2d_with_indices( - x, - [1, 1], - [2, 2], - 0, - 1, - False, - ) - self.common( - fn, - [ - torch.randn_like(result), - x, - indices, - ], - ) - - def test_avg_pool2d_backward(self): - def fn(a, b): - return aten.avg_pool2d_backward( - a, - b, - [2, 2], - [2, 2], - [0, 0], - True, - False, - None, - ) - - self.common( - fn, - [ - torch.randn([2, 4, 7, 7]), - torch.randn([2, 4, 14, 14]), - ], - ) - - def test_avg_pool2d_backward2(self): - def fn(a, b): - return aten.avg_pool2d_backward( - a, - b, - [3, 3], - [1, 1], - [1, 1], - True, - False, - None, - ) - - self.common( - fn, - [ - torch.randn([1, 1, 20, 15]), - torch.randn([1, 1, 20, 15]), - ], - ) - - def test_avg_pool2d_backward3(self): - def fn(a, b): - return aten.avg_pool2d_backward( - a, - b, - [1, 1], - [2, 2], - [0, 0], - False, - False, - None, - ) - - self.common( - fn, - [ - torch.randn([1, 2016, 11, 11]), - torch.randn([1, 2016, 21, 21]), - ], - ) - - def test_mm_views(self): - def fn(a, b): - return torch.mm(a.view(32, 32), b.view(32, 32)) - - self.common( - fn, - ( - torch.randn([32, 32]).transpose(0, 1), - torch.randn([1, 32, 32]).transpose(0, 1), - ), - check_lowp=False, - ) - expected_kernel = 0 - # codegen mm kernel from template - if config.triton.mm != "aten" and self.device == "cuda": - expected_kernel = 1 - if config.triton.mm == "autotune": - self.assertLessEqual( - torchinductor.metrics.generated_kernel_count, expected_kernel - ) - self.assertEqual(torchinductor.metrics.generated_kernel_count, expected_kernel) - - @patch.object(config.triton, "cudagraphs", False) - def test_lowmem_dropout1(self): - n = 100000 - weight = torch.ones( - n, device=self.device, dtype=torch.float32, requires_grad=True - ) - ones = torch.ones(n, device=self.device, dtype=torch.float32) - - @torchdynamo.optimize_assert("inductor") - def run(x, train=True): - return F.dropout(x * weight, 0.33, train) - - def check(r, g): - rmean = r.mean().item() - gmean = g.mean().item() - rcount = len(r.nonzero()) - gcount = len(g.nonzero()) - - # dropped elements should match - self.assertTrue(same(r.nonzero(), g.nonzero())) - self.assertEqual(rcount, gcount) - - # dropped should be close to 0.33 - self.assertGreater(rcount, 0.64 * n) - self.assertGreater(0.68 * n, rcount) - - self.assertAlmostEqual(rmean, gmean) - self.assertAlmostEqual(rmean, 1.0, places=2) - - r1 = run(ones, train=False) - r1.sum().backward() - g1 = weight.grad.clone() - # eval mode should be all ones - self.assertTrue(same(r1, torch.ones_like(r1))) - self.assertTrue(same(g1, torch.ones_like(g1))) - - torch.manual_seed(1234) - weight.grad.zero_() - r2 = run(ones) - r2.sum().backward() - g2 = weight.grad.clone() - check(r2, g2) - - torch.manual_seed(1234) - weight.grad.zero_() - r3 = run(ones) - r3.sum().backward() - g3 = weight.grad.clone() - check(r3, g3) - - # second run is same result as first - self.assertTrue(same(r2, r3)) - self.assertTrue(same(g2, g3)) - - def test_lowmem_dropout2(self): - m = torch.nn.Sequential( - torch.nn.Linear(32, 32, bias=False), - torch.nn.Dropout(), - torch.nn.Linear(32, 32, bias=False), - torch.nn.Dropout(), - ).to(self.device) - - @torchdynamo.optimize_assert("inductor") - def run(x): - return m(x) - - torchinductor.metrics.generated_kernel_count = 0 - result = run(torch.randn([8, 32], device=self.device)) - result.sum().backward() - - expected_kernel = 4 - if config.triton.mm != "aten" and self.device == "cuda": - # fwd: 2 * (mm+dropout) kernels = 2 kernels - # bwd: dropout + (mm) + 2 * (mm+dropout) kernels = 4 kernels - # expect 2 + 4 = 6 kernels - expected_kernel = 6 - if config.triton.mm == "autotune": - self.assertLessEqual( - torchinductor.metrics.generated_kernel_count, expected_kernel - ) - self.assertEqual(torchinductor.metrics.generated_kernel_count, expected_kernel) - - def test_roll(self): - def fn(a): - return ( - aten.roll(a, [-3, 10], [1, 2]), - aten.roll(a, [5]), - ) - - self.common( - fn, - [ - torch.randn([2, 56, 56, 16]), - ], - ) - - def test_argmax_argmin1(self): - def fn(x): - return (aten.argmax(x), aten.argmin(x)) - - self.common( - fn, - [ - torch.randn([8, 256, 256]), - ], - ) - - def test_argmax_argmin2(self): - def fn(x): - return ( - aten.argmax(x, 0), - aten.argmin(x, 0), - aten.argmax(x, 1), - aten.argmin(x, 1), - ) - - self.common( - fn, - [ - torch.randn([144, 144]), - ], - # Mismatched elements: 1 / 144 (0.7%) - # Greatest absolute difference: 26 at index (71,) - # Greatest relative difference: 0.4126984179019928 at index (71,) - atol=1e-5, - rtol=0.5, - ) - - @unittest.skip( - """ - FIXME: In the case of having equally max/min elements, our implementation returns - the last index instead of the first one - """ - ) - def test_argmax_argmin3(self): - def fn(x): - return ( - aten.argmax(x, 0), - aten.argmin(x, 0), - aten.argmax(x, -1), - aten.argmin(x, -1), - ) - - self.common( - fn, - [torch.randint(0, 5, [10, 10])], - ) - - def test_vdd_clamp(self): - def fn(x): - return torch.clamp_min(x, 3) - - self.common( - fn, - [ - torch.randn([16], requires_grad=True) * 10, - ], - ) - - def test_tmp_not_defined_issue1(self): - def forward( - primals_3, - primals_4, - add_tensor, - convert_element_type_default, - div_default, - reciprocal_default, - ): - var_default = torch.ops.prims.var.default( - convert_element_type_default, [2], correction=0 - ) - sub_tensor = torch.ops.aten.sub.Tensor(add_tensor, div_default) - mul_tensor_1 = torch.ops.aten.mul.Tensor(sub_tensor, reciprocal_default) - mul_tensor_2 = torch.ops.aten.mul.Tensor(mul_tensor_1, primals_3) - add_tensor_2 = torch.ops.aten.add.Tensor(mul_tensor_2, primals_4) - convert_element_type_default_1 = ( - torch.ops.prims.convert_element_type.default( - add_tensor_2, torch.float32 - ) - ) - convert_element_type_default_2 = ( - torch.ops.prims.convert_element_type.default( - convert_element_type_default_1, torch.float32 - ) - ) - var_default_1 = torch.ops.prims.var.default( - convert_element_type_default_2, [2], correction=0 - ) - broadcast_in_dim_default_2 = torch.ops.prims.broadcast_in_dim.default( - var_default_1, [1, 512, 1], [0, 1] - ) - sum_default_1 = torch.ops.prims.sum.default( - convert_element_type_default_2, [2] - ) - add_tensor_3 = torch.ops.aten.add.Tensor(broadcast_in_dim_default_2, 1e-05) - return (var_default, sum_default_1, add_tensor_3) - - inps = [ - (torch.Size([1024]), torch.float32), - (torch.Size([1024]), torch.float32), - (torch.Size([1, 512, 1024]), torch.float32), - (torch.Size([1, 512, 1024]), torch.float32), - (torch.Size([1, 512, 1]), torch.float32), - (torch.Size([1, 512, 1]), torch.float32), - ] - inps = [torch.randn(shape, dtype=dtype) for (shape, dtype) in inps] - self.common(forward, inps, atol=1e-05, rtol=2e-05) - - @unittest.skipIf( - TEST_WITH_ASAN - or os.environ.get("BUILD_ENVIRONMENT", "").startswith("parallelnative"), - "TODO: debug this with asan", - ) - def test_tmp_not_defined_issue2(self): - def forward(arg38_1, arg81_1, getitem_17, new_zeros_default_4): - div_tensor_7 = torch.ops.aten.div.Tensor(getitem_17, arg81_1) - mul_tensor_24 = torch.ops.aten.mul.Tensor(div_tensor_7, arg38_1) - sum_default_7 = torch.ops.aten.sum.default(mul_tensor_24) - return (new_zeros_default_4, sum_default_7) - - args = [ - ((1, 88, 40, 40), (140800, 1600, 40, 1), torch.float32), - ((), (), torch.float32), - ((1, 88, 40, 40), (140800, 1600, 40, 1), torch.float32), - ((3,), (1,), torch.float32), - ] - args = [rand_strided(shape, stride, dtype) for shape, stride, dtype in args] - self.common(forward, args) - - def test_misaligned_address_issue1(self): - def forward(sub_tensor_1, unsqueeze_default): - gather_default = torch.ops.aten.gather.default( - sub_tensor_1, 1, unsqueeze_default - ) - return gather_default - - args = [ - ((1, 1000), (1000, 1), torch.float32), - ((1, 1), (1, 1), torch.int64), - ] - args = [rand_strided(shape, stride, dtype) for shape, stride, dtype in args] - self.common(forward, args) - - def test_invalid_operand_issue1(self): - def forward(arg0_1, arg1_1, arg3_1, squeeze, view_1, slice_1): - slice_scatter = torch.ops.aten.slice_scatter.default( - slice_1, arg3_1, 1, 1, 9223372036854775807 - ) - slice_scatter_1 = torch.ops.aten.slice_scatter.default( - arg1_1, slice_scatter, 0, 0, 9223372036854775807 - ) - slice_2 = torch.ops.aten.slice.Tensor( - slice_scatter_1, 0, 0, 9223372036854775807 - ) - select_scatter = torch.ops.aten.select_scatter.default( - slice_2, squeeze, 1, 0 - ) - slice_scatter_2 = torch.ops.aten.slice_scatter.default( - slice_scatter_1, select_scatter, 0, 0, 9223372036854775807 - ) - view = torch.ops.aten.view.default(slice_scatter_2, [-1, 128]) - embedding = torch.ops.aten.embedding.default(arg0_1, view, 1) - return [embedding, view_1] - - args = [ - ((50005, 768), (768, 1), torch.float32), - ((8, 128), (128, 1), torch.int64), - ((8, 127), (127, 1), torch.int64), - ((8,), (1,), torch.int64), - ((1024,), (1,), torch.int64), - ((8, 128), (128, 1), torch.int64), - ] - args = [rand_strided(shape, stride, dtype) for shape, stride, dtype in args] - self.common(forward, args) - - def test_sizehint_issue1(self): - def forward(x): - return torch.nn.functional.unfold( - x, kernel_size=[4, 4], dilation=1, padding=0, stride=[4, 4] - ) - - args = [((2, 24, 56, 56), (75264, 3136, 56, 1), torch.float32, False)] - args = [ - rand_strided(sh, st, dt).requires_grad_(rg) for (sh, st, dt, rg) in args - ] - self.common(forward, args) - - @unittest.skip("https://github.com/pytorch/torchdynamo/issues/1297") - @patch.object(torchinductor.config.triton, "cudagraphs", False) - def test_symbolic(self): - def f(x): - x = x.cos() - x = x.view(x.shape[0] * 2, -1) - return (x,) - - traced = make_fx(f, tracing_mode="symbolic")( - torch.randn(8, 4, device=self.device) - ) - compiled = compile_fx_inner(traced, [torch.randn(8, 4, device=self.device)]) - - out = compiled([torch.randn(8, 4, device=self.device)]) - self.assertEqual(out[0].shape, (16, 2)) - - out = compiled([torch.randn(12, 4, device=self.device)]) - self.assertEqual(out[0].shape, (24, 2)) - - @requires_cuda() - @patch.object(config.triton, "cudagraphs", False) - def test_unspec_inputs(self): - def fn(x, y): - return x + y - - inputs = ( - rand_strided((2, 3), (3, 1), device="cuda"), - rand_strided((), (), device="cpu"), - ) - self.assertTrue(same(fn(*inputs), inputs[0] + inputs[1])) - - @requires_cuda() - @patch.object(config.triton, "cudagraphs", True) - def test_unspec_inputs_cudagraphs(self): - def fn(x, y): - return x + y - - inputs = ( - rand_strided((2, 3), (3, 1), device="cuda"), - rand_strided((), (), device="cpu"), - ) - self.assertTrue(same(fn(*inputs), inputs[0] + inputs[1])) - - @patch.object(config.triton, "mm", "aten") - def test_list_clearing(self): - - if self.device == "cpu": - contexts = [contextlib.nullcontext] - else: - contexts = [ - contextlib.nullcontext, - lambda: patch.object(config.triton, "cudagraphs", True), - ] - - for context in contexts: - with context(): - inps = [ - torch.rand([5, 5]).to(self.device), - torch.rand([5, 5]).to(self.device), - ] - inp_refs = [weakref.ref(inp) for inp in inps] - - def fn(x, y): - a = x + y - return (a @ a,) - - fn_fx = make_fx(fn)(inps[0], inps[1]) - fn_compiled = compile_fx_inner(fn_fx, inps) - - test_self = self - matmul_seen = False - - class TestRefMode(TorchDispatchMode): - def __torch_dispatch__(self, func, types, args=(), kwargs=None): - kwargs = kwargs if kwargs else {} - - nonlocal inps - nonlocal inp_refs - nonlocal test_self - nonlocal matmul_seen - - # by matmul, inputs should be deallocated - if func is aten.mm.out: - matmul_seen = True - test_self.assertEqual(len(inps), 0) - test_self.assertIsNone(inp_refs[0]()) - test_self.assertIsNone(inp_refs[1]()) - - return func(*args, **kwargs) - - with TestRefMode(): - fn_compiled(inps) - - # for some reason, TorchDispatch doesnt capture the - # cuda mm call (even without cudagraphs) - if self.device == "cpu": - self.assertTrue(matmul_seen) - else: - self.assertEqual(len(inps), 0) - - -if HAS_CPU: - - class CpuTests(TestCase): - common = check_model - device = "cpu" - - CommonTemplate.install(CpuTests, "cpu") - - class CPUReproTests(TestCase): - def test_inplace_squeeze_needed(self): - mod = torch.nn.Sequential( - torch.nn.Linear(10, 10), - torch.nn.LayerNorm(10), - torch.nn.ReLU(), - ).eval() - - @torchdynamo.optimize("inductor") - def fn(x): - return mod(x) - - v = torch.randn(10) - result = fn(v) - assert same(result, mod(v)) - - def test_inplace_add_alpha(self): - def fn(x, y): - aten.add_.Tensor(x, y, alpha=0.55) - return (x,) - - x1 = torch.zeros(10) - x2 = torch.zeros(10) - x3 = torch.zeros(10) - y = torch.randn(10) - fn_fx = make_fx(fn)(x1, y) - fn_compiled = compile_fx_inner(fn_fx, [x1, y]) - fn(x2, y) - fn_compiled([x3, y]) - assert same(x2, x3) - - def test_no_op_squeeze(self): - @torchdynamo.optimize("inductor") - def forward(arg0_1): - return torch.ops.aten.squeeze.dim(arg0_1, 1) - - x = torch.randn((10, 20)) - assert same(x, forward(x)) - - def test_parallel_num_threads(self): - @torchdynamo.optimize("inductor") - def fn(x1, x2): - return x1 + x2 - - @contextlib.contextmanager - def set_num_threads(num_threads): - orig_num_threads = torch.get_num_threads() - torch.set_num_threads(num_threads) - yield - torch.set_num_threads(orig_num_threads) - - x1 = torch.randn((10, 20)) - x2 = torch.randn((10, 20)) - with set_num_threads(1): - assert same(x1 + x2, fn(x1, x2)) - with set_num_threads(4): - assert same(x1 + x2, fn(x1, x2)) - - @patch("torch.cuda.is_available", lambda: False) - def test_timed_cpu_only(self): - timed(lambda: torch.randn(10), ()) - - -if HAS_CUDA: - - class SweepInputsCudaTest(SweepInputs2, TestCase): - gen = InputGen(10, "cuda") - - SweepInputsCudaTest.populate() - - class CudaTests(TestCase): - common = check_model_cuda - device = "cuda" - - def test_simplify_dims(self): - def fn(a): - return (a + 1,) - - self.common( - fn, (torch.randn(2, 3, 10, 5, 6, device="cuda")[:, :, 2::2, :, :],) - ) - - CommonTemplate.install(CudaTests, "cuda") - - class CudaReproTests(TestCase): - def test_index_put_issue(self): - def forward( - self, - arg76_1, - expand_default, - full_like_default, - _to_copy_default_67, - zeros, - ): - sum_sym_int_19 = torch.ops.aten.sum(_to_copy_default_67, [0], True) - view_default_57 = torch.ops.aten.view.default( - sum_sym_int_19, [512, 768] - ) - where_self = torch.ops.aten.where.self( - expand_default, view_default_57, full_like_default - ) - clone_default_12 = torch.ops.aten.clone.default(zeros) - index_put__default = torch.ops.aten.index_put_.default( - clone_default_12, [arg76_1], where_self, True - ) - return (index_put__default,) - - inps = [ - (torch.Size([512]), torch.int64), - (torch.Size([512, 768]), torch.bool), - (torch.Size([512, 768]), torch.float16), - (torch.Size([4, 512, 768]), torch.float16), - (torch.Size([512, 768]), torch.float16), - ] - inps = [torch.zeros(())] + [ - torch.ones(shape, dtype=dtype, device="cuda") for (shape, dtype) in inps - ] - mod = make_fx(forward)(*inps) - compiled = compile_fx_inner(mod, inps) - compiled(inps) - - @patch.object(config, "fallback_random", True) - def test_dtype_factory_issue(self): - def forward(): - randn = torch.ops.aten.randn.default( - [12, 64, 1, 64], - dtype=torch.float32, - device=torch.device(type="cuda", index=0), - pin_memory=False, - ) - unsqueeze_default_2 = torch.ops.aten.unsqueeze.default(randn, -1) - return (unsqueeze_default_2,) - - mod = make_fx(forward)() - compiled = compile_fx_inner(mod, ()) - assert compiled([])[0].device.type == "cuda" - - @patch.object(config.triton, "cudagraphs", True) - def test_expanded_inputs_cudagraphs(self): - @torchdynamo.optimize("inductor") - def fn(x, y): - return x + y - - inputs = ( - rand_strided((5, 5, 5, 5), (0, 5, 0, 1), device="cuda"), - rand_strided((5, 5, 5, 5), (0, 5, 0, 1), device="cuda"), - ) - self.assertTrue(same(fn(*inputs), inputs[0] + inputs[1])) - - @patch.object(config, "size_asserts", False) - @patch.object(config.triton, "cudagraphs", True) - def test_expanded_inputs_cudagraphs_no_size_asserts(self): - @torchdynamo.optimize("inductor") - def fn(x, y): - return x + y - - inputs = ( - rand_strided((5, 5, 5, 5), (0, 5, 0, 1), device="cuda"), - rand_strided((5, 5, 5, 5), (0, 5, 0, 1), device="cuda"), - ) - self.assertTrue(same(fn(*inputs), inputs[0] + inputs[1])) - - def test_accuracy_issue1(self): - class Repro(torch.nn.Module): - def __init__(self): - super().__init__() - self.linear = torch.nn.Linear( - in_features=768, out_features=2, bias=True - ) - - def forward(self, start_positions: torch.Tensor, x: torch.Tensor): - linear = self.linear(x) - split = linear.split(1, dim=-1) - getitem = split[0] - squeeze = getitem.squeeze(-1) - clamp = start_positions.clamp(0, 128) - cross_entropy = torch.nn.functional.cross_entropy( - squeeze, clamp, None, None, 128, None, "mean", 0.0 - ) - return cross_entropy - - mod = Repro().cuda() - opt_mod = torchdynamo.optimize("inductor")(mod) - mod.eval() - opt_mod.eval() - - args = [ - ((1,), (1,), torch.int64, "cuda", False), - ((1, 128, 768), (98304, 768, 1), torch.float32, "cuda", True), - ] - args = [ - rand_strided(sh, st, dt, dev).requires_grad_(rg) - for (sh, st, dt, dev, rg) in args - ] - with torch.cuda.amp.autocast(enabled=False): - assert same_two_models(mod, opt_mod, args), "Dynamo failed" - - -if __name__ == "__main__": - from torchdynamo.test_case import run_tests - - if (HAS_CPU or HAS_CUDA) and not TEST_WITH_ROCM: - run_tests(needs="filelock")