From 9c15fb5925b389886b187ac9d3e6d29c9760af45 Mon Sep 17 00:00:00 2001 From: Judyxujj Date: Thu, 23 May 2024 13:15:28 +0200 Subject: [PATCH] Update i6_models/primitives/mixup.py Co-authored-by: Albert Zeyer --- i6_models/primitives/mixup.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/i6_models/primitives/mixup.py b/i6_models/primitives/mixup.py index 9fbeb441..bfbdb42f 100644 --- a/i6_models/primitives/mixup.py +++ b/i6_models/primitives/mixup.py @@ -77,8 +77,8 @@ def get_random(self, b_dim: int, t_dim: int, max_num_mixup: int, n_mask: torch.t start_indicies, n_mask ) # [B, M'] (M' denotes sum of num_mixup over the batch) - idx = torch.arange(t_dim) - idx = torch.unsqueeze(idx, dim=-1) + start_indicies_flat # [T, M'] + idx = torch.arange(t_dim) # [T] + idx = idx[:, None] + start_indicies_flat[None, :] # [T, M'] mixup_values = self.cache[idx] # [T, M', F] return mixup_values