diff --git a/README.md b/README.md index 246d0f5..84a2bc4 100644 --- a/README.md +++ b/README.md @@ -298,7 +298,7 @@ potential goals of `unpy`: - [x] `re.PatternError` => `re.error` - Typing - [x] `types.CapsuleType` => `typing_extensions.CapsuleType` - - [ ] `typing.{ClassVar,Final}` => `typing_extensions.{ClassVar,Final}` when + - [x] `typing.{ClassVar,Final}` => `typing_extensions.{ClassVar,Final}` when nested (python/cpython#89547) - Python 3.12 => 3.11 - [PEP 698][PEP698] diff --git a/tests/test_transformers.py b/tests/test_transformers.py index 6c8759a..1cac680 100644 --- a/tests/test_transformers.py +++ b/tests/test_transformers.py @@ -633,3 +633,67 @@ class OldStyle(Object): ... """) with pytest.raises(NotImplementedError): transform_source(pyi_direct) + + +def test_nested_ClassVar_Final_TN(): + pyi_in = pyi_expect = _src(""" + from typing import ClassVar, Final + + class C: + a: ClassVar[int] = 0 + b: Final[int] + """) + pyi_out = transform_source(pyi_in) + assert pyi_out == pyi_expect + + +def test_nested_ClassVar_Final_TP(): + pyi_in = _src(""" + from typing import ClassVar, Final + + class C: + a: ClassVar[Final[int]] = 1 + """) + pyi_expect = _src(""" + from typing_extensions import ClassVar, Final + + class C: + a: ClassVar[Final[int]] = 1 + """) + pyi_out = transform_source(pyi_in) + assert pyi_out == pyi_expect + + +def test_nested_ClassVar_Final_TP_inv(): + pyi_in = _src(""" + from typing import ClassVar, Final + + class C: + a: Final[ClassVar[int]] = -1 + """) + pyi_expect = _src(""" + from typing_extensions import ClassVar, Final + + class C: + a: Final[ClassVar[int]] = -1 + """) + pyi_out = transform_source(pyi_in) + assert pyi_out == pyi_expect + + +def test_nested_ClassVar_Final_TP_indirect(): + pyi_in = _src(""" + import typing as tp + + class C: + a: tp.ClassVar[tp.Final[int]] = 1 + """) + pyi_expect = _src(""" + import typing as tp + from typing_extensions import ClassVar, Final + + class C: + a: ClassVar[Final[int]] = 1 + """) + pyi_out = transform_source(pyi_in) + assert pyi_out == pyi_expect diff --git a/tests/test_visitors.py b/tests/test_visitors.py index 5069ded..e80f461 100644 --- a/tests/test_visitors.py +++ b/tests/test_visitors.py @@ -327,3 +327,31 @@ def test_baseclasses_single() -> None: # type params # TODO + +# nested ClassVar and Final +def test_nested_classvar_final() -> None: + visitor_tn = _visit( + "from typing import ClassVar, Final", + "class C:", + " a: ClassVar[int] = 0", + " b: Final[int]", + ) + visitor_tp = _visit( + "from typing import ClassVar, Final", + "class C:", + " a: ClassVar[Final[int]] = 1", + ) + visitor_tp_inv = _visit( + "from typing import ClassVar, Final", + "class C:", + " a: Final[ClassVar[int]] = -1", + ) + visitor_tp_indirect = _visit( + "import typing as tp", + "class C:", + " a: tp.ClassVar[tp.Final[int]] = 1", + ) + assert not visitor_tn.nested_classvar_final + assert visitor_tp.nested_classvar_final + assert visitor_tp_inv.nested_classvar_final + assert visitor_tp_indirect.nested_classvar_final diff --git a/unpy/transformers.py b/unpy/transformers.py index 4ab6ef6..96193e5 100644 --- a/unpy/transformers.py +++ b/unpy/transformers.py @@ -251,6 +251,7 @@ def __collection_import_backport_single(self, fqn: str, alias: str, /) -> None: def __collect_imports_backport(self, /) -> None: # collect the imports that should replaced with a `typing_extensions` backport + target = self.target visitor = self.visitor for fqn, alias in visitor.imports.items(): @@ -271,7 +272,7 @@ def __collect_imports_backport(self, /) -> None: if ( (backports := BACKPORTS.get(module)) and name in backports - and self.target < backports[name][2] + and target < backports[name][2] ): new_module, new_name, _ = backports[name] @@ -279,6 +280,18 @@ def __collect_imports_backport(self, /) -> None: if ref != new_ref: self._renames[ref] = new_ref + if visitor.nested_classvar_final and target < (3, 13): + # nested `ClassVar` and `Final` require Python >= 3.13 + for name in ["ClassVar", "Final"]: + name_old = visitor.imported_from_typing_as(name) + assert name_old + + self._discard_import(_MODULE_TP, name) + name_new = self._require_import(_MODULE_TPX, name) + + if name_old != name_new: + self._renames[name_old] = name_new + def __collect_imports_type_aliases(self, /) -> None: # collect the imports for `TypeAlias` and/or `TypeAliasType` aligned = self._type_alias_alignment diff --git a/unpy/visitors.py b/unpy/visitors.py index 72822e6..db807dd 100644 --- a/unpy/visitors.py +++ b/unpy/visitors.py @@ -59,6 +59,8 @@ class StubVisitor(cst.CSTVisitor): # noqa: PLR0904 # {class_qualname: [class_qualname, ...]} class_bases: dict[str, list[str]] + nested_classvar_final: bool + def __init__(self, /) -> None: self._stack_scope = collections.deque() self._stack_attr = collections.deque() @@ -81,6 +83,8 @@ def __init__(self, /) -> None: # TODO(jorenham): refactor this as metadata self.class_bases = {} + self.nested_classvar_final = False + super().__init__() @property @@ -373,6 +377,52 @@ def __after_import(self, /) -> None: assert self._in_import self._in_import = False + def __check_assign_imported(self, node: cst.Assign | cst.AnnAssign, /) -> None: + if not isinstance(node.value, cst.Name | cst.Attribute): + return + if (name := uncst.get_name_strict(node.value)) not in self.imports_by_alias: + return + + # TODO(jorenham): support multiple import aliases + # TODO(jorenham): support creating an import alias by assignment + fqn = self.imports_by_alias[name] + raise NotImplementedError(f"multiple import aliases for {fqn!r}") + + @override + def on_visit(self, /, node: cst.CSTNode) -> bool: + if isinstance(node, cst.BaseSmallStatement): + if isinstance( + node, + cst.Del + | cst.Pass + | cst.Break + | cst.Continue + | cst.Return + | cst.Raise + | cst.Assert + | cst.Global + | cst.Nonlocal, + ): + keyword = type(node).__name__.lower() + raise StubError(f"{keyword!r} statements are useless in stubs") + elif isinstance(node, cst.BaseCompoundStatement): + if isinstance( + node, + cst.Try | cst.TryStar | cst.With | cst.For | cst.While | cst.Match, + ): + keyword = type(node).__name__.lower() + raise StubError(f"{keyword!r} statements are useless in stubs") + elif isinstance(node, cst.BaseExpression): + if isinstance(node, cst.BooleanOperation): + raise StubError("boolean operations are useless in stubs") + if isinstance(node, cst.FormattedString): + raise StubError("f-strings are useless in stubs") + if isinstance(node, cst.Lambda | cst.Await | cst.Yield): + keyword = type(node).__name__.lower() + raise StubError(f"{keyword!r} is an invalid expression") + + return super().on_visit(node) + @override def visit_Module(self, /, node: cst.Module) -> None: scope = self.get_metadata(cst_meta.ScopeProvider, node) @@ -440,17 +490,6 @@ def leave_Attribute(self, /, original_node: cst.Attribute) -> None: def visit_Annotation(self, /, node: cst.Annotation) -> None: _check_annotation_expr(node.annotation) - def __check_assign_imported(self, node: cst.Assign | cst.AnnAssign, /) -> None: - if not isinstance(node.value, cst.Name | cst.Attribute): - return - if (name := uncst.get_name_strict(node.value)) not in self.imports_by_alias: - return - - # TODO(jorenham): support multiple import aliases - # TODO(jorenham): support creating an import alias by assignment - fqn = self.imports_by_alias[name] - raise NotImplementedError(f"multiple import aliases for {fqn!r}") - @override def visit_Assign(self, node: cst.Assign) -> None: self.__check_assign_imported(node) @@ -482,10 +521,12 @@ def visit_Assign(self, node: cst.Assign) -> None: def visit_AnnAssign(self, node: cst.AnnAssign) -> None: self.__check_assign_imported(node) + ann = node.annotation.annotation + if ( node.value and isinstance(node.target, cst.Name) - and isinstance(ann := node.annotation.annotation, cst.Name | cst.Attribute) + and isinstance(ann, cst.Name | cst.Attribute) and (type_alias_name := self.imported_from_typing_as("TypeAlias")) and uncst.get_name_strict(ann) == type_alias_name ): @@ -493,6 +534,31 @@ def visit_AnnAssign(self, node: cst.AnnAssign) -> None: # TODO(jorenham): either warn user & register, or just disallow this _check_annotation_expr(node.value, f"{node.target.value}") + # check for nested `ClassVar` and `Final` + if ( + not self.nested_classvar_final + and isinstance(ann, cst.Subscript) + and isinstance(ann.value, cst.Name | cst.Attribute) + and len(ann.slice) == 1 + and isinstance(index := ann.slice[0].slice, cst.Index) + and isinstance(ann_sub := index.value, cst.Subscript) + and isinstance(ann_sub.value, cst.Name | cst.Attribute) + and len(ann_sub.slice) == 1 + and isinstance(ann.slice[0].slice, cst.Index) + and (name_classvar := self.imported_from_typing_as("ClassVar")) + and (name_final := self.imported_from_typing_as("Final")) + ): + # this is something of the form `_: _[_[_]] = ...` + name_outer = uncst.get_name_strict(ann.value) + name_inner = uncst.get_name_strict(ann_sub.value) + names_typing = {name_classvar, name_final} + if ( + name_outer != name_inner + and name_outer in names_typing + and name_inner in names_typing + ): + self.nested_classvar_final = True + @override def visit_AssignTarget(self, node: cst.AssignTarget) -> None: assert not self._stack_attr @@ -593,38 +659,3 @@ def leave_ClassDef(self, /, original_node: cst.ClassDef) -> None: # https://github.com/jorenham/unpy/issues/46 _ = self._stack_scope.pop() - - @override - def on_visit(self, /, node: cst.CSTNode) -> bool: - if isinstance(node, cst.BaseSmallStatement): - if isinstance( - node, - cst.Del - | cst.Pass - | cst.Break - | cst.Continue - | cst.Return - | cst.Raise - | cst.Assert - | cst.Global - | cst.Nonlocal, - ): - keyword = type(node).__name__.lower() - raise StubError(f"{keyword!r} statements are useless in stubs") - elif isinstance(node, cst.BaseCompoundStatement): - if isinstance( - node, - cst.Try | cst.TryStar | cst.With | cst.For | cst.While | cst.Match, - ): - keyword = type(node).__name__.lower() - raise StubError(f"{keyword!r} statements are useless in stubs") - elif isinstance(node, cst.BaseExpression): - if isinstance(node, cst.BooleanOperation): - raise StubError("boolean operations are useless in stubs") - if isinstance(node, cst.FormattedString): - raise StubError("f-strings are useless in stubs") - if isinstance(node, cst.Lambda | cst.Await | cst.Yield): - keyword = type(node).__name__.lower() - raise StubError(f"{keyword!r} is an invalid expression") - - return super().on_visit(node)