From 53525e11a3bb1af74f3bd8b3bb62c41408828fd2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Alex=20Gr=C3=B6nholm?= Date: Thu, 27 Jul 2023 15:42:56 +0300 Subject: [PATCH] Fixed `@typechecked` failing to instrument functions with duplicate names in the same module Fixes #355. --- docs/versionhistory.rst | 3 ++ src/typeguard/_decorators.py | 70 ++++++++++++++++++++--------------- src/typeguard/_transformer.py | 21 ++++++++++- tests/test_typechecked.py | 37 ++++++++++++++++++ 4 files changed, 100 insertions(+), 31 deletions(-) diff --git a/docs/versionhistory.rst b/docs/versionhistory.rst index 38e07da..33f2762 100644 --- a/docs/versionhistory.rst +++ b/docs/versionhistory.rst @@ -12,6 +12,9 @@ This library adheres to `Semantic Versioning 2.0 `_) - Fixed docstrings disappearing from instrumented functions (`#359 `_) +- Fixed ``@typechecked`` failing to instrument functions when there are more than one + function within the same scope + (`#355 `_) **4.0.0** (2023-05-12) diff --git a/src/typeguard/_decorators.py b/src/typeguard/_decorators.py index 31932d7..53f254f 100644 --- a/src/typeguard/_decorators.py +++ b/src/typeguard/_decorators.py @@ -3,6 +3,7 @@ import ast import inspect import sys +from collections.abc import Sequence from functools import partial from inspect import isclass, isfunction from types import CodeType, FrameType, FunctionType @@ -34,6 +35,25 @@ def make_cell(value: object) -> _Cell: return (lambda: value).__closure__[0] # type: ignore[index] +def find_target_function( + new_code: CodeType, target_path: Sequence[str], firstlineno: int +) -> CodeType | None: + target_name = target_path[0] + for const in new_code.co_consts: + if isinstance(const, CodeType): + if const.co_name == target_name: + if const.co_firstlineno == firstlineno: + return const + elif len(target_path) > 1: + target_code = find_target_function( + const, target_path[1:], firstlineno + ) + if target_code: + return target_code + + return None + + def instrument(f: T_CallableOrType) -> FunctionType | str: if not getattr(f, "__code__", None): return "no code associated" @@ -50,39 +70,31 @@ def instrument(f: T_CallableOrType) -> FunctionType | str: target_path = [item for item in f.__qualname__.split(".") if item != ""] module_source = inspect.getsource(sys.modules[f.__module__]) module_ast = ast.parse(module_source) - instrumentor = TypeguardTransformer(target_path) + instrumentor = TypeguardTransformer(target_path, f.__code__.co_firstlineno) instrumentor.visit(module_ast) - if global_config.debug_instrumentation and sys.version_info >= (3, 9): - # Find the matching AST node, then unparse it to source and print to stdout - level = 0 - for node in ast.walk(module_ast): - if isinstance(node, (ast.ClassDef, ast.FunctionDef)): - if node.name == target_path[level]: - if level == len(target_path) - 1: - print( - f"Source code of {f.__qualname__}() after instrumentation:" - "\n----------------------------------------------", - file=sys.stderr, - ) - print(ast.unparse(node), file=sys.stderr) - print( - "----------------------------------------------", - file=sys.stderr, - ) - else: - level += 1 + if not instrumentor.target_node or instrumentor.target_lineno is None: + return "instrumentor did not find the target function" module_code = compile(module_ast, f.__code__.co_filename, "exec", dont_inherit=True) - new_code = module_code - for name in target_path: - for const in new_code.co_consts: - if isinstance(const, CodeType): - if const.co_name == name: - new_code = const - break - else: - return "cannot find the target function in the AST" + new_code = find_target_function( + module_code, target_path, instrumentor.target_lineno + ) + if not new_code: + return "cannot find the target function in the AST" + + if global_config.debug_instrumentation and sys.version_info >= (3, 9): + # Find the matching AST node, then unparse it to source and print to stdout + print( + f"Source code of {f.__qualname__}() after instrumentation:" + "\n----------------------------------------------", + file=sys.stderr, + ) + print(ast.unparse(instrumentor.target_node), file=sys.stderr) + print( + "----------------------------------------------", + file=sys.stderr, + ) closure = f.__closure__ if new_code.co_freevars != f.__code__.co_freevars: diff --git a/src/typeguard/_transformer.py b/src/typeguard/_transformer.py index 2df160d..32d284e 100644 --- a/src/typeguard/_transformer.py +++ b/src/typeguard/_transformer.py @@ -488,10 +488,14 @@ def visit_Str(self, node: Str) -> Any: class TypeguardTransformer(NodeTransformer): - def __init__(self, target_path: Sequence[str] | None = None) -> None: + def __init__( + self, target_path: Sequence[str] | None = None, target_lineno: int | None = None + ) -> None: self._target_path = tuple(target_path) if target_path else None self._memo = self._module_memo = TransformMemo(None, None, ()) self.names_used_in_annotations: set[str] = set() + self.target_node: FunctionDef | AsyncFunctionDef | None = None + self.target_lineno = target_lineno @contextmanager def _use_memo( @@ -664,6 +668,12 @@ def visit_FunctionDef( with self._use_memo(node): arg_annotations: dict[str, Any] = {} if self._target_path is None or self._memo.path == self._target_path: + # Find line number we're supposed to match against + if node.decorator_list: + first_lineno = node.decorator_list[0].lineno + else: + first_lineno = node.lineno + for decorator in node.decorator_list.copy(): if self._memo.name_matches(decorator, "typing.overload"): # Remove overloads entirely @@ -678,6 +688,14 @@ def visit_FunctionDef( kw.arg: kw.value for kw in decorator.keywords if kw.arg } + if self.target_lineno == first_lineno: + assert self.target_node is None + self.target_node = node + if node.decorator_list and sys.version_info >= (3, 8): + self.target_lineno = node.decorator_list[0].lineno + else: + self.target_lineno = node.lineno + all_args = node.args.args + node.args.kwonlyargs if sys.version_info >= (3, 8): all_args.extend(node.args.posonlyargs) @@ -924,7 +942,6 @@ def visit_Yield(self, node: Yield) -> Yield | Call: self._memo.has_yield_expressions = True self.generic_visit(node) - self.generic_visit(node) if ( self._memo.yield_annotation and self._memo.should_instrument diff --git a/tests/test_typechecked.py b/tests/test_typechecked.py index 258d2e5..dbb516f 100644 --- a/tests/test_typechecked.py +++ b/tests/test_typechecked.py @@ -633,3 +633,40 @@ def foo(self) -> Dict[str, Any]: return {} A().foo() + + +def test_getter_setter(): + """Regression test for #355.""" + + @typechecked + class Foo: + def __init__(self, x: int): + self._x = x + + @property + def x(self) -> int: + return self._x + + @x.setter + def x(self, value: int) -> None: + self._x = value + + f = Foo(1) + f.x = 2 + assert f.x == 2 + with pytest.raises(TypeCheckError): + f.x = "foo" + + +def test_duplicate_method(): + class Foo: + def x(self) -> str: + return "first" + + @typechecked() + def x(self, value: int) -> str: # noqa: F811 + return "second" + + assert Foo().x(1) == "second" + with pytest.raises(TypeCheckError): + Foo().x("wrong")