Skip to content

Commit

Permalink
Merge pull request #210 from idiap/manager
Browse files Browse the repository at this point in the history
feat: allow both Path and strings where possible and add type hints
  • Loading branch information
eginhard authored Dec 16, 2024
2 parents cd52907 + 0df04cc commit 5165e71
Show file tree
Hide file tree
Showing 16 changed files with 223 additions and 167 deletions.
2 changes: 1 addition & 1 deletion TTS/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ def list_models() -> list[str]:

def download_model_by_name(
self, model_name: str, vocoder_name: Optional[str] = None
) -> tuple[Optional[str], Optional[str], Optional[str]]:
) -> tuple[Optional[Path], Optional[Path], Optional[Path]]:
model_path, config_path, model_item = self.manager.download_model(model_name)
if "fairseq" in model_name or (model_item is not None and isinstance(model_item["model_url"], list)):
# return model directory if there are multiple files
Expand Down
6 changes: 3 additions & 3 deletions TTS/bin/synthesize.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
tts --model_info_by_name vocoder_models/en/ljspeech/hifigan_v2
```
#### Single Speaker Models
#### Single speaker models
- Run TTS with the default model (`tts_models/en/ljspeech/tacotron2-DDC`):
Expand Down Expand Up @@ -102,7 +102,7 @@
--vocoder_config_path path/to/vocoder_config.json
```
#### Multi-speaker Models
#### Multi-speaker models
- List the available speakers and choose a `<speaker_id>` among them:
Expand All @@ -125,7 +125,7 @@
--speakers_file_path path/to/speaker.json --speaker_idx <speaker_id>
```
#### Voice Conversion Models
#### Voice conversion models
```sh
tts --out_path output/path/speech.wav --model_name "<language>/<dataset>/<model_name>" \\
Expand Down
5 changes: 3 additions & 2 deletions TTS/config/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import json
import os
import re
from typing import Dict
from typing import Any, Dict, Union

import fsspec
import yaml
Expand Down Expand Up @@ -68,7 +68,7 @@ def _process_model_name(config_dict: Dict) -> str:
return model_name


def load_config(config_path: str) -> Coqpit:
def load_config(config_path: Union[str, os.PathLike[Any]]) -> Coqpit:
"""Import `json` or `yaml` files as TTS configs. First, load the input file as a `dict` and check the model name
to find the corresponding Config class. Then initialize the Config.
Expand All @@ -81,6 +81,7 @@ def load_config(config_path: str) -> Coqpit:
Returns:
Coqpit: TTS config object.
"""
config_path = str(config_path)
config_dict = {}
ext = os.path.splitext(config_path)[1]
if ext in (".yml", ".yaml"):
Expand Down
8 changes: 4 additions & 4 deletions TTS/tts/utils/languages.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import os
from typing import Any, Dict, List, Optional
from typing import Any, Dict, List, Optional, Union

import fsspec
import numpy as np
Expand Down Expand Up @@ -27,8 +27,8 @@ class LanguageManager(BaseIDManager):

def __init__(
self,
language_ids_file_path: str = "",
config: Coqpit = None,
language_ids_file_path: Union[str, os.PathLike[Any]] = "",
config: Optional[Coqpit] = None,
):
super().__init__(id_file_path=language_ids_file_path)

Expand Down Expand Up @@ -76,7 +76,7 @@ def parse_ids_from_data(items: List, parse_key: str) -> Any:
def set_ids_from_data(self, items: List, parse_key: str) -> Any:
raise NotImplementedError

def save_ids_to_file(self, file_path: str) -> None:
def save_ids_to_file(self, file_path: Union[str, os.PathLike[Any]]) -> None:
"""Save language IDs to a json file.
Args:
Expand Down
47 changes: 27 additions & 20 deletions TTS/tts/utils/managers.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import json
import os
import random
from typing import Any, Dict, List, Tuple, Union

Expand All @@ -12,7 +13,8 @@
from TTS.utils.generic_utils import is_pytorch_at_least_2_4


def load_file(path: str):
def load_file(path: Union[str, os.PathLike[Any]]):
path = str(path)
if path.endswith(".json"):
with fsspec.open(path, "r") as f:
return json.load(f)
Expand All @@ -23,7 +25,8 @@ def load_file(path: str):
raise ValueError("Unsupported file type")


def save_file(obj: Any, path: str):
def save_file(obj: Any, path: Union[str, os.PathLike[Any]]):
path = str(path)
if path.endswith(".json"):
with fsspec.open(path, "w") as f:
json.dump(obj, f, indent=4)
Expand All @@ -39,20 +42,20 @@ class BaseIDManager:
It defines common `ID` manager specific functions.
"""

