From 6bb1937f936c2a7be3ba0e1fb85179ffc85bf828 Mon Sep 17 00:00:00 2001 From: Jim Crist-Harif Date: Thu, 27 Jul 2023 23:18:06 -0500 Subject: [PATCH] Support decoding uuids from bytes even in strict mode This seems safe to do since encoding UUIDs as 16 byte values is a well defined format and common-enough format. --- docs/source/supported-types.rst | 30 +++++++-------------------- msgspec/_core.c | 12 ++++++++--- tests/test_common.py | 11 +++------- tests/test_convert.py | 36 +++++++++++++++++---------------- 4 files changed, 38 insertions(+), 51 deletions(-) diff --git a/docs/source/supported-types.rst b/docs/source/supported-types.rst index 1eb6b93f..1903a540 100644 --- a/docs/source/supported-types.rst +++ b/docs/source/supported-types.rst @@ -519,41 +519,25 @@ The format may be selected by passing it to ``uuid_format`` when creating an 128-bit integer representation (same as ``uuid.bytes``). This is only supported by the MessagePack encoder. -When decoding, both ``canonical`` and ``hex`` formats are supported by all -protocols. +When decoding, any of the above formats are accepted. .. code-block:: python >>> enc = msgspec.json.Encoder(uuid_format="hex") - >>> enc.encode(u) - b'"c4524ac0e81e4aa8a5950aec605a659a"' + >>> uuid_hex = enc.encode(u) - >>> u.hex - "c4524ac0e81e4aa8a5950aec605a659a" + >>> uuid_hex + b'"c4524ac0e81e4aa8a5950aec605a659a"' - >>> msgspec.json.decode(b'"c4524ac0e81e4aa8a5950aec605a659a"', type=uuid.UUID) + >>> msgspec.json.decode(uuid_hex, type=uuid.UUID) UUID('c4524ac0-e81e-4aa8-a595-0aec605a659a') -Additionally, if ``strict=False`` is specified, the ``bytes`` format may be -decoded by the MessagePack decoder. See :ref:`strict-vs-lax` for more -information. - -.. code-block:: python - >>> enc = msgspec.msgpack.Encoder(uuid_format="bytes") - >>> msg = enc.encode(u) - - >>> msg - b'\xc4\x10\xc4RJ\xc0\xe8\x1eJ\xa8\xa5\x95\n\xec`Ze\x9a' - - >>> msgspec.msgpack.decode(msg, type=uuid.UUID) - Traceback (most recent call last): - File "", line 1, in - msgspec.ValidationError: Expected `uuid`, got `bytes` + >>> uuid_bytes = enc.encode(u) - >>> msgspec.msgpack.decode(msg, type=uuid.UUID, strict=False) + >>> msgspec.msgpack.decode(uuid_bytes, type=uuid.UUID) UUID('c4524ac0-e81e-4aa8-a595-0aec605a659a') diff --git a/msgspec/_core.c b/msgspec/_core.c index 2d46c836..e6223767 100644 --- a/msgspec/_core.c +++ b/msgspec/_core.c @@ -13585,7 +13585,7 @@ mpack_decode_bin( else if (type->types & MS_TYPE_BYTEARRAY) { return PyByteArray_FromStringAndSize(s, size); } - else if (!self->strict && (type->types & MS_TYPE_UUID)) { + else if (type->types & MS_TYPE_UUID) { return ms_decode_uuid_from_bytes(s, size, path); } @@ -18992,7 +18992,10 @@ convert_bytes( } return PyByteArray_FromObject(obj); } - if (!self->strict && (type->types & MS_TYPE_UUID)) { + if ( + (type->types & MS_TYPE_UUID) && + !(self->builtin_types & MS_BUILTIN_UUID) + ) { return ms_decode_uuid_from_bytes( PyBytes_AS_STRING(obj), PyBytes_GET_SIZE(obj), path ); @@ -19014,7 +19017,10 @@ convert_bytearray( } return PyBytes_FromObject(obj); } - if (!self->strict && (type->types & MS_TYPE_UUID)) { + if ( + (type->types & MS_TYPE_UUID) && + !(self->builtin_types & MS_BUILTIN_UUID) + ) { return ms_decode_uuid_from_bytes( PyByteArray_AS_STRING(obj), PyByteArray_GET_SIZE(obj), path ); diff --git a/tests/test_common.py b/tests/test_common.py index 4dfec13e..e20e8a6d 100644 --- a/tests/test_common.py +++ b/tests/test_common.py @@ -3402,20 +3402,15 @@ def test_decode_uuid(self, proto, upper, hyphens): assert res == u assert res.is_safe == u.is_safe - def test_decode_uuid_from_bytes_lax(self): + def test_decode_uuid_from_bytes(self): sol = uuid.uuid4() msg = msgspec.msgpack.encode(sol.bytes) - res = msgspec.msgpack.decode(msg, type=uuid.UUID, strict=False) + res = msgspec.msgpack.decode(msg, type=uuid.UUID) assert res == sol bad_msg = msgspec.msgpack.encode(b"x" * 8) with pytest.raises(msgspec.ValidationError, match="Invalid UUID bytes"): - msgspec.msgpack.decode(bad_msg, type=uuid.UUID, strict=False) - - with pytest.raises( - msgspec.ValidationError, match="Expected `uuid`, got `bytes`" - ): - msgspec.msgpack.decode(msg, type=uuid.UUID) + msgspec.msgpack.decode(bad_msg, type=uuid.UUID) @pytest.mark.parametrize( "uuid_str", diff --git a/tests/test_convert.py b/tests/test_convert.py index edc51ded..bf50422f 100644 --- a/tests/test_convert.py +++ b/tests/test_convert.py @@ -704,10 +704,26 @@ def test_uuid_str(self): res = convert(str(sol), uuid.UUID) assert res == sol - def test_uuid_str_disabled(self): - msg = str(uuid.uuid4()) + @pytest.mark.parametrize("input_type", [bytes, bytearray]) + def test_uuid_bytes(self, input_type): + sol = uuid.uuid4() + msg = input_type(sol.bytes) + res = convert(msg, uuid.UUID) + assert res == sol + + bad_msg = input_type(b"x" * 8) + with pytest.raises(msgspec.ValidationError, match="Invalid UUID bytes"): + convert(bad_msg, type=uuid.UUID) + + def test_uuid_disabled(self): + u = uuid.uuid4() + with pytest.raises(ValidationError, match="Expected `uuid`, got `str`"): - convert(msg, uuid.UUID, builtin_types=(uuid.UUID,)) + convert(str(u), uuid.UUID, builtin_types=(uuid.UUID,)) + + for typ in [bytes, bytearray]: + with pytest.raises(ValidationError, match="Expected `uuid`, got `bytes`"): + convert(typ(u.bytes), uuid.UUID, builtin_types=(uuid.UUID,)) class TestDecimal: @@ -2348,20 +2364,6 @@ def test_lax_timedelta_invalid_numeric_str(self): with pytest.raises(ValidationError, match="Invalid"): convert(msg, type=datetime.timedelta, strict=False) - @pytest.mark.parametrize("input_type", [bytes, bytearray]) - def test_lax_uuid(self, input_type): - sol = uuid.uuid4() - msg = input_type(sol.bytes) - res = convert(msg, uuid.UUID, strict=False) - assert res == sol - - bad_msg = input_type(b"x" * 8) - with pytest.raises(msgspec.ValidationError, match="Invalid UUID bytes"): - convert(bad_msg, type=uuid.UUID, strict=False) - - with pytest.raises(ValidationError, match="Expected `uuid`, got `bytes`"): - convert(msg, uuid.UUID) - @pytest.mark.parametrize( "msg, sol", [