From 5f50683d7c383b8c640f3339f939af1286a8bfea Mon Sep 17 00:00:00 2001 From: jorenham Date: Wed, 2 Oct 2024 05:26:35 +0200 Subject: [PATCH] disallow module-level `__dir__` and `__getattr__` functions --- tests/test_visitors.py | 14 +++++++++++++- unpy/visitors.py | 10 +++++++--- 2 files changed, 20 insertions(+), 4 deletions(-) diff --git a/tests/test_visitors.py b/tests/test_visitors.py index 0c14adb..5069ded 100644 --- a/tests/test_visitors.py +++ b/tests/test_visitors.py @@ -38,7 +38,19 @@ def test_illegal_future_import(): "class C[T: str = 'str']: ...", ], ) -def test_stringified_annotations(source: str): +def test_illegal_stringified_annotations(source: str): + with pytest.raises(StubError): + _visit(source) + + +@pytest.mark.parametrize( + "source", + [ + "def __dir__() -> list[str]: ...", + "def __getattr__(name: str, /) -> object: ...", + ], +) +def test_illegal_special_functions_at_module_lvl(source: str): with pytest.raises(StubError): _visit(source) diff --git a/unpy/visitors.py b/unpy/visitors.py index f109108..72822e6 100644 --- a/unpy/visitors.py +++ b/unpy/visitors.py @@ -525,16 +525,20 @@ def visit_TypeAlias(self, /, node: cst.TypeAlias) -> None: @override def visit_FunctionDef(self, /, node: cst.FunctionDef) -> None: - self._stack_scope.append(node.name.value) + stack = self._stack_scope + stack.append(name := node.name.value) + + if len(stack) == 1 and name in {"__getattr__", "__dir__"}: + raise StubError(f"module-level {name}() cannot be used in a stub") assert isinstance(node.body, cst.SimpleStatementSuite | cst.IndentedBlock) if len(node.body.body) != 1 or not isinstance(node.body.body[0], cst.Ellipsis): error = StubError("function body must contain only `...`") - qualname = ".".join(self._stack_scope) + qualname = ".".join(stack) error.add_note(qualname) if tpars := node.type_parameters: - self._register_type_params(self._stack_scope[0], tpars) + self._register_type_params(stack[0], tpars) @override def leave_FunctionDef(self, /, original_node: cst.FunctionDef) -> None: