Skip to content

Commit

Permalink
Merge pull request #77 from jorenham/disallowed-bases
Browse files Browse the repository at this point in the history
explicitly disallow unsupported base classes
  • Loading branch information
jorenham authored Oct 1, 2024
2 parents f916b85 + e547c3f commit 8f2fa92
Show file tree
Hide file tree
Showing 3 changed files with 66 additions and 6 deletions.
32 changes: 32 additions & 0 deletions tests/test_transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -602,3 +602,35 @@ class MyPath(Path): ...
""")
with pytest.raises(NotImplementedError):
transform_source(pyi_import_from)


def test_subclass_object():
pyi_direct = _src("class OldStyle(object): ...")
with pytest.raises(NotImplementedError):
transform_source(pyi_direct)


def test_subclass_builtins_object():
pyi_direct = _src("class OldStyle(__builtins__.object): ...")
with pytest.raises(NotImplementedError):
transform_source(pyi_direct)


def test_subclass_builtins_object_import():
pyi_direct = _src("""
import builtins
class OldStyle(builtins.object): ...
""")
with pytest.raises(NotImplementedError):
transform_source(pyi_direct)


def test_subclass_builtins_object_alias():
pyi_direct = _src("""
from builtins import object as Object
class OldStyle(Object): ...
""")
with pytest.raises(NotImplementedError):
transform_source(pyi_direct)
12 changes: 7 additions & 5 deletions unpy/_stdlib.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,18 +32,20 @@
"BaseExceptionGroup": (3, 11),
"ExceptionGroup": (3, 11),
"EncodingWarning": (3, 10),
"reveal_locals": (4, 0),
"reveal_type": (4, 0),
},
"enum": {
# TODO(jorenham): Backport to `builtins.str & enum.Enum`
"StrEnum": (3, 11),
},
}
UNSUPPORTED_BASES: Final = {
"builtins": {"object": (4, 0)},
"inspect": {"BufferFlags", (3, 12)},
"pathlib": {"Path": (3, 12)},
"typing": {"Any": (3, 11)},
"typing_extensions": {"Any": (3, 11)},
"builtins.object": (4, 0),
"inspect.BufferFlags": (3, 12),
"pathlib.Path": (3, 12),
"typing.Any": (3, 11),
"typing_extensions.Any": (3, 11),
}


Expand Down
28 changes: 27 additions & 1 deletion unpy/transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
import libcst as cst

import unpy._cst as uncst
from unpy._stdlib import BACKPORTS
from unpy._stdlib import BACKPORTS, UNSUPPORTED_BASES
from unpy._types import PythonVersion
from unpy.visitors import StubVisitor

Expand Down Expand Up @@ -124,6 +124,7 @@ def __init__(self, visitor: StubVisitor, /, target: PythonVersion) -> None:
self._renames = {}
self._type_alias_alignment = {}

self.__check_base_classes()
self.__collect_imports_typevars()
self.__collect_imports_backport()
self.__collect_imports_generic()
Expand Down Expand Up @@ -187,6 +188,31 @@ def _discard_import(self, module: str, name: str, /) -> str | None:

return None

def __check_base_classes(self, /) -> None:
# raise for unsupported base classes
target = self.target
visitor = self.visitor

illegal_fqn = {base for base, req in UNSUPPORTED_BASES.items() if target < req}
illegal_names = {
alias: base
for base in illegal_fqn
if (alias := visitor.imported_as(*base.rsplit(".", 1)))
}

for bases in self.visitor.class_bases.values():
for base in bases:
if base in illegal_names:
raise NotImplementedError(f"{illegal_names[base]!r} as base class")
base = base.replace("__builtins__.", "builtins.") # noqa: PLW2901
if (fqn := visitor.imports_by_alias.get(base, base)) in illegal_fqn:
raise NotImplementedError(f"{fqn!r} as base class")
if base in visitor.imports_by_ref:
module, name = visitor.imports_by_ref[base]
fqn = f"{module}.{name}" if name else module
if fqn in illegal_fqn:
raise NotImplementedError(f"{fqn!r} as base class")

def __collect_imports_typevars(self, /) -> None:
# collect the missing imports for the typevar-likes
target = self.target
Expand Down

0 comments on commit 8f2fa92

Please sign in to comment.