Skip to content

Commit

Permalink
Merge pull request #1 from seandstewart/seandstewart/enums-and-uniont…
Browse files Browse the repository at this point in the history
…ypes

seandstewart/enums-and-uniontypes
  • Loading branch information
seandstewart authored Oct 16, 2024
2 parents 34d521a + bf4ad13 commit b8fa248
Show file tree
Hide file tree
Showing 12 changed files with 87 additions and 3 deletions.
2 changes: 1 addition & 1 deletion src/typelib/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ def get_type_graph(t: type) -> graphlib.TopologicalSorter[TypeNode]:
qualname = inspection.qualname(child)
*rest, refname = qualname.split(".", maxsplit=1)
is_argument = var is not None
module = getattr(child, "__module__", None)
module = ".".join(rest) or getattr(child, "__module__", None)
if module in (None, "__main__") and rest:
module = rest[0]
is_class = inspect.isclass(child)
Expand Down
2 changes: 2 additions & 0 deletions src/typelib/marshals/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,8 @@ def __call__(self, val: T) -> serdes.MarshalledValueT:
inspection.isliteral: routines.LiteralMarshaller,
# Special handler for Unions...
inspection.isuniontype: routines.UnionMarshaller,
# Special handling for Enums
inspection.isenumtype: routines.EnumMarshaller,
# Non-intersecting types (order doesn't matter here.
inspection.isdatetimetype: routines.DateTimeMarshaller,
inspection.isdatetype: routines.DateMarshaller,
Expand Down
16 changes: 16 additions & 0 deletions src/typelib/marshals/routines.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import contextlib
import datetime
import decimal
import enum
import fractions
import pathlib
import re
Expand Down Expand Up @@ -41,6 +42,7 @@
"SubscriptedMappingMarshaller",
"FixedTupleMarshaller",
"StructuredTypeMarshaller",
"EnumMarshaller",
)


Expand Down Expand Up @@ -151,6 +153,20 @@ def __call__(self, val: T) -> str:
PathT = tp.TypeVar("PathT", bound=pathlib.Path)
PathMarshaller = ToStringMarshaller[PathT]

EnumT = tp.TypeVar("EnumT", bound=enum.Enum)


class EnumMarshaller(AbstractMarshaller[EnumT], tp.Generic[EnumT]):
"""A marshaller that converts an [`enum.Enum`][] instance to its assigned value."""

def __call__(self, val: EnumT) -> serdes.MarshalledValueT:
"""Marshal an [`enum.Enum`][] instance into a [`serdes.MarshalledValueT`][].
Args:
val: The enum instance to marshal.
"""
return val.value


PatternT = tp.TypeVar("PatternT", bound=re.Pattern)

