Skip to content

Commit

Permalink
[fx graph cache] FxGraphPickler: Remove hack to stabilize device stri…
Browse files Browse the repository at this point in the history
…ng hashes (pytorch#138681)

Summary: With the fast pickling mode, we don't need the custom hack for replacing device strings in tensors. This was previously needed because, e.g., two strings "cuda" will pickle differently if they are the same object vs. not.

Test Plan:
The new test fails with fast mode commented out, but succeeds when enabled:
`python test/inductor/test_codecache.py -k test_stable_strings`

Pull Request resolved: pytorch#138681
Approved by: https://github.com/oulgen
  • Loading branch information
masnesral authored and pytorchmergebot committed Oct 28, 2024
1 parent 3b0f393 commit ad93357
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 35 deletions.
16 changes: 16 additions & 0 deletions test/inductor/test_codecache.py
Original file line number Diff line number Diff line change
Expand Up @@ -835,6 +835,22 @@ def uuid(self) -> Optional[Union[bytes, str]]:
FxGraphCachePickler.dumps(details3),
)

def test_stable_strings(self):
"""
Test that objects containing identical strings pickle the same
even if they are not the same id.
"""
s1 = "string"
s2 = "strin"
s2 += "g"

self.assertNotEqual(id(s1), id(s2))

self.assertEqual(
FxGraphCachePickler.dumps([s1, s1]),
FxGraphCachePickler.dumps([s1, s2]),
)

def test_get_hash_for_files(self):
"""
Test the get_hash_for_files helper.
Expand Down
10 changes: 2 additions & 8 deletions torch/_functorch/_aot_autograd/autograd_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,14 +251,8 @@ def _reduce_tensor(tensor):
"""
Reduce the tensor to a stable key for caching.
"""
return (
_ident,
(
extract_tensor_metadata_for_cache_key(
FxGraphCachePickler._device_map, tensor
),
),
)
metadata = extract_tensor_metadata_for_cache_key(tensor)
return (_ident, (metadata,))


class AOTAutogradCachePickler(FxGraphCachePickler):
Expand Down
35 changes: 8 additions & 27 deletions torch/_inductor/codecache.py
Original file line number Diff line number Diff line change
Expand Up @@ -506,9 +506,7 @@ def _ident(x: T) -> T:
return x


def extract_tensor_metadata_for_cache_key(
device_map: Dict[torch.device, torch.device], t: Tensor
) -> TensorMetadata:
def extract_tensor_metadata_for_cache_key(t: Tensor) -> TensorMetadata:
"""
Extracts the tensor metadata and removes fields of the TensorMetadata
that are not needed for caching
Expand All @@ -517,32 +515,19 @@ def extract_tensor_metadata_for_cache_key(
if not hasattr(t, "_is_inductor_static"):
meta = dataclasses.replace(meta, storage_offset=0, storage_bytes=None)

# The pickle implementation avoids serializing the same object more than once.
# That behavior means the byte stream we create to hash will vary if, for example,
# we see two tensor objects with the same device, but the torch.device object is
# actually the same object vs. merely equivalent. We want to produce the same hash
# value in either situation, so we memoize the device objects and always reference
# the same object for a given device. It's possible other metadata fields deserve
# the same treatment, but so far we've only observed this issue with the device.
if meta.device not in device_map:
device_map[meta.device] = meta.device
meta = dataclasses.replace(meta, device=device_map[meta.device])

return meta


def _reduce_fake_tensor(
device_map: Dict[torch.device, torch.device], t: Tensor
) -> Tuple[Callable[[T], T], Tuple[TensorMetadata]]:
def _reduce_fake_tensor(t: Tensor) -> Tuple[Callable[[T], T], Tuple[TensorMetadata]]:
"""
See FxGraphCachePickler. Custom reducer to pickle FakeTensors.
"""
metadata = extract_tensor_metadata_for_cache_key(device_map, t)
metadata = extract_tensor_metadata_for_cache_key(t)
return (_ident, (metadata,))


def _reduce_tensor(
device_map: Dict[torch.device, torch.device], t: Tensor
t: Tensor,
) -> Tuple[Callable[[T], T], Tuple[TensorMetadataAndValues]]:
"""
See FxGraphCachePickler. Custom reducer to pickle Tensors.
Expand Down Expand Up @@ -570,7 +555,7 @@ def _reduce_tensor(
f"FX graph cache handling of a large constant took {elapsed:.1}s. Please file an issue."
)

metadata = extract_tensor_metadata_for_cache_key(device_map, t)
metadata = extract_tensor_metadata_for_cache_key(t)
return (_ident, (TensorMetadataAndValues(metadata, values),))


Expand Down Expand Up @@ -600,13 +585,9 @@ class FxGraphCachePickler(pickle.Pickler):
data that allow us to compute a stable, but safe hash.
"""

# See extract_tensor_metadata_for_cache_key. Whenever we extract metadata during
# pickling, we make sure devices always reference the same torch.device object.
_device_map: Dict[torch.device, torch.device] = {}

dispatch_table = copyreg.dispatch_table.copy()
dispatch_table[FakeTensor] = functools.partial(_reduce_fake_tensor, _device_map)
dispatch_table[torch.Tensor] = functools.partial(_reduce_tensor, _device_map)
dispatch_table[FakeTensor] = _reduce_fake_tensor
dispatch_table[torch.Tensor] = _reduce_tensor
dispatch_table[torch.SymInt] = _reduce_symint
dispatch_table[
torch.fx.experimental._backward_state.BackwardState
Expand Down Expand Up @@ -648,7 +629,7 @@ def debug_lines(cls, inp: FxGraphHashDetails) -> List[str]:

def get_str(obj: Any) -> str:
if isinstance(obj, torch.Tensor):
return str(extract_tensor_metadata_for_cache_key(cls._device_map, obj))
return str(extract_tensor_metadata_for_cache_key(obj))
elif isinstance(obj, bytes):
return "<bytes>"
elif type(obj) in cls.dispatch_table:
Expand Down

0 comments on commit ad93357

Please sign in to comment.