diff --git a/src/betterproto/__init__.py b/src/betterproto/__init__.py index bfda1c0bb..9e363f0a3 100644 --- a/src/betterproto/__init__.py +++ b/src/betterproto/__init__.py @@ -50,7 +50,10 @@ if TYPE_CHECKING: - from _typeshed import ReadableBuffer + from _typeshed import ( + ReadableBuffer, + WriteableBuffer, + ) # Proto 3 data types @@ -127,6 +130,9 @@ WIRE_FIXED_64_TYPES = [TYPE_DOUBLE, TYPE_FIXED64, TYPE_SFIXED64] WIRE_LEN_DELIM_TYPES = [TYPE_STRING, TYPE_BYTES, TYPE_MESSAGE, TYPE_MAP] +# Indicator of message delimitation in streams +SIZE_DELIMITED = -1 + # Protobuf datetimes start at the Unix Epoch in 1970 in UTC. def datetime_default_gen() -> datetime: @@ -322,7 +328,7 @@ def _pack_fmt(proto_type: str) -> str: }[proto_type] -def dump_varint(value: int, stream: BinaryIO) -> None: +def dump_varint(value: int, stream: "WriteableBuffer") -> None: """Encodes a single varint and dumps it into the provided stream.""" if value < -(1 << 63): raise ValueError( @@ -531,7 +537,7 @@ def _dump_float(value: float) -> Union[float, str]: return value -def load_varint(stream: BinaryIO) -> Tuple[int, bytes]: +def load_varint(stream: "ReadableBuffer") -> Tuple[int, bytes]: """ Load a single varint value from a stream. Returns the value and the raw bytes read. """ @@ -569,7 +575,7 @@ class ParsedField: raw: bytes -def load_fields(stream: BinaryIO) -> Generator[ParsedField, None, None]: +def load_fields(stream: "ReadableBuffer") -> Generator[ParsedField, None, None]: while True: try: num_wire, raw = load_varint(stream) @@ -881,7 +887,7 @@ def _betterproto(self) -> ProtoClassMetadata: self.__class__._betterproto_meta = meta # type: ignore return meta - def dump(self, stream: BinaryIO) -> None: + def dump(self, stream: "WriteableBuffer", delimit: bool = False) -> None: """ Dumps the binary encoded Protobuf message to the stream. @@ -889,7 +895,11 @@ def dump(self, stream: BinaryIO) -> None: ----------- stream: :class:`BinaryIO` The stream to dump the message to. + delimit: + Whether to prefix the message with a varint declaring its size. """ + if delimit == SIZE_DELIMITED: + dump_varint(len(self), stream) for field_name, meta in self._betterproto.meta_by_field_name.items(): try: @@ -1207,7 +1217,11 @@ def _include_default_value_for_oneof( meta.group is not None and self._group_current.get(meta.group) == field_name ) - def load(self: T, stream: BinaryIO, size: Optional[int] = None) -> T: + def load( + self: T, + stream: "ReadableBuffer", + size: Optional[int] = None, + ) -> T: """ Load the binary encoded Protobuf from a stream into this message instance. This returns the instance itself and is therefore assignable and chainable. @@ -1219,12 +1233,17 @@ def load(self: T, stream: BinaryIO, size: Optional[int] = None) -> T: size: :class:`Optional[int]` The size of the message in the stream. Reads stream until EOF if ``None`` is given. + Reads based on a size delimiter prefix varint if SIZE_DELIMITED is given. Returns -------- :class:`Message` The initialized message. """ + # If the message is delimited, parse the message delimiter + if size == SIZE_DELIMITED: + size, _ = load_varint(stream) + # Got some data over the wire self._serialized_on_wire = True proto_meta = self._betterproto @@ -1297,7 +1316,7 @@ def load(self: T, stream: BinaryIO, size: Optional[int] = None) -> T: return self - def parse(self: T, data: "ReadableBuffer") -> T: + def parse(self: T, data: bytes) -> T: """ Parse the binary encoded Protobuf into this message instance. This returns the instance itself and is therefore assignable and chainable. diff --git a/tests/streams/delimited_messages.in b/tests/streams/delimited_messages.in new file mode 100644 index 000000000..5993ac6f8 --- /dev/null +++ b/tests/streams/delimited_messages.in @@ -0,0 +1,2 @@ +•šï:bTesting•šï:bTesting +  \ No newline at end of file diff --git a/tests/test_streams.py b/tests/test_streams.py index a1c2bbd98..f41ac45ec 100644 --- a/tests/test_streams.py +++ b/tests/test_streams.py @@ -127,6 +127,18 @@ def test_message_dump_file_multiple(tmp_path): assert test_stream.read() == exp_stream.read() +def test_message_dump_delimited(tmp_path): + with open(tmp_path / "message_dump_delimited.out", "wb") as stream: + oneof_example.dump(stream, betterproto.SIZE_DELIMITED) + oneof_example.dump(stream, betterproto.SIZE_DELIMITED) + nested_example.dump(stream, betterproto.SIZE_DELIMITED) + + with open(tmp_path / "message_dump_delimited.out", "rb") as test_stream, open( + streams_path / "delimited_messages.in", "rb" + ) as exp_stream: + assert test_stream.read() == exp_stream.read() + + def test_message_len(): assert len_oneof == len(bytes(oneof_example)) assert len(nested_example) == len(bytes(nested_example)) @@ -155,7 +167,15 @@ def test_message_load_too_small(): oneof.Test().load(stream, len_oneof - 1) -def test_message_too_large(): +def test_message_load_delimited(): + with open(streams_path / "delimited_messages.in", "rb") as stream: + assert oneof.Test().load(stream, betterproto.SIZE_DELIMITED) == oneof_example + assert oneof.Test().load(stream, betterproto.SIZE_DELIMITED) == oneof_example + assert nested.Test().load(stream, betterproto.SIZE_DELIMITED) == nested_example + assert stream.read(1) == b"" + + +def test_message_load_too_large(): with open( streams_path / "message_dump_file_single.expected", "rb" ) as stream, pytest.raises(ValueError):