diff --git a/optype/inspect.py b/optype/inspect.py new file mode 100644 index 0000000..7ae61c1 --- /dev/null +++ b/optype/inspect.py @@ -0,0 +1,130 @@ +from __future__ import annotations + +import inspect +import itertools +import sys +from typing import TYPE_CHECKING, Any, cast, get_args as _get_args + + +if sys.version_info >= (3, 13): + from typing import TypeAliasType, is_protocol +else: + from typing_extensions import TypeAliasType, is_protocol + + +if TYPE_CHECKING: + from types import ModuleType + + +__all__ = ( + 'get_args', + 'get_protocol_members', + 'get_protocols', + 'is_runtime_protocol', +) + + +def is_runtime_protocol(cls: type, /) -> bool: + """ + Check if `cls` is a `typing[_extensions].Protocol` that's decorated with + `typing[_extensions].runtime_checkable`. + """ + return is_protocol(cls) and getattr(cls, '_is_runtime_protocol', False) + + +def get_args(tp: TypeAliasType | type | str | Any, /) -> tuple[Any, ...]: + """ + A less broken implementation of `typing[_extensions].get_args()` that + + - also works for type aliases defined with the PEP 695 `type` keyword, and + - recursively flattens nested of union'ed type aliases. + """ + args = _get_args(tp.__value__ if isinstance(tp, TypeAliasType) else tp) + return tuple(itertools.chain(*(get_args(arg) or [arg] for arg in args))) + + +def get_protocol_members(cls: type, /) -> frozenset[str]: + """ + A variant of `typing[_extensions].get_protocol_members()` that + + - doesn't hide `__dict__` or `__annotations__`, + - doesn't add a `__hash__` if there's an `__eq__` method, and + - doesn't include methods of base types from different module. + """ + if not is_protocol(cls): + msg = f'{cls!r} is not a protocol' + raise TypeError(msg) + + module_blacklist = {'typing', 'typing_extensions'} + annotations, module = cls.__annotations__, cls.__module__ + members = annotations.keys() | { + name for name, v in vars(cls).items() + if ( + name != '__new__' + and callable(v) + and ( + v.__module__ == module + or ( + # Fun fact: Each `@overload` returns the same dummy + # function; so there's no reference your wrapped method :). + # Oh and BTW; `typing.get_overloads` only works on the + # non-overloaded method... + # Oh, you mean the one that # you shouldn't define within + # a `typing.Protocol`? + # Yes exactly! Anyway, good luck searching for the + # undocumented and ever-changing dark corner of the + # `typing` internals. I'm sure it must be there somewhere! + # Oh yea if you can't find it, try `typing_extensions`. + # Oh, still can't find it? Did you try ALL THE VERSIONS? + # + # ...anyway, the only thing we know here, is the name of + # an overloaded method. But we have no idea how many of + # them there *were*, let alone their signatures. + v.__module__ in module_blacklist + and v.__name__ == '_overload_dummy' + ) + ) + ) or ( + isinstance(v, property) + and v.fget + and v.fget.__module__ == module + ) + } + + # this hack here is plagiarized from the (often incorrect) + # `typing_extensions.get_protocol_members`. + # Maybe the `typing.get_protocol_member`s` that's coming in 3.13 will + # won't be as broken. I have little hope though... + members |= cast( + set[str], + getattr(cls, '__protocol_attrs__', None) or set(), + ) + + # sometimes __protocol_attrs__ hallicunates some non-existing dunders. + # the `getattr_static` avoids potential descriptor magic + members = { + member for member in members + if member in annotations + or inspect.getattr_static(cls, member) is not None + # or getattr(cls, member) is not None + } + + # also include any of the parents + for supercls in cls.mro()[1:]: + if is_protocol(supercls): + members |= get_protocol_members(supercls) + + return frozenset(members) + + +def get_protocols( + module: ModuleType, + /, + private: bool = False, +) -> frozenset[type]: + """Return the protocol types within the given module.""" + return frozenset({ + cls for name in dir(module) + if (private or not name.startswith('_')) + and is_protocol(cls := getattr(module, name)) + }) diff --git a/tests/helpers.py b/tests/helpers.py index 3935bc4..165964f 100644 --- a/tests/helpers.py +++ b/tests/helpers.py @@ -1,15 +1,13 @@ from __future__ import annotations -import inspect import sys -from typing import TYPE_CHECKING, Any, cast, get_args as _get_args +from typing import TYPE_CHECKING if sys.version_info >= (3, 13): - from typing import TypeAliasType, is_protocol + from typing import is_protocol else: - from typing_extensions import TypeAliasType, is_protocol - + from typing_extensions import is_protocol if TYPE_CHECKING: from types import ModuleType @@ -19,122 +17,23 @@ __all__ = ( 'get_callable_members', - 'get_protocol_members', - 'get_protocols', 'is_dunder', 'is_protocol', 'pascamel_to_snake', ) -def is_runtime_protocol(cls: type, /) -> bool: - """Check if `cls` is a `@runtime_checkable` `typing.Protocol`.""" - return is_protocol(cls) and getattr(cls, '_is_runtime_protocol', False) - - -def get_args(tp: TypeAliasType | type | str | Any, /) -> tuple[Any, ...]: +def get_callable_members(module: ModuleType, /) -> frozenset[str]: """ - A less broken `typing.get_args()` that also works for type aliases defined - with the PEP695 `type` keyword, and recurses nested type aliases. + Return the public callables of a module, that aren't protocols, and """ - args = _get_args(tp.__value__ if isinstance(tp, TypeAliasType) else tp) - args_flat: list[Any] = [] - for arg in args: - # recurse nested - args_flat.extend(get_args(arg) or [arg]) - - return tuple(args_flat) - - -def get_protocol_members(cls: type, /) -> frozenset[str]: - """ - A variant of `typing_extensions.get_protocol_members()` that doesn't - hide e.g. `__dict__` and `__annotations__`, or adds `__hash__` if there's - an `__eq__` method. - Does not return method names of base classes defined in another module. - """ - assert is_protocol(cls) - - module = cls.__module__ - annotations = cls.__annotations__ - - members = annotations.keys() | { - name for name, v in vars(cls).items() - if ( - name != '__new__' - and callable(v) - and ( - v.__module__ == module - or ( - # Fun fact: Each `@overload` returns the same dummy - # function; so there's no reference your wrapped method :). - # Oh and BTW; `typing.get_overloads` only works on the - # non-overloaded method... - # Oh, you mean the one that # you shouldn't define within - # a `typing.Protocol`? - # Yes exactly! Anyway, good luck searching for the - # undocumented and ever-changing dark corner of the - # `typing` internals. I'm sure it must be there somewhere! - # Oh yea if you can't find it, try `typing_extensions`. - # Oh, still can't find it? Did you try ALL THE VERSIONS? - # - # ...anyway, the only thing we know here, is the name of - # an overloaded method. But we have no idea how many of - # them there *were*, let alone their signatures. - v.__module__.startswith('typing') - and v.__name__ == '_overload_dummy' - ) - ) - ) or ( - isinstance(v, property) - and v.fget - and v.fget.__module__ == module - ) - } - - # this hack here is plagiarized from the (often incorrect) - # `typing_extensions.get_protocol_members`. - # Maybe the `typing.get_protocol_member`s` that's coming in 3.13 will - # won't be as broken. I have little hope though... - members |= cast( - set[str], - getattr(cls, '__protocol_attrs__', None) or set(), - ) - - # sometimes __protocol_attrs__ hallicunates some non-existing dunders. - # the `getattr_static` avoids potential descriptor magic - members = { - member for member in members - if member in annotations - or inspect.getattr_static(cls, member) is not None - # or getattr(cls, member) is not None - } - - # also include any of the parents - for supercls in cls.mro()[1:]: - if is_protocol(supercls): - members |= get_protocol_members(supercls) - - return frozenset(members) - - -def get_protocols(module: ModuleType) -> frozenset[type]: - """Return the public protocol types within the given module.""" - return frozenset({ - cls for name in dir(module) - if not name.startswith('_') - and is_protocol(cls := getattr(module, name)) - }) - - -def get_callable_members(module: ModuleType) -> frozenset[str]: - """Return the public protocol types within the given module.""" + module_blacklist = {'typing', 'typing_extensions'} return frozenset({ name for name in dir(module) if not name.startswith('_') and callable(cls := getattr(module, name)) and not is_protocol(cls) - and getattr(module, name).__module__ != 'typing' + and getattr(module, name).__module__ not in module_blacklist }) diff --git a/tests/numpy/test_sctypes.py b/tests/numpy/test_sctypes.py index c761bc5..d255a2c 100644 --- a/tests/numpy/test_sctypes.py +++ b/tests/numpy/test_sctypes.py @@ -6,7 +6,7 @@ import optype.numpy as onp from optype.numpy import _sctype # noqa: PLC2701 # pyright: ignore[reportPrivateUsage] -from ..helpers import get_args # noqa: TID252 +from optype.inspect import get_args _TEMPORAL = 'Timedelta64', 'Datetime64' diff --git a/tests/test_protocols.py b/tests/test_protocols.py index 8945a5b..f877e2f 100644 --- a/tests/test_protocols.py +++ b/tests/test_protocols.py @@ -5,14 +5,16 @@ import optype._do import optype._does import optype._has +from optype.inspect import ( + get_protocol_members, + get_protocols, + is_runtime_protocol, +) from .helpers import ( get_callable_members, - get_protocol_members, - get_protocols, is_dunder, is_protocol, - is_runtime_protocol, pascamel_to_snake, )