def __init__(self, id_file_path: str = ""):
def __init__(self, id_file_path: Union[str, os.PathLike[Any]] = ""):
self.name_to_id = {}

if id_file_path:
self.load_ids_from_file(id_file_path)

@staticmethod
def _load_json(json_file_path: str) -> Dict:
with fsspec.open(json_file_path, "r") as f:
def _load_json(json_file_path: Union[str, os.PathLike[Any]]) -> Dict:
with fsspec.open(str(json_file_path), "r") as f:
return json.load(f)

@staticmethod
def _save_json(json_file_path: str, data: dict) -> None:
with fsspec.open(json_file_path, "w") as f:
def _save_json(json_file_path: Union[str, os.PathLike[Any]], data: dict) -> None:
with fsspec.open(str(json_file_path), "w") as f:
json.dump(data, f, indent=4)

def set_ids_from_data(self, items: List, parse_key: str) -> None:
Expand All @@ -63,15 +66,15 @@ def set_ids_from_data(self, items: List, parse_key: str) -> None:
"""
self.name_to_id = self.parse_ids_from_data(items, parse_key=parse_key)

def load_ids_from_file(self, file_path: str) -> None:
def load_ids_from_file(self, file_path: Union[str, os.PathLike[Any]]) -> None:
"""Set IDs from a file.
Args:
file_path (str): Path to the file.
"""
self.name_to_id = load_file(file_path)

def save_ids_to_file(self, file_path: str) -> None:
def save_ids_to_file(self, file_path: Union[str, os.PathLike[Any]]) -> None:
"""Save IDs to a json file.
Args:
Expand Down Expand Up @@ -130,10 +133,10 @@ class EmbeddingManager(BaseIDManager):

def __init__(
self,
embedding_file_path: Union[str, List[str]] = "",
id_file_path: str = "",
encoder_model_path: str = "",
encoder_config_path: str = "",
embedding_file_path: Union[Union[str, os.PathLike[Any]], list[Union[str, os.PathLike[Any]]]] = "",
id_file_path: Union[str, os.PathLike[Any]] = "",
encoder_model_path: Union[str, os.PathLike[Any]] = "",
encoder_config_path: Union[str, os.PathLike[Any]] = "",
use_cuda: bool = False,
):
super().__init__(id_file_path=id_file_path)
Expand Down Expand Up @@ -176,7 +179,7 @@ def embedding_names(self):
"""Get embedding names."""
return list(self.embeddings_by_names.keys())

def save_embeddings_to_file(self, file_path: str) -> None:
def save_embeddings_to_file(self, file_path: Union[str, os.PathLike[Any]]) -> None:
"""Save embeddings to a json file.
Args:
Expand All @@ -185,7 +188,7 @@ def save_embeddings_to_file(self, file_path: str) -> None:
save_file(self.embeddings, file_path)

@staticmethod
def read_embeddings_from_file(file_path: str):
def read_embeddings_from_file(file_path: Union[str, os.PathLike[Any]]):
"""Load embeddings from a json file.
Args:
Expand All @@ -204,7 +207,7 @@ def read_embeddings_from_file(file_path: str):
embeddings_by_names[x["name"]].append(x["embedding"])
return name_to_id, clip_ids, embeddings, embeddings_by_names

def load_embeddings_from_file(self, file_path: str) -> None:
def load_embeddings_from_file(self, file_path: Union[str, os.PathLike[Any]]) -> None:
"""Load embeddings from a json file.
Args:
Expand All @@ -214,7 +217,7 @@ def load_embeddings_from_file(self, file_path: str) -> None:
file_path
)

def load_embeddings_from_list_of_files(self, file_paths: List[str]) -> None:
def load_embeddings_from_list_of_files(self, file_paths: list[Union[str, os.PathLike[Any]]]) -> None:
"""Load embeddings from a list of json files and don't allow duplicate keys.
Args:
Expand Down Expand Up @@ -313,7 +316,9 @@ def get_random_embedding(self) -> Any:
def get_clips(self) -> List:
return sorted(self.embeddings.keys())

