diff --git a/docs/versionhistory.rst b/docs/versionhistory.rst index 1aa9762..f79a795 100644 --- a/docs/versionhistory.rst +++ b/docs/versionhistory.rst @@ -14,6 +14,8 @@ This library adheres to - Fixed ``typing`` types (``dict[str, int]``, ``List[str]``, etc.) not passing checks against ``type`` or ``Type`` (`#432 `_, PR by Yongxin Wang) +- Fixed detection of optional fields (``NotRequired[...]``) in ``TypedDict`` when using + forward references (`#424 `_) **4.1.5** (2023-09-11) diff --git a/src/typeguard/_checkers.py b/src/typeguard/_checkers.py index f79da91..2f8de6f 100644 --- a/src/typeguard/_checkers.py +++ b/src/typeguard/_checkers.py @@ -51,20 +51,20 @@ if sys.version_info >= (3, 11): from typing import ( Annotated, + NotRequired, TypeAlias, get_args, get_origin, - get_type_hints, ) SubclassableAny = Any else: from typing_extensions import ( Annotated, + NotRequired, TypeAlias, get_args, get_origin, - get_type_hints, ) from typing_extensions import Any as SubclassableAny @@ -251,22 +251,33 @@ def check_typed_dict( declared_keys = frozenset(origin_type.__annotations__) if hasattr(origin_type, "__required_keys__"): - required_keys = origin_type.__required_keys__ + required_keys = set(origin_type.__required_keys__) else: # py3.8 and lower - required_keys = declared_keys if origin_type.__total__ else frozenset() + required_keys = set(declared_keys) if origin_type.__total__ else set() - existing_keys = frozenset(value) + existing_keys = set(value) extra_keys = existing_keys - declared_keys if extra_keys: keys_formatted = ", ".join(f'"{key}"' for key in sorted(extra_keys, key=repr)) raise TypeCheckError(f"has unexpected extra key(s): {keys_formatted}") + # Detect NotRequired fields which are hidden by get_type_hints() + type_hints: dict[str, type] = {} + for key, annotation in origin_type.__annotations__.items(): + if isinstance(annotation, ForwardRef): + annotation = evaluate_forwardref(annotation, memo) + if get_origin(annotation) is NotRequired: + required_keys.discard(key) + annotation = get_args(annotation)[0] + + type_hints[key] = annotation + missing_keys = required_keys - existing_keys if missing_keys: keys_formatted = ", ".join(f'"{key}"' for key in sorted(missing_keys, key=repr)) raise TypeCheckError(f"is missing required key(s): {keys_formatted}") - for key, argtype in get_type_hints(origin_type).items(): + for key, argtype in type_hints.items(): argvalue = value.get(key, _missing) if argvalue is not _missing: try: diff --git a/tests/test_checkers.py b/tests/test_checkers.py index f9aef31..d767b4f 100644 --- a/tests/test_checkers.py +++ b/tests/test_checkers.py @@ -492,6 +492,33 @@ class DummyDict(typing_provider.TypedDict): TypeCheckError, check_type, {"x": 1, "y": 2, b"z": 3}, DummyDict ).match(r'dict has unexpected extra key\(s\): "y", "b\'z\'"') + def test_notrequired_pass(self, typing_provider): + try: + NotRequired = typing_provider.NotRequired + except AttributeError: + pytest.skip(f"'NotRequired' not found in {typing_provider.__name__!r}") + + class DummyDict(typing_provider.TypedDict): + x: int + y: "NotRequired[int]" + + check_type({"x": 8}, DummyDict) + + def test_notrequired_fail(self, typing_provider): + try: + NotRequired = typing_provider.NotRequired + except AttributeError: + pytest.skip(f"'NotRequired' not found in {typing_provider.__name__!r}") + + class DummyDict(typing_provider.TypedDict): + x: int + y: "NotRequired[int]" + + with pytest.raises( + TypeCheckError, match=r"value of key 'y' of dict is not an instance of int" + ): + check_type({"x": 1, "y": "foo"}, DummyDict) + class TestList: def test_bad_type(self):