Skip to content

Commit

Permalink
Merge branch 'master' into types
Browse files Browse the repository at this point in the history
  • Loading branch information
Gobot1234 authored Mar 24, 2024
2 parents 13b2c30 + 126b256 commit c84071e
Show file tree
Hide file tree
Showing 9 changed files with 139 additions and 49 deletions.
8 changes: 4 additions & 4 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -16,19 +16,19 @@ jobs:
fail-fast: false
matrix:
os: [Ubuntu, MacOS, Windows]
python-version: ["3.7", "3.8", "3.9", "3.10", "3.11"]
python-version: ['3.7', '3.8', '3.9', '3.10', '3.11', '3.12']
steps:
- uses: actions/checkout@v3
- uses: actions/checkout@v4

- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v4
uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python-version }}

- name: Get full Python version
id: full-python-version
shell: bash
run: echo ::set-output name=version::$(python -c "import sys; print('-'.join(str(v) for v in sys.version_info))")
run: echo "version=$(python -c "import sys; print('-'.join(str(v) for v in sys.version_info))")" >> "$GITHUB_OUTPUT"

- name: Install poetry
shell: bash
Expand Down
23 changes: 16 additions & 7 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -277,7 +277,22 @@ message Test {
}
```

You can use `betterproto.which_one_of(message, group_name)` to determine which of the fields was set. It returns a tuple of the field name and value, or a blank string and `None` if unset.
On Python 3.10 and later, you can use a `match` statement to access the provided one-of field, which supports type-checking:

```py
test = Test()
match test:
case Test(on=value):
print(value) # value: bool
case Test(count=value):
print(value) # value: int
case Test(name=value):
print(value) # value: str
case _:
print("No value provided")
```

You can also use `betterproto.which_one_of(message, group_name)` to determine which of the fields was set. It returns a tuple of the field name and value, or a blank string and `None` if unset.

```py
>>> test = Test()
Expand All @@ -292,17 +307,11 @@ You can use `betterproto.which_one_of(message, group_name)` to determine which o
>>> test.count = 57
>>> betterproto.which_one_of(test, "foo")
["count", 57]
>>> test.on
False

