Skip to content

Commit

Permalink
fix(separation): fix #1735 and #1747
Browse files Browse the repository at this point in the history
  • Loading branch information
clement-pages authored Nov 11, 2024
1 parent cf3e2b2 commit 7d84f61
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 41 deletions.
4 changes: 2 additions & 2 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@

### Fixes

- fix: fix clipping issue in speech separation pipeline ([@joonaskalda](https://github.com/joonaskalda/))
- fix: fix alignment between separated sources and diarization when the diarization reference is available ([@Lebourdais](https://github.com/Lebourdais/))
- fix(separation): fix clipping issue in speech separation pipeline ([@joonaskalda](https://github.com/joonaskalda/))
- fix(separation): fix alignment between separated sources and diarization ([@Lebourdais](https://github.com/Lebourdais/) and [@clement-pages](https://github.com/clement-pages/))
- fix(doc): fix link to pytorch ([@emmanuel-ferdman](https://github.com/emmanuel-ferdman/))

## Version 3.3.2 (2024-09-11)
Expand Down
67 changes: 28 additions & 39 deletions pyannote/audio/pipelines/speech_separation.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from typing import Callable, Optional, Text, Tuple, Union

import numpy as np
from scipy.ndimage import binary_dilation
import torch
from einops import rearrange
from pyannote.core import Annotation, SlidingWindow, SlidingWindowFeature
Expand Down Expand Up @@ -92,7 +93,7 @@ class SpeechSeparation(SpeakerDiarizationMixin, Pipeline):
Usage
-----
>>> pipeline = SpeakerDiarization()
>>> pipeline = SpeechSeparation()
>>> diarization, separation = pipeline("/path/to/audio.wav")
>>> diarization, separation = pipeline("/path/to/audio.wav", num_speakers=4)
>>> diarization, separation = pipeline("/path/to/audio.wav", min_speakers=2, max_speakers=10)
Expand Down Expand Up @@ -237,7 +238,7 @@ def get_segmentations(
segmentations, separations = file[self.CACHED_SEGMENTATION]
else:
segmentations, separations = self._segmentation(file, hook=hook)
file[self.CACHED_SEGMENTATION] = segmentations
file[self.CACHED_SEGMENTATION] = (segmentations, separations)
else:
segmentations, separations = self._segmentation(file, hook=hook)

Expand Down Expand Up @@ -583,7 +584,7 @@ def apply(

# reconstruct discrete diarization from raw hard clusters

# keep track of inactive speakers
# keep track of inactive speakers at chunk level
inactive_speakers = np.sum(binarized_segmentations.data, axis=1) == 0
# shape: (num_chunks, num_speakers)

Expand All @@ -594,7 +595,13 @@ def apply(
count,
)
discrete_diarization = self.to_diarization(discrete_diarization, count)
# remove file-wise inactive speakers from the diarization
active_speakers = np.sum(discrete_diarization, axis=0) > 0
# shape: (num_speakers, )
discrete_diarization.data = discrete_diarization.data[:, active_speakers]
num_frames, num_speakers = discrete_diarization.data.shape
hook("discrete_diarization", discrete_diarization)

clustered_separations = self.reconstruct(separations, hard_clusters, count)
frame_duration = separations.sliding_window.duration / separations.data.shape[1]
frames = SlidingWindow(step=frame_duration, duration=2 * frame_duration)
Expand All @@ -605,6 +612,17 @@ def apply(
missing=0.0,
skip_average=True,
)

_, num_sources = sources.data.shape

# In some cases, maximum num of simultaneous speakers is greater than num of clusters,
# implying a num of speakers in the diarization greater than num of sources after calling
# to_diarization(). So we add dummy sources to match the number of speakers in diarization.
sources.data = np.pad(sources.data, ((0, 0), (0, max(0, num_speakers - num_sources))))

# remove sources corresponding to file-wise inactive speakers
sources.data = sources.data[:, active_speakers]

# zero-out sources when speaker is inactive
# WARNING: this should be rewritten to avoid huge memory consumption
if self.separation.leakage_removal:
Expand All @@ -614,44 +632,15 @@ def apply(
)
)
if asr_collar_frames > 0:
for i in range(discrete_diarization.data.shape[1]):
for i in range(num_speakers):
speaker_activation = discrete_diarization.data.T[i]
non_silent = np.where(speaker_activation != 0)[0]
remaining_gaps = np.where(
np.diff(non_silent) > 2 * asr_collar_frames
)[0]
remaining_zeros = [
np.arange(
non_silent[gap] + asr_collar_frames,
non_silent[gap + 1] - asr_collar_frames,
)
for gap in remaining_gaps
]
# edge cases of long silent regions in beginning or end of audio
if non_silent[0] > asr_collar_frames:
remaining_zeros = [
np.arange(0, non_silent[0] - asr_collar_frames)
] + remaining_zeros
if non_silent[-1] < speaker_activation.shape[0] - asr_collar_frames:
remaining_zeros = remaining_zeros + [
np.arange(
non_silent[-1] + asr_collar_frames,
speaker_activation.shape[0],
)
]

speaker_activation_with_context = np.ones(
len(speaker_activation), dtype=float
)

speaker_activation_with_context[
np.concatenate(remaining_zeros)
] = 0.0

non_silent = speaker_activation != 0
dilated_non_silent = binary_dilation(non_silent, [True] * (2 * asr_collar_frames))
speaker_activation_with_context = dilated_non_silent.astype(np.int8)
discrete_diarization.data.T[i] = speaker_activation_with_context
num_sources = sources.data.shape[1]

sources.data = (
sources.data * discrete_diarization.align(sources).data[:, :num_sources]
sources.data * discrete_diarization.align(sources).data
)

# separated sources might be scaled up/down due to SI-SDR loss used when training
Expand Down Expand Up @@ -701,7 +690,7 @@ def apply(
# re-order sources so that they match
# the order given by diarization.labels()
inverse_mapping = {label: index for index, label in mapping.items()}
source.data = sources.data[
sources.data = sources.data[
:, [inverse_mapping[label] for label in diarization.labels()]
]

Expand Down

0 comments on commit 7d84f61

Please sign in to comment.