Skip to content

Commit

Permalink
Added support for checking against static protocols
Browse files Browse the repository at this point in the history
Fixes #457.
  • Loading branch information
agronholm committed May 26, 2024
1 parent d539190 commit 241d120
Show file tree
Hide file tree
Showing 5 changed files with 199 additions and 48 deletions.
16 changes: 15 additions & 1 deletion docs/features.rst
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ The following type checks are not yet supported in Typeguard:
* Types of values assigned to global or nonlocal variables
* Stubs defined with :func:`@overload <typing.overload>` (the implementation is checked
if instrumented)
* ``yield_from`` statements in generator functions
* ``yield from`` statements in generator functions
* ``ParamSpec`` and ``Concatenate`` are currently ignored
* Types where they are shadowed by arguments with the same name (e.g.
``def foo(x: type, type: str): ...``)
Expand Down Expand Up @@ -58,6 +58,20 @@ target function should be switched to a new one. To work around this limitation,
place :func:`@typechecked <typechecked>` at the bottom of the decorator stack, or use
the import hook instead.

Protocol checking
+++++++++++++++++

As of version 4.3.0, Typeguard can check instances and classes against Protocols,
regardless of whether they were annotated with :decorator:`typing.runtime_checkable`.

There are several limitations on the checks performed, however:

* For non-callable members, only presence is checked for; no type compatibility checks
are performed
* For methods, only the number of positional arguments are checked against, so any added
keyword-only arguments without defaults don't currently trip the checker
* Likewise, argument types are not checked for compatibility

Special considerations for ``if TYPE_CHECKING:``
------------------------------------------------

Expand Down
1 change: 1 addition & 0 deletions docs/versionhistory.rst
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ This library adheres to

**UNRELEASED**

- Added support for checking against static protocols
- Fixed some compatibility problems when running on Python 3.13
(`#460 <https://github.com/agronholm/typeguard/issues/460>`_; PR by @JelleZijlstra)
- Fixed test suite incompatibility with pytest 8.2
Expand Down
103 changes: 92 additions & 11 deletions src/typeguard/_checkers.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
Union,
)
from unittest.mock import Mock
from weakref import WeakKeyDictionary

try:
import typing_extensions
Expand Down Expand Up @@ -88,6 +89,9 @@
if sys.version_info >= (3, 9):
generic_alias_types += (types.GenericAlias,)

protocol_check_cache: WeakKeyDictionary[
type[Any], dict[type[Any], TypeCheckError | None]
] = WeakKeyDictionary()

