Skip to content

Commit

Permalink
wip: add pipeline working with joint model
Browse files Browse the repository at this point in the history
  • Loading branch information
hbredin committed Jul 8, 2024
1 parent f484033 commit 8608a1c
Showing 1 changed file with 375 additions and 2 deletions.
377 changes: 375 additions & 2 deletions pyannote/audio/pipelines/speaker_diarization.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,18 +27,20 @@
import math
import textwrap
import warnings
from typing import Callable, Optional, Text, Union
from typing import Callable, Optional, Text, Tuple, Union

import numpy as np
import torch
from einops import rearrange
from pyannote.core import Annotation, SlidingWindowFeature
from pyannote.metrics.diarization import GreedyDiarizationErrorRate
from pyannote.pipeline.parameter import ParamDict, Uniform
from sklearn.cluster import AgglomerativeClustering

from pyannote.audio import Audio, Inference, Model, Pipeline
from pyannote.audio.core.io import AudioFile
from pyannote.audio.pipelines.clustering import Clustering
from pyannote.audio.core.task import Problem, Resolution
from pyannote.audio.pipelines.clustering import AgglomerativeClustering, Clustering
from pyannote.audio.pipelines.speaker_verification import PretrainedSpeakerEmbedding
from pyannote.audio.pipelines.utils import (
PipelineModel,
Expand Down Expand Up @@ -647,3 +649,374 @@ def apply(

def get_metric(self) -> GreedyDiarizationErrorRate:
return GreedyDiarizationErrorRate(**self.der_variant)


class SpeakerDiarizationV2(SpeakerDiarizationMixin, Pipeline):
"""Speaker diarization pipeline with joint segmentation + embedding model
Parameters
----------
model : Model, str, or dict, optional
Pretrained (segmentation + embedding) model.
See pyannote.audio.pipelines.utils.get_model for supported format.
step: float, optional
The model is applied on a window sliding over the whole audio file.
`step` controls the step of this window, provided as a ratio of its
duration. Defaults to 0.1 (i.e. 90% overlap between two consecuive windows).
clustering : str, optional
Clustering algorithm. See pyannote.audio.pipelines.clustering.Clustering
for available options. Defaults to "AgglomerativeClustering".
batch_size : int, optional
Batch size used for inference. Defaults to 1.
use_auth_token : str, optional
When loading private huggingface.co models, set `use_auth_token`
to True or to a string containing your hugginface.co authentication
token that can be obtained by running `huggingface-cli login`
Usage
-----
# perform (unconstrained) diarization
>>> diarization = pipeline("/path/to/audio.wav")
# perform diarization, targetting exactly 4 speakers
>>> diarization = pipeline("/path/to/audio.wav", num_speakers=4)
# perform diarization, with at least 2 speakers and at most 10 speakers
>>> diarization = pipeline("/path/to/audio.wav", min_speakers=2, max_speakers=10)
# perform diarization and get one representative embedding per speaker
>>> diarization, embeddings = pipeline("/path/to/audio.wav", return_embeddings=True)
>>> for s, speaker in enumerate(diarization.labels()):
... # embeddings[s] is the embedding of speaker `speaker`
"""

def __init__(
self,
model: PipelineModel = None,
step: float = 0.1,
clustering: str = "AgglomerativeClustering",
batch_size: int = 1,
use_auth_token: Union[Text, None] = None,
):
super().__init__()

self.model = model
model: Model = get_model(model, use_auth_token=use_auth_token)

assert len(model.specifications) == 2
segmentation_specifications, embedding_specifications = model.specifications
# TODO: check that specs are correct
assert segmentation_specifications.problem == Problem.MONO_LABEL_CLASSIFICATION
assert segmentation_specifications.resolution == Resolution.FRAME
assert embedding_specifications.problem == Problem.REPRESENTATION
assert embedding_specifications.resolution == Resolution.CHUNK

self.step = step
self.klustering = clustering

duration: float = segmentation_specifications.duration
self._inference = Inference(
model,
duration=duration,
step=self.step * duration,
skip_aggregation=True,
skip_conversion=False, # <-- output multilabel segmentation
batch_size=batch_size,
)

self.clustering = AgglomerativeClustering(metric="cosine")

@property
def batch_size(self) -> int:
return self._inference.batch_size

@batch_size.setter
def batch_size(self, batch_size: int):
self._inference.batch_size = batch_size

def default_parameters(self):
raise NotImplementedError()

def classes(self):
speaker = 0
while True:
yield f"SPEAKER_{speaker:02d}"
speaker += 1

@property
def CACHED_INFERENCE(self):
return "training_cache/inference"

def get_inference(self, file, hook=None) -> Tuple[SlidingWindowFeature]:
"""Apply joint model
Parameter
---------
file : AudioFile
hook : Optional[Callable]
Returns
-------
segmentations : (num_chunks, num_frames, num_speakers) SlidingWindowFeature
embeddings : (num_chunks, num_speakers, dimension) SlidingWindowFeature
"""

if hook is not None:
hook = functools.partial(hook, "inference", None)

if self.training:
if self.CACHED_INFERENCE in file:
inference = file[self.CACHED_INFERENCE]
else:
inference = self._inference(file, hook=hook)
file[self.CACHED_INFERENCE] = inference
else:
inference = self._inference(file, hook=hook)

return inference

def reconstruct(
self,
segmentations: SlidingWindowFeature,
hard_clusters: np.ndarray,
count: SlidingWindowFeature,
) -> SlidingWindowFeature:
"""Build final discrete diarization out of clustered segmentation
Parameters
----------
segmentations : (num_chunks, num_frames, num_speakers) SlidingWindowFeature
Raw speaker segmentation.
hard_clusters : (num_chunks, num_speakers) array
Output of clustering step.
count : (total_num_frames, 1) SlidingWindowFeature
Instantaneous number of active speakers.
Returns
-------
discrete_diarization : SlidingWindowFeature
Discrete (0s and 1s) diarization.
"""

num_chunks, num_frames, local_num_speakers = segmentations.data.shape

num_clusters = np.max(hard_clusters) + 1
clustered_segmentations = np.nan * np.zeros(
(num_chunks, num_frames, num_clusters)
)

for c, (cluster, (chunk, segmentation)) in enumerate(
zip(hard_clusters, segmentations)
):
# cluster is (local_num_speakers, )-shaped
# segmentation is (num_frames, local_num_speakers)-shaped
for k in np.unique(cluster):
if k == -2:
continue

# TODO: can we do better than this max here?
clustered_segmentations[c, :, k] = np.max(
segmentation[:, cluster == k], axis=1
)

clustered_segmentations = SlidingWindowFeature(
clustered_segmentations, segmentations.sliding_window
)

return self.to_diarization(clustered_segmentations, count)

def apply(
self,
file: AudioFile,
num_speakers: Optional[int] = None,
min_speakers: Optional[int] = None,
max_speakers: Optional[int] = None,
return_embeddings: bool = False,
hook: Optional[Callable] = None,
) -> Union[Annotation, Tuple[Annotation, np.ndarray]]:
"""Apply speaker diarization
Parameters
----------
file : AudioFile
Processed file.
num_speakers : int, optional
Number of speakers, when known.
min_speakers : int, optional
Minimum number of speakers. Has no effect when `num_speakers` is provided.
max_speakers : int, optional
Maximum number of speakers. Has no effect when `num_speakers` is provided.
return_embeddings : bool, optional
Return representative speaker embeddings.
hook : callable, optional
Callback called after each major steps of the pipeline as follows:
hook(step_name, # human-readable name of current step
step_artefact, # artifact generated by current step
file=file) # file being processed
Time-consuming steps call `hook` multiple times with the same `step_name`
and additional `completed` and `total` keyword arguments usable to track
progress of current step.
Returns
-------
diarization : Annotation
Speaker diarization
embeddings : np.array, optional
Representative speaker embeddings such that `embeddings[i]` is the
speaker embedding for i-th speaker in diarization.labels().
Only returned when `return_embeddings` is True.
"""

# setup hook (e.g. for debugging purposes)
hook = self.setup_hook(file, hook=hook)

num_speakers, min_speakers, max_speakers = self.set_num_speakers(
num_speakers=num_speakers,
min_speakers=min_speakers,
max_speakers=max_speakers,
)

inference = self.get_inference(file, hook=hook)
hook("inference", inference)
binarized_segmentations, embeddings = inference
# shape: (num_chunks, num_frames, local_num_speakers)
num_chunks, num_frames, local_num_speakers = binarized_segmentations.data.shape
_, _, dimension = embeddings.data.shape

# estimate frame-level number of instantaneous speakers
count = self.speaker_count(
binarized_segmentations,
self._inference.model.receptive_field,
warm_up=(0.0, 0.0),
)
hook("speaker_counting", count)
# shape: (num_frames, 1)
# dtype: int

# exit early when no speaker is ever active
if np.nanmax(count.data) == 0.0:
diarization = Annotation(uri=file["uri"])
if return_embeddings:
return diarization, np.zeros((0, dimension))

return diarization

hard_clusters, _, centroids = self.clustering(
embeddings=embeddings.data,
segmentations=binarized_segmentations,
num_clusters=num_speakers,
min_clusters=min_speakers,
max_clusters=max_speakers,
file=file, # <== for oracle clustering
frames=self._inference.model.receptive_field, # <== for oracle clustering
)
# hard_clusters: (num_chunks, num_speakers)
# centroids: (num_speakers, dimension)

# number of detected clusters is the number of different speakers
num_different_speakers = np.max(hard_clusters) + 1

# detected number of speakers can still be out of bounds
# (specifically, lower than `min_speakers`), since there could be too few embeddings
# to make enough clusters with a given minimum cluster size.
if (
num_different_speakers < min_speakers
or num_different_speakers > max_speakers
):
warnings.warn(
textwrap.dedent(
f"""
The detected number of speakers ({num_different_speakers}) is outside
the given bounds [{min_speakers}, {max_speakers}]. This can happen if the
given audio file is too short to contain {min_speakers} or more speakers.
Try to lower the desired minimal number of speakers.
"""
)
)

# during counting, we could possibly overcount the number of instantaneous
# speakers due to segmentation errors, so we cap the maximum instantaneous number
# of speakers by the `max_speakers` value
count.data = np.minimum(count.data, max_speakers).astype(np.int8)

# reconstruct discrete diarization from raw hard clusters

# keep track of inactive speakers
inactive_speakers = np.sum(binarized_segmentations.data, axis=1) == 0
# shape: (num_chunks, num_speakers)

hard_clusters[inactive_speakers] = -2
discrete_diarization = self.reconstruct(
binarized_segmentations,
hard_clusters,
count,
)
hook("discrete_diarization", discrete_diarization)

# convert to continuous diarization
diarization = self.to_annotation(
discrete_diarization,
min_duration_on=0.0,
min_duration_off=0.0,
)
diarization.uri = file["uri"]

# at this point, `diarization` speaker labels are integers
# from 0 to `num_speakers - 1`, aligned with `centroids` rows.

if "annotation" in file and file["annotation"]:
# when reference is available, use it to map hypothesized speakers
# to reference speakers (this makes later error analysis easier
# but does not modify the actual output of the diarization pipeline)
_, mapping = self.optimal_mapping(
file["annotation"], diarization, return_mapping=True
)

# in case there are more speakers in the hypothesis than in
# the reference, those extra speakers are missing from `mapping`.
# we add them back here
mapping = {key: mapping.get(key, key) for key in diarization.labels()}

else:
# when reference is not available, rename hypothesized speakers
# to human-readable SPEAKER_00, SPEAKER_01, ...
mapping = {
label: expected_label
for label, expected_label in zip(diarization.labels(), self.classes())
}

diarization = diarization.rename_labels(mapping=mapping)

# at this point, `diarization` speaker labels are strings (or mix of
# strings and integers when reference is available and some hypothesis
# speakers are not present in the reference)

if not return_embeddings:
return diarization

# this can happen when we use OracleClustering
if centroids is None:
return diarization, None

# The number of centroids may be smaller than the number of speakers
# in the annotation. This can happen if the number of active speakers
# obtained from `speaker_count` for some frames is larger than the number
# of clusters obtained from `clustering`. In this case, we append zero embeddings
# for extra speakers
if len(diarization.labels()) > centroids.shape[0]:
centroids = np.pad(
centroids, ((0, len(diarization.labels()) - centroids.shape[0]), (0, 0))
)

# re-order centroids so that they match
# the order given by diarization.labels()
inverse_mapping = {label: index for index, label in mapping.items()}
centroids = centroids[
[inverse_mapping[label] for label in diarization.labels()]
]

return diarization, centroids

def get_metric(self) -> GreedyDiarizationErrorRate:
return GreedyDiarizationErrorRate(**self.der_variant)

0 comments on commit 8608a1c

Please sign in to comment.