diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index df3ec80..6cbe7dd 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -10,7 +10,7 @@ jobs: strategy: fail-fast: false matrix: - python-version: ["3.8", "3.9", "3.10", "3.11", "3.12", "3.13", pypy-3.10] + python-version: ["3.9", "3.10", "3.11", "3.12", "3.13", pypy-3.10] runs-on: ubuntu-latest steps: - uses: actions/checkout@v4 diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index f042567..02f80cf 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,6 +1,6 @@ repos: - repo: https://github.com/pre-commit/pre-commit-hooks - rev: v4.6.0 + rev: v5.0.0 hooks: - id: check-case-conflict - id: check-merge-conflict @@ -14,14 +14,14 @@ repos: - id: trailing-whitespace - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.4.5 + rev: v0.7.2 hooks: - id: ruff args: [--fix, --show-fixes] - id: ruff-format - repo: https://github.com/pre-commit/mirrors-mypy - rev: v1.10.0 + rev: v1.13.0 hooks: - id: mypy additional_dependencies: [ "typing_extensions" ] diff --git a/docs/features.rst b/docs/features.rst index 6537737..3141456 100644 --- a/docs/features.rst +++ b/docs/features.rst @@ -62,15 +62,11 @@ Protocol checking +++++++++++++++++ As of version 4.3.0, Typeguard can check instances and classes against Protocols, -regardless of whether they were annotated with :decorator:`typing.runtime_checkable`. +regardless of whether they were annotated with +:func:`@runtime_checkable `. -There are several limitations on the checks performed, however: - -* For non-callable members, only presence is checked for; no type compatibility checks - are performed -* For methods, only the number of positional arguments are checked against, so any added - keyword-only arguments without defaults don't currently trip the checker -* Likewise, argument types are not checked for compatibility +The only current limitation is that argument annotations are not checked for +compatibility, however this should be covered by static type checkers pretty well. Special considerations for ``if TYPE_CHECKING:`` ------------------------------------------------ diff --git a/docs/versionhistory.rst b/docs/versionhistory.rst index afedbaa..0bba890 100644 --- a/docs/versionhistory.rst +++ b/docs/versionhistory.rst @@ -4,6 +4,34 @@ Version history This library adheres to `Semantic Versioning 2.0 `_. +**UNRELEASED** + +- Dropped Python 3.8 support +- Changed the signature of ``typeguard_ignore()`` to be compatible with + ``typing.no_type_check()`` (PR by @jolaf) +- Avoid creating reference cycles when type checking uniontypes and classes +- Fixed checking of variable assignments involving tuple unpacking + (`#486 `_) +- Fixed ``TypeError`` when checking a class against ``type[Self]`` + (`#481 `_) +- Fixed checking of protocols on the class level (against ``type[SomeProtocol]``) + (`#498 `_) +- Fixed ``Self`` checks in instance/class methods that have positional-only arguments +- Fixed explicit checks of PEP 604 unions against ``types.UnionType`` + (`#467 `_) +- Fixed checks against annotations wrapped in ``NotRequired`` not being run unless the + ``NotRequired`` is a forward reference + (`#454 `_) + +**4.4.0** (2024-10-27) + +- Added proper checking for method signatures in protocol checks + (`#465 `_) +- Fixed basic support for intersection protocols + (`#490 `_; PR by @antonagestam) +- Fixed protocol checks running against the class of an instance and not the instance + itself (this produced wrong results for non-method member checks) + **4.3.0** (2024-05-27) - Added support for checking against static protocols diff --git a/pyproject.toml b/pyproject.toml index cf93b98..7c89494 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -17,13 +17,13 @@ classifiers = [ "License :: OSI Approved :: MIT License", "Programming Language :: Python", "Programming Language :: Python :: 3", - "Programming Language :: Python :: 3.8", "Programming Language :: Python :: 3.9", "Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.11", "Programming Language :: Python :: 3.12", + "Programming Language :: Python :: 3.13", ] -requires-python = ">= 3.8" +requires-python = ">= 3.9" dependencies = [ "importlib_metadata >= 3.6; python_version < '3.10'", "typing_extensions >= 4.10.0", @@ -80,36 +80,34 @@ src = ["src"] [tool.ruff.lint] extend-select = [ - "W", # pycodestyle warnings + "B0", # flake8-bugbear "I", # isort "PGH", # pygrep-hooks "UP", # pyupgrade - "B0", # flake8-bugbear + "W", # pycodestyle warnings ] ignore = [ "S307", "B008", + "UP006", + "UP035", ] [tool.mypy] -python_version = "3.9" +python_version = "3.11" strict = true pretty = true [tool.tox] -legacy_tox_ini = """ -[tox] -envlist = pypy3, py38, py39, py310, py311, py312, py313 +env_list = ["py39", "py310", "py311", "py312", "py313"] skip_missing_interpreters = true -minversion = 4.0 -[testenv] -extras = test -commands = coverage run -m pytest {posargs} -package = editable +[tool.tox.env_run_base] +commands = [["coverage", "run", "-m", "pytest", { replace = "posargs", extend = true }]] +package = "editable" +extras = ["test"] -[testenv:docs] -extras = doc -package = editable -commands = sphinx-build -W -n docs build/sphinx -""" +[tool.tox.env.docs] +depends = [] +extras = ["doc"] +commands = [["sphinx-build", "-W", "-n", "docs", "build/sphinx"]] diff --git a/src/typeguard/_checkers.py b/src/typeguard/_checkers.py index 67dd5ad..5e34036 100644 --- a/src/typeguard/_checkers.py +++ b/src/typeguard/_checkers.py @@ -6,24 +6,24 @@ import types import typing import warnings +from collections.abc import Mapping, MutableMapping, Sequence from enum import Enum from inspect import Parameter, isclass, isfunction from io import BufferedIOBase, IOBase, RawIOBase, TextIOBase +from itertools import zip_longest from textwrap import indent from typing import ( IO, AbstractSet, + Annotated, Any, BinaryIO, Callable, Dict, ForwardRef, List, - Mapping, - MutableMapping, NewType, Optional, - Sequence, Set, TextIO, Tuple, @@ -32,12 +32,8 @@ Union, ) from unittest.mock import Mock -from weakref import WeakKeyDictionary -try: - import typing_extensions -except ImportError: - typing_extensions = None # type: ignore[assignment] +import typing_extensions # Must use this because typing.is_typeddict does not recognize # TypedDict from typing_extensions, and as of version 4.12.0 @@ -52,7 +48,6 @@ if sys.version_info >= (3, 11): from typing import ( - Annotated, NotRequired, TypeAlias, get_args, @@ -61,14 +56,13 @@ SubclassableAny = Any else: + from typing_extensions import Any as SubclassableAny from typing_extensions import ( - Annotated, NotRequired, TypeAlias, get_args, get_origin, ) - from typing_extensions import Any as SubclassableAny if sys.version_info >= (3, 10): from importlib.metadata import entry_points @@ -85,13 +79,11 @@ ] checker_lookup_functions: list[TypeCheckLookupCallback] = [] -generic_alias_types: tuple[type, ...] = (type(List), type(List[Any])) -if sys.version_info >= (3, 9): - generic_alias_types += (types.GenericAlias,) - -protocol_check_cache: WeakKeyDictionary[ - type[Any], dict[type[Any], TypeCheckError | None] -] = WeakKeyDictionary() +generic_alias_types: tuple[type, ...] = ( + type(List), + type(List[Any]), + types.GenericAlias, +) # Sentinel _missing = object() @@ -271,9 +263,10 @@ def check_typed_dict( for key, annotation in origin_type.__annotations__.items(): if isinstance(annotation, ForwardRef): annotation = evaluate_forwardref(annotation, memo) - if get_origin(annotation) is NotRequired: - required_keys.discard(key) - annotation = get_args(annotation)[0] + + if get_origin(annotation) is NotRequired: + required_keys.discard(key) + annotation = get_args(annotation)[0] type_hints[key] = annotation @@ -430,6 +423,7 @@ def check_union( ) finally: del errors # avoid creating ref cycle + raise TypeCheckError(f"did not match any element in the union:\n{formatted_errors}") @@ -439,17 +433,24 @@ def check_uniontype( args: tuple[Any, ...], memo: TypeCheckMemo, ) -> None: + if not args: + return check_instance(value, types.UnionType, (), memo) + errors: dict[str, TypeCheckError] = {} - for type_ in args: - try: - check_type_internal(value, type_, memo) - return - except TypeCheckError as exc: - errors[get_type_name(type_)] = exc + try: + for type_ in args: + try: + check_type_internal(value, type_, memo) + return + except TypeCheckError as exc: + errors[get_type_name(type_)] = exc + + formatted_errors = indent( + "\n".join(f"{key}: {error}" for key, error in errors.items()), " " + ) + finally: + del errors # avoid creating ref cycle - formatted_errors = indent( - "\n".join(f"{key}: {error}" for key, error in errors.items()), " " - ) raise TypeCheckError(f"did not match any element in the union:\n{formatted_errors}") @@ -472,28 +473,33 @@ def check_class( if expected_class is Any: return + elif expected_class is typing_extensions.Self: + check_self(value, get_origin(expected_class), get_args(expected_class), memo) elif getattr(expected_class, "_is_protocol", False): check_protocol(value, expected_class, (), memo) elif isinstance(expected_class, TypeVar): check_typevar(value, expected_class, (), memo, subclass_check=True) elif get_origin(expected_class) is Union: errors: dict[str, TypeCheckError] = {} - for arg in get_args(expected_class): - if arg is Any: - return + try: + for arg in get_args(expected_class): + if arg is Any: + return - try: - check_class(value, type, (arg,), memo) - return - except TypeCheckError as exc: - errors[get_type_name(arg)] = exc - else: - formatted_errors = indent( - "\n".join(f"{key}: {error}" for key, error in errors.items()), " " - ) - raise TypeCheckError( - f"did not match any element in the union:\n{formatted_errors}" - ) + try: + check_class(value, type, (arg,), memo) + return + except TypeCheckError as exc: + errors[get_type_name(arg)] = exc + else: + formatted_errors = indent( + "\n".join(f"{key}: {error}" for key, error in errors.items()), " " + ) + raise TypeCheckError( + f"did not match any element in the union:\n{formatted_errors}" + ) + finally: + del errors # avoid creating ref cycle elif not issubclass(value, expected_class): # type: ignore[arg-type] raise TypeCheckError(f"is not a subclass of {qualified_name(expected_class)}") @@ -548,15 +554,8 @@ def check_typevar( ) -if typing_extensions is None: - - def _is_literal_type(typ: object) -> bool: - return typ is typing.Literal - -else: - - def _is_literal_type(typ: object) -> bool: - return typ is typing.Literal or typ is typing_extensions.Literal +def _is_literal_type(typ: object) -> bool: + return typ is typing.Literal or typ is typing_extensions.Literal def check_literal( @@ -648,102 +647,199 @@ def check_io( raise TypeCheckError("is not an I/O object") -def check_protocol( - value: Any, - origin_type: Any, - args: tuple[Any, ...], - memo: TypeCheckMemo, -) -> None: - subject: type[Any] = value if isclass(value) else type(value) +def check_signature_compatible(subject: type, protocol: type, attrname: str) -> None: + subject_sig = inspect.signature(getattr(subject, attrname)) + protocol_sig = inspect.signature(getattr(protocol, attrname)) + protocol_type: typing.Literal["instance", "class", "static"] = "instance" + subject_type: typing.Literal["instance", "class", "static"] = "instance" + + # Check if the protocol-side method is a class method or static method + if attrname in protocol.__dict__: + descriptor = protocol.__dict__[attrname] + if isinstance(descriptor, staticmethod): + protocol_type = "static" + elif isinstance(descriptor, classmethod): + protocol_type = "class" + + # Check if the subject-side method is a class method or static method + if attrname in subject.__dict__: + descriptor = subject.__dict__[attrname] + if isinstance(descriptor, staticmethod): + subject_type = "static" + elif isinstance(descriptor, classmethod): + subject_type = "class" + + if protocol_type == "instance" and subject_type != "instance": + raise TypeCheckError( + f"should be an instance method but it's a {subject_type} method" + ) + elif protocol_type != "instance" and subject_type == "instance": + raise TypeCheckError( + f"should be a {protocol_type} method but it's an instance method" + ) - if subject in protocol_check_cache: - result_map = protocol_check_cache[subject] - if origin_type in result_map: - if exc := result_map[origin_type]: - raise exc - else: - return + expected_varargs = any( + param + for param in protocol_sig.parameters.values() + if param.kind is Parameter.VAR_POSITIONAL + ) + has_varargs = any( + param + for param in subject_sig.parameters.values() + if param.kind is Parameter.VAR_POSITIONAL + ) + if expected_varargs and not has_varargs: + raise TypeCheckError("should accept variable positional arguments but doesn't") - # Collect a set of methods and non-method attributes present in the protocol - ignored_attrs = set(dir(typing.Protocol)) | { - "__annotations__", - "__non_callable_proto_members__", - } - expected_methods: dict[str, tuple[Any, Any]] = {} - expected_noncallable_members: dict[str, Any] = {} - for attrname in dir(origin_type): - # Skip attributes present in typing.Protocol - if attrname in ignored_attrs: - continue + protocol_has_varkwargs = any( + param + for param in protocol_sig.parameters.values() + if param.kind is Parameter.VAR_KEYWORD + ) + subject_has_varkwargs = any( + param + for param in subject_sig.parameters.values() + if param.kind is Parameter.VAR_KEYWORD + ) + if protocol_has_varkwargs and not subject_has_varkwargs: + raise TypeCheckError("should accept variable keyword arguments but doesn't") + + # Check that the callable has at least the expect amount of positional-only + # arguments (and no extra positional-only arguments without default values) + if not has_varargs: + protocol_args = [ + param + for param in protocol_sig.parameters.values() + if param.kind + in (Parameter.POSITIONAL_ONLY, Parameter.POSITIONAL_OR_KEYWORD) + ] + subject_args = [ + param + for param in subject_sig.parameters.values() + if param.kind + in (Parameter.POSITIONAL_ONLY, Parameter.POSITIONAL_OR_KEYWORD) + ] + + # Remove the "self" parameter from the protocol arguments to match + if protocol_type == "instance": + protocol_args.pop(0) + + # Remove the "self" parameter from the subject arguments to match + if subject_type == "instance": + subject_args.pop(0) + + for protocol_arg, subject_arg in zip_longest(protocol_args, subject_args): + if protocol_arg is None: + if subject_arg.default is Parameter.empty: + raise TypeCheckError("has too many mandatory positional arguments") + + break + + if subject_arg is None: + raise TypeCheckError("has too few positional arguments") + + if ( + protocol_arg.kind is Parameter.POSITIONAL_OR_KEYWORD + and subject_arg.kind is Parameter.POSITIONAL_ONLY + ): + raise TypeCheckError( + f"has an argument ({subject_arg.name}) that should not be " + f"positional-only" + ) - member = getattr(origin_type, attrname) - if callable(member): - signature = inspect.signature(member) - argtypes = [ - (p.annotation if p.annotation is not Parameter.empty else Any) - for p in signature.parameters.values() - if p.kind is not Parameter.KEYWORD_ONLY - ] or Ellipsis - return_annotation = ( - signature.return_annotation - if signature.return_annotation is not Parameter.empty - else Any + if ( + protocol_arg.kind is Parameter.POSITIONAL_OR_KEYWORD + and protocol_arg.name != subject_arg.name + ): + raise TypeCheckError( + f"has a positional argument ({subject_arg.name}) that should be " + f"named {protocol_arg.name!r} at this position" + ) + + protocol_kwonlyargs = { + param.name: param + for param in protocol_sig.parameters.values() + if param.kind is Parameter.KEYWORD_ONLY + } + subject_kwonlyargs = { + param.name: param + for param in subject_sig.parameters.values() + if param.kind is Parameter.KEYWORD_ONLY + } + if not subject_has_varkwargs: + # Check that the signature has at least the required keyword-only arguments, and + # no extra mandatory keyword-only arguments + if missing_kwonlyargs := [ + param.name + for param in protocol_kwonlyargs.values() + if param.name not in subject_kwonlyargs + ]: + raise TypeCheckError( + "is missing keyword-only arguments: " + ", ".join(missing_kwonlyargs) ) - expected_methods[attrname] = argtypes, return_annotation - else: - expected_noncallable_members[attrname] = member - for attrname, annotation in typing.get_type_hints(origin_type).items(): - expected_noncallable_members[attrname] = annotation + if not protocol_has_varkwargs: + if extra_kwonlyargs := [ + param.name + for param in subject_kwonlyargs.values() + if param.default is Parameter.empty + and param.name not in protocol_kwonlyargs + ]: + raise TypeCheckError( + "has mandatory keyword-only arguments not present in the protocol: " + + ", ".join(extra_kwonlyargs) + ) - subject_annotations = typing.get_type_hints(subject) - # Check that all required methods are present and their signatures are compatible - result_map = protocol_check_cache.setdefault(subject, {}) - try: - for attrname, callable_args in expected_methods.items(): +def check_protocol( + value: Any, + origin_type: Any, + args: tuple[Any, ...], + memo: TypeCheckMemo, +) -> None: + origin_annotations = typing.get_type_hints(origin_type) + for attrname in sorted(typing_extensions.get_protocol_members(origin_type)): + if (annotation := origin_annotations.get(attrname)) is not None: try: - method = getattr(subject, attrname) + subject_member = getattr(value, attrname) except AttributeError: - if attrname in subject_annotations: - raise TypeCheckError( - f"is not compatible with the {origin_type.__qualname__} protocol " - f"because its {attrname!r} attribute is not a method" - ) from None - else: - raise TypeCheckError( - f"is not compatible with the {origin_type.__qualname__} protocol " - f"because it has no method named {attrname!r}" - ) from None - - if not callable(method): raise TypeCheckError( - f"is not compatible with the {origin_type.__qualname__} protocol " - f"because its {attrname!r} attribute is not a callable" - ) + f"is not compatible with the {origin_type.__qualname__} " + f"protocol because it has no attribute named {attrname!r}" + ) from None - # TODO: raise exception on added keyword-only arguments without defaults try: - check_callable(method, Callable, callable_args, memo) + check_type_internal(subject_member, annotation, memo) except TypeCheckError as exc: raise TypeCheckError( - f"is not compatible with the {origin_type.__qualname__} protocol " - f"because its {attrname!r} method {exc}" + f"is not compatible with the {origin_type.__qualname__} " + f"protocol because its {attrname!r} attribute {exc}" + ) from None + elif callable(getattr(origin_type, attrname)): + try: + subject_member = getattr(value, attrname) + except AttributeError: + raise TypeCheckError( + f"is not compatible with the {origin_type.__qualname__} " + f"protocol because it has no method named {attrname!r}" ) from None - # Check that all required non-callable members are present - for attrname in expected_noncallable_members: - # TODO: implement assignability checks for non-callable members - if attrname not in subject_annotations and not hasattr(subject, attrname): + if not callable(subject_member): raise TypeCheckError( - f"is not compatible with the {origin_type.__qualname__} protocol " - f"because it has no attribute named {attrname!r}" + f"is not compatible with the {origin_type.__qualname__} " + f"protocol because its {attrname!r} attribute is not a callable" ) - except TypeCheckError as exc: - result_map[origin_type] = exc - raise - else: - result_map[origin_type] = None + + # TODO: implement assignability checks for parameter and return value + # annotations + subject = value if isclass(value) else value.__class__ + try: + check_signature_compatible(subject, origin_type, attrname) + except TypeCheckError as exc: + raise TypeCheckError( + f"is not compatible with the {origin_type.__qualname__} " + f"protocol because its {attrname!r} method {exc}" + ) from None def check_byteslike( @@ -768,8 +864,7 @@ def check_self( if isclass(value): if not issubclass(value, memo.self_type): raise TypeCheckError( - f"is not an instance of the self type " - f"({qualified_name(memo.self_type)})" + f"is not a subclass of the self type ({qualified_name(memo.self_type)})" ) elif not isinstance(value, memo.self_type): raise TypeCheckError( @@ -786,16 +881,6 @@ def check_paramspec( pass # No-op for now -def check_instanceof( - value: Any, - origin_type: Any, - args: tuple[Any, ...], - memo: TypeCheckMemo, -) -> None: - if not isinstance(value, origin_type): - raise TypeCheckError(f"is not an instance of {qualified_name(origin_type)}") - - def check_type_internal( value: Any, annotation: Any, @@ -905,6 +990,13 @@ def check_type_internal( type: check_class, Type: check_class, Union: check_union, + # On some versions of Python, these may simply be re-exports from "typing", + # but exactly which Python versions is subject to change. + # It's best to err on the safe side and just always specify these. + typing_extensions.Literal: check_literal, + typing_extensions.LiteralString: check_literal_string, + typing_extensions.Self: check_self, + typing_extensions.TypeGuard: check_typeguard, } if sys.version_info >= (3, 10): origin_type_checkers[types.UnionType] = check_uniontype @@ -913,16 +1005,6 @@ def check_type_internal( origin_type_checkers.update( {typing.LiteralString: check_literal_string, typing.Self: check_self} ) -if typing_extensions is not None: - # On some Python versions, these may simply be re-exports from typing, - # but exactly which Python versions is subject to change, - # so it's best to err on the safe side - # and update the dictionary on all Python versions - # if typing_extensions is installed - origin_type_checkers[typing_extensions.Literal] = check_literal - origin_type_checkers[typing_extensions.LiteralString] = check_literal_string - origin_type_checkers[typing_extensions.Self] = check_self - origin_type_checkers[typing_extensions.TypeGuard] = check_typeguard def builtin_checker_lookup( diff --git a/src/typeguard/_decorators.py b/src/typeguard/_decorators.py index cf32533..a6c20cb 100644 --- a/src/typeguard/_decorators.py +++ b/src/typeguard/_decorators.py @@ -16,20 +16,18 @@ from ._transformer import TypeguardTransformer from ._utils import Unset, function_name, get_stacklevel, is_method_of, unset +T_CallableOrType = TypeVar("T_CallableOrType", bound=Callable[..., Any]) + if TYPE_CHECKING: from typeshed.stdlib.types import _Cell - _F = TypeVar("_F") - - def typeguard_ignore(f: _F) -> _F: + def typeguard_ignore(arg: T_CallableOrType) -> T_CallableOrType: """This decorator is a noop during static type-checking.""" - return f + return arg else: from typing import no_type_check as typeguard_ignore # noqa: F401 -T_CallableOrType = TypeVar("T_CallableOrType", bound=Callable[..., Any]) - def make_cell(value: object) -> _Cell: return (lambda: value).__closure__[0] # type: ignore[index] @@ -218,7 +216,7 @@ def typechecked( ) = None if isinstance(target, (classmethod, staticmethod)): wrapper_class = target.__class__ - target = target.__func__ + target = target.__func__ # type: ignore[assignment] retval = instrument(target) if isinstance(retval, str): diff --git a/src/typeguard/_functions.py b/src/typeguard/_functions.py index 2849785..ca21c14 100644 --- a/src/typeguard/_functions.py +++ b/src/typeguard/_functions.py @@ -2,6 +2,7 @@ import sys import warnings +from collections.abc import Sequence from typing import Any, Callable, NoReturn, TypeVar, Union, overload from . import _suppression @@ -242,59 +243,53 @@ def check_yield_type( def check_variable_assignment( - value: object, varname: str, annotation: Any, memo: TypeCheckMemo + value: Any, targets: Sequence[list[tuple[str, Any]]], memo: TypeCheckMemo ) -> Any: if _suppression.type_checks_suppressed: return value - try: - check_type_internal(value, annotation, memo) - except TypeCheckError as exc: - qualname = qualified_name(value, add_class_prefix=True) - exc.append_path_element(f"value assigned to {varname} ({qualname})") - if memo.config.typecheck_fail_callback: - memo.config.typecheck_fail_callback(exc, memo) - else: - raise - - return value - + value_to_return = value + for target in targets: + star_variable_index = next( + (i for i, (varname, _) in enumerate(target) if varname.startswith("*")), + None, + ) + if star_variable_index is not None: + value_to_return = list(value) + remaining_vars = len(target) - 1 - star_variable_index + end_index = len(value_to_return) - remaining_vars + values_to_check = ( + value_to_return[:star_variable_index] + + [value_to_return[star_variable_index:end_index]] + + value_to_return[end_index:] + ) + elif len(target) > 1: + values_to_check = value_to_return = [] + iterator = iter(value) + for _ in target: + try: + values_to_check.append(next(iterator)) + except StopIteration: + raise ValueError( + f"not enough values to unpack (expected {len(target)}, got " + f"{len(values_to_check)})" + ) from None -def check_multi_variable_assignment( - value: Any, targets: list[dict[str, Any]], memo: TypeCheckMemo -) -> Any: - if max(len(target) for target in targets) == 1: - iterated_values = [value] - else: - iterated_values = list(value) - - if not _suppression.type_checks_suppressed: - for expected_types in targets: - value_index = 0 - for ann_index, (varname, expected_type) in enumerate( - expected_types.items() - ): - if varname.startswith("*"): - varname = varname[1:] - keys_left = len(expected_types) - 1 - ann_index - next_value_index = len(iterated_values) - keys_left - obj: object = iterated_values[value_index:next_value_index] - value_index = next_value_index + else: + values_to_check = [value] + + for val, (varname, annotation) in zip(values_to_check, target): + try: + check_type_internal(val, annotation, memo) + except TypeCheckError as exc: + qualname = qualified_name(val, add_class_prefix=True) + exc.append_path_element(f"value assigned to {varname} ({qualname})") + if memo.config.typecheck_fail_callback: + memo.config.typecheck_fail_callback(exc, memo) else: - obj = iterated_values[value_index] - value_index += 1 + raise - try: - check_type_internal(obj, expected_type, memo) - except TypeCheckError as exc: - qualname = qualified_name(obj, add_class_prefix=True) - exc.append_path_element(f"value assigned to {varname} ({qualname})") - if memo.config.typecheck_fail_callback: - memo.config.typecheck_fail_callback(exc, memo) - else: - raise - - return iterated_values[0] if len(iterated_values) == 1 else iterated_values + return value_to_return def warn_on_error(exc: TypeCheckError, memo: TypeCheckMemo) -> None: diff --git a/src/typeguard/_importhook.py b/src/typeguard/_importhook.py index 8590540..0d1c627 100644 --- a/src/typeguard/_importhook.py +++ b/src/typeguard/_importhook.py @@ -3,14 +3,14 @@ import ast import sys import types -from collections.abc import Callable, Iterable +from collections.abc import Callable, Iterable, Sequence from importlib.abc import MetaPathFinder from importlib.machinery import ModuleSpec, SourceFileLoader from importlib.util import cache_from_source, decode_source from inspect import isclass from os import PathLike from types import CodeType, ModuleType, TracebackType -from typing import Sequence, TypeVar +from typing import TypeVar from unittest.mock import patch from ._config import global_config diff --git a/src/typeguard/_transformer.py b/src/typeguard/_transformer.py index 13ac363..25696a5 100644 --- a/src/typeguard/_transformer.py +++ b/src/typeguard/_transformer.py @@ -28,7 +28,6 @@ If, Import, ImportFrom, - Index, List, Load, LShift, @@ -389,9 +388,7 @@ def visit_BinOp(self, node: BinOp) -> Any: union_name = self.transformer._get_import("typing", "Union") return Subscript( value=union_name, - slice=Index( - Tuple(elts=[node.left, node.right], ctx=Load()), ctx=Load() - ), + slice=Tuple(elts=[node.left, node.right], ctx=Load()), ctx=Load(), ) @@ -410,24 +407,18 @@ def visit_Subscript(self, node: Subscript) -> Any: # The subscript of typing(_extensions).Literal can be any arbitrary string, so # don't try to evaluate it as code if node.slice: - if isinstance(node.slice, Index): - # Python 3.8 - slice_value = node.slice.value # type: ignore[attr-defined] - else: - slice_value = node.slice - - if isinstance(slice_value, Tuple): + if isinstance(node.slice, Tuple): if self._memo.name_matches(node.value, *annotated_names): # Only treat the first argument to typing.Annotated as a potential # forward reference items = cast( typing.List[expr], - [self.visit(slice_value.elts[0])] + slice_value.elts[1:], + [self.visit(node.slice.elts[0])] + node.slice.elts[1:], ) else: items = cast( typing.List[expr], - [self.visit(item) for item in slice_value.elts], + [self.visit(item) for item in node.slice.elts], ) # If this is a Union and any of the items is Any, erase the entire @@ -450,7 +441,7 @@ def visit_Subscript(self, node: Subscript) -> Any: if item is None: items[index] = self.transformer._get_import("typing", "Any") - slice_value.elts = items + node.slice.elts = items else: self.generic_visit(node) @@ -472,12 +463,6 @@ def visit_Name(self, node: Name) -> Any: if self._memo.is_ignored_name(node): return None - if sys.version_info < (3, 9): - for typename, substitute in self.type_substitutions.items(): - if self._memo.name_matches(node, typename): - new_node = self.transformer._get_import(*substitute) - return copy_location(new_node, node) - return node def visit_Call(self, node: Call) -> Any: @@ -548,18 +533,10 @@ def _use_memo( return_annotation, *generator_names ): if isinstance(return_annotation, Subscript): - annotation_slice = return_annotation.slice - - # Python < 3.9 - if isinstance(annotation_slice, Index): - annotation_slice = ( - annotation_slice.value # type: ignore[attr-defined] - ) - - if isinstance(annotation_slice, Tuple): - items = annotation_slice.elts + if isinstance(return_annotation.slice, Tuple): + items = return_annotation.slice.elts else: - items = [annotation_slice] + items = [return_annotation.slice] if len(items) > 0: new_memo.yield_annotation = self._convert_annotation( @@ -723,7 +700,7 @@ def visit_FunctionDef( else: self.target_lineno = node.lineno - all_args = node.args.args + node.args.kwonlyargs + node.args.posonlyargs + all_args = node.args.posonlyargs + node.args.args + node.args.kwonlyargs # Ensure that any type shadowed by the positional or keyword-only # argument names are ignored in this function @@ -748,21 +725,14 @@ def visit_FunctionDef( if node.args.vararg: annotation_ = self._convert_annotation(node.args.vararg.annotation) if annotation_: - if sys.version_info >= (3, 9): - container = Name("tuple", ctx=Load()) - else: - container = self._get_import("typing", "Tuple") - - subscript_slice: Tuple | Index = Tuple( + container = Name("tuple", ctx=Load()) + subscript_slice = Tuple( [ annotation_, Constant(Ellipsis), ], ctx=Load(), ) - if sys.version_info < (3, 9): - subscript_slice = Index(subscript_slice, ctx=Load()) - arg_annotations[node.args.vararg.arg] = Subscript( container, subscript_slice, ctx=Load() ) @@ -770,11 +740,7 @@ def visit_FunctionDef( if node.args.kwarg: annotation_ = self._convert_annotation(node.args.kwarg.annotation) if annotation_: - if sys.version_info >= (3, 9): - container = Name("dict", ctx=Load()) - else: - container = self._get_import("typing", "Dict") - + container = Name("dict", ctx=Load()) subscript_slice = Tuple( [ Name("str", ctx=Load()), @@ -782,9 +748,6 @@ def visit_FunctionDef( ], ctx=Load(), ) - if sys.version_info < (3, 9): - subscript_slice = Index(subscript_slice, ctx=Load()) - arg_annotations[node.args.kwarg.arg] = Subscript( container, subscript_slice, ctx=Load() ) @@ -863,19 +826,20 @@ def visit_FunctionDef( isinstance(decorator, Name) and decorator.id == "classmethod" ): + arglist = node.args.posonlyargs or node.args.args memo_kwargs["self_type"] = Name( - id=node.args.args[0].arg, ctx=Load() + id=arglist[0].arg, ctx=Load() ) break else: - if node.args.args: + if arglist := node.args.posonlyargs or node.args.args: if node.name == "__new__": memo_kwargs["self_type"] = Name( - id=node.args.args[0].arg, ctx=Load() + id=arglist[0].arg, ctx=Load() ) else: memo_kwargs["self_type"] = Attribute( - Name(id=node.args.args[0].arg, ctx=Load()), + Name(id=arglist[0].arg, ctx=Load()), "__class__", ctx=Load(), ) @@ -1044,12 +1008,25 @@ def visit_AnnAssign(self, node: AnnAssign) -> Any: func_name = self._get_import( "typeguard._functions", "check_variable_assignment" ) + targets_arg = List( + [ + List( + [ + Tuple( + [Constant(node.target.id), annotation], + ctx=Load(), + ) + ], + ctx=Load(), + ) + ], + ctx=Load(), + ) node.value = Call( func_name, [ node.value, - Constant(node.target.id), - annotation, + targets_arg, self._memo.get_memo_name(), ], [], @@ -1067,7 +1044,7 @@ def visit_Assign(self, node: Assign) -> Any: # Only instrument function-local assignments if isinstance(self._memo.node, (FunctionDef, AsyncFunctionDef)): - targets: list[dict[Constant, expr | None]] = [] + preliminary_targets: list[list[tuple[Constant, expr | None]]] = [] check_required = False for target in node.targets: elts: Sequence[expr] @@ -1078,63 +1055,63 @@ def visit_Assign(self, node: Assign) -> Any: else: continue - annotations_: dict[Constant, expr | None] = {} + annotations_: list[tuple[Constant, expr | None]] = [] for exp in elts: prefix = "" if isinstance(exp, Starred): exp = exp.value prefix = "*" + path: list[str] = [] + while isinstance(exp, Attribute): + path.insert(0, exp.attr) + exp = exp.value + if isinstance(exp, Name): - self._memo.ignored_names.add(exp.id) - name = prefix + exp.id + if not path: + self._memo.ignored_names.add(exp.id) + + path.insert(0, exp.id) + name = prefix + ".".join(path) annotation = self._memo.variable_annotations.get(exp.id) if annotation: - annotations_[Constant(name)] = annotation + annotations_.append((Constant(name), annotation)) check_required = True else: - annotations_[Constant(name)] = None + annotations_.append((Constant(name), None)) - targets.append(annotations_) + preliminary_targets.append(annotations_) if check_required: # Replace missing annotations with typing.Any - for item in targets: - for key, expression in item.items(): + targets: list[list[tuple[Constant, expr]]] = [] + for items in preliminary_targets: + target_list: list[tuple[Constant, expr]] = [] + targets.append(target_list) + for key, expression in items: if expression is None: - item[key] = self._get_import("typing", "Any") + target_list.append((key, self._get_import("typing", "Any"))) + else: + target_list.append((key, expression)) - if len(targets) == 1 and len(targets[0]) == 1: - func_name = self._get_import( - "typeguard._functions", "check_variable_assignment" - ) - target_varname = next(iter(targets[0])) - node.value = Call( - func_name, - [ - node.value, - target_varname, - targets[0][target_varname], - self._memo.get_memo_name(), - ], - [], - ) - elif targets: - func_name = self._get_import( - "typeguard._functions", "check_multi_variable_assignment" - ) - targets_arg = List( - [ - Dict(keys=list(target), values=list(target.values())) - for target in targets - ], - ctx=Load(), - ) - node.value = Call( - func_name, - [node.value, targets_arg, self._memo.get_memo_name()], - [], - ) + func_name = self._get_import( + "typeguard._functions", "check_variable_assignment" + ) + targets_arg = List( + [ + List( + [Tuple([name, ann], ctx=Load()) for name, ann in target], + ctx=Load(), + ) + for target in targets + ], + ctx=Load(), + ) + node.value = Call( + func_name, + [node.value, targets_arg, self._memo.get_memo_name()], + [], + ) return node @@ -1195,12 +1172,20 @@ def visit_AugAssign(self, node: AugAssign) -> Any: operator_call = Call( operator_func, [Name(node.target.id, ctx=Load()), node.value], [] ) + targets_arg = List( + [ + List( + [Tuple([Constant(node.target.id), annotation], ctx=Load())], + ctx=Load(), + ) + ], + ctx=Load(), + ) check_call = Call( self._get_import("typeguard._functions", "check_variable_assignment"), [ operator_call, - Constant(node.target.id), - annotation, + targets_arg, self._memo.get_memo_name(), ], [], diff --git a/src/typeguard/_union_transformer.py b/src/typeguard/_union_transformer.py index 19617e6..1c296d3 100644 --- a/src/typeguard/_union_transformer.py +++ b/src/typeguard/_union_transformer.py @@ -8,26 +8,16 @@ from ast import ( BinOp, BitOr, - Index, Load, Name, NodeTransformer, Subscript, + Tuple, fix_missing_locations, parse, ) -from ast import Tuple as ASTTuple from types import CodeType -from typing import Any, Dict, FrozenSet, List, Set, Tuple, Union - -type_substitutions = { - "dict": Dict, - "list": List, - "tuple": Tuple, - "set": Set, - "frozenset": FrozenSet, - "Union": Union, -} +from typing import Any class UnionTransformer(NodeTransformer): @@ -39,9 +29,7 @@ def visit_BinOp(self, node: BinOp) -> Any: if isinstance(node.op, BitOr): return Subscript( value=self.union_name, - slice=Index( - ASTTuple(elts=[node.left, node.right], ctx=Load()), ctx=Load() - ), + slice=Tuple(elts=[node.left, node.right], ctx=Load()), ctx=Load(), ) diff --git a/src/typeguard/_utils.py b/src/typeguard/_utils.py index 9bcc841..e8f9b03 100644 --- a/src/typeguard/_utils.py +++ b/src/typeguard/_utils.py @@ -35,7 +35,7 @@ def evaluate_forwardref(forwardref: ForwardRef, memo: TypeCheckMemo) -> Any: ) def evaluate_forwardref(forwardref: ForwardRef, memo: TypeCheckMemo) -> Any: - from ._union_transformer import compile_type_hint, type_substitutions + from ._union_transformer import compile_type_hint if not forwardref.__forward_evaluated__: forwardref.__forward_code__ = compile_type_hint(forwardref.__forward_arg__) @@ -47,8 +47,6 @@ def evaluate_forwardref(forwardref: ForwardRef, memo: TypeCheckMemo) -> Any: # Try again, with the type substitutions (list -> List etc.) in place new_globals = memo.globals.copy() new_globals.setdefault("Union", Union) - if sys.version_info < (3, 9): - new_globals.update(type_substitutions) return forwardref._evaluate( new_globals, memo.locals or new_globals, *evaluate_extra_args diff --git a/tests/__init__.py b/tests/__init__.py index f28f2c2..b48bd69 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -6,10 +6,8 @@ List, NamedTuple, NewType, - Protocol, TypeVar, Union, - runtime_checkable, ) T_Foo = TypeVar("T_Foo") @@ -44,16 +42,3 @@ class Parent: class Child(Parent): def method(self, a: int) -> None: pass - - -class StaticProtocol(Protocol): - member: int - - def meth(self, x: str) -> None: ... - - -@runtime_checkable -class RuntimeProtocol(Protocol): - member: int - - def meth(self, x: str) -> None: ... diff --git a/tests/test_checkers.py b/tests/test_checkers.py index f8b21d6..23e01aa 100644 --- a/tests/test_checkers.py +++ b/tests/test_checkers.py @@ -1,5 +1,6 @@ import collections.abc import sys +import types from contextlib import nullcontext from functools import partial from io import BytesIO, StringIO @@ -7,6 +8,7 @@ from typing import ( IO, AbstractSet, + Annotated, Any, AnyStr, BinaryIO, @@ -16,14 +18,17 @@ Dict, ForwardRef, FrozenSet, + Iterable, Iterator, List, Literal, Mapping, MutableMapping, Optional, + Protocol, Sequence, Set, + Sized, TextIO, Tuple, Type, @@ -32,6 +37,7 @@ ) import pytest +from typing_extensions import LiteralString from typeguard import ( CollectionCheckStrategy, @@ -51,8 +57,6 @@ Employee, JSONType, Parent, - RuntimeProtocol, - StaticProtocol, TChild, TIntStr, TParent, @@ -62,23 +66,15 @@ ) if sys.version_info >= (3, 11): - from typing import LiteralString - SubclassableAny = Any else: from typing_extensions import Any as SubclassableAny - from typing_extensions import LiteralString if sys.version_info >= (3, 10): from typing import Concatenate, ParamSpec, TypeGuard else: from typing_extensions import Concatenate, ParamSpec, TypeGuard -if sys.version_info >= (3, 9): - from typing import Annotated -else: - from typing_extensions import Annotated - P = ParamSpec("P") @@ -512,7 +508,8 @@ def test_notrequired_pass(self, typing_provider): class DummyDict(typing_provider.TypedDict): x: int - y: "NotRequired[int]" + y: NotRequired[int] + z: "NotRequired[int]" check_type({"x": 8}, DummyDict) @@ -524,13 +521,19 @@ def test_notrequired_fail(self, typing_provider): class DummyDict(typing_provider.TypedDict): x: int - y: "NotRequired[int]" + y: NotRequired[int] + z: "NotRequired[int]" with pytest.raises( TypeCheckError, match=r"value of key 'y' of dict is not an instance of int" ): check_type({"x": 1, "y": "foo"}, DummyDict) + with pytest.raises( + TypeCheckError, match=r"value of key 'z' of dict is not an instance of int" + ): + check_type({"x": 1, "y": 6, "z": "foo"}, DummyDict) + def test_is_typeddict(self, typing_provider): # Ensure both typing.TypedDict and typing_extensions.TypedDict are recognized class DummyDict(typing_provider.TypedDict): @@ -821,8 +824,6 @@ def test_union_fail(self, annotation, value): reason="Test relies on CPython's reference counting behavior", ) def test_union_reference_leak(self): - leaked = True - class Leak: def __del__(self): nonlocal leaked @@ -832,19 +833,74 @@ def inner1(): leak = Leak() # noqa: F841 check_type(b"asdf", Union[str, bytes]) + leaked = True inner1() assert not leaked + def inner2(): + leak = Leak() # noqa: F841 + check_type(b"asdf", Union[bytes, str]) + leaked = True + inner2() + assert not leaked - def inner2(): + def inner3(): leak = Leak() # noqa: F841 with pytest.raises(TypeCheckError, match="any element in the union:"): check_type(1, Union[str, bytes]) + leaked = True + inner3() + assert not leaked + + @pytest.mark.skipif( + sys.implementation.name != "cpython", + reason="Test relies on CPython's reference counting behavior", + ) + @pytest.mark.skipif(sys.version_info < (3, 10), reason="UnionType requires 3.10") + def test_uniontype_reference_leak(self): + class Leak: + def __del__(self): + nonlocal leaked + leaked = False + + def inner1(): + leak = Leak() # noqa: F841 + check_type(b"asdf", str | bytes) + + leaked = True + inner1() + assert not leaked + + def inner2(): + leak = Leak() # noqa: F841 + check_type(b"asdf", bytes | str) + + leaked = True inner2() assert not leaked + def inner3(): + leak = Leak() # noqa: F841 + with pytest.raises(TypeCheckError, match="any element in the union:"): + check_type(1, Union[str, bytes]) + + leaked = True + inner3() + assert not leaked + + @pytest.mark.skipif(sys.version_info < (3, 10), reason="UnionType requires 3.10") + def test_raw_uniontype_success(self): + check_type(str | int, types.UnionType) + + @pytest.mark.skipif(sys.version_info < (3, 10), reason="UnionType requires 3.10") + def test_raw_uniontype_fail(self): + with pytest.raises( + TypeCheckError, match=r"class str is not an instance of \w+\.UnionType$" + ): + check_type(str, types.UnionType) + class TestTypevar: def test_bound(self): @@ -940,9 +996,7 @@ def test_union_typevar(self): @pytest.mark.parametrize("check_against", [type, Type[Any]]) def test_generic_aliase(self, check_against): - if sys.version_info >= (3, 9): - check_type(dict[str, str], check_against) - + check_type(dict[str, str], check_against) check_type(Dict, check_against) check_type(Dict[str, str], check_against) @@ -995,119 +1049,325 @@ def test_text_real_file(self, tmp_path: Path): check_type(f, TextIO) -@pytest.mark.parametrize( - "instantiate, annotation", - [ - pytest.param(True, RuntimeProtocol, id="instance_runtime"), - pytest.param(False, Type[RuntimeProtocol], id="class_runtime"), - pytest.param(True, StaticProtocol, id="instance_static"), - pytest.param(False, Type[StaticProtocol], id="class_static"), - ], -) +class TestIntersectingProtocol: + SIT = TypeVar("SIT", covariant=True) + + class SizedIterable( + Sized, + Iterable[SIT], + Protocol[SIT], + ): ... + + @pytest.mark.parametrize( + "subject, predicate_type", + ( + pytest.param( + (), + SizedIterable, + id="empty_tuple_unspecialized", + ), + pytest.param( + range(2), + SizedIterable, + id="range", + ), + pytest.param( + (), + SizedIterable[int], + id="empty_tuple_int_specialized", + ), + pytest.param( + (1, 2, 3), + SizedIterable[int], + id="tuple_int_specialized", + ), + pytest.param( + ("1", "2", "3"), + SizedIterable[str], + id="tuple_str_specialized", + ), + ), + ) + def test_valid_member_passes(self, subject: object, predicate_type: type) -> None: + for _ in range(2): # Makes sure that the cache is also exercised + check_type(subject, predicate_type) + + xfail_nested_protocol_checks = pytest.mark.xfail( + reason="false negative due to missing support for nested protocol checks", + ) + + @pytest.mark.parametrize( + "subject, predicate_type", + ( + pytest.param( + (1 for _ in ()), + SizedIterable, + id="generator", + ), + pytest.param( + range(2), + SizedIterable[str], + marks=xfail_nested_protocol_checks, + id="range_str_specialized", + ), + pytest.param( + (1, 2, 3), + SizedIterable[str], + marks=xfail_nested_protocol_checks, + id="int_tuple_str_specialized", + ), + pytest.param( + ("1", "2", "3"), + SizedIterable[int], + marks=xfail_nested_protocol_checks, + id="str_tuple_int_specialized", + ), + ), + ) + def test_raises_for_non_member(self, subject: object, predicate_type: type) -> None: + with pytest.raises(TypeCheckError): + check_type(subject, predicate_type) + + class TestProtocol: - def test_member_defaultval(self, instantiate, annotation): + @pytest.mark.parametrize( + "instantiate", + [pytest.param(True, id="instance"), pytest.param(False, id="class")], + ) + def test_success(self, typing_provider: Any, instantiate: bool) -> None: + class MyProtocol(Protocol): + member: int + + def noargs(self) -> None: + pass + + def posonlyargs(self, a: int, b: str, /) -> None: + pass + + def posargs(self, a: int, b: str, c: float = 2.0) -> None: + pass + + def varargs(self, *args: Any) -> None: + pass + + def varkwargs(self, **kwargs: Any) -> None: + pass + + def varbothargs(self, *args: Any, **kwargs: Any) -> None: + pass + + @staticmethod + def my_static_method(x: int, y: str) -> None: + pass + + @classmethod + def my_class_method(cls, x: int, y: str) -> None: + pass + class Foo: member = 1 - def meth(self, x: str) -> None: + def noargs(self, x: int = 1) -> None: pass - subject = Foo() if instantiate else Foo - for _ in range(2): # Makes sure that the cache is also exercised - check_type(subject, annotation) + def posonlyargs(self, a: int, b: str, c: float = 2.0, /) -> None: + pass - def test_member_annotation(self, instantiate, annotation): - class Foo: + def posargs(self, *args: Any) -> None: + pass + + def varargs(self, *args: Any, kwarg: str = "foo") -> None: + pass + + def varkwargs(self, **kwargs: Any) -> None: + pass + + def varbothargs(self, *args: Any, **kwargs: Any) -> None: + pass + + # These were intentionally reversed, as this is OK for mypy + @classmethod + def my_static_method(cls, x: int, y: str) -> None: + pass + + @staticmethod + def my_class_method(x: int, y: str) -> None: + pass + + if instantiate: + check_type(Foo(), MyProtocol) + else: + check_type(Foo, type[MyProtocol]) + + @pytest.mark.parametrize("has_member", [True, False]) + def test_member_checks(self, has_member: bool) -> None: + class MyProtocol(Protocol): member: int + class Foo: + def __init__(self, member: int): + if member: + self.member = member + + if has_member: + check_type(Foo(1), MyProtocol) + else: + pytest.raises(TypeCheckError, check_type, Foo(0), MyProtocol).match( + f"^{qualified_name(Foo)} is not compatible with the " + f"{MyProtocol.__qualname__} protocol because it has no attribute named " + f"'member'" + ) + + def test_missing_method(self) -> None: + class MyProtocol(Protocol): + def meth(self) -> None: + pass + + class Foo: + pass + + pytest.raises(TypeCheckError, check_type, Foo(), MyProtocol).match( + f"^{qualified_name(Foo)} is not compatible with the " + f"{MyProtocol.__qualname__} protocol because it has no method named " + f"'meth'" + ) + + def test_too_many_posargs(self) -> None: + class MyProtocol(Protocol): + def meth(self) -> None: + pass + + class Foo: def meth(self, x: str) -> None: pass - subject = Foo() if instantiate else Foo - for _ in range(2): - check_type(subject, annotation) + pytest.raises(TypeCheckError, check_type, Foo(), MyProtocol).match( + f"^{qualified_name(Foo)} is not compatible with the " + f"{MyProtocol.__qualname__} protocol because its 'meth' method has too " + f"many mandatory positional arguments" + ) + + def test_wrong_posarg_name(self) -> None: + class MyProtocol(Protocol): + def meth(self, x: str) -> None: + pass - def test_attribute_missing(self, instantiate, annotation): class Foo: - val = 1 + def meth(self, y: str) -> None: + pass + pytest.raises(TypeCheckError, check_type, Foo(), MyProtocol).match( + rf"^{qualified_name(Foo)} is not compatible with the " + rf"{MyProtocol.__qualname__} protocol because its 'meth' method has a " + rf"positional argument \(y\) that should be named 'x' at this position" + ) + + def test_too_few_posargs(self) -> None: + class MyProtocol(Protocol): def meth(self, x: str) -> None: pass - clsname = f"{__name__}.TestProtocol.test_attribute_missing..Foo" - subject = Foo() if instantiate else Foo - for _ in range(2): - pytest.raises(TypeCheckError, check_type, subject, annotation).match( - f"{clsname} is not compatible with the (Runtime|Static)Protocol " - f"protocol because it has no attribute named 'member'" - ) + class Foo: + def meth(self) -> None: + pass + + pytest.raises(TypeCheckError, check_type, Foo(), MyProtocol).match( + f"^{qualified_name(Foo)} is not compatible with the " + f"{MyProtocol.__qualname__} protocol because its 'meth' method has too " + f"few positional arguments" + ) + + def test_no_varargs(self) -> None: + class MyProtocol(Protocol): + def meth(self, *args: Any) -> None: + pass - def test_method_missing(self, instantiate, annotation): class Foo: - member: int + def meth(self) -> None: + pass - pattern = ( - f"{__name__}.TestProtocol.test_method_missing..Foo is not " - f"compatible with the (Runtime|Static)Protocol protocol because it has no " - f"method named 'meth'" + pytest.raises(TypeCheckError, check_type, Foo(), MyProtocol).match( + f"^{qualified_name(Foo)} is not compatible with the " + f"{MyProtocol.__qualname__} protocol because its 'meth' method should " + f"accept variable positional arguments but doesn't" ) - subject = Foo() if instantiate else Foo - for _ in range(2): - pytest.raises(TypeCheckError, check_type, subject, annotation).match( - pattern - ) - def test_attribute_is_not_method_1(self, instantiate, annotation): + def test_no_kwargs(self) -> None: + class MyProtocol(Protocol): + def meth(self, **kwargs: Any) -> None: + pass + class Foo: - member: int - meth: str + def meth(self) -> None: + pass - pattern = ( - f"{__name__}.TestProtocol.test_attribute_is_not_method_1..Foo is " - f"not compatible with the (Runtime|Static)Protocol protocol because its " - f"'meth' attribute is not a method" + pytest.raises(TypeCheckError, check_type, Foo(), MyProtocol).match( + f"^{qualified_name(Foo)} is not compatible with the " + f"{MyProtocol.__qualname__} protocol because its 'meth' method should " + f"accept variable keyword arguments but doesn't" ) - subject = Foo() if instantiate else Foo - for _ in range(2): - pytest.raises(TypeCheckError, check_type, subject, annotation).match( - pattern - ) - def test_attribute_is_not_method_2(self, instantiate, annotation): + def test_missing_kwarg(self) -> None: + class MyProtocol(Protocol): + def meth(self, *, x: str) -> None: + pass + class Foo: - member: int - meth = "foo" + def meth(self) -> None: + pass - pattern = ( - f"{__name__}.TestProtocol.test_attribute_is_not_method_2..Foo is " - f"not compatible with the (Runtime|Static)Protocol protocol because its " - f"'meth' attribute is not a callable" + pytest.raises(TypeCheckError, check_type, Foo(), MyProtocol).match( + f"^{qualified_name(Foo)} is not compatible with the " + f"{MyProtocol.__qualname__} protocol because its 'meth' method is " + f"missing keyword-only arguments: x" ) - subject = Foo() if instantiate else Foo - for _ in range(2): - pytest.raises(TypeCheckError, check_type, subject, annotation).match( - pattern - ) - def test_method_signature_mismatch(self, instantiate, annotation): + def test_extra_kwarg(self) -> None: + class MyProtocol(Protocol): + def meth(self) -> None: + pass + class Foo: - member: int + def meth(self, *, x: str) -> None: + pass + + pytest.raises(TypeCheckError, check_type, Foo(), MyProtocol).match( + f"^{qualified_name(Foo)} is not compatible with the " + f"{MyProtocol.__qualname__} protocol because its 'meth' method has " + f"mandatory keyword-only arguments not present in the protocol: x" + ) - def meth(self, x: str, y: int) -> None: + def test_instance_staticmethod_mismatch(self) -> None: + class MyProtocol(Protocol): + @staticmethod + def meth() -> None: pass - pattern = ( - rf"(class )?{__name__}.TestProtocol.test_method_signature_mismatch." - rf".Foo is not compatible with the (Runtime|Static)Protocol " - rf"protocol because its 'meth' method has too many mandatory positional " - rf"arguments in its declaration; expected 2 but 3 mandatory positional " - rf"argument\(s\) declared" + class Foo: + def meth(self) -> None: + pass + + pytest.raises(TypeCheckError, check_type, Foo(), MyProtocol).match( + f"^{qualified_name(Foo)} is not compatible with the " + f"{MyProtocol.__qualname__} protocol because its 'meth' method should " + f"be a static method but it's an instance method" + ) + + def test_instance_classmethod_mismatch(self) -> None: + class MyProtocol(Protocol): + @classmethod + def meth(cls) -> None: + pass + + class Foo: + def meth(self) -> None: + pass + + pytest.raises(TypeCheckError, check_type, Foo(), MyProtocol).match( + f"^{qualified_name(Foo)} is not compatible with the " + f"{MyProtocol.__qualname__} protocol because its 'meth' method should " + f"be a class method but it's an instance method" ) - subject = Foo() if instantiate else Foo - for _ in range(2): - pytest.raises(TypeCheckError, check_type, subject, annotation).match( - pattern - ) class TestRecursiveType: diff --git a/tests/test_importhook.py b/tests/test_importhook.py index b0a214a..39d8968 100644 --- a/tests/test_importhook.py +++ b/tests/test_importhook.py @@ -64,5 +64,6 @@ def test_debug_instrumentation(monkeypatch, capsys): monkeypatch.setattr("typeguard.config.debug_instrumentation", True) import_dummymodule() out, err = capsys.readouterr() - assert f"Source code of '{dummy_module_path}' after instrumentation:" in err + path_str = str(dummy_module_path) + assert f"Source code of {path_str!r} after instrumentation:" in err assert "class DummyClass" in err diff --git a/tests/test_transformer.py b/tests/test_transformer.py index 15cf9d4..2e18a5a 100644 --- a/tests/test_transformer.py +++ b/tests/test_transformer.py @@ -1,16 +1,11 @@ import sys -from ast import parse +from ast import parse, unparse from textwrap import dedent import pytest from typeguard._transformer import TypeguardTransformer -if sys.version_info >= (3, 9): - from ast import unparse -else: - pytest.skip("Requires Python 3.9 or newer", allow_module_level=True) - def test_arguments_only() -> None: node = parse( @@ -559,6 +554,35 @@ def foo(self, x: int) -> int: ) +def test_method_posonlyargs() -> None: + node = parse( + dedent( + """ + class Foo: + def foo(self, x: int, /, y: str) -> int: + return x + """ + ) + ) + TypeguardTransformer(["Foo", "foo"]).visit(node) + assert ( + unparse(node) + == dedent( + """ + class Foo: + + def foo(self, x: int, /, y: str) -> int: + from typeguard import TypeCheckMemo + from typeguard._functions import check_argument_types, \ +check_return_type + memo = TypeCheckMemo(globals(), locals(), self_type=self.__class__) + check_argument_types('Foo.foo', {'x': (x, int), 'y': (y, str)}, memo) + return check_return_type('Foo.foo', x, int, memo) + """ + ).strip() + ) + + def test_classmethod() -> None: node = parse( dedent( @@ -590,6 +614,38 @@ def foo(cls, x: int) -> int: ) +def test_classmethod_posonlyargs() -> None: + node = parse( + dedent( + """ + class Foo: + @classmethod + def foo(cls, x: int, /, y: str) -> int: + return x + """ + ) + ) + TypeguardTransformer(["Foo", "foo"]).visit(node) + assert ( + unparse(node) + == dedent( + """ + class Foo: + + @classmethod + def foo(cls, x: int, /, y: str) -> int: + from typeguard import TypeCheckMemo + from typeguard._functions import check_argument_types, \ +check_return_type + memo = TypeCheckMemo(globals(), locals(), self_type=cls) + check_argument_types('Foo.foo', {'x': (x, int), 'y': (y, str)}, \ +memo) + return check_return_type('Foo.foo', x, int, memo) + """ + ).strip() + ) + + def test_staticmethod() -> None: node = parse( dedent( @@ -972,7 +1028,7 @@ def foo(x: Any) -> None: def foo(x: Any) -> None: memo = TypeCheckMemo(globals(), locals()) y: FooBar = x - z: list[FooBar] = check_variable_assignment([y], 'z', list, \ + z: list[FooBar] = check_variable_assignment([y], [[('z', list)]], \ memo) """ ).strip() @@ -1150,7 +1206,8 @@ def foo() -> None: def foo() -> None: memo = TypeCheckMemo(globals(), locals()) - x: int = check_variable_assignment(otherfunc(), 'x', int, memo) + x: int = check_variable_assignment(otherfunc(), [[('x', int)]], \ +memo) """ ).strip() ) @@ -1166,27 +1223,20 @@ def foo(*args: int) -> None: ) TypeguardTransformer().visit(node) - if sys.version_info < (3, 9): - extra_import = "from typing import Tuple\n" - tuple_type = "Tuple" - else: - extra_import = "" - tuple_type = "tuple" - assert ( unparse(node) == dedent( - f""" + """ from typeguard import TypeCheckMemo from typeguard._functions import check_argument_types, \ check_variable_assignment - {extra_import} + def foo(*args: int) -> None: memo = TypeCheckMemo(globals(), locals()) - check_argument_types('foo', {{'args': (args, \ -{tuple_type}[int, ...])}}, memo) - args = check_variable_assignment((5,), 'args', \ -{tuple_type}[int, ...], memo) + check_argument_types('foo', {'args': (args, \ +tuple[int, ...])}, memo) + args = check_variable_assignment((5,), \ +[[('args', tuple[int, ...])]], memo) """ ).strip() ) @@ -1202,27 +1252,20 @@ def foo(**kwargs: int) -> None: ) TypeguardTransformer().visit(node) - if sys.version_info < (3, 9): - extra_import = "from typing import Dict\n" - dict_type = "Dict" - else: - extra_import = "" - dict_type = "dict" - assert ( unparse(node) == dedent( - f""" + """ from typeguard import TypeCheckMemo from typeguard._functions import check_argument_types, \ check_variable_assignment - {extra_import} + def foo(**kwargs: int) -> None: memo = TypeCheckMemo(globals(), locals()) - check_argument_types('foo', {{'kwargs': (kwargs, \ -{dict_type}[str, int])}}, memo) - kwargs = check_variable_assignment({{'a': 5}}, 'kwargs', \ -{dict_type}[str, int], memo) + check_argument_types('foo', {'kwargs': (kwargs, \ +dict[str, int])}, memo) + kwargs = check_variable_assignment({'a': 5}, \ +[[('kwargs', dict[str, int])]], memo) """ ).strip() ) @@ -1251,8 +1294,8 @@ def foo() -> None: def foo() -> None: memo = TypeCheckMemo(globals(), locals()) - x: int | str = check_variable_assignment(otherfunc(), 'x', \ -Union_[int, str], memo) + x: int | str = check_variable_assignment(otherfunc(), \ +[[('x', Union_[int, str])]], memo) """ ).strip() ) @@ -1275,15 +1318,15 @@ def foo() -> None: == dedent( f""" from typeguard import TypeCheckMemo - from typeguard._functions import check_multi_variable_assignment + from typeguard._functions import check_variable_assignment from typing import Any def foo() -> None: memo = TypeCheckMemo(globals(), locals()) x: int z: bytes - {target} = check_multi_variable_assignment(otherfunc(), \ -[{{'x': int, 'y': Any, 'z': bytes}}], memo) + {target} = check_variable_assignment(otherfunc(), \ +[[('x', int), ('y', Any), ('z', bytes)]], memo) """ ).strip() ) @@ -1306,15 +1349,80 @@ def foo() -> None: == dedent( f""" from typeguard import TypeCheckMemo - from typeguard._functions import check_multi_variable_assignment + from typeguard._functions import check_variable_assignment + from typing import Any + + def foo() -> None: + memo = TypeCheckMemo(globals(), locals()) + x: int + z: bytes + {target} = check_variable_assignment(otherfunc(), \ +[[('x', int), ('*y', Any), ('z', bytes)]], memo) + """ + ).strip() + ) + + def test_complex_multi_assign(self) -> None: + node = parse( + dedent( + """ + def foo() -> None: + x: int + z: bytes + all = x, *y, z = otherfunc() + """ + ) + ) + TypeguardTransformer().visit(node) + target = "x, *y, z" if sys.version_info >= (3, 11) else "(x, *y, z)" + assert ( + unparse(node) + == dedent( + f""" + from typeguard import TypeCheckMemo + from typeguard._functions import check_variable_assignment from typing import Any def foo() -> None: memo = TypeCheckMemo(globals(), locals()) x: int z: bytes - {target} = check_multi_variable_assignment(otherfunc(), \ -[{{'x': int, '*y': Any, 'z': bytes}}], memo) + all = {target} = check_variable_assignment(otherfunc(), \ +[[('all', Any)], [('x', int), ('*y', Any), ('z', bytes)]], memo) + """ + ).strip() + ) + + def test_unpacking_assign_to_self(self) -> None: + node = parse( + dedent( + """ + class Foo: + + def foo(self) -> None: + x: int + (x, self.y) = 1, 'test' + """ + ) + ) + TypeguardTransformer().visit(node) + target = "x, self.y" if sys.version_info >= (3, 11) else "(x, self.y)" + assert ( + unparse(node) + == dedent( + f""" + from typeguard import TypeCheckMemo + from typeguard._functions import check_variable_assignment + from typing import Any + + class Foo: + + def foo(self) -> None: + memo = TypeCheckMemo(globals(), locals(), \ +self_type=self.__class__) + x: int + {target} = check_variable_assignment((1, 'test'), \ +[[('x', int), ('self.y', Any)]], memo) """ ).strip() ) @@ -1340,7 +1448,7 @@ def foo(x: int) -> None: def foo(x: int) -> None: memo = TypeCheckMemo(globals(), locals()) check_argument_types('foo', {'x': (x, int)}, memo) - x = check_variable_assignment(6, 'x', int, memo) + x = check_variable_assignment(6, [[('x', int)]], memo) """ ).strip() ) @@ -1441,7 +1549,8 @@ def foo() -> None: def foo() -> None: memo = TypeCheckMemo(globals(), locals()) x: int - x = check_variable_assignment({function}(x, 6), 'x', int, memo) + x = check_variable_assignment({function}(x, 6), [[('x', int)]], \ +memo) """ ).strip() ) @@ -1490,7 +1599,7 @@ def foo(x: int) -> None: def foo(x: int) -> None: memo = TypeCheckMemo(globals(), locals()) check_argument_types('foo', {'x': (x, int)}, memo) - x = check_variable_assignment(iadd(x, 6), 'x', int, memo) + x = check_variable_assignment(iadd(x, 6), [[('x', int)]], memo) """ ).strip() ) diff --git a/tests/test_typechecked.py b/tests/test_typechecked.py index dbb516f..d56f3ae 100644 --- a/tests/test_typechecked.py +++ b/tests/test_typechecked.py @@ -456,6 +456,29 @@ def method(cls, another: Self) -> None: rf"test_classmethod_arg_invalid\.\.Foo\)" ) + def test_self_type_valid(self): + class Foo: + @typechecked + def method(cls, subclass: type[Self]) -> None: + pass + + class Bar(Foo): + pass + + Foo().method(Bar) + + def test_self_type_invalid(self): + class Foo: + @typechecked + def method(cls, subclass: type[Self]) -> None: + pass + + pytest.raises(TypeCheckError, Foo().method, int).match( + rf'argument "subclass" \(class int\) is not a subclass of the self type ' + rf"\({__name__}\.{self.__class__.__name__}\." + rf"test_self_type_invalid\.\.Foo\)" + ) + class TestMock: def test_mock_argument(self): @@ -619,9 +642,8 @@ def foo(x: int) -> None: ) assert process.returncode == expected_return_code if process.returncode == 1: - assert process.stderr.endswith( - b'typeguard.TypeCheckError: argument "x" (str) is not an instance of ' - b"int\n" + assert process.stderr.strip().endswith( + b'typeguard.TypeCheckError: argument "x" (str) is not an instance of int' ) diff --git a/tests/test_union_transformer.py b/tests/test_union_transformer.py index dc45679..e6dcd25 100644 --- a/tests/test_union_transformer.py +++ b/tests/test_union_transformer.py @@ -1,13 +1,17 @@ import typing -from typing import Callable +from typing import Callable, Union import pytest from typing_extensions import Literal -from typeguard._union_transformer import compile_type_hint, type_substitutions +from typeguard._union_transformer import compile_type_hint -eval_globals = {"Callable": Callable, "Literal": Literal, "typing": typing} -eval_globals.update(type_substitutions) +eval_globals = { + "Callable": Callable, + "Literal": Literal, + "typing": typing, + "Union": Union, +} @pytest.mark.parametrize( @@ -15,12 +19,12 @@ [ ["str | int", "Union[str, int]"], ["str | int | bytes", "Union[str, int, bytes]"], - ["str | Union[int | bytes, set]", "Union[str, int, bytes, Set]"], + ["str | Union[int | bytes, set]", "Union[str, int, bytes, set]"], ["str | int | Callable[..., bytes]", "Union[str, int, Callable[..., bytes]]"], ["str | int | Callable[[], bytes]", "Union[str, int, Callable[[], bytes]]"], [ "str | int | Callable[[], bytes | set]", - "Union[str, int, Callable[[], Union[bytes, Set]]]", + "Union[str, int, Callable[[], Union[bytes, set]]]", ], ["str | int | Literal['foo']", "Union[str, int, Literal['foo']]"], ["str | int | Literal[-1]", "Union[str, int, Literal[-1]]"], @@ -29,11 +33,6 @@ 'str | int | Literal["It\'s a string \'\\""]', "Union[str, int, Literal['It\\'s a string \\'\"']]", ], - [ - "typing.Tuple | typing.List | Literal[-1]", - "Union[Tuple, List, Literal[-1]]", - ], - ["tuple[int, ...]", "Tuple[int, ...]"], ], ) def test_union_transformer(inputval: str, expected: str) -> None: