diff --git a/test/inductor/test_codecache.py b/test/inductor/test_codecache.py index 70d1ae48f7cb09..b7b7f11ccea9bf 100644 --- a/test/inductor/test_codecache.py +++ b/test/inductor/test_codecache.py @@ -835,6 +835,22 @@ def uuid(self) -> Optional[Union[bytes, str]]: FxGraphCachePickler.dumps(details3), ) + def test_stable_strings(self): + """ + Test that objects containing identical strings pickle the same + even if they are not the same id. + """ + s1 = "string" + s2 = "strin" + s2 += "g" + + self.assertNotEqual(id(s1), id(s2)) + + self.assertEqual( + FxGraphCachePickler.dumps([s1, s1]), + FxGraphCachePickler.dumps([s1, s2]), + ) + def test_get_hash_for_files(self): """ Test the get_hash_for_files helper. diff --git a/torch/_functorch/_aot_autograd/autograd_cache.py b/torch/_functorch/_aot_autograd/autograd_cache.py index e5bc8bc95d3c6a..fe309bb96b2260 100644 --- a/torch/_functorch/_aot_autograd/autograd_cache.py +++ b/torch/_functorch/_aot_autograd/autograd_cache.py @@ -251,14 +251,8 @@ def _reduce_tensor(tensor): """ Reduce the tensor to a stable key for caching. """ - return ( - _ident, - ( - extract_tensor_metadata_for_cache_key( - FxGraphCachePickler._device_map, tensor - ), - ), - ) + metadata = extract_tensor_metadata_for_cache_key(tensor) + return (_ident, (metadata,)) class AOTAutogradCachePickler(FxGraphCachePickler): diff --git a/torch/_inductor/codecache.py b/torch/_inductor/codecache.py index 5886ac3628d85a..e118fd20b800e5 100644 --- a/torch/_inductor/codecache.py +++ b/torch/_inductor/codecache.py @@ -506,9 +506,7 @@ def _ident(x: T) -> T: return x -def extract_tensor_metadata_for_cache_key( - device_map: Dict[torch.device, torch.device], t: Tensor -) -> TensorMetadata: +def extract_tensor_metadata_for_cache_key(t: Tensor) -> TensorMetadata: """ Extracts the tensor metadata and removes fields of the TensorMetadata that are not needed for caching @@ -517,32 +515,19 @@ def extract_tensor_metadata_for_cache_key( if not hasattr(t, "_is_inductor_static"): meta = dataclasses.replace(meta, storage_offset=0, storage_bytes=None) - # The pickle implementation avoids serializing the same object more than once. - # That behavior means the byte stream we create to hash will vary if, for example, - # we see two tensor objects with the same device, but the torch.device object is - # actually the same object vs. merely equivalent. We want to produce the same hash - # value in either situation, so we memoize the device objects and always reference - # the same object for a given device. It's possible other metadata fields deserve - # the same treatment, but so far we've only observed this issue with the device. - if meta.device not in device_map: - device_map[meta.device] = meta.device - meta = dataclasses.replace(meta, device=device_map[meta.device]) - return meta -def _reduce_fake_tensor( - device_map: Dict[torch.device, torch.device], t: Tensor -) -> Tuple[Callable[[T], T], Tuple[TensorMetadata]]: +def _reduce_fake_tensor(t: Tensor) -> Tuple[Callable[[T], T], Tuple[TensorMetadata]]: """ See FxGraphCachePickler. Custom reducer to pickle FakeTensors. """ - metadata = extract_tensor_metadata_for_cache_key(device_map, t) + metadata = extract_tensor_metadata_for_cache_key(t) return (_ident, (metadata,)) def _reduce_tensor( - device_map: Dict[torch.device, torch.device], t: Tensor + t: Tensor, ) -> Tuple[Callable[[T], T], Tuple[TensorMetadataAndValues]]: """ See FxGraphCachePickler. Custom reducer to pickle Tensors. @@ -570,7 +555,7 @@ def _reduce_tensor( f"FX graph cache handling of a large constant took {elapsed:.1}s. Please file an issue." ) - metadata = extract_tensor_metadata_for_cache_key(device_map, t) + metadata = extract_tensor_metadata_for_cache_key(t) return (_ident, (TensorMetadataAndValues(metadata, values),)) @@ -600,13 +585,9 @@ class FxGraphCachePickler(pickle.Pickler): data that allow us to compute a stable, but safe hash. """ - # See extract_tensor_metadata_for_cache_key. Whenever we extract metadata during - # pickling, we make sure devices always reference the same torch.device object. - _device_map: Dict[torch.device, torch.device] = {} - dispatch_table = copyreg.dispatch_table.copy() - dispatch_table[FakeTensor] = functools.partial(_reduce_fake_tensor, _device_map) - dispatch_table[torch.Tensor] = functools.partial(_reduce_tensor, _device_map) + dispatch_table[FakeTensor] = _reduce_fake_tensor + dispatch_table[torch.Tensor] = _reduce_tensor dispatch_table[torch.SymInt] = _reduce_symint dispatch_table[ torch.fx.experimental._backward_state.BackwardState @@ -648,7 +629,7 @@ def debug_lines(cls, inp: FxGraphHashDetails) -> List[str]: def get_str(obj: Any) -> str: if isinstance(obj, torch.Tensor): - return str(extract_tensor_metadata_for_cache_key(cls._device_map, obj)) + return str(extract_tensor_metadata_for_cache_key(obj)) elif isinstance(obj, bytes): return "" elif type(obj) in cls.dispatch_table: