diff --git a/unpy/_pep695.py b/unpy/_py311.py similarity index 92% rename from unpy/_pep695.py rename to unpy/_py311.py index 47bb7cf..2ac4db1 100644 --- a/unpy/_pep695.py +++ b/unpy/_py311.py @@ -9,6 +9,8 @@ type _AnyDef = cst.ClassDef | cst.FunctionDef type _NodeFlat[N: cst.CSTNode, FN: cst.CSTNode] = N | cst.FlattenSentinel[FN] +__all__ = "PY311Collector", "PY311Transformer" + _TYPING_MODULES: Final = "typing", "typing_extensions" @@ -17,7 +19,6 @@ def bool_expr(value: bool, /) -> cst.Name: def str_expr(value: str, /) -> cst.SimpleString: - # TODO: Configurable quote style return cst.SimpleString(f"'{value}'") @@ -198,7 +199,7 @@ def _workaround_libcst_runtime_typecheck_bug[F: Callable[..., object]](f: F, /) return f -class PEP695Collector(cst.CSTVisitor): +class PY311Collector(cst.CSTVisitor): """ Collect all PEP-695 type-parameters & required imports in the module's functions, classes, and type-aliases. @@ -349,11 +350,16 @@ def leave_FunctionDef(self, /, original_node: cst.FunctionDef) -> None: assert name == original_node.name.value -class TypeAliasTransformer(m.MatcherDecoratableTransformer): +class PY311Transformer(m.MatcherDecoratableTransformer): + _MATCH_TYPING_IMPORT: ClassVar = m.ImportFrom( + m.Name("typing") | m.Name("typing_extensions"), + ) + _stack: collections.deque[str] current_imports: frozenset[_TypingModule] current_imports_from: dict[_TypingModule, dict[str, str]] + missing_imports_from: dict[_TypingModule, set[str]] missing_tvars: dict[str, list[cst.Assign]] def __init__( @@ -362,14 +368,39 @@ def __init__( *, current_imports: set[_TypingModule], current_imports_from: dict[_TypingModule, dict[str, str]], + missing_imports_from: dict[_TypingModule, set[str]], missing_tvars: dict[str, list[cst.Assign]], ) -> None: self._stack = collections.deque() self.current_imports = frozenset(current_imports) self.current_imports_from = current_imports_from self.missing_tvars = missing_tvars + self.missing_imports_from = missing_imports_from super().__init__() + @property + def _del_from_typing(self) -> set[str]: + missing_tpx = self.missing_imports_from["typing_extensions"] + return { + name + for name, as_ in self.current_imports_from["typing"].items() + if name == as_ and name in missing_tpx + } + + @property + def _add_from_typing(self) -> set[str]: + return ( + self.missing_imports_from["typing"] + - set(self.current_imports_from["typing"]) + - self.missing_imports_from["typing_extensions"] + ) + + @property + def _add_from_typing_extensions(self) -> set[str]: + return self.missing_imports_from["typing_extensions"] - set( + self.current_imports_from["typing_extensions"], + ) + @m.call_if_inside(m.Module([m.ZeroOrMore(m.SimpleStatementLine())])) @m.leave(m.SimpleStatementLine([m.TypeAlias()])) @_workaround_libcst_runtime_typecheck_bug @@ -402,7 +433,6 @@ def leave_ClassDef( original_node: cst.ClassDef, updated_node: cst.ClassDef, ) -> _NodeFlat[cst.ClassDef, cst.BaseStatement]: - # TODO: subscript if `Protocol` is a base class, or add `Generic` as base name = self._stack.pop() assert name == updated_node.name.value @@ -452,71 +482,8 @@ def leave_FunctionDef( return self._prepend_tvars(updated_node) - @classmethod - def from_collector(cls, collector: PEP695Collector, /) -> Self: - return cls( - current_imports=collector.current_imports, - current_imports_from=collector.current_imports_from, - missing_tvars=collector.missing_tvars, - ) - - -class TypingImportTransformer(m.MatcherDecoratableTransformer): - _TYPING_MODULES: ClassVar = m.Name("typing") | m.Name("typing_extensions") - - current_imports: frozenset[_TypingModule] - current_imports_from: dict[_TypingModule, dict[str, str]] - missing_imports_from: dict[_TypingModule, set[str]] - - def __init__( - self, - /, - *, - current_imports: set[_TypingModule], - current_imports_from: dict[_TypingModule, dict[str, str]], - missing_imports_from: dict[_TypingModule, set[str]], - ) -> None: - self.current_imports = frozenset(current_imports) - self.current_imports_from = current_imports_from - self.missing_imports_from = missing_imports_from - - super().__init__() - - @property - def _del_from_typing(self) -> set[str]: - """ - The current `typing` imports that should be imported from `typing_extensions` - instead. - - Todo: - Remove aliases as well (requires renaming references). - """ - missing_tpx = self.missing_imports_from["typing_extensions"] - return { - name - for name, as_ in self.current_imports_from["typing"].items() - if name == as_ and name in missing_tpx - } - - @property - def _add_from_typing(self) -> set[str]: - """The `typing` imports that are missing.""" - # return self._req_typing - self._cur_typing - self._req_typing_extensions - return ( - self.missing_imports_from["typing"] - - set(self.current_imports_from["typing"]) - - self.missing_imports_from["typing_extensions"] - ) - - @property - def _add_from_typing_extensions(self) -> set[str]: - """The `typing_extensions` imports that are missing.""" - return self.missing_imports_from["typing_extensions"] - set( - self.current_imports_from["typing_extensions"], - ) - - @m.call_if_inside(m.SimpleStatementLine([m.OneOf(m.ImportFrom(_TYPING_MODULES))])) - @m.leave(m.ImportFrom(_TYPING_MODULES)) + @m.call_if_inside(m.SimpleStatementLine([m.OneOf(_MATCH_TYPING_IMPORT)])) + @m.leave(_MATCH_TYPING_IMPORT) def transform_typing_import( self, /, @@ -633,9 +600,18 @@ def _new_import_statement_index(module_node: cst.Module) -> int: return i_insert @classmethod - def from_collector(cls, /, collector: PEP695Collector) -> Self: + def from_collector(cls, collector: PY311Collector, /) -> Self: return cls( current_imports=collector.current_imports, current_imports_from=collector.current_imports_from, missing_imports_from=collector.missing_imports_from, + missing_tvars=collector.missing_tvars, ) + + +def transform(original: cst.Module, /) -> cst.Module: + collector = PY311Collector() + _ = original.visit(collector) + + transformer = PY311Transformer.from_collector(collector) + return original.visit(transformer) diff --git a/unpy/convert.py b/unpy/convert.py index abf8648..3a629b8 100644 --- a/unpy/convert.py +++ b/unpy/convert.py @@ -1,15 +1,9 @@ import libcst as cst -from ._pep695 import PEP695Collector, TypeAliasTransformer, TypingImportTransformer +from ._py311 import transform as transform_py311 __all__ = ("convert",) def convert(source: str, /) -> str: - return ( - cst.parse_module(source) - .visit(collector := PEP695Collector()) - .visit(TypeAliasTransformer.from_collector(collector)) - .visit(TypingImportTransformer.from_collector(collector)) - .code - ) + return transform_py311(cst.parse_module(source)).code