diff --git a/CHANGELOG.md b/CHANGELOG.md index 02c931d5b..f29a13450 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,18 +2,37 @@ ## develop +### Fixes + +- fix: fix support for `numpy==2.x` ([@ibevers](https://github.com/ibevers/)) + +## Version 3.3.0 (2024-06-14) + +### TL;DR + +`pyannote.audio` does [speech separation](https://hf.co/pyannote/speech-separation-ami-1.0): multi-speaker audio in, one audio channel per speaker out! + +```bash +pip install pyannote.audio[separation]==3.3.0 +``` + ### New features +- feat(task): add `PixIT` joint speaker diarization and speech separation task (with [@joonaskalda](https://github.com/joonaskalda/)) +- feat(model): add `ToTaToNet` joint speaker diarization and speech separation model (with [@joonaskalda](https://github.com/joonaskalda/)) +- feat(pipeline): add `SpeechSeparation` pipeline (with [@joonaskalda](https://github.com/joonaskalda/)) - feat(io): add option to select torchaudio `backend` ### Fixes - fix(task): fix wrong train/development split when training with (some) meta-protocols ([#1709](https://github.com/pyannote/pyannote-audio/issues/1709)) +- fix(task): fix metadata preparation with missing validation subset ([@clement-pages](https://github.com/clement-pages/)) ### Improvements - improve(io): when available, default to using `soundfile` backend - improve(pipeline): do not extract embeddings when `max_speakers` is set to 1 +- improve(pipeline): optimize memory usage of most pipelines ([#1713](https://github.com/pyannote/pyannote-audio/pull/1713) by [@benniekiss](https://github.com/benniekiss/)) ## Version 3.2.0 (2024-05-08) diff --git a/pyannote/audio/core/inference.py b/pyannote/audio/core/inference.py index 0c3e9b212..5155d55e6 100644 --- a/pyannote/audio/core/inference.py +++ b/pyannote/audio/core/inference.py @@ -526,7 +526,7 @@ def aggregate( warm_up: Tuple[float, float] = (0.0, 0.0), epsilon: float = 1e-12, hamming: bool = False, - missing: float = np.NaN, + missing: float = np.nan, skip_average: bool = False, ) -> SlidingWindowFeature: """Aggregation @@ -559,9 +559,6 @@ def aggregate( step=frames.step, ) - masks = 1 - np.isnan(scores) - scores.data = np.nan_to_num(scores.data, copy=True, nan=0.0) - # Hamming window used for overlap-add aggregation hamming_window = ( np.hamming(num_frames_per_chunk).reshape(-1, 1) @@ -613,11 +610,13 @@ def aggregate( ) # loop on the scores of sliding chunks - for (chunk, score), (_, mask) in zip(scores, masks): + for chunk, score in scores: # chunk ~ Segment # score ~ (num_frames_per_chunk, num_classes)-shaped np.ndarray # mask ~ (num_frames_per_chunk, num_classes)-shaped np.ndarray - + mask = 1 - np.isnan(score) + np.nan_to_num(score, copy=False, nan=0.0) + start_frame = frames.closest_frame(chunk.start + 0.5 * frames.duration) aggregated_output[start_frame : start_frame + num_frames_per_chunk] += ( diff --git a/pyannote/audio/core/task.py b/pyannote/audio/core/task.py index 974f43a67..04c73ab51 100644 --- a/pyannote/audio/core/task.py +++ b/pyannote/audio/core/task.py @@ -595,7 +595,9 @@ def prepare_data(self): prepared_data["metadata-labels"] = np.array(unique_labels, dtype=np.str_) unique_labels.clear() - self.prepare_validation(prepared_data) + if self.has_validation: + self.prepare_validation(prepared_data) + self.post_prepare_data(prepared_data) # save prepared data on the disk diff --git a/pyannote/audio/models/separation/ToTaToNet.py b/pyannote/audio/models/separation/ToTaToNet.py new file mode 100644 index 000000000..34fd4f34a --- /dev/null +++ b/pyannote/audio/models/separation/ToTaToNet.py @@ -0,0 +1,351 @@ +# MIT License +# +# Copyright (c) 2024- CNRS +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +# AUTHOR: Joonas Kalda (github.com/joonaskalda) + +from functools import lru_cache +from typing import Optional + +import torch +import torch.nn as nn +import torch.nn.functional as F +from asteroid_filterbanks import make_enc_dec +from pyannote.core.utils.generators import pairwise + +from pyannote.audio.core.model import Model +from pyannote.audio.core.task import Task +from pyannote.audio.utils.params import merge_dict +from pyannote.audio.utils.receptive_field import ( + conv1d_num_frames, + conv1d_receptive_field_center, + conv1d_receptive_field_size, +) + +try: + from asteroid.masknn import DPRNN + from asteroid.utils.torch_utils import pad_x_to_y + + ASTEROID_IS_AVAILABLE = True +except ImportError: + ASTEROID_IS_AVAILABLE = False + + +try: + from transformers import AutoModel + + TRANSFORMERS_IS_AVAILABLE = True +except ImportError: + TRANSFORMERS_IS_AVAILABLE = False + + +class ToTaToNet(Model): + """ToTaToNet joint speaker diarization and speech separation model + + /--------------\\ + Conv1D Encoder --------+--- DPRNN --X------- Conv1D Decoder + WavLM -- upsampling --/ \\--- Avg pool -- Linear -- Classifier + + + Parameters + ---------- + sample_rate : int, optional + Audio sample rate. Defaults to 16kHz (16000). + num_channels : int, optional + Number of channels. Defaults to mono (1). + sincnet : dict, optional + Keyword arugments passed to the SincNet block. + Defaults to {"stride": 1}. + linear : dict, optional + Keyword arugments used to initialize linear layers + See ToTaToNet.LINEAR_DEFAULTS for default values. + diar : dict, optional + Keyword arguments used to initalize the average pooling in the diarization branch. + See ToTaToNet.DIAR_DEFAULTS for default values. + encoder_decoder : dict, optional + Keyword arguments used to initalize the encoder and decoder. + See ToTaToNet.ENCODER_DECODER_DEFAULTS for default values. + dprnn : dict, optional + Keyword arguments used to initalize the DPRNN model. + See ToTaToNet.DPRNN_DEFAULTS for default values. + sample_rate : int, optional + Audio sample rate. Defaults to 16000. + num_channels : int, optional + Number of channels. Defaults to 1. + task : Task, optional + Task to perform. Defaults to None. + n_sources : int, optional + Number of separated sources. Defaults to 3. + use_wavlm : bool, optional + Whether to use the WavLM large model for feature extraction. Defaults to True. + gradient_clip_val : float, optional + Gradient clipping value. Required when fine-tuning the WavLM model and thus using two different optimizers. + Defaults to 5.0. + + References + ---------- + Joonas Kalda, Clément Pagés, Ricard Marxer, Tanel Alumäe, and Hervé Bredin. + "PixIT: Joint Training of Speaker Diarization and Speech Separation + from Real-world Multi-speaker Recordings" + Odyssey 2024. https://arxiv.org/abs/2403.02288 + """ + + ENCODER_DECODER_DEFAULTS = { + "fb_name": "free", + "kernel_size": 32, + "n_filters": 64, + "stride": 16, + } + LINEAR_DEFAULTS = {"hidden_size": 64, "num_layers": 2} + DPRNN_DEFAULTS = { + "n_repeats": 6, + "bn_chan": 128, + "hid_size": 128, + "chunk_size": 100, + "norm_type": "gLN", + "mask_act": "relu", + "rnn_type": "LSTM", + } + DIAR_DEFAULTS = {"frames_per_second": 125} + + def __init__( + self, + encoder_decoder: dict = None, + linear: Optional[dict] = None, + diar: Optional[dict] = None, + dprnn: dict = None, + sample_rate: int = 16000, + num_channels: int = 1, + task: Optional[Task] = None, + n_sources: int = 3, + use_wavlm: bool = True, + gradient_clip_val: float = 5.0, + ): + if not ASTEROID_IS_AVAILABLE: + raise ImportError( + "'asteroid' must be installed to use ToTaToNet separation. " + "`pip install pyannote-audio[separation]` should do the trick." + ) + + if not TRANSFORMERS_IS_AVAILABLE: + raise ImportError( + "'transformers' must be installed to use ToTaToNet separation. " + "`pip install pyannote-audio[separation]` should do the trick." + ) + + super().__init__(sample_rate=sample_rate, num_channels=num_channels, task=task) + + linear = merge_dict(self.LINEAR_DEFAULTS, linear) + dprnn = merge_dict(self.DPRNN_DEFAULTS, dprnn) + encoder_decoder = merge_dict(self.ENCODER_DECODER_DEFAULTS, encoder_decoder) + diar = merge_dict(self.DIAR_DEFAULTS, diar) + self.use_wavlm = use_wavlm + self.save_hyperparameters("encoder_decoder", "linear", "dprnn", "diar") + self.n_sources = n_sources + + if encoder_decoder["fb_name"] == "free": + n_feats_out = encoder_decoder["n_filters"] + elif encoder_decoder["fb_name"] == "stft": + n_feats_out = int(2 * (encoder_decoder["n_filters"] / 2 + 1)) + else: + raise ValueError("Filterbank type not recognized.") + self.encoder, self.decoder = make_enc_dec( + sample_rate=sample_rate, **self.hparams.encoder_decoder + ) + + if self.use_wavlm: + self.wavlm = AutoModel.from_pretrained("microsoft/wavlm-large") + downsampling_factor = 1 + for conv_layer in self.wavlm.feature_extractor.conv_layers: + if isinstance(conv_layer.conv, nn.Conv1d): + downsampling_factor *= conv_layer.conv.stride[0] + self.wavlm_scaling = int(downsampling_factor / encoder_decoder["stride"]) + + self.masker = DPRNN( + encoder_decoder["n_filters"] + + self.wavlm.feature_projection.projection.out_features, + out_chan=encoder_decoder["n_filters"], + n_src=n_sources, + **self.hparams.dprnn, + ) + else: + self.masker = DPRNN( + encoder_decoder["n_filters"], + out_chan=encoder_decoder["n_filters"], + n_src=n_sources, + **self.hparams.dprnn, + ) + + # diarization can use a lower resolution than separation + self.diarization_scaling = int( + sample_rate / diar["frames_per_second"] / encoder_decoder["stride"] + ) + self.average_pool = nn.AvgPool1d( + self.diarization_scaling, stride=self.diarization_scaling + ) + linaer_input_features = n_feats_out + if linear["num_layers"] > 0: + self.linear = nn.ModuleList( + [ + nn.Linear(in_features, out_features) + for in_features, out_features in pairwise( + [ + linaer_input_features, + ] + + [self.hparams.linear["hidden_size"]] + * self.hparams.linear["num_layers"] + ) + ] + ) + self.gradient_clip_val = gradient_clip_val + self.automatic_optimization = False + + @property + def dimension(self) -> int: + """Dimension of output""" + return 1 + + def build(self): + if self.hparams.linear["num_layers"] > 0: + self.classifier = nn.Linear(64, self.dimension) + else: + self.classifier = nn.Linear(1, self.dimension) + self.activation = self.default_activation() + + @lru_cache + def num_frames(self, num_samples: int) -> int: + """Compute number of output frames + + Parameters + ---------- + num_samples : int + Number of input samples. + + Returns + ------- + num_frames : int + Number of output frames. + """ + + equivalent_stride = ( + self.diarization_scaling * self.hparams.encoder_decoder["stride"] + ) + equivalent_kernel_size = ( + self.diarization_scaling * self.hparams.encoder_decoder["kernel_size"] + ) + + return conv1d_num_frames( + num_samples, kernel_size=equivalent_kernel_size, stride=equivalent_stride + ) + + def receptive_field_size(self, num_frames: int = 1) -> int: + """Compute size of receptive field + + Parameters + ---------- + num_frames : int, optional + Number of frames in the output signal + + Returns + ------- + receptive_field_size : int + Receptive field size. + """ + + equivalent_stride = ( + self.diarization_scaling * self.hparams.encoder_decoder["stride"] + ) + equivalent_kernel_size = ( + self.diarization_scaling * self.hparams.encoder_decoder["kernel_size"] + ) + + return conv1d_receptive_field_size( + num_frames, kernel_size=equivalent_kernel_size, stride=equivalent_stride + ) + + def receptive_field_center(self, frame: int = 0) -> int: + """Compute center of receptive field + + Parameters + ---------- + frame : int, optional + Frame index + + Returns + ------- + receptive_field_center : int + Index of receptive field center. + """ + + equivalent_stride = ( + self.diarization_scaling * self.hparams.encoder_decoder["stride"] + ) + equivalent_kernel_size = ( + self.diarization_scaling * self.hparams.encoder_decoder["kernel_size"] + ) + + return conv1d_receptive_field_center( + frame, kernel_size=equivalent_kernel_size, stride=equivalent_stride + ) + + def forward(self, waveforms: torch.Tensor) -> torch.Tensor: + """Pass forward + + Parameters + ---------- + waveforms : (batch, channel, sample) + + Returns + ------- + scores : (batch, frame, classes) + sources : (batch, sample, n_sources) + """ + bsz = waveforms.shape[0] + tf_rep = self.encoder(waveforms) + if self.use_wavlm: + wavlm_rep = self.wavlm(waveforms.squeeze(1)).last_hidden_state + wavlm_rep = wavlm_rep.transpose(1, 2) + wavlm_rep = wavlm_rep.repeat_interleave(self.wavlm_scaling, dim=-1) + wavlm_rep = pad_x_to_y(wavlm_rep, tf_rep) + wavlm_rep = torch.cat((tf_rep, wavlm_rep), dim=1) + masks = self.masker(wavlm_rep) + else: + masks = self.masker(tf_rep) + # shape: (batch, nsrc, nfilters, nframes) + masked_tf_rep = masks * tf_rep.unsqueeze(1) + decoded_sources = self.decoder(masked_tf_rep) + decoded_sources = pad_x_to_y(decoded_sources, waveforms) + decoded_sources = decoded_sources.transpose(1, 2) + outputs = torch.flatten(masked_tf_rep, start_dim=0, end_dim=1) + # shape (batch * nsrc, nfilters, nframes) + outputs = self.average_pool(outputs) + outputs = outputs.transpose(1, 2) + # shape (batch, nframes, nfilters) + if self.hparams.linear["num_layers"] > 0: + for linear in self.linear: + outputs = F.leaky_relu(linear(outputs)) + if self.hparams.linear["num_layers"] == 0: + outputs = (outputs**2).sum(dim=2).unsqueeze(-1) + outputs = self.classifier(outputs) + outputs = outputs.reshape(bsz, self.n_sources, -1) + outputs = outputs.transpose(1, 2) + + return self.activation[0](outputs), decoded_sources diff --git a/pyannote/audio/models/separation/__init__.py b/pyannote/audio/models/separation/__init__.py new file mode 100644 index 000000000..a795392e9 --- /dev/null +++ b/pyannote/audio/models/separation/__init__.py @@ -0,0 +1,25 @@ +# MIT License +# +# Copyright (c) 2024- CNRS +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +from .ToTaToNet import ToTaToNet + +__all__ = ["ToTaToNet"] diff --git a/pyannote/audio/pipelines/__init__.py b/pyannote/audio/pipelines/__init__.py index 0c7d2f25c..443a55d60 100644 --- a/pyannote/audio/pipelines/__init__.py +++ b/pyannote/audio/pipelines/__init__.py @@ -24,6 +24,7 @@ from .overlapped_speech_detection import OverlappedSpeechDetection from .resegmentation import Resegmentation from .speaker_diarization import SpeakerDiarization +from .speech_separation import SpeechSeparation from .voice_activity_detection import VoiceActivityDetection __all__ = [ @@ -32,4 +33,5 @@ "SpeakerDiarization", "Resegmentation", "MultiLabelSegmentation", + "SpeechSeparation", ] diff --git a/pyannote/audio/pipelines/speech_separation.py b/pyannote/audio/pipelines/speech_separation.py new file mode 100644 index 000000000..45c10b9b5 --- /dev/null +++ b/pyannote/audio/pipelines/speech_separation.py @@ -0,0 +1,722 @@ +# The MIT License (MIT) +# +# Copyright (c) 2024- CNRS +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +"""Speech separation pipelines""" + +import functools +import itertools +import math +import textwrap +import warnings +from typing import Callable, Optional, Text, Tuple, Union + +import numpy as np +import torch +from einops import rearrange +from pyannote.core import Annotation, SlidingWindow, SlidingWindowFeature +from pyannote.metrics.diarization import GreedyDiarizationErrorRate +from pyannote.pipeline.parameter import Categorical, ParamDict, Uniform + +from pyannote.audio import Audio, Inference, Model, Pipeline +from pyannote.audio.core.io import AudioFile +from pyannote.audio.pipelines.clustering import Clustering +from pyannote.audio.pipelines.speaker_verification import PretrainedSpeakerEmbedding +from pyannote.audio.pipelines.utils import ( + PipelineModel, + SpeakerDiarizationMixin, + get_model, +) +from pyannote.audio.utils.signal import binarize + + +def batchify(iterable, batch_size: int = 32, fillvalue=None): + """Batchify iterable""" + # batchify('ABCDEFG', 3) --> ['A', 'B', 'C'] ['D', 'E', 'F'] [G, ] + args = [iter(iterable)] * batch_size + return itertools.zip_longest(*args, fillvalue=fillvalue) + + +class SpeechSeparation(SpeakerDiarizationMixin, Pipeline): + """Speech separation pipeline + + Parameters + ---------- + segmentation : Model, str, or dict, optional + Pretrained segmentation model and separation model. + See pyannote.audio.pipelines.utils.get_model for supported format. + segmentation_step: float, optional + The segmentation model is applied on a window sliding over the whole audio file. + `segmentation_step` controls the step of this window, provided as a ratio of its + duration. Defaults to 0.1 (i.e. 90% overlap between two consecuive windows). + embedding : Model, str, or dict, optional + Pretrained embedding model. Defaults to "pyannote/embedding@2022.07". + See pyannote.audio.pipelines.utils.get_model for supported format. + embedding_exclude_overlap : bool, optional + Exclude overlapping speech regions when extracting embeddings. + Defaults (False) to use the whole speech. + clustering : str, optional + Clustering algorithm. See pyannote.audio.pipelines.clustering.Clustering + for available options. Defaults to "AgglomerativeClustering". + segmentation_batch_size : int, optional + Batch size used for speaker segmentation. Defaults to 1. + embedding_batch_size : int, optional + Batch size used for speaker embedding. Defaults to 1. + der_variant : dict, optional + Optimize for a variant of diarization error rate. + Defaults to {"collar": 0.0, "skip_overlap": False}. This is used in `get_metric` + when instantiating the metric: GreedyDiarizationErrorRate(**der_variant). + use_auth_token : str, optional + When loading private huggingface.co models, set `use_auth_token` + to True or to a string containing your hugginface.co authentication + token that can be obtained by running `huggingface-cli login` + + Usage + ----- + >>> pipeline = SpeakerDiarization() + >>> diarization, separation = pipeline("/path/to/audio.wav") + >>> diarization, separation = pipeline("/path/to/audio.wav", num_speakers=4) + >>> diarization, separation = pipeline("/path/to/audio.wav", min_speakers=2, max_speakers=10) + + Hyper-parameters + ---------------- + segmentation.min_duration_off : float + Fill intra-speaker gaps shorter than that many seconds. + segmentation.threshold : float + Mark speaker has active when their probability is higher than this. + clustering.method : {'centroid', 'average', ...} + Linkage used for agglomerative clustering + clustering.min_cluster_size : int + Minium cluster size. + clustering.threshold : float + Clustering threshold used to stop merging clusters. + separation.leakage_removal : bool + Zero-out sources when speaker is inactive. + separation.asr_collar + When using leakage removal, keep that many seconds before and after each speaker turn + + References + ---------- + Joonas Kalda, Clément Pagés, Ricard Marxer, Tanel Alumäe, and Hervé Bredin. + "PixIT: Joint Training of Speaker Diarization and Speech Separation + from Real-world Multi-speaker Recordings" + Odyssey 2024. https://arxiv.org/abs/2403.02288 + """ + + def __init__( + self, + segmentation: PipelineModel = None, + segmentation_step: float = 0.1, + embedding: PipelineModel = "speechbrain/spkrec-ecapa-voxceleb@5c0be3875fda05e81f3c004ed8c7c06be308de1e", + embedding_exclude_overlap: bool = False, + clustering: str = "AgglomerativeClustering", + embedding_batch_size: int = 1, + segmentation_batch_size: int = 1, + der_variant: Optional[dict] = None, + use_auth_token: Union[Text, None] = None, + ): + super().__init__() + + self.segmentation_model = segmentation + model: Model = get_model(segmentation, use_auth_token=use_auth_token) + + self.segmentation_step = segmentation_step + + self.embedding = embedding + self.embedding_batch_size = embedding_batch_size + self.embedding_exclude_overlap = embedding_exclude_overlap + + self.klustering = clustering + + self.der_variant = der_variant or {"collar": 0.0, "skip_overlap": False} + + segmentation_duration = model.specifications[0].duration + self._segmentation = Inference( + model, + duration=segmentation_duration, + step=self.segmentation_step * segmentation_duration, + skip_aggregation=True, + batch_size=segmentation_batch_size, + ) + + if self._segmentation.model.specifications[0].powerset: + self.segmentation = ParamDict( + min_duration_off=Uniform(0.0, 1.0), + ) + + else: + self.segmentation = ParamDict( + threshold=Uniform(0.1, 0.9), + min_duration_off=Uniform(0.0, 1.0), + ) + + if self.klustering == "OracleClustering": + metric = "not_applicable" + + else: + self._embedding = PretrainedSpeakerEmbedding( + self.embedding, use_auth_token=use_auth_token + ) + self._audio = Audio(sample_rate=self._embedding.sample_rate, mono="downmix") + metric = self._embedding.metric + + try: + Klustering = Clustering[clustering] + except KeyError: + raise ValueError( + f'clustering must be one of [{", ".join(list(Clustering.__members__))}]' + ) + self.clustering = Klustering.value(metric=metric) + + self.separation = ParamDict( + leakage_removal=Categorical([True, False]), + asr_collar=Uniform(0.0, 1.0), + ) + + @property + def segmentation_batch_size(self) -> int: + return self._segmentation.batch_size + + @segmentation_batch_size.setter + def segmentation_batch_size(self, batch_size: int): + self._segmentation.batch_size = batch_size + + def default_parameters(self): + raise NotImplementedError() + + def classes(self): + speaker = 0 + while True: + yield f"SPEAKER_{speaker:02d}" + speaker += 1 + + @property + def CACHED_SEGMENTATION(self): + return "training_cache/segmentation" + + def get_segmentations( + self, file, hook=None + ) -> Tuple[SlidingWindowFeature, SlidingWindowFeature]: + """Apply segmentation model + + Parameter + --------- + file : AudioFile + hook : Optional[Callable] + + Returns + ------- + segmentations : (num_chunks, num_frames, num_speakers) SlidingWindowFeature + separations : (num_chunks, num_samples, num_speakers) SlidingWindowFeature + """ + + if hook is not None: + hook = functools.partial(hook, "segmentation", None) + + if self.training: + if self.CACHED_SEGMENTATION in file: + segmentations, separations = file[self.CACHED_SEGMENTATION] + else: + segmentations, separations = self._segmentation(file, hook=hook) + file[self.CACHED_SEGMENTATION] = segmentations + else: + segmentations, separations = self._segmentation(file, hook=hook) + + return segmentations, separations + + def get_embeddings( + self, + file, + binary_segmentations: SlidingWindowFeature, + exclude_overlap: bool = False, + hook: Optional[Callable] = None, + ): + """Extract embeddings for each (chunk, speaker) pair + + Parameters + ---------- + file : AudioFile + binary_segmentations : (num_chunks, num_frames, num_speakers) SlidingWindowFeature + Binarized segmentation. + exclude_overlap : bool, optional + Exclude overlapping speech regions when extracting embeddings. + In case non-overlapping speech is too short, use the whole speech. + hook: Optional[Callable] + Called during embeddings after every batch to report the progress + + Returns + ------- + embeddings : (num_chunks, num_speakers, dimension) array + """ + + # when optimizing the hyper-parameters of this pipeline with frozen + # "segmentation.threshold", one can reuse the embeddings from the first trial, + # bringing a massive speed up to the optimization process (and hence allowing to use + # a larger search space). + if self.training: + # we only re-use embeddings if they were extracted based on the same value of the + # "segmentation.threshold" hyperparameter or if the segmentation model relies on + # `powerset` mode + cache = file.get("training_cache/embeddings", dict()) + if ("embeddings" in cache) and ( + self._segmentation.model.specifications[0].powerset + or (cache["segmentation.threshold"] == self.segmentation.threshold) + ): + return cache["embeddings"] + + duration = binary_segmentations.sliding_window.duration + num_chunks, num_frames, num_speakers = binary_segmentations.data.shape + + if exclude_overlap: + # minimum number of samples needed to extract an embedding + # (a lower number of samples would result in an error) + min_num_samples = self._embedding.min_num_samples + + # corresponding minimum number of frames + num_samples = duration * self._embedding.sample_rate + min_num_frames = math.ceil(num_frames * min_num_samples / num_samples) + + # zero-out frames with overlapping speech + clean_frames = 1.0 * ( + np.sum(binary_segmentations.data, axis=2, keepdims=True) < 2 + ) + clean_segmentations = SlidingWindowFeature( + binary_segmentations.data * clean_frames, + binary_segmentations.sliding_window, + ) + + else: + min_num_frames = -1 + clean_segmentations = SlidingWindowFeature( + binary_segmentations.data, binary_segmentations.sliding_window + ) + + def iter_waveform_and_mask(): + for (chunk, masks), (_, clean_masks) in zip( + binary_segmentations, clean_segmentations + ): + # chunk: Segment(t, t + duration) + # masks: (num_frames, local_num_speakers) np.ndarray + + waveform, _ = self._audio.crop( + file, + chunk, + duration=duration, + mode="pad", + ) + # waveform: (1, num_samples) torch.Tensor + + # speaker_activation_with_context may contain NaN (in case of partial stitching) + masks = np.nan_to_num(masks, nan=0.0).astype(np.float32) + clean_masks = np.nan_to_num(clean_masks, nan=0.0).astype(np.float32) + + for speaker_activation_with_context, clean_mask in zip( + masks.T, clean_masks.T + ): + # speaker_activation_with_context: (num_frames, ) np.ndarray + + if np.sum(clean_mask) > min_num_frames: + used_mask = clean_mask + else: + used_mask = speaker_activation_with_context + + yield waveform[None], torch.from_numpy(used_mask)[None] + # w: (1, 1, num_samples) torch.Tensor + # m: (1, num_frames) torch.Tensor + + batches = batchify( + iter_waveform_and_mask(), + batch_size=self.embedding_batch_size, + fillvalue=(None, None), + ) + + batch_count = math.ceil(num_chunks * num_speakers / self.embedding_batch_size) + + embedding_batches = [] + + if hook is not None: + hook("embeddings", None, total=batch_count, completed=0) + + for i, batch in enumerate(batches, 1): + waveforms, masks = zip(*filter(lambda b: b[0] is not None, batch)) + + waveform_batch = torch.vstack(waveforms) + # (batch_size, 1, num_samples) torch.Tensor + + mask_batch = torch.vstack(masks) + # (batch_size, num_frames) torch.Tensor + + embedding_batch: np.ndarray = self._embedding( + waveform_batch, masks=mask_batch + ) + # (batch_size, dimension) np.ndarray + + embedding_batches.append(embedding_batch) + + if hook is not None: + hook("embeddings", embedding_batch, total=batch_count, completed=i) + + embedding_batches = np.vstack(embedding_batches) + + embeddings = rearrange(embedding_batches, "(c s) d -> c s d", c=num_chunks) + + # caching embeddings for subsequent trials + # (see comments at the top of this method for more details) + if self.training: + if self._segmentation.model.specifications[0].powerset: + file["training_cache/embeddings"] = { + "embeddings": embeddings, + } + else: + file["training_cache/embeddings"] = { + "segmentation.threshold": self.segmentation.threshold, + "embeddings": embeddings, + } + + return embeddings + + def reconstruct( + self, + segmentations: SlidingWindowFeature, + hard_clusters: np.ndarray, + count: SlidingWindowFeature, + ) -> SlidingWindowFeature: + """Build final discrete diarization out of clustered segmentation + + Parameters + ---------- + segmentations : (num_chunks, num_frames, num_speakers) SlidingWindowFeature + Raw speaker segmentation. + hard_clusters : (num_chunks, num_speakers) array + Output of clustering step. + count : (total_num_frames, 1) SlidingWindowFeature + Instantaneous number of active speakers. + + Returns + ------- + discrete_diarization : SlidingWindowFeature + Discrete (0s and 1s) diarization. + """ + + num_chunks, num_frames, local_num_speakers = segmentations.data.shape + + num_clusters = np.max(hard_clusters) + 1 + clustered_segmentations = np.NAN * np.zeros( + (num_chunks, num_frames, num_clusters) + ) + + for c, (cluster, (chunk, segmentation)) in enumerate( + zip(hard_clusters, segmentations) + ): + # cluster is (local_num_speakers, )-shaped + # segmentation is (num_frames, local_num_speakers)-shaped + for k in np.unique(cluster): + if k == -2: + continue + + # TODO: can we do better than this max here? + clustered_segmentations[c, :, k] = np.max( + segmentation[:, cluster == k], axis=1 + ) + + clustered_segmentations = SlidingWindowFeature( + clustered_segmentations, segmentations.sliding_window + ) + return clustered_segmentations + return self.to_diarization(clustered_segmentations, count) + + def apply( + self, + file: AudioFile, + num_speakers: Optional[int] = None, + min_speakers: Optional[int] = None, + max_speakers: Optional[int] = None, + return_embeddings: bool = False, + hook: Optional[Callable] = None, + ) -> Annotation: + """Apply speaker diarization + + Parameters + ---------- + file : AudioFile + Processed file. + num_speakers : int, optional + Number of speakers, when known. + min_speakers : int, optional + Minimum number of speakers. Has no effect when `num_speakers` is provided. + max_speakers : int, optional + Maximum number of speakers. Has no effect when `num_speakers` is provided. + return_embeddings : bool, optional + Return representative speaker embeddings. + hook : callable, optional + Callback called after each major steps of the pipeline as follows: + hook(step_name, # human-readable name of current step + step_artefact, # artifact generated by current step + file=file) # file being processed + Time-consuming steps call `hook` multiple times with the same `step_name` + and additional `completed` and `total` keyword arguments usable to track + progress of current step. + + Returns + ------- + diarization : Annotation + Speaker diarization + sources : SlidingWindowFeature + Separated sources + embeddings : np.array, optional + Representative speaker embeddings such that `embeddings[i]` is the + speaker embedding for i-th speaker in diarization.labels(). + Only returned when `return_embeddings` is True. + """ + + # setup hook (e.g. for debugging purposes) + hook = self.setup_hook(file, hook=hook) + + num_speakers, min_speakers, max_speakers = self.set_num_speakers( + num_speakers=num_speakers, + min_speakers=min_speakers, + max_speakers=max_speakers, + ) + + segmentations, separations = self.get_segmentations(file, hook=hook) + hook("segmentation", segmentations) + # shape: (num_chunks, num_frames, local_num_speakers) + hook("separations", separations) + # shape: (num_chunks, nums_samples, local_num_speakers) + + # binarize segmentation + if self._segmentation.model.specifications[0].powerset: + binarized_segmentations = segmentations + else: + binarized_segmentations: SlidingWindowFeature = binarize( + segmentations, + onset=self.segmentation.threshold, + initial_state=False, + ) + + # estimate frame-level number of instantaneous speakers + count = self.speaker_count( + binarized_segmentations, + self._segmentation.model.receptive_field, + warm_up=(0.0, 0.0), + ) + hook("speaker_counting", count) + + # shape: (num_frames, 1) + # dtype: int + + # exit early when no speaker is ever active + if np.nanmax(count.data) == 0.0: + diarization = Annotation(uri=file["uri"]) + if return_embeddings: + return diarization, None, np.zeros((0, self._embedding.dimension)) + + return diarization, None + + if self.klustering == "OracleClustering" and not return_embeddings: + embeddings = None + else: + embeddings = self.get_embeddings( + file, + binarized_segmentations, + exclude_overlap=self.embedding_exclude_overlap, + hook=hook, + ) + hook("embeddings", embeddings) + # shape: (num_chunks, local_num_speakers, dimension) + + hard_clusters, _, centroids = self.clustering( + embeddings=embeddings, + segmentations=binarized_segmentations, + num_clusters=num_speakers, + min_clusters=min_speakers, + max_clusters=max_speakers, + file=file, # <== for oracle clustering + frames=self._segmentation.model.receptive_field, # <== for oracle clustering + ) + # hard_clusters: (num_chunks, num_speakers) + # centroids: (num_speakers, dimension) + + # number of detected clusters is the number of different speakers + num_different_speakers = np.max(hard_clusters) + 1 + + # detected number of speakers can still be out of bounds + # (specifically, lower than `min_speakers`), since there could be too few embeddings + # to make enough clusters with a given minimum cluster size. + if ( + num_different_speakers < min_speakers + or num_different_speakers > max_speakers + ): + warnings.warn( + textwrap.dedent( + f""" + The detected number of speakers ({num_different_speakers}) is outside + the given bounds [{min_speakers}, {max_speakers}]. This can happen if the + given audio file is too short to contain {min_speakers} or more speakers. + Try to lower the desired minimal number of speakers. + """ + ) + ) + + # during counting, we could possibly overcount the number of instantaneous + # speakers due to segmentation errors, so we cap the maximum instantaneous number + # of speakers by the `max_speakers` value + count.data = np.minimum(count.data, max_speakers).astype(np.int8) + + # reconstruct discrete diarization from raw hard clusters + + # keep track of inactive speakers + inactive_speakers = np.sum(binarized_segmentations.data, axis=1) == 0 + # shape: (num_chunks, num_speakers) + + hard_clusters[inactive_speakers] = -2 + discrete_diarization = self.reconstruct( + segmentations, + hard_clusters, + count, + ) + discrete_diarization = self.to_diarization(discrete_diarization, count) + hook("discrete_diarization", discrete_diarization) + clustered_separations = self.reconstruct(separations, hard_clusters, count) + frame_duration = separations.sliding_window.duration / separations.data.shape[1] + frames = SlidingWindow(step=frame_duration, duration=2 * frame_duration) + sources = Inference.aggregate( + clustered_separations, + frames=frames, + hamming=True, + missing=0.0, + skip_average=True, + ) + # zero-out sources when speaker is inactive + # WARNING: this should be rewritten to avoid huge memory consumption + if self.separation.leakage_removal: + asr_collar_frames = int( + self._segmentation.model.num_frames( + self.separation.asr_collar * self._audio.sample_rate + ) + ) + if asr_collar_frames > 0: + for i in range(discrete_diarization.data.shape[1]): + speaker_activation = discrete_diarization.data.T[i] + non_silent = np.where(speaker_activation != 0)[0] + remaining_gaps = np.where( + np.diff(non_silent) > 2 * asr_collar_frames + )[0] + remaining_zeros = [ + np.arange( + non_silent[gap] + asr_collar_frames, + non_silent[gap + 1] - asr_collar_frames, + ) + for gap in remaining_gaps + ] + # edge cases of long silent regions in beginning or end of audio + if non_silent[0] > asr_collar_frames: + remaining_zeros = [ + np.arange(0, non_silent[0] - asr_collar_frames) + ] + remaining_zeros + if non_silent[-1] < speaker_activation.shape[0] - asr_collar_frames: + remaining_zeros = remaining_zeros + [ + np.arange( + non_silent[-1] + asr_collar_frames, + speaker_activation.shape[0], + ) + ] + + speaker_activation_with_context = np.ones( + len(speaker_activation), dtype=float + ) + + speaker_activation_with_context[ + np.concatenate(remaining_zeros) + ] = 0.0 + + discrete_diarization.data.T[i] = speaker_activation_with_context + num_sources = sources.data.shape[1] + sources.data = ( + sources.data * discrete_diarization.align(sources).data[:, :num_sources] + ) + + # convert to continuous diarization + diarization = self.to_annotation( + discrete_diarization, + min_duration_on=0.0, + min_duration_off=self.segmentation.min_duration_off, + ) + diarization.uri = file["uri"] + + # at this point, `diarization` speaker labels are integers + # from 0 to `num_speakers - 1`, aligned with `centroids` rows. + + if "annotation" in file and file["annotation"]: + # when reference is available, use it to map hypothesized speakers + # to reference speakers (this makes later error analysis easier + # but does not modify the actual output of the diarization pipeline) + _, mapping = self.optimal_mapping( + file["annotation"], diarization, return_mapping=True + ) + + # in case there are more speakers in the hypothesis than in + # the reference, those extra speakers are missing from `mapping`. + # we add them back here + mapping = {key: mapping.get(key, key) for key in diarization.labels()} + + else: + # when reference is not available, rename hypothesized speakers + # to human-readable SPEAKER_00, SPEAKER_01, ... + mapping = { + label: expected_label + for label, expected_label in zip(diarization.labels(), self.classes()) + } + + diarization = diarization.rename_labels(mapping=mapping) + + # at this point, `diarization` speaker labels are strings (or mix of + # strings and integers when reference is available and some hypothesis + # speakers are not present in the reference) + + if not return_embeddings: + return diarization, sources + + # this can happen when we use OracleClustering + if centroids is None: + return diarization, sources, None + + # The number of centroids may be smaller than the number of speakers + # in the annotation. This can happen if the number of active speakers + # obtained from `speaker_count` for some frames is larger than the number + # of clusters obtained from `clustering`. In this case, we append zero embeddings + # for extra speakers + if len(diarization.labels()) > centroids.shape[0]: + centroids = np.pad( + centroids, ((0, len(diarization.labels()) - centroids.shape[0]), (0, 0)) + ) + + # re-order centroids so that they match + # the order given by diarization.labels() + inverse_mapping = {label: index for index, label in mapping.items()} + centroids = centroids[ + [inverse_mapping[label] for label in diarization.labels()] + ] + + return diarization, sources, centroids + + def get_metric(self) -> GreedyDiarizationErrorRate: + return GreedyDiarizationErrorRate(**self.der_variant) diff --git a/pyannote/audio/tasks/__init__.py b/pyannote/audio/tasks/__init__.py index 6cbba258f..517c6dd55 100644 --- a/pyannote/audio/tasks/__init__.py +++ b/pyannote/audio/tasks/__init__.py @@ -22,6 +22,7 @@ from .segmentation.multilabel import MultiLabelSegmentation # isort:skip from .segmentation.speaker_diarization import SpeakerDiarization # isort:skip +from .separation.PixIT import PixIT # isort:skip from .segmentation.voice_activity_detection import VoiceActivityDetection # isort:skip from .segmentation.overlapped_speech_detection import ( # isort:skip OverlappedSpeechDetection, @@ -41,4 +42,5 @@ "MultiLabelSegmentation", "SpeakerEmbedding", "Segmentation", + "PixIT", ] diff --git a/pyannote/audio/tasks/segmentation/mixins.py b/pyannote/audio/tasks/segmentation/mixins.py index be30828f0..cf6e3004a 100644 --- a/pyannote/audio/tasks/segmentation/mixins.py +++ b/pyannote/audio/tasks/segmentation/mixins.py @@ -401,7 +401,7 @@ def validation_step(self, batch, batch_idx: int): ) # reshape target so that there is one line per class when plotting it - y[y == 0] = np.NaN + y[y == 0] = np.nan if len(y.shape) == 2: y = y[:, :, np.newaxis] y *= np.arange(y.shape[2]) diff --git a/pyannote/audio/tasks/segmentation/speaker_diarization.py b/pyannote/audio/tasks/segmentation/speaker_diarization.py index 8a091b1f7..fb0b9b979 100644 --- a/pyannote/audio/tasks/segmentation/speaker_diarization.py +++ b/pyannote/audio/tasks/segmentation/speaker_diarization.py @@ -818,7 +818,7 @@ def validation_step(self, batch, batch_idx: int): ) # reshape target so that there is one line per class when plotting it - y[y == 0] = np.NaN + y[y == 0] = np.nan if len(y.shape) == 2: y = y[:, :, np.newaxis] y *= np.arange(y.shape[2]) diff --git a/pyannote/audio/tasks/separation/PixIT.py b/pyannote/audio/tasks/separation/PixIT.py new file mode 100644 index 000000000..cc647ee63 --- /dev/null +++ b/pyannote/audio/tasks/separation/PixIT.py @@ -0,0 +1,1179 @@ +# MIT License +# +# Copyright (c) 2024- CNRS +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +# AUTHOR: Joonas Kalda (github.com/joonaskalda) + +import itertools +import math +import random +import warnings +from collections import Counter +from functools import partial +from typing import Dict, Literal, Optional, Sequence, Text, Union + +import numpy as np +import torch +import torch.nn.functional +from matplotlib import pyplot as plt +from pyannote.core import Segment, SlidingWindowFeature +from pyannote.database.protocol import SpeakerDiarizationProtocol +from pyannote.database.protocol.protocol import Scope, Subset +from pytorch_lightning.loggers import MLFlowLogger, TensorBoardLogger +from rich.progress import track +from torch.utils.data import DataLoader, IterableDataset +from torch_audiomentations.core.transforms_interface import BaseWaveformTransform +from torchmetrics import Metric + +from pyannote.audio.core.task import Problem, Resolution, Specifications +from pyannote.audio.tasks.segmentation.mixins import SegmentationTask, Task +from pyannote.audio.torchmetrics import ( + OptimalDiarizationErrorRate, + OptimalDiarizationErrorRateThreshold, + OptimalFalseAlarmRate, + OptimalMissedDetectionRate, + OptimalSpeakerConfusionRate, +) +from pyannote.audio.utils.loss import binary_cross_entropy +from pyannote.audio.utils.permutation import permutate +from pyannote.audio.utils.random import create_rng_for_worker + +try: + from asteroid.losses import MixITLossWrapper, multisrc_neg_sisdr + + ASTEROID_IS_AVAILABLE = True +except ImportError: + ASTEROID_IS_AVAILABLE = False + + +Subsets = list(Subset.__args__) +Scopes = list(Scope.__args__) + + +class ValDataset(IterableDataset): + """Validation dataset class + + Val dataset needs to be iterable so that mixture of mixture generation + can be performed in the same way for both training and development. + + Parameters + ---------- + task : PixIT + Task instance. + """ + + def __init__(self, task: Task): + super().__init__() + self.task = task + + def __iter__(self): + return self.task.val__iter__() + + def __len__(self): + return self.task.val__len__() + + +class PixIT(SegmentationTask): + """Joint speaker diarization and speaker separation task based on PixIT + + Parameters + ---------- + protocol : SpeakerDiarizationProtocol + pyannote.database protocol + cache : str, optional + As (meta-)data preparation might take a very long time for large datasets, + it can be cached to disk for later (and faster!) re-use. + When `cache` does not exist, `Task.prepare_data()` generates training + and validation metadata from `protocol` and save them to disk. + When `cache` exists, `Task.prepare_data()` is skipped and (meta)-data + are loaded from disk. Defaults to a temporary path. + duration : float, optional + Chunks duration. Defaults to 5s. + max_speakers_per_chunk : int, optional + Maximum number of speakers per chunk (must be at least 2). + Defaults to estimating it from the training set. + max_speakers_per_frame : int, optional + Maximum number of (overlapping) speakers per frame. + Setting this value to 1 or more enables `powerset multi-class` training. + Default behavior is to use `multi-label` training. + weigh_by_cardinality: bool, optional + Weigh each powerset classes by the size of the corresponding speaker set. + In other words, {0, 1} powerset class weight is 2x bigger than that of {0} + or {1} powerset classes. Note that empty (non-speech) powerset class is + assigned the same weight as mono-speaker classes. Defaults to False (i.e. use + same weight for every class). Has no effect with `multi-label` training. + balance: Sequence[Text], optional + When provided, training samples are sampled uniformly with respect to these keys. + For instance, setting `balance` to ["database","subset"] will make sure that each + database & subset combination will be equally represented in the training samples. + weight: str, optional + When provided, use this key as frame-wise weight in loss function. + batch_size : int, optional + Number of training samples per batch. Defaults to 32. + num_workers : int, optional + Number of workers used for generating training samples. + Defaults to multiprocessing.cpu_count() // 2. + pin_memory : bool, optional + If True, data loaders will copy tensors into CUDA pinned + memory before returning them. See pytorch documentation + for more details. Defaults to False. + augmentation : BaseWaveformTransform, optional + torch_audiomentations waveform transform, used by dataloader + during training. + metric : optional + Validation metric(s). Can be anything supported by torchmetrics.MetricCollection. + Defaults to AUROC (area under the ROC curve). + separation_loss_weight : float, optional + Scaling factor between diarization and separation losses. Defaults to 0.5. + finetune_wavlm : bool, optional + If True, the WavLM feature extractor will be fine-tuned during training. + Defaults to True. + + References + ---------- + Joonas Kalda, Clément Pagés, Ricard Marxer, Tanel Alumäe, and Hervé Bredin. + "PixIT: Joint Training of Speaker Diarization and Speech Separation + from Real-world Multi-speaker Recordings" + Odyssey 2024. https://arxiv.org/abs/2403.02288 + """ + + def __init__( + self, + protocol: SpeakerDiarizationProtocol, + cache: Optional[Union[str, None]] = None, + duration: float = 5.0, + max_speakers_per_chunk: Optional[int] = None, + max_speakers_per_frame: Optional[int] = None, + weigh_by_cardinality: bool = False, + balance: Optional[Sequence[Text]] = None, + weight: Optional[Text] = None, + batch_size: int = 32, + num_workers: Optional[int] = None, + pin_memory: bool = False, + augmentation: Optional[BaseWaveformTransform] = None, + metric: Union[Metric, Sequence[Metric], Dict[str, Metric]] = None, + max_num_speakers: Optional[ + int + ] = None, # deprecated in favor of `max_speakers_per_chunk`` + loss: Literal["bce", "mse"] = None, # deprecated + separation_loss_weight: float = 0.5, + finetune_wavlm: bool = True, + ): + if not ASTEROID_IS_AVAILABLE: + raise ImportError( + "'asteroid' must be installed to train separation models with PixIT . " + "`pip install pyannote-audio[separation]` should do the trick." + ) + + super().__init__( + protocol, + duration=duration, + batch_size=batch_size, + num_workers=num_workers, + pin_memory=pin_memory, + augmentation=augmentation, + metric=metric, + cache=cache, + ) + + if not isinstance(protocol, SpeakerDiarizationProtocol): + raise ValueError( + "SpeakerDiarization task requires a SpeakerDiarizationProtocol." + ) + + # deprecation warnings + if max_speakers_per_chunk is None and max_num_speakers is not None: + max_speakers_per_chunk = max_num_speakers + warnings.warn( + "`max_num_speakers` has been deprecated in favor of `max_speakers_per_chunk`." + ) + if loss is not None: + warnings.warn("`loss` has been deprecated and has no effect.") + + # parameter validation + if max_speakers_per_frame is not None: + raise NotImplementedError( + "Diarization is done on masks separately which is incompatible powerset training" + ) + + if batch_size % 2 != 0: + raise ValueError("`batch_size` must be divisible by 2 for PixIT") + + self.max_speakers_per_chunk = max_speakers_per_chunk + self.max_speakers_per_frame = max_speakers_per_frame + self.weigh_by_cardinality = weigh_by_cardinality + self.balance = balance + self.weight = weight + self.separation_loss_weight = separation_loss_weight + self.mixit_loss = MixITLossWrapper(multisrc_neg_sisdr, generalized=True) + self.finetune_wavlm = finetune_wavlm + + def setup(self, stage=None): + super().setup(stage) + + # estimate maximum number of speakers per chunk when not provided + if self.max_speakers_per_chunk is None: + training = self.prepared_data["audio-metadata"]["subset"] == Subsets.index( + "train" + ) + + num_unique_speakers = [] + progress_description = f"Estimating maximum number of speakers per {self.duration:g}s chunk in the training set" + for file_id in track( + np.where(training)[0], description=progress_description + ): + annotations = self.prepared_data["annotations-segments"][ + np.where( + self.prepared_data["annotations-segments"]["file_id"] == file_id + )[0] + ] + annotated_regions = self.prepared_data["annotations-regions"][ + np.where( + self.prepared_data["annotations-regions"]["file_id"] == file_id + )[0] + ] + for region in annotated_regions: + # find annotations within current region + region_start = region["start"] + region_end = region["start"] + region["duration"] + region_annotations = annotations[ + np.where( + (annotations["start"] >= region_start) + * (annotations["end"] <= region_end) + )[0] + ] + + for window_start in np.arange( + region_start, region_end - self.duration, 0.25 * self.duration + ): + window_end = window_start + self.duration + window_annotations = region_annotations[ + np.where( + (region_annotations["start"] <= window_end) + * (region_annotations["end"] >= window_start) + )[0] + ] + num_unique_speakers.append( + len(np.unique(window_annotations["file_label_idx"])) + ) + + # because there might a few outliers, estimate the upper bound for the + # number of speakers as the 97th percentile + + num_speakers, counts = zip(*list(Counter(num_unique_speakers).items())) + num_speakers, counts = np.array(num_speakers), np.array(counts) + + sorting_indices = np.argsort(num_speakers) + num_speakers = num_speakers[sorting_indices] + counts = counts[sorting_indices] + + ratios = np.cumsum(counts) / np.sum(counts) + + for k, ratio in zip(num_speakers, ratios): + if k == 0: + print(f" - {ratio:7.2%} of all chunks contain no speech at all.") + elif k == 1: + print(f" - {ratio:7.2%} contain 1 speaker or less") + else: + print(f" - {ratio:7.2%} contain {k} speakers or less") + + self.max_speakers_per_chunk = max( + 2, + num_speakers[np.where(ratios > 0.97)[0][0]], + ) + + print( + f"Setting `max_speakers_per_chunk` to {self.max_speakers_per_chunk}. " + f"You can override this value (or avoid this estimation step) by passing `max_speakers_per_chunk={self.max_speakers_per_chunk}` to the task constructor." + ) + + if ( + self.max_speakers_per_frame is not None + and self.max_speakers_per_frame > self.max_speakers_per_chunk + ): + raise ValueError( + f"`max_speakers_per_frame` ({self.max_speakers_per_frame}) must be smaller " + f"than `max_speakers_per_chunk` ({self.max_speakers_per_chunk})" + ) + + # now that we know about the number of speakers upper bound + # we can set task specifications + speaker_diarization = Specifications( + duration=self.duration, + resolution=Resolution.FRAME, + problem=( + Problem.MULTI_LABEL_CLASSIFICATION + if self.max_speakers_per_frame is None + else Problem.MONO_LABEL_CLASSIFICATION + ), + permutation_invariant=True, + classes=[f"speaker#{i+1}" for i in range(self.max_speakers_per_chunk)], + powerset_max_classes=self.max_speakers_per_frame, + ) + + speaker_separation = Specifications( + duration=self.duration, + resolution=Resolution.FRAME, + problem=Problem.MONO_LABEL_CLASSIFICATION, # Doesn't matter + classes=[f"speaker#{i+1}" for i in range(self.max_speakers_per_chunk)], + ) + + self.specifications = (speaker_diarization, speaker_separation) + + def prepare_chunk(self, file_id: int, start_time: float, duration: float): + """Prepare chunk + + Parameters + ---------- + file_id : int + File index + start_time : float + Chunk start time + duration : float + Chunk duration. + + Returns + ------- + sample : dict + Dictionary containing the chunk data with the following keys: + - `X`: waveform + - `y`: target as a SlidingWindowFeature instance where y.labels is + in meta.scope space. + - `meta`: + - `scope`: target scope (0: file, 1: database, 2: global) + - `database`: database index + - `file`: file index + """ + + file = self.get_file(file_id) + + # get label scope + label_scope = Scopes[self.prepared_data["audio-metadata"][file_id]["scope"]] + label_scope_key = f"{label_scope}_label_idx" + + # + chunk = Segment(start_time, start_time + duration) + + sample = dict() + sample["X"], _ = self.model.audio.crop(file, chunk, duration=duration) + + # gather all annotations of current file + annotations = self.prepared_data["annotations-segments"][ + self.prepared_data["annotations-segments"]["file_id"] == file_id + ] + + # gather all annotations with non-empty intersection with current chunk + chunk_annotations = annotations[ + (annotations["start"] < chunk.end) & (annotations["end"] > chunk.start) + ] + + # discretize chunk annotations at model output resolution + step = self.model.receptive_field.step + half = 0.5 * self.model.receptive_field.duration + + start = np.maximum(chunk_annotations["start"], chunk.start) - chunk.start - half + start_idx = np.maximum(0, np.round(start / step)).astype(int) + + end = np.minimum(chunk_annotations["end"], chunk.end) - chunk.start - half + end_idx = np.round(end / step).astype(int) + + # get list and number of labels for current scope + labels = list(np.unique(chunk_annotations[label_scope_key])) + num_labels = len(labels) + + if num_labels > self.max_speakers_per_chunk: + pass + + # initial frame-level targets + num_frames = self.model.num_frames( + round(duration * self.model.hparams.sample_rate) + ) + y = np.zeros((num_frames, num_labels), dtype=np.uint8) + + # map labels to indices + mapping = {label: idx for idx, label in enumerate(labels)} + + for start, end, label in zip( + start_idx, end_idx, chunk_annotations[label_scope_key] + ): + mapped_label = mapping[label] + y[start : end + 1, mapped_label] = 1 + + sample["y"] = SlidingWindowFeature(y, self.model.receptive_field, labels=labels) + + metadata = self.prepared_data["audio-metadata"][file_id] + sample["meta"] = {key: metadata[key] for key in metadata.dtype.names} + sample["meta"]["file"] = file_id + + return sample + + def val_dataloader(self) -> DataLoader: + """Validation data loader + + Returns + ------- + DataLoader + Validation data loader. + """ + return DataLoader( + ValDataset(self), + batch_size=self.batch_size, + num_workers=self.num_workers, + pin_memory=self.pin_memory, + drop_last=True, + collate_fn=partial(self.collate_fn, stage="train"), + ) + + def val__iter__(self): + """Iterate over validation samples + + Yields + ------ + dict: + X: (time, channel) + Audio chunks. + y: (frame, ) + Frame-level targets. Note that frame < time. + `frame` is infered automagically from the + example model output. + ... + """ + + # create worker-specific random number generator + rng = create_rng_for_worker(self.model) + + balance = getattr(self, "balance", None) + if balance is None: + chunks = self.val__iter__helper(rng) + + else: + # create a subchunk generator for each combination of "balance" keys + subchunks = dict() + for product in itertools.product( + [self.metadata_unique_values[key] for key in balance] + ): + filters = {key: value for key, value in zip(balance, product)} + subchunks[product] = self.val__iter__helper(rng, **filters) + + while True: + # select one subchunk generator at random (with uniform probability) + # so that it is balanced on average + if balance is not None: + chunks = subchunks[rng.choice(subchunks)] + + # generate random chunk + yield next(chunks) + + def common__iter__helper(self, split, rng: random.Random, **filters): + """Iterate over samples with optional domain filtering + + Mixtures are paired so that they have no speakers in common, come from the + same file, and the combined number of speakers is no greater than + max_speaker_per_chunk. + + Parameters + ---------- + rng : random.Random + Random number generator + filters : dict, optional + When provided (as {key: value} dict), filter files so that + only files such as file[key] == value are used for generating chunks. + + Yields + ------ + chunk : dict + Chunks. + """ + # indices of files that matches domain filters + split_files = self.prepared_data["audio-metadata"]["subset"] == Subsets.index( + split + ) + for key, value in filters.items(): + split_files &= self.prepared_data["audio-metadata"][ + key + ] == self.prepared_data["metadata"][key].index(value) + file_ids = np.where(split_files)[0] + + # turn annotated duration into a probability distribution + annotated_duration = self.prepared_data["audio-annotated"][file_ids] + cum_prob_annotated_duration = np.cumsum( + annotated_duration / np.sum(annotated_duration) + ) + + duration = self.duration + + num_chunks_per_file = getattr(self, "num_chunks_per_file", 1) + + while True: + # select one file at random (with probability proportional to its annotated duration) + file_id = file_ids[cum_prob_annotated_duration.searchsorted(rng.random())] + annotations = self.prepared_data["annotations-segments"][ + np.where( + self.prepared_data["annotations-segments"]["file_id"] == file_id + )[0] + ] + + # generate `num_chunks_per_file` chunks from this file + for _ in range(num_chunks_per_file): + # find indices of annotated regions in this file + annotated_region_indices = np.where( + self.prepared_data["annotations-regions"]["file_id"] == file_id + )[0] + + # turn annotated regions duration into a probability distribution + + cum_prob_annotated_regions_duration = np.cumsum( + self.prepared_data["annotations-regions"]["duration"][ + annotated_region_indices + ] + / np.sum( + self.prepared_data["annotations-regions"]["duration"][ + annotated_region_indices + ] + ) + ) + + # selected one annotated region at random (with probability proportional to its duration) + annotated_region_index = annotated_region_indices[ + cum_prob_annotated_regions_duration.searchsorted(rng.random()) + ] + + # select one chunk at random in this annotated region + _, region_duration, start = self.prepared_data["annotations-regions"][ + annotated_region_index + ] + start_time = rng.uniform(start, start + region_duration - duration) + + # find speakers that already appeared and all annotations that contain them + chunk_annotations = annotations[ + (annotations["start"] < start_time + duration) + & (annotations["end"] > start_time) + ] + previous_speaker_labels = list( + np.unique(chunk_annotations["file_label_idx"]) + ) + repeated_speaker_annotations = annotations[ + np.isin(annotations["file_label_idx"], previous_speaker_labels) + ] + + if repeated_speaker_annotations.size == 0: + # if previous chunk has 0 speakers then just sample from all annotated regions again + first_chunk = self.prepare_chunk(file_id, start_time, duration) + + # selected one annotated region at random (with probability proportional to its duration) + annotated_region_index = np.random.choice( + annotated_region_indices, p=cum_prob_annotated_regions_duration + ) + + # select one chunk at random in this annotated region + _, region_duration, start = self.prepared_data[ + "annotations-regions" + ][annotated_region_index] + start_time = rng.uniform(start, start + region_duration - duration) + + second_chunk = self.prepare_chunk(file_id, start_time, duration) + + labels = first_chunk["y"].labels + second_chunk["y"].labels + + if len(labels) <= self.max_speakers_per_chunk: + yield first_chunk + yield second_chunk + + else: + # merge segments that contain repeated speakers + merged_repeated_segments = [ + [ + repeated_speaker_annotations["start"][0], + repeated_speaker_annotations["end"][0], + ] + ] + for _, start, end, _, _, _ in repeated_speaker_annotations: + previous = merged_repeated_segments[-1] + if start <= previous[1]: + previous[1] = max(previous[1], end) + else: + merged_repeated_segments.append([start, end]) + + # find segments that don't contain repeated speakers + segments_without_repeat = [] + current_region_index = 0 + previous_time = self.prepared_data["annotations-regions"]["start"][ + annotated_region_indices[0] + ] + for segment in merged_repeated_segments: + if ( + segment[0] + > self.prepared_data["annotations-regions"]["start"][ + annotated_region_indices[current_region_index] + ] + + self.prepared_data["annotations-regions"]["duration"][ + annotated_region_indices[current_region_index] + ] + ): + current_region_index += 1 + previous_time = self.prepared_data["annotations-regions"][ + "start" + ][annotated_region_indices[current_region_index]] + + if segment[0] - previous_time > duration: + segments_without_repeat.append( + (previous_time, segment[0], segment[0] - previous_time) + ) + previous_time = segment[1] + + dtype = [("start", "f"), ("end", "f"), ("duration", "f")] + segments_without_repeat = np.array( + segments_without_repeat, dtype=dtype + ) + + if np.sum(segments_without_repeat["duration"]) != 0: + # only yield chunks if it is possible to choose the second chunk so that yielded chunks are always paired + first_chunk = self.prepare_chunk(file_id, start_time, duration) + + prob_segments_duration = segments_without_repeat[ + "duration" + ] / np.sum(segments_without_repeat["duration"]) + segment = np.random.choice( + segments_without_repeat, p=prob_segments_duration + ) + + start, end, _ = segment + new_start_time = rng.uniform(start, end - duration) + second_chunk = self.prepare_chunk( + file_id, new_start_time, duration + ) + + labels = first_chunk["y"].labels + second_chunk["y"].labels + if len(labels) <= self.max_speakers_per_chunk: + yield first_chunk + yield second_chunk + + def val__iter__helper(self, rng: random.Random, **filters): + """Iterate over validation samples with optional domain filtering + + Parameters + ---------- + rng : random.Random + Random number generator + filters : dict, optional + When provided (as {key: value} dict), filter validation files so that + only files such as file[key] == value are used for generating chunks. + + Yields + ------ + chunk : dict + validation chunks. + """ + + return self.common__iter__helper("development", rng, **filters) + + def train__iter__helper(self, rng: random.Random, **filters): + """Iterate over training samples with optional domain filtering + + Parameters + ---------- + rng : random.Random + Random number generator + filters : dict, optional + When provided (as {key: value} dict), filter training files so that + only files such as file[key] == value are used for generating chunks. + + Yields + ------ + chunk : dict + Training chunks. + """ + + return self.common__iter__helper("train", rng, **filters) + + def collate_fn(self, batch, stage="train"): + """Collate function used for most segmentation tasks + + This function does the following: + * stack waveforms into a (batch_size, num_channels, num_samples) tensor batch["X"]) + * apply augmentation when in "train" stage + * convert targets into a (batch_size, num_frames, num_classes) tensor batch["y"] + * collate any other keys that might be present in the batch using pytorch default_collate function + + Parameters + ---------- + batch : list of dict + List of training samples. + + Returns + ------- + batch : dict + Collated batch as {"X": torch.Tensor, "y": torch.Tensor} dict. + """ + + # collate X + collated_X = self.collate_X(batch) + + # collate y + collated_y = self.collate_y(batch) + + # collate metadata + collated_meta = self.collate_meta(batch) + + # apply augmentation (only in "train" stage) + self.augmentation.train(mode=(stage == "train")) + augmented = self.augmentation( + samples=collated_X, + sample_rate=self.model.hparams.sample_rate, + targets=collated_y.unsqueeze(1), + ) + + return { + "X": augmented.samples, + "y": augmented.targets.squeeze(1), + "meta": collated_meta, + } + + def collate_y(self, batch) -> torch.Tensor: + """ + + Parameters + ---------- + batch : list + List of samples to collate. + "y" field is expected to be a SlidingWindowFeature. + + Returns + ------- + y : torch.Tensor + Collated target tensor of shape (num_frames, self.max_speakers_per_chunk) + If one chunk has more than `self.max_speakers_per_chunk` speakers, we keep + the max_speakers_per_chunk most talkative ones. If it has less, we pad with + zeros (artificial inactive speakers). + """ + + collated_y = [] + for b in batch: + y = b["y"].data + num_speakers = len(b["y"].labels) + if num_speakers > self.max_speakers_per_chunk: + # sort speakers in descending talkativeness order + indices = np.argsort(-np.sum(y, axis=0), axis=0) + # keep only the most talkative speakers + y = y[:, indices[: self.max_speakers_per_chunk]] + + # TODO: we should also sort the speaker labels in the same way + + elif num_speakers < self.max_speakers_per_chunk: + # create inactive speakers by zero padding + y = np.pad( + y, + ((0, 0), (0, self.max_speakers_per_chunk - num_speakers)), + mode="constant", + ) + + else: + # we have exactly the right number of speakers + pass + + collated_y.append(y) + + return torch.from_numpy(np.stack(collated_y)) + + def segmentation_loss( + self, + permutated_prediction: torch.Tensor, + target: torch.Tensor, + weight: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """Permutation-invariant segmentation loss + + Parameters + ---------- + permutated_prediction : (batch_size, num_frames, num_classes) torch.Tensor + Permutated speaker activity predictions. + target : (batch_size, num_frames, num_speakers) torch.Tensor + Speaker activity. + weight : (batch_size, num_frames, 1) torch.Tensor, optional + Frames weight. + + Returns + ------- + seg_loss : torch.Tensor + Permutation-invariant segmentation loss + """ + + seg_loss = binary_cross_entropy( + permutated_prediction, target.float(), weight=weight + ) + + return seg_loss + + def create_mixtures_of_mixtures(self, mix1, mix2, target1, target2): + """ + Creates mixtures of mixtures and corresponding diarization targets. + Keeps track of how many speakers came from each mixture in order to + reconstruct the original mixtures. + + Parameters + ---------- + mix1 : torch.Tensor + First mixture. + mix2 : torch.Tensor + Second mixture. + target1 : torch.Tensor + First mixture diarization targets. + target2 : torch.Tensor + Second mixture diarization targets. + + Returns + ------- + mom : torch.Tensor + Mixtures of mixtures. + targets : torch.Tensor + Diarization targets for mixtures of mixtures. + num_active_speakers_mix1 : torch.Tensor + Number of active speakers in the first mixture. + num_active_speakers_mix2 : torch.Tensor + Number of active speakers in the second mixture. + """ + batch_size = mix1.shape[0] + mom = mix1 + mix2 + num_active_speakers_mix1 = (target1.sum(dim=1) != 0).sum(dim=1) + num_active_speakers_mix2 = (target2.sum(dim=1) != 0).sum(dim=1) + targets = [] + for i in range(batch_size): + target = torch.cat( + ( + target1[i][:, target1[i].sum(dim=0) != 0], + target2[i][:, target2[i].sum(dim=0) != 0], + ), + dim=1, + ) + padding_dim = ( + target1.shape[2] + - num_active_speakers_mix1[i] + - num_active_speakers_mix2[i] + ) + padding_tensor = torch.zeros( + (target1.shape[1], padding_dim), device=target.device + ) + target = torch.cat((target, padding_tensor), dim=1) + targets.append(target) + targets = torch.stack(targets) + + return mom, targets, num_active_speakers_mix1, num_active_speakers_mix2 + + def common_step(self, batch): + """Common step for training and validation + + Parameters + ---------- + batch : dict of torch.Tensor + Current batch. + + Returns + ------- + seg_loss : torch.Tensor + Segmentation loss. + separation_loss : torch.Tensor + Separation loss. + diarization : torch.Tensor + Diarization predictions. + permutated_diarization : torch.Tensor + Permutated diarization predictions that minizimise seg_loss. + target : torch.Tensor + Diarization target. + """ + + target = batch["y"] + # (batch_size, num_frames, num_speakers) + + waveform = batch["X"] + # (batch_size, num_channels, num_samples) + + # forward pass + bsz = waveform.shape[0] + + # MoMs can't be created for batch size < 2 + if bsz < 2: + return None + # if bsz not even, then leave out last sample + if bsz % 2 != 0: + waveform = waveform[:-1] + + mix1 = waveform[0::2].squeeze(1) + mix2 = waveform[1::2].squeeze(1) + + ( + mom, + mom_target, + _, + _, + ) = self.create_mixtures_of_mixtures(mix1, mix2, target[0::2], target[1::2]) + target = torch.cat((target[0::2], target[1::2], mom_target), dim=0) + + diarization, sources = self.model(torch.cat((mix1, mix2, mom), dim=0)) + mom_sources = sources[bsz:] + + batch_size, num_frames, _ = diarization.shape + # (batch_size, num_frames, num_classes) + + # frames weight + weight_key = getattr(self, "weight", None) + weight = batch.get( + weight_key, + torch.ones(batch_size, num_frames, 1, device=self.model.device), + ) + # (batch_size, num_frames, 1) + + permutated_diarization, _ = permutate(target, diarization) + + seg_loss = self.segmentation_loss(permutated_diarization, target, weight=weight) + + separation_loss = self.mixit_loss( + mom_sources.transpose(1, 2), torch.stack((mix1, mix2)).transpose(0, 1) + ).mean() + + return ( + seg_loss, + separation_loss, + diarization, + permutated_diarization, + target, + ) + + def training_step(self, batch, batch_idx: int): + """Compute PixIT loss for training + + Parameters + ---------- + batch : (usually) dict of torch.Tensor + Current batch. + batch_idx: int + Batch index. + + Returns + ------- + loss : {str: torch.tensor} + {"loss": loss} + """ + # finetuning wavlm with a smaller learning rate requires two optimizers + # and manual gradient stepping + if self.finetune_wavlm: + wavlm_opt, rest_opt = self.model.optimizers() + wavlm_opt.zero_grad() + rest_opt.zero_grad() + + ( + seg_loss, + separation_loss, + diarization, + permutated_diarization, + target, + ) = self.common_step(batch) + self.model.log( + "loss/train/separation", + separation_loss, + on_step=False, + on_epoch=True, + prog_bar=False, + logger=True, + ) + + self.model.log( + "loss/train/segmentation", + seg_loss, + on_step=False, + on_epoch=True, + prog_bar=False, + logger=True, + ) + + loss = ( + 1 - self.separation_loss_weight + ) * seg_loss + self.separation_loss_weight * separation_loss + + # skip batch if something went wrong for some reason + if torch.isnan(loss): + return None + + self.model.log( + "loss/train", + loss, + on_step=False, + on_epoch=True, + prog_bar=False, + logger=True, + ) + + if self.finetune_wavlm: + self.model.manual_backward(loss) + self.model.clip_gradients( + wavlm_opt, + gradient_clip_val=self.model.gradient_clip_val, + gradient_clip_algorithm="norm", + ) + self.model.clip_gradients( + rest_opt, + gradient_clip_val=self.model.gradient_clip_val, + gradient_clip_algorithm="norm", + ) + wavlm_opt.step() + rest_opt.step() + + return {"loss": loss} + + def default_metric( + self, + ) -> Union[Metric, Sequence[Metric], Dict[str, Metric]]: + """Returns diarization error rate and its components""" + + return { + "DiarizationErrorRate": OptimalDiarizationErrorRate(), + "DiarizationErrorRate/Threshold": OptimalDiarizationErrorRateThreshold(), + "DiarizationErrorRate/Confusion": OptimalSpeakerConfusionRate(), + "DiarizationErrorRate/Miss": OptimalMissedDetectionRate(), + "DiarizationErrorRate/FalseAlarm": OptimalFalseAlarmRate(), + } + + # TODO: no need to compute gradient in this method + def validation_step(self, batch, batch_idx: int): + """Compute validation loss and metric + + Parameters + ---------- + batch : dict of torch.Tensor + Current batch. + batch_idx: int + Batch index. + """ + + ( + seg_loss, + separation_loss, + diarization, + permutated_diarization, + target, + ) = self.common_step(batch) + + self.model.log( + "loss/val/separation", + separation_loss, + on_step=False, + on_epoch=True, + prog_bar=False, + logger=True, + ) + + self.model.log( + "loss/val/segmentation", + seg_loss, + on_step=False, + on_epoch=True, + prog_bar=False, + logger=True, + ) + + loss = ( + 1 - self.separation_loss_weight + ) * seg_loss + self.separation_loss_weight * separation_loss + + self.model.log( + "loss/val", + loss, + on_step=False, + on_epoch=True, + prog_bar=False, + logger=True, + ) + + self.model.validation_metric( + torch.transpose(diarization, 1, 2), + torch.transpose(target, 1, 2), + ) + + self.model.log_dict( + self.model.validation_metric, + on_step=False, + on_epoch=True, + prog_bar=True, + logger=True, + ) + + # log first batch visualization every 2^n epochs. + if ( + self.model.current_epoch == 0 + or math.log2(self.model.current_epoch) % 1 > 0 + or batch_idx > 0 + ): + return + + # visualize first 9 validation samples of first batch in Tensorboard/MLflow + + y = target.float().cpu().numpy() + y_pred = permutated_diarization.cpu().numpy() + + # prepare 3 x 3 grid (or smaller if batch size is smaller) + num_samples = min(self.batch_size, 9) + nrows = math.ceil(math.sqrt(num_samples)) + ncols = math.ceil(num_samples / nrows) + fig, axes = plt.subplots( + nrows=2 * nrows, ncols=ncols, figsize=(8, 5), squeeze=False + ) + + # reshape target so that there is one line per class when plotting it + y[y == 0] = np.nan + if len(y.shape) == 2: + y = y[:, :, np.newaxis] + y *= np.arange(y.shape[2]) + + # plot each sample + for sample_idx in range(num_samples): + # find where in the grid it should be plotted + row_idx = sample_idx // nrows + col_idx = sample_idx % ncols + + # plot target + ax_ref = axes[row_idx * 2 + 0, col_idx] + sample_y = y[sample_idx] + ax_ref.plot(sample_y) + ax_ref.set_xlim(0, len(sample_y)) + ax_ref.set_ylim(-1, sample_y.shape[1]) + ax_ref.get_xaxis().set_visible(False) + ax_ref.get_yaxis().set_visible(False) + + # plot predictions + ax_hyp = axes[row_idx * 2 + 1, col_idx] + sample_y_pred = y_pred[sample_idx] + ax_hyp.plot(sample_y_pred) + ax_hyp.set_ylim(-0.1, 1.1) + ax_hyp.set_xlim(0, len(sample_y)) + ax_hyp.get_xaxis().set_visible(False) + + plt.tight_layout() + + for logger in self.model.loggers: + if isinstance(logger, TensorBoardLogger): + logger.experiment.add_figure("samples", fig, self.model.current_epoch) + elif isinstance(logger, MLFlowLogger): + logger.experiment.log_figure( + run_id=logger.run_id, + figure=fig, + artifact_file=f"samples_epoch{self.model.current_epoch}.png", + ) + + plt.close(fig) diff --git a/pyannote/audio/tasks/separation/__init__.py b/pyannote/audio/tasks/separation/__init__.py new file mode 100644 index 000000000..759cf1321 --- /dev/null +++ b/pyannote/audio/tasks/separation/__init__.py @@ -0,0 +1,21 @@ +# MIT License +# +# Copyright (c) 2024- CNRS +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. diff --git a/setup.cfg b/setup.cfg index 5eb8f619d..1ae4b0d6c 100644 --- a/setup.cfg +++ b/setup.cfg @@ -51,6 +51,9 @@ dev = cli = hydra-core >=1.1,<1.2 typer >= 0.4.0,<0.5.0 +separation = + transformers >= 4.39.1 + asteroid >=0.7.0 [options.entry_points] diff --git a/version.txt b/version.txt index 944880fa1..15a279981 100644 --- a/version.txt +++ b/version.txt @@ -1 +1 @@ -3.2.0 +3.3.0