Skip to content

Commit

Permalink
[dynamo] Support if cond on UnspecializedNNModuleVariable and add inl…
Browse files Browse the repository at this point in the history
…ine tests (pytorch#128158)

Pull Request resolved: pytorch#128158
Approved by: https://github.com/jansel
ghstack dependencies: pytorch#128001, pytorch#126578
  • Loading branch information
anijain2305 authored and pytorchmergebot committed Jun 7, 2024
1 parent 5e5bbdb commit 747fc35
Show file tree
Hide file tree
Showing 2 changed files with 69 additions and 1 deletion.
62 changes: 62 additions & 0 deletions test/dynamo/test_inline_inbuilt_nn_modules.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
# Owner(s): ["module: dynamo"]

from torch._dynamo import config
from torch._dynamo.testing import make_test_cls_with_patches

try:
from . import (
test_aot_autograd,
test_functions,
test_higher_order_ops,
test_misc,
test_modules,
# test_repros,
)
except ImportError:
import test_aot_autograd
import test_functions
import test_higher_order_ops
import test_misc
import test_modules


test_classes = {}


def make_inline_inbuilt_nn_modules_cls(cls):
suffix = "_inline_inbuilt_nn_modules"

cls_prefix = "InlineInbuiltNNModules"

test_class = make_test_cls_with_patches(
cls,
cls_prefix,
suffix,
(config, "inline_inbuilt_nn_modules", True),
xfail_prop="_expected_failure_inline_inbuilt_nn_modules",
)

test_classes[test_class.__name__] = test_class
# REMOVING THIS LINE WILL STOP TESTS FROM RUNNING
globals()[test_class.__name__] = test_class
test_class.__module__ = __name__
return test_class


tests = [
test_misc.MiscTests,
test_functions.FunctionTests,
test_modules.NNModuleTests,
test_higher_order_ops.HigherOrderOpTests,
test_higher_order_ops.FuncTorchHigherOrderOpTests,
test_aot_autograd.AotAutogradFallbackTests,
# test_repros.ReproTests,
]
for test in tests:
make_inline_inbuilt_nn_modules_cls(test)
del test

if __name__ == "__main__":
from torch._dynamo.test_case import run_tests

run_tests()
8 changes: 7 additions & 1 deletion torch/_dynamo/symbolic_convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@
PythonModuleVariable,
UnknownVariable,
)
from .variables.nn_module import NNModuleVariable
from .variables.nn_module import NNModuleVariable, UnspecializedNNModuleVariable
from .variables.tensor import supported_comparison_ops, SymNodeVariable, TensorVariable
from .variables.user_defined import (
RemovableHandleVariable,
Expand Down Expand Up @@ -414,6 +414,12 @@ def inner(self: "InstructionTranslatorBase", inst: Instruction):
if push:
self.push(value)
self.jump(inst)
elif isinstance(value, UnspecializedNNModuleVariable):
mod = value.value
if truth_fn(mod):
if push:
self.push(value)
self.jump(inst)
elif isinstance(value, UserDefinedObjectVariable):
try:
x = value.var_getattr(self, "__bool__")
Expand Down

0 comments on commit 747fc35

Please sign in to comment.