From 83de0214b7354c83bdeb6bb97bdac3568eebc92f Mon Sep 17 00:00:00 2001 From: "Michael A. Smith" Date: Sun, 16 Jul 2023 09:14:39 -0400 Subject: [PATCH] AVRO-2921 Add Type Hints to avro.schema --- lang/py/avro/__main__.py | 7 +- lang/py/avro/protocol.py | 12 +- lang/py/avro/schema.py | 447 +++++++++++++++++++-------------------- 3 files changed, 229 insertions(+), 237 deletions(-) diff --git a/lang/py/avro/__main__.py b/lang/py/avro/__main__.py index 44fda88252b..342d51e434a 100755 --- a/lang/py/avro/__main__.py +++ b/lang/py/avro/__main__.py @@ -161,9 +161,12 @@ def convert(value: str, field: avro.schema.Field) -> Union[int, float, str, byte def convert_union(value: str, field: avro.schema.Field) -> Union[int, float, str, bytes, bool, None]: - for name in (s.name for s in field.type.schemas): + if not isinstance(field.type, avro.schema.UnionSchema): + raise avro.errors.UsageError(f"Expected field.type to be a Union, but it was {field.type}") + # Casts to be fixed in AVRO-3798 + for name in (cast(avro.schema.NamedSchema, s).name for s in field.type.schemas): try: - return convert(value, name) + return convert(value, cast(avro.schema.Field, name)) except ValueError: continue raise avro.errors.UsageError("Exhausted Union Schema without finding a match") diff --git a/lang/py/avro/protocol.py b/lang/py/avro/protocol.py index 5477fc45f05..d575a86704c 100644 --- a/lang/py/avro/protocol.py +++ b/lang/py/avro/protocol.py @@ -25,7 +25,7 @@ import hashlib import json -from typing import Mapping, Optional, Sequence, Union, cast +from typing import Mapping, MutableMapping, Optional, Sequence, Union, cast import avro.errors import avro.name @@ -37,7 +37,7 @@ class MessageObject(TypedDict, total=False): - request: Sequence[Mapping[str, object]] + request: Sequence[MutableMapping[str, object]] response: Union[str, object] errors: Optional[Sequence[str]] @@ -176,7 +176,7 @@ class Message: def __init__( self, name: str, - request: Sequence[Mapping[str, object]], + request: Sequence[MutableMapping[str, object]], response: Union[str, object], errors: Optional[Sequence[str]] = None, names: Optional[avro.name.Names] = None, @@ -215,10 +215,10 @@ def to_json(self, names: Optional[avro.name.Names] = None) -> "MessageObject": to_dump = MessageObject() except NameError: to_dump = {} - to_dump["request"] = self.request.to_json(names) + to_dump["request"] = cast(Sequence[MutableMapping[str, object]], self.request.to_json(names)) to_dump["response"] = self.response.to_json(names) if self.errors: - to_dump["errors"] = self.errors.to_json(names) + to_dump["errors"] = cast(Optional[Sequence[str]], self.errors.to_json(names)) return to_dump @@ -226,7 +226,7 @@ def __eq__(self, that: object) -> bool: return all(hasattr(that, prop) and getattr(self, prop) == getattr(that, prop) for prop in self.__class__.__slots__) -def _parse_request(request: Sequence[Mapping[str, object]], names: avro.name.Names, validate_names: bool = True) -> avro.schema.RecordSchema: +def _parse_request(request: Sequence[MutableMapping[str, object]], names: avro.name.Names, validate_names: bool = True) -> avro.schema.RecordSchema: if not isinstance(request, Sequence): raise avro.errors.ProtocolParseException(f"Request property not a list: {request}") return avro.schema.RecordSchema(None, None, request, names, "request", validate_names=validate_names) diff --git a/lang/py/avro/schema.py b/lang/py/avro/schema.py index 018f74debe6..faad854e03d 100644 --- a/lang/py/avro/schema.py +++ b/lang/py/avro/schema.py @@ -57,6 +57,7 @@ MutableMapping, Optional, Sequence, + Tuple, Union, cast, ) @@ -65,6 +66,12 @@ import avro.errors from avro.name import Name, Names, validate_basename +# +# Types +# +PropertiesType = MutableMapping[str, object] +SchemaDescriptionType = Union[str, PropertiesType, Sequence["SchemaDescriptionType"]] + # # Constants # @@ -168,10 +175,10 @@ class PropertiesMixin: """A mixin that provides basic properties.""" _reserved_properties: Sequence[str] = () - _props: Optional[MutableMapping[str, object]] = None + _props: Optional[PropertiesType] = None @property - def props(self) -> MutableMapping[str, object]: + def props(self) -> PropertiesType: if self._props is None: self._props = {} return self._props @@ -192,7 +199,7 @@ def check_props(self, other: "PropertiesMixin", props: Sequence[str]) -> bool: return all(getattr(self, prop) == getattr(other, prop) for prop in props) @property - def other_props(self) -> Mapping[str, object]: + def other_props(self) -> PropertiesType: """Dictionary of non-reserved properties""" return get_other_props(self.props, self._reserved_properties) @@ -237,7 +244,7 @@ class CanonicalPropertiesMixin(PropertiesMixin): """A Mixin that provides canonical properties to Schema and Field types.""" @property - def canonical_properties(self) -> Mapping[str, object]: + def canonical_properties(self) -> PropertiesType: props = self.props return collections.OrderedDict((key, props[key]) for key in CANONICAL_FIELD_ORDER if key in props) @@ -247,7 +254,7 @@ class Schema(abc.ABC, CanonicalPropertiesMixin): _reserved_properties = SCHEMA_RESERVED_PROPS - def __init__(self, type_: str, other_props: Optional[Mapping[str, object]] = None, validate_names: bool = True) -> None: + def __init__(self, type_: str, other_props: Optional[PropertiesType] = None, validate_names: bool = True) -> None: if not isinstance(type_, str): raise avro.errors.SchemaParseException("Schema type must be a string.") if type_ not in avro.constants.VALID_TYPES: @@ -269,7 +276,7 @@ def __str__(self) -> str: return json.dumps(self.to_json()) @abc.abstractmethod - def to_json(self, names: Optional[Names] = None) -> object: + def to_json(self, names: Optional[Names] = None) -> SchemaDescriptionType: """ Converts the schema object into its AVRO specification representation. @@ -293,7 +300,7 @@ def validate(self, datum: object) -> Optional["Schema"]: """ @abc.abstractmethod - def to_canonical_json(self, names: Optional[Names] = None) -> object: + def to_canonical_json(self, names: Optional[Names] = None) -> SchemaDescriptionType: """ Converts the schema object into its Canonical Form http://avro.apache.org/docs/current/spec.html#Parsing+Canonical+Form+for+Schemas @@ -347,7 +354,7 @@ def __init__( name: str, namespace: Optional[str] = None, names: Optional[Names] = None, - other_props: Optional[Mapping[str, object]] = None, + other_props: Optional[PropertiesType] = None, validate_names: bool = True, ) -> None: super().__init__(type_, other_props, validate_names=validate_names) @@ -369,21 +376,21 @@ def __init__( # Store full name as calculated from name, namespace self._fullname = new_name.fullname - def name_ref(self, names): + def name_ref(self, names: Names) -> str: return self.name if self.namespace == names.default_namespace else self.fullname # read-only properties @property - def name(self): - return self.get_prop("name") + def name(self) -> str: + return cast(str, self.get_prop("name")) @property - def namespace(self): - return self.get_prop("namespace") + def namespace(self) -> str: + return cast(str, self.get_prop("namespace")) @property - def fullname(self): - return self._fullname + def fullname(self) -> str: + return cast(str, self._fullname) # @@ -392,7 +399,9 @@ def fullname(self): class LogicalSchema: - def __init__(self, logical_type): + logical_type: str + + def __init__(self, logical_type: str) -> None: self.logical_type = logical_type @@ -402,7 +411,7 @@ def __init__(self, logical_type): class DecimalLogicalSchema(LogicalSchema): - def __init__(self, precision, scale=0, max_precision=0): + def __init__(self, precision: int, scale: int = 0, max_precision: int = 0) -> None: if not isinstance(precision, int) or precision <= 0: raise avro.errors.IgnoredLogicalType(f"Invalid decimal precision {precision}. Must be a positive integer.") @@ -421,7 +430,18 @@ def __init__(self, precision, scale=0, max_precision=0): class Field(CanonicalPropertiesMixin, EqualByJsonMixin): _reserved_properties: Sequence[str] = FIELD_RESERVED_PROPS - def __init__(self, type_, name, has_default, default=None, order=None, names=None, doc=None, other_props=None, validate_names: bool = True): + def __init__( + self, + type_: str, + name: str, + has_default: bool, + default: Optional[object] = None, + order: Optional[str] = None, + names: Optional[Names] = None, + doc: Optional[str] = None, + other_props: Optional[PropertiesType] = None, + validate_names: bool = True, + ) -> None: if not name: raise avro.errors.SchemaParseException("Fields must have a non-empty name.") if not isinstance(name, str): @@ -432,7 +452,7 @@ def __init__(self, type_, name, has_default, default=None, order=None, names=Non self.props.update(other_props or {}) if isinstance(type_, str) and names is not None and names.has_name(type_, None): - type_schema = names.get_name(type_, None) + type_schema: Schema = cast(NamedSchema, names.get_name(type_, None)) else: try: type_schema = make_avsc_object(type_, names, validate_names=validate_names) @@ -453,39 +473,32 @@ def __init__(self, type_, name, has_default, default=None, order=None, names=Non # read-only properties @property - def default(self): + def default(self) -> Optional[object]: return self.get_prop("default") @property - def has_default(self): + def has_default(self) -> bool: return self._has_default @property - def order(self): - return self.get_prop("order") + def order(self) -> Optional[str]: + order = self.get_prop("order") + return None if order is None else cast(str, order) @property - def doc(self): - return self.get_prop("doc") + def doc(self) -> Optional[str]: + return cast(Optional[str], self.get_prop("doc")) - def __str__(self): + def __str__(self) -> str: return json.dumps(self.to_json()) - def to_json(self, names=None): + def to_json(self, names: Optional[Names] = None) -> SchemaDescriptionType: names = names or Names(validate_names=self.validate_names) + return {**self.props, "type": self.type.to_json(names)} - to_dump = self.props.copy() - to_dump["type"] = self.type.to_json(names) - - return to_dump - - def to_canonical_json(self, names=None): + def to_canonical_json(self, names: Optional[Names] = None) -> SchemaDescriptionType: names = names or Names(validate_names=self.validate_names) - - to_dump = self.canonical_properties - to_dump["type"] = self.type.to_canonical_json(names) - - return to_dump + return {**self.canonical_properties, "type": self.type.to_canonical_json(names)} # @@ -496,7 +509,7 @@ def to_canonical_json(self, names=None): class PrimitiveSchema(EqualByPropsMixin, Schema): """Valid primitive types are in PRIMITIVE_TYPES.""" - _validators = { + _validators: Mapping[str, Callable[[object], bool]] = { "null": lambda x: x is None, "boolean": lambda x: isinstance(x, bool), "string": lambda x: isinstance(x, str), @@ -507,7 +520,7 @@ class PrimitiveSchema(EqualByPropsMixin, Schema): "double": lambda x: isinstance(x, (int, float)), } - def __init__(self, type, other_props=None): + def __init__(self, type: str, other_props: Optional[PropertiesType] = None) -> None: # Ensure valid ctor args if type not in avro.constants.PRIMITIVE_TYPES: raise avro.errors.AvroException(f"{type} is not a valid primitive type.") @@ -517,7 +530,7 @@ def __init__(self, type, other_props=None): self.fullname = type - def match(self, writer): + def match(self, writer: "Schema") -> bool: """Return True if the current schema (as reader) matches the writer schema. @arg writer: the schema to match against @@ -533,16 +546,13 @@ def match(self, writer): }, }.get(writer.type, False) - def to_json(self, names=None): - if len(self.props) == 1: - return self.fullname - else: - return self.props + def to_json(self, names: Optional[Names] = None) -> SchemaDescriptionType: + return self.fullname if len(self.props) == 1 else self.props - def to_canonical_json(self, names=None): + def to_canonical_json(self, names: Optional[Names] = None) -> SchemaDescriptionType: return self.fullname if len(self.props) == 1 else self.canonical_properties - def validate(self, datum): + def validate(self, datum: object) -> Optional[Schema]: """Return self if datum is a valid representation of this type of primitive schema, else None @arg datum: The data to be checked for validity according to this schema @@ -558,7 +568,7 @@ def validate(self, datum): class BytesDecimalSchema(PrimitiveSchema, DecimalLogicalSchema): - def __init__(self, precision, scale=0, other_props=None): + def __init__(self, precision: int, scale: int = 0, other_props: Optional[PropertiesType] = None) -> None: DecimalLogicalSchema.__init__(self, precision, scale, max_precision=((1 << 31) - 1)) PrimitiveSchema.__init__(self, "bytes", other_props) self.set_prop("precision", precision) @@ -566,17 +576,17 @@ def __init__(self, precision, scale=0, other_props=None): # read-only properties @property - def precision(self): - return self.get_prop("precision") + def precision(self) -> int: + return cast(int, self.get_prop("precision")) @property - def scale(self): - return self.get_prop("scale") + def scale(self) -> int: + return cast(int, self.get_prop("scale")) - def to_json(self, names=None): + def to_json(self, names: Optional[Names] = None) -> SchemaDescriptionType: return self.props - def validate(self, datum): + def validate(self, datum: object) -> Optional[Schema]: """Return self if datum is a Decimal object, else None.""" return self if isinstance(datum, decimal.Decimal) else None @@ -585,7 +595,15 @@ def validate(self, datum): # Complex Types (non-recursive) # class FixedSchema(EqualByPropsMixin, NamedSchema): - def __init__(self, name, namespace, size, names=None, other_props=None, validate_names: bool = True): + def __init__( + self, + name: str, + namespace: Optional[str], + size: int, + names: Optional[Names] = None, + other_props: Optional[PropertiesType] = None, + validate_names: bool = True, + ) -> None: # Ensure valid ctor args if not isinstance(size, int) or size < 0: fail_msg = "Fixed Schema requires a valid positive integer for size property." @@ -599,10 +617,10 @@ def __init__(self, name, namespace, size, names=None, other_props=None, validate # read-only properties @property - def size(self): - return self.get_prop("size") + def size(self) -> int: + return cast(int, self.get_prop("size")) - def match(self, writer): + def match(self, writer: "Schema") -> bool: """Return True if the current schema (as reader) matches the writer schema. @arg writer: the schema to match against @@ -610,22 +628,22 @@ def match(self, writer): """ return self.type == writer.type and self.check_props(writer, ["fullname", "size"]) - def to_json(self, names=None): + def to_json(self, names: Optional[Names] = None) -> SchemaDescriptionType: names = names or Names(validate_names=self.validate_names) if self.fullname in names.names: return self.name_ref(names) names.names[self.fullname] = self - return names.prune_namespace(self.props) + return names.prune_namespace(dict(self.props)) - def to_canonical_json(self, names=None): + def to_canonical_json(self, names: Optional[Names] = None) -> SchemaDescriptionType: to_dump = self.canonical_properties to_dump["name"] = self.fullname return to_dump - def validate(self, datum): + def validate(self, datum: object) -> Optional[Schema]: """Return self if datum is a valid representation of this schema, else None.""" return self if isinstance(datum, bytes) and len(datum) == self.size else None @@ -638,15 +656,15 @@ def validate(self, datum): class FixedDecimalSchema(FixedSchema, DecimalLogicalSchema): def __init__( self, - size, - name, - precision, - scale=0, - namespace=None, - names=None, - other_props=None, + size: int, + name: str, + precision: int, + scale: int = 0, + namespace: Optional[str] = None, + names: Optional[Names] = None, + other_props: Optional[PropertiesType] = None, validate_names: bool = True, - ): + ) -> None: max_precision = int(math.floor(math.log10(2) * (8 * size - 1))) DecimalLogicalSchema.__init__(self, precision, scale, max_precision) FixedSchema.__init__(self, name, namespace, size, names, other_props, validate_names=validate_names) @@ -655,17 +673,17 @@ def __init__( # read-only properties @property - def precision(self): - return self.get_prop("precision") + def precision(self) -> int: + return cast(int, self.get_prop("precision")) @property - def scale(self): - return self.get_prop("scale") + def scale(self) -> int: + return cast(int, self.get_prop("scale")) - def to_json(self, names=None): + def to_json(self, names: Optional[Names] = None) -> SchemaDescriptionType: return self.props - def validate(self, datum): + def validate(self, datum: object) -> Optional[Schema]: """Return self if datum is a Decimal object, else None.""" return self if isinstance(datum, decimal.Decimal) else None @@ -676,9 +694,9 @@ def __init__( name: str, namespace: str, symbols: Sequence[str], - names: Optional[avro.name.Names] = None, + names: Optional[Names] = None, doc: Optional[str] = None, - other_props: Optional[Mapping[str, object]] = None, + other_props: Optional[PropertiesType] = None, validate_enum_symbols: bool = True, validate_names: bool = True, ) -> None: @@ -716,10 +734,10 @@ def symbols(self) -> Sequence[str]: raise Exception @property - def doc(self): - return self.get_prop("doc") + def doc(self) -> Optional[str]: + return cast(Optional[str], self.get_prop("doc")) - def match(self, writer): + def match(self, writer: "Schema") -> bool: """Return True if the current schema (as reader) matches the writer schema. @arg writer: the schema to match against @@ -727,27 +745,25 @@ def match(self, writer): """ return self.type == writer.type and self.check_props(writer, ["fullname"]) - def to_json(self, names=None): + def to_json(self, names: Optional[Names] = None) -> SchemaDescriptionType: names = names or Names(validate_names=self.validate_names) if self.fullname in names.names: return self.name_ref(names) names.names[self.fullname] = self - return names.prune_namespace(self.props) + return names.prune_namespace(dict(self.props)) - def to_canonical_json(self, names=None): + def to_canonical_json(self, names: Optional[Names] = None) -> SchemaDescriptionType: names_as_json = self.to_json(names) if isinstance(names_as_json, str): - to_dump = self.fullname - else: - to_dump = self.canonical_properties - to_dump["name"] = self.fullname - + return self.fullname + to_dump = self.canonical_properties + to_dump["name"] = self.fullname return to_dump - def validate(self, datum): + def validate(self, datum: object) -> Optional[Schema]: """Return self if datum is a valid member of this Enum, else None.""" return self if datum in self.symbols else None @@ -758,66 +774,57 @@ def validate(self, datum): class ArraySchema(EqualByJsonMixin, Schema): - def __init__(self, items, names=None, other_props=None, validate_names: bool = True): + def __init__(self, items: str, names: Optional[Names] = None, other_props: Optional[PropertiesType] = None, validate_names: bool = True) -> None: # Call parent ctor Schema.__init__(self, "array", other_props, validate_names=validate_names) # Add class members - if isinstance(items, str) and names.has_name(items, None): - items_schema = names.get_name(items, None) + if isinstance(items, str) and names and names.has_name(items, None): + items_schema: Schema = cast(NamedSchema, names.get_name(items, None)) else: try: items_schema = make_avsc_object(items, names, validate_names=self.validate_names) except avro.errors.SchemaParseException as e: - fail_msg = f"Items schema ({items}) not a valid Avro schema: {e} (known names: {names.names.keys()})" - raise avro.errors.SchemaParseException(fail_msg) + known_names = names.names.keys() if names else () + raise avro.errors.SchemaParseException(f"Items schema ({items}) not a valid Avro schema: {e} (known names: {known_names}") self.set_prop("items", items_schema) # read-only properties @property - def items(self): - return self.get_prop("items") + def items(self) -> Schema: + return cast(Schema, self.get_prop("items")) - def match(self, writer): + def match(self, writer: "Schema") -> bool: """Return True if the current schema (as reader) matches the writer schema. @arg writer: the schema to match against @return bool """ - return self.type == writer.type and self.items.check_props(writer.items, ["type"]) + return self.type == writer.type and self.items.check_props(cast("ArraySchema", writer).items, ["type"]) - def to_json(self, names=None): + def to_json(self, names: Optional[Names] = None) -> SchemaDescriptionType: names = names or Names(validate_names=self.validate_names) + item_schema = cast(Schema, self.get_prop("items")) + return {**self.props, "items": item_schema.to_json(names)} - to_dump = self.props.copy() - item_schema = self.get_prop("items") - to_dump["items"] = item_schema.to_json(names) - - return to_dump - - def to_canonical_json(self, names=None): + def to_canonical_json(self, names: Optional[Names] = None) -> SchemaDescriptionType: names = names or Names(validate_names=self.validate_names) + return {**self.canonical_properties, "items": self.items.to_canonical_json(names)} - to_dump = self.canonical_properties - item_schema = self.get_prop("items") - to_dump["items"] = item_schema.to_canonical_json(names) - - return to_dump - - def validate(self, datum): + def validate(self, datum: object) -> Optional[Schema]: """Return self if datum is a valid representation of this schema, else None.""" return self if isinstance(datum, list) else None class MapSchema(EqualByJsonMixin, Schema): - def __init__(self, values, names=None, other_props=None, validate_names: bool = True): + def __init__(self, values: str, names: Optional[Names] = None, other_props: Optional[PropertiesType] = None, validate_names: bool = True) -> None: # Call parent ctor Schema.__init__(self, "map", other_props, validate_names=validate_names) # Add class members - if isinstance(values, str) and names.has_name(values, None): - values_schema = names.get_name(values, None) + if isinstance(values, str) and names and names.has_name(values, None): + values_schema: Schema = cast(NamedSchema, names.get_name(values, None)) else: try: values_schema = make_avsc_object(values, names, validate_names=self.validate_names) @@ -830,34 +837,26 @@ def __init__(self, values, names=None, other_props=None, validate_names: bool = # read-only properties @property - def values(self): - return self.get_prop("values") + def values(self) -> Schema: + return cast(Schema, self.get_prop("values")) - def match(self, writer): + def match(self, writer: "Schema") -> bool: """Return True if the current schema (as reader) matches the writer schema. @arg writer: the schema to match against @return bool """ - return writer.type == self.type and self.values.check_props(writer.values, ["type"]) + return writer.type == self.type and self.values.check_props(cast(MapSchema, writer).values, ["type"]) - def to_json(self, names=None): + def to_json(self, names: Optional[Names] = None) -> SchemaDescriptionType: names = names or Names(validate_names=self.validate_names) + return {**self.props, "values": self.values.to_json(names)} - to_dump = self.props.copy() - to_dump["values"] = self.get_prop("values").to_json(names) - - return to_dump - - def to_canonical_json(self, names=None): + def to_canonical_json(self, names: Optional[Names] = None) -> SchemaDescriptionType: names = names or Names(validate_names=self.validate_names) + return {**self.canonical_properties, "values": self.values.to_canonical_json(names)} - to_dump = self.canonical_properties - to_dump["values"] = self.get_prop("values").to_canonical_json(names) - - return to_dump - - def validate(self, datum): + def validate(self, datum: object) -> Optional[Schema]: """Return self if datum is a valid representation of this schema, else None.""" return self if isinstance(datum, dict) and all(isinstance(key, str) for key in datum) else None @@ -867,7 +866,7 @@ class UnionSchema(EqualByJsonMixin, Schema): names is a dictionary of schema objects """ - def __init__(self, schemas, names=None, validate_names: bool = True): + def __init__(self, schemas: Sequence[SchemaDescriptionType], names: Optional[Names] = None, validate_names: bool = True) -> None: # Ensure valid ctor args if not isinstance(schemas, list): fail_msg = "Union schema requires a list of schemas." @@ -879,8 +878,8 @@ def __init__(self, schemas, names=None, validate_names: bool = True): # Add class members schema_objects: List[Schema] = [] for schema in schemas: - if isinstance(schema, str) and names.has_name(schema, None): - new_schema = names.get_name(schema, None) + if isinstance(schema, str) and names and names.has_name(schema, None): + new_schema: Schema = cast(NamedSchema, names.get_name(schema, None)) else: try: new_schema = make_avsc_object(schema, names, validate_names=self.validate_names) @@ -901,10 +900,10 @@ def __init__(self, schemas, names=None, validate_names: bool = True): # read-only properties @property - def schemas(self): + def schemas(self) -> Sequence[Schema]: return self._schemas - def match(self, writer): + def match(self, writer: "Schema") -> bool: """Return True if the current schema (as reader) matches the writer schema. @arg writer: the schema to match against @@ -912,31 +911,26 @@ def match(self, writer): """ return writer.type in {"union", "error_union"} or any(s.match(writer) for s in self.schemas) - def to_json(self, names=None): + def to_json(self, names: Optional[Names] = None) -> SchemaDescriptionType: names = names or Names(validate_names=self.validate_names) + return [schema.to_json(names) for schema in self.schemas] - to_dump = [] - for schema in self.schemas: - to_dump.append(schema.to_json(names)) - - return to_dump - - def to_canonical_json(self, names=None): + def to_canonical_json(self, names: Optional[Names] = None) -> SchemaDescriptionType: names = names or Names(validate_names=self.validate_names) return [schema.to_canonical_json(names) for schema in self.schemas] - def validate(self, datum): + def validate(self, datum: object) -> Optional[Schema]: """Return the first branch schema of which datum is a valid example, else None.""" return next((branch for branch in self.schemas if branch.validate(datum) is not None), None) class ErrorUnionSchema(UnionSchema): - def __init__(self, schemas, names=None, validate_names: bool = True): + def __init__(self, schemas: Sequence[str], names: Optional[Names] = None, validate_names: bool = True) -> None: # Prepend "string" to handle system errors - UnionSchema.__init__(self, ["string"] + schemas, names, validate_names) + UnionSchema.__init__(self, ["string", *schemas], names, validate_names) - def to_json(self, names=None): + def to_json(self, names: Optional[Names] = None) -> SchemaDescriptionType: names = names or Names(validate_names=self.validate_names) to_dump = [] @@ -951,21 +945,21 @@ def to_json(self, names=None): class RecordSchema(EqualByJsonMixin, NamedSchema): @staticmethod - def make_field_objects(field_data: Sequence[Mapping[str, object]], names: avro.name.Names, validate_names: bool = True) -> Sequence[Field]: + def make_field_objects(field_data: Sequence[MutableMapping[str, object]], names: Names, validate_names: bool = True) -> Sequence[Field]: """We're going to need to make message parameters too.""" field_objects = [] field_names = [] for field in field_data: if not callable(getattr(field, "get", None)): raise avro.errors.SchemaParseException(f"Not a valid field: {field}") - type = field.get("type") - name = field.get("name") + type = cast(str, field.get("type")) + name = cast(str, field.get("name")) # null values can have a default value of None has_default = "default" in field default = field.get("default") - order = field.get("order") - doc = field.get("doc") + order = cast(Optional[str], field.get("order")) + doc = cast(Optional[str], field.get("doc")) other_props = get_other_props(field, FIELD_RESERVED_PROPS) new_field = Field(type, name, has_default, default, order, names, doc, other_props, validate_names=validate_names) # make sure field name has not been used yet @@ -976,7 +970,7 @@ def make_field_objects(field_data: Sequence[Mapping[str, object]], names: avro.n field_objects.append(new_field) return field_objects - def match(self, writer): + def match(self, writer: "Schema") -> bool: """Return True if the current schema (as reader) matches the other schema. @arg writer: the schema to match against @@ -986,18 +980,19 @@ def match(self, writer): def __init__( self, - name, - namespace, - fields, - names=None, - schema_type="record", - doc=None, - other_props=None, + name: Optional[str], + namespace: Optional[str], + fields: Sequence[MutableMapping[str, object]], + names: Optional[Names] = None, + schema_type: str = "record", + doc: Optional[str] = None, + other_props: Optional[PropertiesType] = None, validate_names: bool = True, - ): + ) -> None: # Ensure valid ctor args + # (Should we continue to do manual runtime type checking?) if fields is None: - fail_msg = "Record schema requires a non-empty fields property." + fail_msg = "Record schema requires a non-empty fields property." # type: ignore raise avro.errors.SchemaParseException(fail_msg) elif not isinstance(fields, list): fail_msg = "Fields property must be a list of Avro schemas." @@ -1006,8 +1001,10 @@ def __init__( # Call parent ctor (adds own name to namespace, too) if schema_type == "request": Schema.__init__(self, schema_type, other_props) - else: + elif name: NamedSchema.__init__(self, schema_type, name, namespace, names, other_props, validate_names=validate_names) + else: + raise avro.errors.InvalidName("Attempted to create a record with no name") names = names or Names(validate_names=self.validate_names) if schema_type == "record": @@ -1025,21 +1022,18 @@ def __init__( # read-only properties @property - def fields(self): - return self.get_prop("fields") + def fields(self) -> Sequence[Field]: + return cast(Sequence[Field], self.get_prop("fields")) @property - def doc(self): - return self.get_prop("doc") + def doc(self) -> Optional[str]: + return cast(Optional[str], self.get_prop("doc")) @property - def fields_dict(self): - fields_dict = {} - for field in self.fields: - fields_dict[field.name] = field - return fields_dict + def fields_dict(self) -> MutableMapping[str, Field]: + return {field.name: field for field in self.fields} - def to_json(self, names=None): + def to_json(self, names: Optional[Names] = None) -> SchemaDescriptionType: names = names or Names(validate_names=self.validate_names) # Request records don't have names @@ -1048,15 +1042,10 @@ def to_json(self, names=None): if self.fullname in names.names: return self.name_ref(names) - else: - names.names[self.fullname] = self - - to_dump = names.prune_namespace(self.props.copy()) - to_dump["fields"] = [f.to_json(names) for f in self.fields] - - return to_dump + names.names[self.fullname] = self + return names.prune_namespace({**self.props, "fields": [f.to_json(names) for f in self.fields]}) - def to_canonical_json(self, names=None): + def to_canonical_json(self, names: Optional[Names] = None) -> SchemaDescriptionType: names = names or Names(validate_names=self.validate_names) if self.type == "request": @@ -1073,7 +1062,7 @@ def to_canonical_json(self, names=None): return to_dump - def validate(self, datum): + def validate(self, datum: object) -> Optional[Schema]: """Return self if datum is a valid representation of this schema, else None""" return self if isinstance(datum, dict) and {f.name for f in self.fields}.issuperset(datum.keys()) else None @@ -1084,14 +1073,14 @@ def validate(self, datum): class DateSchema(LogicalSchema, PrimitiveSchema): - def __init__(self, other_props=None): + def __init__(self, other_props: Optional[PropertiesType] = None) -> None: LogicalSchema.__init__(self, avro.constants.DATE) PrimitiveSchema.__init__(self, "int", other_props) - def to_json(self, names=None): + def to_json(self, names: Optional[Names] = None) -> SchemaDescriptionType: return self.props - def validate(self, datum): + def validate(self, datum: object) -> Optional[Schema]: """Return self if datum is a valid date object, else None.""" return self if isinstance(datum, datetime.date) else None @@ -1102,14 +1091,14 @@ def validate(self, datum): class TimeMillisSchema(LogicalSchema, PrimitiveSchema): - def __init__(self, other_props=None): + def __init__(self, other_props: Optional[PropertiesType] = None) -> None: LogicalSchema.__init__(self, avro.constants.TIME_MILLIS) PrimitiveSchema.__init__(self, "int", other_props) - def to_json(self, names=None): + def to_json(self, names: Optional[Names] = None) -> SchemaDescriptionType: return self.props - def validate(self, datum): + def validate(self, datum: object) -> Optional[Schema]: """Return self if datum is a valid representation of this schema, else None.""" return self if isinstance(datum, datetime.time) else None @@ -1120,14 +1109,14 @@ def validate(self, datum): class TimeMicrosSchema(LogicalSchema, PrimitiveSchema): - def __init__(self, other_props=None): + def __init__(self, other_props: Optional[PropertiesType] = None) -> None: LogicalSchema.__init__(self, avro.constants.TIME_MICROS) PrimitiveSchema.__init__(self, "long", other_props) - def to_json(self, names=None): + def to_json(self, names: Optional[Names] = None) -> SchemaDescriptionType: return self.props - def validate(self, datum): + def validate(self, datum: object) -> Optional[Schema]: """Return self if datum is a valid representation of this schema, else None.""" return self if isinstance(datum, datetime.time) else None @@ -1138,14 +1127,14 @@ def validate(self, datum): class TimestampMillisSchema(LogicalSchema, PrimitiveSchema): - def __init__(self, other_props=None): + def __init__(self, other_props: Optional[PropertiesType] = None) -> None: LogicalSchema.__init__(self, avro.constants.TIMESTAMP_MILLIS) PrimitiveSchema.__init__(self, "long", other_props) - def to_json(self, names=None): + def to_json(self, names: Optional[Names] = None) -> SchemaDescriptionType: return self.props - def validate(self, datum): + def validate(self, datum: object) -> Optional[Schema]: return self if isinstance(datum, datetime.datetime) and _is_timezone_aware_datetime(datum) else None @@ -1155,14 +1144,14 @@ def validate(self, datum): class TimestampMicrosSchema(LogicalSchema, PrimitiveSchema): - def __init__(self, other_props=None): + def __init__(self, other_props: Optional[PropertiesType] = None) -> None: LogicalSchema.__init__(self, avro.constants.TIMESTAMP_MICROS) PrimitiveSchema.__init__(self, "long", other_props) - def to_json(self, names=None): + def to_json(self, names: Optional[Names] = None) -> SchemaDescriptionType: return self.props - def validate(self, datum): + def validate(self, datum: object) -> Optional[Schema]: return self if isinstance(datum, datetime.datetime) and _is_timezone_aware_datetime(datum) else None @@ -1172,16 +1161,16 @@ def validate(self, datum): class UUIDSchema(LogicalSchema, PrimitiveSchema): - def __init__(self, other_props=None): + def __init__(self, other_props: Optional[PropertiesType] = None) -> None: LogicalSchema.__init__(self, avro.constants.UUID) PrimitiveSchema.__init__(self, "string", other_props) - def to_json(self, names=None): + def to_json(self, names: Optional[Names] = None) -> SchemaDescriptionType: return self.props - def validate(self, datum): + def validate(self, datum: object) -> Optional[Schema]: try: - uuid.UUID(datum) + uuid.UUID(datum) # type: ignore except (ValueError, TypeError): return None @@ -1193,7 +1182,7 @@ def validate(self, datum): # -def get_other_props(all_props: Mapping[str, object], reserved_props: Sequence[str]) -> Mapping[str, object]: +def get_other_props(all_props: PropertiesType, reserved_props: Sequence[str]) -> PropertiesType: """ Retrieve the non-reserved properties from a dictionary of properties @args reserved_props: The set of reserved properties to exclude @@ -1201,14 +1190,16 @@ def get_other_props(all_props: Mapping[str, object], reserved_props: Sequence[st return {k: v for k, v in all_props.items() if k not in reserved_props} -def make_bytes_decimal_schema(other_props): +def make_bytes_decimal_schema(other_props: PropertiesType) -> BytesDecimalSchema: """Make a BytesDecimalSchema from just other_props.""" - return BytesDecimalSchema(other_props.get("precision"), other_props.get("scale", 0), other_props) + precision = cast(int, other_props.get("precision")) + scale = cast(int, other_props.get("scale", 0)) + return BytesDecimalSchema(precision, scale, other_props) -def make_logical_schema(logical_type, type_, other_props): +def make_logical_schema(logical_type: str, type_: str, other_props: PropertiesType) -> Optional[Schema]: """Map the logical types to the appropriate literal type and schema class.""" - logical_types = { + logical_types: Mapping[Tuple[str, str], Callable[..., Optional[Schema]]] = { (avro.constants.DATE, "int"): DateSchema, (avro.constants.DECIMAL, "bytes"): make_bytes_decimal_schema, # The fixed decimal schema is handled later by returning None now. @@ -1236,9 +1227,7 @@ def make_logical_schema(logical_type, type_, other_props): return None -def make_avsc_object( - json_data: object, names: Optional[avro.name.Names] = None, validate_enum_symbols: bool = True, validate_names: bool = True -) -> Schema: +def make_avsc_object(json_data: object, names: Optional[Names] = None, validate_enum_symbols: bool = True, validate_names: bool = True) -> Schema: """ Build Avro Schema from data parsed out of JSON string. @@ -1249,26 +1238,26 @@ def make_avsc_object( # JSON object (non-union) if callable(getattr(json_data, "get", None)): - json_data = cast(Mapping, json_data) - type_ = json_data.get("type") + json_data = cast(MutableMapping[str, object], json_data) + type_ = cast(str, json_data.get("type")) other_props = get_other_props(json_data, SCHEMA_RESERVED_PROPS) - logical_type = json_data.get("logicalType") + logical_type = cast(str, json_data.get("logicalType")) if logical_type: logical_schema = make_logical_schema(logical_type, type_, other_props or {}) if logical_schema is not None: - return cast(Schema, logical_schema) + return logical_schema if type_ in avro.constants.NAMED_TYPES: name = json_data.get("name") if not isinstance(name, str): raise avro.errors.SchemaParseException(f"Name {name} must be a string, but it is {type(name)}.") - namespace = json_data.get("namespace", names.default_namespace) + namespace = cast(str, json_data.get("namespace", names.default_namespace)) if type_ == "fixed": - size = json_data.get("size") + size = cast(int, json_data.get("size")) if logical_type == "decimal": - precision = json_data.get("precision") - scale = json_data.get("scale", 0) + precision = cast(int, json_data.get("precision")) + scale = cast(int, json_data.get("scale", 0)) try: return FixedDecimalSchema(size, name, precision, scale, namespace, names, other_props, validate_names) except avro.errors.IgnoredLogicalType as warning: @@ -1281,11 +1270,11 @@ def make_avsc_object( for symbol in symbols: if not isinstance(symbol, str): raise avro.errors.SchemaParseException(f"Enum symbols must be a sequence of strings, but one symbol is a {type(symbol)}") - doc = json_data.get("doc") + doc = cast(Optional[str], json_data.get("doc")) return EnumSchema(name, namespace, symbols, names, doc, other_props, validate_enum_symbols, validate_names) if type_ in ["record", "error"]: - fields = json_data.get("fields") - doc = json_data.get("doc") + fields = cast(Sequence[MutableMapping[str, object]], json_data.get("fields")) + doc = cast(Optional[str], json_data.get("doc")) return RecordSchema(name, namespace, fields, names, type_, doc, other_props, validate_names) raise avro.errors.SchemaParseException(f"Unknown Named Type: {type_}") @@ -1294,13 +1283,13 @@ def make_avsc_object( if type_ in avro.constants.VALID_TYPES: if type_ == "array": - items = json_data.get("items") + items = cast(str, json_data.get("items")) return ArraySchema(items, names, other_props, validate_names) elif type_ == "map": - values = json_data.get("values") + values = cast(str, json_data.get("values")) return MapSchema(values, names, other_props, validate_names) elif type_ == "error_union": - declared_errors = json_data.get("declared_errors") + declared_errors = cast(List[str], json_data.get("declared_errors")) return ErrorUnionSchema(declared_errors, names, validate_names) else: raise avro.errors.SchemaParseException(f"Unknown Valid Type: {type_}") @@ -1313,7 +1302,7 @@ def make_avsc_object( return UnionSchema(json_data, names, validate_names=validate_names) # JSON string (primitive) elif json_data in avro.constants.PRIMITIVE_TYPES: - return PrimitiveSchema(json_data) + return PrimitiveSchema(cast(str, json_data)) # not for us! fail_msg = f"Could not make an Avro Schema object from {json_data}" raise avro.errors.SchemaParseException(fail_msg)