diff --git a/.devcontainer/devcontainer.json b/.devcontainer/devcontainer.json new file mode 100644 index 00000000..29a749c7 --- /dev/null +++ b/.devcontainer/devcontainer.json @@ -0,0 +1,26 @@ +{ + "name": "jcrist/msgspec", + "image": "mcr.microsoft.com/devcontainers/python:3.12-bookworm", + "postCreateCommand": "scripts/install.sh", + "customizations": { + "vscode": { + "extensions": [ + "charliermarsh.ruff", + "ms-python.python", + "ms-vscode.cpptools" + ], + "settings": { + "C_Cpp.default.includePath": [ + "/usr/local/include/**" + ], + "C_Cpp.formatting": "disabled", + "python.testing.pytestArgs": [ + "-v", + "tests/" + ], + "python.testing.pytestEnabled": true, + "python.testing.unittestEnabled": false + } + } + } +} \ No newline at end of file diff --git a/docs/source/constraints.rst b/docs/source/constraints.rst index b2ec764f..c73b25c5 100644 --- a/docs/source/constraints.rst +++ b/docs/source/constraints.rst @@ -76,7 +76,7 @@ The following constraints are supported: Numeric Constraints ------------------- -These constraints are valid on `int` or `float` types: +These constraints are valid on `int`, `float`, or `decimal.Decimal` types: - ``ge``: The value must be greater than or equal to ``ge``. - ``gt``: The value must be greater than ``gt``. @@ -88,6 +88,8 @@ These constraints are valid on `int` or `float` types: >>> import msgspec + >>> from decimal import Decimal + >>> from typing import Annotated >>> msgspec.json.decode(b'-1', type=Annotated[int, msgspec.Meta(ge=0)]) @@ -95,16 +97,24 @@ These constraints are valid on `int` or `float` types: File "", line 1, in msgspec.ValidationError: Expected `int` >= 0 -.. warning:: + >>> msgspec.json.decode(b'0.3', type=Annotated[Decimal, msgspec.Meta(multiple_of=Decimal('0.1'))]) + Decimal('0.3') + +.. note:: While ``multiple_of`` works on ``float`` types, we don't recommend - specifying *non-integral* ``multiple_of`` constraints, as they may be - erroneously marked as invalid due to floating point precision issues. For - example, annotating a ``float`` type with ``multiple_of=10`` is fine, but - ``multiple_of=0.1`` may lead to issues. See `this GitHub issue + specifying *non-integral* ``multiple_of`` constraints on them, + as they may be erroneously marked as invalid due to floating point + precision issues. For example, annotating a ``float`` type with + ``multiple_of=10`` is fine, but ``multiple_of=0.1`` may lead to issues. + See `this GitHub issue `_ for more details. + To address this issue, ``msgspec`` supports specifying ``multiple_of`` + constraints with `decimal.Decimal` types, that offer arbitrary precision + arithmetic. + String Constraints ------------------ diff --git a/msgspec/__init__.pyi b/msgspec/__init__.pyi index 86a1051a..70c036cd 100644 --- a/msgspec/__init__.pyi +++ b/msgspec/__init__.pyi @@ -1,4 +1,5 @@ import enum +from decimal import Decimal from typing import ( Any, Callable, @@ -115,11 +116,11 @@ class Meta: def __init__( self, *, - gt: Union[int, float, None] = None, - ge: Union[int, float, None] = None, - lt: Union[int, float, None] = None, - le: Union[int, float, None] = None, - multiple_of: Union[int, float, None] = None, + gt: Union[int, float, Decimal, None] = None, + ge: Union[int, float, Decimal, None] = None, + lt: Union[int, float, Decimal, None] = None, + le: Union[int, float, Decimal, None] = None, + multiple_of: Union[int, float, Decimal, None] = None, pattern: Union[str, None] = None, min_length: Union[int, None] = None, max_length: Union[int, None] = None, @@ -130,11 +131,11 @@ class Meta: extra_json_schema: Union[dict, None] = None, extra: Union[dict, None] = None, ): ... - gt: Final[Union[int, float, None]] - ge: Final[Union[int, float, None]] - lt: Final[Union[int, float, None]] - le: Final[Union[int, float, None]] - multiple_of: Final[Union[int, float, None]] + gt: Final[Union[int, float, Decimal, None]] + ge: Final[Union[int, float, Decimal, None]] + lt: Final[Union[int, float, Decimal, None]] + le: Final[Union[int, float, Decimal, None]] + multiple_of: Final[Union[int, float, Decimal, None]] pattern: Final[Union[str, None]] min_length: Final[Union[int, None]] max_length: Final[Union[int, None]] diff --git a/msgspec/_core.c b/msgspec/_core.c index 0f287814..bec17393 100644 --- a/msgspec/_core.c +++ b/msgspec/_core.c @@ -1618,12 +1618,20 @@ ensure_is_finite_numeric(PyObject *val, const char *param, bool positive) { } } else { - PyErr_Format( - PyExc_TypeError, - "`%s` must be an int or float, got %.200s", - param, Py_TYPE(val)->tp_name - ); - return false; + MsgspecState *mod = msgspec_get_global_state(); + if (PyObject_IsInstance(val, mod->DecimalType)) { + PyObject *as_py_float = PyNumber_Float(val); + x = PyFloat_AS_DOUBLE(as_py_float); + Py_DECREF(as_py_float); + } + else { + PyErr_Format( + PyExc_TypeError, + "`%s` must be an int or float or Decimal, got %.200s", + param, Py_TYPE(val)->tp_name + ); + return false; + } } if (positive && x <= 0) { PyErr_Format(PyExc_ValueError, "`%s` must be > 0", param); @@ -1833,6 +1841,11 @@ Meta_new(PyTypeObject *type, PyObject *args, PyObject *kwargs) { static int Meta_traverse(Meta *self, visitproc visit, void *arg) { + Py_VISIT(self->gt); + Py_VISIT(self->ge); + Py_VISIT(self->lt); + Py_VISIT(self->le); + Py_VISIT(self->multiple_of); Py_VISIT(self->regex); Py_VISIT(self->examples); Py_VISIT(self->extra_json_schema); @@ -2682,63 +2695,68 @@ AssocList_Sort(AssocList* list) { *************************************************************************/ /* Types */ -#define MS_TYPE_ANY (1ull << 0) -#define MS_TYPE_NONE (1ull << 1) -#define MS_TYPE_BOOL (1ull << 2) -#define MS_TYPE_INT (1ull << 3) -#define MS_TYPE_FLOAT (1ull << 4) -#define MS_TYPE_STR (1ull << 5) -#define MS_TYPE_BYTES (1ull << 6) -#define MS_TYPE_BYTEARRAY (1ull << 7) -#define MS_TYPE_MEMORYVIEW (1ull << 8) -#define MS_TYPE_DATETIME (1ull << 9) -#define MS_TYPE_DATE (1ull << 10) -#define MS_TYPE_TIME (1ull << 11) -#define MS_TYPE_TIMEDELTA (1ull << 12) -#define MS_TYPE_UUID (1ull << 13) -#define MS_TYPE_DECIMAL (1ull << 14) -#define MS_TYPE_EXT (1ull << 15) -#define MS_TYPE_STRUCT (1ull << 16) -#define MS_TYPE_STRUCT_ARRAY (1ull << 17) -#define MS_TYPE_STRUCT_UNION (1ull << 18) -#define MS_TYPE_STRUCT_ARRAY_UNION (1ull << 19) -#define MS_TYPE_ENUM (1ull << 20) -#define MS_TYPE_INTENUM (1ull << 21) -#define MS_TYPE_CUSTOM (1ull << 22) -#define MS_TYPE_CUSTOM_GENERIC (1ull << 23) -#define MS_TYPE_DICT ((1ull << 24) | (1ull << 25)) -#define MS_TYPE_LIST (1ull << 26) -#define MS_TYPE_SET (1ull << 27) -#define MS_TYPE_FROZENSET (1ull << 28) -#define MS_TYPE_VARTUPLE (1ull << 29) -#define MS_TYPE_FIXTUPLE (1ull << 30) -#define MS_TYPE_INTLITERAL (1ull << 31) -#define MS_TYPE_STRLITERAL (1ull << 32) -#define MS_TYPE_TYPEDDICT (1ull << 33) -#define MS_TYPE_DATACLASS (1ull << 34) -#define MS_TYPE_NAMEDTUPLE (1ull << 35) +#define MS_TYPE_ANY (1ull << 0) +#define MS_TYPE_NONE (1ull << 1) +#define MS_TYPE_BOOL (1ull << 2) +#define MS_TYPE_INT (1ull << 3) +#define MS_TYPE_FLOAT (1ull << 4) +#define MS_TYPE_STR (1ull << 5) +#define MS_TYPE_BYTES (1ull << 6) +#define MS_TYPE_BYTEARRAY (1ull << 7) +#define MS_TYPE_MEMORYVIEW (1ull << 8) +#define MS_TYPE_DATETIME (1ull << 9) +#define MS_TYPE_DATE (1ull << 10) +#define MS_TYPE_TIME (1ull << 11) +#define MS_TYPE_TIMEDELTA (1ull << 12) +#define MS_TYPE_UUID (1ull << 13) +#define MS_TYPE_DECIMAL (1ull << 14) +#define MS_TYPE_EXT (1ull << 15) +#define MS_TYPE_STRUCT (1ull << 16) +#define MS_TYPE_STRUCT_ARRAY (1ull << 17) +#define MS_TYPE_STRUCT_UNION (1ull << 18) +#define MS_TYPE_STRUCT_ARRAY_UNION (1ull << 19) +#define MS_TYPE_ENUM (1ull << 20) +#define MS_TYPE_INTENUM (1ull << 21) +#define MS_TYPE_CUSTOM (1ull << 22) +#define MS_TYPE_CUSTOM_GENERIC (1ull << 23) +#define MS_TYPE_DICT ((1ull << 24) | (1ull << 25)) +#define MS_TYPE_LIST (1ull << 26) +#define MS_TYPE_SET (1ull << 27) +#define MS_TYPE_FROZENSET (1ull << 28) +#define MS_TYPE_VARTUPLE (1ull << 29) +#define MS_TYPE_FIXTUPLE (1ull << 30) +#define MS_TYPE_INTLITERAL (1ull << 31) +#define MS_TYPE_STRLITERAL (1ull << 32) +#define MS_TYPE_TYPEDDICT (1ull << 33) +#define MS_TYPE_DATACLASS (1ull << 34) +#define MS_TYPE_NAMEDTUPLE (1ull << 35) /* Constraints */ -#define MS_CONSTR_INT_MIN (1ull << 42) -#define MS_CONSTR_INT_MAX (1ull << 43) -#define MS_CONSTR_INT_MULTIPLE_OF (1ull << 44) -#define MS_CONSTR_FLOAT_GT (1ull << 45) -#define MS_CONSTR_FLOAT_GE (1ull << 46) -#define MS_CONSTR_FLOAT_LT (1ull << 47) -#define MS_CONSTR_FLOAT_LE (1ull << 48) -#define MS_CONSTR_FLOAT_MULTIPLE_OF (1ull << 49) -#define MS_CONSTR_STR_REGEX (1ull << 50) -#define MS_CONSTR_STR_MIN_LENGTH (1ull << 51) -#define MS_CONSTR_STR_MAX_LENGTH (1ull << 52) -#define MS_CONSTR_BYTES_MIN_LENGTH (1ull << 53) -#define MS_CONSTR_BYTES_MAX_LENGTH (1ull << 54) -#define MS_CONSTR_ARRAY_MIN_LENGTH (1ull << 55) -#define MS_CONSTR_ARRAY_MAX_LENGTH (1ull << 56) -#define MS_CONSTR_MAP_MIN_LENGTH (1ull << 57) -#define MS_CONSTR_MAP_MAX_LENGTH (1ull << 58) -#define MS_CONSTR_TZ_AWARE (1ull << 59) -#define MS_CONSTR_TZ_NAIVE (1ull << 60) +#define MS_CONSTR_INT_MIN (1ull << 37) +#define MS_CONSTR_INT_MAX (1ull << 38) +#define MS_CONSTR_INT_MULTIPLE_OF (1ull << 39) +#define MS_CONSTR_FLOAT_GT (1ull << 40) +#define MS_CONSTR_FLOAT_GE (1ull << 41) +#define MS_CONSTR_FLOAT_LT (1ull << 42) +#define MS_CONSTR_FLOAT_LE (1ull << 43) +#define MS_CONSTR_FLOAT_MULTIPLE_OF (1ull << 44) +#define MS_CONSTR_STR_REGEX (1ull << 45) +#define MS_CONSTR_STR_MIN_LENGTH (1ull << 46) +#define MS_CONSTR_STR_MAX_LENGTH (1ull << 47) +#define MS_CONSTR_BYTES_MIN_LENGTH (1ull << 48) +#define MS_CONSTR_BYTES_MAX_LENGTH (1ull << 49) +#define MS_CONSTR_ARRAY_MIN_LENGTH (1ull << 50) +#define MS_CONSTR_ARRAY_MAX_LENGTH (1ull << 51) +#define MS_CONSTR_MAP_MIN_LENGTH (1ull << 52) +#define MS_CONSTR_MAP_MAX_LENGTH (1ull << 53) +#define MS_CONSTR_TZ_AWARE (1ull << 54) +#define MS_CONSTR_TZ_NAIVE (1ull << 55) +#define MS_CONSTR_DECIMAL_GT (1ull << 56) +#define MS_CONSTR_DECIMAL_GE (1ull << 57) +#define MS_CONSTR_DECIMAL_LT (1ull << 58) +#define MS_CONSTR_DECIMAL_LE (1ull << 59) +#define MS_CONSTR_DECIMAL_MULTIPLE_OF (1ull << 60) /* Extra flag bit, used by TypedDict/dataclass implementations */ -#define MS_EXTRA_FLAG (1ull << 63) +#define MS_EXTRA_FLAG (1ull << 63) /* A TypeNode encodes information about all types at the same hierarchy in the * type tree. They can encode both single types (`int`) and unions of types @@ -2778,6 +2796,11 @@ AssocList_Sort(AssocList* list) { * S | ARRAY_MAX_LENGTH | * S | MAP_MIN_LENGTH | * S | MAP_MAX_LENGTH | + * D | DECIMAL_GT | + * D | DECIMAL_GE | + * D | DECIMAL_LT | + * D | DECIMAL_LE | + * D | DECIMAL_MULTIPLE_OF | * T | FIXTUPLE [size, types ...] | * */ @@ -2807,10 +2830,16 @@ AssocList_Sort(AssocList* list) { #define SLOT_19 MS_CONSTR_ARRAY_MAX_LENGTH #define SLOT_20 MS_CONSTR_MAP_MIN_LENGTH #define SLOT_21 MS_CONSTR_MAP_MAX_LENGTH +#define SLOT_22 MS_CONSTR_DECIMAL_GT +#define SLOT_23 MS_CONSTR_DECIMAL_GE +#define SLOT_24 MS_CONSTR_DECIMAL_LT +#define SLOT_25 MS_CONSTR_DECIMAL_LE +#define SLOT_26 MS_CONSTR_DECIMAL_MULTIPLE_OF /* Common groups */ #define MS_INT_CONSTRS (SLOT_08 | SLOT_09 | SLOT_10) #define MS_FLOAT_CONSTRS (SLOT_11 | SLOT_12 | SLOT_13) +#define MS_DECIMAL_CONSTRS (SLOT_22 | SLOT_23 | SLOT_24 | SLOT_25 | SLOT_26) #define MS_STR_CONSTRS (SLOT_05 | SLOT_14 | SLOT_15) #define MS_BYTES_CONSTRS (SLOT_16 | SLOT_17) #define MS_ARRAY_CONSTRS (SLOT_18 | SLOT_19) @@ -3177,13 +3206,76 @@ TypeNode_get_constr_map_max_length(TypeNode *type) { return type->details[i].py_ssize_t; } +static MS_INLINE PyObject * +TypeNode_get_constr_decimal_gt(TypeNode *type) { + Py_ssize_t i = ms_popcount( + type->types & ( + SLOT_00 | SLOT_01 | SLOT_02 | SLOT_03 | SLOT_04 | SLOT_05 | SLOT_06 | SLOT_07 | + SLOT_08 | SLOT_09 | SLOT_10 | SLOT_11 | SLOT_12 | SLOT_13 | SLOT_14 | SLOT_15 | + SLOT_16 | SLOT_17 | SLOT_18 | SLOT_19 | SLOT_20 | SLOT_21 + ) + ); + return type->details[i].pointer; +} + +static MS_INLINE PyObject * +TypeNode_get_constr_decimal_ge(TypeNode *type) { + Py_ssize_t i = ms_popcount( + type->types & ( + SLOT_00 | SLOT_01 | SLOT_02 | SLOT_03 | SLOT_04 | SLOT_05 | SLOT_06 | SLOT_07 | + SLOT_08 | SLOT_09 | SLOT_10 | SLOT_11 | SLOT_12 | SLOT_13 | SLOT_14 | SLOT_15 | + SLOT_16 | SLOT_17 | SLOT_18 | SLOT_19 | SLOT_20 | SLOT_21 | SLOT_22 + ) + ); + return type->details[i].pointer; +} + +static MS_INLINE PyObject * +TypeNode_get_constr_decimal_lt(TypeNode *type) { + Py_ssize_t i = ms_popcount( + type->types & ( + SLOT_00 | SLOT_01 | SLOT_02 | SLOT_03 | SLOT_04 | SLOT_05 | SLOT_06 | SLOT_07 | + SLOT_08 | SLOT_09 | SLOT_10 | SLOT_11 | SLOT_12 | SLOT_13 | SLOT_14 | SLOT_15 | + SLOT_16 | SLOT_17 | SLOT_18 | SLOT_19 | SLOT_20 | SLOT_21 | SLOT_22 | SLOT_23 + ) + ); + return type->details[i].pointer; +} + +static MS_INLINE PyObject * +TypeNode_get_constr_decimal_le(TypeNode *type) { + Py_ssize_t i = ms_popcount( + type->types & ( + SLOT_00 | SLOT_01 | SLOT_02 | SLOT_03 | SLOT_04 | SLOT_05 | SLOT_06 | SLOT_07 | + SLOT_08 | SLOT_09 | SLOT_10 | SLOT_11 | SLOT_12 | SLOT_13 | SLOT_14 | SLOT_15 | + SLOT_16 | SLOT_17 | SLOT_18 | SLOT_19 | SLOT_20 | SLOT_21 | SLOT_22 | SLOT_23 | + SLOT_24 + ) + ); + return type->details[i].pointer; +} + +static MS_INLINE PyObject * +TypeNode_get_constr_decimal_multiple_of(TypeNode *type) { + Py_ssize_t i = ms_popcount( + type->types & ( + SLOT_00 | SLOT_01 | SLOT_02 | SLOT_03 | SLOT_04 | SLOT_05 | SLOT_06 | SLOT_07 | + SLOT_08 | SLOT_09 | SLOT_10 | SLOT_11 | SLOT_12 | SLOT_13 | SLOT_14 | SLOT_15 | + SLOT_16 | SLOT_17 | SLOT_18 | SLOT_19 | SLOT_20 | SLOT_21 | SLOT_22 | SLOT_23 | + SLOT_24 | SLOT_25 + ) + ); + return type->details[i].pointer; +} + static MS_INLINE void TypeNode_get_fixtuple(TypeNode *type, Py_ssize_t *offset, Py_ssize_t *size) { Py_ssize_t i = ms_popcount( type->types & ( SLOT_00 | SLOT_01 | SLOT_02 | SLOT_03 | SLOT_04 | SLOT_05 | SLOT_06 | SLOT_07 | SLOT_08 | SLOT_09 | SLOT_10 | SLOT_11 | SLOT_12 | SLOT_13 | SLOT_14 | SLOT_15 | - SLOT_16 | SLOT_17 | SLOT_18 | SLOT_19 | SLOT_20 | SLOT_21 + SLOT_16 | SLOT_17 | SLOT_18 | SLOT_19 | SLOT_20 | SLOT_21 | SLOT_22 | SLOT_23 | + SLOT_24 | SLOT_25 | SLOT_26 ) ); *size = type->details[i].py_ssize_t; @@ -3311,7 +3403,7 @@ typenode_simple_repr(TypeNode *self) { if (!strbuilder_extend_literal(&builder, "uuid")) return NULL; } if (self->types & MS_TYPE_DECIMAL) { - if (!strbuilder_extend_literal(&builder, "decimal")) return NULL; + if (!strbuilder_extend_literal(&builder, "Decimal")) return NULL; } if (self->types & MS_TYPE_EXT) { if (!strbuilder_extend_literal(&builder, "ext")) return NULL; @@ -3381,6 +3473,11 @@ typedef struct { double c_float_min; double c_float_max; double c_float_multiple_of; + PyObject *py_decimal_gt; + PyObject *py_decimal_ge; + PyObject *py_decimal_lt; + PyObject *py_decimal_le; + PyObject *py_decimal_multiple_of; PyObject *c_str_regex; Py_ssize_t c_str_min_length; Py_ssize_t c_str_max_length; @@ -3462,12 +3559,13 @@ constraints_update(Constraints *self, Meta *meta, PyObject *type) { enum constraint_kind { CK_INT = 0, CK_FLOAT = 1, - CK_STR = 2, - CK_BYTES = 3, - CK_TIME = 4, - CK_ARRAY = 5, - CK_MAP = 6, - CK_OTHER = 7, + CK_DECIMAL = 2, + CK_STR = 3, + CK_BYTES = 4, + CK_TIME = 5, + CK_ARRAY = 6, + CK_MAP = 7, + CK_OTHER = 8, }; static int @@ -3482,8 +3580,17 @@ err_invalid_constraint(const char *name, const char *kind, PyObject *obj) { static bool _constr_as_i64(PyObject *obj, int64_t *target, int offset) { + MsgspecState *mod = msgspec_get_global_state(); + int64_t x; int overflow; - int64_t x = PyLong_AsLongLongAndOverflow(obj, &overflow); + if (PyObject_IsInstance(obj, mod->DecimalType)) { + PyObject *as_py_long = PyNumber_Long(obj); + x = PyLong_AsLongLongAndOverflow(as_py_long, &overflow); + Py_DECREF(as_py_long); + } + else { + x = PyLong_AsLongLongAndOverflow(obj, &overflow); + } if (overflow != 0) { PyErr_SetString( PyExc_ValueError, @@ -3541,6 +3648,32 @@ _constr_as_py_ssize_t(PyObject *obj, Py_ssize_t *target) { return true; } +static PyObject * +_py_float_to_decimal(PyObject *py_float, MsgspecState *mod) { + /* Render as the nearest IEEE754 double before calling Decimal */ + double val = PyFloat_AsDouble(py_float); + char buf[24]; + int size = write_f64(val, buf, false); + PyObject *str = PyUnicode_New(size, 127); + if (str == NULL) return NULL; + memcpy(ascii_get_buffer(str), buf, size); + PyObject *out = CALL_ONE_ARG(mod->DecimalType, str); + Py_DECREF(str); + return out; +} + +static bool +_constr_as_decimal(PyObject *obj, PyObject **target) { + MsgspecState *mod = msgspec_get_global_state(); + PyObject *decimal = PyFloat_Check(obj) ? + _py_float_to_decimal(obj, mod) : + CALL_ONE_ARG(mod->DecimalType, obj); + if (decimal == NULL) return false; + *target = decimal; + Py_INCREF(decimal); + return true; +} + static int typenode_collect_constraints( TypeNodeCollectState *state, @@ -3553,7 +3686,7 @@ typenode_collect_constraints( if (constraints_is_empty(constraints)) return 0; /* Check that the constraints are valid for the corresponding type */ - if (kind != CK_INT && kind != CK_FLOAT) { + if (kind != CK_INT && kind != CK_FLOAT && kind != CK_DECIMAL) { if (constraints->gt != NULL) return err_invalid_constraint("gt", "numeric", obj); if (constraints->ge != NULL) return err_invalid_constraint("ge", "numeric", obj); if (constraints->lt != NULL) return err_invalid_constraint("lt", "numeric", obj); @@ -3616,6 +3749,28 @@ typenode_collect_constraints( if (!_constr_as_f64(constraints->multiple_of, &(state->c_float_multiple_of), 0)) return -1; } } + else if (kind == CK_DECIMAL) { + if (constraints->gt != NULL) { + state->types |= MS_CONSTR_DECIMAL_GT; + if (!_constr_as_decimal(constraints->gt, &(state->py_decimal_gt))) return -1; + } + else if (constraints->ge != NULL) { + state->types |= MS_CONSTR_DECIMAL_GE; + if (!_constr_as_decimal(constraints->ge, &(state->py_decimal_ge))) return -1; + } + if (constraints->lt != NULL) { + state->types |= MS_CONSTR_DECIMAL_LT; + if (!_constr_as_decimal(constraints->lt, &(state->py_decimal_lt))) return -1; + } + else if (constraints->le != NULL) { + state->types |= MS_CONSTR_DECIMAL_LE; + if (!_constr_as_decimal(constraints->le, &(state->py_decimal_le))) return -1; + } + if (constraints->multiple_of != NULL) { + state->types |= MS_CONSTR_DECIMAL_MULTIPLE_OF; + if (!_constr_as_decimal(constraints->multiple_of, &(state->py_decimal_multiple_of))) return -1; + } + } else if (kind == CK_STR) { if (constraints->regex != NULL) { state->types |= MS_CONSTR_STR_REGEX; @@ -3706,7 +3861,12 @@ typenode_from_collect_state(TypeNodeCollectState *state) { MS_CONSTR_ARRAY_MIN_LENGTH | MS_CONSTR_ARRAY_MAX_LENGTH | MS_CONSTR_MAP_MIN_LENGTH | - MS_CONSTR_MAP_MAX_LENGTH + MS_CONSTR_MAP_MAX_LENGTH | + MS_CONSTR_DECIMAL_GT | + MS_CONSTR_DECIMAL_GE | + MS_CONSTR_DECIMAL_LT | + MS_CONSTR_DECIMAL_LE | + MS_CONSTR_DECIMAL_MULTIPLE_OF ) ); if (state->types & MS_TYPE_FIXTUPLE) { @@ -3834,6 +3994,21 @@ typenode_from_collect_state(TypeNodeCollectState *state) { Py_INCREF(state->c_str_regex); out->details[e_ind++].pointer = state->c_str_regex; } + if (state->types & MS_CONSTR_DECIMAL_GT) { + out->details[e_ind++].pointer = state->py_decimal_gt; + } + if (state->types & MS_CONSTR_DECIMAL_GE) { + out->details[e_ind++].pointer = state->py_decimal_ge; + } + if (state->types & MS_CONSTR_DECIMAL_LT) { + out->details[e_ind++].pointer = state->py_decimal_lt; + } + if (state->types & MS_CONSTR_DECIMAL_LE) { + out->details[e_ind++].pointer = state->py_decimal_le; + } + if (state->types & MS_CONSTR_DECIMAL_MULTIPLE_OF) { + out->details[e_ind++].pointer = state->py_decimal_multiple_of; + } if (state->dict_key_obj != NULL) { TypeNode *temp = TypeNode_Convert(state->dict_key_obj); if (temp == NULL) goto error; @@ -4034,7 +4209,7 @@ typenode_collect_check_invariants(TypeNodeCollectState *state) { PyExc_TypeError, "Type unions may not contain more than one str-like type (`str`, " "`Enum`, `Literal[str values]`, `datetime`, `date`, `time`, `timedelta`, " - "`uuid`, `decimal`, `bytes`, `bytearray`) - type `%R` is not supported", + "`uuid`, `Decimal`, `bytes`, `bytearray`) - type `%R` is not supported", state->context ); return -1; @@ -4582,6 +4757,11 @@ typenode_collect_clear_state(TypeNodeCollectState *state) { Py_CLEAR(state->literal_int_lookup); Py_CLEAR(state->literal_str_values); Py_CLEAR(state->literal_str_lookup); + Py_CLEAR(state->py_decimal_gt); + Py_CLEAR(state->py_decimal_ge); + Py_CLEAR(state->py_decimal_lt); + Py_CLEAR(state->py_decimal_le); + Py_CLEAR(state->py_decimal_multiple_of); Py_CLEAR(state->c_str_regex); } @@ -4803,6 +4983,10 @@ typenode_collect_type(TypeNodeCollectState *state, PyObject *obj) { state->types |= MS_TYPE_FLOAT; kind = CK_FLOAT; } + else if (t == state->mod->DecimalType) { + state->types |= MS_TYPE_DECIMAL; + kind = CK_DECIMAL; + } else if (t == (PyObject *)(&PyUnicode_Type)) { state->types |= MS_TYPE_STR; kind = CK_STR; @@ -4836,9 +5020,6 @@ typenode_collect_type(TypeNodeCollectState *state, PyObject *obj) { else if (t == state->mod->UUIDType) { state->types |= MS_TYPE_UUID; } - else if (t == state->mod->DecimalType) { - state->types |= MS_TYPE_DECIMAL; - } else if (t == (PyObject *)(&Ext_Type)) { state->types |= MS_TYPE_EXT; } @@ -9968,6 +10149,75 @@ ms_check_float_constraints(PyObject *obj, TypeNode *type, PathNode *path) { return _ms_check_float_constraints(obj, type, path); } +static MS_NOINLINE void +_err_decimal_constraint(const char *msg, PyObject *c, PathNode *path) { + ms_raise_validation_error(path, "Expected `Decimal` %s %S%U", msg, c); +} + +static bool +ms_passes_decimal_constraints( + PyObject *obj, TypeNode *type, PathNode *path +) { + MsgspecState *mod = msgspec_get_global_state(); + if (type->types & MS_CONSTR_DECIMAL_GT) { + PyObject *c = TypeNode_get_constr_decimal_gt(type); + int ok = PyObject_RichCompareBool(obj, c, Py_GT); + if (MS_UNLIKELY(ok != 1)) { + _err_decimal_constraint(">", c, path); + return false; + } + } + if (type->types & MS_CONSTR_DECIMAL_GE) { + PyObject *c = TypeNode_get_constr_decimal_ge(type); + int ok = PyObject_RichCompareBool(obj, c, Py_GE); + if (MS_UNLIKELY(ok != 1)) { + _err_decimal_constraint(">=", c, path); + return false; + } + } + if (type->types & MS_CONSTR_DECIMAL_LT) { + PyObject *c = TypeNode_get_constr_decimal_lt(type); + int ok = PyObject_RichCompareBool(obj, c, Py_LT); + if (MS_UNLIKELY(ok != 1)) { + _err_decimal_constraint("<", c, path); + return false; + } + } + if (type->types & MS_CONSTR_DECIMAL_LE) { + PyObject *c = TypeNode_get_constr_decimal_le(type); + int ok = PyObject_RichCompareBool(obj, c, Py_LE); + if (MS_UNLIKELY(ok != 1)) { + _err_decimal_constraint("<=", c, path); + return false; + } + } + if (MS_UNLIKELY(type->types & MS_CONSTR_DECIMAL_MULTIPLE_OF)) { + PyObject *c = TypeNode_get_constr_decimal_multiple_of(type); + PyObject *modulo = PyNumber_Remainder(obj, c); + PyObject *zero = PyLong_FromLong(0L); + int ok = PyObject_RichCompareBool(modulo, zero, Py_EQ); + Py_DECREF(modulo); + if (MS_UNLIKELY(ok != 1)) { + _err_decimal_constraint("that's a multiple of", c, path); + return false; + } + } + return true; +} + +static MS_NOINLINE PyObject * +_ms_check_decimal_constraints(PyObject *obj, TypeNode *type, PathNode *path) { + if (ms_passes_decimal_constraints(obj, type, path)) return obj; + Py_DECREF(obj); + return NULL; +} + +static MS_INLINE PyObject * +ms_check_decimal_constraints(PyObject *obj, TypeNode *type, PathNode *path) { + if (MS_LIKELY(!(type->types & MS_DECIMAL_CONSTRS))) return obj; + return _ms_check_decimal_constraints(obj, type, path); +} + static MS_NOINLINE bool _err_py_ssize_t_constraint(const char *msg, Py_ssize_t c, PathNode *path) { ms_raise_validation_error(path, msg, c); @@ -11360,25 +11610,54 @@ ms_decode_uuid_from_bytes(const char *buf, Py_ssize_t size, PathNode *path) { *************************************************************************/ static PyObject * -ms_decode_decimal_from_pyobj(PyObject *str, PathNode *path, MsgspecState *mod) { +_ms_decode_constr_decimal_from_pyobj( + PyObject *obj, TypeNode *type, PathNode *path, MsgspecState *mod +) { + PyObject *out = CALL_ONE_ARG(mod->DecimalType, obj); + if (out == NULL) return NULL; + if (!ms_passes_decimal_constraints(out, type, path)) { + Py_DECREF(out); + return NULL; + } + return out; +} + +static PyObject * +ms_decode_decimal_from_pyobj( + PyObject *obj, TypeNode *type, PathNode *path, MsgspecState *mod +) { if (mod == NULL) { mod = msgspec_get_global_state(); } - return CALL_ONE_ARG(mod->DecimalType, str); + if (MS_UNLIKELY(type->types & MS_DECIMAL_CONSTRS)) { + return _ms_decode_constr_decimal_from_pyobj(obj, type, path, mod); + } + return CALL_ONE_ARG(mod->DecimalType, obj); } static PyObject * -ms_decode_decimal_from_pystr(PyObject *str, PathNode *path, MsgspecState *mod) { - PyObject *out = ms_decode_decimal_from_pyobj(str, path, mod); +ms_decode_decimal_from_pystr( + PyObject *str, TypeNode *type, PathNode *path, MsgspecState *mod +) { + PyObject *out = ms_decode_decimal_from_pyobj(str, type, path, mod); if (out == NULL) { - ms_error_with_path("Invalid decimal string%U", path); + if (mod == NULL) { + mod = msgspec_get_global_state(); + } + bool validation_error = PyErr_ExceptionMatches(mod->ValidationError); + if (!validation_error) ms_error_with_path("Invalid decimal string%U", path); } return out; } static PyObject * ms_decode_decimal( - const char *view, Py_ssize_t size, bool is_ascii, PathNode *path, MsgspecState *mod + const char *view, + Py_ssize_t size, + bool is_ascii, + TypeNode *type, + PathNode *path, + MsgspecState *mod ) { PyObject *str; @@ -11391,43 +11670,45 @@ ms_decode_decimal( str = PyUnicode_DecodeUTF8(view, size, NULL); if (str == NULL) return NULL; } - PyObject *out = ms_decode_decimal_from_pystr(str, path, mod); + PyObject *out = ms_decode_decimal_from_pystr(str, type, path, mod); Py_DECREF(str); return out; } static PyObject * -ms_decode_decimal_from_int64(int64_t x, PathNode *path) { +ms_decode_decimal_from_int64(int64_t x, TypeNode *type, PathNode *path) { PyObject *temp = PyLong_FromLongLong(x); if (temp == NULL) return NULL; - PyObject *out = ms_decode_decimal_from_pyobj(temp, path, NULL); + PyObject *out = ms_decode_decimal_from_pyobj(temp, type, path, NULL); Py_DECREF(temp); return out; } static PyObject * -ms_decode_decimal_from_uint64(uint64_t x, PathNode *path) { +ms_decode_decimal_from_uint64(uint64_t x, TypeNode *type, PathNode *path) { PyObject *temp = PyLong_FromUnsignedLongLong(x); if (temp == NULL) return NULL; - PyObject *out = ms_decode_decimal_from_pyobj(temp, path, NULL); + PyObject *out = ms_decode_decimal_from_pyobj(temp, type, path, NULL); Py_DECREF(temp); return out; } static PyObject * -ms_decode_decimal_from_float(double val, PathNode *path, MsgspecState *mod) { +ms_decode_decimal_from_float( + double val, TypeNode *type, PathNode *path, MsgspecState *mod +) { if (MS_LIKELY(isfinite(val))) { /* For finite values, render as the nearest IEEE754 double in string * form, then call decimal.Decimal to parse */ char buf[24]; int n = write_f64(val, buf, false); - return ms_decode_decimal(buf, n, true, path, mod); + return ms_decode_decimal(buf, n, true, type, path, mod); } else { /* For nonfinite values, convert to float obj and go through python */ PyObject *temp = PyFloat_FromDouble(val); if (temp == NULL) return NULL; - PyObject *out = ms_decode_decimal_from_pyobj(temp, path, mod); + PyObject *out = ms_decode_decimal_from_pyobj(temp, type, path, mod); Py_DECREF(temp); return out; } @@ -11520,7 +11801,7 @@ ms_post_decode_int64( return ms_decode_float(x, type, path); } else if (type->types & MS_TYPE_DECIMAL) { - return ms_decode_decimal_from_int64(x, path); + return ms_decode_decimal_from_int64(x, type, path); } else if (!strict) { if (type->types & MS_TYPE_BOOL) { @@ -11551,7 +11832,7 @@ ms_post_decode_uint64( return ms_decode_float(x, type, path); } else if (type->types & MS_TYPE_DECIMAL) { - return ms_decode_decimal_from_uint64(x, path); + return ms_decode_decimal_from_uint64(x, type, path); } else if (!strict) { if (type->types & MS_TYPE_BOOL) { @@ -11744,7 +12025,7 @@ parse_number_nonfinite( ) ) { return ms_decode_decimal( - (char *)start, pend - start, true, path, NULL + (char *)start, pend - start, true, type, path, NULL ); } if (is_negative) { @@ -11961,7 +12242,7 @@ parse_number_inline( ) ) { return ms_decode_decimal( - (char *)start, p - start, true, path, NULL + (char *)start, p - start, true, type, path, NULL ); } else if (MS_UNLIKELY(float_hook != NULL && type->types & MS_TYPE_ANY)) { @@ -14948,7 +15229,7 @@ mpack_decode_float(DecoderState *self, double x, TypeNode *type, PathNode *path) return ms_decode_float(x, type, path); } else if (type->types & MS_TYPE_DECIMAL) { - return ms_decode_decimal_from_float(x, path, NULL); + return ms_decode_decimal_from_float(x, type, path, NULL); } else if (!self->strict) { if (type->types & MS_TYPE_INT) { @@ -15002,7 +15283,7 @@ mpack_decode_str(DecoderState *self, Py_ssize_t size, TypeNode *type, PathNode * return ms_decode_uuid_from_str(s, size, path); } else if (MS_UNLIKELY(type->types & MS_TYPE_DECIMAL)) { - return ms_decode_decimal(s, size, false, path, NULL); + return ms_decode_decimal(s, size, false, type, path, NULL); } return ms_validation_error("str", type, path); @@ -17157,7 +17438,7 @@ json_decode_string(JSONDecoderState *self, TypeNode *type, PathNode *path) { return ms_decode_uuid_from_str(view, size, path); } else if (MS_UNLIKELY(type->types & MS_TYPE_DECIMAL)) { - return ms_decode_decimal(view, size, is_ascii, path, NULL); + return ms_decode_decimal(view, size, is_ascii, type, path, NULL); } else if ( MS_UNLIKELY(type->types & @@ -20174,7 +20455,7 @@ convert_int( type->types & MS_TYPE_DECIMAL && !(self->builtin_types & MS_BUILTIN_DECIMAL) ) { - return ms_decode_decimal_from_pyobj(obj, path, self->mod); + return ms_decode_decimal_from_pyobj(obj, type, path, self->mod); } return convert_int_uncommon(self, obj, type, path); } @@ -20192,7 +20473,7 @@ convert_float( && !(self->builtin_types & MS_BUILTIN_DECIMAL) ) { return ms_decode_decimal_from_float( - PyFloat_AS_DOUBLE(obj), path, self->mod + PyFloat_AS_DOUBLE(obj), type, path, self->mod ); } else if (!self->strict) { @@ -20291,7 +20572,7 @@ convert_str_uncommon( (type->types & MS_TYPE_DECIMAL) && !(self->builtin_types & MS_BUILTIN_DECIMAL) ) { - return ms_decode_decimal_from_pystr(obj, path, self->mod); + return ms_decode_decimal_from_pystr(obj, type, path, self->mod); } else if ( (type->types & MS_TYPE_BYTES) @@ -20478,6 +20759,7 @@ convert_decimal( ConvertState *self, PyObject *obj, TypeNode *type, PathNode *path ) { if (type->types & MS_TYPE_DECIMAL) { + if (!ms_passes_decimal_constraints(obj, type, path)) return NULL; Py_INCREF(obj); return obj; } @@ -20488,7 +20770,7 @@ convert_decimal( Py_DECREF(temp); return out; } - return ms_validation_error("decimal", type, path); + return ms_validation_error("Decimal", type, path); } @@ -22196,6 +22478,7 @@ PyInit__core(void) temp_module = PyImport_ImportModule("decimal"); if (temp_module == NULL) return NULL; st->DecimalType = PyObject_GetAttrString(temp_module, "Decimal"); + Py_DECREF(temp_module); if (st->DecimalType == NULL) return NULL; /* Get the re.compile function */ diff --git a/scripts/install.sh b/scripts/install.sh new file mode 100755 index 00000000..b8092d80 --- /dev/null +++ b/scripts/install.sh @@ -0,0 +1,13 @@ +#!/bin/bash + +set -e + +VENV=.venv +if [ ! -d "$VENV" ]; then + python -m venv "$VENV" +fi +source "$VENV"/bin/activate + +pip install -e .[dev,test,doc] + +pre-commit install diff --git a/tests/test_constraints.py b/tests/test_constraints.py index 6dc86110..72715cdb 100644 --- a/tests/test_constraints.py +++ b/tests/test_constraints.py @@ -1,6 +1,8 @@ import datetime +import decimal import math import re +from decimal import Decimal from typing import Dict, List, Union import pytest @@ -45,6 +47,14 @@ def sign(x): return out +def decimal_nextafter(x, towards): + """Not 100% accurate either. Assumes the Decimal value is + in the range between -10 and 10.""" + precision = decimal.getcontext().prec + epsilon = Decimal(10) ** -(precision - 1) + return x + epsilon if towards > 0 else x - epsilon + + FIELDS = { "gt": 0, "ge": 0, @@ -179,15 +189,16 @@ def test_field_equality(self, field): def test_numeric_fields(self, field): Meta(**{field: 1}) Meta(**{field: 2.5}) + Meta(**{field: Decimal("3.3")}) with pytest.raises( - TypeError, match=f"`{field}` must be an int or float, got str" + TypeError, match=f"`{field}` must be an int or float or Decimal, got str" ): Meta(**{field: "bad"}) with pytest.raises(ValueError, match=f"`{field}` must be finite"): Meta(**{field: float("inf")}) - @pytest.mark.parametrize("val", [0, 0.0]) + @pytest.mark.parametrize("val", [0, 0.0, Decimal(0)]) def test_multiple_of_bounds(self, val): with pytest.raises(ValueError, match=r"`multiple_of` must be > 0"): Meta(multiple_of=val) @@ -329,6 +340,20 @@ class Ex(msgspec.Struct): with pytest.raises(msgspec.ValidationError, match=err_msg): dec.decode(proto.encode(Ex(x))) + def test_bound_to_decimal(self, proto): + class Ex(msgspec.Struct): + x: Annotated[int, Meta(le=Decimal(1))] + + dec = proto.Decoder(Ex) + + for x in [-2, 0, 1]: + assert dec.decode(proto.encode(Ex(x))).x == x + + err_msg = r"Expected `int` <= 1 - at `\$.x`" + for x in [2, 3]: + with pytest.raises(msgspec.ValidationError, match=err_msg): + dec.decode(proto.encode(Ex(x))) + def test_multiple_of(self, proto): good = [-(2**64), -2, 0, 2, 40, 2**63 + 2, 2**65] bad = [1, -1, 2**63 + 1, 2**65 + 1] @@ -350,6 +375,20 @@ class Ex(msgspec.Struct): with pytest.raises(msgspec.ValidationError, match=err_msg): dec.decode(proto.encode(Ex(x))) + def test_multiple_of_decimal(self, proto): + class Ex(msgspec.Struct): + x: Annotated[int, Meta(multiple_of=Decimal(2))] + + dec = proto.Decoder(Ex) + + for x in [-2, 0, 2, 4]: + assert dec.decode(proto.encode(Ex(x))).x == x + + err_msg = r"Expected `int` that's a multiple of 2 - at `\$.x`" + for x in [-1, 1, 3]: + with pytest.raises(msgspec.ValidationError, match=err_msg): + dec.decode(proto.encode(Ex(x))) + @pytest.mark.parametrize( "meta, good, bad", [ @@ -429,6 +468,20 @@ class Ex(msgspec.Struct): with pytest.raises(msgspec.ValidationError, match=err_msg): dec.decode(proto.encode(Ex(x))) + def test_bound_to_decimal(self, proto): + class Ex(msgspec.Struct): + x: Annotated[float, Meta(gt=Decimal("1.3"))] + + dec = proto.Decoder(Ex) + + for x in [1.5, 2, 3]: + assert dec.decode(proto.encode(Ex(x))).x == x + + err_msg = r"Expected `float` > 1.3 - at `\$.x`" + for x in [0, 1.2, -2, -0.5]: + with pytest.raises(msgspec.ValidationError, match=err_msg): + dec.decode(proto.encode(Ex(x))) + def test_multiple_of(self, proto): """multipleOf for floats will always have precisions issues. This check just ensures that _some_ cases work. See @@ -448,27 +501,169 @@ class Ex(msgspec.Struct): with pytest.raises(msgspec.ValidationError, match=err_msg): dec.decode(proto.encode(Ex(x))) + def test_multiple_of_decimal(self, proto): + class Ex(msgspec.Struct): + x: Annotated[float, Meta(multiple_of=Decimal("0.1"))] + + dec = proto.Decoder(Ex) + + for x in [0, 0.0, 0.1, -0.1, 0.2, -0.2]: + assert dec.decode(proto.encode(Ex(x))).x == x + + err_msg = r"Expected `float` that's a multiple of 0.1 - at `\$.x`" + for x in [0.01, -0.15]: + with pytest.raises(msgspec.ValidationError, match=err_msg): + dec.decode(proto.encode(Ex(x))) + + +class TestDecimalConstraints: + def get_bounds_cases(self, name, bound): + def ceilp1(x): + return math.ceil(x + 1) + + def floorm1(x): + return math.floor(x - 1) + + if name.startswith("g"): + good_dir = math.inf + good_round = ceilp1 + bad_round = floorm1 + else: + good_dir = -math.inf + good_round = floorm1 + bad_round = ceilp1 + + if name.endswith("e"): + good = bound + bad = decimal_nextafter(bound, -good_dir) + else: + good = decimal_nextafter(bound, good_dir) + bad = bound + good_cases = [good, good_round(good), float(good_round(good))] + bad_cases = [bad, bad_round(bad), float(bad_round(bad))] + + op = ">" if name.startswith("g") else "<" + if name.endswith("e"): + op += "=" + + return good_cases, bad_cases, op + + @pytest.mark.parametrize("name", ["ge", "gt", "le", "lt"]) + @pytest.mark.parametrize("bound", [Decimal("3.3"), Decimal("-2.2"), Decimal(4)]) + def test_bounds(self, proto, name, bound): + class Ex(msgspec.Struct): + x: Annotated[Decimal, Meta(**{name: bound})] + + dec = proto.Decoder(Ex) + + good, bad, op = self.get_bounds_cases(name, bound) + + for x in good: + assert dec.decode(proto.encode(Ex(x))).x == x + + err_msg = rf"Expected `Decimal` {op} {bound} - at `\$.x`" + for x in bad: + with pytest.raises(msgspec.ValidationError, match=err_msg): + dec.decode(proto.encode(Ex(x))) + + def test_bound_to_int(self, proto): + class Ex(msgspec.Struct): + x: Annotated[Decimal, Meta(le=2)] + + dec = proto.Decoder(Ex) + + for x in map(Decimal, [0, 1, -1.5, 2]): + assert dec.decode(proto.encode(Ex(x))).x == x + + err_msg = r"Expected `Decimal` <= 2 - at `\$.x`" + for x in map(Decimal, [2.2, 3]): + with pytest.raises(msgspec.ValidationError, match=err_msg): + dec.decode(proto.encode(Ex(x))) + + def test_bound_to_float(self, proto): + class Ex(msgspec.Struct): + x: Annotated[Decimal, Meta(ge=1.3)] + + dec = proto.Decoder(Ex) + + for x in map(Decimal, [1.3, 2, 3]): + assert dec.decode(proto.encode(Ex(x))).x == x + + err_msg = r"Expected `Decimal` >= 1.3 - at `\$.x`" + for x in map(Decimal, [0, 1.2, -2, -0.5]): + with pytest.raises(msgspec.ValidationError, match=err_msg): + dec.decode(proto.encode(Ex(x))) + + def test_multiple_of(self, proto): + class Ex(msgspec.Struct): + x: Annotated[Decimal, Meta(multiple_of=Decimal("5.3"))] + + dec = proto.Decoder(Ex) + + for x in map(Decimal, [0, "-15.9", "10.6", 106, -53]): + assert dec.decode(proto.encode(Ex(x))).x == x + + err_msg = r"Expected `Decimal` that's a multiple of 5.3 - at `\$.x`" + for x in map(Decimal, ["0.01", "-0.15", 5]): + with pytest.raises(msgspec.ValidationError, match=err_msg): + dec.decode(proto.encode(Ex(x))) + + def test_multiple_of_int(self, proto): + class Ex(msgspec.Struct): + x: Annotated[Decimal, Meta(multiple_of=2)] + + dec = proto.Decoder(Ex) + + for x in map(Decimal, [0, 2, 4, -6]): + assert dec.decode(proto.encode(Ex(x))).x == x + + err_msg = r"Expected `Decimal` that's a multiple of 2 - at `\$.x`" + for x in map(Decimal, [1, -3, "0.5", "-2.5"]): + with pytest.raises(msgspec.ValidationError, match=err_msg): + dec.decode(proto.encode(Ex(x))) + + def test_multiple_of_float(self, proto): + """Just as above, this check just ensures that _some_ cases work.""" + + class Ex(msgspec.Struct): + x: Annotated[Decimal, Meta(multiple_of=0.1)] + + dec = proto.Decoder(Ex) + + for x in map(Decimal, [0, "0.1", "-0.1", "0.2", "-0.2"]): + assert dec.decode(proto.encode(Ex(x))).x == x + + err_msg = r"Expected `Decimal` that's a multiple of 0.1 - at `\$.x`" + for x in map(Decimal, ["0.01", "-0.15"]): + with pytest.raises(msgspec.ValidationError, match=err_msg): + dec.decode(proto.encode(Ex(x))) + @pytest.mark.parametrize( "meta, good, bad", [ - (Meta(ge=0.0, le=10.0, multiple_of=2.0), [0, 2.0, 10], [-2, 11, 3]), - (Meta(ge=0.0, multiple_of=2.0), [0, 2, 10.0], [-2, 3]), - (Meta(le=10.0, multiple_of=2.0), [-2.0, 10.0], [11.0, 3.0]), - (Meta(ge=0.0, le=10.0), [0.0, 2.0, 10.0], [-1.0, 11.5, 11]), + ( + Meta(ge=Decimal(0), le=10, multiple_of=Decimal("2.5")), + [0, 5, "7.5"], + [-1, 1, 11], + ), + (Meta(ge=0, multiple_of=2), [0, 2, 2**63 + 2], [-2, 2**63 + 1]), + (Meta(le=Decimal(0), multiple_of=2), [0, -(2**63)], [-1.5, 2, 2**63]), + (Meta(ge=0, le=10), [0, 10, 0.2], [-1, 11]), + (Meta(gt=0, lt=10), [1, 2, 9.8], [-1, 0, 10]), ], ) def test_combinations(self, proto, meta, good, bad): class Ex(msgspec.Struct): - x: Annotated[float, meta] + x: Annotated[Decimal, meta] dec = proto.Decoder(Ex) - for x in good: + for x in map(Decimal, good): assert dec.decode(proto.encode(Ex(x))).x == x - for x in bad: + for x in map(Decimal, bad): with pytest.raises(msgspec.ValidationError): - assert dec.decode(proto.encode(Ex(x))) + dec.decode(proto.encode(Ex(x))) class TestStrConstraints: diff --git a/tests/test_convert.py b/tests/test_convert.py index 95d496e3..1f065d21 100644 --- a/tests/test_convert.py +++ b/tests/test_convert.py @@ -1,5 +1,4 @@ import datetime -import decimal import enum import gc import math @@ -8,6 +7,7 @@ from base64 import b64encode from collections.abc import MutableMapping from dataclasses import dataclass, field +from decimal import Decimal from typing import ( Any, Dict, @@ -23,7 +23,7 @@ ) import pytest -from utils import temp_module, max_call_depth +from utils import max_call_depth, temp_module import msgspec from msgspec import Meta, Struct, ValidationError, convert, to_builtins @@ -294,7 +294,7 @@ def test_unsupported_output_type(self): (datetime.time(12, 34), "time"), (datetime.date(2022, 1, 2), "date"), (uuid.uuid4(), "uuid"), - (decimal.Decimal("1.5"), "decimal"), + (Decimal("1.5"), "Decimal"), ([1], "array"), ((1,), "array"), ({"a": 1}, "object"), @@ -461,19 +461,19 @@ class Ex(Struct): convert({"x": x}, Ex) def test_float_from_decimal(self): - res = convert(decimal.Decimal("1.5"), float) + res = convert(Decimal("1.5"), float) assert res == 1.5 assert type(res) is float @uses_annotated def test_constr_float_from_decimal(self): typ = Annotated[float, Meta(ge=0)] - res = convert(decimal.Decimal("1.5"), typ) + res = convert(Decimal("1.5"), typ) assert res == 1.5 assert type(res) is float with pytest.raises(ValidationError, match="Expected `float` >= 0.0"): - convert(decimal.Decimal("-1.5"), typ) + convert(Decimal("-1.5"), typ) class TestStr: @@ -740,37 +740,84 @@ class UUID2(uuid.UUID): class TestDecimal: def test_decimal_wrong_type(self): - with pytest.raises(ValidationError, match="Expected `decimal`, got `array`"): - convert([], decimal.Decimal) + with pytest.raises(ValidationError, match="Expected `Decimal`, got `array`"): + convert([], Decimal) def test_decimal_builtin(self): - x = decimal.Decimal("1.5") - assert convert(x, decimal.Decimal) is x + x = Decimal("1.5") + assert convert(x, Decimal) is x def test_decimal_str(self): - sol = decimal.Decimal("1.5") - res = convert("1.5", decimal.Decimal) + sol = Decimal("1.5") + res = convert("1.5", Decimal) assert res == sol - assert type(res) is decimal.Decimal + assert type(res) is Decimal @pytest.mark.parametrize("val", [1.3, float("nan"), float("inf"), float("-inf")]) def test_decimal_float(self, val): - sol = decimal.Decimal(str(val)) - res = convert(val, decimal.Decimal) + sol = Decimal(str(val)) + res = convert(val, Decimal) assert str(res) == str(sol) # compare strs to support NaN - assert type(res) is decimal.Decimal + assert type(res) is Decimal @pytest.mark.parametrize("val", [0, 1234, -1234]) def test_decimal_int(self, val): - sol = decimal.Decimal(val) - res = convert(val, decimal.Decimal) + sol = Decimal(val) + res = convert(val, Decimal) assert res == sol - assert type(res) is decimal.Decimal + assert type(res) is Decimal @pytest.mark.parametrize("val, typ", [("1.5", "str"), (123, "int"), (1.3, "float")]) def test_decimal_conversion_disabled(self, val, typ): - with pytest.raises(ValidationError, match=f"Expected `decimal`, got `{typ}`"): - convert(val, decimal.Decimal, builtin_types=(decimal.Decimal,)) + with pytest.raises(ValidationError, match=f"Expected `Decimal`, got `{typ}`"): + convert(val, Decimal, builtin_types=(Decimal,)) + + @pytest.mark.parametrize( + "meta, good, bad", + [ + ( + Meta(ge=0, le=10, multiple_of=Decimal("2.5")), + [0, "7.5", 10], + ["-2.5", 11, 3], + ), + (Meta(ge=Decimal(0), multiple_of=2), [0, 2, 10], [-2, 3]), + (Meta(le=Decimal(10), multiple_of=2), [-2, 10], [11, 3]), + (Meta(ge=0, le=10), [0, 2, 10], [-1, "11.5", 11]), + ], + ) + @uses_annotated + def test_decimal_constrs(self, meta, good, bad): + class Ex(Struct): + x: Annotated[Decimal, meta] + + for x in map(Decimal, good): + assert convert({"x": x}, Ex).x == x + + for x in map(Decimal, bad): + with pytest.raises(ValidationError): + convert({"x": x}, Ex) + + @uses_annotated + def test_constr_decimal_from_float(self): + typ = Annotated[Decimal, Meta(ge=0)] + res = convert(1.5, typ) + assert res == Decimal(1.5) + assert type(res) is Decimal + + with pytest.raises(ValidationError, match="Expected `Decimal` >= 0"): + convert(-1, typ) + + @uses_annotated + def test_constr_decimal_from_str(self): + typ = Annotated[Decimal, Meta(multiple_of=Decimal("5.3"))] + res = convert("-15.9", typ) + assert res == Decimal("-15.9") + assert type(res) is Decimal + + with pytest.raises( + ValidationError, match="Expected `Decimal` that's a multiple of 5.3" + ): + convert("-1", typ) class TestExt: