From c5abce152eebf7cc55b3ea8c38a010823335a626 Mon Sep 17 00:00:00 2001 From: Albert Zeyer Date: Fri, 18 Oct 2024 12:29:33 +0200 Subject: [PATCH] fix hash --- sisyphus/hash.py | 25 +++++++++++++++++++------ 1 file changed, 19 insertions(+), 6 deletions(-) diff --git a/sisyphus/hash.py b/sisyphus/hash.py index 1263304..33d5d91 100644 --- a/sisyphus/hash.py +++ b/sisyphus/hash.py @@ -42,6 +42,7 @@ def short_hash(obj, length=12, chars="0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdef _BasicSeqTypes: Tuple[type, ...] = (list, tuple) _BasicSetTypes: Tuple[type, ...] = (set, frozenset) _BasicDictTypes: Tuple[type, ...] = (dict,) +_BasicTypesCombined: Tuple[type, ...] = _BasicTypes + _BasicSeqTypes + _BasicSetTypes + _BasicDictTypes def get_object_state(obj): @@ -56,8 +57,16 @@ def get_object_state(obj): # so we keep consistent to the behavior of sis_hash_helper. if obj is None: return None - if isinstance(obj, _BasicTypes + _BasicSeqTypes + _BasicSetTypes + _BasicDictTypes): - return obj + if isinstance(obj, _BasicTypesCombined): + for type_ in _BasicTypesCombined: + if isinstance(obj, type_): + if type(obj) is type_: + return obj + else: + # This is a derived type. E.g. consider a namedtuple or np.float. + # We want to return the basic type, to break any potential recursion. + return type(obj) + assert False, f"should not get here, obj {obj!r} type {type(obj)!r}" if isfunction(obj) or isclass(obj): return obj.__module__, obj.__qualname__ @@ -118,13 +127,17 @@ def sis_hash_helper(obj): byte_list.append(obj) elif obj is None: pass - elif isinstance(obj, _BasicTypes): + # Note: Using type(obj) in _Types instead of isinstance(obj, _Types) + # because of historical reasons (and we cannot change this now). + # For derived types (e.g. namedtuple, np.float), it is then handled by get_object_state. + # That's why the handling of get_object_state for those types is important. + elif type(obj) in _BasicTypes: byte_list.append(repr(obj).encode()) - elif isinstance(obj, _BasicSeqTypes): + elif type(obj) in _BasicSeqTypes: byte_list += map(sis_hash_helper, obj) - elif isinstance(obj, _BasicSetTypes): + elif type(obj) in _BasicSetTypes: byte_list += sorted(map(sis_hash_helper, obj)) - elif isinstance(obj, _BasicDictTypes): + elif type(obj) in _BasicDictTypes: # sort items to ensure they are always in the same order byte_list += sorted(map(sis_hash_helper, obj.items())) elif isfunction(obj):