diff --git a/tests/test_transformers.py b/tests/test_transformers.py index 17f85d1..6c8759a 100644 --- a/tests/test_transformers.py +++ b/tests/test_transformers.py @@ -31,7 +31,6 @@ def test_comments(source: str) -> None: "__version__: str = '3.14'\n", "def concat(*args: str) -> str: ...\n", "class C:\n def f(self, /) -> None: ...\n", - "raise NotImplementedError\n", ], ) def test_already_compatible(source: str) -> None: diff --git a/tests/test_visitors.py b/tests/test_visitors.py index a99fc19..0c14adb 100644 --- a/tests/test_visitors.py +++ b/tests/test_visitors.py @@ -2,6 +2,7 @@ import libcst as cst import pytest +from unpy.exceptions import StubError from unpy.visitors import StubVisitor @@ -11,6 +12,37 @@ def _visit(*lines: str) -> StubVisitor: return visitor +# stub errors + + +def test_illegal_future_import(): + # https://github.com/jorenham/unpy/issues/43 + with pytest.raises(StubError): + _visit("from __future__ import annotations") + + +@pytest.mark.parametrize( + "source", + [ + "Const: 'str' = ...", + "type Alias = 'str'", + "from typing import TypeAlias\nAlias: TypeAlias = 'str'", + "from typing import TypeAliasType\nAlias = TypeAliasType('Alias', 'str')", + "from typing import TypeVar\nT = TypeVar('T', bound='str')", + "from typing import TypeVar\nT = TypeVar('T', default='str')", + "def f(x: 'str') -> str: ...", + "def f(x: str) -> 'str': ...", + "def f[T: 'str'](x: T) -> T: ...", + "def f[T: str = 'str'](x: T) -> T: ...", + "class C[T: 'str']: ...", + "class C[T: str = 'str']: ...", + ], +) +def test_stringified_annotations(source: str): + with pytest.raises(StubError): + _visit(source) + + # imports @@ -271,9 +303,7 @@ def test_import_access_package_attr_attr() -> None: } -# baseclass -# TODO: more tests - +# baseclasses def test_baseclasses_single() -> None: visitor = _visit( diff --git a/unpy/exceptions.py b/unpy/exceptions.py new file mode 100644 index 0000000..0e6bcf7 --- /dev/null +++ b/unpy/exceptions.py @@ -0,0 +1,5 @@ +__all__ = ("StubError",) + + +class StubError(TypeError): + pass diff --git a/unpy/visitors.py b/unpy/visitors.py index 270a49d..f109108 100644 --- a/unpy/visitors.py +++ b/unpy/visitors.py @@ -1,29 +1,29 @@ import collections -import functools from typing import Final, override import libcst as cst -from libcst.metadata import Scope, ScopeProvider +import libcst.metadata as cst_meta import unpy._cst as uncst +from unpy.exceptions import StubError __all__ = ("StubVisitor",) _MODULE_BUILTINS: Final = "builtins" -_MODULE_PATHLIB: Final = "pathlib" _MODULE_TP: Final = "typing" _MODULE_TPX: Final = "typing_extensions" -_ILLEGAL_BASES: Final = ( - (_MODULE_BUILTINS, "BaseExceptionGroup"), - (_MODULE_BUILTINS, "ExceptionGroup"), - (_MODULE_BUILTINS, "_IncompleteInputError"), - (_MODULE_BUILTINS, "PythonFinalizationError"), - (_MODULE_BUILTINS, "EncodingWarning"), - (_MODULE_PATHLIB, "Path"), - (_MODULE_TP, "Any"), - (_MODULE_TPX, "Any"), -) + +def _check_annotation_expr( + node: cst.BaseExpression, + /, + name: str | None = None, +) -> None: + if isinstance(node, cst.BaseString): + error = StubError("quoted annotations should not be included in stubs") + if name: + error.add_note(f"in {name!r}") + raise error class StubVisitor(cst.CSTVisitor): # noqa: PLR0904 @@ -32,9 +32,9 @@ class StubVisitor(cst.CSTVisitor): # noqa: PLR0904 classes, and type-aliases. """ - METADATA_DEPENDENCIES = (ScopeProvider,) + METADATA_DEPENDENCIES = (cst_meta.ScopeProvider,) - _global_scope: Scope + _global_scope: cst_meta.Scope _stack_scope: Final[collections.deque[str]] _stack_attr: Final[collections.deque[cst.Attribute]] @@ -199,7 +199,12 @@ def _register_import( module: str | None = None, alias: str | None = None, ) -> str: + # this `str.removesuffix` avoids a double `.` in case of `from . import name` fqn = f"{module.removesuffix(".")}.{name}" if module else name + + if fqn.startswith("__future__"): + raise StubError("__future__ imports are useless in stubs") + alias = alias or name if self.imports.setdefault(fqn, alias) != alias: @@ -249,8 +254,8 @@ def _build_type_param( # noqa: C901 name = param.name.value default = tpar.default - if isinstance(default, cst.BaseString): - raise NotImplementedError("stringified type parameter defaults") + if default: + _check_annotation_expr(default) name_any = self.imported_from_typing_as("Any") name_object = self.imported_as(_MODULE_BUILTINS, "object") @@ -279,8 +284,6 @@ def _build_type_param( # noqa: C901 constraints: tuple[cst.BaseExpression, ...] = () if not (bound := param.bound): pass - elif isinstance(bound, cst.BaseString): - raise NotImplementedError("stringified type parameter bounds") elif ( name_any and isinstance(bound, cst.Name | cst.Attribute) @@ -308,6 +311,8 @@ def _build_type_param( # noqa: C901 constraints = tuple(cons) bound = None + else: + _check_annotation_expr(bound) if _default_any and bound is not None: # if `default=Any`, replace it the value of `bound` (`Any` is horrible) @@ -370,8 +375,8 @@ def __after_import(self, /) -> None: @override def visit_Module(self, /, node: cst.Module) -> None: - scope = self.get_metadata(ScopeProvider, node) - assert isinstance(scope, Scope) + scope = self.get_metadata(cst_meta.ScopeProvider, node) + assert isinstance(scope, cst_meta.Scope) self._global_scope = scope @override @@ -431,6 +436,10 @@ def leave_Attribute(self, /, original_node: cst.Attribute) -> None: node = self._stack_attr.pop() assert node is original_node + @override + def visit_Annotation(self, /, node: cst.Annotation) -> None: + _check_annotation_expr(node.annotation) + def __check_assign_imported(self, node: cst.Assign | cst.AnnAssign, /) -> None: if not isinstance(node.value, cst.Name | cst.Attribute): return @@ -446,10 +455,44 @@ def __check_assign_imported(self, node: cst.Assign | cst.AnnAssign, /) -> None: def visit_Assign(self, node: cst.Assign) -> None: self.__check_assign_imported(node) + # TODO(jorenham): enforce len(node.targets) == 1 + + if ( + len(node.targets) == 1 + and isinstance(target := node.targets[0].target, cst.Name) + and isinstance(node.value, cst.Call) + # typevar-likes can only "contain" annotations if they have >1 [kw]args + and len(node.value.args) > 1 + and isinstance(node.value.func, cst.Name | cst.Attribute) + and isinstance(strname := node.value.args[0].value, cst.SimpleString) + and strname.raw_value == target.value + ): + # this is (probably) a legacy typevar-like (or a manual `TypeAliasType`) + # TODO(jorenham): either warn user & register, or just disallow this + + fname = "" # set later if needed + for arg in node.value.args[1:]: + # annotations can only occur in the positional args and the + # `bound` or `default` kwargs + if arg.keyword is None or arg.keyword.value in {"bound", "default"}: + fname = fname or uncst.get_name_strict(node.value.func) + _check_annotation_expr(arg.value, f"{target.value} = {fname}(...)") + @override def visit_AnnAssign(self, node: cst.AnnAssign) -> None: self.__check_assign_imported(node) + if ( + node.value + and isinstance(node.target, cst.Name) + and isinstance(ann := node.annotation.annotation, cst.Name | cst.Attribute) + and (type_alias_name := self.imported_from_typing_as("TypeAlias")) + and uncst.get_name_strict(ann) == type_alias_name + ): + # this is a legacy `typing[_extensions].TypeAlias` + # TODO(jorenham): either warn user & register, or just disallow this + _check_annotation_expr(node.value, f"{node.target.value}") + @override def visit_AssignTarget(self, node: cst.AssignTarget) -> None: assert not self._stack_attr @@ -470,6 +513,8 @@ def visit_TypeAlias(self, /, node: cst.TypeAlias) -> None: name = node.name.value assert name not in self.type_aliases + _check_annotation_expr(node.value, f"type {name}") + if tpars := node.type_parameters: self.type_aliases[name] = self._register_type_params(name, tpars) else: @@ -482,6 +527,12 @@ def visit_TypeAlias(self, /, node: cst.TypeAlias) -> None: def visit_FunctionDef(self, /, node: cst.FunctionDef) -> None: self._stack_scope.append(node.name.value) + assert isinstance(node.body, cst.SimpleStatementSuite | cst.IndentedBlock) + if len(node.body.body) != 1 or not isinstance(node.body.body[0], cst.Ellipsis): + error = StubError("function body must contain only `...`") + qualname = ".".join(self._stack_scope) + error.add_note(qualname) + if tpars := node.type_parameters: self._register_type_params(self._stack_scope[0], tpars) @@ -492,14 +543,6 @@ def leave_FunctionDef(self, /, original_node: cst.FunctionDef) -> None: _ = self._stack_scope.pop() - @functools.cached_property - def _illegal_bases(self, /) -> frozenset[str]: - return frozenset({ - base_name - for module, name in _ILLEGAL_BASES - if (base_name := self.imported_as(module, name)) - }) - @override def visit_ClassDef(self, /, node: cst.ClassDef) -> None: name = node.name.value @@ -537,15 +580,6 @@ def visit_ClassDef(self, /, node: cst.ClassDef) -> None: # TODO: figure out the FQN if not a global name (i.e. locally scoped) bases.append(basename) - base_set = set(bases) - if base_set and (illegal := base_set & self._illegal_bases): - if len(illegal) == 1: - raise NotImplementedError(f"{illegal.pop()!r} as base class") - raise ExceptionGroup( - "unsupported base classes", - [NotImplementedError(f"{base!r} as base class") for base in illegal], - ) - if tpars := node.type_parameters: self._register_type_params(stack[0], tpars, infer_variance=True) @@ -555,3 +589,38 @@ def leave_ClassDef(self, /, original_node: cst.ClassDef) -> None: # https://github.com/jorenham/unpy/issues/46 _ = self._stack_scope.pop() + + @override + def on_visit(self, /, node: cst.CSTNode) -> bool: + if isinstance(node, cst.BaseSmallStatement): + if isinstance( + node, + cst.Del + | cst.Pass + | cst.Break + | cst.Continue + | cst.Return + | cst.Raise + | cst.Assert + | cst.Global + | cst.Nonlocal, + ): + keyword = type(node).__name__.lower() + raise StubError(f"{keyword!r} statements are useless in stubs") + elif isinstance(node, cst.BaseCompoundStatement): + if isinstance( + node, + cst.Try | cst.TryStar | cst.With | cst.For | cst.While | cst.Match, + ): + keyword = type(node).__name__.lower() + raise StubError(f"{keyword!r} statements are useless in stubs") + elif isinstance(node, cst.BaseExpression): + if isinstance(node, cst.BooleanOperation): + raise StubError("boolean operations are useless in stubs") + if isinstance(node, cst.FormattedString): + raise StubError("f-strings are useless in stubs") + if isinstance(node, cst.Lambda | cst.Await | cst.Yield): + keyword = type(node).__name__.lower() + raise StubError(f"{keyword!r} is an invalid expression") + + return super().on_visit(node)