From 747fc35ff54154ddec2a5ab5661f57c28d65c591 Mon Sep 17 00:00:00 2001 From: Animesh Jain Date: Thu, 6 Jun 2024 13:22:35 -0700 Subject: [PATCH] [dynamo] Support if cond on UnspecializedNNModuleVariable and add inline tests (#128158) Pull Request resolved: https://github.com/pytorch/pytorch/pull/128158 Approved by: https://github.com/jansel ghstack dependencies: #128001, #126578 --- test/dynamo/test_inline_inbuilt_nn_modules.py | 62 +++++++++++++++++++ torch/_dynamo/symbolic_convert.py | 8 ++- 2 files changed, 69 insertions(+), 1 deletion(-) create mode 100644 test/dynamo/test_inline_inbuilt_nn_modules.py diff --git a/test/dynamo/test_inline_inbuilt_nn_modules.py b/test/dynamo/test_inline_inbuilt_nn_modules.py new file mode 100644 index 00000000000000..f7ba32bc15f3ba --- /dev/null +++ b/test/dynamo/test_inline_inbuilt_nn_modules.py @@ -0,0 +1,62 @@ +# Owner(s): ["module: dynamo"] + +from torch._dynamo import config +from torch._dynamo.testing import make_test_cls_with_patches + +try: + from . import ( + test_aot_autograd, + test_functions, + test_higher_order_ops, + test_misc, + test_modules, + # test_repros, + ) +except ImportError: + import test_aot_autograd + import test_functions + import test_higher_order_ops + import test_misc + import test_modules + + +test_classes = {} + + +def make_inline_inbuilt_nn_modules_cls(cls): + suffix = "_inline_inbuilt_nn_modules" + + cls_prefix = "InlineInbuiltNNModules" + + test_class = make_test_cls_with_patches( + cls, + cls_prefix, + suffix, + (config, "inline_inbuilt_nn_modules", True), + xfail_prop="_expected_failure_inline_inbuilt_nn_modules", + ) + + test_classes[test_class.__name__] = test_class + # REMOVING THIS LINE WILL STOP TESTS FROM RUNNING + globals()[test_class.__name__] = test_class + test_class.__module__ = __name__ + return test_class + + +tests = [ + test_misc.MiscTests, + test_functions.FunctionTests, + test_modules.NNModuleTests, + test_higher_order_ops.HigherOrderOpTests, + test_higher_order_ops.FuncTorchHigherOrderOpTests, + test_aot_autograd.AotAutogradFallbackTests, + # test_repros.ReproTests, +] +for test in tests: + make_inline_inbuilt_nn_modules_cls(test) +del test + +if __name__ == "__main__": + from torch._dynamo.test_case import run_tests + + run_tests() diff --git a/torch/_dynamo/symbolic_convert.py b/torch/_dynamo/symbolic_convert.py index 30f28e2ab265fb..da04fdfa8584ae 100644 --- a/torch/_dynamo/symbolic_convert.py +++ b/torch/_dynamo/symbolic_convert.py @@ -101,7 +101,7 @@ PythonModuleVariable, UnknownVariable, ) -from .variables.nn_module import NNModuleVariable +from .variables.nn_module import NNModuleVariable, UnspecializedNNModuleVariable from .variables.tensor import supported_comparison_ops, SymNodeVariable, TensorVariable from .variables.user_defined import ( RemovableHandleVariable, @@ -414,6 +414,12 @@ def inner(self: "InstructionTranslatorBase", inst: Instruction): if push: self.push(value) self.jump(inst) + elif isinstance(value, UnspecializedNNModuleVariable): + mod = value.value + if truth_fn(mod): + if push: + self.push(value) + self.jump(inst) elif isinstance(value, UserDefinedObjectVariable): try: x = value.var_getattr(self, "__bool__")