-
Notifications
You must be signed in to change notification settings - Fork 210
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
utils: fixed loading of audio files with broken metadata encodings
- Loading branch information
Showing
6 changed files
with
434 additions
and
5 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,258 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "7fdbe3b9", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"#| default_exp extract_metrics" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "6cf56fcb", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"%load_ext autoreload\n", | ||
"%autoreload 2" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "ecbdddfd", | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"name": "stderr", | ||
"output_type": "stream", | ||
"text": [ | ||
"/opt/conda/lib/python3.10/site-packages/pyannote/audio/core/io.py:43: UserWarning: torchaudio._backend.set_audio_backend has been deprecated. With dispatcher enabled, this function is no-op. You can remove the function call.\n", | ||
" torchaudio.set_audio_backend(\"soundfile\")\n", | ||
"/opt/conda/lib/python3.10/site-packages/torch_audiomentations/utils/io.py:27: UserWarning: torchaudio._backend.set_audio_backend has been deprecated. With dispatcher enabled, this function is no-op. You can remove the function call.\n", | ||
" torchaudio.set_audio_backend(\"soundfile\")\n" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"#| export\n", | ||
"import sys\n", | ||
"import os\n", | ||
"from os.path import expanduser\n", | ||
"import itertools\n", | ||
"from pathlib import Path\n", | ||
"\n", | ||
"import numpy as np\n", | ||
"import torch\n", | ||
"import torchaudio\n", | ||
"import torch.nn.functional as F\n", | ||
"from torch.profiler import profile, record_function, ProfilerActivity\n", | ||
"\n", | ||
"from fastprogress import progress_bar\n", | ||
"from fastcore.script import *\n", | ||
"\n", | ||
"from pyannote.audio import Model\n", | ||
"from brouhaha.pipeline import RegressiveActivityDetectionPipeline\n", | ||
"from whisperspeech import vq_stoks, utils, vad_merge\n", | ||
"import webdataset as wds\n", | ||
"\n", | ||
"from whisperspeech.inference import get_compute_device" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "e1d80d3b", | ||
"metadata": {}, | ||
"source": [ | ||
"# Semantic token extraction\n", | ||
"\n", | ||
"We take a webdataset shard and extract acoustic and semantic tokens from it.\n", | ||
"\n", | ||
"We don't use the VAD data since the S2A should work on any random 30 second window." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "5f271d55", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"#| exporti\n", | ||
"@call_parse\n", | ||
"def prepare_metrics(\n", | ||
" input:str, # audio file webdataset file path\n", | ||
" output:str, # output shard path\n", | ||
" n_samples:int=None, # process a limited amount of samples\n", | ||
" \n", | ||
"):\n", | ||
" device = get_compute_device()\n", | ||
"\n", | ||
" model = Model.from_pretrained(expanduser('~/.cache/brouhaha.ckpt'), strict=False)\n", | ||
" snr_pipeline = RegressiveActivityDetectionPipeline(segmentation=model).to(torch.device(device))\n", | ||
" \n", | ||
" total = n_samples if n_samples else 'noinfer'\n", | ||
"\n", | ||
" if total == 'noinfer':\n", | ||
" import math, time\n", | ||
" start = time.time()\n", | ||
" ds = wds.WebDataset([utils.derived_name(input, 'mvad')]).decode()\n", | ||
" total = math.ceil(sum([len(x[f'max.spk_emb.npy']) for x in ds]))\n", | ||
" print(f\"Counting {total} batches: {time.time()-start:.2f}\")\n", | ||
" \n", | ||
" ds = vad_merge.chunked_audio_dataset([input], 'max').compose(\n", | ||
" wds.to_tuple('__key__', 'rpad', 'gain_shift.npy', 'samples', 'sample_rate'),\n", | ||
" )\n", | ||
"\n", | ||
" dl = wds.WebLoader(ds, num_workers=1, batch_size=None)\n", | ||
" \n", | ||
" with utils.AtomicTarWriter(output, throwaway=n_samples is not None) as sink:\n", | ||
" for keys, rpad, gain_shift, samples, sr in progress_bar(dl, total=total):\n", | ||
" with torch.no_grad():\n", | ||
" snd = samples\n", | ||
" if rpad > 0: snd = snd[:-rpad]\n", | ||
" snd = (snd - gain_shift[1]) * gain_shift[0]\n", | ||
" snd = snd.unsqueeze(0).to(device)\n", | ||
"\n", | ||
" res = snr_pipeline({\n", | ||
" \"sample_rate\": sr, \"waveform\": snd\n", | ||
" })\n", | ||
"\n", | ||
" s = {\n", | ||
" \"__key__\": keys,\n", | ||
" \"snr_c50.npy\": np.array([res['snr'].mean(), res['c50'].mean()])\n", | ||
" }\n", | ||
" sink.write(s)\n", | ||
" sys.stdout.write(\"\\n\")" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "1ac2ffde", | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"name": "stdout", | ||
"output_type": "stream", | ||
"text": [ | ||
"Automatic pdb calling has been turned ON\n" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"%pdb" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "1ab9f0a9", | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"name": "stderr", | ||
"output_type": "stream", | ||
"text": [ | ||
"Lightning automatically upgraded your loaded checkpoint from v1.6.5 to v2.1.0. To apply the upgrade to your files permanently, run `python -m pytorch_lightning.utilities.upgrade_checkpoint ../../../.cache/brouhaha.ckpt`\n" | ||
] | ||
}, | ||
{ | ||
"name": "stdout", | ||
"output_type": "stream", | ||
"text": [ | ||
"Model was trained with pyannote.audio 0.0.1, yours is 3.1.1. Bad things might happen unless you revert pyannote.audio to 0.x.\n", | ||
"Model was trained with torch 1.12.1+cu102, yours is 2.2.2+cu121. Bad things might happen unless you revert torch to 1.x.\n" | ||
] | ||
}, | ||
{ | ||
"data": { | ||
"text/html": [ | ||
"\n", | ||
"<style>\n", | ||
" /* Turns off some styling */\n", | ||
" progress {\n", | ||
" /* gets rid of default border in Firefox and Opera. */\n", | ||
" border: none;\n", | ||
" /* Needs to be in here for Safari polyfill so background images work as expected. */\n", | ||
" background-size: auto;\n", | ||
" }\n", | ||
" progress:not([value]), progress:not([value])::-webkit-progress-bar {\n", | ||
" background: repeating-linear-gradient(45deg, #7e7e7e, #7e7e7e 10px, #5c5c5c 10px, #5c5c5c 20px);\n", | ||
" }\n", | ||
" .progress-bar-interrupted, .progress-bar-interrupted::-webkit-progress-bar {\n", | ||
" background: #F44336;\n", | ||
" }\n", | ||
"</style>\n" | ||
], | ||
"text/plain": [ | ||
"<IPython.core.display.HTML object>" | ||
] | ||
}, | ||
"metadata": {}, | ||
"output_type": "display_data" | ||
}, | ||
{ | ||
"data": { | ||
"text/html": [ | ||
"\n", | ||
" <div>\n", | ||
" <progress value='1024' class='' max='1024' style='width:300px; height:20px; vertical-align: middle;'></progress>\n", | ||
" 100.00% [1024/1024 00:28<00:00]\n", | ||
" </div>\n", | ||
" " | ||
], | ||
"text/plain": [ | ||
"<IPython.core.display.HTML object>" | ||
] | ||
}, | ||
"metadata": {}, | ||
"output_type": "display_data" | ||
}, | ||
{ | ||
"name": "stdout", | ||
"output_type": "stream", | ||
"text": [ | ||
"Using default parameters optimized on Brouhaha\n", | ||
"\n" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"prepare_metrics('/data2/mls-polish/audio/mls_polish_train-000000.tar', '/data2/mls-polish/snr-c50/mls_polish_train-000000.tar.gz', n_samples=1024)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "68b66e6f", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"#| hide\n", | ||
"import nbdev; nbdev.nbdev_export()" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "68ace212", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [] | ||
} | ||
], | ||
"metadata": { | ||
"kernelspec": { | ||
"display_name": "python3", | ||
"language": "python", | ||
"name": "python3" | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 5 | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.