Skip to content

Commit

Permalink
Fixed checking of protocols on the class level
Browse files Browse the repository at this point in the history
Fixes #498.
  • Loading branch information
agronholm committed Nov 3, 2024
1 parent 121efd5 commit 28dafec
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 13 deletions.
2 changes: 2 additions & 0 deletions docs/versionhistory.rst
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ This library adheres to
(`#486 <https://github.com/agronholm/typeguard/pull/486>`_)
- Fixed ``TypeError`` when checking a class against ``type[Self]``
(`#481 <https://github.com/agronholm/typeguard/pull/481>`_)
- Fixed checking of protocols on the class level (against ``type[SomeProtocol]``)
(`#498 <https://github.com/agronholm/typeguard/pull/498>`_)

**4.4.0** (2024-10-27)

Expand Down
25 changes: 14 additions & 11 deletions src/typeguard/_checkers.py
Original file line number Diff line number Diff line change
Expand Up @@ -635,10 +635,8 @@ def check_io(
raise TypeCheckError("is not an I/O object")


def check_signature_compatible(
subject_callable: Callable[..., Any], protocol: type, attrname: str
) -> None:
subject_sig = inspect.signature(subject_callable)
def check_signature_compatible(subject: type, protocol: type, attrname: str) -> None:
subject_sig = inspect.signature(getattr(subject, attrname))
protocol_sig = inspect.signature(getattr(protocol, attrname))
protocol_type: typing.Literal["instance", "class", "static"] = "instance"
subject_type: typing.Literal["instance", "class", "static"] = "instance"
Expand All @@ -652,12 +650,12 @@ def check_signature_compatible(
protocol_type = "class"

# Check if the subject-side method is a class method or static method
if inspect.ismethod(subject_callable) and inspect.isclass(
subject_callable.__self__
):
subject_type = "class"
elif not hasattr(subject_callable, "__self__"):
subject_type = "static"
if attrname in subject.__dict__:
descriptor = subject.__dict__[attrname]
if isinstance(descriptor, staticmethod):
subject_type = "static"
elif isinstance(descriptor, classmethod):
subject_type = "class"

if protocol_type == "instance" and subject_type != "instance":
raise TypeCheckError(
Expand Down Expand Up @@ -714,6 +712,10 @@ def check_signature_compatible(
if protocol_type == "instance":
protocol_args.pop(0)

# Remove the "self" parameter from the subject arguments to match
if subject_type == "instance":
subject_args.pop(0)

for protocol_arg, subject_arg in zip_longest(protocol_args, subject_args):
if protocol_arg is None:
if subject_arg.default is Parameter.empty:
Expand Down Expand Up @@ -818,8 +820,9 @@ def check_protocol(

# TODO: implement assignability checks for parameter and return value
# annotations
subject = value if isclass(value) else value.__class__
try:
check_signature_compatible(subject_member, origin_type, attrname)
check_signature_compatible(subject, origin_type, attrname)
except TypeCheckError as exc:
raise TypeCheckError(
f"is not compatible with the {origin_type.__qualname__} "
Expand Down
11 changes: 9 additions & 2 deletions tests/test_checkers.py
Original file line number Diff line number Diff line change
Expand Up @@ -1069,7 +1069,11 @@ def test_raises_for_non_member(self, subject: object, predicate_type: type) -> N


class TestProtocol:
def test_success(self, typing_provider: Any) -> None:
@pytest.mark.parametrize(
"instantiate",
[pytest.param(True, id="instance"), pytest.param(False, id="class")],
)
def test_success(self, typing_provider: Any, instantiate: bool) -> None:
class MyProtocol(Protocol):
member: int

Expand Down Expand Up @@ -1129,7 +1133,10 @@ def my_static_method(cls, x: int, y: str) -> None:
def my_class_method(x: int, y: str) -> None:
pass

check_type(Foo(), MyProtocol)
if instantiate:
check_type(Foo(), MyProtocol)
else:
check_type(Foo, type[MyProtocol])

@pytest.mark.parametrize("has_member", [True, False])
def test_member_checks(self, has_member: bool) -> None:
Expand Down

0 comments on commit 28dafec

Please sign in to comment.