Skip to content

Commit

Permalink
Raise AttributeError on attempts to access unset oneof fields
Browse files Browse the repository at this point in the history
This commit modifies `Message.__getattribute__` to raise
`AttributeError` whenever an attempt is made to access an unset `oneof`
field. This provides several benefits over the current approach:

* There is no longer any risk of `betterproto` users accidentally
  relying on values of unset fields.

* Pattern matching with `match/case` on messages containing `oneof`
  groups is now supported. The following is now possible:
  ```
  @dataclasses.dataclass(eq=False, repr=False)
  class Test(betterproto.Message):
    x: int = betterproto.int32_field(1, group="g")
    y: str = betterproto.string_field(2, group="g")

  match Test(y="text"):
    case Test(x=v):
      print("x", v)
    case Test(y=v):
      print("y", v)
  ```
  Before this commit the code above would output `x 0` instead of
  `y text`, but now the output is `y text` as expected. The reason
  this works is because an `AttributeError` in a `case` pattern does
  not propagate and instead simply skips the `case`.

* We now have a type-checkable way to deconstruct `oneof`. When running
  `mypy` for the snippet above `v` has type `int` in the first `case`
  and type `str` in the second `case`. For versions of Python that do
  not support `match/case` (before 3.10) it is now possbile to use
  `try/except/else` blocks to achieve the same result:
  ```
  t = Test(y="text")
  try:
    v0: int = t.x
  except AttributeError:
    v1: str = t.y  # `oneof` contains `y`
  else:
    pass  # `oneof` contains `x`
  ```

This is a breaking change.
  • Loading branch information
a-khabarov committed Jul 19, 2023
1 parent 098989e commit 921a2b7
Show file tree
Hide file tree
Showing 4 changed files with 52 additions and 22 deletions.
46 changes: 37 additions & 9 deletions src/betterproto/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -693,8 +693,28 @@ def __repr__(self) -> str:
def __getattribute__(self, name: str) -> Any:
"""
Lazily initialize default values to avoid infinite recursion for recursive
message types
message types.
Raise :class:`AttributeError` on attempts to access unset ``oneof`` fields.
"""
try:
group_current = super().__getattribute__("_group_current")
except AttributeError:
pass
else:
if name not in {"__class__", "_betterproto"}:
group = self._betterproto.oneof_group_by_field.get(name)
if group is not None and group_current[group] != name:
if sys.version_info < (3, 10):
raise AttributeError(
f"{group!r} is set to {group_current[group]}, not {name!r}"
)
else:
raise AttributeError(
f"{group!r} is set to {group_current[group]}, not {name!r}",
name=name,
obj=self,
)

