Skip to content

Commit

Permalink
Merge pull request speechbrain#2727 from asumagic/pretrainer-no-colle…
Browse files Browse the repository at this point in the history
…ct-dir

Make `collect_in` optional for `Pretrainer`, disable it by default
  • Loading branch information
asumagic authored Oct 25, 2024
2 parents 1bb368a + 7168e0c commit fd0cd20
Show file tree
Hide file tree
Showing 6 changed files with 84 additions and 32 deletions.
4 changes: 2 additions & 2 deletions speechbrain/inference/classifiers.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,7 +270,7 @@ def classify_batch(self, wavs, wav_lens=None):
text_lab = self.hparams.label_encoder.decode_torch(index)
return out_probs, score, index, text_lab

def classify_file(self, path, savedir="audio_cache"):
def classify_file(self, path, savedir=None):
"""Classifies the given audiofile into the given set of labels.
Arguments
Expand All @@ -297,7 +297,7 @@ def classify_file(self, path, savedir="audio_cache"):
fl,
source=source,
savedir=savedir,
local_strategy=LocalStrategy.NO_LINK,
local_strategy=LocalStrategy.SYMLINK,
)

batch, fs_file = torchaudio.load(path)
Expand Down
8 changes: 4 additions & 4 deletions speechbrain/inference/interfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def foreign_class(
use_auth_token=False,
download_only=False,
huggingface_cache_dir=None,
local_strategy: LocalStrategy = LocalStrategy.NO_LINK,
local_strategy: LocalStrategy = LocalStrategy.SYMLINK,
**kwargs,
):
"""Fetch and load an interface from an outside source
Expand Down Expand Up @@ -278,7 +278,7 @@ def _prepare_modules(self, freeze_params):
for p in self.mods.parameters():
p.requires_grad = False

def load_audio(self, path, savedir="."):
def load_audio(self, path, savedir=None):
"""Load an audio file with this model's input spec
When using a speech model, it is important to use the same type of data,
Expand All @@ -293,7 +293,7 @@ def load_audio(self, path, savedir="."):
fl,
source=source,
savedir=savedir,
local_strategy=LocalStrategy.NO_LINK,
local_strategy=LocalStrategy.SYMLINK,
)
signal, sr = torchaudio.load(str(path), channels_first=False)
return self.audio_normalizer(signal, sr)
Expand Down Expand Up @@ -405,7 +405,7 @@ def from_hparams(
download_only=False,
huggingface_cache_dir=None,
overrides_must_match=True,
local_strategy: LocalStrategy = LocalStrategy.NO_LINK,
local_strategy: LocalStrategy = LocalStrategy.SYMLINK,
**kwargs,
):
"""Fetch and load based from outside source based on HyperPyYAML file
Expand Down
4 changes: 2 additions & 2 deletions speechbrain/inference/interpretability.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ def interpret_batch(self, wavs):

return x_int_sound_domain, text_lab

def interpret_file(self, path, savedir="audio_cache"):
def interpret_file(self, path, savedir=None):
"""Classifies the given audiofile into the given set of labels.
It also provides the interpretation in the audio domain.
Expand All @@ -157,7 +157,7 @@ def interpret_file(self, path, savedir="audio_cache"):
fl,
source=source,
savedir=savedir,
local_strategy=LocalStrategy.NO_LINK,
local_strategy=LocalStrategy.SYMLINK,
)

batch, fs_file = torchaudio.load(path)
Expand Down
4 changes: 2 additions & 2 deletions speechbrain/inference/separation.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ def separate_batch(self, mix):
est_source = est_source[:, :T_origin, :]
return est_source

def separate_file(self, path, savedir="audio_cache"):
def separate_file(self, path, savedir=None):
"""Separate sources from file.
Arguments
Expand All @@ -101,7 +101,7 @@ def separate_file(self, path, savedir="audio_cache"):
fl,
source=source,
savedir=savedir,
local_strategy=LocalStrategy.NO_LINK,
local_strategy=LocalStrategy.SYMLINK,
)

batch, fs_file = torchaudio.load(path)
Expand Down
2 changes: 0 additions & 2 deletions speechbrain/inference/text.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,7 +234,6 @@ class GPTResponseGenerator(ResponseGenerator):
>>> tmpdir = getfixture("tmpdir")
>>> res_gen_model = GPTResponseGenerator.from_hparams(source="speechbrain/MultiWOZ-GPT-Response_Generation",
... savedir="tmpdir",
... pymodule_file="custom.py") # doctest: +SKIP
>>> response = res_gen_model.generate_response("I want to book a table for dinner") # doctest: +SKIP
"""
Expand Down Expand Up @@ -350,7 +349,6 @@ class Llama2ResponseGenerator(ResponseGenerator):
>>> tmpdir = getfixture("tmpdir")
>>> res_gen_model = Llama2ResponseGenerator.from_hparams(source="speechbrain/MultiWOZ-Llama2-Response_Generation",
... savedir="tmpdir",
... pymodule_file="custom.py") # doctest: +SKIP
>>> response = res_gen_model.generate_response("I want to book a table for dinner") # doctest: +SKIP
"""
Expand Down
94 changes: 74 additions & 20 deletions speechbrain/utils/parameter_transfer.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
"""

import pathlib
import platform
import warnings

from speechbrain.utils.checkpoints import (
DEFAULT_LOAD_HOOKS,
Expand All @@ -27,13 +29,20 @@
class Pretrainer:
"""Orchestrates pretraining
First collects parameter file symlinks into the given directory. Then
calls load hooks for each of those parameter files.
First optionally collects files from some source (local directory,
HuggingFace repository, base URL), into the `collect_in` directory, if
specified.
Then, calls load hooks for each of those files.
Arguments
---------
collect_in : str or Path
Path to directory where the parameter file symlinks are collected.
collect_in : str or Path, optional
Path to directory where the files are to be collected.
If `None`, then files will be referred to from cache or directly, if
possible (URLs will fail). There will not be a centralized target
directory with all the files.
loadables : mapping
Mapping from loadable key to object. This connects the keys to
the actual object instances.
Expand All @@ -56,14 +65,16 @@ class Pretrainer:

