From a05be1e199079fa895c6560faba69065d8a3298e Mon Sep 17 00:00:00 2001 From: Sean Stewart Date: Wed, 30 Oct 2024 13:23:20 -0400 Subject: [PATCH] fix: Track "unwrapped" types during routine resolution - Fixes an issue where a `TypeAliasType` would not correctly resolve in a highly nested, recursive context. --- src/typelib/graph.py | 38 +++++++++++++++------------- src/typelib/marshals/api.py | 9 ++++--- src/typelib/py/inspection.py | 2 +- src/typelib/unmarshals/api.py | 9 ++++--- tests/models.py | 5 ++++ tests/unit/marshals/test_routines.py | 9 ++----- tests/unit/test_graph.py | 2 +- 7 files changed, 41 insertions(+), 33 deletions(-) diff --git a/src/typelib/graph.py b/src/typelib/graph.py index 31c93ee..656dc69 100644 --- a/src/typelib/graph.py +++ b/src/typelib/graph.py @@ -50,9 +50,6 @@ def static_order( """ # We want to leverage the cache if possible, hence the recursive call. # Shouldn't actually recurse more than once or twice. - if inspection.istypealiastype(t): - value = inspection.unwrap(t) - return static_order(value) if isinstance(t, (str, refs.ForwardRef)): ref = refs.forwardref(t) if isinstance(t, str) else t t = refs.evaluate(ref) @@ -80,8 +77,6 @@ def itertypes( [`static_order`][typelib.graph.static_order] instead of [`itertypes`][typelib.graph.itertypes]. """ - if inspection.istypealiastype(t): - t = inspection.unwrap(t) if isinstance(t, (str, refs.ForwardRef)): # pragma: no cover ref = refs.forwardref(t) if isinstance(t, str) else t t = refs.evaluate(ref) @@ -113,33 +108,31 @@ def get_type_graph(t: type) -> graphlib.TopologicalSorter[TypeNode]: resolve one level deep on each attempt, otherwise we will find ourselves stuck in a closed loop which never terminates (infinite recursion). """ - if inspection.istypealiastype(t): - t = inspection.unwrap(t) - graph: graphlib.TopologicalSorter = graphlib.TopologicalSorter() - root = TypeNode(t) + u = inspection.unwrap(t) + root = TypeNode(t, u) stack = collections.deque([root]) visited = {root.type} while stack: parent = stack.popleft() - if inspection.isliteral(parent.type): + parent_unwrapped = inspection.unwrap(parent.type) + if inspection.isliteral(parent_unwrapped): graph.add(parent) continue predecessors = [] - for var, child in _level(parent.type): + for var, child in _level(parent_unwrapped): # If no type was provided, there's no reason to do further processing. if child in (constants.empty, typing.Any): continue - if inspection.istypealiastype(child): - child = inspection.unwrap(child) + unwrapped = inspection.unwrap(child) # Only subscripted generics or non-stdlib types can be cyclic. # i.e., we may get `str` or `datetime` any number of times, # that's not cyclic, so we can just add it to the graph. - is_visited = child in visited - is_subscripted = inspection.issubscriptedgeneric(child) - is_stdlib = inspection.isstdlibtype(child) + is_visited = child in visited or unwrapped in visited + is_subscripted = inspection.issubscriptedgeneric(unwrapped) + is_stdlib = inspection.isstdlibtype(unwrapped) can_be_cyclic = is_subscripted or is_stdlib is False # We detected a cyclic type, # wrap in a ForwardRef and don't add it to the stack @@ -155,10 +148,13 @@ def get_type_graph(t: type) -> graphlib.TopologicalSorter[TypeNode]: ref = refs.forwardref( refname, is_argument=is_argument, module=module, is_class=is_class ) - node = TypeNode(ref, var=var, cyclic=True) + uref = refs.forwardref( + unwrapped, is_argument=is_argument, module=module, is_class=is_class + ) + node = TypeNode(ref, uref, var=var, cyclic=True) # Otherwise, add the type to the stack and track that it's been seen. else: - node = TypeNode(type=child, var=var) + node = TypeNode(type=child, unwrapped=unwrapped, var=var) visited.add(node.type) stack.append(node) # Flag the type as a "predecessor" of the parent type. @@ -177,11 +173,17 @@ class TypeNode: type: typing.Any """The type annotation for this node.""" + unwrapped: typing.Any | None = None + """The unwrapped type annotation for this node.""" var: str | None = None """The variable or parameter name associated to the type annotation for this node.""" cyclic: bool = dataclasses.field(default=False, hash=False, compare=False) """Whether this type annotation is cyclic.""" + def __post_init__(self): + if self.unwrapped is None: + self.unwrapped = self.type + def _level(t: typing.Any) -> typing.Iterable[tuple[str | None, type]]: args = inspection.args(t) diff --git a/src/typelib/marshals/api.py b/src/typelib/marshals/api.py index 8c5cb23..f30a693 100644 --- a/src/typelib/marshals/api.py +++ b/src/typelib/marshals/api.py @@ -54,6 +54,7 @@ def marshaller( root = nodes[-1] for node in nodes: context[node.type] = _get_unmarshaller(node, context=context) + context[node.unwrapped] = context[node.type] return context[root.type] @@ -66,10 +67,12 @@ def _get_unmarshaller( # type: ignore[return] return context[node.type] for check, unmarshaller_cls in _HANDLERS.items(): - if check(node.type): - return unmarshaller_cls(node.type, context=context, var=node.var) + if check(node.unwrapped): + return unmarshaller_cls(node.unwrapped, context=context, var=node.var) - return routines.StructuredTypeMarshaller(node.type, context=context, var=node.var) + return routines.StructuredTypeMarshaller( + node.unwrapped, context=context, var=node.var + ) class DelayedMarshaller(routines.AbstractMarshaller[T]): diff --git a/src/typelib/py/inspection.py b/src/typelib/py/inspection.py index ca01d80..224225e 100644 --- a/src/typelib/py/inspection.py +++ b/src/typelib/py/inspection.py @@ -573,7 +573,7 @@ def isoptionaltype(obj: type[_OT]) -> compat.TypeIs[type[tp.Optional[_OT]]]: tname = name(origin(obj)) nullarg = next((a for a in args if a in (type(None), None)), ...) isoptional = tname == "Optional" or ( - nullarg is not ... and tname in ("Union", "Uniontype", "Literal") + nullarg is not ... and tname in ("Union", "UnionType", "Literal") ) return isoptional diff --git a/src/typelib/unmarshals/api.py b/src/typelib/unmarshals/api.py index ded29f6..330c779 100644 --- a/src/typelib/unmarshals/api.py +++ b/src/typelib/unmarshals/api.py @@ -48,6 +48,7 @@ def unmarshaller( root = nodes[-1] for node in nodes: context[node.type] = _get_unmarshaller(node, context=context) + context[node.unwrapped] = context[node.type] return context[root.type] @@ -60,10 +61,12 @@ def _get_unmarshaller( # type: ignore[return] return context[node.type] for check, unmarshaller_cls in _HANDLERS.items(): - if check(node.type): - return unmarshaller_cls(node.type, context=context, var=node.var) + if check(node.unwrapped): + return unmarshaller_cls(node.unwrapped, context=context, var=node.var) - return routines.StructuredTypeUnmarshaller(node.type, context=context, var=node.var) + return routines.StructuredTypeUnmarshaller( + node.unwrapped, context=context, var=node.var + ) class DelayedUnmarshaller(routines.AbstractUnmarshaller[T]): diff --git a/tests/models.py b/tests/models.py index 83668e8..39aed48 100644 --- a/tests/models.py +++ b/tests/models.py @@ -99,3 +99,8 @@ class NestedTypeAliasType: RecursiveAlias = compat.TypeAliasType( "RecursiveAlias", "dict[str, RecursiveAlias | ValueAlias]" ) + +ScalarValue = compat.TypeAliasType("ScalarValue", "int | float | str | bool | None") +Record = compat.TypeAliasType( + "Record", "dict[str, list[Record] | list[ScalarValue] | Record | ScalarValue]" +) diff --git a/tests/unit/marshals/test_routines.py b/tests/unit/marshals/test_routines.py index a6c34bf..16a7cff 100644 --- a/tests/unit/marshals/test_routines.py +++ b/tests/unit/marshals/test_routines.py @@ -10,7 +10,6 @@ import pytest -from typelib import graph from typelib.marshals import routines from tests import models @@ -405,12 +404,8 @@ def test_fixed_tuple_marshaller( @pytest.mark.suite( context=dict( given_context={ - graph.TypeNode(int, var="value"): routines.IntegerMarshaller( - int, {}, var="value" - ), - graph.TypeNode(str, var="field"): routines.StringMarshaller( - str, {}, var="field" - ), + int: routines.IntegerMarshaller(int, {}, var="value"), + str: routines.StringMarshaller(str, {}, var="field"), }, expected_output=dict(field="data", value=1), ), diff --git a/tests/unit/test_graph.py b/tests/unit/test_graph.py index 5c410fa..2630988 100644 --- a/tests/unit/test_graph.py +++ b/tests/unit/test_graph.py @@ -94,7 +94,7 @@ class NoTypes: given_type=models.NestedTypeAliasType, expected_nodes=[ graph.TypeNode(type=int), - graph.TypeNode(type=list[int], var="alias"), + graph.TypeNode(type=models.ListAlias, unwrapped=list[int], var="alias"), graph.TypeNode(type=NestedTypeAliasType), ], ),