Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

disallow quoted annotations, __future__ imports, and runtime-only nodes #78

Merged
merged 4 commits into from
Oct 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion tests/test_transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@ def test_comments(source: str) -> None:
"__version__: str = '3.14'\n",
"def concat(*args: str) -> str: ...\n",
"class C:\n def f(self, /) -> None: ...\n",
"raise NotImplementedError\n",
],
)
def test_already_compatible(source: str) -> None:
Expand Down
36 changes: 33 additions & 3 deletions tests/test_visitors.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import libcst as cst
import pytest
from unpy.exceptions import StubError
from unpy.visitors import StubVisitor


Expand All @@ -11,6 +12,37 @@ def _visit(*lines: str) -> StubVisitor:
return visitor


# stub errors


def test_illegal_future_import():
# https://github.com/jorenham/unpy/issues/43
with pytest.raises(StubError):
_visit("from __future__ import annotations")


@pytest.mark.parametrize(
"source",
[
"Const: 'str' = ...",
"type Alias = 'str'",
"from typing import TypeAlias\nAlias: TypeAlias = 'str'",
"from typing import TypeAliasType\nAlias = TypeAliasType('Alias', 'str')",
"from typing import TypeVar\nT = TypeVar('T', bound='str')",
"from typing import TypeVar\nT = TypeVar('T', default='str')",
"def f(x: 'str') -> str: ...",
"def f(x: str) -> 'str': ...",
"def f[T: 'str'](x: T) -> T: ...",
"def f[T: str = 'str'](x: T) -> T: ...",
"class C[T: 'str']: ...",
"class C[T: str = 'str']: ...",
],
)
def test_stringified_annotations(source: str):
with pytest.raises(StubError):
_visit(source)


# imports


Expand Down Expand Up @@ -271,9 +303,7 @@ def test_import_access_package_attr_attr() -> None:
}


# baseclass
# TODO: more tests

# baseclasses