def __init__(
self,
collect_in="./model_checkpoints",
collect_in=None,
loadables=None,
paths=None,
custom_hooks=None,
conditions=None,
):
self.loadables = {}
self.collect_in = pathlib.Path(collect_in)

self.set_collect_in(collect_in)

if loadables is not None:
self.add_loadables(loadables)
self.paths = {}
Expand All @@ -79,7 +90,7 @@ def __init__(

def set_collect_in(self, path):
"""Change the collecting path"""
self.collect_in = pathlib.Path(path)
self.collect_in = pathlib.Path(path) if path is not None else None

def add_loadables(self, loadables):
"""Update the loadables dict from the given mapping.
Expand Down Expand Up @@ -174,7 +185,7 @@ def collect_files(
self,
default_source=None,
use_auth_token=False,
local_strategy: LocalStrategy = LocalStrategy.NO_LINK,
local_strategy: LocalStrategy = LocalStrategy.SYMLINK,
):
"""Fetches parameters from known paths with fallback default_source
Expand All @@ -190,14 +201,16 @@ def collect_files(
---------
default_source : str or Path or FetchSource
This is used for each loadable which doesn't have a path already
specified. If the loadable has key "asr", then the file to look for is
default_source/asr.ckpt
specified.
e.g. if the loadable has key `"asr"`, then the file to look for is
`<default_source>/asr.ckpt`
use_auth_token : bool (default: False)
If true Huggingface's auth_token will be used to load private models from the HuggingFace Hub,
default is False because the majority of models are public.
local_strategy : speechbrain.utils.fetching.LocalStrategy
The fetching strategy to use, which controls the behavior of remote file
fetching with regards to symlinking and copying.
Ignored if a `collect_in` directory was not specified.
See :func:`speechbrain.utils.fetching.fetch` for further details.
Returns
Expand All @@ -207,10 +220,25 @@ def collect_files(
parameters can be loaded. This is not used in this class, but
can possibly be helpful.
"""
logger.debug(
f"Collecting files (or symlinks) for pretraining in {self.collect_in}."
)
self.collect_in.mkdir(exist_ok=True)

if self.collect_in is not None:
logger.debug(
f"Collecting files (or symlinks) for pretraining in {self.collect_in}."
)
self.collect_in.mkdir(exist_ok=True)

if (
platform.system() == "Windows"
and local_strategy == LocalStrategy.SYMLINK
):
warnings.warn(
"Requested Pretrainer collection using symlinks on Windows. This might not work; see `LocalStrategy` documentation. Consider unsetting `collect_in` in Pretrainer to avoid symlinking altogether."
)
else:
logger.debug(
"Fetching files for pretraining (no collection directory set)"
)

loadable_paths = {}
for name in self.loadables:
if not self.is_loadable(name):
Expand Down Expand Up @@ -238,11 +266,32 @@ def collect_files(
"local_strategy": local_strategy,
}

path = None

def run_fetch(**kwargs):
"""Very basic local wrapper to fetch to store the path in a
local of collect_files
Arguments
---------
**kwargs : dict
Arguments to forward to fetch"""
nonlocal path
path = fetch(**kwargs)

# run fetch() on the main process, potentially performing downloading
# which we do NOT want to happen concurrently.
#
# then, if there are any non-main processes, run fetch() on them to
# resolve the path.
#
# path needs to be available only if it is a local source w/o symlink
run_on_main(fetch, kwargs=fetch_kwargs)

# we need the path; regardless of rank
path = fetch(**fetch_kwargs)
run_on_main(
run_fetch,
kwargs=fetch_kwargs,
post_func=run_fetch,
post_kwargs=fetch_kwargs,
)

loadable_paths[name] = path
if isinstance(source, FetchSource):
Expand Down Expand Up @@ -285,13 +334,18 @@ def load_collected(self):
if not self.is_loadable(name):
continue
filename = name + PARAMFILE_EXT
paramfiles[name] = self.collect_in / filename

if name in self.is_local:
logger.debug(
f"Redirecting (loading from local path): {paramfiles[name]} -> {self.paths[name]}"
f"Redirecting (loading from local path): {name} -> {self.paths[name]}"
)
paramfiles[name] = self.paths[name]
elif self.collect_in is not None:
paramfiles[name] = self.collect_in / filename
else:
raise ValueError(
f'Pretrainer has never collected `{name}`, did you forget a call to `collect_files`? Could not fall back to `collect_in`, as it was not specified (default is no longer "model_checkpoints").'
)
self._call_load_hooks(paramfiles)

def _call_load_hooks(self, paramfiles):
Expand Down

0 comments on commit fd0cd20

Please sign in to comment.