Skip to content

Commit

Permalink
feat: add support for k-means clustering (#1774)
Browse files Browse the repository at this point in the history
  • Loading branch information
hbredin authored Oct 17, 2024
1 parent 5ae7a48 commit 1d36685
Show file tree
Hide file tree
Showing 6 changed files with 186 additions and 28 deletions.
3 changes: 2 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,14 @@

### New features

- feat: add support for `k-means` clustering
- feat: add `"hidden"` option to `ProgressHook`
- feat: add `FilterByNumberOfSpeakers` protocol files filter

### Fixes

- fix: fix clipping issue in speech separation pipeline ([@joonaskalda](https://github.com/joonaskalda/))


## Version 3.3.2 (2024-09-11)

### Fixes
Expand Down
82 changes: 77 additions & 5 deletions pyannote/audio/pipelines/clustering.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@

"""Clustering pipelines"""


import random
from enum import Enum
from typing import Optional, Tuple
Expand All @@ -35,6 +34,7 @@
from scipy.cluster.hierarchy import fcluster, linkage
from scipy.optimize import linear_sum_assignment
from scipy.spatial.distance import cdist
from sklearn.cluster import KMeans

from pyannote.audio.core.io import AudioFile
from pyannote.audio.pipelines.utils import oracle_segmentation
Expand Down Expand Up @@ -264,8 +264,8 @@ def __call__(

train_clusters = self.cluster(
train_embeddings,
min_clusters,
max_clusters,
min_clusters=min_clusters,
max_clusters=max_clusters,
num_clusters=num_clusters,
)

Expand Down Expand Up @@ -298,6 +298,8 @@ class AgglomerativeClustering(BaseClustering):
Minimum cluster size
"""

expects_num_clusters: bool = False

def __init__(
self,
metric: str = "cosine",
Expand All @@ -321,8 +323,8 @@ def __init__(
def cluster(
self,
embeddings: np.ndarray,
min_clusters: int,
max_clusters: int,
min_clusters: Optional[int] = None,
max_clusters: Optional[int] = None,
num_clusters: Optional[int] = None,
):
"""
Expand Down Expand Up @@ -471,9 +473,78 @@ def cluster(
return clusters


class KMeansClustering(BaseClustering):
"""KMeans clustering
Parameters
----------
metric : {"cosine", "euclidean"}, optional
Distance metric to use. Defaults to "cosine".
Hyper-parameters
----------------
None
"""

expects_num_clusters: bool = True

def __init__(
self,
metric: str = "cosine",
):
if metric not in ["cosine", "euclidean"]:
raise ValueError(
f"Unsupported metric: {metric}. Must be 'cosine' or 'euclidean'."
)

super().__init__(metric=metric)

def cluster(
self,
embeddings: np.ndarray,
min_clusters: Optional[int] = None,
max_clusters: Optional[int] = None,
num_clusters: Optional[int] = None,
):
"""Perform KMeans clustering
Parameters
----------
embeddings : (num_embeddings, dimension) array
Embeddings
num_clusters : int, optional
Expected number of clusters.
Returns
-------
clusters : (num_embeddings, ) array
0-indexed cluster indices.
"""

if num_clusters is None:
raise ValueError("`num_clusters` must be provided.")

num_embeddings, _ = embeddings.shape
if num_embeddings < num_clusters:
# one cluster per embedding as int
return np.arange(num_embeddings, dtype=np.int32)

# unit-normalize embeddings to use 'euclidean' distance
if self.metric == "cosine":
with np.errstate(divide="ignore", invalid="ignore"):
embeddings /= np.linalg.norm(embeddings, axis=-1, keepdims=True)

# perform Kmeans clustering
return KMeans(
n_clusters=num_clusters, n_init=3, random_state=42, copy_x=False
).fit_predict(embeddings)


class OracleClustering(BaseClustering):
"""Oracle clustering"""

expects_num_clusters: bool = True

def __call__(
self,
embeddings: Optional[np.ndarray] = None,
Expand Down Expand Up @@ -558,4 +629,5 @@ def __call__(

class Clustering(Enum):
AgglomerativeClustering = AgglomerativeClustering
KMeansClustering = KMeansClustering
OracleClustering = OracleClustering
20 changes: 18 additions & 2 deletions pyannote/audio/pipelines/speaker_diarization.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
import math
import textwrap
import warnings
from typing import Callable, Optional, Text, Union
from typing import Callable, Mapping, Optional, Text, Union

import numpy as np
import torch
Expand All @@ -45,6 +45,7 @@
SpeakerDiarizationMixin,
get_model,
)
from pyannote.audio.pipelines.utils.diarization import set_num_speakers
from pyannote.audio.utils.signal import binarize


Expand Down Expand Up @@ -177,6 +178,8 @@ def __init__(
)
self.clustering = Klustering.value(metric=metric)

self._expects_num_speakers = self.clustering.expects_num_clusters

@property
def segmentation_batch_size(self) -> int:
return self._segmentation.batch_size
Expand Down Expand Up @@ -469,12 +472,25 @@ def apply(
# setup hook (e.g. for debugging purposes)
hook = self.setup_hook(file, hook=hook)

num_speakers, min_speakers, max_speakers = self.set_num_speakers(
num_speakers, min_speakers, max_speakers = set_num_speakers(
num_speakers=num_speakers,
min_speakers=min_speakers,
max_speakers=max_speakers,
)

# when using KMeans clustering (or equivalent), the number of speakers must
# be provided alongside the audio file. also, during pipeline training, we
# infer the number of speakers from the reference annotation to avoid the
# pipeline complaining about missing number of speakers.
if self._expects_num_speakers and num_speakers is None:
if isinstance(file, Mapping) and "annotation" in file:
num_speakers = len(file["annotation"].labels())

else:
raise ValueError(
f"num_speakers must be provided when using {self.klustering} clustering"
)

segmentations = self.get_segmentations(file, hook=hook)
hook("segmentation", segmentations)
# shape: (num_chunks, num_frames, local_num_speakers)
Expand Down
3 changes: 2 additions & 1 deletion pyannote/audio/pipelines/speech_separation.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
SpeakerDiarizationMixin,
get_model,
)
from pyannote.audio.pipelines.utils.diarization import set_num_speakers
from pyannote.audio.utils.signal import binarize


Expand Down Expand Up @@ -489,7 +490,7 @@ def apply(
# setup hook (e.g. for debugging purposes)
hook = self.setup_hook(file, hook=hook)

num_speakers, min_speakers, max_speakers = self.set_num_speakers(
num_speakers, min_speakers, max_speakers = set_num_speakers(
num_speakers=num_speakers,
min_speakers=min_speakers,
max_speakers=max_speakers,
Expand Down
58 changes: 43 additions & 15 deletions pyannote/audio/pipelines/utils/diarization.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,44 @@
from pyannote.audio.utils.signal import Binarize


# TODO: move to dedicated module
def set_num_speakers(
num_speakers: Optional[int] = None,
min_speakers: Optional[int] = None,
max_speakers: Optional[int] = None,
):
"""Validate number of speakers
Parameters
----------
num_speakers : int, optional
Number of speakers.
min_speakers : int, optional
Minimum number of speakers.
max_speakers : int, optional
Maximum number of speakers.
Returns
-------
num_speakers : int or None
min_speakers : int
max_speakers : int or np.inf
"""

# override {min|max}_num_speakers by num_speakers when available
min_speakers = num_speakers or min_speakers or 1
max_speakers = num_speakers or max_speakers or np.inf

if min_speakers > max_speakers:
raise ValueError(
f"min_speakers must be smaller than (or equal to) max_speakers "
f"(here: min_speakers={min_speakers:g} and max_speakers={max_speakers:g})."
)
if min_speakers == max_speakers:
num_speakers = min_speakers

return num_speakers, min_speakers, max_speakers


class SpeakerDiarizationMixin:
"""Defines a bunch of methods common to speaker diarization pipelines"""

Expand All @@ -58,20 +95,11 @@ def set_num_speakers(
min_speakers : int
max_speakers : int or np.inf
"""

# override {min|max}_num_speakers by num_speakers when available
min_speakers = num_speakers or min_speakers or 1
max_speakers = num_speakers or max_speakers or np.inf

if min_speakers > max_speakers:
raise ValueError(
f"min_speakers must be smaller than (or equal to) max_speakers "
f"(here: min_speakers={min_speakers:g} and max_speakers={max_speakers:g})."
)
if min_speakers == max_speakers:
num_speakers = min_speakers

return num_speakers, min_speakers, max_speakers
return set_num_speakers(
num_speakers=num_speakers,
min_speakers=min_speakers,
max_speakers=max_speakers,
)

@staticmethod
def optimal_mapping(
Expand Down
48 changes: 44 additions & 4 deletions pyannote/audio/utils/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,10 @@
# SOFTWARE.

from functools import partial
from typing import Optional

import torchaudio
from pyannote.core import Annotation
from pyannote.database import FileFinder, Protocol, get_annotated
from pyannote.database.protocol import SpeakerVerificationProtocol

Expand Down Expand Up @@ -62,14 +64,12 @@ def check_protocol(protocol: Protocol) -> Protocol:

# does protocol provide audio keys?
if "audio" not in file:

if "waveform" in file:
if "sample_rate" not in file:
msg = f'Protocol {protocol.name} provides audio with "waveform" key but is missing a "sample_rate" key.'
raise ValueError(msg)

else:

file_finder = FileFinder()
try:
_ = file_finder(file)
Expand All @@ -90,7 +90,6 @@ def check_protocol(protocol: Protocol) -> Protocol:
print(msg)

if "waveform" not in file and "torchaudio.info" not in file:

# use soundfile when available (it usually is faster than ffmpeg for getting info)
backends = (
torchaudio.list_audio_backends()
Expand All @@ -107,7 +106,6 @@ def check_protocol(protocol: Protocol) -> Protocol:
print(msg)

if "annotated" not in file:

if "duration" not in file:
protocol.preprocessors["duration"] = get_duration

Expand Down Expand Up @@ -143,3 +141,45 @@ def check_protocol(protocol: Protocol) -> Protocol:
}

return protocol, checks


class FilterByNumberOfSpeakers:
"""Filter files based on the number of speakers
Note
----
Always returns True if `current_file` does not have an "annotation" key.
"""

def __init__(
self,
num_speakers: Optional[int] = None,
min_speakers: Optional[int] = None,
max_speakers: Optional[int] = None,
):
from pyannote.audio.pipelines.utils.diarization import set_num_speakers

self.num_speakers, self.min_speakers, self.max_speakers = set_num_speakers(
num_speakers=num_speakers,
min_speakers=min_speakers,
max_speakers=max_speakers,
)

def __call__(self, current_file: dict) -> bool:
if "annotation" not in current_file:
return True

annotation: Annotation = current_file["annotation"]
num_speakers: int = len(annotation.labels())

if self.num_speakers is not None and self.num_speakers != num_speakers:
return False

if self.min_speakers is not None and self.min_speakers > num_speakers:
return False

if self.max_speakers is not None and self.max_speakers < num_speakers:
return False

return True

0 comments on commit 1d36685

Please sign in to comment.