Skip to content

Commit

Permalink
AVRO-1737 Implement Hashable Schema
Browse files Browse the repository at this point in the history
A hashable thing can be a member of a set or a key in a dictionary.
  • Loading branch information
kojiromike committed Jul 19, 2023
1 parent 5e6cec1 commit 184a055
Show file tree
Hide file tree
Showing 2 changed files with 63 additions and 7 deletions.
44 changes: 37 additions & 7 deletions lang/py/avro/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,6 +247,10 @@ def match(self, writer: "Schema") -> bool:
@return bool
"""

def __hash__(self) -> int:

Check warning

Code scanning / CodeQL

Inconsistent equality and hashing Warning

Class
Schema
implements __hash__ but does not define __eq__.
Class
NamedSchema
implements __hash__ but does not define __eq__.
Class
PrimitiveSchema
implements __hash__ but does not define __eq__.
Class
BytesDecimalSchema
implements __hash__ but does not define __eq__.
Class
FixedSchema
implements __hash__ but does not define __eq__.
Class
FixedDecimalSchema
implements __hash__ but does not define __eq__.
Class
EnumSchema
implements __hash__ but does not define __eq__.
Class
ArraySchema
implements __hash__ but does not define __eq__.
Class
MapSchema
implements __hash__ but does not define __eq__.
Class
UnionSchema
implements __hash__ but does not define __eq__.
Class
ErrorUnionSchema
implements __hash__ but does not define __eq__.
Class
RecordSchema
implements __hash__ but does not define __eq__.
Class
DateSchema
implements __hash__ but does not define __eq__.
Class
TimeMillisSchema
implements __hash__ but does not define __eq__.
Class
TimeMicrosSchema
implements __hash__ but does not define __eq__.
Class
TimestampMillisSchema
implements __hash__ but does not define __eq__.
Class
TimestampMicrosSchema
implements __hash__ but does not define __eq__.
Class
UUIDSchema
implements __hash__ but does not define __eq__.
"""Make it so a schema can be in a set or a key in a dictionary."""
return hash(self.fingerprint())

def __str__(self) -> str:
return json.dumps(self.to_json())

Expand Down Expand Up @@ -288,13 +292,6 @@ def canonical_form(self) -> str:
# The separators eliminate whitespace around commas and colons.
return json.dumps(self.to_canonical_json(), separators=(",", ":"))

@abc.abstractmethod
def __eq__(self, that: object) -> bool:
"""
Determines how two schema are compared.
Consider the mixins EqualByPropsMixin and EqualByJsonMixin
"""

def fingerprint(self, algorithm="CRC-64-AVRO") -> bytes:
"""
Generate fingerprint for supplied algorithm.
Expand Down Expand Up @@ -323,6 +320,8 @@ def fingerprint(self, algorithm="CRC-64-AVRO") -> bytes:
class NamedSchema(Schema):
"""Named Schemas specified in NAMED_TYPES."""

__hash__ = Schema.__hash__

