Skip to content

Commit

Permalink
move some testing helper functions to a new optype.inspect module
Browse files Browse the repository at this point in the history
  • Loading branch information
jorenham committed Jun 23, 2024
1 parent 9419d6a commit 9b51280
Show file tree
Hide file tree
Showing 4 changed files with 143 additions and 112 deletions.
130 changes: 130 additions & 0 deletions optype/inspect.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
from __future__ import annotations

import inspect
import itertools
import sys
from typing import TYPE_CHECKING, Any, cast, get_args as _get_args


if sys.version_info >= (3, 13):
from typing import TypeAliasType, is_protocol
else:
from typing_extensions import TypeAliasType, is_protocol


if TYPE_CHECKING:
from types import ModuleType


__all__ = (
'get_args',
'get_protocol_members',
'get_protocols',
'is_runtime_protocol',
)


def is_runtime_protocol(cls: type, /) -> bool:
"""
Check if `cls` is a `typing[_extensions].Protocol` that's decorated with
`typing[_extensions].runtime_checkable`.
"""
return is_protocol(cls) and getattr(cls, '_is_runtime_protocol', False)


def get_args(tp: TypeAliasType | type | str | Any, /) -> tuple[Any, ...]:
"""
A less broken implementation of `typing[_extensions].get_args()` that
- also works for type aliases defined with the PEP 695 `type` keyword, and
- recursively flattens nested of union'ed type aliases.
"""
args = _get_args(tp.__value__ if isinstance(tp, TypeAliasType) else tp)
return tuple(itertools.chain(*(get_args(arg) or [arg] for arg in args)))


def get_protocol_members(cls: type, /) -> frozenset[str]:
"""
A variant of `typing[_extensions].get_protocol_members()` that
- doesn't hide `__dict__` or `__annotations__`,
- doesn't add a `__hash__` if there's an `__eq__` method, and
- doesn't include methods of base types from different module.
"""
if not is_protocol(cls):
msg = f'{cls!r} is not a protocol'
raise TypeError(msg)

module_blacklist = {'typing', 'typing_extensions'}
annotations, module = cls.__annotations__, cls.__module__
members = annotations.keys() | {
name for name, v in vars(cls).items()
if (
name != '__new__'
and callable(v)
and (
v.__module__ == module
or (
# Fun fact: Each `@overload` returns the same dummy
# function; so there's no reference your wrapped method :).
# Oh and BTW; `typing.get_overloads` only works on the
# non-overloaded method...
# Oh, you mean the one that # you shouldn't define within
# a `typing.Protocol`?
# Yes exactly! Anyway, good luck searching for the
# undocumented and ever-changing dark corner of the
# `typing` internals. I'm sure it must be there somewhere!
# Oh yea if you can't find it, try `typing_extensions`.
# Oh, still can't find it? Did you try ALL THE VERSIONS?
#
# ...anyway, the only thing we know here, is the name of
# an overloaded method. But we have no idea how many of
# them there *were*, let alone their signatures.
v.__module__ in module_blacklist
and v.__name__ == '_overload_dummy'
)
)
) or (
isinstance(v, property)
and v.fget
and v.fget.__module__ == module
)
}

# this hack here is plagiarized from the (often incorrect)
# `typing_extensions.get_protocol_members`.
# Maybe the `typing.get_protocol_member`s` that's coming in 3.13 will
# won't be as broken. I have little hope though...
members |= cast(
set[str],
getattr(cls, '__protocol_attrs__', None) or set(),
)

# sometimes __protocol_attrs__ hallicunates some non-existing dunders.
# the `getattr_static` avoids potential descriptor magic
members = {
member for member in members
if member in annotations
or inspect.getattr_static(cls, member) is not None
# or getattr(cls, member) is not None
}

# also include any of the parents
for supercls in cls.mro()[1:]:
if is_protocol(supercls):
members |= get_protocol_members(supercls)

return frozenset(members)


def get_protocols(
module: ModuleType,
/,
private: bool = False,
) -> frozenset[type]:
"""Return the protocol types within the given module."""
return frozenset({
cls for name in dir(module)
if (private or not name.startswith('_'))
and is_protocol(cls := getattr(module, name))
})
115 changes: 7 additions & 108 deletions tests/helpers.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,13 @@
from __future__ import annotations

import inspect
import sys
from typing import TYPE_CHECKING, Any, cast, get_args as _get_args
from typing import TYPE_CHECKING


if sys.version_info >= (3, 13):
from typing import TypeAliasType, is_protocol
from typing import is_protocol
else:
from typing_extensions import TypeAliasType, is_protocol

from typing_extensions import is_protocol

if TYPE_CHECKING:
from types import ModuleType
Expand All @@ -19,122 +17,23 @@

