Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: Track "unwrapped" types during routine resolution #9

Merged
merged 1 commit into from
Oct 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 20 additions & 18 deletions src/typelib/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,9 +50,6 @@ def static_order(
"""
# We want to leverage the cache if possible, hence the recursive call.
# Shouldn't actually recurse more than once or twice.
if inspection.istypealiastype(t):
value = inspection.unwrap(t)
return static_order(value)
if isinstance(t, (str, refs.ForwardRef)):
ref = refs.forwardref(t) if isinstance(t, str) else t
t = refs.evaluate(ref)
Expand Down Expand Up @@ -80,8 +77,6 @@ def itertypes(
[`static_order`][typelib.graph.static_order] instead of
[`itertypes`][typelib.graph.itertypes].
"""
if inspection.istypealiastype(t):
t = inspection.unwrap(t)
if isinstance(t, (str, refs.ForwardRef)): # pragma: no cover
ref = refs.forwardref(t) if isinstance(t, str) else t
t = refs.evaluate(ref)
Expand Down Expand Up @@ -113,33 +108,31 @@ def get_type_graph(t: type) -> graphlib.TopologicalSorter[TypeNode]:
resolve one level deep on each attempt, otherwise we will find ourselves stuck
in a closed loop which never terminates (infinite recursion).
"""
if inspection.istypealiastype(t):
t = inspection.unwrap(t)

graph: graphlib.TopologicalSorter = graphlib.TopologicalSorter()
root = TypeNode(t)
u = inspection.unwrap(t)
root = TypeNode(t, u)
stack = collections.deque([root])
visited = {root.type}
while stack:
parent = stack.popleft()
if inspection.isliteral(parent.type):
parent_unwrapped = inspection.unwrap(parent.type)
if inspection.isliteral(parent_unwrapped):
graph.add(parent)
continue

predecessors = []
for var, child in _level(parent.type):
for var, child in _level(parent_unwrapped):
# If no type was provided, there's no reason to do further processing.
if child in (constants.empty, typing.Any):
continue
if inspection.istypealiastype(child):
child = inspection.unwrap(child)

unwrapped = inspection.unwrap(child)
# Only subscripted generics or non-stdlib types can be cyclic.
# i.e., we may get `str` or `datetime` any number of times,
# that's not cyclic, so we can just add it to the graph.
is_visited = child in visited
is_subscripted = inspection.issubscriptedgeneric(child)
is_stdlib = inspection.isstdlibtype(child)
is_visited = child in visited or unwrapped in visited
is_subscripted = inspection.issubscriptedgeneric(unwrapped)
is_stdlib = inspection.isstdlibtype(unwrapped)
can_be_cyclic = is_subscripted or is_stdlib is False
# We detected a cyclic type,
# wrap in a ForwardRef and don't add it to the stack
Expand All @@ -155,10 +148,13 @@ def get_type_graph(t: type) -> graphlib.TopologicalSorter[TypeNode]:
ref = refs.forwardref(
refname, is_argument=is_argument, module=module, is_class=is_class
)
node = TypeNode(ref, var=var, cyclic=True)
uref = refs.forwardref(
unwrapped, is_argument=is_argument, module=module, is_class=is_class
)
node = TypeNode(ref, uref, var=var, cyclic=True)
# Otherwise, add the type to the stack and track that it's been seen.
else:
node = TypeNode(type=child, var=var)
node = TypeNode(type=child, unwrapped=unwrapped, var=var)
visited.add(node.type)
stack.append(node)
# Flag the type as a "predecessor" of the parent type.
Expand All @@ -177,11 +173,17 @@ class TypeNode:

type: typing.Any
"""The type annotation for this node."""
unwrapped: typing.Any | None = None
"""The unwrapped type annotation for this node."""
var: str | None = None
"""The variable or parameter name associated to the type annotation for this node."""
cyclic: bool = dataclasses.field(default=False, hash=False, compare=False)
"""Whether this type annotation is cyclic."""

def __post_init__(self):
if self.unwrapped is None:
self.unwrapped = self.type


def _level(t: typing.Any) -> typing.Iterable[tuple[str | None, type]]:
args = inspection.args(t)
Expand Down
9 changes: 6 additions & 3 deletions src/typelib/marshals/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ def marshaller(
root = nodes[-1]
for node in nodes:
context[node.type] = _get_unmarshaller(node, context=context)
context[node.unwrapped] = context[node.type]

return context[root.type]

Expand All @@ -66,10 +67,12 @@ def _get_unmarshaller( # type: ignore[return]
return context[node.type]

for check, unmarshaller_cls in _HANDLERS.items():
if check(node.type):
return unmarshaller_cls(node.type, context=context, var=node.var)
if check(node.unwrapped):
return unmarshaller_cls(node.unwrapped, context=context, var=node.var)

return routines.StructuredTypeMarshaller(node.type, context=context, var=node.var)
return routines.StructuredTypeMarshaller(
node.unwrapped, context=context, var=node.var
)


class DelayedMarshaller(routines.AbstractMarshaller[T]):
Expand Down
2 changes: 1 addition & 1 deletion src/typelib/py/inspection.py
Original file line number Diff line number Diff line change
Expand Up @@ -573,7 +573,7 @@ def isoptionaltype(obj: type[_OT]) -> compat.TypeIs[type[tp.Optional[_OT]]]:
tname = name(origin(obj))
nullarg = next((a for a in args if a in (type(None), None)), ...)
isoptional = tname == "Optional" or (
nullarg is not ... and tname in ("Union", "Uniontype", "Literal")
nullarg is not ... and tname in ("Union", "UnionType", "Literal")
)
return isoptional

Expand Down
9 changes: 6 additions & 3 deletions src/typelib/unmarshals/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ def unmarshaller(
root = nodes[-1]
for node in nodes:
context[node.type] = _get_unmarshaller(node, context=context)
context[node.unwrapped] = context[node.type]

return context[root.type]

Expand All @@ -60,10 +61,12 @@ def _get_unmarshaller( # type: ignore[return]
return context[node.type]

for check, unmarshaller_cls in _HANDLERS.items():
if check(node.type):
return unmarshaller_cls(node.type, context=context, var=node.var)
if check(node.unwrapped):
return unmarshaller_cls(node.unwrapped, context=context, var=node.var)

return routines.StructuredTypeUnmarshaller(node.type, context=context, var=node.var)
return routines.StructuredTypeUnmarshaller(
node.unwrapped, context=context, var=node.var
)


class DelayedUnmarshaller(routines.AbstractUnmarshaller[T]):
Expand Down
5 changes: 5 additions & 0 deletions tests/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,3 +99,8 @@ class NestedTypeAliasType:
RecursiveAlias = compat.TypeAliasType(
"RecursiveAlias", "dict[str, RecursiveAlias | ValueAlias]"
)

ScalarValue = compat.TypeAliasType("ScalarValue", "int | float | str | bool | None")
Record = compat.TypeAliasType(
"Record", "dict[str, list[Record] | list[ScalarValue] | Record | ScalarValue]"
)
9 changes: 2 additions & 7 deletions tests/unit/marshals/test_routines.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@

import pytest

from typelib import graph
from typelib.marshals import routines

from tests import models
Expand Down Expand Up @@ -405,12 +404,8 @@ def test_fixed_tuple_marshaller(
@pytest.mark.suite(
context=dict(
given_context={
graph.TypeNode(int, var="value"): routines.IntegerMarshaller(
int, {}, var="value"
),
graph.TypeNode(str, var="field"): routines.StringMarshaller(
str, {}, var="field"
),
int: routines.IntegerMarshaller(int, {}, var="value"),
str: routines.StringMarshaller(str, {}, var="field"),
},
expected_output=dict(field="data", value=1),
),
Expand Down
2 changes: 1 addition & 1 deletion tests/unit/test_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ class NoTypes:
given_type=models.NestedTypeAliasType,
expected_nodes=[
graph.TypeNode(type=int),
graph.TypeNode(type=list[int], var="alias"),
graph.TypeNode(type=models.ListAlias, unwrapped=list[int], var="alias"),
graph.TypeNode(type=NestedTypeAliasType),
],
),
Expand Down
Loading