Skip to content

Commit

Permalink
Add StringEnum and IntegerEnum
Browse files Browse the repository at this point in the history
  • Loading branch information
lafrech committed Jul 20, 2022
1 parent 5d569ad commit 82d27e7
Show file tree
Hide file tree
Showing 4 changed files with 111 additions and 41 deletions.
62 changes: 43 additions & 19 deletions src/marshmallow/fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,8 @@
"IPv4Interface",
"IPv6Interface",
"Enum",
"StringEnum",
"IntegerEnum",
"Method",
"Function",
"Str",
Expand Down Expand Up @@ -1856,49 +1858,71 @@ class IPv6Interface(IPInterface):
DESERIALIZATION_CLASS = ipaddress.IPv6Interface


class Enum(Field):
class Enum(String):

default_error_messages = {
"invalid": "Not a valid string.",
"unknown": "Must be one of: {choices}",
"unknown": "Must be one of: {choices}.",
}

def __init__(
self,
enum,
by_value=False,
*args,
**kwargs,
):
self.enum = enum
self.by_value = by_value
self.choices = ", ".join([str(e.value if by_value else e.name) for e in enum])
self.choices = ", ".join(enum.__members__)
super().__init__(*args, **kwargs)

def _serialize(self, value, attr, obj, **kwargs):
if value is None:
return None
return value.value if self.by_value else value.name
return value.name

def _deserialize(self, value, attr, data, **kwargs):
if self.by_value:
return self._deserialize_by_value(value, attr, data)
else:
return self._deserialize_by_name(value, attr, data)
value = super()._deserialize(value, attr, data, **kwargs)
try:
return getattr(self.enum, value)
except AttributeError as exc:
raise self.make_error("unknown", choices=self.choices) from exc


class TypedEnum:
"""Base class for typed Enum fields"""

default_error_messages = {
"unknown": "Must be one of: {choices}.",
}

def __init__(
self,
enum,
*args,
**kwargs,
):
self.enum = enum
self.choices = ", ".join([str(m.value) for m in enum])
super().__init__(*args, **kwargs)

def _deserialize_by_value(self, value, attr, data):
def _serialize(self, value, attr, obj, **kwargs):
if value is None:
return None
return value.value

def _deserialize(self, value, attr, data, **kwargs):
value = super()._deserialize(value, attr, data, **kwargs)
try:
return self.enum(value)
except ValueError as exc:
raise self.make_error("unknown", choices=self.choices) from exc

def _deserialize_by_name(self, value, attr, data):
if not isinstance(value, (str, bytes)):
raise self.make_error("invalid")
try:
return getattr(self.enum, value)
except AttributeError as exc:
raise self.make_error("unknown", choices=self.choices) from exc

class StringEnum(TypedEnum, String):
"""String Enum"""


class IntegerEnum(TypedEnum, Integer):
"""Integer Enum"""