# Default (zero) values also work.
>>> test.name = ""
>>> betterproto.which_one_of(test, "foo")
["name", ""]
>>> test.count
0
>>> test.on
False
```

Again this is a little different than the official Google code generator:
Expand Down
53 changes: 21 additions & 32 deletions benchmarks/benchmarks.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,32 +6,32 @@

@dataclass
class TestMessage(betterproto.Message):
foo: int = betterproto.uint32_field(0)
bar: str = betterproto.string_field(1)
baz: float = betterproto.float_field(2)
foo: int = betterproto.uint32_field(1)
bar: str = betterproto.string_field(2)
baz: float = betterproto.float_field(3)


@dataclass
class TestNestedChildMessage(betterproto.Message):
str_key: str = betterproto.string_field(0)
bytes_key: bytes = betterproto.bytes_field(1)
bool_key: bool = betterproto.bool_field(2)
float_key: float = betterproto.float_field(3)
int_key: int = betterproto.uint64_field(4)
str_key: str = betterproto.string_field(1)
bytes_key: bytes = betterproto.bytes_field(2)
bool_key: bool = betterproto.bool_field(3)
float_key: float = betterproto.float_field(4)
int_key: int = betterproto.uint64_field(5)


@dataclass
class TestNestedMessage(betterproto.Message):
foo: TestNestedChildMessage = betterproto.message_field(0)
bar: TestNestedChildMessage = betterproto.message_field(1)
baz: TestNestedChildMessage = betterproto.message_field(2)
foo: TestNestedChildMessage = betterproto.message_field(1)
bar: TestNestedChildMessage = betterproto.message_field(2)
baz: TestNestedChildMessage = betterproto.message_field(3)


@dataclass
class TestRepeatedMessage(betterproto.Message):
foo_repeat: List[str] = betterproto.string_field(0)
bar_repeat: List[int] = betterproto.int64_field(1)
baz_repeat: List[bool] = betterproto.bool_field(2)
foo_repeat: List[str] = betterproto.string_field(1)
bar_repeat: List[int] = betterproto.int64_field(2)
baz_repeat: List[bool] = betterproto.bool_field(3)


class BenchMessage:
Expand All @@ -44,25 +44,14 @@ def setup(self):
self.instance_filled_bytes = bytes(self.instance_filled)
self.instance_filled_nested = TestNestedMessage(
TestNestedChildMessage("foo", bytearray(b"test1"), True, 0.1234, 500),
TestNestedChildMessage("bar", bytearray(b"test2"), True, 3.1415, -302),
TestNestedChildMessage("bar", bytearray(b"test2"), True, 3.1415, 302),
TestNestedChildMessage("baz", bytearray(b"test3"), False, 1e5, 300),
)
self.instance_filled_nested_bytes = bytes(self.instance_filled_nested)
self.instance_filled_repeated = TestRepeatedMessage(
[
"test1",
"test2",
"test3",
"test4",
"test5",
"test6",
"test7",
"test8",
"test9",
"test10",
],
[2, -100, 0, 500000, 600, -425678, 1000000000, -300, 1, -694214214466],
[True, False, False, False, True, True, False, True, False, False],
[f"test{i}" for i in range(1_000)],
[(i - 500) ** 3 for i in range(1_000)],
[i % 2 == 0 for i in range(1_000)],
)
self.instance_filled_repeated_bytes = bytes(self.instance_filled_repeated)

Expand All @@ -71,9 +60,9 @@ def time_overhead(self):

@dataclass
class Message(betterproto.Message):
foo: int = betterproto.uint32_field(0)
bar: str = betterproto.string_field(1)
baz: float = betterproto.float_field(2)
foo: int = betterproto.uint32_field(1)
bar: str = betterproto.string_field(2)
baz: float = betterproto.float_field(3)

def time_instantiation(self):
"""Time instantiation"""
Expand Down
23 changes: 22 additions & 1 deletion poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ jinja2 = { version = ">=3.0.3", optional = true }
python-dateutil = "^2.8"
isort = {version = "^5.11.5", optional = true}
typing-extensions = "^4.7.1"
betterproto-rust-codec = { version = "0.1.0", optional = true }

[tool.poetry.group.dev]
optional = true
Expand Down Expand Up @@ -52,6 +53,7 @@ protoc-gen-python_betterproto = "betterproto.plugin:main"

[tool.poetry.extras]
compiler = ["black", "isort", "jinja2"]
rust-codec = ["betterproto-rust-codec"]


# Dev workflow tasks
Expand Down
19 changes: 18 additions & 1 deletion src/betterproto/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -732,7 +732,7 @@ def __post_init__(self) -> None:
group_current.setdefault(meta.group)

value = self.__raw_get(field_name)
if value != PLACEHOLDER and not (meta.optional and value is None):
if value is not PLACEHOLDER and not (meta.optional and value is None):
# Found a non-sentinel value
all_sentinel = False

Expand Down Expand Up @@ -1862,6 +1862,23 @@ def _validate_field_groups(cls, values):
{}
) # HACK to avoid typing.get_type_hints breaking because we have to manually pass globals

# monkey patch (de-)serialization functions of class `Message`
# with functions from `betterproto-rust-codec` if available
try:
import betterproto_rust_codec

def __parse_patch(self: T, data: bytes) -> T:
betterproto_rust_codec.deserialize(self, data)
return self

def __bytes_patch(self) -> bytes:
return betterproto_rust_codec.serialize(self)

Message.parse = __parse_patch
Message.__bytes__ = __bytes_patch
except ModuleNotFoundError:
pass


def serialized_on_wire(message: Message) -> bool:
"""
Expand Down
31 changes: 31 additions & 0 deletions src/betterproto/lib/google/protobuf/__init__.py

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

5 changes: 1 addition & 4 deletions src/betterproto/plugin/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,10 +143,7 @@ def get_comment(
pad = " " * indent
for sci_loc in proto_file.source_code_info.location:
if list(sci_loc.path) == path and sci_loc.leading_comments:
lines = textwrap.wrap(
sci_loc.leading_comments.strip().replace("\n", ""), width=79 - indent
)

lines = sci_loc.leading_comments.strip().split("\n")
# This is a field, message, enum, service, or method
if len(lines) == 1 and len(lines[0]) < 79 - indent - 6:
lines[0] = lines[0].strip('"')
Expand Down
24 changes: 24 additions & 0 deletions tests/test_struct.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
import json

from betterproto.lib.google.protobuf import Struct


def test_struct_roundtrip():
data = {
"foo": "bar",
"baz": None,
"quux": 123,
"zap": [1, {"two": 3}, "four"],
}
data_json = json.dumps(data)

struct_from_dict = Struct().from_dict(data)
assert struct_from_dict.fields == data
assert struct_from_dict.to_dict() == data
assert struct_from_dict.to_json() == data_json

struct_from_json = Struct().from_json(data_json)
assert struct_from_json.fields == data
assert struct_from_json.to_dict() == data
assert struct_from_json == struct_from_dict
assert struct_from_json.to_json() == data_json

0 comments on commit c84071e

Please sign in to comment.