Skip to content

Commit

Permalink
[BugFix] Fix pickling of weakrefs (#597)
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens authored Dec 13, 2023
1 parent e3353f1 commit 4071e30
Show file tree
Hide file tree
Showing 4 changed files with 56 additions and 6 deletions.
11 changes: 10 additions & 1 deletion tensordict/_td.py
Original file line number Diff line number Diff line change
Expand Up @@ -1758,7 +1758,8 @@ def __getstate__(self):
result = {
key: val
for key, val in self.__dict__.items()
if key not in ("_last_op", "_cache", "__last_op_queue")
if key
not in ("_last_op", "_cache", "__last_op_queue", "__lock_parents_weakrefs")
}
return result

Expand All @@ -1768,6 +1769,14 @@ def __setstate__(self, state):
self._cache = None
self.__last_op_queue = None
self._last_op = None
if self._is_locked:
# this can cause avoidable overhead, as we will be locking the leaves
# then locking their parent, and the parent of the parent, every
# time re-locking tensordicts that have already been locked.
# To avoid this, we should lock only at the root, but it isn't easy
# to spot what the root is...
self._is_locked = False
self.lock_()

# some custom methods for efficiency
def items(
Expand Down
20 changes: 16 additions & 4 deletions tensordict/persistent.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,7 @@ class PersistentTensorDict(TensorDictBase):
"""

_td_dim_names = None
LOCKING = None

def __init__(
self,
Expand All @@ -147,7 +148,7 @@ def __init__(
if backend != "h5":
raise NotImplementedError
if filename is not None and group is None:
self.file = h5py.File(filename, mode)
self.file = h5py.File(filename, mode, locking=self.LOCKING)
elif group is not None:
self.file = group
else:
Expand Down Expand Up @@ -202,7 +203,7 @@ def from_dict(cls, input_dict, filename, batch_size=None, device=None, **kwargs)
A :class:`PersitentTensorDict` instance linked to the newly created file.
"""
file = h5py.File(filename, "w")
file = h5py.File(filename, "w", locking=cls.LOCKING)
_has_batch_size = True
if batch_size is None:
if is_tensor_collection(input_dict):
Expand Down Expand Up @@ -931,7 +932,7 @@ def clone(self, recurse: bool = True, newfile=None) -> PersistentTensorDict:
)
tmpfile = tempfile.NamedTemporaryFile()
newfile = tmpfile.name
f_dest = h5py.File(newfile, "w")
f_dest = h5py.File(newfile, "w", locking=self.LOCKING)
f_src = self.file
for key in self.keys(include_nested=True, leaves_only=True):
key = self._process_key(key)
Expand Down Expand Up @@ -974,14 +975,25 @@ def __getstate__(self):
state["file"] = None
state["filename"] = filename
state["group_name"] = group_name
state["__lock_parents_weakrefs"] = None
return state

def __setstate__(self, state):
state["file"] = h5py.File(state["filename"], mode=state["mode"])
state["file"] = h5py.File(
state["filename"], mode=state["mode"], locking=self.LOCKING
)
if state["group_name"] != "/":
state["file"] = state["file"][state["group_name"]]
del state["group_name"]
self.__dict__.update(state)
if self._is_locked:
# this can cause avoidable overhead, as we will be locking the leaves
# then locking their parent, and the parent of the parent, every
# time re-locking tensordicts that have already been locked.
# To avoid this, we should lock only at the root, but it isn't easy
# to spot what the root is...
self._is_locked = False
self.lock_()

def _add_batch_dim(self, *, in_dim, vmap_level):
raise RuntimeError("Persistent tensordicts cannot be used with vmap.")
Expand Down
3 changes: 2 additions & 1 deletion test/_utils_internal.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,8 @@ def nested_stacked_td(self, device):
batch_size=[4, 3, 2, 1],
device=device,
)
return torch.stack(list(td.unbind(1)), 1)
# we need to clone to avoid passing a views other tensors
return torch.stack([_td.clone() for _td in td.unbind(1)], 1)

for device in get_available_devices():
TYPES_DEVICES += [["nested_stacked_td", device]]
Expand Down
28 changes: 28 additions & 0 deletions test/test_tensordict.py
Original file line number Diff line number Diff line change
Expand Up @@ -6428,6 +6428,34 @@ def test_dense_stack_tds(stack_dim, nested_stack_dim):
TestTensorDictsBase.TYPES_DEVICES,
)
class TestTensorDictMP(TestTensorDictsBase):
# Tests sharing a locked tensordict
@staticmethod
def worker_lock(td, q):
assert td.is_locked
for val in td.values(True):
if is_tensor_collection(val):
assert val.is_locked
assert val._lock_parents_weakrefs
assert not td._lock_parents_weakrefs
q.put("succeeded")

def test_sharing_locked_td(self, td_name, device):
td = getattr(self, td_name)(device)
if td_name in ("sub_td", "sub_td2"):
pytest.skip("cannot lock sub-tds")
if td_name in ("td_h5",):
pytest.skip("h5 files should not be opened across different processes.")
q = mp.Queue(1)
try:
p = mp.Process(target=self.worker_lock, args=(td.lock_(), q))
p.start()
assert q.get(timeout=30) == "succeeded"
finally:
try:
p.join()
except AssertionError:
pass

@staticmethod
def add1(x):
return x + 1
Expand Down

0 comments on commit 4071e30

Please sign in to comment.