diff --git a/CHANGELOG.md b/CHANGELOG.md index 811b6a7ae..dc824b180 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -19,8 +19,8 @@ ### Fixes -- fix: fix clipping issue in speech separation pipeline ([@joonaskalda](https://github.com/joonaskalda/)) -- fix: fix alignment between separated sources and diarization when the diarization reference is available ([@Lebourdais](https://github.com/Lebourdais/)) +- fix(separation): fix clipping issue in speech separation pipeline ([@joonaskalda](https://github.com/joonaskalda/)) +- fix(separation): fix alignment between separated sources and diarization ([@Lebourdais](https://github.com/Lebourdais/) and [@clement-pages](https://github.com/clement-pages/)) - fix(doc): fix link to pytorch ([@emmanuel-ferdman](https://github.com/emmanuel-ferdman/)) ## Version 3.3.2 (2024-09-11) diff --git a/pyannote/audio/pipelines/speech_separation.py b/pyannote/audio/pipelines/speech_separation.py index a129ea7a4..0ffe42a0d 100644 --- a/pyannote/audio/pipelines/speech_separation.py +++ b/pyannote/audio/pipelines/speech_separation.py @@ -30,6 +30,7 @@ from typing import Callable, Optional, Text, Tuple, Union import numpy as np +from scipy.ndimage import binary_dilation import torch from einops import rearrange from pyannote.core import Annotation, SlidingWindow, SlidingWindowFeature @@ -92,7 +93,7 @@ class SpeechSeparation(SpeakerDiarizationMixin, Pipeline): Usage ----- - >>> pipeline = SpeakerDiarization() + >>> pipeline = SpeechSeparation() >>> diarization, separation = pipeline("/path/to/audio.wav") >>> diarization, separation = pipeline("/path/to/audio.wav", num_speakers=4) >>> diarization, separation = pipeline("/path/to/audio.wav", min_speakers=2, max_speakers=10) @@ -237,7 +238,7 @@ def get_segmentations( segmentations, separations = file[self.CACHED_SEGMENTATION] else: segmentations, separations = self._segmentation(file, hook=hook) - file[self.CACHED_SEGMENTATION] = segmentations + file[self.CACHED_SEGMENTATION] = (segmentations, separations) else: segmentations, separations = self._segmentation(file, hook=hook) @@ -583,7 +584,7 @@ def apply( # reconstruct discrete diarization from raw hard clusters - # keep track of inactive speakers + # keep track of inactive speakers at chunk level inactive_speakers = np.sum(binarized_segmentations.data, axis=1) == 0 # shape: (num_chunks, num_speakers) @@ -594,7 +595,13 @@ def apply( count, ) discrete_diarization = self.to_diarization(discrete_diarization, count) + # remove file-wise inactive speakers from the diarization + active_speakers = np.sum(discrete_diarization, axis=0) > 0 + # shape: (num_speakers, ) + discrete_diarization.data = discrete_diarization.data[:, active_speakers] + num_frames, num_speakers = discrete_diarization.data.shape hook("discrete_diarization", discrete_diarization) + clustered_separations = self.reconstruct(separations, hard_clusters, count) frame_duration = separations.sliding_window.duration / separations.data.shape[1] frames = SlidingWindow(step=frame_duration, duration=2 * frame_duration) @@ -605,6 +612,17 @@ def apply( missing=0.0, skip_average=True, ) + + _, num_sources = sources.data.shape + + # In some cases, maximum num of simultaneous speakers is greater than num of clusters, + # implying a num of speakers in the diarization greater than num of sources after calling + # to_diarization(). So we add dummy sources to match the number of speakers in diarization. + sources.data = np.pad(sources.data, ((0, 0), (0, max(0, num_speakers - num_sources)))) + + # remove sources corresponding to file-wise inactive speakers + sources.data = sources.data[:, active_speakers] + # zero-out sources when speaker is inactive # WARNING: this should be rewritten to avoid huge memory consumption if self.separation.leakage_removal: @@ -614,44 +632,15 @@ def apply( ) ) if asr_collar_frames > 0: - for i in range(discrete_diarization.data.shape[1]): + for i in range(num_speakers): speaker_activation = discrete_diarization.data.T[i] - non_silent = np.where(speaker_activation != 0)[0] - remaining_gaps = np.where( - np.diff(non_silent) > 2 * asr_collar_frames - )[0] - remaining_zeros = [ - np.arange( - non_silent[gap] + asr_collar_frames, - non_silent[gap + 1] - asr_collar_frames, - ) - for gap in remaining_gaps - ] - # edge cases of long silent regions in beginning or end of audio - if non_silent[0] > asr_collar_frames: - remaining_zeros = [ - np.arange(0, non_silent[0] - asr_collar_frames) - ] + remaining_zeros - if non_silent[-1] < speaker_activation.shape[0] - asr_collar_frames: - remaining_zeros = remaining_zeros + [ - np.arange( - non_silent[-1] + asr_collar_frames, - speaker_activation.shape[0], - ) - ] - - speaker_activation_with_context = np.ones( - len(speaker_activation), dtype=float - ) - - speaker_activation_with_context[ - np.concatenate(remaining_zeros) - ] = 0.0 - + non_silent = speaker_activation != 0 + dilated_non_silent = binary_dilation(non_silent, [True] * (2 * asr_collar_frames)) + speaker_activation_with_context = dilated_non_silent.astype(np.int8) discrete_diarization.data.T[i] = speaker_activation_with_context - num_sources = sources.data.shape[1] + sources.data = ( - sources.data * discrete_diarization.align(sources).data[:, :num_sources] + sources.data * discrete_diarization.align(sources).data ) # separated sources might be scaled up/down due to SI-SDR loss used when training @@ -701,7 +690,7 @@ def apply( # re-order sources so that they match # the order given by diarization.labels() inverse_mapping = {label: index for index, label in mapping.items()} - source.data = sources.data[ + sources.data = sources.data[ :, [inverse_mapping[label] for label in diarization.labels()] ]