Skip to content

Commit

Permalink
init
Browse files Browse the repository at this point in the history
  • Loading branch information
Daniel Mann committed May 23, 2024
1 parent 1264482 commit 55a6fd6
Showing 1 changed file with 54 additions and 0 deletions.
54 changes: 54 additions & 0 deletions tests/test_masking.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
import torch
import torch.nn.functional as F
from torch import nn

from i6_models.parts.frontend.common import mask_pool, get_same_padding, apply_same_padding

def test_masking():
# tensor as batch with different sequence lengths
T = 2
kernel_size = 3
stride = 2
padding = get_same_padding(kernel_size)

out_sequence_mask = mask_pool(
sequence_mask,
kernel_size=kernel_size,
stride=stride,
padding=padding
)

# what does conv to one sequence in [B, F, T] format
B, F = (1, 1)
x = torch.ones((B, F, T))
pad = lambda x: x
conv = nn.Conv1d(
in_channels=F,
out_channels=F,
kernel_size=kernel_size,
stride=stride,
padding=padding,
)
out = conv(x)
out_len = out.shape[-1]

# mask for this sequence length in a batch of max length = 100
batch_T = 100
idx = T - 1 # sequence at this index has length T by following construction
in_mask = torch.tensor(
[[True] * t + [False] * (batch_T - t) for t in range(1, batch_T + 1)]
)

out_mask = mask_pool(
in_mask,
kernel_size=kernel_size,
stride=stride,
padding=padding
)[idx, :]

# we expect True for the length of the sequence and False otherwise
mask_len = len(torch.where(out_mask)[0])
assert out_len == mask_len, f"Actual out length of the sequence {out_len=}" \
+ f" and the length of the mask {mask_len=} are not equal where " \
+ f" {out=} and {out_mask=}."

0 comments on commit 55a6fd6

Please sign in to comment.