From 4f18ed1325d008d140a435640c470e965abb473b Mon Sep 17 00:00:00 2001 From: Joshua Leivers Date: Mon, 16 Oct 2023 11:59:33 +0100 Subject: [PATCH] Add support for streaming delimited messages (#529) * Add support for streaming delimited messages This allows developers to easily dump and load multiple messages from a stream in a way that is compatible with official protobuf implementations (such as Java's `MessageLite#writeDelimitedTo(...)`). * Add Java compatibility tests for streaming These tests stream data such as messages to output files, have a Java binary read them and then write them back using the `protobuf-java` functions, and then read them back in on the Python side to check that the returned data is as expected. This checks that the official Java implementation (and so any other matching implementations) can properly parse outputs from Betterproto, and vice-versa, ensuring compatibility in these functions between the two. * Replace `xxxxableBuffer` with `SupportsXxxx` --- .pre-commit-config.yaml | 9 +- src/betterproto/__init__.py | 33 +++- tests/streams/delimited_messages.in | 2 + tests/streams/java/.gitignore | 38 ++++ tests/streams/java/pom.xml | 94 ++++++++++ .../java/betterproto/CompatibilityTest.java | 41 +++++ .../java/src/main/java/betterproto/Tests.java | 115 ++++++++++++ .../src/main/proto/betterproto/nested.proto | 27 +++ .../src/main/proto/betterproto/oneof.proto | 19 ++ tests/test_streams.py | 168 +++++++++++++++++- 10 files changed, 537 insertions(+), 9 deletions(-) create mode 100644 tests/streams/delimited_messages.in create mode 100644 tests/streams/java/.gitignore create mode 100644 tests/streams/java/pom.xml create mode 100644 tests/streams/java/src/main/java/betterproto/CompatibilityTest.java create mode 100644 tests/streams/java/src/main/java/betterproto/Tests.java create mode 100644 tests/streams/java/src/main/proto/betterproto/nested.proto create mode 100644 tests/streams/java/src/main/proto/betterproto/oneof.proto diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index fa461c82..9e3f65a8 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -16,6 +16,13 @@ repos: - repo: https://github.com/PyCQA/doc8 rev: 0.10.1 hooks: - - id: doc8 + - id: doc8 additional_dependencies: - toml + + - repo: https://github.com/macisamuele/language-formatters-pre-commit-hooks + rev: v2.10.0 + hooks: + - id: pretty-format-java + args: [--autofix, --aosp] + files: ^.*\.java$ diff --git a/src/betterproto/__init__.py b/src/betterproto/__init__.py index bfda1c0b..35acd788 100644 --- a/src/betterproto/__init__.py +++ b/src/betterproto/__init__.py @@ -50,7 +50,10 @@ if TYPE_CHECKING: - from _typeshed import ReadableBuffer + from _typeshed import ( + SupportsRead, + SupportsWrite, + ) # Proto 3 data types @@ -127,6 +130,9 @@ WIRE_FIXED_64_TYPES = [TYPE_DOUBLE, TYPE_FIXED64, TYPE_SFIXED64] WIRE_LEN_DELIM_TYPES = [TYPE_STRING, TYPE_BYTES, TYPE_MESSAGE, TYPE_MAP] +# Indicator of message delimitation in streams +SIZE_DELIMITED = -1 + # Protobuf datetimes start at the Unix Epoch in 1970 in UTC. def datetime_default_gen() -> datetime: @@ -322,7 +328,7 @@ def _pack_fmt(proto_type: str) -> str: }[proto_type] -def dump_varint(value: int, stream: BinaryIO) -> None: +def dump_varint(value: int, stream: "SupportsWrite[bytes]") -> None: """Encodes a single varint and dumps it into the provided stream.""" if value < -(1 << 63): raise ValueError( @@ -531,7 +537,7 @@ def _dump_float(value: float) -> Union[float, str]: return value -def load_varint(stream: BinaryIO) -> Tuple[int, bytes]: +def load_varint(stream: "SupportsRead[bytes]") -> Tuple[int, bytes]: """ Load a single varint value from a stream. Returns the value and the raw bytes read. """ @@ -569,7 +575,7 @@ class ParsedField: raw: bytes -def load_fields(stream: BinaryIO) -> Generator[ParsedField, None, None]: +def load_fields(stream: "SupportsRead[bytes]") -> Generator[ParsedField, None, None]: while True: try: num_wire, raw = load_varint(stream) @@ -881,7 +887,7 @@ def _betterproto(self) -> ProtoClassMetadata: self.__class__._betterproto_meta = meta # type: ignore return meta - def dump(self, stream: BinaryIO) -> None: + def dump(self, stream: "SupportsWrite[bytes]", delimit: bool = False) -> None: """ Dumps the binary encoded Protobuf message to the stream. @@ -889,7 +895,11 @@ def dump(self, stream: BinaryIO) -> None: ----------- stream: :class:`BinaryIO` The stream to dump the message to. + delimit: + Whether to prefix the message with a varint declaring its size. """ + if delimit == SIZE_DELIMITED: + dump_varint(len(self), stream) for field_name, meta in self._betterproto.meta_by_field_name.items(): try: @@ -1207,7 +1217,11 @@ def _include_default_value_for_oneof( meta.group is not None and self._group_current.get(meta.group) == field_name ) - def load(self: T, stream: BinaryIO, size: Optional[int] = None) -> T: + def load( + self: T, + stream: "SupportsRead[bytes]", + size: Optional[int] = None, + ) -> T: """ Load the binary encoded Protobuf from a stream into this message instance. This returns the instance itself and is therefore assignable and chainable. @@ -1219,12 +1233,17 @@ def load(self: T, stream: BinaryIO, size: Optional[int] = None) -> T: size: :class:`Optional[int]` The size of the message in the stream. Reads stream until EOF if ``None`` is given. + Reads based on a size delimiter prefix varint if SIZE_DELIMITED is given. Returns -------- :class:`Message` The initialized message. """ + # If the message is delimited, parse the message delimiter + if size == SIZE_DELIMITED: + size, _ = load_varint(stream) + # Got some data over the wire self._serialized_on_wire = True proto_meta = self._betterproto @@ -1297,7 +1316,7 @@ def load(self: T, stream: BinaryIO, size: Optional[int] = None) -> T: return self - def parse(self: T, data: "ReadableBuffer") -> T: + def parse(self: T, data: bytes) -> T: """ Parse the binary encoded Protobuf into this message instance. This returns the instance itself and is therefore assignable and chainable. diff --git a/tests/streams/delimited_messages.in b/tests/streams/delimited_messages.in new file mode 100644 index 00000000..5993ac6f --- /dev/null +++ b/tests/streams/delimited_messages.in @@ -0,0 +1,2 @@ +•šï:bTesting•šï:bTesting +  \ No newline at end of file diff --git a/tests/streams/java/.gitignore b/tests/streams/java/.gitignore new file mode 100644 index 00000000..9b1ebba9 --- /dev/null +++ b/tests/streams/java/.gitignore @@ -0,0 +1,38 @@ +### Output ### +target/ +!.mvn/wrapper/maven-wrapper.jar +!**/src/main/**/target/ +!**/src/test/**/target/ +dependency-reduced-pom.xml +MANIFEST.MF + +### IntelliJ IDEA ### +.idea/ +*.iws +*.iml +*.ipr + +### Eclipse ### +.apt_generated +.classpath +.factorypath +.project +.settings +.springBeans +.sts4-cache + +### NetBeans ### +/nbproject/private/ +/nbbuild/ +/dist/ +/nbdist/ +/.nb-gradle/ +build/ +!**/src/main/**/build/ +!**/src/test/**/build/ + +### VS Code ### +.vscode/ + +### Mac OS ### +.DS_Store \ No newline at end of file diff --git a/tests/streams/java/pom.xml b/tests/streams/java/pom.xml new file mode 100644 index 00000000..170d2d66 --- /dev/null +++ b/tests/streams/java/pom.xml @@ -0,0 +1,94 @@ + + + 4.0.0 + + betterproto + compatibility-test + 1.0-SNAPSHOT + jar + + + 11 + 11 + UTF-8 + 3.23.4 + + + + + com.google.protobuf + protobuf-java + ${protobuf.version} + + + + + + + kr.motd.maven + os-maven-plugin + 1.7.1 + + + + + + org.apache.maven.plugins + maven-shade-plugin + 3.5.0 + + + package + + shade + + + + + betterproto.CompatibilityTest + + + + + + + + + org.apache.maven.plugins + maven-jar-plugin + 3.3.0 + + + + true + betterproto.CompatibilityTest + + + + + + + org.xolstice.maven.plugins + protobuf-maven-plugin + 0.6.1 + + + + compile + + + + + + com.google.protobuf:protoc:${protobuf.version}:exe:${os.detected.classifier} + + + + + + ${project.artifactId} + + + \ No newline at end of file diff --git a/tests/streams/java/src/main/java/betterproto/CompatibilityTest.java b/tests/streams/java/src/main/java/betterproto/CompatibilityTest.java new file mode 100644 index 00000000..908f87af --- /dev/null +++ b/tests/streams/java/src/main/java/betterproto/CompatibilityTest.java @@ -0,0 +1,41 @@ +package betterproto; + +import java.io.IOException; + +public class CompatibilityTest { + public static void main(String[] args) throws IOException { + if (args.length < 2) + throw new RuntimeException("Attempted to run without the required arguments."); + else if (args.length > 2) + throw new RuntimeException( + "Attempted to run with more than the expected number of arguments (>1)."); + + Tests tests = new Tests(args[1]); + + switch (args[0]) { + case "single_varint": + tests.testSingleVarint(); + break; + + case "multiple_varints": + tests.testMultipleVarints(); + break; + + case "single_message": + tests.testSingleMessage(); + break; + + case "multiple_messages": + tests.testMultipleMessages(); + break; + + case "infinite_messages": + tests.testInfiniteMessages(); + break; + + default: + throw new RuntimeException( + "Attempted to run with unknown argument '" + args[0] + "'."); + } + } +} diff --git a/tests/streams/java/src/main/java/betterproto/Tests.java b/tests/streams/java/src/main/java/betterproto/Tests.java new file mode 100644 index 00000000..a7c8fd57 --- /dev/null +++ b/tests/streams/java/src/main/java/betterproto/Tests.java @@ -0,0 +1,115 @@ +package betterproto; + +import betterproto.nested.NestedOuterClass; +import betterproto.oneof.Oneof; + +import com.google.protobuf.CodedInputStream; +import com.google.protobuf.CodedOutputStream; + +import java.io.FileInputStream; +import java.io.FileOutputStream; +import java.io.IOException; + +public class Tests { + String path; + + public Tests(String path) { + this.path = path; + } + + public void testSingleVarint() throws IOException { + // Read in the Python-generated single varint file + FileInputStream inputStream = new FileInputStream(path + "/py_single_varint.out"); + CodedInputStream codedInput = CodedInputStream.newInstance(inputStream); + + int value = codedInput.readUInt32(); + + inputStream.close(); + + // Write the value back to a file + FileOutputStream outputStream = new FileOutputStream(path + "/java_single_varint.out"); + CodedOutputStream codedOutput = CodedOutputStream.newInstance(outputStream); + + codedOutput.writeUInt32NoTag(value); + + codedOutput.flush(); + outputStream.close(); + } + + public void testMultipleVarints() throws IOException { + // Read in the Python-generated multiple varints file + FileInputStream inputStream = new FileInputStream(path + "/py_multiple_varints.out"); + CodedInputStream codedInput = CodedInputStream.newInstance(inputStream); + + int value1 = codedInput.readUInt32(); + int value2 = codedInput.readUInt32(); + long value3 = codedInput.readUInt64(); + + inputStream.close(); + + // Write the values back to a file + FileOutputStream outputStream = new FileOutputStream(path + "/java_multiple_varints.out"); + CodedOutputStream codedOutput = CodedOutputStream.newInstance(outputStream); + + codedOutput.writeUInt32NoTag(value1); + codedOutput.writeUInt64NoTag(value2); + codedOutput.writeUInt64NoTag(value3); + + codedOutput.flush(); + outputStream.close(); + } + + public void testSingleMessage() throws IOException { + // Read in the Python-generated single message file + FileInputStream inputStream = new FileInputStream(path + "/py_single_message.out"); + CodedInputStream codedInput = CodedInputStream.newInstance(inputStream); + + Oneof.Test message = Oneof.Test.parseFrom(codedInput); + + inputStream.close(); + + // Write the message back to a file + FileOutputStream outputStream = new FileOutputStream(path + "/java_single_message.out"); + CodedOutputStream codedOutput = CodedOutputStream.newInstance(outputStream); + + message.writeTo(codedOutput); + + codedOutput.flush(); + outputStream.close(); + } + + public void testMultipleMessages() throws IOException { + // Read in the Python-generated multi-message file + FileInputStream inputStream = new FileInputStream(path + "/py_multiple_messages.out"); + + Oneof.Test oneof = Oneof.Test.parseDelimitedFrom(inputStream); + NestedOuterClass.Test nested = NestedOuterClass.Test.parseDelimitedFrom(inputStream); + + inputStream.close(); + + // Write the messages back to a file + FileOutputStream outputStream = new FileOutputStream(path + "/java_multiple_messages.out"); + + oneof.writeDelimitedTo(outputStream); + nested.writeDelimitedTo(outputStream); + + outputStream.flush(); + outputStream.close(); + } + + public void testInfiniteMessages() throws IOException { + // Read in as many messages as are present in the Python-generated file and write them back + FileInputStream inputStream = new FileInputStream(path + "/py_infinite_messages.out"); + FileOutputStream outputStream = new FileOutputStream(path + "/java_infinite_messages.out"); + + Oneof.Test current = Oneof.Test.parseDelimitedFrom(inputStream); + while (current != null) { + current.writeDelimitedTo(outputStream); + current = Oneof.Test.parseDelimitedFrom(inputStream); + } + + inputStream.close(); + outputStream.flush(); + outputStream.close(); + } +} diff --git a/tests/streams/java/src/main/proto/betterproto/nested.proto b/tests/streams/java/src/main/proto/betterproto/nested.proto new file mode 100644 index 00000000..405a05a4 --- /dev/null +++ b/tests/streams/java/src/main/proto/betterproto/nested.proto @@ -0,0 +1,27 @@ +syntax = "proto3"; + +package nested; +option java_package = "betterproto.nested"; + +// A test message with a nested message inside of it. +message Test { + // This is the nested type. + message Nested { + // Stores a simple counter. + int32 count = 1; + } + // This is the nested enum. + enum Msg { + NONE = 0; + THIS = 1; + } + + Nested nested = 1; + Sibling sibling = 2; + Sibling sibling2 = 3; + Msg msg = 4; +} + +message Sibling { + int32 foo = 1; +} \ No newline at end of file diff --git a/tests/streams/java/src/main/proto/betterproto/oneof.proto b/tests/streams/java/src/main/proto/betterproto/oneof.proto new file mode 100644 index 00000000..ad21028c --- /dev/null +++ b/tests/streams/java/src/main/proto/betterproto/oneof.proto @@ -0,0 +1,19 @@ +syntax = "proto3"; + +package oneof; +option java_package = "betterproto.oneof"; + +message Test { + oneof foo { + int32 pitied = 1; + string pitier = 2; + } + + int32 just_a_regular_field = 3; + + oneof bar { + int32 drinks = 11; + string bar_name = 12; + } +} + diff --git a/tests/test_streams.py b/tests/test_streams.py index a1c2bbd9..3bb9aceb 100644 --- a/tests/test_streams.py +++ b/tests/test_streams.py @@ -1,6 +1,8 @@ from dataclasses import dataclass from io import BytesIO from pathlib import Path +from shutil import which +from subprocess import run from typing import Optional import pytest @@ -40,6 +42,8 @@ streams_path = Path("tests/streams/") +java = which("java") + def test_load_varint_too_long(): with BytesIO( @@ -127,6 +131,18 @@ def test_message_dump_file_multiple(tmp_path): assert test_stream.read() == exp_stream.read() +def test_message_dump_delimited(tmp_path): + with open(tmp_path / "message_dump_delimited.out", "wb") as stream: + oneof_example.dump(stream, betterproto.SIZE_DELIMITED) + oneof_example.dump(stream, betterproto.SIZE_DELIMITED) + nested_example.dump(stream, betterproto.SIZE_DELIMITED) + + with open(tmp_path / "message_dump_delimited.out", "rb") as test_stream, open( + streams_path / "delimited_messages.in", "rb" + ) as exp_stream: + assert test_stream.read() == exp_stream.read() + + def test_message_len(): assert len_oneof == len(bytes(oneof_example)) assert len(nested_example) == len(bytes(nested_example)) @@ -155,7 +171,15 @@ def test_message_load_too_small(): oneof.Test().load(stream, len_oneof - 1) -def test_message_too_large(): +def test_message_load_delimited(): + with open(streams_path / "delimited_messages.in", "rb") as stream: + assert oneof.Test().load(stream, betterproto.SIZE_DELIMITED) == oneof_example + assert oneof.Test().load(stream, betterproto.SIZE_DELIMITED) == oneof_example + assert nested.Test().load(stream, betterproto.SIZE_DELIMITED) == nested_example + assert stream.read(1) == b"" + + +def test_message_load_too_large(): with open( streams_path / "message_dump_file_single.expected", "rb" ) as stream, pytest.raises(ValueError): @@ -266,3 +290,145 @@ def test_dump_varint_positive(tmp_path): streams_path / "dump_varint_positive.expected", "rb" ) as exp_stream: assert test_stream.read() == exp_stream.read() + + +# Java compatibility tests + + +@pytest.fixture(scope="module") +def compile_jar(): + # Skip if not all required tools are present + if java is None: + pytest.skip("`java` command is absent and is required") + mvn = which("mvn") + if mvn is None: + pytest.skip("Maven is absent and is required") + + # Compile the JAR + proc_maven = run([mvn, "clean", "install", "-f", "tests/streams/java/pom.xml"]) + if proc_maven.returncode != 0: + pytest.skip( + "Maven compatibility-test.jar build failed (maybe Java version <11?)" + ) + + +jar = "tests/streams/java/target/compatibility-test.jar" + + +def run_jar(command: str, tmp_path): + return run([java, "-jar", jar, command, tmp_path], check=True) + + +def run_java_single_varint(value: int, tmp_path) -> int: + # Write single varint to file + with open(tmp_path / "py_single_varint.out", "wb") as stream: + betterproto.dump_varint(value, stream) + + # Have Java read this varint and write it back + run_jar("single_varint", tmp_path) + + # Read single varint from Java output file + with open(tmp_path / "java_single_varint.out", "rb") as stream: + returned = betterproto.load_varint(stream) + with pytest.raises(EOFError): + betterproto.load_varint(stream) + + return returned + + +def test_single_varint(compile_jar, tmp_path): + single_byte = (1, b"\x01") + multi_byte = (123456789, b"\x95\x9A\xEF\x3A") + + # Write a single-byte varint to a file and have Java read it back + returned = run_java_single_varint(single_byte[0], tmp_path) + assert returned == single_byte + + # Same for a multi-byte varint + returned = run_java_single_varint(multi_byte[0], tmp_path) + assert returned == multi_byte + + +def test_multiple_varints(compile_jar, tmp_path): + single_byte = (1, b"\x01") + multi_byte = (123456789, b"\x95\x9A\xEF\x3A") + over32 = (3000000000, b"\x80\xBC\xC1\x96\x0B") + + # Write two varints to the same file + with open(tmp_path / "py_multiple_varints.out", "wb") as stream: + betterproto.dump_varint(single_byte[0], stream) + betterproto.dump_varint(multi_byte[0], stream) + betterproto.dump_varint(over32[0], stream) + + # Have Java read these varints and write them back + run_jar("multiple_varints", tmp_path) + + # Read varints from Java output file + with open(tmp_path / "java_multiple_varints.out", "rb") as stream: + returned_single = betterproto.load_varint(stream) + returned_multi = betterproto.load_varint(stream) + returned_over32 = betterproto.load_varint(stream) + with pytest.raises(EOFError): + betterproto.load_varint(stream) + + assert returned_single == single_byte + assert returned_multi == multi_byte + assert returned_over32 == over32 + + +def test_single_message(compile_jar, tmp_path): + # Write message to file + with open(tmp_path / "py_single_message.out", "wb") as stream: + oneof_example.dump(stream) + + # Have Java read and return the message + run_jar("single_message", tmp_path) + + # Read and check the returned message + with open(tmp_path / "java_single_message.out", "rb") as stream: + returned = oneof.Test().load(stream, len(bytes(oneof_example))) + assert stream.read() == b"" + + assert returned == oneof_example + + +def test_multiple_messages(compile_jar, tmp_path): + # Write delimited messages to file + with open(tmp_path / "py_multiple_messages.out", "wb") as stream: + oneof_example.dump(stream, betterproto.SIZE_DELIMITED) + nested_example.dump(stream, betterproto.SIZE_DELIMITED) + + # Have Java read and return the messages + run_jar("multiple_messages", tmp_path) + + # Read and check the returned messages + with open(tmp_path / "java_multiple_messages.out", "rb") as stream: + returned_oneof = oneof.Test().load(stream, betterproto.SIZE_DELIMITED) + returned_nested = nested.Test().load(stream, betterproto.SIZE_DELIMITED) + assert stream.read() == b"" + + assert returned_oneof == oneof_example + assert returned_nested == nested_example + + +def test_infinite_messages(compile_jar, tmp_path): + num_messages = 5 + + # Write delimited messages to file + with open(tmp_path / "py_infinite_messages.out", "wb") as stream: + for x in range(num_messages): + oneof_example.dump(stream, betterproto.SIZE_DELIMITED) + + # Have Java read and return the messages + run_jar("infinite_messages", tmp_path) + + # Read and check the returned messages + messages = [] + with open(tmp_path / "java_infinite_messages.out", "rb") as stream: + while True: + try: + messages.append(oneof.Test().load(stream, betterproto.SIZE_DELIMITED)) + except EOFError: + break + + assert len(messages) == num_messages