def init_encoder(self, model_path: str, config_path: str, use_cuda=False) -> None:
def init_encoder(
self, model_path: Union[str, os.PathLike[Any]], config_path: Union[str, os.PathLike[Any]], use_cuda=False
) -> None:
"""Initialize a speaker encoder model.
Args:
Expand All @@ -325,11 +330,13 @@ def init_encoder(self, model_path: str, config_path: str, use_cuda=False) -> Non
self.encoder_config = load_config(config_path)
self.encoder = setup_encoder_model(self.encoder_config)
self.encoder_criterion = self.encoder.load_checkpoint(
self.encoder_config, model_path, eval=True, use_cuda=use_cuda, cache=True
self.encoder_config, str(model_path), eval=True, use_cuda=use_cuda, cache=True
)
self.encoder_ap = AudioProcessor(**self.encoder_config.audio)

def compute_embedding_from_clip(self, wav_file: Union[str, List[str]]) -> list:
def compute_embedding_from_clip(
self, wav_file: Union[Union[str, os.PathLike[Any]], List[Union[str, os.PathLike[Any]]]]
) -> list:
"""Compute a embedding from a given audio file.
Args:
Expand Down
10 changes: 5 additions & 5 deletions TTS/tts/utils/speakers.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import json
import logging
import os
from typing import Any, Dict, List, Union
from typing import Any, Dict, List, Optional, Union

import fsspec
import numpy as np
Expand Down Expand Up @@ -56,11 +56,11 @@ class SpeakerManager(EmbeddingManager):

def __init__(
self,
data_items: List[List[Any]] = None,
data_items: Optional[list[list[Any]]] = None,
d_vectors_file_path: str = "",
speaker_id_file_path: str = "",
encoder_model_path: str = "",
encoder_config_path: str = "",
speaker_id_file_path: Union[str, os.PathLike[Any]] = "",
encoder_model_path: Union[str, os.PathLike[Any]] = "",
encoder_config_path: Union[str, os.PathLike[Any]] = "",
use_cuda: bool = False,
):
super().__init__(
Expand Down
9 changes: 6 additions & 3 deletions TTS/utils/audio/numpy_transforms.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import logging
import os
from io import BytesIO
from typing import Optional
from typing import Any, Optional, Union

import librosa
import numpy as np
Expand Down Expand Up @@ -406,7 +407,9 @@ def rms_volume_norm(*, x: np.ndarray, db_level: float = -27.0, **kwargs) -> np.n
return rms_norm(wav=x, db_level=db_level)


def load_wav(*, filename: str, sample_rate: Optional[int] = None, resample: bool = False, **kwargs) -> np.ndarray:
def load_wav(
*, filename: Union[str, os.PathLike[Any]], sample_rate: Optional[int] = None, resample: bool = False, **kwargs
) -> np.ndarray:
"""Read a wav file using Librosa and optionally resample, silence trim, volume normalize.
Resampling slows down loading the file significantly. Therefore it is recommended to resample the file before.
Expand Down Expand Up @@ -434,7 +437,7 @@ def load_wav(*, filename: str, sample_rate: Optional[int] = None, resample: bool
def save_wav(
*,
wav: np.ndarray,
path: str,
path: Union[str, os.PathLike[Any]],
sample_rate: int,
pipe_out=None,
do_rms_norm: bool = False,
Expand Down
9 changes: 6 additions & 3 deletions TTS/utils/audio/processor.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import logging
from typing import Optional
import os
from typing import Any, Optional, Union

import librosa
import numpy as np
Expand Down Expand Up @@ -548,7 +549,7 @@ def sound_norm(x: np.ndarray) -> np.ndarray:
return volume_norm(x=x)

### save and load ###
def load_wav(self, filename: str, sr: Optional[int] = None) -> np.ndarray:
def load_wav(self, filename: Union[str, os.PathLike[Any]], sr: Optional[int] = None) -> np.ndarray:
"""Read a wav file using Librosa and optionally resample, silence trim, volume normalize.
Resampling slows down loading the file significantly. Therefore it is recommended to resample the file before.
Expand All @@ -575,7 +576,9 @@ def load_wav(self, filename: str, sr: Optional[int] = None) -> np.ndarray:
x = rms_volume_norm(x=x, db_level=self.db_level)
return x

def save_wav(self, wav: np.ndarray, path: str, sr: Optional[int] = None, pipe_out=None) -> None:
def save_wav(
self, wav: np.ndarray, path: Union[str, os.PathLike[Any]], sr: Optional[int] = None, pipe_out=None
) -> None:
"""Save a waveform to a file using Scipy.
Args:
Expand Down
7 changes: 6 additions & 1 deletion TTS/utils/generic_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import logging
import re
from pathlib import Path
from typing import Callable, Dict, Optional, TypeVar, Union
from typing import Any, Callable, Dict, Optional, TypeVar, Union

import torch
from packaging.version import Version
Expand Down Expand Up @@ -133,3 +133,8 @@ def setup_logger(
def is_pytorch_at_least_2_4() -> bool:
"""Check if the installed Pytorch version is 2.4 or higher."""
return Version(torch.__version__) >= Version("2.4")


def optional_to_str(x: Optional[Any]) -> str:
"""Convert input to string, using empty string if input is None."""
return "" if x is None else str(x)
Loading

0 comments on commit 5165e71

Please sign in to comment.