Skip to content

Commit

Permalink
Merge pull request #80 from jorenham/nested-classvar-final
Browse files Browse the repository at this point in the history
backport nested `ClassVar` and `Final` on `python<3.13`
  • Loading branch information
jorenham authored Oct 2, 2024
2 parents a43b160 + c3b7336 commit f20fdfc
Show file tree
Hide file tree
Showing 5 changed files with 185 additions and 49 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -298,7 +298,7 @@ potential goals of `unpy`:
- [x] `re.PatternError` => `re.error`
- Typing
- [x] `types.CapsuleType` => `typing_extensions.CapsuleType`
- [ ] `typing.{ClassVar,Final}` => `typing_extensions.{ClassVar,Final}` when
- [x] `typing.{ClassVar,Final}` => `typing_extensions.{ClassVar,Final}` when
nested (python/cpython#89547)
- Python 3.12 => 3.11
- [PEP 698][PEP698]
Expand Down
64 changes: 64 additions & 0 deletions tests/test_transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -633,3 +633,67 @@ class OldStyle(Object): ...
""")
with pytest.raises(NotImplementedError):
transform_source(pyi_direct)


def test_nested_ClassVar_Final_TN():
pyi_in = pyi_expect = _src("""
from typing import ClassVar, Final
class C:
a: ClassVar[int] = 0
b: Final[int]
""")
pyi_out = transform_source(pyi_in)
assert pyi_out == pyi_expect


def test_nested_ClassVar_Final_TP():
pyi_in = _src("""
from typing import ClassVar, Final
class C:
a: ClassVar[Final[int]] = 1
""")
pyi_expect = _src("""
from typing_extensions import ClassVar, Final
class C:
a: ClassVar[Final[int]] = 1
""")
pyi_out = transform_source(pyi_in)
assert pyi_out == pyi_expect


def test_nested_ClassVar_Final_TP_inv():
pyi_in = _src("""
from typing import ClassVar, Final
class C:
a: Final[ClassVar[int]] = -1
""")
pyi_expect = _src("""
from typing_extensions import ClassVar, Final
class C:
a: Final[ClassVar[int]] = -1
""")
pyi_out = transform_source(pyi_in)
assert pyi_out == pyi_expect


def test_nested_ClassVar_Final_TP_indirect():
pyi_in = _src("""
import typing as tp
class C:
a: tp.ClassVar[tp.Final[int]] = 1
""")
pyi_expect = _src("""
import typing as tp
from typing_extensions import ClassVar, Final
class C:
a: ClassVar[Final[int]] = 1
""")
pyi_out = transform_source(pyi_in)
assert pyi_out == pyi_expect
28 changes: 28 additions & 0 deletions tests/test_visitors.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,3 +327,31 @@ def test_baseclasses_single() -> None:

# type params
# TODO

# nested ClassVar and Final
def test_nested_classvar_final() -> None:
visitor_tn = _visit(
"from typing import ClassVar, Final",
"class C:",
" a: ClassVar[int] = 0",
" b: Final[int]",
)
visitor_tp = _visit(
"from typing import ClassVar, Final",
"class C:",
" a: ClassVar[Final[int]] = 1",
)
visitor_tp_inv = _visit(
"from typing import ClassVar, Final",
"class C:",
" a: Final[ClassVar[int]] = -1",
)
visitor_tp_indirect = _visit(
"import typing as tp",
"class C:",
" a: tp.ClassVar[tp.Final[int]] = 1",
)
assert not visitor_tn.nested_classvar_final
assert visitor_tp.nested_classvar_final
assert visitor_tp_inv.nested_classvar_final
assert visitor_tp_indirect.nested_classvar_final
15 changes: 14 additions & 1 deletion unpy/transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,6 +251,7 @@ def __collection_import_backport_single(self, fqn: str, alias: str, /) -> None:

def __collect_imports_backport(self, /) -> None:
# collect the imports that should replaced with a `typing_extensions` backport
target = self.target
visitor = self.visitor

for fqn, alias in visitor.imports.items():
Expand All @@ -271,14 +272,26 @@ def __collect_imports_backport(self, /) -> None:
if (
(backports := BACKPORTS.get(module))
and name in backports
and self.target < backports[name][2]
and target < backports[name][2]
):
new_module, new_name, _ = backports[name]

new_ref = self._require_import(new_module, new_name, has_backport=False)
if ref != new_ref:
self._renames[ref] = new_ref

if visitor.nested_classvar_final and target < (3, 13):
# nested `ClassVar` and `Final` require Python >= 3.13
for name in ["ClassVar", "Final"]:
name_old = visitor.imported_from_typing_as(name)
assert name_old

self._discard_import(_MODULE_TP, name)
name_new = self._require_import(_MODULE_TPX, name)

if name_old != name_new:
self._renames[name_old] = name_new

def __collect_imports_type_aliases(self, /) -> None:
# collect the imports for `TypeAlias` and/or `TypeAliasType`
aligned = self._type_alias_alignment
Expand Down
125 changes: 78 additions & 47 deletions unpy/visitors.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,8 @@ class StubVisitor(cst.CSTVisitor): # noqa: PLR0904
# {class_qualname: [class_qualname, ...]}
class_bases: dict[str, list[str]]

nested_classvar_final: bool

def __init__(self, /) -> None:
self._stack_scope = collections.deque()
self._stack_attr = collections.deque()
Expand All @@ -81,6 +83,8 @@ def __init__(self, /) -> None:
# TODO(jorenham): refactor this as metadata
self.class_bases = {}

self.nested_classvar_final = False

super().__init__()

@property
Expand Down Expand Up @@ -373,6 +377,52 @@ def __after_import(self, /) -> None:
assert self._in_import
self._in_import = False

def __check_assign_imported(self, node: cst.Assign | cst.AnnAssign, /) -> None:
if not isinstance(node.value, cst.Name | cst.Attribute):
return
if (name := uncst.get_name_strict(node.value)) not in self.imports_by_alias:
return

# TODO(jorenham): support multiple import aliases
# TODO(jorenham): support creating an import alias by assignment
fqn = self.imports_by_alias[name]
raise NotImplementedError(f"multiple import aliases for {fqn!r}")

@override
def on_visit(self, /, node: cst.CSTNode) -> bool:
if isinstance(node, cst.BaseSmallStatement):
if isinstance(
node,
cst.Del
| cst.Pass
| cst.Break
| cst.Continue
| cst.Return
| cst.Raise
| cst.Assert
| cst.Global
| cst.Nonlocal,
):
keyword = type(node).__name__.lower()
raise StubError(f"{keyword!r} statements are useless in stubs")
elif isinstance(node, cst.BaseCompoundStatement):
if isinstance(
node,
cst.Try | cst.TryStar | cst.With | cst.For | cst.While | cst.Match,
):
keyword = type(node).__name__.lower()
raise StubError(f"{keyword!r} statements are useless in stubs")
elif isinstance(node, cst.BaseExpression):
if isinstance(node, cst.BooleanOperation):
raise StubError("boolean operations are useless in stubs")
if isinstance(node, cst.FormattedString):
raise StubError("f-strings are useless in stubs")
if isinstance(node, cst.Lambda | cst.Await | cst.Yield):
keyword = type(node).__name__.lower()
raise StubError(f"{keyword!r} is an invalid expression")

return super().on_visit(node)

@override
def visit_Module(self, /, node: cst.Module) -> None:
scope = self.get_metadata(cst_meta.ScopeProvider, node)
Expand Down Expand Up @@ -440,17 +490,6 @@ def leave_Attribute(self, /, original_node: cst.Attribute) -> None:
def visit_Annotation(self, /, node: cst.Annotation) -> None:
_check_annotation_expr(node.annotation)

def __check_assign_imported(self, node: cst.Assign | cst.AnnAssign, /) -> None:
if not isinstance(node.value, cst.Name | cst.Attribute):
return
if (name := uncst.get_name_strict(node.value)) not in self.imports_by_alias:
return

# TODO(jorenham): support multiple import aliases
# TODO(jorenham): support creating an import alias by assignment
fqn = self.imports_by_alias[name]
raise NotImplementedError(f"multiple import aliases for {fqn!r}")

@override
def visit_Assign(self, node: cst.Assign) -> None:
self.__check_assign_imported(node)
Expand Down Expand Up @@ -482,17 +521,44 @@ def visit_Assign(self, node: cst.Assign) -> None:
def visit_AnnAssign(self, node: cst.AnnAssign) -> None:
self.__check_assign_imported(node)

ann = node.annotation.annotation

if (
node.value
and isinstance(node.target, cst.Name)
and isinstance(ann := node.annotation.annotation, cst.Name | cst.Attribute)
and isinstance(ann, cst.Name | cst.Attribute)
and (type_alias_name := self.imported_from_typing_as("TypeAlias"))
and uncst.get_name_strict(ann) == type_alias_name
):
# this is a legacy `typing[_extensions].TypeAlias`
# TODO(jorenham): either warn user & register, or just disallow this
_check_annotation_expr(node.value, f"{node.target.value}")

# check for nested `ClassVar` and `Final`
if (
not self.nested_classvar_final
and isinstance(ann, cst.Subscript)
and isinstance(ann.value, cst.Name | cst.Attribute)
and len(ann.slice) == 1
and isinstance(index := ann.slice[0].slice, cst.Index)
and isinstance(ann_sub := index.value, cst.Subscript)
and isinstance(ann_sub.value, cst.Name | cst.Attribute)
and len(ann_sub.slice) == 1
and isinstance(ann.slice[0].slice, cst.Index)
and (name_classvar := self.imported_from_typing_as("ClassVar"))
and (name_final := self.imported_from_typing_as("Final"))
):
# this is something of the form `_: _[_[_]] = ...`
name_outer = uncst.get_name_strict(ann.value)
name_inner = uncst.get_name_strict(ann_sub.value)
names_typing = {name_classvar, name_final}
if (
name_outer != name_inner
and name_outer in names_typing
and name_inner in names_typing
):
self.nested_classvar_final = True

@override
def visit_AssignTarget(self, node: cst.AssignTarget) -> None:
assert not self._stack_attr
Expand Down Expand Up @@ -593,38 +659,3 @@ def leave_ClassDef(self, /, original_node: cst.ClassDef) -> None:
# https://github.com/jorenham/unpy/issues/46

_ = self._stack_scope.pop()

@override
def on_visit(self, /, node: cst.CSTNode) -> bool:
if isinstance(node, cst.BaseSmallStatement):
if isinstance(
node,
cst.Del
| cst.Pass
| cst.Break
| cst.Continue
| cst.Return
| cst.Raise
| cst.Assert
| cst.Global
| cst.Nonlocal,
):
keyword = type(node).__name__.lower()
raise StubError(f"{keyword!r} statements are useless in stubs")
elif isinstance(node, cst.BaseCompoundStatement):
if isinstance(
node,
cst.Try | cst.TryStar | cst.With | cst.For | cst.While | cst.Match,
):
keyword = type(node).__name__.lower()
raise StubError(f"{keyword!r} statements are useless in stubs")
elif isinstance(node, cst.BaseExpression):
if isinstance(node, cst.BooleanOperation):
raise StubError("boolean operations are useless in stubs")
if isinstance(node, cst.FormattedString):
raise StubError("f-strings are useless in stubs")
if isinstance(node, cst.Lambda | cst.Await | cst.Yield):
keyword = type(node).__name__.lower()
raise StubError(f"{keyword!r} is an invalid expression")

return super().on_visit(node)

0 comments on commit f20fdfc

Please sign in to comment.