diff --git a/CHANGELOG.md b/CHANGELOG.md index 992a432c0..100ef7278 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -4,13 +4,14 @@ ### New features +- feat: add support for `k-means` clustering - feat: add `"hidden"` option to `ProgressHook` +- feat: add `FilterByNumberOfSpeakers` protocol files filter ### Fixes - fix: fix clipping issue in speech separation pipeline ([@joonaskalda](https://github.com/joonaskalda/)) - ## Version 3.3.2 (2024-09-11) ### Fixes diff --git a/pyannote/audio/pipelines/clustering.py b/pyannote/audio/pipelines/clustering.py index cd4b38935..18b3b527c 100644 --- a/pyannote/audio/pipelines/clustering.py +++ b/pyannote/audio/pipelines/clustering.py @@ -22,7 +22,6 @@ """Clustering pipelines""" - import random from enum import Enum from typing import Optional, Tuple @@ -35,6 +34,7 @@ from scipy.cluster.hierarchy import fcluster, linkage from scipy.optimize import linear_sum_assignment from scipy.spatial.distance import cdist +from sklearn.cluster import KMeans from pyannote.audio.core.io import AudioFile from pyannote.audio.pipelines.utils import oracle_segmentation @@ -264,8 +264,8 @@ def __call__( train_clusters = self.cluster( train_embeddings, - min_clusters, - max_clusters, + min_clusters=min_clusters, + max_clusters=max_clusters, num_clusters=num_clusters, ) @@ -298,6 +298,8 @@ class AgglomerativeClustering(BaseClustering): Minimum cluster size """ + expects_num_clusters: bool = False + def __init__( self, metric: str = "cosine", @@ -321,8 +323,8 @@ def __init__( def cluster( self, embeddings: np.ndarray, - min_clusters: int, - max_clusters: int, + min_clusters: Optional[int] = None, + max_clusters: Optional[int] = None, num_clusters: Optional[int] = None, ): """ @@ -471,9 +473,78 @@ def cluster( return clusters +class KMeansClustering(BaseClustering): + """KMeans clustering + + Parameters + ---------- + metric : {"cosine", "euclidean"}, optional + Distance metric to use. Defaults to "cosine". + + Hyper-parameters + ---------------- + None + """ + + expects_num_clusters: bool = True + + def __init__( + self, + metric: str = "cosine", + ): + if metric not in ["cosine", "euclidean"]: + raise ValueError( + f"Unsupported metric: {metric}. Must be 'cosine' or 'euclidean'." + ) + + super().__init__(metric=metric) + + def cluster( + self, + embeddings: np.ndarray, + min_clusters: Optional[int] = None, + max_clusters: Optional[int] = None, + num_clusters: Optional[int] = None, + ): + """Perform KMeans clustering + + Parameters + ---------- + embeddings : (num_embeddings, dimension) array + Embeddings + num_clusters : int, optional + Expected number of clusters. + + Returns + ------- + clusters : (num_embeddings, ) array + 0-indexed cluster indices. + """ + + if num_clusters is None: + raise ValueError("`num_clusters` must be provided.") + + num_embeddings, _ = embeddings.shape + if num_embeddings < num_clusters: + # one cluster per embedding as int + return np.arange(num_embeddings, dtype=np.int32) + + # unit-normalize embeddings to use 'euclidean' distance + if self.metric == "cosine": + with np.errstate(divide="ignore", invalid="ignore"): + embeddings /= np.linalg.norm(embeddings, axis=-1, keepdims=True) + + # perform Kmeans clustering + return KMeans( + n_clusters=num_clusters, n_init=3, random_state=42, copy_x=False + ).fit_predict(embeddings) + + class OracleClustering(BaseClustering): """Oracle clustering""" + expects_num_clusters: bool = True + def __call__( self, embeddings: Optional[np.ndarray] = None, @@ -558,4 +629,5 @@ def __call__( class Clustering(Enum): AgglomerativeClustering = AgglomerativeClustering + KMeansClustering = KMeansClustering OracleClustering = OracleClustering diff --git a/pyannote/audio/pipelines/speaker_diarization.py b/pyannote/audio/pipelines/speaker_diarization.py index edfa5966c..e0d43e30c 100644 --- a/pyannote/audio/pipelines/speaker_diarization.py +++ b/pyannote/audio/pipelines/speaker_diarization.py @@ -27,7 +27,7 @@ import math import textwrap import warnings -from typing import Callable, Optional, Text, Union +from typing import Callable, Mapping, Optional, Text, Union import numpy as np import torch @@ -45,6 +45,7 @@ SpeakerDiarizationMixin, get_model, ) +from pyannote.audio.pipelines.utils.diarization import set_num_speakers from pyannote.audio.utils.signal import binarize @@ -177,6 +178,8 @@ def __init__( ) self.clustering = Klustering.value(metric=metric) + self._expects_num_speakers = self.clustering.expects_num_clusters + @property def segmentation_batch_size(self) -> int: return self._segmentation.batch_size @@ -469,12 +472,25 @@ def apply( # setup hook (e.g. for debugging purposes) hook = self.setup_hook(file, hook=hook) - num_speakers, min_speakers, max_speakers = self.set_num_speakers( + num_speakers, min_speakers, max_speakers = set_num_speakers( num_speakers=num_speakers, min_speakers=min_speakers, max_speakers=max_speakers, ) + # when using KMeans clustering (or equivalent), the number of speakers must + # be provided alongside the audio file. also, during pipeline training, we + # infer the number of speakers from the reference annotation to avoid the + # pipeline complaining about missing number of speakers. + if self._expects_num_speakers and num_speakers is None: + if isinstance(file, Mapping) and "annotation" in file: + num_speakers = len(file["annotation"].labels()) + + else: + raise ValueError( + f"num_speakers must be provided when using {self.klustering} clustering" + ) + segmentations = self.get_segmentations(file, hook=hook) hook("segmentation", segmentations) # shape: (num_chunks, num_frames, local_num_speakers) diff --git a/pyannote/audio/pipelines/speech_separation.py b/pyannote/audio/pipelines/speech_separation.py index 3dac1b1f5..dacb637b1 100644 --- a/pyannote/audio/pipelines/speech_separation.py +++ b/pyannote/audio/pipelines/speech_separation.py @@ -45,6 +45,7 @@ SpeakerDiarizationMixin, get_model, ) +from pyannote.audio.pipelines.utils.diarization import set_num_speakers from pyannote.audio.utils.signal import binarize @@ -489,7 +490,7 @@ def apply( # setup hook (e.g. for debugging purposes) hook = self.setup_hook(file, hook=hook) - num_speakers, min_speakers, max_speakers = self.set_num_speakers( + num_speakers, min_speakers, max_speakers = set_num_speakers( num_speakers=num_speakers, min_speakers=min_speakers, max_speakers=max_speakers, diff --git a/pyannote/audio/pipelines/utils/diarization.py b/pyannote/audio/pipelines/utils/diarization.py index 5a0f8f675..2130fb961 100644 --- a/pyannote/audio/pipelines/utils/diarization.py +++ b/pyannote/audio/pipelines/utils/diarization.py @@ -31,7 +31,44 @@ from pyannote.audio.utils.signal import Binarize -# TODO: move to dedicated module +def set_num_speakers( + num_speakers: Optional[int] = None, + min_speakers: Optional[int] = None, + max_speakers: Optional[int] = None, +): + """Validate number of speakers + + Parameters + ---------- + num_speakers : int, optional + Number of speakers. + min_speakers : int, optional + Minimum number of speakers. + max_speakers : int, optional + Maximum number of speakers. + + Returns + ------- + num_speakers : int or None + min_speakers : int + max_speakers : int or np.inf + """ + + # override {min|max}_num_speakers by num_speakers when available + min_speakers = num_speakers or min_speakers or 1 + max_speakers = num_speakers or max_speakers or np.inf + + if min_speakers > max_speakers: + raise ValueError( + f"min_speakers must be smaller than (or equal to) max_speakers " + f"(here: min_speakers={min_speakers:g} and max_speakers={max_speakers:g})." + ) + if min_speakers == max_speakers: + num_speakers = min_speakers + + return num_speakers, min_speakers, max_speakers + + class SpeakerDiarizationMixin: """Defines a bunch of methods common to speaker diarization pipelines""" @@ -58,20 +95,11 @@ def set_num_speakers( min_speakers : int max_speakers : int or np.inf """ - - # override {min|max}_num_speakers by num_speakers when available - min_speakers = num_speakers or min_speakers or 1 - max_speakers = num_speakers or max_speakers or np.inf - - if min_speakers > max_speakers: - raise ValueError( - f"min_speakers must be smaller than (or equal to) max_speakers " - f"(here: min_speakers={min_speakers:g} and max_speakers={max_speakers:g})." - ) - if min_speakers == max_speakers: - num_speakers = min_speakers - - return num_speakers, min_speakers, max_speakers + return set_num_speakers( + num_speakers=num_speakers, + min_speakers=min_speakers, + max_speakers=max_speakers, + ) @staticmethod def optimal_mapping( diff --git a/pyannote/audio/utils/protocol.py b/pyannote/audio/utils/protocol.py index bca0e5942..ba716d867 100644 --- a/pyannote/audio/utils/protocol.py +++ b/pyannote/audio/utils/protocol.py @@ -21,8 +21,10 @@ # SOFTWARE. from functools import partial +from typing import Optional import torchaudio +from pyannote.core import Annotation from pyannote.database import FileFinder, Protocol, get_annotated from pyannote.database.protocol import SpeakerVerificationProtocol @@ -62,14 +64,12 @@ def check_protocol(protocol: Protocol) -> Protocol: # does protocol provide audio keys? if "audio" not in file: - if "waveform" in file: if "sample_rate" not in file: msg = f'Protocol {protocol.name} provides audio with "waveform" key but is missing a "sample_rate" key.' raise ValueError(msg) else: - file_finder = FileFinder() try: _ = file_finder(file) @@ -90,7 +90,6 @@ def check_protocol(protocol: Protocol) -> Protocol: print(msg) if "waveform" not in file and "torchaudio.info" not in file: - # use soundfile when available (it usually is faster than ffmpeg for getting info) backends = ( torchaudio.list_audio_backends() @@ -107,7 +106,6 @@ def check_protocol(protocol: Protocol) -> Protocol: print(msg) if "annotated" not in file: - if "duration" not in file: protocol.preprocessors["duration"] = get_duration @@ -143,3 +141,45 @@ def check_protocol(protocol: Protocol) -> Protocol: } return protocol, checks + + +class FilterByNumberOfSpeakers: + """Filter files based on the number of speakers + + Note + ---- + Always returns True if `current_file` does not have an "annotation" key. + + """ + + def __init__( + self, + num_speakers: Optional[int] = None, + min_speakers: Optional[int] = None, + max_speakers: Optional[int] = None, + ): + from pyannote.audio.pipelines.utils.diarization import set_num_speakers + + self.num_speakers, self.min_speakers, self.max_speakers = set_num_speakers( + num_speakers=num_speakers, + min_speakers=min_speakers, + max_speakers=max_speakers, + ) + + def __call__(self, current_file: dict) -> bool: + if "annotation" not in current_file: + return True + + annotation: Annotation = current_file["annotation"] + num_speakers: int = len(annotation.labels()) + + if self.num_speakers is not None and self.num_speakers != num_speakers: + return False + + if self.min_speakers is not None and self.min_speakers > num_speakers: + return False + + if self.max_speakers is not None and self.max_speakers < num_speakers: + return False + + return True