Skip to content

Commit

Permalink
Fixed detection of optional TypedDict fields when NotRequired is a fo…
Browse files Browse the repository at this point in the history
…rward reference

Fixes #424.
  • Loading branch information
agronholm committed Mar 23, 2024
1 parent 2df6f4a commit 750c719
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 6 deletions.
2 changes: 2 additions & 0 deletions docs/versionhistory.rst
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ This library adheres to
- Fixed ``typing`` types (``dict[str, int]``, ``List[str]``, etc.) not passing checks
against ``type`` or ``Type``
(`#432 <https://github.com/agronholm/typeguard/issues/432>`_, PR by Yongxin Wang)
- Fixed detection of optional fields (``NotRequired[...]``) in ``TypedDict`` when using
forward references (`#424 <https://github.com/agronholm/typeguard/issues/424>`_)

**4.1.5** (2023-09-11)

Expand Down
23 changes: 17 additions & 6 deletions src/typeguard/_checkers.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,20 +51,20 @@
if sys.version_info >= (3, 11):
from typing import (
Annotated,
NotRequired,
TypeAlias,
get_args,
get_origin,
get_type_hints,
)

SubclassableAny = Any
else:
from typing_extensions import (
Annotated,
NotRequired,
TypeAlias,
get_args,
get_origin,
get_type_hints,
)
from typing_extensions import Any as SubclassableAny

Expand Down Expand Up @@ -251,22 +251,33 @@ def check_typed_dict(

declared_keys = frozenset(origin_type.__annotations__)
if hasattr(origin_type, "__required_keys__"):
required_keys = origin_type.__required_keys__
required_keys = set(origin_type.__required_keys__)
else: # py3.8 and lower
required_keys = declared_keys if origin_type.__total__ else frozenset()
required_keys = set(declared_keys) if origin_type.__total__ else set()

existing_keys = frozenset(value)
existing_keys = set(value)
extra_keys = existing_keys - declared_keys
if extra_keys:
keys_formatted = ", ".join(f'"{key}"' for key in sorted(extra_keys, key=repr))
raise TypeCheckError(f"has unexpected extra key(s): {keys_formatted}")

# Detect NotRequired fields which are hidden by get_type_hints()
type_hints: dict[str, type] = {}
for key, annotation in origin_type.__annotations__.items():
if isinstance(annotation, ForwardRef):
annotation = evaluate_forwardref(annotation, memo)
if get_origin(annotation) is NotRequired:
required_keys.discard(key)
annotation = get_args(annotation)[0]

type_hints[key] = annotation

missing_keys = required_keys - existing_keys
if missing_keys:
keys_formatted = ", ".join(f'"{key}"' for key in sorted(missing_keys, key=repr))
raise TypeCheckError(f"is missing required key(s): {keys_formatted}")

for key, argtype in get_type_hints(origin_type).items():
for key, argtype in type_hints.items():
argvalue = value.get(key, _missing)
if argvalue is not _missing:
try:
Expand Down
27 changes: 27 additions & 0 deletions tests/test_checkers.py
Original file line number Diff line number Diff line change
Expand Up @@ -492,6 +492,33 @@ class DummyDict(typing_provider.TypedDict):
TypeCheckError, check_type, {"x": 1, "y": 2, b"z": 3}, DummyDict
).match(r'dict has unexpected extra key\(s\): "y", "b\'z\'"')

def test_notrequired_pass(self, typing_provider):
try:
NotRequired = typing_provider.NotRequired
except AttributeError:
pytest.skip(f"'NotRequired' not found in {typing_provider.__name__!r}")

class DummyDict(typing_provider.TypedDict):
x: int
y: "NotRequired[int]"

check_type({"x": 8}, DummyDict)

def test_notrequired_fail(self, typing_provider):
try:
NotRequired = typing_provider.NotRequired
except AttributeError:
pytest.skip(f"'NotRequired' not found in {typing_provider.__name__!r}")

class DummyDict(typing_provider.TypedDict):
x: int
y: "NotRequired[int]"

with pytest.raises(
TypeCheckError, match=r"value of key 'y' of dict is not an instance of int"
):
check_type({"x": 1, "y": "foo"}, DummyDict)


class TestList:
def test_bad_type(self):
Expand Down

0 comments on commit 750c719

Please sign in to comment.