# Sentinel
_missing = object()
Expand Down Expand Up @@ -650,19 +654,96 @@ def check_protocol(
args: tuple[Any, ...],
memo: TypeCheckMemo,
) -> None:
# TODO: implement proper compatibility checking and support non-runtime protocols
if getattr(origin_type, "_is_runtime_protocol", False):
if not isinstance(value, origin_type):
raise TypeCheckError(
f"is not compatible with the {origin_type.__qualname__} protocol"
subject: type[Any] = value if isclass(value) else type(value)

if subject in protocol_check_cache:
result_map = protocol_check_cache[subject]
if origin_type in result_map:
if exc := result_map[origin_type]:
raise exc
else:
return

# Collect a set of methods and non-method attributes present in the protocol
ignored_attrs = set(dir(typing.Protocol)) | {
"__annotations__",
"__non_callable_proto_members__",
}
expected_methods: dict[str, tuple[Any, Any]] = {}
expected_noncallable_members: dict[str, Any] = {}
for attrname in dir(origin_type):
# Skip attributes present in typing.Protocol
if attrname in ignored_attrs:
continue

member = getattr(origin_type, attrname)
if callable(member):
signature = inspect.signature(member)
argtypes = [
(p.annotation if p.annotation is not Parameter.empty else Any)
for p in signature.parameters.values()
if p.kind is not Parameter.KEYWORD_ONLY
] or Ellipsis
return_annotation = (
signature.return_annotation
if signature.return_annotation is not Parameter.empty
else Any
)
expected_methods[attrname] = argtypes, return_annotation
else:
expected_noncallable_members[attrname] = member

for attrname, annotation in typing.get_type_hints(origin_type).items():
expected_noncallable_members[attrname] = annotation

subject_annotations = typing.get_type_hints(subject)

# Check that all required methods are present and their signatures are compatible
result_map = protocol_check_cache.setdefault(subject, {})
try:
for attrname, callable_args in expected_methods.items():
try:
method = getattr(subject, attrname)
except AttributeError:
if attrname in subject_annotations:
raise TypeCheckError(
f"is not compatible with the {origin_type.__qualname__} protocol "
f"because its {attrname!r} attribute is not a method"
) from None
else:
raise TypeCheckError(
f"is not compatible with the {origin_type.__qualname__} protocol "
f"because it has no method named {attrname!r}"
) from None

if not callable(method):
raise TypeCheckError(
f"is not compatible with the {origin_type.__qualname__} protocol "
f"because its {attrname!r} attribute is not a callable"
)

# TODO: raise exception on added keyword-only arguments without defaults
try:
check_callable(method, Callable, callable_args, memo)
except TypeCheckError as exc:
raise TypeCheckError(
f"is not compatible with the {origin_type.__qualname__} protocol "
f"because its {attrname!r} method {exc}"
) from None

# Check that all required non-callable members are present
for attrname in expected_noncallable_members:
# TODO: implement assignability checks for non-callable members
if attrname not in subject_annotations and not hasattr(subject, attrname):
raise TypeCheckError(
f"is not compatible with the {origin_type.__qualname__} protocol "
f"because it has no attribute named {attrname!r}"
)
except TypeCheckError as exc:
result_map[origin_type] = exc
raise
else:
warnings.warn(
f"Typeguard cannot check the {origin_type.__qualname__} protocol because "
f"it is a non-runtime protocol. If you would like to type check this "
f"protocol, please use @typing.runtime_checkable",
stacklevel=get_stacklevel(),
)
result_map[origin_type] = None


def check_byteslike(
Expand Down
6 changes: 4 additions & 2 deletions tests/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,11 +47,13 @@ def method(self, a: int) -> None:


class StaticProtocol(Protocol):
def meth(self) -> None: ...
member: int

def meth(self, x: str) -> None: ...


@runtime_checkable
class RuntimeProtocol(Protocol):
member: int

def meth(self) -> None: ...
def meth(self, x: str) -> None: ...
121 changes: 87 additions & 34 deletions tests/test_checkers.py
Original file line number Diff line number Diff line change
Expand Up @@ -995,66 +995,119 @@ def test_text_real_file(self, tmp_path: Path):
check_type(f, TextIO)


@pytest.mark.parametrize(
"instantiate, annotation",
[
pytest.param(True, RuntimeProtocol, id="instance_runtime"),
pytest.param(False, Type[RuntimeProtocol], id="class_runtime"),
pytest.param(True, StaticProtocol, id="instance_static"),
pytest.param(False, Type[StaticProtocol], id="class_static"),
],
)
class TestProtocol:
def test_protocol(self):
def test_member_defaultval(self, instantiate, annotation):
class Foo:
member = 1

def meth(self) -> None:
def meth(self, x: str) -> None:
pass

check_type(Foo(), RuntimeProtocol)
check_type(Foo, Type[RuntimeProtocol])
subject = Foo() if instantiate else Foo
for _ in range(2): # Makes sure that the cache is also exercised
check_type(subject, annotation)

def test_protocol_warns_on_static(self):
def test_member_annotation(self, instantiate, annotation):
class Foo:
member = 1
member: int

def meth(self) -> None:
def meth(self, x: str) -> None:
pass

with pytest.warns(
UserWarning, match=r"Typeguard cannot check the StaticProtocol protocol.*"
) as warning:
check_type(Foo(), StaticProtocol)
subject = Foo() if instantiate else Foo
for _ in range(2):
check_type(subject, annotation)

assert warning.list[0].filename == __file__
def test_attribute_missing(self, instantiate, annotation):
class Foo:
val = 1

with pytest.warns(
UserWarning, match=r"Typeguard cannot check the StaticProtocol protocol.*"
) as warning:
check_type(Foo, Type[StaticProtocol])
def meth(self, x: str) -> None:
pass

assert warning.list[0].filename == __file__
clsname = f"{__name__}.TestProtocol.test_attribute_missing.<locals>.Foo"
subject = Foo() if instantiate else Foo
for _ in range(2):
pytest.raises(TypeCheckError, check_type, subject, annotation).match(
f"{clsname} is not compatible with the (Runtime|Static)Protocol "
f"protocol because it has no attribute named 'member'"
)

def test_fail_non_method_members(self):
def test_method_missing(self, instantiate, annotation):
class Foo:
val = 1
member: int

def meth(self) -> None:
pass
pattern = (
f"{__name__}.TestProtocol.test_method_missing.<locals>.Foo is not "
f"compatible with the (Runtime|Static)Protocol protocol because it has no "
f"method named 'meth'"
)
subject = Foo() if instantiate else Foo
for _ in range(2):
pytest.raises(TypeCheckError, check_type, subject, annotation).match(
pattern
)

def test_attribute_is_not_method_1(self, instantiate, annotation):
class Foo:
member: int
meth: str

clsname = f"{__name__}.TestProtocol.test_fail_non_method_members.<locals>.Foo"
pytest.raises(TypeCheckError, check_type, Foo(), RuntimeProtocol).match(
f"{clsname} is not compatible with the RuntimeProtocol protocol"
pattern = (
f"{__name__}.TestProtocol.test_attribute_is_not_method_1.<locals>.Foo is "
f"not compatible with the (Runtime|Static)Protocol protocol because its "
f"'meth' attribute is not a method"
)
pytest.raises(TypeCheckError, check_type, Foo, Type[RuntimeProtocol]).match(
f"class {clsname} is not compatible with the RuntimeProtocol protocol"
subject = Foo() if instantiate else Foo
for _ in range(2):
pytest.raises(TypeCheckError, check_type, subject, annotation).match(
pattern
)

def test_attribute_is_not_method_2(self, instantiate, annotation):
class Foo:
member: int
meth = "foo"

pattern = (
f"{__name__}.TestProtocol.test_attribute_is_not_method_2.<locals>.Foo is "
f"not compatible with the (Runtime|Static)Protocol protocol because its "
f"'meth' attribute is not a callable"
)
subject = Foo() if instantiate else Foo
for _ in range(2):
pytest.raises(TypeCheckError, check_type, subject, annotation).match(
pattern
)

def test_fail(self):
def test_method_signature_mismatch(self, instantiate, annotation):
class Foo:
def meth2(self) -> None:
member: int

def meth(self, x: str, y: int) -> None:
pass

pattern = (
f"{__name__}.TestProtocol.test_fail.<locals>.Foo is not compatible with "
f"the RuntimeProtocol protocol"
)
pytest.raises(TypeCheckError, check_type, Foo(), RuntimeProtocol).match(pattern)
pytest.raises(TypeCheckError, check_type, Foo, Type[RuntimeProtocol]).match(
pattern
rf"(class )?{__name__}.TestProtocol.test_method_signature_mismatch."
rf"<locals>.Foo is not compatible with the (Runtime|Static)Protocol "
rf"protocol because its 'meth' method has too many mandatory positional "
rf"arguments in its declaration; expected 2 but 3 mandatory positional "
rf"argument\(s\) declared"
)
subject = Foo() if instantiate else Foo
for _ in range(2):
pytest.raises(TypeCheckError, check_type, subject, annotation).match(
pattern
)


class TestRecursiveType:
Expand Down

0 comments on commit 241d120

Please sign in to comment.