Skip to content

Commit

Permalink
[BugFix] Proper masks for padding with custom pad value
Browse files Browse the repository at this point in the history
ghstack-source-id: 0580f89ce9bbaf5a13bab33f9c9b8f5a9e9df96f
Pull Request resolved: #1185
  • Loading branch information
vmoens committed Jan 15, 2025
1 parent 0013e38 commit bbf773b
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 11 deletions.
7 changes: 5 additions & 2 deletions tensordict/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,17 +156,19 @@ def pad_sequence(
"plase convert the tensorclasses to TensorDicts first."
)

masks_key = "masks"
if not isinstance(return_mask, bool):
masks_key = unravel_key(return_mask)
return_mask = True
else:
masks_key = "masks"

# check that all tensordict match
update_batch_size = True
max_seq_length = float("-inf")
keys = _check_keys(list_of_tensordicts, leaves_only=True, include_nested=True)
list_of_dicts = [{} for _ in range(len(list_of_tensordicts))]
keys_copy = list(keys)
mask_keys = []
for i, td in enumerate(list_of_tensordicts):
if is_tensorclass(td):
td = td._tensordict
Expand Down Expand Up @@ -197,6 +199,7 @@ def pad_sequence(

if return_mask:
mask_key = unravel_key((masks_key, key))
mask_keys.append(mask_key)
list_of_dicts[i][mask_key] = torch.ones(mask_shape, dtype=torch.bool)
keys_copy.append(mask_key)

Expand Down Expand Up @@ -229,7 +232,7 @@ def pad_sequence(
torch.nn.utils.rnn.pad_sequence(
[d[key].transpose(0, pos_pad_dim) for d in list_of_dicts],
batch_first=True,
padding_value=padding_value,
padding_value=padding_value if key not in mask_keys else False,
).transpose(1, pos_pad_dim + 1),
inplace=True,
)
Expand Down
25 changes: 16 additions & 9 deletions test/test_tensordict.py
Original file line number Diff line number Diff line change
Expand Up @@ -1942,7 +1942,8 @@ class Sample:
assert d.b == ["asd", "efg"]

@pytest.mark.parametrize("make_mask", [True, ("bibbidi", "bobbidi", "boo"), False])
def test_pad_sequence_pad_dim0(self, make_mask):
@pytest.mark.parametrize("pad_val", [0, -1])
def test_pad_sequence_pad_dim0(self, make_mask, pad_val):
pad_dim = 0
list_td = [
TensorDict(
Expand All @@ -1953,7 +1954,9 @@ def test_pad_sequence_pad_dim0(self, make_mask):
[4],
),
]
padded_td = pad_sequence(list_td, pad_dim=pad_dim, return_mask=make_mask)
padded_td = pad_sequence(
list_td, pad_dim=pad_dim, return_mask=make_mask, padding_value=pad_val
)
assert padded_td.shape == torch.Size(
[2, 4]
) # check the shape of the padded tensordict
Expand All @@ -1966,30 +1969,34 @@ def test_pad_sequence_pad_dim0(self, make_mask):
assert padded_td["a"].shape == torch.Size(
[2, 4, 8, 8]
) # check the shape of the padded tensor
assert torch.all(padded_td["a"][0, 2:, :, :] == 0) # check the padding
assert torch.all(padded_td["a"][0, 2:, :, :] == pad_val) # check the padding
assert padded_td["b", "c"].shape == torch.Size(
[2, 4, 3]
) # check the shape of the padded tensor
assert torch.all(padded_td["b", "c"][0, 2:, :] == 0) # check the padding
assert torch.all(padded_td["b", "c"][0, 2:, :] == pad_val) # check the padding
if make_mask:
masks_key = "masks"
if not isinstance(make_mask, bool):
masks_key = make_mask
padded_td_without_masks = pad_sequence(
list_td, pad_dim=pad_dim, return_mask=False
list_td, pad_dim=pad_dim, return_mask=False, padding_value=pad_val
)
assert masks_key in padded_td.keys(True)
assert set(
padded_td_without_masks.keys(include_nested=True, leaves_only=True)
) == set(padded_td[masks_key].keys(include_nested=True, leaves_only=True))
assert not padded_td[masks_key, "a"].all()
assert padded_td[masks_key, "a"].ndim == pad_dim + 2
assert (padded_td["a"][padded_td[masks_key, "a"]] != 0).all()
assert (padded_td["a"][~padded_td[masks_key, "a"]] == 0).all()
assert (padded_td["a"][padded_td[masks_key, "a"]] != pad_val).all()
assert (padded_td["a"][~padded_td[masks_key, "a"]] == pad_val).all()
assert not padded_td[masks_key, "b", "c"].all()
assert padded_td[masks_key, "b", "c"].ndim == pad_dim + 2
assert (padded_td["b", "c"][padded_td[masks_key, "b", "c"]] != 0).all()
assert (padded_td["b", "c"][~padded_td[masks_key, "b", "c"]] == 0).all()
assert (
padded_td["b", "c"][padded_td[masks_key, "b", "c"]] != pad_val
).all()
assert (
padded_td["b", "c"][~padded_td[masks_key, "b", "c"]] == pad_val
).all()
else:
assert "masks" not in padded_td.keys()

Expand Down

0 comments on commit bbf773b

Please sign in to comment.