From e547c3f6f29e0020648e49558b25104c65401537 Mon Sep 17 00:00:00 2001 From: jorenham Date: Wed, 2 Oct 2024 00:03:45 +0200 Subject: [PATCH] explicitly disallow unsupported base classes --- tests/test_transformers.py | 32 ++++++++++++++++++++++++++++++++ unpy/_stdlib.py | 12 +++++++----- unpy/transformers.py | 28 +++++++++++++++++++++++++++- 3 files changed, 66 insertions(+), 6 deletions(-) diff --git a/tests/test_transformers.py b/tests/test_transformers.py index a982fbe..17f85d1 100644 --- a/tests/test_transformers.py +++ b/tests/test_transformers.py @@ -602,3 +602,35 @@ class MyPath(Path): ... """) with pytest.raises(NotImplementedError): transform_source(pyi_import_from) + + +def test_subclass_object(): + pyi_direct = _src("class OldStyle(object): ...") + with pytest.raises(NotImplementedError): + transform_source(pyi_direct) + + +def test_subclass_builtins_object(): + pyi_direct = _src("class OldStyle(__builtins__.object): ...") + with pytest.raises(NotImplementedError): + transform_source(pyi_direct) + + +def test_subclass_builtins_object_import(): + pyi_direct = _src(""" + import builtins + + class OldStyle(builtins.object): ... + """) + with pytest.raises(NotImplementedError): + transform_source(pyi_direct) + + +def test_subclass_builtins_object_alias(): + pyi_direct = _src(""" + from builtins import object as Object + + class OldStyle(Object): ... + """) + with pytest.raises(NotImplementedError): + transform_source(pyi_direct) diff --git a/unpy/_stdlib.py b/unpy/_stdlib.py index a2aadf7..66485b3 100644 --- a/unpy/_stdlib.py +++ b/unpy/_stdlib.py @@ -32,6 +32,8 @@ "BaseExceptionGroup": (3, 11), "ExceptionGroup": (3, 11), "EncodingWarning": (3, 10), + "reveal_locals": (4, 0), + "reveal_type": (4, 0), }, "enum": { # TODO(jorenham): Backport to `builtins.str & enum.Enum` @@ -39,11 +41,11 @@ }, } UNSUPPORTED_BASES: Final = { - "builtins": {"object": (4, 0)}, - "inspect": {"BufferFlags", (3, 12)}, - "pathlib": {"Path": (3, 12)}, - "typing": {"Any": (3, 11)}, - "typing_extensions": {"Any": (3, 11)}, + "builtins.object": (4, 0), + "inspect.BufferFlags": (3, 12), + "pathlib.Path": (3, 12), + "typing.Any": (3, 11), + "typing_extensions.Any": (3, 11), } diff --git a/unpy/transformers.py b/unpy/transformers.py index 0b7ab08..4ab6ef6 100644 --- a/unpy/transformers.py +++ b/unpy/transformers.py @@ -13,7 +13,7 @@ import libcst as cst import unpy._cst as uncst -from unpy._stdlib import BACKPORTS +from unpy._stdlib import BACKPORTS, UNSUPPORTED_BASES from unpy._types import PythonVersion from unpy.visitors import StubVisitor @@ -124,6 +124,7 @@ def __init__(self, visitor: StubVisitor, /, target: PythonVersion) -> None: self._renames = {} self._type_alias_alignment = {} + self.__check_base_classes() self.__collect_imports_typevars() self.__collect_imports_backport() self.__collect_imports_generic() @@ -187,6 +188,31 @@ def _discard_import(self, module: str, name: str, /) -> str | None: return None + def __check_base_classes(self, /) -> None: + # raise for unsupported base classes + target = self.target + visitor = self.visitor + + illegal_fqn = {base for base, req in UNSUPPORTED_BASES.items() if target < req} + illegal_names = { + alias: base + for base in illegal_fqn + if (alias := visitor.imported_as(*base.rsplit(".", 1))) + } + + for bases in self.visitor.class_bases.values(): + for base in bases: + if base in illegal_names: + raise NotImplementedError(f"{illegal_names[base]!r} as base class") + base = base.replace("__builtins__.", "builtins.") # noqa: PLW2901 + if (fqn := visitor.imports_by_alias.get(base, base)) in illegal_fqn: + raise NotImplementedError(f"{fqn!r} as base class") + if base in visitor.imports_by_ref: + module, name = visitor.imports_by_ref[base] + fqn = f"{module}.{name}" if name else module + if fqn in illegal_fqn: + raise NotImplementedError(f"{fqn!r} as base class") + def __collect_imports_typevars(self, /) -> None: # collect the missing imports for the typevar-likes target = self.target