Skip to content

Commit

Permalink
merged the transformers
Browse files Browse the repository at this point in the history
  • Loading branch information
jorenham committed Sep 15, 2024
1 parent 79cfb97 commit 429ef6c
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 78 deletions.
116 changes: 46 additions & 70 deletions unpy/_pep695.py → unpy/_py311.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
type _AnyDef = cst.ClassDef | cst.FunctionDef
type _NodeFlat[N: cst.CSTNode, FN: cst.CSTNode] = N | cst.FlattenSentinel[FN]

__all__ = "PY311Collector", "PY311Transformer"

_TYPING_MODULES: Final = "typing", "typing_extensions"


Expand All @@ -17,7 +19,6 @@ def bool_expr(value: bool, /) -> cst.Name:


def str_expr(value: str, /) -> cst.SimpleString:
# TODO: Configurable quote style
return cst.SimpleString(f"'{value}'")


Expand Down Expand Up @@ -198,7 +199,7 @@ def _workaround_libcst_runtime_typecheck_bug[F: Callable[..., object]](f: F, /)
return f


class PEP695Collector(cst.CSTVisitor):
class PY311Collector(cst.CSTVisitor):
"""
Collect all PEP-695 type-parameters & required imports in the module's functions,
classes, and type-aliases.
Expand Down Expand Up @@ -349,11 +350,16 @@ def leave_FunctionDef(self, /, original_node: cst.FunctionDef) -> None:
assert name == original_node.name.value


class TypeAliasTransformer(m.MatcherDecoratableTransformer):
class PY311Transformer(m.MatcherDecoratableTransformer):
_MATCH_TYPING_IMPORT: ClassVar = m.ImportFrom(
m.Name("typing") | m.Name("typing_extensions"),
)

_stack: collections.deque[str]

current_imports: frozenset[_TypingModule]
current_imports_from: dict[_TypingModule, dict[str, str]]
missing_imports_from: dict[_TypingModule, set[str]]
missing_tvars: dict[str, list[cst.Assign]]

def __init__(
Expand All @@ -362,14 +368,39 @@ def __init__(
*,
current_imports: set[_TypingModule],
current_imports_from: dict[_TypingModule, dict[str, str]],
missing_imports_from: dict[_TypingModule, set[str]],
missing_tvars: dict[str, list[cst.Assign]],
) -> None:
self._stack = collections.deque()
self.current_imports = frozenset(current_imports)
self.current_imports_from = current_imports_from
self.missing_tvars = missing_tvars
self.missing_imports_from = missing_imports_from
super().__init__()

@property
def _del_from_typing(self) -> set[str]:
missing_tpx = self.missing_imports_from["typing_extensions"]
return {
name
for name, as_ in self.current_imports_from["typing"].items()
if name == as_ and name in missing_tpx
}

@property
def _add_from_typing(self) -> set[str]:
return (
self.missing_imports_from["typing"]
- set(self.current_imports_from["typing"])
- self.missing_imports_from["typing_extensions"]
)

@property
def _add_from_typing_extensions(self) -> set[str]:
return self.missing_imports_from["typing_extensions"] - set(
self.current_imports_from["typing_extensions"],
)

@m.call_if_inside(m.Module([m.ZeroOrMore(m.SimpleStatementLine())]))
@m.leave(m.SimpleStatementLine([m.TypeAlias()]))
@_workaround_libcst_runtime_typecheck_bug
Expand Down Expand Up @@ -402,7 +433,6 @@ def leave_ClassDef(
original_node: cst.ClassDef,
updated_node: cst.ClassDef,
) -> _NodeFlat[cst.ClassDef, cst.BaseStatement]:
# TODO: subscript if `Protocol` is a base class, or add `Generic` as base
name = self._stack.pop()
assert name == updated_node.name.value

Expand Down Expand Up @@ -452,71 +482,8 @@ def leave_FunctionDef(

return self._prepend_tvars(updated_node)

@classmethod
def from_collector(cls, collector: PEP695Collector, /) -> Self:
return cls(
current_imports=collector.current_imports,
current_imports_from=collector.current_imports_from,
missing_tvars=collector.missing_tvars,
)


class TypingImportTransformer(m.MatcherDecoratableTransformer):
_TYPING_MODULES: ClassVar = m.Name("typing") | m.Name("typing_extensions")

current_imports: frozenset[_TypingModule]
current_imports_from: dict[_TypingModule, dict[str, str]]
missing_imports_from: dict[_TypingModule, set[str]]

def __init__(
self,
/,
*,
current_imports: set[_TypingModule],
current_imports_from: dict[_TypingModule, dict[str, str]],
missing_imports_from: dict[_TypingModule, set[str]],
) -> None:
self.current_imports = frozenset(current_imports)
self.current_imports_from = current_imports_from
self.missing_imports_from = missing_imports_from

super().__init__()

@property
def _del_from_typing(self) -> set[str]:
"""
The current `typing` imports that should be imported from `typing_extensions`
instead.
Todo:
Remove aliases as well (requires renaming references).
"""
missing_tpx = self.missing_imports_from["typing_extensions"]
return {
name
for name, as_ in self.current_imports_from["typing"].items()
if name == as_ and name in missing_tpx
}

@property
def _add_from_typing(self) -> set[str]:
"""The `typing` imports that are missing."""
# return self._req_typing - self._cur_typing - self._req_typing_extensions
return (
self.missing_imports_from["typing"]
- set(self.current_imports_from["typing"])
- self.missing_imports_from["typing_extensions"]
)

@property
def _add_from_typing_extensions(self) -> set[str]:
"""The `typing_extensions` imports that are missing."""
return self.missing_imports_from["typing_extensions"] - set(
self.current_imports_from["typing_extensions"],
)

@m.call_if_inside(m.SimpleStatementLine([m.OneOf(m.ImportFrom(_TYPING_MODULES))]))
@m.leave(m.ImportFrom(_TYPING_MODULES))
@m.call_if_inside(m.SimpleStatementLine([m.OneOf(_MATCH_TYPING_IMPORT)]))
@m.leave(_MATCH_TYPING_IMPORT)
def transform_typing_import(
self,
/,
Expand Down Expand Up @@ -633,9 +600,18 @@ def _new_import_statement_index(module_node: cst.Module) -> int:
return i_insert

@classmethod
def from_collector(cls, /, collector: PEP695Collector) -> Self:
def from_collector(cls, collector: PY311Collector, /) -> Self:
return cls(
current_imports=collector.current_imports,
current_imports_from=collector.current_imports_from,
missing_imports_from=collector.missing_imports_from,
missing_tvars=collector.missing_tvars,
)


def transform(original: cst.Module, /) -> cst.Module:
collector = PY311Collector()
_ = original.visit(collector)

transformer = PY311Transformer.from_collector(collector)
return original.visit(transformer)
10 changes: 2 additions & 8 deletions unpy/convert.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,9 @@
import libcst as cst

from ._pep695 import PEP695Collector, TypeAliasTransformer, TypingImportTransformer
from ._py311 import transform as transform_py311

__all__ = ("convert",)


def convert(source: str, /) -> str:
return (
cst.parse_module(source)
.visit(collector := PEP695Collector())
.visit(TypeAliasTransformer.from_collector(collector))
.visit(TypingImportTransformer.from_collector(collector))
.code
)
return transform_py311(cst.parse_module(source)).code

0 comments on commit 429ef6c

Please sign in to comment.