diff --git a/lhotse/features/kaldi/layers.py b/lhotse/features/kaldi/layers.py index 4d974b804..e2b271556 100644 --- a/lhotse/features/kaldi/layers.py +++ b/lhotse/features/kaldi/layers.py @@ -776,7 +776,7 @@ def _get_strided_batch( if npad_right >= 0: pad_right = torch.flip(waveform[:, -npad_right:], (1,)) else: - pad_right = torch.zeros(0, dtype=waveform.dtype) + pad_right = torch.zeros(0, dtype=waveform.dtype, device=waveform.device) waveform = torch.cat((pad_left, waveform, pad_right), dim=1) strides = (