Skip to content

Commit

Permalink
feat: add huggingface model hub integration
Browse files Browse the repository at this point in the history
  • Loading branch information
julien-c authored Jan 5, 2021
1 parent 78a16a2 commit 61be512
Show file tree
Hide file tree
Showing 5 changed files with 49 additions and 4 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ __pycache__/
# Distribution / packaging
.Python
env/
.env/
build/
develop-eggs/
dist/
Expand Down
2 changes: 1 addition & 1 deletion pyannote/audio/core/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def __init__(
self.model = (
model
if isinstance(model, Model)
else load_from_checkpoint(Path(model), strict=False)
else load_from_checkpoint(Path(model), map_location=device, strict=False)
)

if window not in ["sliding", "whole"]:
Expand Down
42 changes: 39 additions & 3 deletions pyannote/audio/core/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,23 +20,32 @@
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.

import os
import warnings
from dataclasses import dataclass
from functools import cached_property
from importlib import import_module
from pathlib import Path
from typing import Any, Dict, List, Optional, Text, Tuple, Union
from urllib.parse import urlparse

import pytorch_lightning as pl
import torch
import torch.nn as nn
from huggingface_hub import cached_download, hf_hub_url
from pytorch_lightning.utilities.cloud_io import load as pl_load
from semver import VersionInfo

from pyannote.audio import __version__
from pyannote.audio.core.io import Audio
from pyannote.audio.core.task import Problem, Scale, Task, TaskSpecification

CACHE_DIR = os.getenv(
"PYANNOTE_CACHE",
os.path.expanduser("~/.cache/torch/pyannote"),
)
HF_PYTORCH_WEIGHTS_NAME = "pytorch_model.bin"


@dataclass
class ModelIntrospection:
Expand Down Expand Up @@ -628,7 +637,8 @@ def load_from_checkpoint(
Parameters
----------
checkpoint_path : Path or str
Path to checkpoint. This can also be a URL.
Path to checkpoint, or a remote URL, or a model identifier from
the huggingface.co model hub.
map_location: optional
If your checkpoint saved a GPU model and you now load on CPUs
or a different number of GPUs, use this to map to the new setup.
Expand Down Expand Up @@ -661,8 +671,34 @@ def load_from_checkpoint(
if hparams_file is not None:
hparams_file = str(hparams_file)

# resolve the checkpoint_path to
# something that pl will handle
if os.path.isfile(checkpoint_path):
path_for_pl = checkpoint_path
elif urlparse(checkpoint_path).scheme in ("http", "https"):
path_for_pl = checkpoint_path
else:
# Finally, let's try to find it on Hugging Face model hub
# e.g. julien-c/voice-activity-detection is a valid model id
# and julien-c/voice-activity-detection@main supports specifying a commit/branch/tag.
if "@" in checkpoint_path:
model_id = checkpoint_path.split("@")[0]
revision = checkpoint_path.split("@")[1]
else:
model_id = checkpoint_path
revision = None
url = hf_hub_url(
model_id=model_id, filename=HF_PYTORCH_WEIGHTS_NAME, revision=revision
)
path_for_pl = cached_download(
url=url,
library_name="pyannote",
library_version=__version__,
cache_dir=CACHE_DIR,
)

# obtain model class from the checkpoint
checkpoint = pl_load(checkpoint_path, map_location=map_location)
checkpoint = pl_load(path_for_pl, map_location=map_location)

module_name: str = checkpoint["pyannote.audio"]["model"]["module"]
module = import_module(module_name)
Expand All @@ -671,7 +707,7 @@ def load_from_checkpoint(
Klass: Model = getattr(module, class_name)

return Klass.load_from_checkpoint(
checkpoint_path,
path_for_pl,
map_location=map_location,
hparams_file=hparams_file,
strict=strict,
Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
asteroid-filterbanks == 0.1.0
einops >= 0.3.0
huggingface_hub
hydra-core >= 1.0.4
librosa >= 0.8
pyannote.core >= 4.1
Expand Down
7 changes: 7 additions & 0 deletions tests/inference_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
from pyannote.core import SlidingWindowFeature
from pyannote.database import FileFinder, get_protocol

HF_SAMPLE_MODEL_ID = "julien-c/voice-activity-detection"


@pytest.fixture()
def trained():
Expand Down Expand Up @@ -93,3 +95,8 @@ def test_multi_seg_infer():
for attr in ["vad", "scd", "osd"]:
assert attr in scores
assert isinstance(scores[attr], SlidingWindowFeature)


def test_hf_download():
inference = Inference(HF_SAMPLE_MODEL_ID, device="cpu")
assert isinstance(inference, Inference)

0 comments on commit 61be512

Please sign in to comment.