diff --git a/pyannote/audio/pipelines/speaker_diarization.py b/pyannote/audio/pipelines/speaker_diarization.py index edfa5966c..a46425b64 100644 --- a/pyannote/audio/pipelines/speaker_diarization.py +++ b/pyannote/audio/pipelines/speaker_diarization.py @@ -27,7 +27,7 @@ import math import textwrap import warnings -from typing import Callable, Optional, Text, Union +from typing import Callable, Optional, Text, Tuple, Union import numpy as np import torch @@ -35,10 +35,12 @@ from pyannote.core import Annotation, SlidingWindowFeature from pyannote.metrics.diarization import GreedyDiarizationErrorRate from pyannote.pipeline.parameter import ParamDict, Uniform +from sklearn.cluster import AgglomerativeClustering from pyannote.audio import Audio, Inference, Model, Pipeline from pyannote.audio.core.io import AudioFile -from pyannote.audio.pipelines.clustering import Clustering +from pyannote.audio.core.task import Problem, Resolution +from pyannote.audio.pipelines.clustering import AgglomerativeClustering, Clustering from pyannote.audio.pipelines.speaker_verification import PretrainedSpeakerEmbedding from pyannote.audio.pipelines.utils import ( PipelineModel, @@ -647,3 +649,374 @@ def apply( def get_metric(self) -> GreedyDiarizationErrorRate: return GreedyDiarizationErrorRate(**self.der_variant) + + +class SpeakerDiarizationV2(SpeakerDiarizationMixin, Pipeline): + """Speaker diarization pipeline with joint segmentation + embedding model + + Parameters + ---------- + model : Model, str, or dict, optional + Pretrained (segmentation + embedding) model. + See pyannote.audio.pipelines.utils.get_model for supported format. + step: float, optional + The model is applied on a window sliding over the whole audio file. + `step` controls the step of this window, provided as a ratio of its + duration. Defaults to 0.1 (i.e. 90% overlap between two consecuive windows). + clustering : str, optional + Clustering algorithm. See pyannote.audio.pipelines.clustering.Clustering + for available options. Defaults to "AgglomerativeClustering". + batch_size : int, optional + Batch size used for inference. Defaults to 1. + use_auth_token : str, optional + When loading private huggingface.co models, set `use_auth_token` + to True or to a string containing your hugginface.co authentication + token that can be obtained by running `huggingface-cli login` + + Usage + ----- + # perform (unconstrained) diarization + >>> diarization = pipeline("/path/to/audio.wav") + + # perform diarization, targetting exactly 4 speakers + >>> diarization = pipeline("/path/to/audio.wav", num_speakers=4) + + # perform diarization, with at least 2 speakers and at most 10 speakers + >>> diarization = pipeline("/path/to/audio.wav", min_speakers=2, max_speakers=10) + + # perform diarization and get one representative embedding per speaker + >>> diarization, embeddings = pipeline("/path/to/audio.wav", return_embeddings=True) + >>> for s, speaker in enumerate(diarization.labels()): + ... # embeddings[s] is the embedding of speaker `speaker` + + """ + + def __init__( + self, + model: PipelineModel = None, + step: float = 0.1, + clustering: str = "AgglomerativeClustering", + batch_size: int = 1, + use_auth_token: Union[Text, None] = None, + ): + super().__init__() + + self.model = model + model: Model = get_model(model, use_auth_token=use_auth_token) + + assert len(model.specifications) == 2 + segmentation_specifications, embedding_specifications = model.specifications + # TODO: check that specs are correct + assert segmentation_specifications.problem == Problem.MONO_LABEL_CLASSIFICATION + assert segmentation_specifications.resolution == Resolution.FRAME + assert embedding_specifications.problem == Problem.REPRESENTATION + assert embedding_specifications.resolution == Resolution.CHUNK + + self.step = step + self.klustering = clustering + + duration: float = segmentation_specifications.duration + self._inference = Inference( + model, + duration=duration, + step=self.step * duration, + skip_aggregation=True, + skip_conversion=False, # <-- output multilabel segmentation + batch_size=batch_size, + ) + + self.clustering = AgglomerativeClustering(metric="cosine") + + @property + def batch_size(self) -> int: + return self._inference.batch_size + + @batch_size.setter + def batch_size(self, batch_size: int): + self._inference.batch_size = batch_size + + def default_parameters(self): + raise NotImplementedError() + + def classes(self): + speaker = 0 + while True: + yield f"SPEAKER_{speaker:02d}" + speaker += 1 + + @property + def CACHED_INFERENCE(self): + return "training_cache/inference" + + def get_inference(self, file, hook=None) -> Tuple[SlidingWindowFeature]: + """Apply joint model + + Parameter + --------- + file : AudioFile + hook : Optional[Callable] + + Returns + ------- + segmentations : (num_chunks, num_frames, num_speakers) SlidingWindowFeature + embeddings : (num_chunks, num_speakers, dimension) SlidingWindowFeature + """ + + if hook is not None: + hook = functools.partial(hook, "inference", None) + + if self.training: + if self.CACHED_INFERENCE in file: + inference = file[self.CACHED_INFERENCE] + else: + inference = self._inference(file, hook=hook) + file[self.CACHED_INFERENCE] = inference + else: + inference = self._inference(file, hook=hook) + + return inference + + def reconstruct( + self, + segmentations: SlidingWindowFeature, + hard_clusters: np.ndarray, + count: SlidingWindowFeature, + ) -> SlidingWindowFeature: + """Build final discrete diarization out of clustered segmentation + + Parameters + ---------- + segmentations : (num_chunks, num_frames, num_speakers) SlidingWindowFeature + Raw speaker segmentation. + hard_clusters : (num_chunks, num_speakers) array + Output of clustering step. + count : (total_num_frames, 1) SlidingWindowFeature + Instantaneous number of active speakers. + + Returns + ------- + discrete_diarization : SlidingWindowFeature + Discrete (0s and 1s) diarization. + """ + + num_chunks, num_frames, local_num_speakers = segmentations.data.shape + + num_clusters = np.max(hard_clusters) + 1 + clustered_segmentations = np.nan * np.zeros( + (num_chunks, num_frames, num_clusters) + ) + + for c, (cluster, (chunk, segmentation)) in enumerate( + zip(hard_clusters, segmentations) + ): + # cluster is (local_num_speakers, )-shaped + # segmentation is (num_frames, local_num_speakers)-shaped + for k in np.unique(cluster): + if k == -2: + continue + + # TODO: can we do better than this max here? + clustered_segmentations[c, :, k] = np.max( + segmentation[:, cluster == k], axis=1 + ) + + clustered_segmentations = SlidingWindowFeature( + clustered_segmentations, segmentations.sliding_window + ) + + return self.to_diarization(clustered_segmentations, count) + + def apply( + self, + file: AudioFile, + num_speakers: Optional[int] = None, + min_speakers: Optional[int] = None, + max_speakers: Optional[int] = None, + return_embeddings: bool = False, + hook: Optional[Callable] = None, + ) -> Union[Annotation, Tuple[Annotation, np.ndarray]]: + """Apply speaker diarization + + Parameters + ---------- + file : AudioFile + Processed file. + num_speakers : int, optional + Number of speakers, when known. + min_speakers : int, optional + Minimum number of speakers. Has no effect when `num_speakers` is provided. + max_speakers : int, optional + Maximum number of speakers. Has no effect when `num_speakers` is provided. + return_embeddings : bool, optional + Return representative speaker embeddings. + hook : callable, optional + Callback called after each major steps of the pipeline as follows: + hook(step_name, # human-readable name of current step + step_artefact, # artifact generated by current step + file=file) # file being processed + Time-consuming steps call `hook` multiple times with the same `step_name` + and additional `completed` and `total` keyword arguments usable to track + progress of current step. + + Returns + ------- + diarization : Annotation + Speaker diarization + embeddings : np.array, optional + Representative speaker embeddings such that `embeddings[i]` is the + speaker embedding for i-th speaker in diarization.labels(). + Only returned when `return_embeddings` is True. + """ + + # setup hook (e.g. for debugging purposes) + hook = self.setup_hook(file, hook=hook) + + num_speakers, min_speakers, max_speakers = self.set_num_speakers( + num_speakers=num_speakers, + min_speakers=min_speakers, + max_speakers=max_speakers, + ) + + inference = self.get_inference(file, hook=hook) + hook("inference", inference) + binarized_segmentations, embeddings = inference + # shape: (num_chunks, num_frames, local_num_speakers) + num_chunks, num_frames, local_num_speakers = binarized_segmentations.data.shape + _, _, dimension = embeddings.data.shape + + # estimate frame-level number of instantaneous speakers + count = self.speaker_count( + binarized_segmentations, + self._inference.model.receptive_field, + warm_up=(0.0, 0.0), + ) + hook("speaker_counting", count) + # shape: (num_frames, 1) + # dtype: int + + # exit early when no speaker is ever active + if np.nanmax(count.data) == 0.0: + diarization = Annotation(uri=file["uri"]) + if return_embeddings: + return diarization, np.zeros((0, dimension)) + + return diarization + + hard_clusters, _, centroids = self.clustering( + embeddings=embeddings.data, + segmentations=binarized_segmentations, + num_clusters=num_speakers, + min_clusters=min_speakers, + max_clusters=max_speakers, + file=file, # <== for oracle clustering + frames=self._inference.model.receptive_field, # <== for oracle clustering + ) + # hard_clusters: (num_chunks, num_speakers) + # centroids: (num_speakers, dimension) + + # number of detected clusters is the number of different speakers + num_different_speakers = np.max(hard_clusters) + 1 + + # detected number of speakers can still be out of bounds + # (specifically, lower than `min_speakers`), since there could be too few embeddings + # to make enough clusters with a given minimum cluster size. + if ( + num_different_speakers < min_speakers + or num_different_speakers > max_speakers + ): + warnings.warn( + textwrap.dedent( + f""" + The detected number of speakers ({num_different_speakers}) is outside + the given bounds [{min_speakers}, {max_speakers}]. This can happen if the + given audio file is too short to contain {min_speakers} or more speakers. + Try to lower the desired minimal number of speakers. + """ + ) + ) + + # during counting, we could possibly overcount the number of instantaneous + # speakers due to segmentation errors, so we cap the maximum instantaneous number + # of speakers by the `max_speakers` value + count.data = np.minimum(count.data, max_speakers).astype(np.int8) + + # reconstruct discrete diarization from raw hard clusters + + # keep track of inactive speakers + inactive_speakers = np.sum(binarized_segmentations.data, axis=1) == 0 + # shape: (num_chunks, num_speakers) + + hard_clusters[inactive_speakers] = -2 + discrete_diarization = self.reconstruct( + binarized_segmentations, + hard_clusters, + count, + ) + hook("discrete_diarization", discrete_diarization) + + # convert to continuous diarization + diarization = self.to_annotation( + discrete_diarization, + min_duration_on=0.0, + min_duration_off=0.0, + ) + diarization.uri = file["uri"] + + # at this point, `diarization` speaker labels are integers + # from 0 to `num_speakers - 1`, aligned with `centroids` rows. + + if "annotation" in file and file["annotation"]: + # when reference is available, use it to map hypothesized speakers + # to reference speakers (this makes later error analysis easier + # but does not modify the actual output of the diarization pipeline) + _, mapping = self.optimal_mapping( + file["annotation"], diarization, return_mapping=True + ) + + # in case there are more speakers in the hypothesis than in + # the reference, those extra speakers are missing from `mapping`. + # we add them back here + mapping = {key: mapping.get(key, key) for key in diarization.labels()} + + else: + # when reference is not available, rename hypothesized speakers + # to human-readable SPEAKER_00, SPEAKER_01, ... + mapping = { + label: expected_label + for label, expected_label in zip(diarization.labels(), self.classes()) + } + + diarization = diarization.rename_labels(mapping=mapping) + + # at this point, `diarization` speaker labels are strings (or mix of + # strings and integers when reference is available and some hypothesis + # speakers are not present in the reference) + + if not return_embeddings: + return diarization + + # this can happen when we use OracleClustering + if centroids is None: + return diarization, None + + # The number of centroids may be smaller than the number of speakers + # in the annotation. This can happen if the number of active speakers + # obtained from `speaker_count` for some frames is larger than the number + # of clusters obtained from `clustering`. In this case, we append zero embeddings + # for extra speakers + if len(diarization.labels()) > centroids.shape[0]: + centroids = np.pad( + centroids, ((0, len(diarization.labels()) - centroids.shape[0]), (0, 0)) + ) + + # re-order centroids so that they match + # the order given by diarization.labels() + inverse_mapping = {label: index for index, label in mapping.items()} + centroids = centroids[ + [inverse_mapping[label] for label in diarization.labels()] + ] + + return diarization, centroids + + def get_metric(self) -> GreedyDiarizationErrorRate: + return GreedyDiarizationErrorRate(**self.der_variant)