__all__ = (
'get_callable_members',
'get_protocol_members',
'get_protocols',
'is_dunder',
'is_protocol',
'pascamel_to_snake',
)


def is_runtime_protocol(cls: type, /) -> bool:
"""Check if `cls` is a `@runtime_checkable` `typing.Protocol`."""
return is_protocol(cls) and getattr(cls, '_is_runtime_protocol', False)


def get_args(tp: TypeAliasType | type | str | Any, /) -> tuple[Any, ...]:
def get_callable_members(module: ModuleType, /) -> frozenset[str]:
"""
A less broken `typing.get_args()` that also works for type aliases defined
with the PEP695 `type` keyword, and recurses nested type aliases.
Return the public callables of a module, that aren't protocols, and
"""
args = _get_args(tp.__value__ if isinstance(tp, TypeAliasType) else tp)
args_flat: list[Any] = []
for arg in args:
# recurse nested
args_flat.extend(get_args(arg) or [arg])

return tuple(args_flat)


def get_protocol_members(cls: type, /) -> frozenset[str]:
"""
A variant of `typing_extensions.get_protocol_members()` that doesn't
hide e.g. `__dict__` and `__annotations__`, or adds `__hash__` if there's
an `__eq__` method.
Does not return method names of base classes defined in another module.
"""
assert is_protocol(cls)

module = cls.__module__
annotations = cls.__annotations__

members = annotations.keys() | {
name for name, v in vars(cls).items()
if (
name != '__new__'
and callable(v)
and (
v.__module__ == module
or (
# Fun fact: Each `@overload` returns the same dummy
# function; so there's no reference your wrapped method :).
# Oh and BTW; `typing.get_overloads` only works on the
# non-overloaded method...
# Oh, you mean the one that # you shouldn't define within
# a `typing.Protocol`?
# Yes exactly! Anyway, good luck searching for the
# undocumented and ever-changing dark corner of the
# `typing` internals. I'm sure it must be there somewhere!
# Oh yea if you can't find it, try `typing_extensions`.
# Oh, still can't find it? Did you try ALL THE VERSIONS?
#
# ...anyway, the only thing we know here, is the name of
# an overloaded method. But we have no idea how many of
# them there *were*, let alone their signatures.
v.__module__.startswith('typing')
and v.__name__ == '_overload_dummy'
)
)
) or (
isinstance(v, property)
and v.fget
and v.fget.__module__ == module
)
}

# this hack here is plagiarized from the (often incorrect)
# `typing_extensions.get_protocol_members`.
# Maybe the `typing.get_protocol_member`s` that's coming in 3.13 will
# won't be as broken. I have little hope though...
members |= cast(
set[str],
getattr(cls, '__protocol_attrs__', None) or set(),
)

# sometimes __protocol_attrs__ hallicunates some non-existing dunders.
# the `getattr_static` avoids potential descriptor magic
members = {
member for member in members
if member in annotations
or inspect.getattr_static(cls, member) is not None
# or getattr(cls, member) is not None
}

# also include any of the parents
for supercls in cls.mro()[1:]:
if is_protocol(supercls):
members |= get_protocol_members(supercls)

return frozenset(members)


def get_protocols(module: ModuleType) -> frozenset[type]:
"""Return the public protocol types within the given module."""
return frozenset({
cls for name in dir(module)
if not name.startswith('_')
and is_protocol(cls := getattr(module, name))
})


def get_callable_members(module: ModuleType) -> frozenset[str]:
"""Return the public protocol types within the given module."""
module_blacklist = {'typing', 'typing_extensions'}
return frozenset({
name for name in dir(module)
if not name.startswith('_')
and callable(cls := getattr(module, name))
and not is_protocol(cls)
and getattr(module, name).__module__ != 'typing'
and getattr(module, name).__module__ not in module_blacklist
})


Expand Down
2 changes: 1 addition & 1 deletion tests/numpy/test_sctypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

import optype.numpy as onp
from optype.numpy import _sctype # noqa: PLC2701 # pyright: ignore[reportPrivateUsage]
from ..helpers import get_args # noqa: TID252
from optype.inspect import get_args


_TEMPORAL = 'Timedelta64', 'Datetime64'
Expand Down
8 changes: 5 additions & 3 deletions tests/test_protocols.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,16 @@
import optype._do
import optype._does
import optype._has
from optype.inspect import (
get_protocol_members,
get_protocols,
is_runtime_protocol,
)

from .helpers import (
get_callable_members,
get_protocol_members,
get_protocols,
is_dunder,
is_protocol,
is_runtime_protocol,
pascamel_to_snake,
)

Expand Down

0 comments on commit 9b51280

Please sign in to comment.