Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
[ghstack-poisoned]
  • Loading branch information
vmoens committed Jan 8, 2025
1 parent 4393bb8 commit ea86a4b
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 5 deletions.
7 changes: 3 additions & 4 deletions tensordict/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,7 @@ def new_func(self, *others, **kwargs):
if other is None:
continue
if (isinstance(other, torch.Tensor) and other.ndim) or (
is_tensor_collection(other)
_is_tensor_collection(type(other))
and other.ndim
and other.shape != self.shape
):
Expand All @@ -228,9 +228,8 @@ def new_func(self, *others, **kwargs):
others_map = []
shape = self.shape
self_expand = self
shape = torch.broadcast_shapes(
shape, *[other.shape for other in others if other is not None]
)
shapes = [shape, *[other.shape for other in others if other is not None]]
shape = torch.broadcast_shapes(*shapes)
if shape != self_expand.shape:
self_expand = self_expand.expand(shape)
for other in others:
Expand Down
2 changes: 1 addition & 1 deletion test/test_tensordict.py
Original file line number Diff line number Diff line change
Expand Up @@ -7184,7 +7184,7 @@ def test_unflatten_keys(self, td_name, device, inplace, separator):
td_unflatten = td_flatten.unflatten_keys(
inplace=inplace, separator=separator
)
assert (td == td_unflatten).all()
assert (td == td.empty(recurse=True).update(td_unflatten)).all()
if inplace:
assert td is td_unflatten

Expand Down

0 comments on commit ea86a4b

Please sign in to comment.