Skip to content

Commit

Permalink
Added support for loading the datasets files relative to the config f…
Browse files Browse the repository at this point in the history
…ile path
  • Loading branch information
jpc committed Mar 12, 2024
1 parent 4a047dd commit 1a5fb25
Show file tree
Hide file tree
Showing 8 changed files with 150 additions and 38 deletions.
11 changes: 6 additions & 5 deletions nbs/4B. Multi-language semantic to acoustic token modeling.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -175,12 +175,13 @@
" validation:bool=False,\n",
" exclude_files:str=None,\n",
" randomize_speakers:bool=False,\n",
" cwd:Path=None,\n",
" ):\n",
" import webdataset as wds\n",
" from . import utils, languages\n",
" from whisperspeech import utils, languages\n",
"\n",
" shards = utils.shard_glob(atoks_shard_spec)\n",
" excludes = {x for file in exclude_files.split() for x in utils.readlines(file)} if exclude_files else set()\n",
" shards = utils.shard_glob(cwd/atoks_shard_spec)\n",
" excludes = {x for file in exclude_files.split() for x in utils.readlines(cwd/file)} if exclude_files else set()\n",
" \n",
" def check_for_nan(s):\n",
" if torch.tensor(s['spk_emb.npy']).isnan().any(): print(\"found NaN:\", s['__key__'])\n",
Expand All @@ -193,7 +194,7 @@
" same_on_all_nodes = lambda urls: urls # will only be used for validation\n",
" ds = wds.WebDataset(shards, resampled=not validation, nodesplitter=same_on_all_nodes).compose(\n",
" wds.decode(),\n",
" utils.merge_in(utils.derived_dataset('maxvad-stoks', base='atoks-3kbps', suffix='', dir=stoks_shard_dir)),\n",
" utils.merge_in(utils.derived_dataset('maxvad-stoks', base='atoks-3kbps', suffix='', dir=cwd/stoks_shard_dir)),\n",
" wds.map(check_for_nan),\n",
" wds.select(lambda s: s['__key__'] not in excludes),\n",
" wds.map_dict(**{'spk_emb.npy':np.nan_to_num}), # remove nans from the speaker embedding model\n",
Expand Down Expand Up @@ -763,7 +764,7 @@
"\n",
"def make_model(size:str, quantizers:int=4, frozen_embeddings_model:str=None, frozen_acoustic_embeddings:bool=False, spk_width:int=None, tunables:Tunables=Tunables(), dataset=None):\n",
" from encodec.model import EncodecModel\n",
" from . import vq_stoks\n",
" from whisperspeech import vq_stoks\n",
"\n",
" amodel = EncodecModel.encodec_model_24khz() if frozen_acoustic_embeddings else None\n",
" vqmodel = vq_stoks.RQBottleneckTransformer.load_model(frozen_embeddings_model) if frozen_embeddings_model else None\n",
Expand Down
7 changes: 4 additions & 3 deletions nbs/5B. Multi-lang text to semantic token modeling.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -158,12 +158,13 @@
" weight:float=1,\n",
" validation:bool=False,\n",
" exclude_files:str=None,\n",
" cwd:Path=None,\n",
"):\n",
" import webdataset as wds\n",
" from . import utils\n",
"\n",
" shards = utils.shard_glob(txt_shard_spec)\n",
" excludes = {x for file in exclude_files.split() for x in utils.readlines(file)} if exclude_files else set()\n",
" shards = utils.shard_glob(cwd/txt_shard_spec)\n",
" excludes = {x for file in exclude_files.split() for x in utils.readlines(cwd/file)} if exclude_files else set()\n",
" \n",
" language = languages.to_id(language)\n",
" \n",
Expand All @@ -174,7 +175,7 @@
" same_on_all_nodes = lambda urls: urls # will only be used for validation\n",
" ds = wds.WebDataset(shards, resampled=not validation, nodesplitter=same_on_all_nodes).compose(\n",
" wds.decode(),\n",
" utils.merge_in(utils.derived_dataset('eqvad-stoks', base=txt_kind, suffix='', dir=stoks_shard_dir)),\n",
" utils.merge_in(utils.derived_dataset('eqvad-stoks', base=txt_kind, suffix='', dir=cwd/stoks_shard_dir)),\n",
" # discard validation samples, select samples > .5s\n",
" wds.select(lambda s: s['__key__'] not in excludes and s['stoks.npy'].shape[-1] > 12),\n",
" tokenizer('txt', 'ttoks', length=550),\n",
Expand Down
19 changes: 14 additions & 5 deletions nbs/B2. Training (Lightning).ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
"import random\n",
"import re\n",
"from pathlib import Path\n",
"import requests\n",
"\n",
"from fastprogress import progress_bar, master_bar\n",
"import fastprogress\n",
Expand Down Expand Up @@ -358,13 +359,21 @@
"outputs": [],
"source": [
"#| exporti\n",
"def load_file_reference(matchobj):\n",
" with open(matchobj.group(1), 'r') as f:\n",
" return f.read().strip()\n",
"\n",
"def parse_dataset_string(s):\n",
" cwd = [None]\n",
" def load_file_reference(matchobj):\n",
" fname = matchobj.group(1)\n",
" cwd[0] = Path(fname).parent\n",
" if fname.startswith('http://') or fname.startswith('https://'):\n",
" response = requests.get(target_url)\n",
" return response.text.strip()\n",
" else:\n",
" with open(fname, 'r') as f:\n",
" return f.read().strip()\n",
" s = re.sub('@([^ ]+)', load_file_reference, s)\n",
" return shlex.split(s)"
" arg_list = shlex.split(s)\n",
" if cwd[0]: arg_list += ['--cwd', str(cwd[0])]\n",
" return arg_list"
]
},
{
Expand Down
90 changes: 89 additions & 1 deletion nbs/D. Common dataset utilities.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,11 @@
"source": [
"#| export\n",
"def shard_glob(input):\n",
" if isinstance(input, Path):\n",
" input = str(input)\n",
" if '{' in input:\n",
" return wds.shardlists.expand_urls(input)\n",
" if isinstance(input, (Path, str)):\n",
" if str:\n",
" path = Path(input)\n",
" if path.is_dir():\n",
" glob = '*.tar.gz'\n",
Expand All @@ -51,6 +53,92 @@
" return [str(x) for x in input]"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "d1f98923",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"['../librilight/librilight-atoks-txts/librilight-small-atoks-3kbps-000000.tar.gz',\n",
" '../librilight/librilight-atoks-txts/librilight-small-atoks-3kbps-000006.tar.gz',\n",
" '../librilight/librilight-atoks-txts/librilight-small-atoks-3kbps-000004.tar.gz',\n",
" '../librilight/librilight-atoks-txts/librilight-small-atoks-3kbps-000001.tar.gz',\n",
" '../librilight/librilight-atoks-txts/librilight-small-atoks-3kbps-000003.tar.gz',\n",
" '../librilight/librilight-atoks-txts/librilight-small-atoks-3kbps-000002.tar.gz',\n",
" '../librilight/librilight-atoks-txts/librilight-small-atoks-3kbps-000005.tar.gz',\n",
" '../librilight/librilight-atoks-txts/librilight-small-atoks-3kbps-000007.tar.gz']"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"shard_glob('../librilight/librilight-atoks-txts/librilight-small-atoks-3kbps-*.tar.gz')"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "528d9dc8",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"['../librilight/librilight-atoks-txts/librilight-small-atoks-3kbps-000000.tar.gz',\n",
" '../librilight/librilight-atoks-txts/librilight-small-atoks-3kbps-000006.tar.gz',\n",
" '../librilight/librilight-atoks-txts/librilight-small-atoks-3kbps-000004.tar.gz',\n",
" '../librilight/librilight-atoks-txts/librilight-small-atoks-3kbps-000001.tar.gz',\n",
" '../librilight/librilight-atoks-txts/librilight-small-atoks-3kbps-000003.tar.gz',\n",
" '../librilight/librilight-atoks-txts/librilight-small-atoks-3kbps-000002.tar.gz',\n",
" '../librilight/librilight-atoks-txts/librilight-small-atoks-3kbps-000005.tar.gz',\n",
" '../librilight/librilight-atoks-txts/librilight-small-atoks-3kbps-000007.tar.gz']"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# \n",
"shard_glob(Path('../librilight/librilight-atoks-txts/librilight-small-atoks-3kbps-*.tar.gz'))"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "be49b92b",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"['https:/huggingface.co/datasets/collabora/librilight-processed-webdataset/resolve/main/librilight-small-atoks-3kbps-000000.tar.gz',\n",
" 'https:/huggingface.co/datasets/collabora/librilight-processed-webdataset/resolve/main/librilight-small-atoks-3kbps-000001.tar.gz',\n",
" 'https:/huggingface.co/datasets/collabora/librilight-processed-webdataset/resolve/main/librilight-small-atoks-3kbps-000002.tar.gz',\n",
" 'https:/huggingface.co/datasets/collabora/librilight-processed-webdataset/resolve/main/librilight-small-atoks-3kbps-000003.tar.gz',\n",
" 'https:/huggingface.co/datasets/collabora/librilight-processed-webdataset/resolve/main/librilight-small-atoks-3kbps-000004.tar.gz',\n",
" 'https:/huggingface.co/datasets/collabora/librilight-processed-webdataset/resolve/main/librilight-small-atoks-3kbps-000005.tar.gz',\n",
" 'https:/huggingface.co/datasets/collabora/librilight-processed-webdataset/resolve/main/librilight-small-atoks-3kbps-000006.tar.gz',\n",
" 'https:/huggingface.co/datasets/collabora/librilight-processed-webdataset/resolve/main/librilight-small-atoks-3kbps-000007.tar.gz']"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# we can also specify the range and generate shard URLs\n",
"shard_glob(Path('https://huggingface.co/datasets/collabora/librilight-processed-webdataset/resolve/main/librilight-small-atoks-3kbps-{000000..000007}.tar.gz'))"
]
},
{
"cell_type": "code",
"execution_count": null,
Expand Down
7 changes: 4 additions & 3 deletions whisperspeech/s2a_delar_mup_wds_mlang.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,12 +72,13 @@ def load_dataset(
validation:bool=False,
exclude_files:str=None,
randomize_speakers:bool=False,
cwd:Path=None,
):
import webdataset as wds
from whisperspeech import utils, languages

shards = utils.shard_glob(atoks_shard_spec)
excludes = {x for file in exclude_files.split() for x in utils.readlines(file)} if exclude_files else set()
shards = utils.shard_glob(cwd/atoks_shard_spec)
excludes = {x for file in exclude_files.split() for x in utils.readlines(cwd/file)} if exclude_files else set()

def check_for_nan(s):
if torch.tensor(s['spk_emb.npy']).isnan().any(): print("found NaN:", s['__key__'])
Expand All @@ -90,7 +91,7 @@ def set_language(x):
same_on_all_nodes = lambda urls: urls # will only be used for validation
ds = wds.WebDataset(shards, resampled=not validation, nodesplitter=same_on_all_nodes).compose(
wds.decode(),
utils.merge_in(utils.derived_dataset('maxvad-stoks', base='atoks-3kbps', suffix='', dir=stoks_shard_dir)),
utils.merge_in(utils.derived_dataset('maxvad-stoks', base='atoks-3kbps', suffix='', dir=cwd/stoks_shard_dir)),
wds.map(check_for_nan),
wds.select(lambda s: s['__key__'] not in excludes),
wds.map_dict(**{'spk_emb.npy':np.nan_to_num}), # remove nans from the speaker embedding model
Expand Down
7 changes: 4 additions & 3 deletions whisperspeech/t2s_up_wds_mlang_enclm.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,12 +80,13 @@ def load_dataset(
weight:float=1,
validation:bool=False,
exclude_files:str=None,
cwd:Path=None,
):
import webdataset as wds
from . import utils

shards = utils.shard_glob(txt_shard_spec)
excludes = {x for file in exclude_files.split() for x in utils.readlines(file)} if exclude_files else set()
shards = utils.shard_glob(cwd/txt_shard_spec)
excludes = {x for file in exclude_files.split() for x in utils.readlines(cwd/file)} if exclude_files else set()

language = languages.to_id(language)

Expand All @@ -96,7 +97,7 @@ def set_language(x):
same_on_all_nodes = lambda urls: urls # will only be used for validation
ds = wds.WebDataset(shards, resampled=not validation, nodesplitter=same_on_all_nodes).compose(
wds.decode(),
utils.merge_in(utils.derived_dataset('eqvad-stoks', base=txt_kind, suffix='', dir=stoks_shard_dir)),
utils.merge_in(utils.derived_dataset('eqvad-stoks', base=txt_kind, suffix='', dir=cwd/stoks_shard_dir)),
# discard validation samples, select samples > .5s
wds.select(lambda s: s['__key__'] not in excludes and s['stoks.npy'].shape[-1] > 12),
tokenizer('txt', 'ttoks', length=550),
Expand Down
19 changes: 14 additions & 5 deletions whisperspeech/train_multi.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import random
import re
from pathlib import Path
import requests

from fastprogress import progress_bar, master_bar
import fastprogress
Expand Down Expand Up @@ -225,13 +226,21 @@ def parse_and_call(name, fun, args, kwargs={}, log_to_wandb=True):
hyp_params['world_size'] = 1

# %% ../nbs/B2. Training (Lightning).ipynb 9
def load_file_reference(matchobj):
with open(matchobj.group(1), 'r') as f:
return f.read().strip()

def parse_dataset_string(s):
cwd = [None]
def load_file_reference(matchobj):
fname = matchobj.group(1)
cwd[0] = Path(fname).parent
if fname.startswith('http://') or fname.startswith('https://'):
response = requests.get(target_url)
return response.text.strip()
else:
with open(fname, 'r') as f:
return f.read().strip()
s = re.sub('@([^ ]+)', load_file_reference, s)
return shlex.split(s)
arg_list = shlex.split(s)
if cwd[0]: arg_list += ['--cwd', str(cwd[0])]
return arg_list

# %% ../nbs/B2. Training (Lightning).ipynb 10
from lightning.pytorch.loggers import WandbLogger
Expand Down
28 changes: 15 additions & 13 deletions whisperspeech/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,11 @@

# %% ../nbs/D. Common dataset utilities.ipynb 2
def shard_glob(input):
if isinstance(input, Path):
input = str(input)
if '{' in input:
return wds.shardlists.expand_urls(input)
if isinstance(input, (Path, str)):
if str:
path = Path(input)
if path.is_dir():
glob = '*.tar.gz'
Expand All @@ -29,7 +31,7 @@ def shard_glob(input):
raise ArgumentError("input should be either a list or a path with an optional glob specifier")
return [str(x) for x in input]

# %% ../nbs/D. Common dataset utilities.ipynb 3
# %% ../nbs/D. Common dataset utilities.ipynb 6
class join_datasets(torch.utils.data.IterableDataset):
def __init__(self, datasets):
self.datasets = datasets
Expand All @@ -46,7 +48,7 @@ def __iter__(self):
def __len__(self):
return sum([ds.total_samples for ds in self.datasets])

# %% ../nbs/D. Common dataset utilities.ipynb 6
# %% ../nbs/D. Common dataset utilities.ipynb 9
def resampler(newsr = 24000, key = 'samples_24k'):
_last_sr = None
tform = None
Expand All @@ -63,12 +65,12 @@ def _resample(samples):

return _resample

# %% ../nbs/D. Common dataset utilities.ipynb 7
# %% ../nbs/D. Common dataset utilities.ipynb 10
def derived_name(input, kind, base="audio", suffix=".gz", dir=None):
dir = Path(dir) if dir else Path(input).parent
return str(dir/(Path(input).name.replace(f"-{base}-", f"-{kind}-") + suffix))

# %% ../nbs/D. Common dataset utilities.ipynb 8
# %% ../nbs/D. Common dataset utilities.ipynb 11
def derived_dataset(kind, base='audio', suffix=".gz", decoders=[], dir=None):
def deriver(url):
url = str(derived_name(url, kind, base=base, suffix=suffix, dir=dir))
Expand All @@ -77,7 +79,7 @@ def deriver(url):
).decode(*decoders)
return deriver

# %% ../nbs/D. Common dataset utilities.ipynb 9
# %% ../nbs/D. Common dataset utilities.ipynb 12
def merge_in(dataset_fun):
"""Merge a dataset into the current one returning samples with the union of keys. Pass in a function
that takes a URL of a sample and returns a dataset for it (called everytime the URL changes).
Expand Down Expand Up @@ -110,7 +112,7 @@ def merge_loop(main_samples):
yield news
return merge_loop

# %% ../nbs/D. Common dataset utilities.ipynb 10
# %% ../nbs/D. Common dataset utilities.ipynb 13
def split_to_chunks(stream, ikey='vad.npy', metakeys=[], pad_to_seconds=30, random_shift=False):
for s in stream:
audio, sr = s['audio']
Expand All @@ -133,11 +135,11 @@ def split_to_chunks(stream, ikey='vad.npy', metakeys=[], pad_to_seconds=30, rand
subs[k] = s[k][i]
yield subs

# %% ../nbs/D. Common dataset utilities.ipynb 11
# %% ../nbs/D. Common dataset utilities.ipynb 14
import re
import tempfile

# %% ../nbs/D. Common dataset utilities.ipynb 12
# %% ../nbs/D. Common dataset utilities.ipynb 15
def torch_audio_opus(key, data):
"""Decode audio using the torchaudio library.
Expand All @@ -156,7 +158,7 @@ def torch_audio_opus(key, data):
stream.write(data)
return torchaudio.load(fname)

# %% ../nbs/D. Common dataset utilities.ipynb 13
# %% ../nbs/D. Common dataset utilities.ipynb 16
def find_audio(stream, okey='audio', ikeys='flac;mp3;wav;ogg;opus'):
ikeys = ikeys.split(';')
for s in stream:
Expand All @@ -167,7 +169,7 @@ def find_audio(stream, okey='audio', ikeys='flac;mp3;wav;ogg;opus'):
break
# implicitly skips elements without any audio

# %% ../nbs/D. Common dataset utilities.ipynb 14
# %% ../nbs/D. Common dataset utilities.ipynb 17
def vad_dataset(shards, ikey='vad.npy', kind='vad'):
return wds.WebDataset(shards).compose(
wds.decode(torch_audio_opus),
Expand All @@ -176,7 +178,7 @@ def vad_dataset(shards, ikey='vad.npy', kind='vad'):
lambda x: split_to_chunks(x, ikey=ikey),
)

# %% ../nbs/D. Common dataset utilities.ipynb 15
# %% ../nbs/D. Common dataset utilities.ipynb 18
@contextmanager
def AtomicTarWriter(name, throwaway=False):
tmp = name+".tmp"
Expand All @@ -185,7 +187,7 @@ def AtomicTarWriter(name, throwaway=False):
if not throwaway:
os.rename(tmp, name)

# %% ../nbs/D. Common dataset utilities.ipynb 16
# %% ../nbs/D. Common dataset utilities.ipynb 19
def readlines(fname):
with open(fname) as file:
return [line.rstrip() for line in file]

0 comments on commit 1a5fb25

Please sign in to comment.