diff --git a/src/marshmallow/fields.py b/src/marshmallow/fields.py index 383b34ef6..45a5572e3 100644 --- a/src/marshmallow/fields.py +++ b/src/marshmallow/fields.py @@ -11,6 +11,7 @@ import math import typing import warnings +from enum import Enum from collections.abc import Mapping as _Mapping from marshmallow import validate, utils, class_registry, types @@ -59,6 +60,8 @@ "IPInterface", "IPv4Interface", "IPv6Interface", + "EnumSymbol", + "EnumValue", "Method", "Function", "Str", @@ -1855,6 +1858,78 @@ class IPv6Interface(IPInterface): DESERIALIZATION_CLASS = ipaddress.IPv6Interface +class EnumSymbol(String): + """An Enum field (de)serializing enum members by symbol (name) as string. + + :param enum Enum: Enum class + + .. versionadded:: 3.18.0 + """ + + default_error_messages = { + "unknown": "Must be one of: {choices}.", + } + + def __init__(self, enum: type[Enum], **kwargs): + self.enum = enum + self.choices = ", ".join(enum.__members__) + super().__init__(**kwargs) + + def _serialize(self, value, attr, obj, **kwargs): + if value is None: + return None + return value.name + + def _deserialize(self, value, attr, data, **kwargs): + value = super()._deserialize(value, attr, data, **kwargs) + try: + return getattr(self.enum, value) + except AttributeError as exc: + raise self.make_error("unknown", choices=self.choices) from exc + + +class EnumValue(Field): + """An Enum field (de)serializing enum members by value. + + A Field must be provided to (de)serialize the value. + + :param cls_or_instance: Field class or instance. + :param enum Enum: Enum class + + .. versionadded:: 3.18.0 + """ + + default_error_messages = { + "unknown": "Must be one of: {choices}.", + } + + def __init__(self, cls_or_instance: Field | type, enum: type[Enum], **kwargs): + super().__init__(**kwargs) + try: + self.field = resolve_field_instance(cls_or_instance) + except FieldInstanceResolutionError as error: + raise ValueError( + "The enum field must be a subclass or instance of " + "marshmallow.base.FieldABC." + ) from error + self.enum = enum + self.choices = ", ".join( + [str(self.field._serialize(m.value, None, None)) for m in enum] + ) + + def _serialize(self, value, attr, obj, **kwargs): + if value is None: + return None + return self.field._serialize(value.value, attr, obj, **kwargs) + + def _deserialize(self, value, attr, data, **kwargs): + value = self.field._deserialize(value, attr, data, **kwargs) + try: + return self.enum(value) + except ValueError as exc: + raise self.make_error("unknown", choices=self.choices) from exc + + class Method(Field): """A field that takes the value returned by a `Schema` method. diff --git a/tests/base.py b/tests/base.py index 597d3f944..4dbbcd5fc 100644 --- a/tests/base.py +++ b/tests/base.py @@ -1,6 +1,8 @@ """Test utilities and fixtures.""" +import functools import datetime as dt import uuid +from enum import Enum, IntEnum import simplejson @@ -12,6 +14,25 @@ central = pytz.timezone("US/Central") +class GenderEnum(IntEnum): + male = 1 + female = 2 + non_binary = 3 + + +class HairColorEnum(Enum): + black = "black hair" + brown = "brown hair" + blond = "blond hair" + red = "red hair" + + +class DateEnum(Enum): + date_1 = dt.date(2004, 2, 29) + date_2 = dt.date(2008, 2, 29) + date_3 = dt.date(2012, 2, 29) + + ALL_FIELDS = [ fields.String, fields.Integer, @@ -33,8 +54,12 @@ fields.IPInterface, fields.IPv4Interface, fields.IPv6Interface, + functools.partial(fields.EnumSymbol, GenderEnum), + functools.partial(fields.EnumValue, fields.String, HairColorEnum), + functools.partial(fields.EnumValue, fields.Integer, GenderEnum), ] + ##### Custom asserts ##### @@ -69,7 +94,8 @@ def __init__( birthdate=None, birthtime=None, balance=100, - sex="male", + sex=GenderEnum.male, + hair_color=HairColorEnum.black, employer=None, various_data=None, ): @@ -86,8 +112,8 @@ def __init__( self.email = email self.balance = balance self.registered = registered - self.hair_colors = ["black", "brown", "blond", "redhead"] - self.sex_choices = ("male", "female") + self.hair_colors = list(HairColorEnum.__members__) + self.sex_choices = list(GenderEnum.__members__) self.finger_count = 10 self.uid = uuid.uuid1() self.time_registered = time_registered or dt.time(1, 23, 45, 6789) @@ -95,6 +121,7 @@ def __init__( self.birthtime = birthtime or dt.time(0, 1, 2, 3333) self.activation_date = dt.date(2013, 12, 11) self.sex = sex + self.hair_color = hair_color self.employer = employer self.relatives = [] self.various_data = various_data or { @@ -180,7 +207,7 @@ class UserSchema(Schema): birthtime = fields.Time() activation_date = fields.Date() since_created = fields.TimeDelta() - sex = fields.Str(validate=validate.OneOf(["male", "female"])) + sex = fields.Str(validate=validate.OneOf(list(GenderEnum.__members__))) various_data = fields.Dict() class Meta: diff --git a/tests/test_deserialization.py b/tests/test_deserialization.py index 4ee5ba5b1..b922d6a56 100644 --- a/tests/test_deserialization.py +++ b/tests/test_deserialization.py @@ -10,7 +10,15 @@ from marshmallow.exceptions import ValidationError from marshmallow.validate import Equal -from tests.base import assert_date_equal, assert_time_equal, central, ALL_FIELDS +from tests.base import ( + assert_date_equal, + assert_time_equal, + central, + ALL_FIELDS, + GenderEnum, + HairColorEnum, + DateEnum, +) class TestDeserializingNone: @@ -1089,6 +1097,57 @@ def test_invalid_ipv6interface_deserialization(self, in_value): assert excinfo.value.args[0] == "Not a valid IPv6 interface." + def test_enumsymbol_field_deserialization(self): + field = fields.EnumSymbol(GenderEnum) + assert field.deserialize("male") == GenderEnum.male + + def test_enumsymbol_field_invalid_value(self): + field = fields.EnumSymbol(GenderEnum) + with pytest.raises( + ValidationError, match="Must be one of: male, female, non_binary." + ): + field.deserialize("dummy") + + def test_enumsymbol_field_not_string(self): + field = fields.EnumSymbol(GenderEnum) + with pytest.raises(ValidationError, match="Not a valid string."): + field.deserialize(12) + + def test_enumvalue_field_deserialization(self): + field = fields.EnumValue(fields.String, HairColorEnum) + assert field.deserialize("black hair") == HairColorEnum.black + field = fields.EnumValue(fields.Integer, GenderEnum) + assert field.deserialize(1) == GenderEnum.male + field = fields.EnumValue(fields.Date(format="%d/%m/%Y"), DateEnum) + assert field.deserialize("29/02/2004") == DateEnum.date_1 + + def test_enumvalue_field_invalid_value(self): + field = fields.EnumValue(fields.String, HairColorEnum) + with pytest.raises( + ValidationError, + match="Must be one of: black hair, brown hair, blond hair, red hair.", + ): + field.deserialize("dummy") + field = fields.EnumValue(fields.Integer, GenderEnum) + with pytest.raises(ValidationError, match="Must be one of: 1, 2, 3."): + field.deserialize(12) + field = fields.EnumValue(fields.Date(format="%d/%m/%Y"), DateEnum) + with pytest.raises( + ValidationError, match="Must be one of: 29/02/2004, 29/02/2008, 29/02/2012." + ): + field.deserialize("28/02/2004") + + def test_enumvalue_field_wrong_type(self): + field = fields.EnumValue(fields.String, HairColorEnum) + with pytest.raises(ValidationError, match="Not a valid string."): + field.deserialize(12) + field = fields.EnumValue(fields.Integer, GenderEnum) + with pytest.raises(ValidationError, match="Not a valid integer."): + field.deserialize("dummy") + field = fields.EnumValue(fields.Date(format="%d/%m/%Y"), DateEnum) + with pytest.raises(ValidationError, match="Not a valid date."): + field.deserialize("30/02/2004") + def test_deserialization_function_must_be_callable(self): with pytest.raises(TypeError): fields.Function(lambda x: None, deserialize="notvalid") diff --git a/tests/test_serialization.py b/tests/test_serialization.py index 41eed7098..8e751e5ce 100644 --- a/tests/test_serialization.py +++ b/tests/test_serialization.py @@ -11,7 +11,7 @@ from marshmallow import Schema, fields, missing as missing_ -from tests.base import User, ALL_FIELDS, central +from tests.base import User, ALL_FIELDS, central, GenderEnum, HairColorEnum, DateEnum class DateTimeList: @@ -255,6 +255,22 @@ def test_ipv6_interface_field(self, user): == ipv6interface_exploded_string ) + def test_enumsymbol_field_serialization(self, user): + user.sex = GenderEnum.male + field = fields.EnumSymbol(GenderEnum) + assert field.serialize("sex", user) == "male" + + def test_enumvalue_field_serialization(self, user): + user.hair_color = HairColorEnum.black + field = fields.EnumValue(fields.String, HairColorEnum) + assert field.serialize("hair_color", user) == "black hair" + user.sex = GenderEnum.male + field = fields.EnumValue(fields.Integer, GenderEnum) + assert field.serialize("sex", user) == 1 + user.some_date = DateEnum.date_1 + field = fields.EnumValue(fields.Date(format="%d/%m/%Y"), DateEnum) + assert field.serialize("some_date", user) == "29/02/2004" + def test_decimal_field(self, user): user.m1 = 12 user.m2 = "12.355"