diff --git a/docs/versionhistory.rst b/docs/versionhistory.rst index c6fd657..28b92a3 100644 --- a/docs/versionhistory.rst +++ b/docs/versionhistory.rst @@ -13,6 +13,8 @@ This library adheres to (`#486 `_) - Fixed ``TypeError`` when checking a class against ``type[Self]`` (`#481 `_) +- Fixed checking of protocols on the class level (against ``type[SomeProtocol]``) + (`#498 `_) **4.4.0** (2024-10-27) diff --git a/src/typeguard/_checkers.py b/src/typeguard/_checkers.py index 8166bf2..1ab6ee2 100644 --- a/src/typeguard/_checkers.py +++ b/src/typeguard/_checkers.py @@ -635,10 +635,8 @@ def check_io( raise TypeCheckError("is not an I/O object") -def check_signature_compatible( - subject_callable: Callable[..., Any], protocol: type, attrname: str -) -> None: - subject_sig = inspect.signature(subject_callable) +def check_signature_compatible(subject: type, protocol: type, attrname: str) -> None: + subject_sig = inspect.signature(getattr(subject, attrname)) protocol_sig = inspect.signature(getattr(protocol, attrname)) protocol_type: typing.Literal["instance", "class", "static"] = "instance" subject_type: typing.Literal["instance", "class", "static"] = "instance" @@ -652,12 +650,12 @@ def check_signature_compatible( protocol_type = "class" # Check if the subject-side method is a class method or static method - if inspect.ismethod(subject_callable) and inspect.isclass( - subject_callable.__self__ - ): - subject_type = "class" - elif not hasattr(subject_callable, "__self__"): - subject_type = "static" + if attrname in subject.__dict__: + descriptor = subject.__dict__[attrname] + if isinstance(descriptor, staticmethod): + subject_type = "static" + elif isinstance(descriptor, classmethod): + subject_type = "class" if protocol_type == "instance" and subject_type != "instance": raise TypeCheckError( @@ -714,6 +712,10 @@ def check_signature_compatible( if protocol_type == "instance": protocol_args.pop(0) + # Remove the "self" parameter from the subject arguments to match + if subject_type == "instance": + subject_args.pop(0) + for protocol_arg, subject_arg in zip_longest(protocol_args, subject_args): if protocol_arg is None: if subject_arg.default is Parameter.empty: @@ -818,8 +820,9 @@ def check_protocol( # TODO: implement assignability checks for parameter and return value # annotations + subject = value if isclass(value) else value.__class__ try: - check_signature_compatible(subject_member, origin_type, attrname) + check_signature_compatible(subject, origin_type, attrname) except TypeCheckError as exc: raise TypeCheckError( f"is not compatible with the {origin_type.__qualname__} " diff --git a/tests/test_checkers.py b/tests/test_checkers.py index d2bcd81..526e94f 100644 --- a/tests/test_checkers.py +++ b/tests/test_checkers.py @@ -1069,7 +1069,11 @@ def test_raises_for_non_member(self, subject: object, predicate_type: type) -> N class TestProtocol: - def test_success(self, typing_provider: Any) -> None: + @pytest.mark.parametrize( + "instantiate", + [pytest.param(True, id="instance"), pytest.param(False, id="class")], + ) + def test_success(self, typing_provider: Any, instantiate: bool) -> None: class MyProtocol(Protocol): member: int @@ -1129,7 +1133,10 @@ def my_static_method(cls, x: int, y: str) -> None: def my_class_method(x: int, y: str) -> None: pass - check_type(Foo(), MyProtocol) + if instantiate: + check_type(Foo(), MyProtocol) + else: + check_type(Foo, type[MyProtocol]) @pytest.mark.parametrize("has_member", [True, False]) def test_member_checks(self, has_member: bool) -> None: