diff --git a/pyannote/audio/augmentation/mix.py b/pyannote/audio/augmentation/mix.py index c6e811280..631d29812 100644 --- a/pyannote/audio/augmentation/mix.py +++ b/pyannote/audio/augmentation/mix.py @@ -25,7 +25,7 @@ import torch from torch import Tensor -from torch_audiomentations import Mix +from torch_audiomentations.augmentations.mix import Mix class MixSpeakerDiarization(Mix): @@ -85,7 +85,6 @@ def randomize_parameters( targets: Optional[Tensor] = None, target_rate: Optional[int] = None, ): - batch_size, num_channels, num_samples = samples.shape snr_distribution = torch.distributions.Uniform( low=torch.tensor( @@ -116,7 +115,6 @@ def randomize_parameters( batch_size, dtype=torch.int64 ) for n in range(max_num_speakers + 1): - # indices of samples with exactly n speakers samples_with_n_speakers = torch.where(num_speakers == n)[0] num_samples_with_n_speakers = len(samples_with_n_speakers)