Skip to content

Commit

Permalink
[BE][Easy] improve submodule discovery for torch.ao type annotations (
Browse files Browse the repository at this point in the history
pytorch#144680)

Pull Request resolved: pytorch#144680
Approved by: https://github.com/Skylion007
  • Loading branch information
XuehaiPan authored and pytorchmergebot committed Jan 13, 2025
1 parent c40d917 commit bee84e8
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 7 deletions.
18 changes: 15 additions & 3 deletions torch/ao/__init__.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,29 @@
# mypy: allow-untyped-defs
# torch.ao is a package with a lot of interdependencies.
# We will use lazy import to avoid cyclic dependencies here.

from typing import TYPE_CHECKING as _TYPE_CHECKING


if _TYPE_CHECKING:
from types import ModuleType

from torch.ao import ( # noqa: TC004
nn as nn,
ns as ns,
pruning as pruning,
quantization as quantization,
)


__all__ = [
"nn",
"ns",
"quantization",
"pruning",
"quantization",
]


def __getattr__(name):
def __getattr__(name: str) -> "ModuleType":
if name in __all__:
import importlib

Expand Down
19 changes: 16 additions & 3 deletions torch/ao/nn/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,21 @@
# mypy: allow-untyped-defs
# We are exposing all subpackages to the end-user.
# Because of possible inter-dependency, we want to avoid
# the cyclic imports, thus implementing lazy version
# as per https://peps.python.org/pep-0562/

import importlib
from typing import TYPE_CHECKING as _TYPE_CHECKING


if _TYPE_CHECKING:
from types import ModuleType

from torch.ao.nn import ( # noqa: TC004
intrinsic as intrinsic,
qat as qat,
quantizable as quantizable,
quantized as quantized,
sparse as sparse,
)


__all__ = [
Expand All @@ -16,7 +27,9 @@
]


def __getattr__(name):
def __getattr__(name: str) -> "ModuleType":
if name in __all__:
import importlib

return importlib.import_module("." + name, __name__)
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
2 changes: 1 addition & 1 deletion torch/ao/nn/quantized/modules/rnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def from_float(cls, *args, **kwargs):

@classmethod
def from_observed(cls, other):
assert isinstance(other, cls._FLOAT_MODULE)
assert isinstance(other, cls._FLOAT_MODULE) # type: ignore[has-type]
converted = torch.ao.quantization.convert(
other, inplace=False, remove_qconfig=True
)
Expand Down

0 comments on commit bee84e8

Please sign in to comment.