Skip to content

Commit

Permalink
Merge pull request #78 from jorenham/disallow-stringified-annotations
Browse files Browse the repository at this point in the history
disallow quoted annotations, `__future__` imports, and runtime-only nodes
  • Loading branch information
jorenham authored Oct 2, 2024
2 parents 8f2fa92 + 6b3852d commit f8426ec
Show file tree
Hide file tree
Showing 4 changed files with 145 additions and 42 deletions.
1 change: 0 additions & 1 deletion tests/test_transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@ def test_comments(source: str) -> None:
"__version__: str = '3.14'\n",
"def concat(*args: str) -> str: ...\n",
"class C:\n def f(self, /) -> None: ...\n",
"raise NotImplementedError\n",
],
)
def test_already_compatible(source: str) -> None:
Expand Down
36 changes: 33 additions & 3 deletions tests/test_visitors.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import libcst as cst
import pytest
from unpy.exceptions import StubError
from unpy.visitors import StubVisitor


Expand All @@ -11,6 +12,37 @@ def _visit(*lines: str) -> StubVisitor:
return visitor


# stub errors


def test_illegal_future_import():
# https://github.com/jorenham/unpy/issues/43
with pytest.raises(StubError):
_visit("from __future__ import annotations")


@pytest.mark.parametrize(
"source",
[
"Const: 'str' = ...",
"type Alias = 'str'",
"from typing import TypeAlias\nAlias: TypeAlias = 'str'",
"from typing import TypeAliasType\nAlias = TypeAliasType('Alias', 'str')",
"from typing import TypeVar\nT = TypeVar('T', bound='str')",
"from typing import TypeVar\nT = TypeVar('T', default='str')",
"def f(x: 'str') -> str: ...",
"def f(x: str) -> 'str': ...",
"def f[T: 'str'](x: T) -> T: ...",
"def f[T: str = 'str'](x: T) -> T: ...",
"class C[T: 'str']: ...",
"class C[T: str = 'str']: ...",
],
)
def test_stringified_annotations(source: str):
with pytest.raises(StubError):
_visit(source)


# imports


Expand Down Expand Up @@ -271,9 +303,7 @@ def test_import_access_package_attr_attr() -> None:
}


# baseclass
# TODO: more tests

# baseclasses

