Skip to content

Commit

Permalink
Fixed @typechecked failing to instrument functions with duplicate n…
Browse files Browse the repository at this point in the history
…ames in the same module

Fixes #355.
  • Loading branch information
agronholm committed Jul 27, 2023
1 parent f377be3 commit 99532c8
Show file tree
Hide file tree
Showing 4 changed files with 98 additions and 30 deletions.
1 change: 1 addition & 0 deletions docs/versionhistory.rst
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ This library adheres to `Semantic Versioning 2.0 <https://semver.org/#semantic-v
a method (`#362 <https://github.com/agronholm/typeguard/issues/362>`_)
- Fixed docstrings disappearing from instrumented functions
(`#359 <https://github.com/agronholm/typeguard/issues/359>`_)
- Fixed

**4.0.0** (2023-05-12)

Expand Down
70 changes: 41 additions & 29 deletions src/typeguard/_decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import ast
import inspect
import sys
from collections.abc import Sequence
from functools import partial
from inspect import isclass, isfunction
from types import CodeType, FrameType, FunctionType
Expand Down Expand Up @@ -34,6 +35,25 @@ def make_cell(value: object) -> _Cell:
return (lambda: value).__closure__[0] # type: ignore[index]


def find_target_function(
new_code: CodeType, target_path: Sequence[str], firstlineno: int
) -> CodeType | None:
target_name = target_path[0]
for const in new_code.co_consts:
if isinstance(const, CodeType):
if const.co_name == target_name:
if const.co_firstlineno == firstlineno:
return const
elif len(target_path) > 1:
target_code = find_target_function(
const, target_path[1:], firstlineno
)
if target_code:
return target_code

return None


def instrument(f: T_CallableOrType) -> FunctionType | str:
if not getattr(f, "__code__", None):
return "no code associated"
Expand All @@ -50,39 +70,31 @@ def instrument(f: T_CallableOrType) -> FunctionType | str:
target_path = [item for item in f.__qualname__.split(".") if item != "<locals>"]
module_source = inspect.getsource(sys.modules[f.__module__])
module_ast = ast.parse(module_source)
instrumentor = TypeguardTransformer(target_path)
instrumentor = TypeguardTransformer(target_path, f.__code__.co_firstlineno)
instrumentor.visit(module_ast)

if global_config.debug_instrumentation and sys.version_info >= (3, 9):
# Find the matching AST node, then unparse it to source and print to stdout
level = 0
for node in ast.walk(module_ast):
if isinstance(node, (ast.ClassDef, ast.FunctionDef)):
if node.name == target_path[level]:
if level == len(target_path) - 1:
print(
f"Source code of {f.__qualname__}() after instrumentation:"
"\n----------------------------------------------",
file=sys.stderr,
)
print(ast.unparse(node), file=sys.stderr)
print(
"----------------------------------------------",
file=sys.stderr,
)
else:
level += 1
if not instrumentor.target_node or instrumentor.target_lineno is None:
return "instrumentor did not find the target function"

module_code = compile(module_ast, f.__code__.co_filename, "exec", dont_inherit=True)
new_code = module_code
for name in target_path:
for const in new_code.co_consts:
if isinstance(const, CodeType):
if const.co_name == name:
new_code = const
break
else:
return "cannot find the target function in the AST"
new_code = find_target_function(
module_code, target_path, instrumentor.target_lineno
)
if not new_code:
return "cannot find the target function in the AST"

if global_config.debug_instrumentation and sys.version_info >= (3, 9):
# Find the matching AST node, then unparse it to source and print to stdout
print(
f"Source code of {f.__qualname__}() after instrumentation:"
"\n----------------------------------------------",
file=sys.stderr,
)
print(ast.unparse(instrumentor.target_node), file=sys.stderr)
print(
"----------------------------------------------",
file=sys.stderr,
)

closure = f.__closure__
if new_code.co_freevars != f.__code__.co_freevars:
Expand Down
20 changes: 19 additions & 1 deletion src/typeguard/_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -488,10 +488,14 @@ def visit_Str(self, node: Str) -> Any:


class TypeguardTransformer(NodeTransformer):
def __init__(self, target_path: Sequence[str] | None = None) -> None:
def __init__(
self, target_path: Sequence[str] | None = None, target_lineno: int | None = None
) -> None:
self._target_path = tuple(target_path) if target_path else None
self._memo = self._module_memo = TransformMemo(None, None, ())
self.names_used_in_annotations: set[str] = set()
self.target_node: FunctionDef | AsyncFunctionDef | None = None
self.target_lineno = target_lineno

@contextmanager
def _use_memo(
Expand Down Expand Up @@ -664,6 +668,12 @@ def visit_FunctionDef(
with self._use_memo(node):
arg_annotations: dict[str, Any] = {}
if self._target_path is None or self._memo.path == self._target_path:
# Find line number we're supposed to match against
if node.decorator_list:
first_lineno = node.decorator_list[0].lineno
else:
first_lineno = node.lineno

for decorator in node.decorator_list.copy():
if self._memo.name_matches(decorator, "typing.overload"):
# Remove overloads entirely
Expand All @@ -678,6 +688,14 @@ def visit_FunctionDef(
kw.arg: kw.value for kw in decorator.keywords if kw.arg
}

if self.target_lineno == first_lineno:
assert self.target_node is None
self.target_node = node
if node.decorator_list and sys.version_info >= (3, 8):
self.target_lineno = node.decorator_list[0].lineno
else:
self.target_lineno = node.lineno

all_args = node.args.args + node.args.kwonlyargs
if sys.version_info >= (3, 8):
all_args.extend(node.args.posonlyargs)
Expand Down
37 changes: 37 additions & 0 deletions tests/test_typechecked.py
Original file line number Diff line number Diff line change
Expand Up @@ -633,3 +633,40 @@ def foo(self) -> Dict[str, Any]:
return {}

A().foo()


def test_getter_setter():
"""Regression test for #355."""

@typechecked
class Foo:
def __init__(self, x: int):
self._x = x

@property
def x(self) -> int:
return self._x

@x.setter
def x(self, value: int) -> None:
self._x = value

f = Foo(1)
f.x = 2
assert f.x == 2
with pytest.raises(TypeCheckError):
f.x = "foo"


def test_duplicate_method():
class Foo:
def x(self) -> str:
return "first"

@typechecked()
def x(self, value: int) -> str: # noqa: F811
return "second"

assert Foo().x(1) == "second"
with pytest.raises(TypeCheckError):
Foo().x("wrong")

0 comments on commit 99532c8

Please sign in to comment.