Skip to content

Commit

Permalink
pass cache_dir to speech_separation pipeline
Browse files Browse the repository at this point in the history
  • Loading branch information
benniekiss authored and bgmt committed Jul 6, 2024
1 parent 2eb7526 commit 0733e91
Showing 1 changed file with 7 additions and 2 deletions.
9 changes: 7 additions & 2 deletions pyannote/audio/pipelines/speech_separation.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import math
import textwrap
import warnings
from pathlib import Path
from typing import Callable, Optional, Text, Tuple, Union

import numpy as np
Expand Down Expand Up @@ -88,6 +89,9 @@ class SpeechSeparation(SpeakerDiarizationMixin, Pipeline):
When loading private huggingface.co models, set `use_auth_token`
to True or to a string containing your hugginface.co authentication
token that can be obtained by running `huggingface-cli login`
cache_dir: Path or str, optional
Path to model cache directory. Defaults to content of PYANNOTE_CACHE
environment variable, or "~/.cache/torch/pyannote" when unset.
Usage
-----
Expand Down Expand Up @@ -132,11 +136,12 @@ def __init__(
segmentation_batch_size: int = 1,
der_variant: Optional[dict] = None,
use_auth_token: Union[Text, None] = None,
cache_dir: Union[Path, Text, None] = None,
):
super().__init__()

self.segmentation_model = segmentation
model: Model = get_model(segmentation, use_auth_token=use_auth_token)
model: Model = get_model(segmentation, use_auth_token=use_auth_token, cache_dir=cache_dir)

self.segmentation_step = segmentation_step

Expand Down Expand Up @@ -173,7 +178,7 @@ def __init__(

else:
self._embedding = PretrainedSpeakerEmbedding(
self.embedding, use_auth_token=use_auth_token
self.embedding, use_auth_token=use_auth_token, cache_dir=cache_dir
)
self._audio = Audio(sample_rate=self._embedding.sample_rate, mono="downmix")
metric = self._embedding.metric
Expand Down

0 comments on commit 0733e91

Please sign in to comment.