From 4080ed582a8db34419761a1b91cff54b65f0ea3f Mon Sep 17 00:00:00 2001 From: jorenham Date: Tue, 24 Sep 2024 05:48:01 +0200 Subject: [PATCH] preliminary support for backporting to Python 3.12 and Python 3.10 --- unpy/_types.py | 2 +- unpy/transformers.py | 20 ++++++++++++++++++-- unpy/visitors.py | 3 --- 3 files changed, 19 insertions(+), 6 deletions(-) diff --git a/unpy/_types.py b/unpy/_types.py index 160bdff..763db34 100644 --- a/unpy/_types.py +++ b/unpy/_types.py @@ -29,7 +29,7 @@ class PythonVersion(tuple[int, ...], enum.ReprEnum): # noqa: SLOT001 # PY39 = (3, 9) - # PY310 = (3, 10) + PY310 = (3, 10) PY311 = (3, 11) PY312 = (3, 12) # PY313 = (3, 13) diff --git a/unpy/transformers.py b/unpy/transformers.py index 219f17d..69d82bd 100644 --- a/unpy/transformers.py +++ b/unpy/transformers.py @@ -136,6 +136,15 @@ def visit_Module(self, /, node: cst.Module) -> None: visitor = self.visitor target = self.target + if target >= (3, 12): + # PEP 695 makes these redundant + for tvar_name in ["TypeVar", "TypeVarTuple", "ParamSpec"]: + visitor.imports_add.discard(f"typing.{tvar_name}") + if target >= (3, 13): + # PEP 695 & PEP 696 make these redundant + visitor.imports_add.clear() + visitor.imports_del.clear() + for fqn in visitor.imports: if "." not in fqn or fqn.startswith(".") or fqn.endswith("*"): continue @@ -256,6 +265,11 @@ def leave_FunctionDef( ) -> cst.FunctionDef | cst.FlattenSentinel[cst.BaseStatement]: self._stack.pop() + if not (tpars := updated_node.type_parameters): + return updated_node + if self.target >= (3, 12) and not any(tpar.default for tpar in tpars.params): + return updated_node + updated_node = _remove_tpars(updated_node) return updated_node if self._stack else self._prepend_tvars(updated_node) @@ -274,6 +288,8 @@ def leave_ClassDef( if not (tpars := original_node.type_parameters): return updated_node + if self.target >= (3, 12) and not any(tpar.default for tpar in tpars.params): + return updated_node base_args = updated_node.bases tpar_names = (tpar.param.name for tpar in tpars.params) @@ -304,6 +320,8 @@ def leave_ClassDef( expr_generic = parse_name(name_generic or "Generic") new_bases.insert(i, cst.Arg(parse_subscript(expr_generic, *tpar_names))) + visitor.desire_import("typing", "Generic", has_backport=True) + updated_node = updated_node.with_changes(type_parameters=None, bases=new_bases) stack.pop() @@ -374,6 +392,4 @@ def transform_source( *, target: PythonVersion = PythonVersion.PY311, ) -> str: - if target != PythonVersion.PY311: - raise NotImplementedError(f"Python {target}") return transform_module(cst.parse_module(source), target=target).code diff --git a/unpy/visitors.py b/unpy/visitors.py index f491f86..d5f62b4 100644 --- a/unpy/visitors.py +++ b/unpy/visitors.py @@ -482,9 +482,6 @@ def visit_ClassDef(self, /, node: cst.ClassDef) -> None: ) if tpars := node.type_parameters: - if self.imported_from_typing_as("Protocol") not in base_set: - self.desire_import("typing", "Generic", has_backport=True) - self._register_type_params(stack[0], tpars, variant=True) @override