class Method(Field):
Expand Down
15 changes: 13 additions & 2 deletions tests/base.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""Test utilities and fixtures."""
import datetime as dt
import uuid
from enum import IntEnum
from enum import Enum, IntEnum

import simplejson

Expand All @@ -19,6 +19,13 @@ class GenderEnum(IntEnum):
non_binary = 3


class HairColorEnum(Enum):
black = "black hair"
brown = "brown hair"
blond = "blond hair"
red = "red hair"


ALL_FIELDS = [
fields.String,
fields.Integer,
Expand All @@ -41,6 +48,8 @@ class GenderEnum(IntEnum):
fields.IPv4Interface,
fields.IPv6Interface,
lambda **x: fields.Enum(GenderEnum, **x),
lambda **x: fields.StringEnum(HairColorEnum, **x),
lambda **x: fields.IntegerEnum(GenderEnum, **x),
]


Expand Down Expand Up @@ -79,6 +88,7 @@ def __init__(
birthtime=None,
balance=100,
sex=GenderEnum.male,
hair_color=HairColorEnum.black,
employer=None,
various_data=None,
):
Expand All @@ -95,7 +105,7 @@ def __init__(
self.email = email
self.balance = balance
self.registered = registered
self.hair_colors = ["black", "brown", "blond", "redhead"]
self.hair_colors = list(HairColorEnum.__members__)
self.sex_choices = list(GenderEnum.__members__)
self.finger_count = 10
self.uid = uuid.uuid1()
Expand All @@ -104,6 +114,7 @@ def __init__(
self.birthtime = birthtime or dt.time(0, 1, 2, 3333)
self.activation_date = dt.date(2013, 12, 11)
self.sex = sex
self.hair_color = hair_color
self.employer = employer
self.relatives = []
self.various_data = various_data or {
Expand Down
56 changes: 41 additions & 15 deletions tests/test_deserialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
central,
ALL_FIELDS,
GenderEnum,
HairColorEnum,
)


Expand Down Expand Up @@ -1095,28 +1096,53 @@ def test_invalid_ipv6interface_deserialization(self, in_value):

assert excinfo.value.args[0] == "Not a valid IPv6 interface."

@pytest.mark.parametrize("by_value,value", ((True, 1), (False, "male")))
def test_enum_field_deserialization(self, by_value, value):
field = fields.Enum(GenderEnum, by_value=by_value)
assert field.deserialize(value) == GenderEnum.male
def test_enum_field_deserialization(self):
field = fields.Enum(GenderEnum)
assert field.deserialize("male") == GenderEnum.male

@pytest.mark.parametrize(
"by_value,exc_str",
(
(True, "Must be one of: 1, 2, 3"),
(False, "Must be one of: male, female, non_binary"),
),
)
def test_enum_field_invalid_value(self, by_value, exc_str):
field = fields.Enum(GenderEnum, by_value=by_value)
with pytest.raises(ValidationError, match=exc_str):
def test_enum_field_invalid_value(self):
field = fields.Enum(GenderEnum)
with pytest.raises(
ValidationError, match="Must be one of: male, female, non_binary."
):
field.deserialize("dummy")

def test_enum_field_by_name_not_string(self):
def test_enum_field_not_string(self):
field = fields.Enum(GenderEnum)
with pytest.raises(ValidationError, match="Not a valid string."):
field.deserialize(12)

def test_stringenum_field_deserialization(self):
field = fields.StringEnum(HairColorEnum)
assert field.deserialize("black hair") == HairColorEnum.black

def test_stringenum_field_invalid_value(self):
field = fields.StringEnum(HairColorEnum)
with pytest.raises(
ValidationError,
match="Must be one of: black hair, brown hair, blond hair, red hair.",
):
field.deserialize("dummy")

def test_stringenum_field_not_string(self):
field = fields.StringEnum(HairColorEnum)
with pytest.raises(ValidationError, match="Not a valid string."):
field.deserialize(12)

def test_integerenum_field_deserialization(self):
field = fields.IntegerEnum(GenderEnum)
assert field.deserialize(1) == GenderEnum.male

def test_integerenum_field_invalid_value(self):
field = fields.IntegerEnum(GenderEnum)
with pytest.raises(ValidationError, match="Must be one of: 1, 2, 3."):
field.deserialize(12)

def test_integerenum_field_not_integer(self):
field = fields.IntegerEnum(GenderEnum)
with pytest.raises(ValidationError, match="Not a valid integer."):
field.deserialize("dummy")

def test_deserialization_function_must_be_callable(self):
with pytest.raises(TypeError):
fields.Function(lambda x: None, deserialize="notvalid")
Expand Down
19 changes: 14 additions & 5 deletions tests/test_serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

from marshmallow import Schema, fields, missing as missing_

from tests.base import User, ALL_FIELDS, central, GenderEnum
from tests.base import User, ALL_FIELDS, central, GenderEnum, HairColorEnum


class DateTimeList:
Expand Down Expand Up @@ -255,11 +255,20 @@ def test_ipv6_interface_field(self, user):
== ipv6interface_exploded_string
)

@pytest.mark.parametrize("by_value,value", ((True, 1), (False, "male")))
def test_enum_field_serialization(self, user, by_value, value):
def test_enum_field_serialization(self, user):
user.sex = GenderEnum.male
field = fields.Enum(GenderEnum, by_value=by_value)
assert field.serialize("sex", user) == value
field = fields.Enum(GenderEnum)
assert field.serialize("sex", user) == "male"

def test_stringenum_field_serialization(self, user):
user.hair_color = HairColorEnum.black
field = fields.StringEnum(HairColorEnum)
assert field.serialize("hair_color", user) == "black hair"

def test_integerenum_field_serialization(self, user):
user.sex = GenderEnum.male
field = fields.IntegerEnum(GenderEnum)
assert field.serialize("sex", user) == 1

def test_decimal_field(self, user):
user.m1 = 12
Expand Down

0 comments on commit 82d27e7

Please sign in to comment.