Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

explicitly disallow unsupported base classes #77

Merged
merged 1 commit into from
Oct 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 32 additions & 0 deletions tests/test_transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -602,3 +602,35 @@ class MyPath(Path): ...
""")
with pytest.raises(NotImplementedError):
transform_source(pyi_import_from)


def test_subclass_object():
pyi_direct = _src("class OldStyle(object): ...")
with pytest.raises(NotImplementedError):
transform_source(pyi_direct)


def test_subclass_builtins_object():
pyi_direct = _src("class OldStyle(__builtins__.object): ...")
with pytest.raises(NotImplementedError):
transform_source(pyi_direct)


def test_subclass_builtins_object_import():
pyi_direct = _src("""
import builtins

class OldStyle(builtins.object): ...
""")
with pytest.raises(NotImplementedError):
transform_source(pyi_direct)


def test_subclass_builtins_object_alias():
pyi_direct = _src("""
from builtins import object as Object

class OldStyle(Object): ...
""")
with pytest.raises(NotImplementedError):
transform_source(pyi_direct)
12 changes: 7 additions & 5 deletions unpy/_stdlib.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,18 +32,20 @@
"BaseExceptionGroup": (3, 11),
"ExceptionGroup": (3, 11),
"EncodingWarning": (3, 10),
"reveal_locals": (4, 0),
"reveal_type": (4, 0),
},
"enum": {
# TODO(jorenham): Backport to `builtins.str & enum.Enum`
"StrEnum": (3, 11),
},
}
UNSUPPORTED_BASES: Final = {
"builtins": {"object": (4, 0)},
"inspect": {"BufferFlags", (3, 12)},
"pathlib": {"Path": (3, 12)},
"typing": {"Any": (3, 11)},
"typing_extensions": {"Any": (3, 11)},
"builtins.object": (4, 0),
"inspect.BufferFlags": (3, 12),
"pathlib.Path": (3, 12),
"typing.Any": (3, 11),
"typing_extensions.Any": (3, 11),
}


Expand Down
28 changes: 27 additions & 1 deletion unpy/transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
import libcst as cst

import unpy._cst as uncst
from unpy._stdlib import BACKPORTS
from unpy._stdlib import BACKPORTS, UNSUPPORTED_BASES
from unpy._types import PythonVersion
from unpy.visitors import StubVisitor

Expand Down Expand Up @@ -124,6 +124,7 @@ def __init__(self, visitor: StubVisitor, /, target: PythonVersion) -> None:
self._renames = {}
self._type_alias_alignment = {}

self.__check_base_classes()
self.__collect_imports_typevars()
self.__collect_imports_backport()
self.__collect_imports_generic()
Expand Down Expand Up @@ -187,6 +188,31 @@ def _discard_import(self, module: str, name: str, /) -> str | None:

return None

def __check_base_classes(self, /) -> None:
# raise for unsupported base classes
target = self.target
visitor = self.visitor

illegal_fqn = {base for base, req in UNSUPPORTED_BASES.items() if target < req}
illegal_names = {
alias: base
for base in illegal_fqn
if (alias := visitor.imported_as(*base.rsplit(".", 1)))
}

for bases in self.visitor.class_bases.values():
for base in bases:
if base in illegal_names:
raise NotImplementedError(f"{illegal_names[base]!r} as base class")
base = base.replace("__builtins__.", "builtins.") # noqa: PLW2901
if (fqn := visitor.imports_by_alias.get(base, base)) in illegal_fqn:
raise NotImplementedError(f"{fqn!r} as base class")
if base in visitor.imports_by_ref:
module, name = visitor.imports_by_ref[base]
fqn = f"{module}.{name}" if name else module
if fqn in illegal_fqn:
raise NotImplementedError(f"{fqn!r} as base class")

def __collect_imports_typevars(self, /) -> None:
# collect the missing imports for the typevar-likes
target = self.target
Expand Down