def test_baseclasses_single() -> None:
visitor = _visit(
Expand Down
5 changes: 5 additions & 0 deletions unpy/exceptions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
__all__ = ("StubError",)


class StubError(TypeError):
pass
145 changes: 107 additions & 38 deletions unpy/visitors.py
Original file line number Diff line number Diff line change
@@ -1,29 +1,29 @@
import collections
import functools
from typing import Final, override

import libcst as cst
from libcst.metadata import Scope, ScopeProvider
import libcst.metadata as cst_meta

import unpy._cst as uncst
from unpy.exceptions import StubError

__all__ = ("StubVisitor",)

_MODULE_BUILTINS: Final = "builtins"
_MODULE_PATHLIB: Final = "pathlib"
_MODULE_TP: Final = "typing"
_MODULE_TPX: Final = "typing_extensions"

_ILLEGAL_BASES: Final = (
(_MODULE_BUILTINS, "BaseExceptionGroup"),
(_MODULE_BUILTINS, "ExceptionGroup"),
(_MODULE_BUILTINS, "_IncompleteInputError"),
(_MODULE_BUILTINS, "PythonFinalizationError"),
(_MODULE_BUILTINS, "EncodingWarning"),
(_MODULE_PATHLIB, "Path"),
(_MODULE_TP, "Any"),
(_MODULE_TPX, "Any"),
)

def _check_annotation_expr(
node: cst.BaseExpression,
/,
name: str | None = None,
) -> None:
if isinstance(node, cst.BaseString):
error = StubError("quoted annotations should not be included in stubs")
if name:
error.add_note(f"in {name!r}")
raise error


class StubVisitor(cst.CSTVisitor): # noqa: PLR0904
Expand All @@ -32,9 +32,9 @@ class StubVisitor(cst.CSTVisitor): # noqa: PLR0904
classes, and type-aliases.
"""

METADATA_DEPENDENCIES = (ScopeProvider,)
METADATA_DEPENDENCIES = (cst_meta.ScopeProvider,)

_global_scope: Scope
_global_scope: cst_meta.Scope

_stack_scope: Final[collections.deque[str]]
_stack_attr: Final[collections.deque[cst.Attribute]]
Expand Down Expand Up @@ -199,7 +199,12 @@ def _register_import(
module: str | None = None,
alias: str | None = None,
) -> str:
# this `str.removesuffix` avoids a double `.` in case of `from . import name`
fqn = f"{module.removesuffix(".")}.{name}" if module else name

if fqn.startswith("__future__"):
raise StubError("__future__ imports are useless in stubs")

alias = alias or name

if self.imports.setdefault(fqn, alias) != alias:
Expand Down Expand Up @@ -249,8 +254,8 @@ def _build_type_param( # noqa: C901
name = param.name.value
default = tpar.default

if isinstance(default, cst.BaseString):
raise NotImplementedError("stringified type parameter defaults")
if default:
_check_annotation_expr(default)

name_any = self.imported_from_typing_as("Any")
name_object = self.imported_as(_MODULE_BUILTINS, "object")
Expand Down Expand Up @@ -279,8 +284,6 @@ def _build_type_param( # noqa: C901
constraints: tuple[cst.BaseExpression, ...] = ()
if not (bound := param.bound):
pass
elif isinstance(bound, cst.BaseString):
raise NotImplementedError("stringified type parameter bounds")
elif (
name_any
and isinstance(bound, cst.Name | cst.Attribute)
Expand Down Expand Up @@ -308,6 +311,8 @@ def _build_type_param( # noqa: C901

constraints = tuple(cons)
bound = None
else:
_check_annotation_expr(bound)

if _default_any and bound is not None:
# if `default=Any`, replace it the value of `bound` (`Any` is horrible)
Expand Down Expand Up @@ -370,8 +375,8 @@ def __after_import(self, /) -> None:

@override
def visit_Module(self, /, node: cst.Module) -> None:
scope = self.get_metadata(ScopeProvider, node)
assert isinstance(scope, Scope)
scope = self.get_metadata(cst_meta.ScopeProvider, node)
assert isinstance(scope, cst_meta.Scope)
self._global_scope = scope

@override
Expand Down Expand Up @@ -431,6 +436,10 @@ def leave_Attribute(self, /, original_node: cst.Attribute) -> None:
node = self._stack_attr.pop()
assert node is original_node

@override
def visit_Annotation(self, /, node: cst.Annotation) -> None:
_check_annotation_expr(node.annotation)

def __check_assign_imported(self, node: cst.Assign | cst.AnnAssign, /) -> None:
if not isinstance(node.value, cst.Name | cst.Attribute):
return
Expand All @@ -446,10 +455,44 @@ def __check_assign_imported(self, node: cst.Assign | cst.AnnAssign, /) -> None:
def visit_Assign(self, node: cst.Assign) -> None:
self.__check_assign_imported(node)

# TODO(jorenham): enforce len(node.targets) == 1

if (
len(node.targets) == 1
and isinstance(target := node.targets[0].target, cst.Name)
and isinstance(node.value, cst.Call)
# typevar-likes can only "contain" annotations if they have >1 [kw]args
and len(node.value.args) > 1
and isinstance(node.value.func, cst.Name | cst.Attribute)
and isinstance(strname := node.value.args[0].value, cst.SimpleString)
and strname.raw_value == target.value
):
# this is (probably) a legacy typevar-like (or a manual `TypeAliasType`)
# TODO(jorenham): either warn user & register, or just disallow this

fname = "" # set later if needed
for arg in node.value.args[1:]:
# annotations can only occur in the positional args and the
# `bound` or `default` kwargs
if arg.keyword is None or arg.keyword.value in {"bound", "default"}:
fname = fname or uncst.get_name_strict(node.value.func)
_check_annotation_expr(arg.value, f"{target.value} = {fname}(...)")

@override
def visit_AnnAssign(self, node: cst.AnnAssign) -> None:
self.__check_assign_imported(node)

if (
node.value
and isinstance(node.target, cst.Name)
and isinstance(ann := node.annotation.annotation, cst.Name | cst.Attribute)
and (type_alias_name := self.imported_from_typing_as("TypeAlias"))
and uncst.get_name_strict(ann) == type_alias_name
):
# this is a legacy `typing[_extensions].TypeAlias`
# TODO(jorenham): either warn user & register, or just disallow this
_check_annotation_expr(node.value, f"{node.target.value}")

@override
def visit_AssignTarget(self, node: cst.AssignTarget) -> None:
assert not self._stack_attr
Expand All @@ -470,6 +513,8 @@ def visit_TypeAlias(self, /, node: cst.TypeAlias) -> None:
name = node.name.value
assert name not in self.type_aliases

_check_annotation_expr(node.value, f"type {name}")

if tpars := node.type_parameters:
self.type_aliases[name] = self._register_type_params(name, tpars)
else:
Expand All @@ -482,6 +527,12 @@ def visit_TypeAlias(self, /, node: cst.TypeAlias) -> None:
def visit_FunctionDef(self, /, node: cst.FunctionDef) -> None:
self._stack_scope.append(node.name.value)

assert isinstance(node.body, cst.SimpleStatementSuite | cst.IndentedBlock)
if len(node.body.body) != 1 or not isinstance(node.body.body[0], cst.Ellipsis):
error = StubError("function body must contain only `...`")
qualname = ".".join(self._stack_scope)
error.add_note(qualname)

if tpars := node.type_parameters:
self._register_type_params(self._stack_scope[0], tpars)

Expand All @@ -492,14 +543,6 @@ def leave_FunctionDef(self, /, original_node: cst.FunctionDef) -> None:

_ = self._stack_scope.pop()

@functools.cached_property
def _illegal_bases(self, /) -> frozenset[str]:
return frozenset({
base_name
for module, name in _ILLEGAL_BASES
if (base_name := self.imported_as(module, name))
})

@override
def visit_ClassDef(self, /, node: cst.ClassDef) -> None:
name = node.name.value
Expand Down Expand Up @@ -537,15 +580,6 @@ def visit_ClassDef(self, /, node: cst.ClassDef) -> None:
# TODO: figure out the FQN if not a global name (i.e. locally scoped)
bases.append(basename)

base_set = set(bases)
if base_set and (illegal := base_set & self._illegal_bases):
if len(illegal) == 1:
raise NotImplementedError(f"{illegal.pop()!r} as base class")
raise ExceptionGroup(
"unsupported base classes",
[NotImplementedError(f"{base!r} as base class") for base in illegal],
)

if tpars := node.type_parameters:
self._register_type_params(stack[0], tpars, infer_variance=True)

Expand All @@ -555,3 +589,38 @@ def leave_ClassDef(self, /, original_node: cst.ClassDef) -> None:
# https://github.com/jorenham/unpy/issues/46

_ = self._stack_scope.pop()

@override
def on_visit(self, /, node: cst.CSTNode) -> bool:
if isinstance(node, cst.BaseSmallStatement):
if isinstance(
node,
cst.Del
| cst.Pass
| cst.Break
| cst.Continue
| cst.Return
| cst.Raise
| cst.Assert
| cst.Global
| cst.Nonlocal,
):
keyword = type(node).__name__.lower()
raise StubError(f"{keyword!r} statements are useless in stubs")
elif isinstance(node, cst.BaseCompoundStatement):
if isinstance(
node,
cst.Try | cst.TryStar | cst.With | cst.For | cst.While | cst.Match,
):
keyword = type(node).__name__.lower()
raise StubError(f"{keyword!r} statements are useless in stubs")
elif isinstance(node, cst.BaseExpression):
if isinstance(node, cst.BooleanOperation):
raise StubError("boolean operations are useless in stubs")
if isinstance(node, cst.FormattedString):
raise StubError("f-strings are useless in stubs")
if isinstance(node, cst.Lambda | cst.Await | cst.Yield):
keyword = type(node).__name__.lower()
raise StubError(f"{keyword!r} is an invalid expression")

return super().on_visit(node)

0 comments on commit f8426ec

Please sign in to comment.