def __init__(
self,
type_: str,
Expand Down Expand Up @@ -478,6 +477,8 @@ def to_canonical_json(self, names=None):
class PrimitiveSchema(EqualByPropsMixin, Schema):
"""Valid primitive types are in PRIMITIVE_TYPES."""

__hash__ = Schema.__hash__

_validators = {
"null": lambda x: x is None,
"boolean": lambda x: isinstance(x, bool),
Expand Down Expand Up @@ -540,6 +541,8 @@ def validate(self, datum):


class BytesDecimalSchema(PrimitiveSchema, DecimalLogicalSchema):
__hash__ = Schema.__hash__

def __init__(self, precision, scale=0, other_props=None):
DecimalLogicalSchema.__init__(self, precision, scale, max_precision=((1 << 31) - 1))
PrimitiveSchema.__init__(self, "bytes", other_props)
Expand Down Expand Up @@ -567,6 +570,8 @@ def validate(self, datum):
# Complex Types (non-recursive)
#
class FixedSchema(EqualByPropsMixin, NamedSchema):
__hash__ = Schema.__hash__

def __init__(self, name, namespace, size, names=None, other_props=None, validate_names: bool = True):
# Ensure valid ctor args
if not isinstance(size, int) or size < 0:
Expand Down Expand Up @@ -618,6 +623,8 @@ def validate(self, datum):


class FixedDecimalSchema(FixedSchema, DecimalLogicalSchema):
__hash__ = Schema.__hash__

def __init__(
self,
size,
Expand Down Expand Up @@ -653,6 +660,8 @@ def validate(self, datum):


class EnumSchema(EqualByPropsMixin, NamedSchema):
__hash__ = Schema.__hash__

def __init__(
self,
name: str,
Expand Down Expand Up @@ -740,6 +749,8 @@ def validate(self, datum):


class ArraySchema(EqualByJsonMixin, Schema):
__hash__ = Schema.__hash__

def __init__(self, items, names=None, other_props=None, validate_names: bool = True):
# Call parent ctor
Schema.__init__(self, "array", other_props, validate_names=validate_names)
Expand Down Expand Up @@ -793,6 +804,8 @@ def validate(self, datum):


class MapSchema(EqualByJsonMixin, Schema):
__hash__ = Schema.__hash__

def __init__(self, values, names=None, other_props=None, validate_names: bool = True):
# Call parent ctor
Schema.__init__(self, "map", other_props, validate_names=validate_names)
Expand Down Expand Up @@ -845,6 +858,7 @@ def validate(self, datum):


class UnionSchema(EqualByJsonMixin, Schema):
__hash__ = Schema.__hash__
"""
names is a dictionary of schema objects
"""
Expand Down Expand Up @@ -916,6 +930,8 @@ def validate(self, datum):


class ErrorUnionSchema(UnionSchema):
__hash__ = Schema.__hash__

def __init__(self, schemas, names=None, validate_names: bool = True):
# Prepend "string" to handle system errors
UnionSchema.__init__(self, ["string"] + schemas, names, validate_names)
Expand All @@ -934,6 +950,8 @@ def to_json(self, names=None):


class RecordSchema(EqualByJsonMixin, NamedSchema):
__hash__ = Schema.__hash__

@staticmethod
def make_field_objects(field_data: Sequence[Mapping[str, object]], names: avro.name.Names, validate_names: bool = True) -> Sequence[Field]:
"""We're going to need to make message parameters too."""
Expand Down Expand Up @@ -1068,6 +1086,8 @@ def validate(self, datum):


class DateSchema(LogicalSchema, PrimitiveSchema):
__hash__ = Schema.__hash__

def __init__(self, other_props=None):
LogicalSchema.__init__(self, avro.constants.DATE)
PrimitiveSchema.__init__(self, "int", other_props)
Expand All @@ -1086,6 +1106,8 @@ def validate(self, datum):


class TimeMillisSchema(LogicalSchema, PrimitiveSchema):
__hash__ = Schema.__hash__

def __init__(self, other_props=None):
LogicalSchema.__init__(self, avro.constants.TIME_MILLIS)
PrimitiveSchema.__init__(self, "int", other_props)
Expand All @@ -1104,6 +1126,8 @@ def validate(self, datum):


class TimeMicrosSchema(LogicalSchema, PrimitiveSchema):
__hash__ = Schema.__hash__

def __init__(self, other_props=None):
LogicalSchema.__init__(self, avro.constants.TIME_MICROS)
PrimitiveSchema.__init__(self, "long", other_props)
Expand All @@ -1122,6 +1146,8 @@ def validate(self, datum):


class TimestampMillisSchema(LogicalSchema, PrimitiveSchema):
__hash__ = Schema.__hash__

def __init__(self, other_props=None):
LogicalSchema.__init__(self, avro.constants.TIMESTAMP_MILLIS)
PrimitiveSchema.__init__(self, "long", other_props)
Expand All @@ -1139,6 +1165,8 @@ def validate(self, datum):


class TimestampMicrosSchema(LogicalSchema, PrimitiveSchema):
__hash__ = Schema.__hash__

def __init__(self, other_props=None):
LogicalSchema.__init__(self, avro.constants.TIMESTAMP_MICROS)
PrimitiveSchema.__init__(self, "long", other_props)
Expand All @@ -1156,6 +1184,8 @@ def validate(self, datum):


class UUIDSchema(LogicalSchema, PrimitiveSchema):
__hash__ = Schema.__hash__

def __init__(self, other_props=None):
LogicalSchema.__init__(self, avro.constants.UUID)
PrimitiveSchema.__init__(self, "string", other_props)
Expand Down
26 changes: 26 additions & 0 deletions lang/py/avro/test/test_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -890,6 +890,31 @@ def parse_invalid(self):
self.test_schema.parse()


class HashableTestCase(unittest.TestCase):
"""Ensure that Schema are hashable.
While hashability is implemented with parsing canonical form fingerprinting,
this test should be kept distinct to avoid coupling."""

def __init__(self, test_schema):
"""Ignore the normal signature for unittest.TestCase because we are generating
many test cases from this one class. This is safe as long as the autoloader
ignores this class. The autoloader will ignore this class as long as it has
no methods starting with `test_`.
"""
super().__init__("parse_and_hash")
self.test_schema = test_schema

def parse_and_hash(self):
"""Ensure that every schema can be hashed."""
try:
hash(self.test_schema.parse())
except TypeError as e:
if "unhashable type" in str(e):
self.fail(f"{self.test_schema} is not hashable")
raise


class RoundTripParseTestCase(unittest.TestCase):
"""Enable generating round-trip parse test cases over all the valid test schema."""

Expand Down Expand Up @@ -1434,6 +1459,7 @@ def load_tests(loader, default_tests, pattern):
suite.addTests(OtherAttributesTestCase(ex) for ex in OTHER_PROP_EXAMPLES)
suite.addTests(loader.loadTestsFromTestCase(CanonicalFormTestCase))
suite.addTests(FingerprintTestCase(ex[0], ex[1]) for ex in FINGERPRINT_EXAMPLES)
suite.addTests(HashableTestCase(ex) for ex in VALID_EXAMPLES)
return suite


Expand Down

0 comments on commit 184a055

Please sign in to comment.