diff --git a/CHANGELOG.md b/CHANGELOG.md index bd685a51..d61be807 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,6 +11,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - The `safe_run` method did unintentionally double-wrap the run method, if it already had a `make_action_safe` decorator. This is now fixed. +### Fixed +- Under certain conditions hashing of an object defined in the `__main__` module failed. + This release implements a workaround for this issue, that should hopefully resolve most cases. + ## [0.12.0] - 2022-11-15 ### Added diff --git a/tpcp/_hash.py b/tpcp/_hash.py index 941e648b..05101ecc 100644 --- a/tpcp/_hash.py +++ b/tpcp/_hash.py @@ -22,12 +22,42 @@ def memoize(self, obj): return def hash(self, obj, return_digest=True): + """Get hash while handling some edgecases. + + Namely, this implementation fixes the following issues: + + - Because we skip memoization, we need to handle the case where the object is self-referential. + We just catch the error and raise a more descriptive error message. + - We need to handle the case where the object is defined in the `__main__` module. + For some reason, this can lead to pickle issues. + Based on some information I found, this should not happen, but it still does... + To fix it, we detect, when an object is defined in `__main__` and temporarily add it to the "real" module + representing the main function. + Afterwards we do some cleanup. + Not sure if really required, but it seems to work. + Overall very hacky, but I don't see a better way to fix this. + + """ + modules_modified = [] + if getattr(obj, "__module__", None) == "__main__": + try: + name = obj.__name__ + except AttributeError: + name = obj.__class__.__name__ + mod = sys.modules["__main__"] + if not hasattr(mod, name): + modules_modified.append((mod, name)) + setattr(mod, name, obj) try: return super().hash(obj, return_digest) except RecursionError as e: raise ValueError( "The custom hasher used in tpcp does not support hashing of self-referential objects." ) from e + finally: + # Remove all new entries made to the main module. + for mod, name in modules_modified: + delattr(mod, name) class NoMemoizeNumpyHasher(NumpyHasher, NoMemoizeHasher):