Skip to content

Commit

Permalink
AVRO-2921 Add Type Hints to avro.schema
Browse files Browse the repository at this point in the history
  • Loading branch information
kojiromike committed Jul 20, 2023
1 parent ea1ed80 commit 83de021
Show file tree
Hide file tree
Showing 3 changed files with 229 additions and 237 deletions.
7 changes: 5 additions & 2 deletions lang/py/avro/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,9 +161,12 @@ def convert(value: str, field: avro.schema.Field) -> Union[int, float, str, byte


def convert_union(value: str, field: avro.schema.Field) -> Union[int, float, str, bytes, bool, None]:
for name in (s.name for s in field.type.schemas):
if not isinstance(field.type, avro.schema.UnionSchema):
raise avro.errors.UsageError(f"Expected field.type to be a Union, but it was {field.type}")
# Casts to be fixed in AVRO-3798
for name in (cast(avro.schema.NamedSchema, s).name for s in field.type.schemas):
try:
return convert(value, name)
return convert(value, cast(avro.schema.Field, name))
except ValueError:
continue
raise avro.errors.UsageError("Exhausted Union Schema without finding a match")
Expand Down
12 changes: 6 additions & 6 deletions lang/py/avro/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@

import hashlib
import json
from typing import Mapping, Optional, Sequence, Union, cast
from typing import Mapping, MutableMapping, Optional, Sequence, Union, cast

import avro.errors
import avro.name
Expand All @@ -37,7 +37,7 @@


class MessageObject(TypedDict, total=False):
request: Sequence[Mapping[str, object]]
request: Sequence[MutableMapping[str, object]]
response: Union[str, object]
errors: Optional[Sequence[str]]

Expand Down Expand Up @@ -176,7 +176,7 @@ class Message:
def __init__(
self,
name: str,
request: Sequence[Mapping[str, object]],
request: Sequence[MutableMapping[str, object]],
response: Union[str, object],
errors: Optional[Sequence[str]] = None,
names: Optional[avro.name.Names] = None,
Expand Down Expand Up @@ -215,18 +215,18 @@ def to_json(self, names: Optional[avro.name.Names] = None) -> "MessageObject":
to_dump = MessageObject()
except NameError:
to_dump = {}
to_dump["request"] = self.request.to_json(names)
to_dump["request"] = cast(Sequence[MutableMapping[str, object]], self.request.to_json(names))
to_dump["response"] = self.response.to_json(names)
if self.errors:
to_dump["errors"] = self.errors.to_json(names)
to_dump["errors"] = cast(Optional[Sequence[str]], self.errors.to_json(names))

return to_dump

def __eq__(self, that: object) -> bool:
return all(hasattr(that, prop) and getattr(self, prop) == getattr(that, prop) for prop in self.__class__.__slots__)


def _parse_request(request: Sequence[Mapping[str, object]], names: avro.name.Names, validate_names: bool = True) -> avro.schema.RecordSchema:
def _parse_request(request: Sequence[MutableMapping[str, object]], names: avro.name.Names, validate_names: bool = True) -> avro.schema.RecordSchema:
if not isinstance(request, Sequence):
raise avro.errors.ProtocolParseException(f"Request property not a list: {request}")
return avro.schema.RecordSchema(None, None, request, names, "request", validate_names=validate_names)
Expand Down
Loading

0 comments on commit 83de021

Please sign in to comment.