diff --git a/lang/py/avro/schema.py b/lang/py/avro/schema.py index f852e146a3e..2672d736eb9 100644 --- a/lang/py/avro/schema.py +++ b/lang/py/avro/schema.py @@ -50,6 +50,7 @@ from functools import reduce from pathlib import Path from typing import ( + Callable, FrozenSet, List, Mapping, @@ -197,9 +198,11 @@ def other_props(self) -> Mapping[str, object]: return get_other_props(self.props, self._reserved_properties) -class EqualByJsonMixin: +class EqualByJsonMixin(collections.abc.Hashable): """A mixin that defines equality as equal if the json deserializations are equal.""" + fingerprint: Callable[..., bytes] + def __eq__(self, that: object) -> bool: try: that_obj = json.loads(str(that)) @@ -207,13 +210,29 @@ def __eq__(self, that: object) -> bool: return False return cast(bool, json.loads(str(self)) == that_obj) + def __hash__(self) -> int: + """Make it so a schema can be in a set or a key in a dictionary. + + NB: Python has special rules for this method being defined in the same class as __eq__. + """ + return hash(self.fingerprint()) + -class EqualByPropsMixin(PropertiesMixin): +class EqualByPropsMixin(collections.abc.Hashable, PropertiesMixin): """A mixin that defines equality as equal if the props are equal.""" + fingerprint: Callable[..., bytes] + def __eq__(self, that: object) -> bool: return hasattr(that, "props") and self.props == getattr(that, "props") + def __hash__(self) -> int: + """Make it so a schema can be in a set or a key in a dictionary. + + NB: Python has special rules for this method being defined in the same class as __eq__. + """ + return hash(self.fingerprint()) + class CanonicalPropertiesMixin(PropertiesMixin): """A Mixin that provides canonical properties to Schema and Field types.""" diff --git a/lang/py/avro/test/test_schema.py b/lang/py/avro/test/test_schema.py index 668ca8258f2..6423185eff9 100644 --- a/lang/py/avro/test/test_schema.py +++ b/lang/py/avro/test/test_schema.py @@ -890,6 +890,31 @@ def parse_invalid(self): self.test_schema.parse() +class HashableTestCase(unittest.TestCase): + """Ensure that Schema are hashable. + + While hashability is implemented with parsing canonical form fingerprinting, + this test should be kept distinct to avoid coupling.""" + + def __init__(self, test_schema): + """Ignore the normal signature for unittest.TestCase because we are generating + many test cases from this one class. This is safe as long as the autoloader + ignores this class. The autoloader will ignore this class as long as it has + no methods starting with `test_`. + """ + super().__init__("parse_and_hash") + self.test_schema = test_schema + + def parse_and_hash(self): + """Ensure that every schema can be hashed.""" + try: + hash(self.test_schema.parse()) + except TypeError as e: + if "unhashable type" in str(e): + self.fail(f"{self.test_schema} is not hashable") + raise + + class RoundTripParseTestCase(unittest.TestCase): """Enable generating round-trip parse test cases over all the valid test schema.""" @@ -1434,6 +1459,7 @@ def load_tests(loader, default_tests, pattern): suite.addTests(OtherAttributesTestCase(ex) for ex in OTHER_PROP_EXAMPLES) suite.addTests(loader.loadTestsFromTestCase(CanonicalFormTestCase)) suite.addTests(FingerprintTestCase(ex[0], ex[1]) for ex in FINGERPRINT_EXAMPLES) + suite.addTests(HashableTestCase(ex) for ex in VALID_EXAMPLES) return suite