diff --git a/pytato/array.py b/pytato/array.py index 75fb16e0e..196f49257 100644 --- a/pytato/array.py +++ b/pytato/array.py @@ -344,35 +344,65 @@ def _augment_array_dataclass( cls: type, generate_hash: bool, ) -> None: - from dataclasses import fields - attr_tuple = ", ".join(f"self.{fld.name}" - for fld in fields(cls) if fld.name != "non_equality_tags") - if attr_tuple: - attr_tuple = f"({attr_tuple},)" - else: - attr_tuple = "()" + + # {{{ hashing and hash caching if generate_hash: + from dataclasses import fields + + # Non-equality tags are automatically excluded from equality in + # EqualityComparer, and are excluded here from hashing. + attr_tuple_hash = ", ".join(f"self.{fld.name}" + for fld in fields(cls) if fld.name != "non_equality_tags") + + if attr_tuple_hash: + attr_tuple_hash = f"({attr_tuple_hash},)" + else: + attr_tuple_hash = "()" + from pytools.codegen import remove_common_indentation augment_code = remove_common_indentation( f""" + from dataclasses import fields + def {cls.__name__}_hash(self): try: return self._hash_value except AttributeError: pass - h = hash(frozenset({attr_tuple})) + h = hash(frozenset({attr_tuple_hash})) object.__setattr__(self, "_hash_value", h) return h cls.__hash__ = {cls.__name__}_hash + + # By default (when slots=False), dataclasses do not have special + # handling for pickling, thus using pickle's default behavior that + # looks at obj.__dict__. This would also pickle the cached hash, + # which may change across invocations. Here, we override the + # pickling methods such that only fields are pickled. + # See also https://github.com/python/cpython/blob/5468d219df65d4fe3335e2bcc09d2f6032a32c70/Lib/dataclasses.py#L1267-L1272 + + def _dataclass_getstate(self): + return [getattr(self, f.name) for f in fields(self)] + + + def _dataclass_setstate(self, state): + for field, value in zip(fields(self), state, strict=True): + # use setattr because dataclass may be frozen + object.__setattr__(self, field.name, value) + + cls.__getstate__ = _dataclass_getstate + cls.__setstate__ = _dataclass_setstate """) exec_dict = {"cls": cls, "_MODULE_SOURCE_CODE": augment_code} exec(compile(augment_code, f"", "exec"), exec_dict) + # }}} + # {{{ assign mapper_method mm_cls = cast(type[_HasMapperMethod], cls) diff --git a/test/test_pytato.py b/test/test_pytato.py index 46c794d3e..4a29ae13c 100644 --- a/test/test_pytato.py +++ b/test/test_pytato.py @@ -1379,6 +1379,53 @@ def dtype(self): assert _np_result_dtype(42.0, NotReallyAnArray()) == np.float64 +def test_pickling_hash(): + # See https://github.com/inducer/pytato/pull/563 for context + + # {{{ Placeholder + + p = pt.make_placeholder("p", (4, 4), int) + + assert not hasattr(p, "_hash_value") + + # Force hash creation: + hash(p) + + assert hasattr(p, "_hash_value") + + from pickle import dumps, loads + + p_new = loads(dumps(p)) + + assert not hasattr(p_new, "_hash_value") + + assert p == p_new + + # }}} + + # {{{ DataWrapper + + dw = pt.make_data_wrapper(np.zeros((4, 4), int)) + + assert not hasattr(dw, "_hash_value") + + hash(dw) + + # DataWrappers have no hash caching + assert not hasattr(dw, "_hash_value") + + dw_new = loads(dumps(dw)) + + assert dw_new.shape == dw.shape + assert dw_new.dtype == dw.dtype + assert np.all(dw_new.data == dw.data) + + # DataWrappers that are not the same object compare unequal + assert dw != dw_new + + # }}} + + if __name__ == "__main__": if len(sys.argv) > 1: exec(sys.argv[1])