diff --git a/CHANGELOG.md b/CHANGELOG.md index 4b1944781..6b1b96e36 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,12 +2,20 @@ ## develop +### Breaking changes + +- BREAKING(task): drop support for `multilabel` training in `SpeakerDiarization` task +- BREAKING(task): drop support for `warm_up` option in `SpeakerDiarization` task +- BREAKING(task): drop support for `weigh_by_cardinality` option in `SpeakerDiarization` task +- BREAKING(task): drop support for `vad_loss` option in `SpeakerDiarization` task + ### New features -- feat: add support for `k-means` clustering -- feat: add `"hidden"` option to `ProgressHook` -- feat: add `FilterByNumberOfSpeakers` protocol files filter +- feat(clustering): add support for `k-means` clustering - feat(model): add `wav2vec_frozen` option to freeze/unfreeze `wav2vec` in `SSeRiouSS` architecture +- feat(task): add support for manual optimization in `SpeakerDiarization` task +- feat(utils): add `hidden` option to `ProgressHook` +- feat(utils): add `FilterByNumberOfSpeakers` protocol files filter ### Fixes diff --git a/pyannote/audio/tasks/segmentation/speaker_diarization.py b/pyannote/audio/tasks/segmentation/speaker_diarization.py index fb0b9b979..244615652 100644 --- a/pyannote/audio/tasks/segmentation/speaker_diarization.py +++ b/pyannote/audio/tasks/segmentation/speaker_diarization.py @@ -23,7 +23,7 @@ import math import warnings from collections import Counter -from typing import Dict, Literal, Optional, Sequence, Text, Tuple, Union +from typing import Dict, Literal, Optional, Sequence, Text, Union import numpy as np import torch @@ -43,14 +43,9 @@ DiarizationErrorRate, FalseAlarmRate, MissedDetectionRate, - OptimalDiarizationErrorRate, - OptimalDiarizationErrorRateThreshold, - OptimalFalseAlarmRate, - OptimalMissedDetectionRate, - OptimalSpeakerConfusionRate, SpeakerConfusionRate, ) -from pyannote.audio.utils.loss import binary_cross_entropy, mse_loss, nll_loss +from pyannote.audio.utils.loss import nll_loss from pyannote.audio.utils.permutation import permutate from pyannote.audio.utils.powerset import Powerset @@ -78,21 +73,7 @@ class SpeakerDiarization(SegmentationTask): Maximum number of speakers per chunk (must be at least 2). Defaults to estimating it from the training set. max_speakers_per_frame : int, optional - Maximum number of (overlapping) speakers per frame. - Setting this value to 1 or more enables `powerset multi-class` training. - Default behavior is to use `multi-label` training. - weigh_by_cardinality: bool, optional - Weigh each powerset classes by the size of the corresponding speaker set. - In other words, {0, 1} powerset class weight is 2x bigger than that of {0} - or {1} powerset classes. Note that empty (non-speech) powerset class is - assigned the same weight as mono-speaker classes. Defaults to False (i.e. use - same weight for every class). Has no effect with `multi-label` training. - warm_up : float or (float, float), optional - Use that many seconds on the left- and rightmost parts of each chunk - to warm up the model. While the model does process those left- and right-most - parts, only the remaining central part of each chunk is used for computing the - loss during training, and for aggregating scores during inference. - Defaults to 0. (i.e. no warm-up). + Maximum number of (overlapping) speakers per frame. Defaults to 2. balance: Sequence[Text], optional When provided, training samples are sampled uniformly with respect to these keys. For instance, setting `balance` to ["database","subset"] will make sure that each @@ -111,42 +92,34 @@ class SpeakerDiarization(SegmentationTask): augmentation : BaseWaveformTransform, optional torch_audiomentations waveform transform, used by dataloader during training. - vad_loss : {"bce", "mse"}, optional - Add voice activity detection loss. - Cannot be used in conjunction with `max_speakers_per_frame`. metric : optional Validation metric(s). Can be anything supported by torchmetrics.MetricCollection. - Defaults to AUROC (area under the ROC curve). + Defaults to DiarizationErrorRate and its components. References ---------- + Alexis Plaquet and Hervé Bredin + "Powerset multi-class cross entropy loss for neural speaker diarization" + Proc. Interspeech 2023 + Hervé Bredin and Antoine Laurent "End-To-End Speaker Segmentation for Overlap-Aware Resegmentation." Proc. Interspeech 2021 - - Zhihao Du, Shiliang Zhang, Siqi Zheng, and Zhijie Yan - "Speaker Embedding-aware Neural Diarization: an Efficient Framework for Overlapping - Speech Diarization in Meeting Scenarios" - https://arxiv.org/abs/2203.09767 - """ def __init__( self, protocol: SpeakerDiarizationProtocol, cache: Optional[Union[str, None]] = None, - duration: float = 2.0, + duration: float = 10.0, max_speakers_per_chunk: Optional[int] = None, - max_speakers_per_frame: Optional[int] = None, - weigh_by_cardinality: bool = False, - warm_up: Union[float, Tuple[float, float]] = 0.0, + max_speakers_per_frame: int = 2, balance: Optional[Sequence[Text]] = None, weight: Optional[Text] = None, batch_size: int = 32, num_workers: Optional[int] = None, pin_memory: bool = False, augmentation: Optional[BaseWaveformTransform] = None, - vad_loss: Literal["bce", "mse"] = None, metric: Union[Metric, Sequence[Metric], Dict[str, Metric]] = None, max_num_speakers: Optional[ int @@ -156,7 +129,6 @@ def __init__( super().__init__( protocol, duration=duration, - warm_up=warm_up, batch_size=batch_size, num_workers=num_workers, pin_memory=pin_memory, @@ -180,22 +152,15 @@ def __init__( warnings.warn("`loss` has been deprecated and has no effect.") # parameter validation - if max_speakers_per_frame is not None: - if max_speakers_per_frame < 1: - raise ValueError( - f"`max_speakers_per_frame` must be 1 or more (you used {max_speakers_per_frame})." - ) - if vad_loss is not None: - raise ValueError( - "`vad_loss` cannot be used jointly with `max_speakers_per_frame`" - ) + if max_speakers_per_frame < 1: + raise ValueError( + f"`max_speakers_per_frame` must be 1 or more (you used {max_speakers_per_frame})." + ) self.max_speakers_per_chunk = max_speakers_per_chunk self.max_speakers_per_frame = max_speakers_per_frame - self.weigh_by_cardinality = weigh_by_cardinality self.balance = balance self.weight = weight - self.vad_loss = vad_loss def setup(self, stage=None): super().setup(stage) @@ -276,10 +241,7 @@ def setup(self, stage=None): f"You can override this value (or avoid this estimation step) by passing `max_speakers_per_chunk={self.max_speakers_per_chunk}` to the task constructor." ) - if ( - self.max_speakers_per_frame is not None - and self.max_speakers_per_frame > self.max_speakers_per_chunk - ): + if self.max_speakers_per_frame > self.max_speakers_per_chunk: raise ValueError( f"`max_speakers_per_frame` ({self.max_speakers_per_frame}) must be smaller " f"than `max_speakers_per_chunk` ({self.max_speakers_per_chunk})" @@ -288,24 +250,20 @@ def setup(self, stage=None): # now that we know about the number of speakers upper bound # we can set task specifications self.specifications = Specifications( - problem=Problem.MULTI_LABEL_CLASSIFICATION - if self.max_speakers_per_frame is None - else Problem.MONO_LABEL_CLASSIFICATION, + problem=Problem.MONO_LABEL_CLASSIFICATION, resolution=Resolution.FRAME, duration=self.duration, min_duration=self.min_duration, - warm_up=self.warm_up, classes=[f"speaker#{i+1}" for i in range(self.max_speakers_per_chunk)], powerset_max_classes=self.max_speakers_per_frame, permutation_invariant=True, ) def setup_loss_func(self): - if self.specifications.powerset: - self.model.powerset = Powerset( - len(self.specifications.classes), - self.specifications.powerset_max_classes, - ) + self.model.powerset = Powerset( + len(self.specifications.classes), + self.specifications.powerset_max_classes, + ) def prepare_chunk(self, file_id: int, start_time: float, duration: float): """Prepare chunk @@ -440,85 +398,9 @@ def collate_y(self, batch) -> torch.Tensor: return torch.from_numpy(np.stack(collated_y)) - def segmentation_loss( - self, - permutated_prediction: torch.Tensor, - target: torch.Tensor, - weight: Optional[torch.Tensor] = None, - ) -> torch.Tensor: - """Permutation-invariant segmentation loss - - Parameters - ---------- - permutated_prediction : (batch_size, num_frames, num_classes) torch.Tensor - Permutated speaker activity predictions. - target : (batch_size, num_frames, num_speakers) torch.Tensor - Speaker activity. - weight : (batch_size, num_frames, 1) torch.Tensor, optional - Frames weight. - - Returns - ------- - seg_loss : torch.Tensor - Permutation-invariant segmentation loss - """ - - if self.specifications.powerset: - # `clamp_min` is needed to set non-speech weight to 1. - class_weight = ( - torch.clamp_min(self.model.powerset.cardinality, 1.0) - if self.weigh_by_cardinality - else None - ) - seg_loss = nll_loss( - permutated_prediction, - torch.argmax(target, dim=-1), - class_weight=class_weight, - weight=weight, - ) - else: - seg_loss = binary_cross_entropy( - permutated_prediction, target.float(), weight=weight - ) - - return seg_loss - - def voice_activity_detection_loss( - self, - permutated_prediction: torch.Tensor, - target: torch.Tensor, - weight: Optional[torch.Tensor] = None, - ) -> torch.Tensor: - """Voice activity detection loss - - Parameters - ---------- - permutated_prediction : (batch_size, num_frames, num_classes) torch.Tensor - Speaker activity predictions. - target : (batch_size, num_frames, num_speakers) torch.Tensor - Speaker activity. - weight : (batch_size, num_frames, 1) torch.Tensor, optional - Frames weight. - - Returns - ------- - vad_loss : torch.Tensor - Voice activity detection loss. - """ - - vad_prediction, _ = torch.max(permutated_prediction, dim=2, keepdim=True) - # (batch_size, num_frames, 1) - - vad_target, _ = torch.max(target.float(), dim=2, keepdim=False) - # (batch_size, num_frames) - - if self.vad_loss == "bce": - loss = binary_cross_entropy(vad_prediction, vad_target, weight=weight) - - elif self.vad_loss == "mse": - loss = mse_loss(vad_prediction, vad_target, weight=weight) - - return loss + @property + def automatic_optimization(self): + return self.model.automatic_optimization def training_step(self, batch, batch_idx: int): """Compute permutation-invariant segmentation loss @@ -566,76 +448,41 @@ def training_step(self, batch, batch_idx: int): ) # (batch_size, num_frames, 1) - # warm-up - warm_up_left = round(self.warm_up[0] / self.duration * num_frames) - weight[:, :warm_up_left] = 0.0 - warm_up_right = round(self.warm_up[1] / self.duration * num_frames) - weight[:, num_frames - warm_up_right :] = 0.0 - - if self.specifications.powerset: - multilabel = self.model.powerset.to_multilabel(prediction) - permutated_target, _ = permutate(multilabel, target) - permutated_target_powerset = self.model.powerset.to_powerset( - permutated_target.float() - ) - seg_loss = self.segmentation_loss( - prediction, permutated_target_powerset, weight=weight - ) + multilabel = self.model.powerset.to_multilabel(prediction) + permutated_target, _ = permutate(multilabel, target) + permutated_target_powerset = self.model.powerset.to_powerset( + permutated_target.float() + ) - else: - permutated_prediction, _ = permutate(target, prediction) - seg_loss = self.segmentation_loss( - permutated_prediction, target, weight=weight - ) + loss = nll_loss( + prediction, + torch.argmax(permutated_target_powerset, dim=-1), + weight=weight, + ) self.model.log( - "loss/train/segmentation", - seg_loss, + "loss/train", + loss, on_step=False, on_epoch=True, prog_bar=False, logger=True, ) - if self.vad_loss is None: - vad_loss = 0.0 + if not self.model.automatic_optimization: + optimizers = self.model.optimizers() + for optimizer in optimizers: + optimizer.zero_grad() - else: - # TODO: vad_loss probably does not make sense in powerset mode - # because first class (empty set of labels) does exactly this... - if self.specifications.powerset: - vad_loss = self.voice_activity_detection_loss( - prediction, permutated_target_powerset, weight=weight - ) + self.model.manual_backward(loss) - else: - vad_loss = self.voice_activity_detection_loss( - permutated_prediction, target, weight=weight + for optimizer in optimizers: + self.model.clip_gradients( + optimizer, + gradient_clip_val=5.0, + gradient_clip_algorithm="norm", ) - - self.model.log( - "loss/train/vad", - vad_loss, - on_step=False, - on_epoch=True, - prog_bar=False, - logger=True, - ) - - loss = seg_loss + vad_loss - - # skip batch if something went wrong for some reason - if torch.isnan(loss): - return None - - self.model.log( - "loss/train", - loss, - on_step=False, - on_epoch=True, - prog_bar=False, - logger=True, - ) + optimizer.step() return {"loss": loss} @@ -644,20 +491,11 @@ def default_metric( ) -> Union[Metric, Sequence[Metric], Dict[str, Metric]]: """Returns diarization error rate and its components""" - if self.specifications.powerset: - return { - "DiarizationErrorRate": DiarizationErrorRate(0.5), - "DiarizationErrorRate/Confusion": SpeakerConfusionRate(0.5), - "DiarizationErrorRate/Miss": MissedDetectionRate(0.5), - "DiarizationErrorRate/FalseAlarm": FalseAlarmRate(0.5), - } - return { - "DiarizationErrorRate": OptimalDiarizationErrorRate(), - "DiarizationErrorRate/Threshold": OptimalDiarizationErrorRateThreshold(), - "DiarizationErrorRate/Confusion": OptimalSpeakerConfusionRate(), - "DiarizationErrorRate/Miss": OptimalMissedDetectionRate(), - "DiarizationErrorRate/FalseAlarm": OptimalFalseAlarmRate(), + "DiarizationErrorRate": DiarizationErrorRate(0.5), + "DiarizationErrorRate/Confusion": SpeakerConfusionRate(0.5), + "DiarizationErrorRate/Miss": MissedDetectionRate(0.5), + "DiarizationErrorRate/FalseAlarm": FalseAlarmRate(0.5), } # TODO: no need to compute gradient in this method @@ -695,66 +533,19 @@ def validation_step(self, batch, batch_idx: int): ) # (batch_size, num_frames, 1) - # warm-up - warm_up_left = round(self.warm_up[0] / self.duration * num_frames) - weight[:, :warm_up_left] = 0.0 - warm_up_right = round(self.warm_up[1] / self.duration * num_frames) - weight[:, num_frames - warm_up_right :] = 0.0 - - if self.specifications.powerset: - multilabel = self.model.powerset.to_multilabel(prediction) - permutated_target, _ = permutate(multilabel, target) - - # FIXME: handle case where target have too many speakers? - # since we don't need - permutated_target_powerset = self.model.powerset.to_powerset( - permutated_target.float() - ) - seg_loss = self.segmentation_loss( - prediction, permutated_target_powerset, weight=weight - ) - - else: - permutated_prediction, _ = permutate(target, prediction) - seg_loss = self.segmentation_loss( - permutated_prediction, target, weight=weight - ) + multilabel = self.model.powerset.to_multilabel(prediction) + permutated_target, _ = permutate(multilabel, target) - self.model.log( - "loss/val/segmentation", - seg_loss, - on_step=False, - on_epoch=True, - prog_bar=False, - logger=True, + # FIXME: handle case where target have too many speakers? + permutated_target_powerset = self.model.powerset.to_powerset( + permutated_target.float() ) - if self.vad_loss is None: - vad_loss = 0.0 - - else: - # TODO: vad_loss probably does not make sense in powerset mode - # because first class (empty set of labels) does exactly this... - if self.specifications.powerset: - vad_loss = self.voice_activity_detection_loss( - prediction, permutated_target_powerset, weight=weight - ) - - else: - vad_loss = self.voice_activity_detection_loss( - permutated_prediction, target, weight=weight - ) - - self.model.log( - "loss/val/vad", - vad_loss, - on_step=False, - on_epoch=True, - prog_bar=False, - logger=True, - ) - - loss = seg_loss + vad_loss + loss = nll_loss( + prediction, + torch.argmax(permutated_target_powerset, dim=-1), + weight=weight, + ) self.model.log( "loss/val", @@ -765,24 +556,10 @@ def validation_step(self, batch, batch_idx: int): logger=True, ) - if self.specifications.powerset: - self.model.validation_metric( - torch.transpose( - multilabel[:, warm_up_left : num_frames - warm_up_right], 1, 2 - ), - torch.transpose( - target[:, warm_up_left : num_frames - warm_up_right], 1, 2 - ), - ) - else: - self.model.validation_metric( - torch.transpose( - prediction[:, warm_up_left : num_frames - warm_up_right], 1, 2 - ), - torch.transpose( - target[:, warm_up_left : num_frames - warm_up_right], 1, 2 - ), - ) + self.model.validation_metric( + torch.transpose(multilabel, 1, 2), + torch.transpose(target, 1, 2), + ) self.model.log_dict( self.model.validation_metric, @@ -802,12 +579,8 @@ def validation_step(self, batch, batch_idx: int): # visualize first 9 validation samples of first batch in Tensorboard/MLflow - if self.specifications.powerset: - y = permutated_target.float().cpu().numpy() - y_pred = multilabel.cpu().numpy() - else: - y = target.float().cpu().numpy() - y_pred = permutated_prediction.cpu().numpy() + y = permutated_target.float().cpu().numpy() + y_pred = multilabel.cpu().numpy() # prepare 3 x 3 grid (or smaller if batch size is smaller) num_samples = min(self.batch_size, 9) @@ -841,10 +614,6 @@ def validation_step(self, batch, batch_idx: int): # plot predictions ax_hyp = axes[row_idx * 2 + 1, col_idx] sample_y_pred = y_pred[sample_idx] - ax_hyp.axvspan(0, warm_up_left, color="k", alpha=0.5, lw=0) - ax_hyp.axvspan( - num_frames - warm_up_right, num_frames, color="k", alpha=0.5, lw=0 - ) ax_hyp.plot(sample_y_pred) ax_hyp.set_ylim(-0.1, 1.1) ax_hyp.set_xlim(0, len(sample_y)) @@ -865,7 +634,7 @@ def validation_step(self, batch, batch_idx: int): plt.close(fig) -def main(protocol: str, subset: str = "test", model: str = "pyannote/segmentation"): +def evaluate(protocol: str, subset: str = "test", model: str = "pyannote/segmentation"): """Evaluate a segmentation model""" from pyannote.database import FileFinder, get_protocol @@ -903,5 +672,4 @@ def progress_hook(completed: Optional[int] = None, total: Optional[int] = None): if __name__ == "__main__": import typer - - typer.run(main) + typer.run(evaluate)