value = super().__getattribute__(name)
if value is not PLACEHOLDER:
return value
Expand Down Expand Up @@ -761,7 +781,10 @@ def __bytes__(self) -> bytes:
"""
output = bytearray()
for field_name, meta in self._betterproto.meta_by_field_name.items():
value = getattr(self, field_name)
try:
value = getattr(self, field_name)
except AttributeError:
continue

if value is None:
# Optional items should be skipped. This is used for the Google
Expand All @@ -775,9 +798,7 @@ def __bytes__(self) -> bytes:
# Note that proto3 field presence/optional fields are put in a
# synthetic single-item oneof by protoc, which helps us ensure we
# send the value even if the value is the default zero value.
selected_in_group = (
meta.group and self._group_current[meta.group] == field_name
)
selected_in_group = bool(meta.group)

# Empty messages can still be sent on the wire if they were
# set (or received empty).
Expand Down Expand Up @@ -1016,7 +1037,12 @@ def parse(self: T, data: bytes) -> T:
parsed.wire_type, meta, field_name, parsed.value
)

current = getattr(self, field_name)
try:
current = getattr(self, field_name)
except AttributeError:
current = self._get_field_default(field_name)
setattr(self, field_name, current)

if meta.proto_type == TYPE_MAP:
# Value represents a single key/value pair entry in the map.
current[value.key] = value.value
Expand Down Expand Up @@ -1077,7 +1103,10 @@ def to_dict(
defaults = self._betterproto.default_gen
for field_name, meta in self._betterproto.meta_by_field_name.items():
field_is_repeated = defaults[field_name] is list
value = getattr(self, field_name)
try:
value = getattr(self, field_name)
except AttributeError:
value = self._get_field_default(field_name)
cased_name = casing(field_name).rstrip("_") # type: ignore
if meta.proto_type == TYPE_MESSAGE:
if isinstance(value, datetime):
Expand Down Expand Up @@ -1209,7 +1238,7 @@ def from_dict(self: T, value: Mapping[str, Any]) -> T:

if value[key] is not None:
if meta.proto_type == TYPE_MESSAGE:
v = getattr(self, field_name)
v = self._get_field_default(field_name)
cls = self._betterproto.cls_by_field[field_name]
if isinstance(v, list):
if cls == datetime:
Expand Down Expand Up @@ -1486,7 +1515,6 @@ def _validate_field_groups(cls, values):
field_name_to_meta = cls._betterproto_meta.meta_by_field_name # type: ignore

for group, field_set in group_to_one_ofs.items():

if len(field_set) == 1:
(field,) = field_set
field_name = field.name
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,10 @@ def test_bytes_are_the_same_for_oneof():

# None of these fields were explicitly set BUT they should not actually be null
# themselves
assert isinstance(message.foo, Foo)
assert isinstance(message2.foo, Foo)
assert not hasattr(message, "foo")
assert object.__getattribute__(message, "foo") == betterproto.PLACEHOLDER
assert not hasattr(message2, "foo")
assert object.__getattribute__(message2, "foo") == betterproto.PLACEHOLDER

assert isinstance(message_reference.foo, ReferenceFoo)
assert isinstance(message_reference2.foo, ReferenceFoo)
Expand Down
13 changes: 6 additions & 7 deletions tests/inputs/oneof_enum/test_oneof_enum.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,8 @@ def test_which_one_of_returns_enum_with_default_value():
get_test_case_json_data("oneof_enum", "oneof_enum-enum-0.json")[0].json
)

assert message.move == Move(
x=0, y=0
) # Proto3 will default this as there is no null
assert not hasattr(message, "move")
assert object.__getattribute__(message, "move") == betterproto.PLACEHOLDER
assert message.signal == Signal.PASS
assert betterproto.which_one_of(message, "action") == ("signal", Signal.PASS)

Expand All @@ -33,9 +32,8 @@ def test_which_one_of_returns_enum_with_non_default_value():
message.from_json(
get_test_case_json_data("oneof_enum", "oneof_enum-enum-1.json")[0].json
)
assert message.move == Move(
x=0, y=0
) # Proto3 will default this as there is no null
assert not hasattr(message, "move")
assert object.__getattribute__(message, "move") == betterproto.PLACEHOLDER
assert message.signal == Signal.RESIGN
assert betterproto.which_one_of(message, "action") == ("signal", Signal.RESIGN)

Expand All @@ -44,5 +42,6 @@ def test_which_one_of_returns_second_field_when_set():
message = Test()
message.from_json(get_test_case_json_data("oneof_enum")[0].json)
assert message.move == Move(x=2, y=3)
assert message.signal == Signal.PASS
assert not hasattr(message, "signal")
assert object.__getattribute__(message, "signal") == betterproto.PLACEHOLDER
assert betterproto.which_one_of(message, "action") == ("move", Move(x=2, y=3))
9 changes: 5 additions & 4 deletions tests/test_features.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,17 +151,18 @@ class Foo(betterproto.Message):
foo.baz = "test"

# Other oneof fields should now be unset
assert foo.bar == 0
assert not hasattr(foo, "bar")
assert object.__getattribute__(foo, "bar") == betterproto.PLACEHOLDER
assert betterproto.which_one_of(foo, "group1")[0] == "baz"

foo.sub.val = 1
foo.sub = Sub(val=1)
assert betterproto.serialized_on_wire(foo.sub)

foo.abc = "test"

# Group 1 shouldn't be touched, group 2 should have reset
assert foo.sub.val == 0
assert betterproto.serialized_on_wire(foo.sub) is False
assert not hasattr(foo, "sub")
assert object.__getattribute__(foo, "sub") == betterproto.PLACEHOLDER
assert betterproto.which_one_of(foo, "group2")[0] == "abc"

# Zero value should always serialize for one-of
Expand Down

0 comments on commit 921a2b7

Please sign in to comment.