diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index fa461c82..9e3f65a8 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -16,6 +16,13 @@ repos: - repo: https://github.com/PyCQA/doc8 rev: 0.10.1 hooks: - - id: doc8 + - id: doc8 additional_dependencies: - toml + + - repo: https://github.com/macisamuele/language-formatters-pre-commit-hooks + rev: v2.10.0 + hooks: + - id: pretty-format-java + args: [--autofix, --aosp] + files: ^.*\.java$ diff --git a/src/betterproto/__init__.py b/src/betterproto/__init__.py index bfda1c0b..35acd788 100644 --- a/src/betterproto/__init__.py +++ b/src/betterproto/__init__.py @@ -50,7 +50,10 @@ if TYPE_CHECKING: - from _typeshed import ReadableBuffer + from _typeshed import ( + SupportsRead, + SupportsWrite, + ) # Proto 3 data types @@ -127,6 +130,9 @@ WIRE_FIXED_64_TYPES = [TYPE_DOUBLE, TYPE_FIXED64, TYPE_SFIXED64] WIRE_LEN_DELIM_TYPES = [TYPE_STRING, TYPE_BYTES, TYPE_MESSAGE, TYPE_MAP] +# Indicator of message delimitation in streams +SIZE_DELIMITED = -1 + # Protobuf datetimes start at the Unix Epoch in 1970 in UTC. def datetime_default_gen() -> datetime: @@ -322,7 +328,7 @@ def _pack_fmt(proto_type: str) -> str: }[proto_type] -def dump_varint(value: int, stream: BinaryIO) -> None: +def dump_varint(value: int, stream: "SupportsWrite[bytes]") -> None: """Encodes a single varint and dumps it into the provided stream.""" if value < -(1 << 63): raise ValueError( @@ -531,7 +537,7 @@ def _dump_float(value: float) -> Union[float, str]: return value -def load_varint(stream: BinaryIO) -> Tuple[int, bytes]: +def load_varint(stream: "SupportsRead[bytes]") -> Tuple[int, bytes]: """ Load a single varint value from a stream. Returns the value and the raw bytes read. """ @@ -569,7 +575,7 @@ class ParsedField: raw: bytes -def load_fields(stream: BinaryIO) -> Generator[ParsedField, None, None]: +def load_fields(stream: "SupportsRead[bytes]") -> Generator[ParsedField, None, None]: while True: try: num_wire, raw = load_varint(stream) @@ -881,7 +887,7 @@ def _betterproto(self) -> ProtoClassMetadata: self.__class__._betterproto_meta = meta # type: ignore return meta - def dump(self, stream: BinaryIO) -> None: + def dump(self, stream: "SupportsWrite[bytes]", delimit: bool = False) -> None: """ Dumps the binary encoded Protobuf message to the stream. @@ -889,7 +895,11 @@ def dump(self, stream: BinaryIO) -> None: ----------- stream: :class:`BinaryIO` The stream to dump the message to. + delimit: + Whether to prefix the message with a varint declaring its size. """ + if delimit == SIZE_DELIMITED: + dump_varint(len(self), stream) for field_name, meta in self._betterproto.meta_by_field_name.items(): try: @@ -1207,7 +1217,11 @@ def _include_default_value_for_oneof( meta.group is not None and self._group_current.get(meta.group) == field_name ) - def load(self: T, stream: BinaryIO, size: Optional[int] = None) -> T: + def load( + self: T, + stream: "SupportsRead[bytes]", + size: Optional[int] = None, + ) -> T: """ Load the binary encoded Protobuf from a stream into this message instance. This returns the instance itself and is therefore assignable and chainable. @@ -1219,12 +1233,17 @@ def load(self: T, stream: BinaryIO, size: Optional[int] = None) -> T: size: :class:`Optional[int]` The size of the message in the stream. Reads stream until EOF if ``None`` is given. + Reads based on a size delimiter prefix varint if SIZE_DELIMITED is given. Returns -------- :class:`Message` The initialized message. """ + # If the message is delimited, parse the message delimiter + if size == SIZE_DELIMITED: + size, _ = load_varint(stream) + # Got some data over the wire self._serialized_on_wire = True proto_meta = self._betterproto @@ -1297,7 +1316,7 @@ def load(self: T, stream: BinaryIO, size: Optional[int] = None) -> T: return self - def parse(self: T, data: "ReadableBuffer") -> T: + def parse(self: T, data: bytes) -> T: """ Parse the binary encoded Protobuf into this message instance. This returns the instance itself and is therefore assignable and chainable. diff --git a/tests/streams/delimited_messages.in b/tests/streams/delimited_messages.in new file mode 100644 index 00000000..5993ac6f --- /dev/null +++ b/tests/streams/delimited_messages.in @@ -0,0 +1,2 @@ +•šï:bTesting•šï:bTesting +  \ No newline at end of file diff --git a/tests/streams/java/.gitignore b/tests/streams/java/.gitignore new file mode 100644 index 00000000..9b1ebba9 --- /dev/null +++ b/tests/streams/java/.gitignore @@ -0,0 +1,38 @@ +### Output ### +target/ +!.mvn/wrapper/maven-wrapper.jar +!**/src/main/**/target/ +!**/src/test/**/target/ +dependency-reduced-pom.xml +MANIFEST.MF + +### IntelliJ IDEA ### +.idea/ +*.iws +*.iml +*.ipr + +### Eclipse ### +.apt_generated +.classpath +.factorypath +.project +.settings +.springBeans +.sts4-cache + +### NetBeans ### +/nbproject/private/ +/nbbuild/ +/dist/ +/nbdist/ +/.nb-gradle/ +build/ +!**/src/main/**/build/ +!**/src/test/**/build/ + +### VS Code ### +.vscode/ + +### Mac OS ### +.DS_Store \ No newline at end of file diff --git a/tests/streams/java/pom.xml b/tests/streams/java/pom.xml new file mode 100644 index 00000000..170d2d66 --- /dev/null +++ b/tests/streams/java/pom.xml @@ -0,0 +1,94 @@ + + + 4.0.0 + + betterproto + compatibility-test + 1.0-SNAPSHOT + jar + + + 11 + 11 + UTF-8 + 3.23.4 + + + + + com.google.protobuf + protobuf-java + ${protobuf.version} + + + + + + + kr.motd.maven + os-maven-plugin + 1.7.1 + + + + + + org.apache.maven.plugins + maven-shade-plugin + 3.5.0 + + + package + + shade + + + + + betterproto.CompatibilityTest + + + + + + + + + org.apache.maven.plugins + maven-jar-plugin + 3.3.0 + + + + true + betterproto.CompatibilityTest + + + + + + + org.xolstice.maven.plugins + protobuf-maven-plugin + 0.6.1 + + + + compile + + + + + + com.google.protobuf:protoc:${protobuf.version}:exe:${os.detected.classifier} + + + + + + ${project.artifactId} + + + \ No newline at end of file diff --git a/tests/streams/java/src/main/java/betterproto/CompatibilityTest.java b/tests/streams/java/src/main/java/betterproto/CompatibilityTest.java new file mode 100644 index 00000000..908f87af --- /dev/null +++ b/tests/streams/java/src/main/java/betterproto/CompatibilityTest.java @@ -0,0 +1,41 @@ +package betterproto; + +import java.io.IOException; + +public class CompatibilityTest { + public static void main(String[] args) throws IOException { + if (args.length < 2) + throw new RuntimeException("Attempted to run without the required arguments."); + else if (args.length > 2) + throw new RuntimeException( + "Attempted to run with more than the expected number of arguments (>1)."); + + Tests tests = new Tests(args[1]); + + switch (args[0]) { + case "single_varint": + tests.testSingleVarint(); + break; + + case "multiple_varints": + tests.testMultipleVarints(); + break; + + case "single_message": + tests.testSingleMessage(); + break; + + case "multiple_messages": + tests.testMultipleMessages(); + break; + + case "infinite_messages": + tests.testInfiniteMessages(); + break; + + default: + throw new RuntimeException( + "Attempted to run with unknown argument '" + args[0] + "'."); + } + } +} diff --git a/tests/streams/java/src/main/java/betterproto/Tests.java b/tests/streams/java/src/main/java/betterproto/Tests.java new file mode 100644 index 00000000..a7c8fd57 --- /dev/null +++ b/tests/streams/java/src/main/java/betterproto/Tests.java @@ -0,0 +1,115 @@ +package betterproto; + +import betterproto.nested.NestedOuterClass; +import betterproto.oneof.Oneof; + +import com.google.protobuf.CodedInputStream; +import com.google.protobuf.CodedOutputStream; + +import java.io.FileInputStream; +import java.io.FileOutputStream; +import java.io.IOException; + +public class Tests { + String path; + + public Tests(String path) { + this.path = path; + } + + public void testSingleVarint() throws IOException { + // Read in the Python-generated single varint file + FileInputStream inputStream = new FileInputStream(path + "/py_single_varint.out"); + CodedInputStream codedInput = CodedInputStream.newInstance(inputStream); + + int value = codedInput.readUInt32(); + + inputStream.close(); + + // Write the value back to a file + FileOutputStream outputStream = new FileOutputStream(path + "/java_single_varint.out"); + CodedOutputStream codedOutput = CodedOutputStream.newInstance(outputStream); + + codedOutput.writeUInt32NoTag(value); + + codedOutput.flush(); + outputStream.close(); + } + + public void testMultipleVarints() throws IOException { + // Read in the Python-generated multiple varints file + FileInputStream inputStream = new FileInputStream(path + "/py_multiple_varints.out"); + CodedInputStream codedInput = CodedInputStream.newInstance(inputStream); + + int value1 = codedInput.readUInt32(); + int value2 = codedInput.readUInt32(); + long value3 = codedInput.readUInt64(); + + inputStream.close(); + + // Write the values back to a file + FileOutputStream outputStream = new FileOutputStream(path + "/java_multiple_varints.out"); + CodedOutputStream codedOutput = CodedOutputStream.newInstance(outputStream); + + codedOutput.writeUInt32NoTag(value1); + codedOutput.writeUInt64NoTag(value2); + codedOutput.writeUInt64NoTag(value3); + + codedOutput.flush(); + outputStream.close(); + } + + public void testSingleMessage() throws IOException { + // Read in the Python-generated single message file + FileInputStream inputStream = new FileInputStream(path + "/py_single_message.out"); + CodedInputStream codedInput = CodedInputStream.newInstance(inputStream); + + Oneof.Test message = Oneof.Test.parseFrom(codedInput); + + inputStream.close(); + + // Write the message back to a file + FileOutputStream outputStream = new FileOutputStream(path + "/java_single_message.out"); + CodedOutputStream codedOutput = CodedOutputStream.newInstance(outputStream); + + message.writeTo(codedOutput); + + codedOutput.flush(); + outputStream.close(); + } + + public void testMultipleMessages() throws IOException { + // Read in the Python-generated multi-message file + FileInputStream inputStream = new FileInputStream(path + "/py_multiple_messages.out"); + + Oneof.Test oneof = Oneof.Test.parseDelimitedFrom(inputStream); + NestedOuterClass.Test nested = NestedOuterClass.Test.parseDelimitedFrom(inputStream); + + inputStream.close(); + + // Write the messages back to a file + FileOutputStream outputStream = new FileOutputStream(path + "/java_multiple_messages.out"); + + oneof.writeDelimitedTo(outputStream); + nested.writeDelimitedTo(outputStream); + + outputStream.flush(); + outputStream.close(); + } + + public void testInfiniteMessages() throws IOException { + // Read in as many messages as are present in the Python-generated file and write them back + FileInputStream inputStream = new FileInputStream(path + "/py_infinite_messages.out"); + FileOutputStream outputStream = new FileOutputStream(path + "/java_infinite_messages.out"); + + Oneof.Test current = Oneof.Test.parseDelimitedFrom(inputStream); + while (current != null) { + current.writeDelimitedTo(outputStream); + current = Oneof.Test.parseDelimitedFrom(inputStream); + } + + inputStream.close(); + outputStream.flush(); + outputStream.close(); + } +} diff --git a/tests/streams/java/src/main/proto/betterproto/nested.proto b/tests/streams/java/src/main/proto/betterproto/nested.proto new file mode 100644 index 00000000..405a05a4 --- /dev/null +++ b/tests/streams/java/src/main/proto/betterproto/nested.proto @@ -0,0 +1,27 @@ +syntax = "proto3"; + +package nested; +option java_package = "betterproto.nested"; + +// A test message with a nested message inside of it. +message Test { + // This is the nested type. + message Nested { + // Stores a simple counter. + int32 count = 1; + } + // This is the nested enum. + enum Msg { + NONE = 0; + THIS = 1; + } + + Nested nested = 1; + Sibling sibling = 2; + Sibling sibling2 = 3; + Msg msg = 4; +} + +message Sibling { + int32 foo = 1; +} \ No newline at end of file diff --git a/tests/streams/java/src/main/proto/betterproto/oneof.proto b/tests/streams/java/src/main/proto/betterproto/oneof.proto new file mode 100644 index 00000000..ad21028c --- /dev/null +++ b/tests/streams/java/src/main/proto/betterproto/oneof.proto @@ -0,0 +1,19 @@ +syntax = "proto3"; + +package oneof; +option java_package = "betterproto.oneof"; + +message Test { + oneof foo { + int32 pitied = 1; + string pitier = 2; + } + + int32 just_a_regular_field = 3; + + oneof bar { + int32 drinks = 11; + string bar_name = 12; + } +} + diff --git a/tests/test_streams.py b/tests/test_streams.py index a1c2bbd9..3bb9aceb 100644 --- a/tests/test_streams.py +++ b/tests/test_streams.py @@ -1,6 +1,8 @@ from dataclasses import dataclass from io import BytesIO from pathlib import Path +from shutil import which +from subprocess import run from typing import Optional import pytest @@ -40,6 +42,8 @@ streams_path = Path("tests/streams/") +java = which("java") + def test_load_varint_too_long(): with BytesIO( @@ -127,6 +131,18 @@ def test_message_dump_file_multiple(tmp_path): assert test_stream.read() == exp_stream.read() +def test_message_dump_delimited(tmp_path): + with open(tmp_path / "message_dump_delimited.out", "wb") as stream: + oneof_example.dump(stream, betterproto.SIZE_DELIMITED) + oneof_example.dump(stream, betterproto.SIZE_DELIMITED) + nested_example.dump(stream, betterproto.SIZE_DELIMITED) + + with open(tmp_path / "message_dump_delimited.out", "rb") as test_stream, open( + streams_path / "delimited_messages.in", "rb" + ) as exp_stream: + assert test_stream.read() == exp_stream.read() + + def test_message_len(): assert len_oneof == len(bytes(oneof_example)) assert len(nested_example) == len(bytes(nested_example)) @@ -155,7 +171,15 @@ def test_message_load_too_small(): oneof.Test().load(stream, len_oneof - 1) -def test_message_too_large(): +def test_message_load_delimited(): + with open(streams_path / "delimited_messages.in", "rb") as stream: + assert oneof.Test().load(stream, betterproto.SIZE_DELIMITED) == oneof_example + assert oneof.Test().load(stream, betterproto.SIZE_DELIMITED) == oneof_example + assert nested.Test().load(stream, betterproto.SIZE_DELIMITED) == nested_example + assert stream.read(1) == b"" + + +def test_message_load_too_large(): with open( streams_path / "message_dump_file_single.expected", "rb" ) as stream, pytest.raises(ValueError): @@ -266,3 +290,145 @@ def test_dump_varint_positive(tmp_path): streams_path / "dump_varint_positive.expected", "rb" ) as exp_stream: assert test_stream.read() == exp_stream.read() + + +# Java compatibility tests + + +@pytest.fixture(scope="module") +def compile_jar(): + # Skip if not all required tools are present + if java is None: + pytest.skip("`java` command is absent and is required") + mvn = which("mvn") + if mvn is None: + pytest.skip("Maven is absent and is required") + + # Compile the JAR + proc_maven = run([mvn, "clean", "install", "-f", "tests/streams/java/pom.xml"]) + if proc_maven.returncode != 0: + pytest.skip( + "Maven compatibility-test.jar build failed (maybe Java version <11?)" + ) + + +jar = "tests/streams/java/target/compatibility-test.jar" + + +def run_jar(command: str, tmp_path): + return run([java, "-jar", jar, command, tmp_path], check=True) + + +def run_java_single_varint(value: int, tmp_path) -> int: + # Write single varint to file + with open(tmp_path / "py_single_varint.out", "wb") as stream: + betterproto.dump_varint(value, stream) + + # Have Java read this varint and write it back + run_jar("single_varint", tmp_path) + + # Read single varint from Java output file + with open(tmp_path / "java_single_varint.out", "rb") as stream: + returned = betterproto.load_varint(stream) + with pytest.raises(EOFError): + betterproto.load_varint(stream) + + return returned + + +def test_single_varint(compile_jar, tmp_path): + single_byte = (1, b"\x01") + multi_byte = (123456789, b"\x95\x9A\xEF\x3A") + + # Write a single-byte varint to a file and have Java read it back + returned = run_java_single_varint(single_byte[0], tmp_path) + assert returned == single_byte + + # Same for a multi-byte varint + returned = run_java_single_varint(multi_byte[0], tmp_path) + assert returned == multi_byte + + +def test_multiple_varints(compile_jar, tmp_path): + single_byte = (1, b"\x01") + multi_byte = (123456789, b"\x95\x9A\xEF\x3A") + over32 = (3000000000, b"\x80\xBC\xC1\x96\x0B") + + # Write two varints to the same file + with open(tmp_path / "py_multiple_varints.out", "wb") as stream: + betterproto.dump_varint(single_byte[0], stream) + betterproto.dump_varint(multi_byte[0], stream) + betterproto.dump_varint(over32[0], stream) + + # Have Java read these varints and write them back + run_jar("multiple_varints", tmp_path) + + # Read varints from Java output file + with open(tmp_path / "java_multiple_varints.out", "rb") as stream: + returned_single = betterproto.load_varint(stream) + returned_multi = betterproto.load_varint(stream) + returned_over32 = betterproto.load_varint(stream) + with pytest.raises(EOFError): + betterproto.load_varint(stream) + + assert returned_single == single_byte + assert returned_multi == multi_byte + assert returned_over32 == over32 + + +def test_single_message(compile_jar, tmp_path): + # Write message to file + with open(tmp_path / "py_single_message.out", "wb") as stream: + oneof_example.dump(stream) + + # Have Java read and return the message + run_jar("single_message", tmp_path) + + # Read and check the returned message + with open(tmp_path / "java_single_message.out", "rb") as stream: + returned = oneof.Test().load(stream, len(bytes(oneof_example))) + assert stream.read() == b"" + + assert returned == oneof_example + + +def test_multiple_messages(compile_jar, tmp_path): + # Write delimited messages to file + with open(tmp_path / "py_multiple_messages.out", "wb") as stream: + oneof_example.dump(stream, betterproto.SIZE_DELIMITED) + nested_example.dump(stream, betterproto.SIZE_DELIMITED) + + # Have Java read and return the messages + run_jar("multiple_messages", tmp_path) + + # Read and check the returned messages + with open(tmp_path / "java_multiple_messages.out", "rb") as stream: + returned_oneof = oneof.Test().load(stream, betterproto.SIZE_DELIMITED) + returned_nested = nested.Test().load(stream, betterproto.SIZE_DELIMITED) + assert stream.read() == b"" + + assert returned_oneof == oneof_example + assert returned_nested == nested_example + + +def test_infinite_messages(compile_jar, tmp_path): + num_messages = 5 + + # Write delimited messages to file + with open(tmp_path / "py_infinite_messages.out", "wb") as stream: + for x in range(num_messages): + oneof_example.dump(stream, betterproto.SIZE_DELIMITED) + + # Have Java read and return the messages + run_jar("infinite_messages", tmp_path) + + # Read and check the returned messages + messages = [] + with open(tmp_path / "java_infinite_messages.out", "rb") as stream: + while True: + try: + messages.append(oneof.Test().load(stream, betterproto.SIZE_DELIMITED)) + except EOFError: + break + + assert len(messages) == num_messages