Expand Down
9 changes: 8 additions & 1 deletion src/typelib/py/inspection.py
Original file line number Diff line number Diff line change
Expand Up @@ -480,6 +480,13 @@ def isbuiltintype(

@compat.cache
def isstdlibtype(obj: type) -> compat.TypeIs[type[STDLibtypeT]]:
if isoptionaltype(obj):
nargs = tp.get_args(obj)[:-1]
return all(isstdlibtype(a) for a in nargs)
if isuniontype(obj):
args = tp.get_args(obj)
return all(isstdlibtype(a) for a in args)

return (
resolve_supertype(obj) in STDLIB_TYPES
or resolve_supertype(type(obj)) in STDLIB_TYPES
Expand Down Expand Up @@ -903,7 +910,7 @@ def isenumtype(obj: type) -> compat.TypeIs[type[enum.Enum]]:
>>> isenumtype(FooNum)
True
"""
return issubclass(obj, enum.Enum)
return _safe_issubclass(obj, enum.Enum)


@compat.cache
Expand Down
2 changes: 2 additions & 0 deletions src/typelib/unmarshals/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,8 @@ def __call__(self, val: tp.Any) -> T:
inspection.isliteral: routines.LiteralUnmarshaller,
# Special handler for Unions...
inspection.isuniontype: routines.UnionUnmarshaller,
# Special handling for Enums
inspection.isenumtype: routines.EnumUnmarshaller,
# Non-intersecting types (order doesn't matter here.
inspection.isdatetimetype: routines.DateTimeUnmarshaller,
inspection.isdatetype: routines.DateUnmarshaller,
Expand Down
7 changes: 6 additions & 1 deletion src/typelib/unmarshals/routines.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import contextlib
import datetime
import decimal
import enum
import fractions
import numbers
import pathlib
Expand Down Expand Up @@ -46,6 +47,7 @@
"SubscriptedMappingUnmarshaller",
"FixedTupleUnmarshaller",
"StructuredTypeUnmarshaller",
"EnumUnmarshaller",
)


Expand Down Expand Up @@ -537,7 +539,7 @@ def __call__(self, val: tp.Any) -> UUIDT:


class PatternUnmarshaller(AbstractUnmarshaller[PatternT], tp.Generic[PatternT]):
"""Unmarshaller that converts an input to a[`re.Pattern`][].
"""Unmarshaller that converts an input to a [`re.Pattern`][].
Note:
You can't instantiate a [`re.Pattern`][] directly, so we don't have a good
Expand Down Expand Up @@ -596,6 +598,9 @@ def __call__(self, val: tp.Any) -> T:
MappingUnmarshaller = CastUnmarshaller[tp.Mapping]
IterableUnmarshaller = CastUnmarshaller[tp.Iterable]

EnumT = tp.TypeVar("EnumT", bound=enum.Enum)
EnumUnmarshaller = CastUnmarshaller[EnumT]


LiteralT = tp.TypeVar("LiteralT")

Expand Down
13 changes: 13 additions & 0 deletions tests/models.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from __future__ import annotations

import dataclasses
import datetime
import enum
import typing


Expand Down Expand Up @@ -49,3 +51,14 @@ class NTuple(typing.NamedTuple):
class TDict(typing.TypedDict):
field: str
value: int


class GivenEnum(enum.Enum):
one = "one"


@dataclasses.dataclass
class UnionSTDLib:
timestamp: datetime.datetime | None = None
date_time: datetime.datetime | None = None
intstr: int | str = 0
5 changes: 5 additions & 0 deletions tests/unit/marshals/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,11 @@
),
expected_output={"indirect": {"cycle": {"indirect": {"cycle": None}}}},
),
enum_type=dict(
given_type=models.GivenEnum,
given_input=models.GivenEnum.one,
expected_output=models.GivenEnum.one.value,
),
)
def test_marshal(given_type, given_input, expected_output):
# When
Expand Down
11 changes: 11 additions & 0 deletions tests/unit/marshals/test_routines.py
Original file line number Diff line number Diff line change
Expand Up @@ -449,3 +449,14 @@ def test_invalid_union():
# When/Then
with pytest.raises(expected_exception):
given_marshaller(given_value)


def test_enum_unmarshaller():
# Given
given_unmarshaller = routines.EnumMarshaller(models.GivenEnum, {})
given_value = models.GivenEnum.one
expected_value = models.GivenEnum.one.value
# When
unmarshalled = given_unmarshaller(given_value)
# Then
assert unmarshalled == expected_value
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,18 @@
given_input='["1", "2"]',
expected_output=[1, 2],
),
enum_type=dict(
given_type=models.GivenEnum,
given_input="one",
expected_output=models.GivenEnum.one,
),
union_std_lib=dict(
given_type=models.UnionSTDLib,
given_input={"timestamp": 0},
expected_output=models.UnionSTDLib(
timestamp=datetime.datetime.fromtimestamp(0, datetime.timezone.utc)
),
),
)
def test_unmarshal(given_type, given_input, expected_output):
# When
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -820,3 +820,14 @@ def test_invalid_union():
# When/Then
with pytest.raises(expected_exception):
given_unmarshaller(given_value)


def test_enum_unmarshaller():
# Given
given_unmarshaller = routines.EnumUnmarshaller(models.GivenEnum, {})
given_value = models.GivenEnum.one.value
expected_value = models.GivenEnum.one
# When
unmarshalled = given_unmarshaller(given_value)
# Then
assert unmarshalled == expected_value

0 comments on commit b8fa248

Please sign in to comment.