Skip to content

Commit

Permalink
fix: Handle recursive and cyclic TypeAliasType
Browse files Browse the repository at this point in the history
- Fixes an issue where a direct recursive or indirect reference to a `TypeAliasType` would produce an unusable type-hint and fail to resolve the forward reference.
  • Loading branch information
seandstewart committed Oct 30, 2024
1 parent b3ec9e5 commit 94c8fa3
Show file tree
Hide file tree
Showing 7 changed files with 46 additions and 17 deletions.
9 changes: 5 additions & 4 deletions src/typelib/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,8 @@ def static_order(
# We want to leverage the cache if possible, hence the recursive call.
# Shouldn't actually recurse more than once or twice.
if inspection.istypealiastype(t):
return static_order(t.__value__)
value = inspection.unwrap(t)
return static_order(value)
if isinstance(t, (str, refs.ForwardRef)):
ref = refs.forwardref(t) if isinstance(t, str) else t
t = refs.evaluate(ref)
Expand Down Expand Up @@ -80,7 +81,7 @@ def itertypes(
[`itertypes`][typelib.graph.itertypes].
"""
if inspection.istypealiastype(t):
t = t.__value__
t = inspection.unwrap(t)
if isinstance(t, (str, refs.ForwardRef)): # pragma: no cover
ref = refs.forwardref(t) if isinstance(t, str) else t
t = refs.evaluate(ref)
Expand Down Expand Up @@ -113,7 +114,7 @@ def get_type_graph(t: type) -> graphlib.TopologicalSorter[TypeNode]:
in a closed loop which never terminates (infinite recursion).
"""
if inspection.istypealiastype(t):
t = t.__value__
t = inspection.unwrap(t)

graph: graphlib.TopologicalSorter = graphlib.TopologicalSorter()
root = TypeNode(t)
Expand All @@ -131,7 +132,7 @@ def get_type_graph(t: type) -> graphlib.TopologicalSorter[TypeNode]:
if child in (constants.empty, typing.Any):
continue
if inspection.istypealiastype(child):
child = child.__value__
child = inspection.unwrap(child)

# Only subscripted generics or non-stdlib types can be cyclic.
# i.e., we may get `str` or `datetime` any number of times,
Expand Down
8 changes: 4 additions & 4 deletions src/typelib/marshals/routines.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,7 +273,7 @@ def __init__(self, t: type[UnionT], context: ContextT, *, var: str | None = None
var: A variable name for the indicated type annotation (unused, optional).
"""
super().__init__(t, context, var=var)
self.stack = inspection.args(t)
self.stack = inspection.args(t, evaluate=True)
self.nullable = inspection.isoptionaltype(t)
self.ordered_routines = [self.context[typ] for typ in self.stack]

Expand Down Expand Up @@ -352,7 +352,7 @@ def __init__(self, t: type[MappingT], context: ContextT, *, var: str | None = No
var: A variable name for the indicated type annotation (unused, optional).
"""
super().__init__(t, context, var=var)
key_t, value_t = inspection.args(t)
key_t, value_t = inspection.args(t, evaluate=True)
self.keys = context[key_t]
self.values = context[value_t]

Expand Down Expand Up @@ -387,7 +387,7 @@ def __init__(
"""
super().__init__(t=t, context=context, var=var)
# supporting tuple[str, ...]
(value_t, *_) = inspection.args(t)
(value_t, *_) = inspection.args(t, evaluate=True)
self.values = context[value_t]

def __call__(self, val: IterableT) -> MarshalledIterableT:
Expand Down Expand Up @@ -423,7 +423,7 @@ def __init__(
var: A variable name for the indicated type annotation (unused, optional).
"""
super().__init__(t, context, var=var)
self.stack = inspection.args(t)
self.stack = inspection.args(t, evaluate=True)
self.ordered_routines = [self.context[vt] for vt in self.stack]

def __call__(self, val: compat.TupleT) -> MarshalledIterableT:
Expand Down
18 changes: 15 additions & 3 deletions src/typelib/py/inspection.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ def _check_generics(hint: tp.Any):
}


def args(annotation: tp.Any) -> tp.Tuple[tp.Any, ...]:
def args(annotation: tp.Any, *, evaluate: bool = False) -> tp.Tuple[tp.Any, ...]:
"""Get the args supplied to an annotation, normalizing [`typing.TypeVar`][].
Note:
Expand Down Expand Up @@ -184,6 +184,9 @@ def args(annotation: tp.Any) -> tp.Tuple[tp.Any, ...]:
if not a:
a = getattr(annotation, "__args__", a)

if evaluate:
a = (*(refs.evaluate(r) for r in a),)

return (*_normalize_typevars(*a),)


Expand Down Expand Up @@ -1485,18 +1488,27 @@ def istypealiastype(t: tp.Any) -> compat.TypeIs[compat.TypeAliasType]:

@compat.cache
def unwrap(t: tp.Any) -> tp.Any:
while True:
lt = None
while lt is not t:
if should_unwrap(t):
lt = t
t = t.__args__[0]
continue
if istypealiastype(t):
t = t.__value__
tv = t.__value__
if issubclass(type(tv), str):
return refs.forwardref(tv, module=t.__module__)
lt = t
t = tv
continue

if hasattr(t, "__supertype__"):
lt = t
t = t.__supertype__
continue

return t
return t


def _safe_issubclass(__cls: type, __class_or_tuple: type | tuple[type, ...]) -> bool:
Expand Down
12 changes: 6 additions & 6 deletions src/typelib/unmarshals/routines.py
Original file line number Diff line number Diff line change
Expand Up @@ -628,7 +628,7 @@ def __init__(self, t: type[LiteralT], context: ContextT, *, var: str | None = No
var: A variable name for the indicated type annotation (unused, optional).
"""
super().__init__(t, context, var=var)
self.values = inspection.args(t)
self.values = inspection.args(t, evaluate=True)

def __call__(self, val: tp.Any) -> LiteralT:
if val in self.values:
Expand Down Expand Up @@ -677,7 +677,7 @@ def __init__(self, t: type[UnionT], context: ContextT, *, var: str | None = None
var: A variable name for the indicated type annotation (unused, optional).
"""
super().__init__(t, context, var=var)
self.stack = inspection.args(t)
self.stack = inspection.args(t, evaluate=True)
if inspection.isoptionaltype(t):
self.stack = (self.stack[-1], *self.stack[:-1])

Expand Down Expand Up @@ -746,7 +746,7 @@ def __init__(self, t: type[MappingT], context: ContextT, *, var: str | None = No
var: A variable name for the indicated type annotation (unused, optional).
"""
super().__init__(t, context, var=var)
key_t, value_t = inspection.args(t)
key_t, value_t = inspection.args(t, evaluate=True)
self.keys = context[key_t]
self.values = context[value_t]

Expand Down Expand Up @@ -804,7 +804,7 @@ def __init__(
"""
super().__init__(t=t, context=context, var=var)
# supporting tuple[str, ...]
(value_t, *_) = inspection.args(t)
(value_t, *_) = inspection.args(t, evaluate=True)
self.values = context[value_t]

def __call__(self, val: tp.Any) -> IterableT:
Expand Down Expand Up @@ -857,7 +857,7 @@ def __init__(
var: A variable name for the indicated type annotation (unused, optional).
"""
super().__init__(t, context, var=var)
(value_t,) = inspection.args(t)
(value_t,) = inspection.args(t, evaluate=True)
self.values = context[value_t]

def __call__(self, val: tp.Any) -> IteratorT:
Expand Down Expand Up @@ -916,7 +916,7 @@ def __init__(
var: A variable name for the indicated type annotation (unused, optional).
"""
super().__init__(t, context, var=var)
self.stack = inspection.args(t)
self.stack = inspection.args(t, evaluate=True)
self.ordered_routines = [self.context[vt] for vt in self.stack]

def __call__(self, val: tp.Any) -> compat.TupleT:
Expand Down
6 changes: 6 additions & 0 deletions tests/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,3 +93,9 @@ class ChildIntersect:
@dataclasses.dataclass
class NestedTypeAliasType:
alias: ListAlias


ValueAlias = compat.TypeAliasType("ValueAlias", int)
RecursiveAlias = compat.TypeAliasType(
"RecursiveAlias", "dict[str, RecursiveAlias | ValueAlias]"
)
5 changes: 5 additions & 0 deletions tests/unit/marshals/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,11 @@
given_input=models.NestedTypeAliasType(alias=[1]),
expected_output={"alias": [1]},
),
recursive_alias=dict(
given_type=models.RecursiveAlias,
given_input={"cycle": {"cycle": {"cycle": 1}}},
expected_output={"cycle": {"cycle": {"cycle": 1}}},
),
)
def test_marshal(given_type, given_input, expected_output):
# When
Expand Down
5 changes: 5 additions & 0 deletions tests/unit/unmarshals/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,11 @@
given_input={"alias": ["1"]},
expected_output=models.NestedTypeAliasType(alias=[1]),
),
recursive_alias=dict(
given_type=models.RecursiveAlias,
given_input={"cycle": {"cycle": {"cycle": "1"}}},
expected_output={"cycle": {"cycle": {"cycle": 1}}},
),
)
def test_unmarshal(given_type, given_input, expected_output):
# When
Expand Down

0 comments on commit 94c8fa3

Please sign in to comment.