def test_baseclasses_single() -> None:
visitor = _visit(
Expand Down
5 changes: 5 additions & 0 deletions unpy/exceptions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
__all__ = ("StubError",)


class StubError(TypeError):
pass
145 changes: 107 additions & 38 deletions unpy/visitors.py
Original file line number Diff line number Diff line change
@@ -1,29 +1,29 @@
import collections
import functools
from typing import Final, override

import libcst as cst
from libcst.metadata import Scope, ScopeProvider
import libcst.metadata as cst_meta

import unpy._cst as uncst
from unpy.exceptions import StubError

__all__ = ("StubVisitor",)

_MODULE_BUILTINS: Final = "builtins"
_MODULE_PATHLIB: Final = "pathlib"
_MODULE_TP: Final = "typing"
_MODULE_TPX: Final = "typing_extensions"

_ILLEGAL_BASES: Final = (
(_MODULE_BUILTINS, "BaseExceptionGroup"),
(_MODULE_BUILTINS, "ExceptionGroup"),
(_MODULE_BUILTINS, "_IncompleteInputError"),
(_MODULE_BUILTINS, "PythonFinalizationError"),
(_MODULE_BUILTINS, "EncodingWarning"),
(_MODULE_PATHLIB, "Path"),
(_MODULE_TP, "Any"),
(_MODULE_TPX, "Any"),
)

def _check_annotation_expr(
node: cst.BaseExpression,
/,
name: str | None = None,
) -> None:
if isinstance(node, cst.BaseString):
error = StubError("quoted annotations should not be included in stubs")
if name:
error.add_note(f"in {name!r}")
raise error


class StubVisitor(cst.CSTVisitor): # noqa: PLR0904
Expand All @@ -32,9 +32,9 @@ class StubVisitor(cst.CSTVisitor): # noqa: PLR0904
classes, and type-aliases.
"""

METADATA_DEPENDENCIES = (ScopeProvider,)
METADATA_DEPENDENCIES = (cst_meta.ScopeProvider,)

_global_scope: Scope
_global_scope: cst_meta.Scope

_stack_scope: Final[collections.deque[str]]
_stack_attr: Final[collections.deque[cst.Attribute]]
Expand Down Expand Up @@ -199,7 +199,12 @@ def _register_import(
module: str | None = None,
alias: str | None = None,
) -> str:
# this `str.removesuffix` avoids a double `.` in case of `from . import name`
fqn = f"{module.removesuffix(".")}.{name}" if module else name

if fqn.startswith("__future__"):
raise StubError("__future__ imports are useless in stubs")

alias = alias or name

if self.imports.setdefault(fqn, alias) != alias:
Expand Down Expand Up @@ -249,8 +254,8 @@ def _build_type_param( # noqa: C901
name = param.name.value
default = tpar.default

if isinstance(default, cst.BaseString):
raise NotImplementedError("stringified type parameter defaults")
if default:
_check_annotation_expr(default)

name_any = self.imported_from_typing_as("Any")
name_object = self.imported_as(_MODULE_BUILTINS, "object")
Expand Down Expand Up @@ -279,8 +284,6 @@ def _build_type_param( # noqa: C901
constraints: tuple[cst.BaseExpression, ...] = ()
if not (bound := param.bound):
pass
elif isinstance(bound, cst.BaseString):
raise NotImplementedError("stringified type parameter bounds")
elif (
name_any
and isinstance(bound, cst.Name | cst.Attribute)
Expand Down Expand Up @@ -308,6 +311,8 @@ def _build_type_param( # noqa: C901

constraints = tuple(cons)
bound = None
else:
_check_annotation_expr(bound)

if _default_any and bound is not None:
# if `default=Any`, replace it the value of `bound` (`Any` is horrible)
Expand Down Expand Up @@ -370,8 +375,8 @@ def __after_import(self, /) -> None:

@override
def visit_Module(self, /, node: cst.Module) -> None:
scope = self.get_metadata(ScopeProvider, node)
assert isinstance(scope, Scope)
scope = self.get_metadata(cst_meta.ScopeProvider, node)
assert isinstance(scope, cst_meta.Scope)
self._global_scope = scope

@override
Expand Down Expand Up @@ -431,6 +436,10 @@ def leave_Attribute(self, /, original_node: cst.Attribute) -> None:
node = self._stack_attr.pop()
assert node is original_node

@override
def visit_Annotation(self, /, node: cst.Annotation) -> None:
_check_annotation_expr(node.annotation)

def __check_assign_imported(self, node: cst.Assign | cst.AnnAssign, /) -> None:
if not isinstance(node.value, cst.Name | cst.Attribute):
return
Expand All @@ -446,10 +455,44 @@ def __check_assign_imported(self, node: cst.Assign | cst.AnnAssign, /) -> None:
def visit_Assign(self, node: cst.Assign) -> None:
self.__check_assign_imported(node)

# TODO(jorenham): enforce len(node.targets) == 1

if (
len(node.targets) == 1
and isinstance(target := node.targets[0].target, cst.Name)
and isinstance(node.value, cst.Call)
# typevar-likes can only "contain" annotations if they have >1 [kw]args
and len(node.value.args) > 1
and isinstance(node.value.func, cst.Name | cst.Attribute)
and isinstance(strname := node.value.args[0].value, cst.SimpleString)
and strname.raw_value == target.value
):
# this is (probably) a legacy typevar-like (or a manual `TypeAliasType`)
# TODO(jorenham): either warn user & register, or just disallow this

fname = "" # set later if needed
for arg in node.value.args[1:]:
# annotations can only occur in the positional args and the
# `bound` or `default` kwargs
if arg.keyword is None or arg.keyword.value in {"bound", "default"}:
fname = fname or uncst.get_name_strict(node.value.func)
_check_annotation_expr(arg.value, f"{target.value} = {fname}(...)")

@override
def visit_AnnAssign(self, node: cst.AnnAssign) -> None:
self.__check_assign_imported(node)

if (
node.value
and isinstance(node.target, cst.Name)
and isinstance(ann := node.annotation.annotation, cst.Name | cst.Attribute)
and (type_alias_name := self.imported_from_typing_as("TypeAlias"))
and uncst.get_name_strict(ann) == type_alias_name
):
# this is a legacy `typing[_extensions].TypeAlias`
# TODO(jorenham): either warn user & register, or just disallow this
_check_annotation_expr(node.value, f"{node.target.value}")

@override
def visit_AssignTarget(self, node: cst.AssignTarget) -> None:
assert not self._stack_attr
Expand All @@ -470,6 +513,8 @@ def visit_TypeAlias(self, /, node: cst.TypeAlias) -> None:
name = node.name.value
assert name not in self.type_aliases

_check_annotation_expr(node.value, f"type {name}")

if tpars := node.type_parameters:
self.type_aliases[name] = self._register_type_params(name, tpars)
else:
Expand All @@ -482,6 +527,12 @@ def visit_TypeAlias(self, /, node: cst.TypeAlias) -> None:
def visit_FunctionDef(self, /, node: cst.FunctionDef) -> None:
self._stack_scope.append(node.name.value)

assert isinstance(node.body, cst.SimpleStatementSuite | cst.IndentedBlock)
if len(node.body.body) != 1 or not isinstance(node.body.body[0], cst.Ellipsis):
error = StubError("function body must contain only `...`")
qualname = ".".join(self._stack_scope)
error.add_note(qualname)

if tpars := node.type_parameters:
self._register_type_params(self._stack_scope[0], tpars)

Expand All @@ -492,14 +543,6 @@ def leave_FunctionDef(self, /, original_node: cst.FunctionDef) -> None:

_ = self._stack_scope.pop()

@functools.cached_property
def _illegal_bases(self, /) -> frozenset[str]:
return frozenset({
base_name
for module, name in _ILLEGAL_BASES
if (base_name := self.imported_as(module, name))
})

@override
def visit_ClassDef(self, /, node: cst.ClassDef) -> None:
name = node.name.value
Expand Down Expand Up @@ -537,15 +580,6 @@ def visit_ClassDef(self, /, node: cst.ClassDef) -> None:
# TODO: figure out the FQN if not a global name (i.e. locally scoped)
bases.append(basename)

base_set = set(bases)
if base_set and (illegal := base_set & self._illegal_bases):
if len(illegal) == 1:
raise NotImplementedError(f"{illegal.pop()!r} as base class")
raise ExceptionGroup(
"unsupported base classes",
[NotImplementedError(f"{base!r} as base class") for base in illegal],
)

if tpars := node.type_parameters:
self._register_type_params(stack[0], tpars, infer_variance=True)

Expand All @@ -555,3 +589,38 @@ def leave_ClassDef(self, /, original_node: cst.ClassDef) -> None:
# https://github.com/jorenham/unpy/issues/46

_ = self._stack_scope.pop()

@override
def on_visit(self, /, node: cst.CSTNode) -> bool:
if isinstance(node, cst.BaseSmallStatement):
if isinstance(
node,
cst.Del
| cst.Pass
| cst.Break
| cst.Continue
| cst.Return
| cst.Raise
| cst.Assert
| cst.Global
| cst.Nonlocal,
):
keyword = type(node).__name__.lower()
raise StubError(f"{keyword!r} statements are useless in stubs")
elif isinstance(node, cst.BaseCompoundStatement):
if isinstance(
node,
cst.Try | cst.TryStar | cst.With | cst.For | cst.While | cst.Match,
):
keyword = type(node).__name__.lower()
raise StubError(f"{keyword!r} statements are useless in stubs")
elif isinstance(node, cst.BaseExpression):
if isinstance(node, cst.BooleanOperation):
raise StubError("boolean operations are useless in stubs")
if isinstance(node, cst.FormattedString):
raise StubError("f-strings are useless in stubs")
if isinstance(node, cst.Lambda | cst.Await | cst.Yield):
keyword = type(node).__name__.lower()
raise StubError(f"{keyword!r} is an invalid expression")

return super().on_visit(node)