Skip to content

Commit

Permalink
[BugFix] Fix keys for nested lazy stacks (#745)
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens authored Apr 23, 2024
1 parent cf6f133 commit 357a981
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 29 deletions.
44 changes: 17 additions & 27 deletions tensordict/_lazy.py
Original file line number Diff line number Diff line change
Expand Up @@ -900,6 +900,20 @@ def _stack_onto_(
self.update_at_(td, idx)
return self

def _maybe_get_list(self, key):
vals = []
for td in self.tensordicts:
if isinstance(td, LazyStackedTensorDict):
val = td._maybe_get_list(key)
else:
val = td._get_str(key, None)
if _is_tensor_collection(type(val)):
return self._get_str(key, NO_DEFAULT)
elif val is None:
return None
vals.append(val)
return vals

@cache # noqa: B019
def _get_str(
self,
Expand Down Expand Up @@ -3400,34 +3414,10 @@ def names(self, value):
def _iter_items_lazystack(
tensordict: LazyStackedTensorDict, return_none_for_het_values: bool = False
) -> Iterator[tuple[str, CompatibleType]]:
# for key in tensordict.keys():
# try:
# value = tensordict.get(key)
# except RuntimeError as err:
# if return_none_for_het_values and re.match(
# r"Found more than one unique shape in the tensors", str(err)
# ):
# # this is a het key
# value = None
# else:
# raise err
# yield key, value
for key in tensordict.tensordicts[0].keys():
shapes = set()
values = []
is_tc = None
for td in tensordict.tensordicts:
val = td._get_str(key, None)
val_shape = getattr(val, "shape", None)
shapes.add(val_shape)
if is_tc is None:
is_tc = _is_tensor_collection(type(val))
values.append(val)
if None not in shapes:
if not is_tc:
yield key, values
else:
yield key, tensordict._get_str(key, NO_DEFAULT)
values = tensordict._maybe_get_list(key)
if values is not None:
yield key, values


_register_tensor_class(LazyStackedTensorDict)
Expand Down
6 changes: 4 additions & 2 deletions tensordict/_td.py
Original file line number Diff line number Diff line change
Expand Up @@ -3382,8 +3382,10 @@ def _iter_helper(
for key, value in self._items(tensordict):
full_key = self._combine_keys(prefix, key)
cls = value.__class__
if cls is list:
cls = value[0].__class__
while cls is list:
# For lazy stacks
value = value[0]
cls = value.__class__
is_leaf = self.is_leaf(cls)
if self.include_nested and not is_leaf:
yield from self._iter_helper(value, prefix=full_key)
Expand Down

0 comments on commit 357a981

Please sign in to comment.