Skip to content

Commit

Permalink
Add support for streaming delimited messages
Browse files Browse the repository at this point in the history
This allows developers to easily dump and load multiple messages
from a stream in a way that is compatible with official
protobuf implementations (such as Java's
`MessageLite#writeDelimitedTo(...)`).
  • Loading branch information
JoshuaLeivers committed Oct 16, 2023
1 parent 6b36b9b commit 690798c
Show file tree
Hide file tree
Showing 3 changed files with 49 additions and 8 deletions.
33 changes: 26 additions & 7 deletions src/betterproto/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,10 @@


if TYPE_CHECKING:
from _typeshed import ReadableBuffer
from _typeshed import (
ReadableBuffer,
WriteableBuffer,
)


# Proto 3 data types
Expand Down Expand Up @@ -127,6 +130,9 @@
WIRE_FIXED_64_TYPES = [TYPE_DOUBLE, TYPE_FIXED64, TYPE_SFIXED64]
WIRE_LEN_DELIM_TYPES = [TYPE_STRING, TYPE_BYTES, TYPE_MESSAGE, TYPE_MAP]

# Indicator of message delimitation in streams
SIZE_DELIMITED = -1


# Protobuf datetimes start at the Unix Epoch in 1970 in UTC.
def datetime_default_gen() -> datetime:
Expand Down Expand Up @@ -322,7 +328,7 @@ def _pack_fmt(proto_type: str) -> str:
}[proto_type]


def dump_varint(value: int, stream: BinaryIO) -> None:
def dump_varint(value: int, stream: "WriteableBuffer") -> None:
"""Encodes a single varint and dumps it into the provided stream."""
if value < -(1 << 63):
raise ValueError(
Expand Down Expand Up @@ -531,7 +537,7 @@ def _dump_float(value: float) -> Union[float, str]:
return value


def load_varint(stream: BinaryIO) -> Tuple[int, bytes]:
def load_varint(stream: "ReadableBuffer") -> Tuple[int, bytes]:
"""
Load a single varint value from a stream. Returns the value and the raw bytes read.
"""
Expand Down Expand Up @@ -569,7 +575,7 @@ class ParsedField:
raw: bytes


def load_fields(stream: BinaryIO) -> Generator[ParsedField, None, None]:
def load_fields(stream: "ReadableBuffer") -> Generator[ParsedField, None, None]:
while True:
try:
num_wire, raw = load_varint(stream)
Expand Down Expand Up @@ -881,15 +887,19 @@ def _betterproto(self) -> ProtoClassMetadata:
self.__class__._betterproto_meta = meta # type: ignore
return meta

def dump(self, stream: BinaryIO) -> None:
def dump(self, stream: "WriteableBuffer", delimit: bool = False) -> None:
"""
Dumps the binary encoded Protobuf message to the stream.
Parameters
-----------
stream: :class:`BinaryIO`
The stream to dump the message to.
delimit:
Whether to prefix the message with a varint declaring its size.
"""
if delimit == SIZE_DELIMITED:
dump_varint(len(self), stream)

for field_name, meta in self._betterproto.meta_by_field_name.items():
try:
Expand Down Expand Up @@ -1207,7 +1217,11 @@ def _include_default_value_for_oneof(
meta.group is not None and self._group_current.get(meta.group) == field_name
)

def load(self: T, stream: BinaryIO, size: Optional[int] = None) -> T:
def load(
self: T,
stream: "ReadableBuffer",
size: Optional[int] = None,
) -> T:
"""
Load the binary encoded Protobuf from a stream into this message instance. This
returns the instance itself and is therefore assignable and chainable.
Expand All @@ -1219,12 +1233,17 @@ def load(self: T, stream: BinaryIO, size: Optional[int] = None) -> T:
size: :class:`Optional[int]`
The size of the message in the stream.
Reads stream until EOF if ``None`` is given.
Reads based on a size delimiter prefix varint if SIZE_DELIMITED is given.
Returns
--------
:class:`Message`
The initialized message.
"""
# If the message is delimited, parse the message delimiter
if size == SIZE_DELIMITED:
size, _ = load_varint(stream)

# Got some data over the wire
self._serialized_on_wire = True
proto_meta = self._betterproto
Expand Down Expand Up @@ -1297,7 +1316,7 @@ def load(self: T, stream: BinaryIO, size: Optional[int] = None) -> T:

return self

def parse(self: T, data: "ReadableBuffer") -> T:
def parse(self: T, data: bytes) -> T:
"""
Parse the binary encoded Protobuf into this message instance. This
returns the instance itself and is therefore assignable and chainable.
Expand Down
2 changes: 2 additions & 0 deletions tests/streams/delimited_messages.in
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
���:bTesting���:bTesting
 
22 changes: 21 additions & 1 deletion tests/test_streams.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,18 @@ def test_message_dump_file_multiple(tmp_path):
assert test_stream.read() == exp_stream.read()


def test_message_dump_delimited(tmp_path):
with open(tmp_path / "message_dump_delimited.out", "wb") as stream:
oneof_example.dump(stream, betterproto.SIZE_DELIMITED)
oneof_example.dump(stream, betterproto.SIZE_DELIMITED)
nested_example.dump(stream, betterproto.SIZE_DELIMITED)

with open(tmp_path / "message_dump_delimited.out", "rb") as test_stream, open(
streams_path / "delimited_messages.in", "rb"
) as exp_stream:
assert test_stream.read() == exp_stream.read()


def test_message_len():
assert len_oneof == len(bytes(oneof_example))
assert len(nested_example) == len(bytes(nested_example))
Expand Down Expand Up @@ -155,7 +167,15 @@ def test_message_load_too_small():
oneof.Test().load(stream, len_oneof - 1)


def test_message_too_large():
def test_message_load_delimited():
with open(streams_path / "delimited_messages.in", "rb") as stream:
assert oneof.Test().load(stream, betterproto.SIZE_DELIMITED) == oneof_example
assert oneof.Test().load(stream, betterproto.SIZE_DELIMITED) == oneof_example
assert nested.Test().load(stream, betterproto.SIZE_DELIMITED) == nested_example
assert stream.read(1) == b""


def test_message_load_too_large():
with open(
streams_path / "message_dump_file_single.expected", "rb"
) as stream, pytest.raises(ValueError):
Expand Down

0 comments on commit 690798c

Please sign in to comment.