Skip to content

Commit

Permalink
utils: fixed loading of audio files with broken metadata encodings
Browse files Browse the repository at this point in the history
  • Loading branch information
jpc committed Apr 24, 2024
1 parent 8717e42 commit 767bbd8
Show file tree
Hide file tree
Showing 6 changed files with 434 additions and 5 deletions.
6 changes: 4 additions & 2 deletions nbs/0. Download models.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@
{
"cell_type": "code",
"execution_count": null,
"id": "18ef67d5",
"id": "10ecf6bb",
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -71,7 +71,9 @@
" load_whisperx('medium', 'en')\n",
" load_whisperx('large-v3', 'en')\n",
" EncoderClassifier.from_hparams(source=\"speechbrain/spkrec-ecapa-voxceleb\",\n",
" savedir=expanduser(\"~/.cache/speechbrain/\"))\n"
" savedir=expanduser(\"~/.cache/speechbrain/\"))\n",
" urllib.request.urlretrieve('https://github.com/marianne-m/brouhaha-vad/raw/main/models/best/checkpoints/best.ckpt',\n",
" expanduser('~/.cache/brouhaha.ckpt'))"
]
},
{
Expand Down
258 changes: 258 additions & 0 deletions nbs/3B. Speech quality metrics extraction.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,258 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"id": "7fdbe3b9",
"metadata": {},
"outputs": [],
"source": [
"#| default_exp extract_metrics"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "6cf56fcb",
"metadata": {},
"outputs": [],
"source": [
"%load_ext autoreload\n",
"%autoreload 2"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "ecbdddfd",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/opt/conda/lib/python3.10/site-packages/pyannote/audio/core/io.py:43: UserWarning: torchaudio._backend.set_audio_backend has been deprecated. With dispatcher enabled, this function is no-op. You can remove the function call.\n",
" torchaudio.set_audio_backend(\"soundfile\")\n",
"/opt/conda/lib/python3.10/site-packages/torch_audiomentations/utils/io.py:27: UserWarning: torchaudio._backend.set_audio_backend has been deprecated. With dispatcher enabled, this function is no-op. You can remove the function call.\n",
" torchaudio.set_audio_backend(\"soundfile\")\n"
]
}
],
"source": [
"#| export\n",
"import sys\n",
"import os\n",
"from os.path import expanduser\n",
"import itertools\n",
"from pathlib import Path\n",
"\n",
"import numpy as np\n",
"import torch\n",
"import torchaudio\n",
"import torch.nn.functional as F\n",
"from torch.profiler import profile, record_function, ProfilerActivity\n",
"\n",
"from fastprogress import progress_bar\n",
"from fastcore.script import *\n",
"\n",
"from pyannote.audio import Model\n",
"from brouhaha.pipeline import RegressiveActivityDetectionPipeline\n",
"from whisperspeech import vq_stoks, utils, vad_merge\n",
"import webdataset as wds\n",
"\n",
"from whisperspeech.inference import get_compute_device"
]
},
{
"cell_type": "markdown",
"id": "e1d80d3b",
"metadata": {},
"source": [
"# Semantic token extraction\n",
"\n",
"We take a webdataset shard and extract acoustic and semantic tokens from it.\n",
"\n",
"We don't use the VAD data since the S2A should work on any random 30 second window."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "5f271d55",
"metadata": {},
"outputs": [],
"source": [
"#| exporti\n",
"@call_parse\n",
"def prepare_metrics(\n",
" input:str, # audio file webdataset file path\n",
" output:str, # output shard path\n",
" n_samples:int=None, # process a limited amount of samples\n",
" \n",
"):\n",
" device = get_compute_device()\n",
"\n",
" model = Model.from_pretrained(expanduser('~/.cache/brouhaha.ckpt'), strict=False)\n",
" snr_pipeline = RegressiveActivityDetectionPipeline(segmentation=model).to(torch.device(device))\n",
" \n",
" total = n_samples if n_samples else 'noinfer'\n",
"\n",
" if total == 'noinfer':\n",
" import math, time\n",
" start = time.time()\n",
" ds = wds.WebDataset([utils.derived_name(input, 'mvad')]).decode()\n",
" total = math.ceil(sum([len(x[f'max.spk_emb.npy']) for x in ds]))\n",
" print(f\"Counting {total} batches: {time.time()-start:.2f}\")\n",
" \n",
" ds = vad_merge.chunked_audio_dataset([input], 'max').compose(\n",
" wds.to_tuple('__key__', 'rpad', 'gain_shift.npy', 'samples', 'sample_rate'),\n",
" )\n",
"\n",
" dl = wds.WebLoader(ds, num_workers=1, batch_size=None)\n",
" \n",
" with utils.AtomicTarWriter(output, throwaway=n_samples is not None) as sink:\n",
" for keys, rpad, gain_shift, samples, sr in progress_bar(dl, total=total):\n",
" with torch.no_grad():\n",
" snd = samples\n",
" if rpad > 0: snd = snd[:-rpad]\n",
" snd = (snd - gain_shift[1]) * gain_shift[0]\n",
" snd = snd.unsqueeze(0).to(device)\n",
"\n",
" res = snr_pipeline({\n",
" \"sample_rate\": sr, \"waveform\": snd\n",
" })\n",
"\n",
" s = {\n",
" \"__key__\": keys,\n",
" \"snr_c50.npy\": np.array([res['snr'].mean(), res['c50'].mean()])\n",
" }\n",
" sink.write(s)\n",
" sys.stdout.write(\"\\n\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "1ac2ffde",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Automatic pdb calling has been turned ON\n"
]
}
],
"source": [
"%pdb"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "1ab9f0a9",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Lightning automatically upgraded your loaded checkpoint from v1.6.5 to v2.1.0. To apply the upgrade to your files permanently, run `python -m pytorch_lightning.utilities.upgrade_checkpoint ../../../.cache/brouhaha.ckpt`\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Model was trained with pyannote.audio 0.0.1, yours is 3.1.1. Bad things might happen unless you revert pyannote.audio to 0.x.\n",
"Model was trained with torch 1.12.1+cu102, yours is 2.2.2+cu121. Bad things might happen unless you revert torch to 1.x.\n"
]
},
{
"data": {
"text/html": [
"\n",
"<style>\n",
" /* Turns off some styling */\n",
" progress {\n",
" /* gets rid of default border in Firefox and Opera. */\n",
" border: none;\n",
" /* Needs to be in here for Safari polyfill so background images work as expected. */\n",
" background-size: auto;\n",
" }\n",
" progress:not([value]), progress:not([value])::-webkit-progress-bar {\n",
" background: repeating-linear-gradient(45deg, #7e7e7e, #7e7e7e 10px, #5c5c5c 10px, #5c5c5c 20px);\n",
" }\n",
" .progress-bar-interrupted, .progress-bar-interrupted::-webkit-progress-bar {\n",
" background: #F44336;\n",
" }\n",
"</style>\n"
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"\n",
" <div>\n",
" <progress value='1024' class='' max='1024' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
" 100.00% [1024/1024 00:28&lt;00:00]\n",
" </div>\n",
" "
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Using default parameters optimized on Brouhaha\n",
"\n"
]
}
],
"source": [
"prepare_metrics('/data2/mls-polish/audio/mls_polish_train-000000.tar', '/data2/mls-polish/snr-c50/mls_polish_train-000000.tar.gz', n_samples=1024)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "68b66e6f",
"metadata": {},
"outputs": [],
"source": [
"#| hide\n",
"import nbdev; nbdev.nbdev_export()"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "68ace212",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "python3",
"language": "python",
"name": "python3"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
49 changes: 48 additions & 1 deletion nbs/D. Common dataset utilities.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -416,6 +416,53 @@
"outputs": [],
"source": [
"#| exporti\n",
"# a patch to ignore invalid utf-8 metadata\n",
"import torio.io._streaming_media_decoder\n",
"def new_parse_si(i):\n",
" media_type = i.media_type\n",
" try:\n",
" metadata = i.metadata\n",
" except UnicodeDecodeError:\n",
" metadata = {}\n",
" if media_type == \"audio\":\n",
" return torio.io._streaming_media_decoder.SourceAudioStream(\n",
" media_type=i.media_type,\n",
" codec=i.codec_name,\n",
" codec_long_name=i.codec_long_name,\n",
" format=i.format,\n",
" bit_rate=i.bit_rate,\n",
" num_frames=i.num_frames,\n",
" bits_per_sample=i.bits_per_sample,\n",
" metadata=metadata,\n",
" sample_rate=i.sample_rate,\n",
" num_channels=i.num_channels,\n",
" )\n",
" if media_type == \"video\":\n",
" return torio.io._streaming_media_decoder.SourceVideoStream(\n",
" media_type=i.media_type,\n",
" codec=i.codec_name,\n",
" codec_long_name=i.codec_long_name,\n",
" format=i.format,\n",
" bit_rate=i.bit_rate,\n",
" num_frames=i.num_frames,\n",
" bits_per_sample=i.bits_per_sample,\n",
" metadata=metadata,\n",
" width=i.width,\n",
" height=i.height,\n",
" frame_rate=i.frame_rate,\n",
" )\n",
" return torio.io._streaming_media_decoder.SourceStream(\n",
" media_type=i.media_type,\n",
" codec=i.codec_name,\n",
" codec_long_name=i.codec_long_name,\n",
" format=None,\n",
" bit_rate=None,\n",
" num_frames=None,\n",
" bits_per_sample=None,\n",
" metadata=metadata,\n",
" )\n",
"torio.io._streaming_media_decoder._parse_si = new_parse_si\n",
"\n",
"def torch_audio_opus(key, data):\n",
" \"\"\"Decode audio using the torchaudio library.\n",
"\n",
Expand All @@ -432,7 +479,7 @@
" fname = os.path.join(dirname, f\"file.{extension}\")\n",
" with open(fname, \"wb\") as stream:\n",
" stream.write(data)\n",
" return torchaudio.load(fname)"
" return torchaudio.load(fname, backend='soundfile' if extension == \"mp3\" else None)"
]
},
{
Expand Down
Loading

0 comments on commit 767bbd8

Please sign in to comment.