From 9b28df2ace73fd5ca9ac5847849269a83af39934 Mon Sep 17 00:00:00 2001 From: Your Name Date: Tue, 22 Oct 2024 04:10:38 -0700 Subject: [PATCH 1/4] add workflow: dnsmos --- lhotse/workflows/__init__.py | 1 + lhotse/workflows/dnsmos.py | 207 +++++++++++++++++++++++++++++++++++ 2 files changed, 208 insertions(+) create mode 100644 lhotse/workflows/dnsmos.py diff --git a/lhotse/workflows/__init__.py b/lhotse/workflows/__init__.py index ccce27bf6..953aa803e 100644 --- a/lhotse/workflows/__init__.py +++ b/lhotse/workflows/__init__.py @@ -1,4 +1,5 @@ from .activity_detection import * +from .dnsmos import annotate_dnsmos from .forced_alignment import align_with_torchaudio from .meeting_simulation import * from .whisper import annotate_with_whisper diff --git a/lhotse/workflows/dnsmos.py b/lhotse/workflows/dnsmos.py new file mode 100644 index 000000000..be8e5923e --- /dev/null +++ b/lhotse/workflows/dnsmos.py @@ -0,0 +1,207 @@ +import logging +import os +from concurrent.futures.thread import ThreadPoolExecutor +from typing import Generator, List, Optional, Union + +import numpy as np +from tqdm import tqdm + +from lhotse import CutSet, MonoCut, RecordingSet, SupervisionSegment +from lhotse.utils import fastcopy, is_module_available, resumable_download + + +class ComputeScore: + def __init__(self, primary_model_path) -> None: + import onnxruntime as ort + + self.onnx_sess = ort.InferenceSession(primary_model_path) + self.SAMPLING_RATE = 16000 + self.INPUT_LENGTH = 9.01 + + def audio_melspec( + self, audio, n_mels=120, frame_size=320, hop_length=160, sr=16000, to_db=True + ): + import librosa + + mel_spec = librosa.feature.melspectrogram( + y=audio, sr=sr, n_fft=frame_size + 1, hop_length=hop_length, n_mels=n_mels + ) + if to_db: + mel_spec = (librosa.power_to_db(mel_spec, ref=np.max) + 40) / 40 + return mel_spec.T + + def get_polyfit_val(self, sig, bak, ovr, is_personalized_MOS): + if is_personalized_MOS: + p_ovr = np.poly1d([-0.00533021, 0.005101, 1.18058466, -0.11236046]) + p_sig = np.poly1d([-0.01019296, 0.02751166, 1.19576786, -0.24348726]) + p_bak = np.poly1d([-0.04976499, 0.44276479, -0.1644611, 0.96883132]) + else: + p_ovr = np.poly1d([-0.06766283, 1.11546468, 0.04602535]) + p_sig = np.poly1d([-0.08397278, 1.22083953, 0.0052439]) + p_bak = np.poly1d([-0.13166888, 1.60915514, -0.39604546]) + + sig_poly = p_sig(sig) + bak_poly = p_bak(bak) + ovr_poly = p_ovr(ovr) + + return sig_poly, bak_poly, ovr_poly + + def __call__(self, manifest, is_personalized_MOS): + fs = self.SAMPLING_RATE + audio = manifest.resample(fs).load_audio() + len_samples = int(self.INPUT_LENGTH * fs) + while len(audio) < len_samples: + audio = np.append(audio, audio) + + num_hops = int(np.floor(len(audio) / fs) - self.INPUT_LENGTH) + 1 + hop_len_samples = fs + predicted_mos_sig_seg = [] + predicted_mos_bak_seg = [] + predicted_mos_ovr_seg = [] + + for idx in range(num_hops): + audio_seg = audio[ + int(idx * hop_len_samples) : int( + (idx + self.INPUT_LENGTH) * hop_len_samples + ) + ] + if len(audio_seg) < len_samples: + continue + + input_features = np.array(audio_seg).astype("float32")[np.newaxis, :] + oi = {"input_1": input_features} + mos_sig_raw, mos_bak_raw, mos_ovr_raw = self.onnx_sess.run(None, oi)[0][0] + mos_sig, mos_bak, mos_ovr = self.get_polyfit_val( + mos_sig_raw, mos_bak_raw, mos_ovr_raw, is_personalized_MOS + ) + predicted_mos_sig_seg.append(mos_sig) + predicted_mos_bak_seg.append(mos_bak) + predicted_mos_ovr_seg.append(mos_ovr) + + return manifest, { + "OVRL": np.mean(predicted_mos_ovr_seg), + "SIG": np.mean(predicted_mos_sig_seg), + "BAK": np.mean(predicted_mos_bak_seg), + } + + +def download_model( + is_personalized_MOS: bool = False, + download_root: Optional[str] = None, +) -> str: + download_root = download_root if download_root is not None else "/tmp" + url = ( + "https://github.com/microsoft/DNS-Challenge/blob/master/DNSMOS/pDNSMOS/sig_bak_ovr.onnx" + if is_personalized_MOS + else "https://github.com/microsoft/DNS-Challenge/blob/master/DNSMOS/DNSMOS/sig_bak_ovr.onnx" + ) + filename = os.path.join(download_root, "sig_bak_ovr.onnx") + resumable_download(url, filename=filename) + return filename + + +def annotate_dnsmos( + manifest: Union[RecordingSet, CutSet], + is_personalized_MOS: bool = False, + download_root: Optional[str] = None, +) -> Generator[MonoCut, None, None]: + """ + Use Microsoft DNSMOS P.835 prediction model to annotate either RECORDINGS_MANIFEST, RECORDINGS_DIR, or CUTS_MANIFEST. + It will predict DNSMOS P.835 score including SIG, NAK, and OVRL. + + See the original repo for more details: https://github.com/microsoft/DNS-Challenge/tree/master/DNSMOS + + :param manifest: a ``RecordingSet`` or ``CutSet`` object. + :param is_personalized_MOS: flag to indicate if personalized MOS score is needed or regular. + :param download_root: if specified, the model will be downloaded to this directory. Otherwise, + it will be downloaded to /tmp. + :return: a generator of cuts (use ``CutSet.open_writer()`` to write them). + """ + assert is_module_available("librosa"), ( + "This function expects librosa to be installed. " + "You can install it via 'pip install librosa'" + ) + + assert is_module_available("onnxruntime"), ( + "This function expects onnxruntime to be installed. " + "You can install it via 'pip install onnxruntime'" + ) + + if isinstance(manifest, RecordingSet): + yield from _annotate_recordings( + manifest, + is_personalized_MOS, + download_root, + ) + elif isinstance(manifest, CutSet): + yield from _annotate_cuts( + manifest, + is_personalized_MOS, + download_root, + ) + else: + raise ValueError("The ``manifest`` must be either a RecordingSet or a CutSet.") + + +def _annotate_recordings( + recordings: RecordingSet, + is_personalized_MOS: bool = False, + download_root: Optional[str] = None, +): + """ + Helper function that annotates a RecordingSet with DNSMOS P.835 prediction model. + """ + primary_model_path = download_model(is_personalized_MOS, download_root) + compute_score = ComputeScore(primary_model_path) + + with ThreadPoolExecutor() as ex: + futures = [] + for recording in tqdm(recordings, desc="Distributing tasks"): + if recording.num_channels > 1: + logging.warning( + f"Skipping recording '{recording.id}'. It has {recording.num_channels} channels, " + f"but we currently only support mono input." + ) + continue + futures.append(ex.submit(compute_score, recording, is_personalized_MOS)) + + for future in tqdm(futures, desc="Processing"): + recording, result = future.result() + supervision = SupervisionSegment( + id=recording.id, + recording_id=recording.id, + start=0, + duration=recording.durations, + ) + cut = MonoCut( + recording=recording, supervisions=[supervision], custom=result + ) + yield cut + + +def _annotate_cuts( + cuts: CutSet, + is_personalized_MOS: bool = False, + download_root: Optional[str] = None, +): + """ + Helper function that annotates a CutSet with DNSMOS P.835 prediction model. + """ + primary_model_path = download_model(is_personalized_MOS, download_root) + compute_score = ComputeScore(primary_model_path) + + with ThreadPoolExecutor() as ex: + futures = [] + for cut in tqdm(cuts, desc="Distributing tasks"): + if cut.num_channels > 1: + logging.warning( + f"Skipping cut '{cut.id}'. It has {cut.num_channels} channels, " + f"but we currently only support mono input." + ) + continue + futures.append(ex.submit(compute_score, cut, is_personalized_MOS)) + + for future in tqdm(futures, desc="Processing"): + cut, result = future.result() + new_cut = fastcopy(cut, custom=result) + yield new_cut From ab9a632649df6892b3e41f3c8295ea5f187302c0 Mon Sep 17 00:00:00 2001 From: Your Name Date: Tue, 22 Oct 2024 05:46:49 -0700 Subject: [PATCH 2/4] add cli for dnsmos workflow --- lhotse/bin/modes/workflows.py | 79 +++++++++++++++++++++++++++++++++++ 1 file changed, 79 insertions(+) diff --git a/lhotse/bin/modes/workflows.py b/lhotse/bin/modes/workflows.py index dec178c7c..1f91bb429 100644 --- a/lhotse/bin/modes/workflows.py +++ b/lhotse/bin/modes/workflows.py @@ -569,3 +569,82 @@ def activity_detection( supervisions.to_file(str(sups_path)) print("Results saved to:", str(sups_path), sep="\n") + + +@workflows.command() +@click.argument("out_cuts", type=click.Path(allow_dash=True)) +@click.option( + "-m", + "--recordings-manifest", + type=click.Path(exists=True, dir_okay=False, allow_dash=True), + help="Path to an existing recording manifest.", +) +@click.option( + "-r", + "--recordings-dir", + type=click.Path(exists=True, file_okay=False), + help="Directory with recordings. We will create a RecordingSet for it automatically.", +) +@click.option( + "-c", + "--cuts-manifest", + type=click.Path(exists=True, dir_okay=False, allow_dash=True), + help="Path to an existing cuts manifest.", +) +@click.option( + "-e", + "--extension", + default="wav", + help="Audio file extension to search for. Used with RECORDINGS_DIR.", +) +@click.option( + "-p", + "--is_personalized_MOS", + default="False", + help="Flag to indicate if personalized MOS score is needed or regular.", +) +@click.option("-j", "--jobs", default=1, help="Number of jobs for audio scanning.") +def annotate_dnsmos( + out_cuts: str, + recordings_manifest: Optional[str], + recordings_dir: Optional[str], + cuts_manifest: Optional[str], + extension: str, + is_personalized_MOS: str, + jobs: int, +): + """ + Use Microsoft DNSMOS P.835 prediction model to annotate either RECORDINGS_MANIFEST, RECORDINGS_DIR, or CUTS_MANIFEST. + It will predict DNSMOS P.835 score including SIG, NAK, and OVRL. + + See the original repo for more details: https://github.com/microsoft/DNS-Challenge/tree/master/DNSMOS + + RECORDINGS_MANIFEST, RECORDINGS_DIR, and CUTS_MANIFEST are mutually exclusive. If CUTS_MANIFEST + is provided, its supervisions will be overwritten with the results of the inference. + """ + from lhotse import annotate_dnsmos as annotate_dnsmos_ + + assert exactly_one_not_null(recordings_manifest, recordings_dir, cuts_manifest), ( + "Options RECORDINGS_MANIFEST, RECORDINGS_DIR, and CUTS_MANIFEST are mutually exclusive " + "and at least one is required." + ) + + if recordings_manifest is not None: + manifest = RecordingSet.from_file(recordings_manifest) + elif recordings_dir is not None: + manifest = RecordingSet.from_dir( + recordings_dir, pattern=f"*.{extension}", num_jobs=jobs + ) + else: + manifest = CutSet.from_file(cuts_manifest).to_eager() + + with CutSet.open_writer(out_cuts) as writer: + for cut in tqdm( + annotate_dnsmos_( + manifest, + is_personalized_MOS=is_personalized_MOS, + ), + total=len(manifest), + desc="Annotating with Whisper", + ): + writer.write(cut, flush=True) From 6e829296ec5ff77f4ba100d9f6c12ddb87a9c163 Mon Sep 17 00:00:00 2001 From: Your Name Date: Tue, 22 Oct 2024 06:06:16 -0700 Subject: [PATCH 3/4] fix and test --- lhotse/bin/modes/workflows.py | 8 +++--- lhotse/workflows/dnsmos.py | 46 ++++++++++++++++++++--------------- 2 files changed, 30 insertions(+), 24 deletions(-) diff --git a/lhotse/bin/modes/workflows.py b/lhotse/bin/modes/workflows.py index 1f91bb429..bd2ef664b 100644 --- a/lhotse/bin/modes/workflows.py +++ b/lhotse/bin/modes/workflows.py @@ -599,7 +599,7 @@ def activity_detection( ) @click.option( "-p", - "--is_personalized_MOS", + "--is-personalized-mos", default="False", help="Flag to indicate if personalized MOS score is needed or regular.", ) @@ -610,7 +610,7 @@ def annotate_dnsmos( recordings_dir: Optional[str], cuts_manifest: Optional[str], extension: str, - is_personalized_MOS: str, + is_personalized_mos: str, jobs: int, ): """ @@ -642,9 +642,9 @@ def annotate_dnsmos( for cut in tqdm( annotate_dnsmos_( manifest, - is_personalized_MOS=is_personalized_MOS, + is_personalized_mos=is_personalized_mos, ), total=len(manifest), - desc="Annotating with Whisper", + desc="Annotating with DNSMOS P.835 prediction model", ): writer.write(cut, flush=True) diff --git a/lhotse/workflows/dnsmos.py b/lhotse/workflows/dnsmos.py index be8e5923e..1270212c3 100644 --- a/lhotse/workflows/dnsmos.py +++ b/lhotse/workflows/dnsmos.py @@ -30,8 +30,8 @@ def audio_melspec( mel_spec = (librosa.power_to_db(mel_spec, ref=np.max) + 40) / 40 return mel_spec.T - def get_polyfit_val(self, sig, bak, ovr, is_personalized_MOS): - if is_personalized_MOS: + def get_polyfit_val(self, sig, bak, ovr, is_personalized_mos): + if is_personalized_mos: p_ovr = np.poly1d([-0.00533021, 0.005101, 1.18058466, -0.11236046]) p_sig = np.poly1d([-0.01019296, 0.02751166, 1.19576786, -0.24348726]) p_bak = np.poly1d([-0.04976499, 0.44276479, -0.1644611, 0.96883132]) @@ -46,7 +46,7 @@ def get_polyfit_val(self, sig, bak, ovr, is_personalized_MOS): return sig_poly, bak_poly, ovr_poly - def __call__(self, manifest, is_personalized_MOS): + def __call__(self, manifest, is_personalized_mos): fs = self.SAMPLING_RATE audio = manifest.resample(fs).load_audio() len_samples = int(self.INPUT_LENGTH * fs) @@ -72,7 +72,7 @@ def __call__(self, manifest, is_personalized_MOS): oi = {"input_1": input_features} mos_sig_raw, mos_bak_raw, mos_ovr_raw = self.onnx_sess.run(None, oi)[0][0] mos_sig, mos_bak, mos_ovr = self.get_polyfit_val( - mos_sig_raw, mos_bak_raw, mos_ovr_raw, is_personalized_MOS + mos_sig_raw, mos_bak_raw, mos_ovr_raw, is_personalized_mos ) predicted_mos_sig_seg.append(mos_sig) predicted_mos_bak_seg.append(mos_bak) @@ -86,14 +86,14 @@ def __call__(self, manifest, is_personalized_MOS): def download_model( - is_personalized_MOS: bool = False, + is_personalized_mos: bool = False, download_root: Optional[str] = None, ) -> str: download_root = download_root if download_root is not None else "/tmp" url = ( - "https://github.com/microsoft/DNS-Challenge/blob/master/DNSMOS/pDNSMOS/sig_bak_ovr.onnx" - if is_personalized_MOS - else "https://github.com/microsoft/DNS-Challenge/blob/master/DNSMOS/DNSMOS/sig_bak_ovr.onnx" + "https://github.com/microsoft/DNS-Challenge/raw/refs/heads/master/DNSMOS/pDNSMOS/sig_bak_ovr.onnx" + if is_personalized_mos + else "https://github.com/microsoft/DNS-Challenge/raw/refs/heads/master/DNSMOS/DNSMOS/sig_bak_ovr.onnx" ) filename = os.path.join(download_root, "sig_bak_ovr.onnx") resumable_download(url, filename=filename) @@ -102,7 +102,7 @@ def download_model( def annotate_dnsmos( manifest: Union[RecordingSet, CutSet], - is_personalized_MOS: bool = False, + is_personalized_mos: bool = False, download_root: Optional[str] = None, ) -> Generator[MonoCut, None, None]: """ @@ -112,7 +112,7 @@ def annotate_dnsmos( See the original repo for more details: https://github.com/microsoft/DNS-Challenge/tree/master/DNSMOS :param manifest: a ``RecordingSet`` or ``CutSet`` object. - :param is_personalized_MOS: flag to indicate if personalized MOS score is needed or regular. + :param is_personalized_mos: flag to indicate if personalized MOS score is needed or regular. :param download_root: if specified, the model will be downloaded to this directory. Otherwise, it will be downloaded to /tmp. :return: a generator of cuts (use ``CutSet.open_writer()`` to write them). @@ -130,13 +130,13 @@ def annotate_dnsmos( if isinstance(manifest, RecordingSet): yield from _annotate_recordings( manifest, - is_personalized_MOS, + is_personalized_mos, download_root, ) elif isinstance(manifest, CutSet): yield from _annotate_cuts( manifest, - is_personalized_MOS, + is_personalized_mos, download_root, ) else: @@ -145,13 +145,13 @@ def annotate_dnsmos( def _annotate_recordings( recordings: RecordingSet, - is_personalized_MOS: bool = False, + is_personalized_mos: bool = False, download_root: Optional[str] = None, ): """ Helper function that annotates a RecordingSet with DNSMOS P.835 prediction model. """ - primary_model_path = download_model(is_personalized_MOS, download_root) + primary_model_path = download_model(is_personalized_mos, download_root) compute_score = ComputeScore(primary_model_path) with ThreadPoolExecutor() as ex: @@ -163,7 +163,7 @@ def _annotate_recordings( f"but we currently only support mono input." ) continue - futures.append(ex.submit(compute_score, recording, is_personalized_MOS)) + futures.append(ex.submit(compute_score, recording, is_personalized_mos)) for future in tqdm(futures, desc="Processing"): recording, result = future.result() @@ -171,23 +171,29 @@ def _annotate_recordings( id=recording.id, recording_id=recording.id, start=0, - duration=recording.durations, + duration=recording.duration, ) cut = MonoCut( - recording=recording, supervisions=[supervision], custom=result + id=recording.id, + start=0, + duration=recording.duration, + channel=0, + recording=recording, + supervisions=[supervision], + custom=result, ) yield cut def _annotate_cuts( cuts: CutSet, - is_personalized_MOS: bool = False, + is_personalized_mos: bool = False, download_root: Optional[str] = None, ): """ Helper function that annotates a CutSet with DNSMOS P.835 prediction model. """ - primary_model_path = download_model(is_personalized_MOS, download_root) + primary_model_path = download_model(is_personalized_mos, download_root) compute_score = ComputeScore(primary_model_path) with ThreadPoolExecutor() as ex: @@ -199,7 +205,7 @@ def _annotate_cuts( f"but we currently only support mono input." ) continue - futures.append(ex.submit(compute_score, cut, is_personalized_MOS)) + futures.append(ex.submit(compute_score, cut, is_personalized_mos)) for future in tqdm(futures, desc="Processing"): cut, result = future.result() From d4865087d848ccffe272267477ab572f6c4bd695 Mon Sep 17 00:00:00 2001 From: Your Name Date: Tue, 22 Oct 2024 06:32:09 -0700 Subject: [PATCH 4/4] fix --- lhotse/bin/modes/workflows.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lhotse/bin/modes/workflows.py b/lhotse/bin/modes/workflows.py index bd2ef664b..ef629e4fe 100644 --- a/lhotse/bin/modes/workflows.py +++ b/lhotse/bin/modes/workflows.py @@ -600,7 +600,7 @@ def activity_detection( @click.option( "-p", "--is-personalized-mos", - default="False", + default=False, help="Flag to indicate if personalized MOS score is needed or regular.", ) @click.option("-j", "--jobs", default=1, help="Number of jobs for audio scanning.")