From bee84e88f8c5bbb35c9c8502db31d867dcb25ffb Mon Sep 17 00:00:00 2001 From: Xuehai Pan Date: Mon, 13 Jan 2025 21:39:17 +0800 Subject: [PATCH] [BE][Easy] improve submodule discovery for `torch.ao` type annotations (#144680) Pull Request resolved: https://github.com/pytorch/pytorch/pull/144680 Approved by: https://github.com/Skylion007 --- torch/ao/__init__.py | 18 +++++++++++++++--- torch/ao/nn/__init__.py | 19 ++++++++++++++++--- torch/ao/nn/quantized/modules/rnn.py | 2 +- 3 files changed, 32 insertions(+), 7 deletions(-) diff --git a/torch/ao/__init__.py b/torch/ao/__init__.py index 9c54493da25e85..ac866b5073deb7 100644 --- a/torch/ao/__init__.py +++ b/torch/ao/__init__.py @@ -1,17 +1,29 @@ -# mypy: allow-untyped-defs # torch.ao is a package with a lot of interdependencies. # We will use lazy import to avoid cyclic dependencies here. +from typing import TYPE_CHECKING as _TYPE_CHECKING + + +if _TYPE_CHECKING: + from types import ModuleType + + from torch.ao import ( # noqa: TC004 + nn as nn, + ns as ns, + pruning as pruning, + quantization as quantization, + ) + __all__ = [ "nn", "ns", - "quantization", "pruning", + "quantization", ] -def __getattr__(name): +def __getattr__(name: str) -> "ModuleType": if name in __all__: import importlib diff --git a/torch/ao/nn/__init__.py b/torch/ao/nn/__init__.py index a60b90d88b9039..7439c22d66882d 100644 --- a/torch/ao/nn/__init__.py +++ b/torch/ao/nn/__init__.py @@ -1,10 +1,21 @@ -# mypy: allow-untyped-defs # We are exposing all subpackages to the end-user. # Because of possible inter-dependency, we want to avoid # the cyclic imports, thus implementing lazy version # as per https://peps.python.org/pep-0562/ -import importlib +from typing import TYPE_CHECKING as _TYPE_CHECKING + + +if _TYPE_CHECKING: + from types import ModuleType + + from torch.ao.nn import ( # noqa: TC004 + intrinsic as intrinsic, + qat as qat, + quantizable as quantizable, + quantized as quantized, + sparse as sparse, + ) __all__ = [ @@ -16,7 +27,9 @@ ] -def __getattr__(name): +def __getattr__(name: str) -> "ModuleType": if name in __all__: + import importlib + return importlib.import_module("." + name, __name__) raise AttributeError(f"module {__name__!r} has no attribute {name!r}") diff --git a/torch/ao/nn/quantized/modules/rnn.py b/torch/ao/nn/quantized/modules/rnn.py index ac5c2d55e1c243..24b17ca2d62bd5 100644 --- a/torch/ao/nn/quantized/modules/rnn.py +++ b/torch/ao/nn/quantized/modules/rnn.py @@ -49,7 +49,7 @@ def from_float(cls, *args, **kwargs): @classmethod def from_observed(cls, other): - assert isinstance(other, cls._FLOAT_MODULE) + assert isinstance(other, cls._FLOAT_MODULE) # type: ignore[has-type] converted = torch.ao.quantization.convert( other, inplace=False, remove_qconfig=True )