Skip to content

Commit

Permalink
Support decoding uuids from bytes even in strict mode
Browse files Browse the repository at this point in the history
This seems safe to do since encoding UUIDs as 16 byte values is a well
defined format and common-enough format.
  • Loading branch information
jcrist committed Jul 28, 2023
1 parent 96359e4 commit 6bb1937
Show file tree
Hide file tree
Showing 4 changed files with 38 additions and 51 deletions.
30 changes: 7 additions & 23 deletions docs/source/supported-types.rst
Original file line number Diff line number Diff line change
Expand Up @@ -519,41 +519,25 @@ The format may be selected by passing it to ``uuid_format`` when creating an
128-bit integer representation (same as ``uuid.bytes``). This is only supported
by the MessagePack encoder.

When decoding, both ``canonical`` and ``hex`` formats are supported by all
protocols.
When decoding, any of the above formats are accepted.

.. code-block:: python
>>> enc = msgspec.json.Encoder(uuid_format="hex")
>>> enc.encode(u)
b'"c4524ac0e81e4aa8a5950aec605a659a"'
>>> uuid_hex = enc.encode(u)
>>> u.hex
"c4524ac0e81e4aa8a5950aec605a659a"
>>> uuid_hex
b'"c4524ac0e81e4aa8a5950aec605a659a"'
>>> msgspec.json.decode(b'"c4524ac0e81e4aa8a5950aec605a659a"', type=uuid.UUID)
>>> msgspec.json.decode(uuid_hex, type=uuid.UUID)
UUID('c4524ac0-e81e-4aa8-a595-0aec605a659a')
Additionally, if ``strict=False`` is specified, the ``bytes`` format may be
decoded by the MessagePack decoder. See :ref:`strict-vs-lax` for more
information.

.. code-block:: python
>>> enc = msgspec.msgpack.Encoder(uuid_format="bytes")
>>> msg = enc.encode(u)
>>> msg
b'\xc4\x10\xc4RJ\xc0\xe8\x1eJ\xa8\xa5\x95\n\xec`Ze\x9a'
>>> msgspec.msgpack.decode(msg, type=uuid.UUID)
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
msgspec.ValidationError: Expected `uuid`, got `bytes`
>>> uuid_bytes = enc.encode(u)
>>> msgspec.msgpack.decode(msg, type=uuid.UUID, strict=False)
>>> msgspec.msgpack.decode(uuid_bytes, type=uuid.UUID)
UUID('c4524ac0-e81e-4aa8-a595-0aec605a659a')
Expand Down
12 changes: 9 additions & 3 deletions msgspec/_core.c
Original file line number Diff line number Diff line change
Expand Up @@ -13585,7 +13585,7 @@ mpack_decode_bin(
else if (type->types & MS_TYPE_BYTEARRAY) {
return PyByteArray_FromStringAndSize(s, size);
}
else if (!self->strict && (type->types & MS_TYPE_UUID)) {
else if (type->types & MS_TYPE_UUID) {
return ms_decode_uuid_from_bytes(s, size, path);
}

Expand Down Expand Up @@ -18992,7 +18992,10 @@ convert_bytes(
}
return PyByteArray_FromObject(obj);
}
if (!self->strict && (type->types & MS_TYPE_UUID)) {
if (
(type->types & MS_TYPE_UUID) &&
!(self->builtin_types & MS_BUILTIN_UUID)
) {
return ms_decode_uuid_from_bytes(
PyBytes_AS_STRING(obj), PyBytes_GET_SIZE(obj), path
);
Expand All @@ -19014,7 +19017,10 @@ convert_bytearray(
}
return PyBytes_FromObject(obj);
}
if (!self->strict && (type->types & MS_TYPE_UUID)) {
if (
(type->types & MS_TYPE_UUID) &&
!(self->builtin_types & MS_BUILTIN_UUID)
) {
return ms_decode_uuid_from_bytes(
PyByteArray_AS_STRING(obj), PyByteArray_GET_SIZE(obj), path
);
Expand Down
11 changes: 3 additions & 8 deletions tests/test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -3402,20 +3402,15 @@ def test_decode_uuid(self, proto, upper, hyphens):
assert res == u
assert res.is_safe == u.is_safe

def test_decode_uuid_from_bytes_lax(self):
def test_decode_uuid_from_bytes(self):
sol = uuid.uuid4()
msg = msgspec.msgpack.encode(sol.bytes)
res = msgspec.msgpack.decode(msg, type=uuid.UUID, strict=False)
res = msgspec.msgpack.decode(msg, type=uuid.UUID)
assert res == sol

bad_msg = msgspec.msgpack.encode(b"x" * 8)
with pytest.raises(msgspec.ValidationError, match="Invalid UUID bytes"):
msgspec.msgpack.decode(bad_msg, type=uuid.UUID, strict=False)

with pytest.raises(
msgspec.ValidationError, match="Expected `uuid`, got `bytes`"
):
msgspec.msgpack.decode(msg, type=uuid.UUID)
msgspec.msgpack.decode(bad_msg, type=uuid.UUID)

@pytest.mark.parametrize(
"uuid_str",
Expand Down
36 changes: 19 additions & 17 deletions tests/test_convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -704,10 +704,26 @@ def test_uuid_str(self):
res = convert(str(sol), uuid.UUID)
assert res == sol

def test_uuid_str_disabled(self):
msg = str(uuid.uuid4())
@pytest.mark.parametrize("input_type", [bytes, bytearray])
def test_uuid_bytes(self, input_type):
sol = uuid.uuid4()
msg = input_type(sol.bytes)
res = convert(msg, uuid.UUID)
assert res == sol

bad_msg = input_type(b"x" * 8)
with pytest.raises(msgspec.ValidationError, match="Invalid UUID bytes"):
convert(bad_msg, type=uuid.UUID)

def test_uuid_disabled(self):
u = uuid.uuid4()

with pytest.raises(ValidationError, match="Expected `uuid`, got `str`"):
convert(msg, uuid.UUID, builtin_types=(uuid.UUID,))
convert(str(u), uuid.UUID, builtin_types=(uuid.UUID,))

for typ in [bytes, bytearray]:
with pytest.raises(ValidationError, match="Expected `uuid`, got `bytes`"):
convert(typ(u.bytes), uuid.UUID, builtin_types=(uuid.UUID,))


class TestDecimal:
Expand Down Expand Up @@ -2348,20 +2364,6 @@ def test_lax_timedelta_invalid_numeric_str(self):
with pytest.raises(ValidationError, match="Invalid"):
convert(msg, type=datetime.timedelta, strict=False)

@pytest.mark.parametrize("input_type", [bytes, bytearray])
def test_lax_uuid(self, input_type):
sol = uuid.uuid4()
msg = input_type(sol.bytes)
res = convert(msg, uuid.UUID, strict=False)
assert res == sol

bad_msg = input_type(b"x" * 8)
with pytest.raises(msgspec.ValidationError, match="Invalid UUID bytes"):
convert(bad_msg, type=uuid.UUID, strict=False)

with pytest.raises(ValidationError, match="Expected `uuid`, got `bytes`"):
convert(msg, uuid.UUID)

@pytest.mark.parametrize(
"msg, sol",
[
Expand Down

0 comments on commit 6bb1937

Please sign in to comment.