From ca47cc500b19c16be681c116cd5a443a2e701531 Mon Sep 17 00:00:00 2001 From: Alexander Khabarov Date: Tue, 11 Jul 2023 11:57:35 +0100 Subject: [PATCH] Raise `AttributeError` on attempts to access unset `oneof` fields This commit modifies `Message.__getattribute__` to raise `AttributeError` whenever an attempt is made to access an unset `oneof` field. This provides several benefits over the current approach: * There is no longer any risk of `betterproto` users accidentally relying on values of unset fields. * Pattern matching with `match/case` on messages containing `oneof` groups is now supported. The following is now possible: ``` @dataclasses.dataclass(eq=False, repr=False) class Test(betterproto.Message): x: int = betterproto.int32_field(1, group="g") y: str = betterproto.string_field(2, group="g") match Test(y="text"): case Test(x=v): print("x", v) case Test(y=v): print("y", v) ``` Before this commit the code above would output `x 0` instead of `y text`, but now the output is `y text` as expected. The reason this works is because an `AttributeError` in a `case` pattern does not propagate and instead simply skips the `case`. * We now have a type-checkable way to deconstruct `oneof`. When running `mypy` for the snippet above `v` has type `int` in the first `case` and type `str` in the second `case`. For versions of Python that do not support `match/case` (before 3.10) it is now possbile to use `try/except/else` blocks to achieve the same result: ``` t = Test(y="text") try: v0: int = t.x except AttributeError: v1: str = t.y # `oneof` contains `y` else: pass # `oneof` contains `x` ``` This is a breaking change. --- src/betterproto/__init__.py | 39 ++++++++++++++----- .../test_google_impl_behavior_equivalence.py | 6 ++- tests/inputs/oneof_enum/test_oneof_enum.py | 13 +++---- tests/test_features.py | 9 +++-- 4 files changed, 45 insertions(+), 22 deletions(-) diff --git a/src/betterproto/__init__.py b/src/betterproto/__init__.py index f22a8f7cb..8bc754781 100644 --- a/src/betterproto/__init__.py +++ b/src/betterproto/__init__.py @@ -693,8 +693,21 @@ def __repr__(self) -> str: def __getattribute__(self, name: str) -> Any: """ Lazily initialize default values to avoid infinite recursion for recursive - message types + message types. + Raise :class:`AttributeError` on attempts to access unset ``oneof`` fields. """ + try: + group_current = super().__getattribute__("_group_current") + except AttributeError: + pass + else: + if name not in {"__class__", "_betterproto"}: + group = self._betterproto.oneof_group_by_field.get(name) + if group is not None and group_current[group] != name: + raise AttributeError( + f"'{self.__class__.__name__}.{group}' is set to '{group_current[group]}', not '{name}'" + ) + value = super().__getattribute__(name) if value is not PLACEHOLDER: return value @@ -761,7 +774,10 @@ def __bytes__(self) -> bytes: """ output = bytearray() for field_name, meta in self._betterproto.meta_by_field_name.items(): - value = getattr(self, field_name) + try: + value = getattr(self, field_name) + except AttributeError: + continue if value is None: # Optional items should be skipped. This is used for the Google @@ -775,9 +791,7 @@ def __bytes__(self) -> bytes: # Note that proto3 field presence/optional fields are put in a # synthetic single-item oneof by protoc, which helps us ensure we # send the value even if the value is the default zero value. - selected_in_group = ( - meta.group and self._group_current[meta.group] == field_name - ) + selected_in_group = bool(meta.group) # Empty messages can still be sent on the wire if they were # set (or received empty). @@ -1016,7 +1030,12 @@ def parse(self: T, data: bytes) -> T: parsed.wire_type, meta, field_name, parsed.value ) - current = getattr(self, field_name) + try: + current = getattr(self, field_name) + except AttributeError: + current = self._get_field_default(field_name) + setattr(self, field_name, current) + if meta.proto_type == TYPE_MAP: # Value represents a single key/value pair entry in the map. current[value.key] = value.value @@ -1077,7 +1096,10 @@ def to_dict( defaults = self._betterproto.default_gen for field_name, meta in self._betterproto.meta_by_field_name.items(): field_is_repeated = defaults[field_name] is list - value = getattr(self, field_name) + try: + value = getattr(self, field_name) + except AttributeError: + value = self._get_field_default(field_name) cased_name = casing(field_name).rstrip("_") # type: ignore if meta.proto_type == TYPE_MESSAGE: if isinstance(value, datetime): @@ -1209,7 +1231,7 @@ def from_dict(self: T, value: Mapping[str, Any]) -> T: if value[key] is not None: if meta.proto_type == TYPE_MESSAGE: - v = getattr(self, field_name) + v = self._get_field_default(field_name) cls = self._betterproto.cls_by_field[field_name] if isinstance(v, list): if cls == datetime: @@ -1486,7 +1508,6 @@ def _validate_field_groups(cls, values): field_name_to_meta = cls._betterproto_meta.meta_by_field_name # type: ignore for group, field_set in group_to_one_ofs.items(): - if len(field_set) == 1: (field,) = field_set field_name = field.name diff --git a/tests/inputs/google_impl_behavior_equivalence/test_google_impl_behavior_equivalence.py b/tests/inputs/google_impl_behavior_equivalence/test_google_impl_behavior_equivalence.py index 476d20e3b..dd2a9f53e 100644 --- a/tests/inputs/google_impl_behavior_equivalence/test_google_impl_behavior_equivalence.py +++ b/tests/inputs/google_impl_behavior_equivalence/test_google_impl_behavior_equivalence.py @@ -50,8 +50,10 @@ def test_bytes_are_the_same_for_oneof(): # None of these fields were explicitly set BUT they should not actually be null # themselves - assert isinstance(message.foo, Foo) - assert isinstance(message2.foo, Foo) + assert not hasattr(message, "foo") + assert object.__getattribute__(message, "foo") == betterproto.PLACEHOLDER + assert not hasattr(message2, "foo") + assert object.__getattribute__(message2, "foo") == betterproto.PLACEHOLDER assert isinstance(message_reference.foo, ReferenceFoo) assert isinstance(message_reference2.foo, ReferenceFoo) diff --git a/tests/inputs/oneof_enum/test_oneof_enum.py b/tests/inputs/oneof_enum/test_oneof_enum.py index 7e287d4a4..e54fa3859 100644 --- a/tests/inputs/oneof_enum/test_oneof_enum.py +++ b/tests/inputs/oneof_enum/test_oneof_enum.py @@ -18,9 +18,8 @@ def test_which_one_of_returns_enum_with_default_value(): get_test_case_json_data("oneof_enum", "oneof_enum-enum-0.json")[0].json ) - assert message.move == Move( - x=0, y=0 - ) # Proto3 will default this as there is no null + assert not hasattr(message, "move") + assert object.__getattribute__(message, "move") == betterproto.PLACEHOLDER assert message.signal == Signal.PASS assert betterproto.which_one_of(message, "action") == ("signal", Signal.PASS) @@ -33,9 +32,8 @@ def test_which_one_of_returns_enum_with_non_default_value(): message.from_json( get_test_case_json_data("oneof_enum", "oneof_enum-enum-1.json")[0].json ) - assert message.move == Move( - x=0, y=0 - ) # Proto3 will default this as there is no null + assert not hasattr(message, "move") + assert object.__getattribute__(message, "move") == betterproto.PLACEHOLDER assert message.signal == Signal.RESIGN assert betterproto.which_one_of(message, "action") == ("signal", Signal.RESIGN) @@ -44,5 +42,6 @@ def test_which_one_of_returns_second_field_when_set(): message = Test() message.from_json(get_test_case_json_data("oneof_enum")[0].json) assert message.move == Move(x=2, y=3) - assert message.signal == Signal.PASS + assert not hasattr(message, "signal") + assert object.__getattribute__(message, "signal") == betterproto.PLACEHOLDER assert betterproto.which_one_of(message, "action") == ("move", Move(x=2, y=3)) diff --git a/tests/test_features.py b/tests/test_features.py index 940cd51c8..f314b8795 100644 --- a/tests/test_features.py +++ b/tests/test_features.py @@ -151,17 +151,18 @@ class Foo(betterproto.Message): foo.baz = "test" # Other oneof fields should now be unset - assert foo.bar == 0 + assert not hasattr(foo, "bar") + assert object.__getattribute__(foo, "bar") == betterproto.PLACEHOLDER assert betterproto.which_one_of(foo, "group1")[0] == "baz" - foo.sub.val = 1 + foo.sub = Sub(val=1) assert betterproto.serialized_on_wire(foo.sub) foo.abc = "test" # Group 1 shouldn't be touched, group 2 should have reset - assert foo.sub.val == 0 - assert betterproto.serialized_on_wire(foo.sub) is False + assert not hasattr(foo, "sub") + assert object.__getattribute__(foo, "sub") == betterproto.PLACEHOLDER assert betterproto.which_one_of(foo, "group2")[0] == "abc" # Zero value should always serialize for one-of