From a425ba599d93db96338dca86cd1e0e6c9fe34d2d Mon Sep 17 00:00:00 2001 From: Enno Hermann Date: Sat, 14 Dec 2024 00:28:01 +0100 Subject: [PATCH 1/2] feat: allow both Path and strings where possible and add type hints --- TTS/api.py | 2 +- TTS/config/__init__.py | 5 +- TTS/tts/utils/languages.py | 8 +- TTS/tts/utils/managers.py | 47 ++++--- TTS/tts/utils/speakers.py | 10 +- TTS/utils/audio/numpy_transforms.py | 9 +- TTS/utils/audio/processor.py | 9 +- TTS/utils/generic_utils.py | 7 +- TTS/utils/manage.py | 199 +++++++++++++++------------- TTS/utils/synthesizer.py | 58 ++++---- tests/zoo_tests/test_models.py | 13 +- 11 files changed, 204 insertions(+), 163 deletions(-) diff --git a/TTS/api.py b/TTS/api.py index 83189482cb..7720530823 100644 --- a/TTS/api.py +++ b/TTS/api.py @@ -157,7 +157,7 @@ def list_models() -> list[str]: def download_model_by_name( self, model_name: str, vocoder_name: Optional[str] = None - ) -> tuple[Optional[str], Optional[str], Optional[str]]: + ) -> tuple[Optional[Path], Optional[Path], Optional[Path]]: model_path, config_path, model_item = self.manager.download_model(model_name) if "fairseq" in model_name or (model_item is not None and isinstance(model_item["model_url"], list)): # return model directory if there are multiple files diff --git a/TTS/config/__init__.py b/TTS/config/__init__.py index 5103f200b0..e5f40c0296 100644 --- a/TTS/config/__init__.py +++ b/TTS/config/__init__.py @@ -1,7 +1,7 @@ import json import os import re -from typing import Dict +from typing import Any, Dict, Union import fsspec import yaml @@ -68,7 +68,7 @@ def _process_model_name(config_dict: Dict) -> str: return model_name -def load_config(config_path: str) -> Coqpit: +def load_config(config_path: Union[str, os.PathLike[Any]]) -> Coqpit: """Import `json` or `yaml` files as TTS configs. First, load the input file as a `dict` and check the model name to find the corresponding Config class. Then initialize the Config. @@ -81,6 +81,7 @@ def load_config(config_path: str) -> Coqpit: Returns: Coqpit: TTS config object. """ + config_path = str(config_path) config_dict = {} ext = os.path.splitext(config_path)[1] if ext in (".yml", ".yaml"): diff --git a/TTS/tts/utils/languages.py b/TTS/tts/utils/languages.py index f134daf58e..c72de2d4e6 100644 --- a/TTS/tts/utils/languages.py +++ b/TTS/tts/utils/languages.py @@ -1,5 +1,5 @@ import os -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List, Optional, Union import fsspec import numpy as np @@ -27,8 +27,8 @@ class LanguageManager(BaseIDManager): def __init__( self, - language_ids_file_path: str = "", - config: Coqpit = None, + language_ids_file_path: Union[str, os.PathLike[Any]] = "", + config: Optional[Coqpit] = None, ): super().__init__(id_file_path=language_ids_file_path) @@ -76,7 +76,7 @@ def parse_ids_from_data(items: List, parse_key: str) -> Any: def set_ids_from_data(self, items: List, parse_key: str) -> Any: raise NotImplementedError - def save_ids_to_file(self, file_path: str) -> None: + def save_ids_to_file(self, file_path: Union[str, os.PathLike[Any]]) -> None: """Save language IDs to a json file. Args: diff --git a/TTS/tts/utils/managers.py b/TTS/tts/utils/managers.py index 6a2f7df67b..3a715dd75d 100644 --- a/TTS/tts/utils/managers.py +++ b/TTS/tts/utils/managers.py @@ -1,4 +1,5 @@ import json +import os import random from typing import Any, Dict, List, Tuple, Union @@ -12,7 +13,8 @@ from TTS.utils.generic_utils import is_pytorch_at_least_2_4 -def load_file(path: str): +def load_file(path: Union[str, os.PathLike[Any]]): + path = str(path) if path.endswith(".json"): with fsspec.open(path, "r") as f: return json.load(f) @@ -23,7 +25,8 @@ def load_file(path: str): raise ValueError("Unsupported file type") -def save_file(obj: Any, path: str): +def save_file(obj: Any, path: Union[str, os.PathLike[Any]]): + path = str(path) if path.endswith(".json"): with fsspec.open(path, "w") as f: json.dump(obj, f, indent=4) @@ -39,20 +42,20 @@ class BaseIDManager: It defines common `ID` manager specific functions. """ - def __init__(self, id_file_path: str = ""): + def __init__(self, id_file_path: Union[str, os.PathLike[Any]] = ""): self.name_to_id = {} if id_file_path: self.load_ids_from_file(id_file_path) @staticmethod - def _load_json(json_file_path: str) -> Dict: - with fsspec.open(json_file_path, "r") as f: + def _load_json(json_file_path: Union[str, os.PathLike[Any]]) -> Dict: + with fsspec.open(str(json_file_path), "r") as f: return json.load(f) @staticmethod - def _save_json(json_file_path: str, data: dict) -> None: - with fsspec.open(json_file_path, "w") as f: + def _save_json(json_file_path: Union[str, os.PathLike[Any]], data: dict) -> None: + with fsspec.open(str(json_file_path), "w") as f: json.dump(data, f, indent=4) def set_ids_from_data(self, items: List, parse_key: str) -> None: @@ -63,7 +66,7 @@ def set_ids_from_data(self, items: List, parse_key: str) -> None: """ self.name_to_id = self.parse_ids_from_data(items, parse_key=parse_key) - def load_ids_from_file(self, file_path: str) -> None: + def load_ids_from_file(self, file_path: Union[str, os.PathLike[Any]]) -> None: """Set IDs from a file. Args: @@ -71,7 +74,7 @@ def load_ids_from_file(self, file_path: str) -> None: """ self.name_to_id = load_file(file_path) - def save_ids_to_file(self, file_path: str) -> None: + def save_ids_to_file(self, file_path: Union[str, os.PathLike[Any]]) -> None: """Save IDs to a json file. Args: @@ -130,10 +133,10 @@ class EmbeddingManager(BaseIDManager): def __init__( self, - embedding_file_path: Union[str, List[str]] = "", - id_file_path: str = "", - encoder_model_path: str = "", - encoder_config_path: str = "", + embedding_file_path: Union[Union[str, os.PathLike[Any]], list[Union[str, os.PathLike[Any]]]] = "", + id_file_path: Union[str, os.PathLike[Any]] = "", + encoder_model_path: Union[str, os.PathLike[Any]] = "", + encoder_config_path: Union[str, os.PathLike[Any]] = "", use_cuda: bool = False, ): super().__init__(id_file_path=id_file_path) @@ -176,7 +179,7 @@ def embedding_names(self): """Get embedding names.""" return list(self.embeddings_by_names.keys()) - def save_embeddings_to_file(self, file_path: str) -> None: + def save_embeddings_to_file(self, file_path: Union[str, os.PathLike[Any]]) -> None: """Save embeddings to a json file. Args: @@ -185,7 +188,7 @@ def save_embeddings_to_file(self, file_path: str) -> None: save_file(self.embeddings, file_path) @staticmethod - def read_embeddings_from_file(file_path: str): + def read_embeddings_from_file(file_path: Union[str, os.PathLike[Any]]): """Load embeddings from a json file. Args: @@ -204,7 +207,7 @@ def read_embeddings_from_file(file_path: str): embeddings_by_names[x["name"]].append(x["embedding"]) return name_to_id, clip_ids, embeddings, embeddings_by_names - def load_embeddings_from_file(self, file_path: str) -> None: + def load_embeddings_from_file(self, file_path: Union[str, os.PathLike[Any]]) -> None: """Load embeddings from a json file. Args: @@ -214,7 +217,7 @@ def load_embeddings_from_file(self, file_path: str) -> None: file_path ) - def load_embeddings_from_list_of_files(self, file_paths: List[str]) -> None: + def load_embeddings_from_list_of_files(self, file_paths: list[Union[str, os.PathLike[Any]]]) -> None: """Load embeddings from a list of json files and don't allow duplicate keys. Args: @@ -313,7 +316,9 @@ def get_random_embedding(self) -> Any: def get_clips(self) -> List: return sorted(self.embeddings.keys()) - def init_encoder(self, model_path: str, config_path: str, use_cuda=False) -> None: + def init_encoder( + self, model_path: Union[str, os.PathLike[Any]], config_path: Union[str, os.PathLike[Any]], use_cuda=False + ) -> None: """Initialize a speaker encoder model. Args: @@ -325,11 +330,13 @@ def init_encoder(self, model_path: str, config_path: str, use_cuda=False) -> Non self.encoder_config = load_config(config_path) self.encoder = setup_encoder_model(self.encoder_config) self.encoder_criterion = self.encoder.load_checkpoint( - self.encoder_config, model_path, eval=True, use_cuda=use_cuda, cache=True + self.encoder_config, str(model_path), eval=True, use_cuda=use_cuda, cache=True ) self.encoder_ap = AudioProcessor(**self.encoder_config.audio) - def compute_embedding_from_clip(self, wav_file: Union[str, List[str]]) -> list: + def compute_embedding_from_clip( + self, wav_file: Union[Union[str, os.PathLike[Any]], List[Union[str, os.PathLike[Any]]]] + ) -> list: """Compute a embedding from a given audio file. Args: diff --git a/TTS/tts/utils/speakers.py b/TTS/tts/utils/speakers.py index 5229af81c5..89c56583f5 100644 --- a/TTS/tts/utils/speakers.py +++ b/TTS/tts/utils/speakers.py @@ -1,7 +1,7 @@ import json import logging import os -from typing import Any, Dict, List, Union +from typing import Any, Dict, List, Optional, Union import fsspec import numpy as np @@ -56,11 +56,11 @@ class SpeakerManager(EmbeddingManager): def __init__( self, - data_items: List[List[Any]] = None, + data_items: Optional[list[list[Any]]] = None, d_vectors_file_path: str = "", - speaker_id_file_path: str = "", - encoder_model_path: str = "", - encoder_config_path: str = "", + speaker_id_file_path: Union[str, os.PathLike[Any]] = "", + encoder_model_path: Union[str, os.PathLike[Any]] = "", + encoder_config_path: Union[str, os.PathLike[Any]] = "", use_cuda: bool = False, ): super().__init__( diff --git a/TTS/utils/audio/numpy_transforms.py b/TTS/utils/audio/numpy_transforms.py index 9c83009b0f..0cba7fc8a8 100644 --- a/TTS/utils/audio/numpy_transforms.py +++ b/TTS/utils/audio/numpy_transforms.py @@ -1,6 +1,7 @@ import logging +import os from io import BytesIO -from typing import Optional +from typing import Any, Optional, Union import librosa import numpy as np @@ -406,7 +407,9 @@ def rms_volume_norm(*, x: np.ndarray, db_level: float = -27.0, **kwargs) -> np.n return rms_norm(wav=x, db_level=db_level) -def load_wav(*, filename: str, sample_rate: Optional[int] = None, resample: bool = False, **kwargs) -> np.ndarray: +def load_wav( + *, filename: Union[str, os.PathLike[Any]], sample_rate: Optional[int] = None, resample: bool = False, **kwargs +) -> np.ndarray: """Read a wav file using Librosa and optionally resample, silence trim, volume normalize. Resampling slows down loading the file significantly. Therefore it is recommended to resample the file before. @@ -434,7 +437,7 @@ def load_wav(*, filename: str, sample_rate: Optional[int] = None, resample: bool def save_wav( *, wav: np.ndarray, - path: str, + path: Union[str, os.PathLike[Any]], sample_rate: int, pipe_out=None, do_rms_norm: bool = False, diff --git a/TTS/utils/audio/processor.py b/TTS/utils/audio/processor.py index 1d8fed8e39..bf07333aea 100644 --- a/TTS/utils/audio/processor.py +++ b/TTS/utils/audio/processor.py @@ -1,5 +1,6 @@ import logging -from typing import Optional +import os +from typing import Any, Optional, Union import librosa import numpy as np @@ -548,7 +549,7 @@ def sound_norm(x: np.ndarray) -> np.ndarray: return volume_norm(x=x) ### save and load ### - def load_wav(self, filename: str, sr: Optional[int] = None) -> np.ndarray: + def load_wav(self, filename: Union[str, os.PathLike[Any]], sr: Optional[int] = None) -> np.ndarray: """Read a wav file using Librosa and optionally resample, silence trim, volume normalize. Resampling slows down loading the file significantly. Therefore it is recommended to resample the file before. @@ -575,7 +576,9 @@ def load_wav(self, filename: str, sr: Optional[int] = None) -> np.ndarray: x = rms_volume_norm(x=x, db_level=self.db_level) return x - def save_wav(self, wav: np.ndarray, path: str, sr: Optional[int] = None, pipe_out=None) -> None: + def save_wav( + self, wav: np.ndarray, path: Union[str, os.PathLike[Any]], sr: Optional[int] = None, pipe_out=None + ) -> None: """Save a waveform to a file using Scipy. Args: diff --git a/TTS/utils/generic_utils.py b/TTS/utils/generic_utils.py index 087ae7d0e1..d7397f673d 100644 --- a/TTS/utils/generic_utils.py +++ b/TTS/utils/generic_utils.py @@ -4,7 +4,7 @@ import logging import re from pathlib import Path -from typing import Callable, Dict, Optional, TypeVar, Union +from typing import Any, Callable, Dict, Optional, TypeVar, Union import torch from packaging.version import Version @@ -133,3 +133,8 @@ def setup_logger( def is_pytorch_at_least_2_4() -> bool: """Check if the installed Pytorch version is 2.4 or higher.""" return Version(torch.__version__) >= Version("2.4") + + +def optional_to_str(x: Optional[Any]) -> str: + """Convert input to string, using empty string if input is None.""" + return "" if x is None else str(x) diff --git a/TTS/utils/manage.py b/TTS/utils/manage.py index 38fcfd60e9..b33243ffa9 100644 --- a/TTS/utils/manage.py +++ b/TTS/utils/manage.py @@ -6,17 +6,35 @@ import zipfile from pathlib import Path from shutil import copyfile, rmtree -from typing import Dict, Tuple +from typing import Any, Optional, TypedDict, Union import fsspec import requests from tqdm import tqdm from trainer.io import get_user_data_dir +from typing_extensions import Required from TTS.config import load_config, read_json_with_comments logger = logging.getLogger(__name__) + +class ModelItem(TypedDict, total=False): + model_name: Required[str] + model_type: Required[str] + description: str + license: str + author: str + contact: str + commit: Optional[str] + model_hash: str + tos_required: bool + default_vocoder: Optional[str] + model_url: Union[str, list[str]] + github_rls_url: Union[str, list[str]] + hf_url: list[str] + + LICENSE_URLS = { "cc by-nc-nd 4.0": "https://creativecommons.org/licenses/by-nc-nd/4.0/", "mpl": "https://www.mozilla.org/en-US/MPL/2.0/", @@ -40,19 +58,24 @@ class ModelManager(object): home path. Args: - models_file (str): path to .model.json file. Defaults to None. - output_prefix (str): prefix to `tts` to download models. Defaults to None + models_file (str or Path): path to .model.json file. Defaults to None. + output_prefix (str or Path): prefix to `tts` to download models. Defaults to None progress_bar (bool): print a progress bar when donwloading a file. Defaults to False. """ - def __init__(self, models_file=None, output_prefix=None, progress_bar=False): + def __init__( + self, + models_file: Optional[Union[str, os.PathLike[Any]]] = None, + output_prefix: Optional[Union[str, os.PathLike[Any]]] = None, + progress_bar: bool = False, + ) -> None: super().__init__() self.progress_bar = progress_bar if output_prefix is None: self.output_prefix = get_user_data_dir("tts") else: - self.output_prefix = os.path.join(output_prefix, "tts") - self.models_dict = None + self.output_prefix = Path(output_prefix) / "tts" + self.models_dict = {} if models_file is not None: self.read_models_file(models_file) else: @@ -60,7 +83,7 @@ def __init__(self, models_file=None, output_prefix=None, progress_bar=False): path = Path(__file__).parent / "../.models.json" self.read_models_file(path) - def read_models_file(self, file_path): + def read_models_file(self, file_path: Union[str, os.PathLike[Any]]) -> None: """Read .models.json as a dict Args: @@ -68,7 +91,7 @@ def read_models_file(self, file_path): """ self.models_dict = read_json_with_comments(file_path) - def _list_models(self, model_type, model_count=0): + def _list_models(self, model_type: str, model_count: int = 0) -> list[str]: logger.info("") logger.info("Name format: type/language/dataset/model") model_list = [] @@ -83,13 +106,13 @@ def _list_models(self, model_type, model_count=0): model_count += 1 return model_list - def _list_for_model_type(self, model_type): + def _list_for_model_type(self, model_type: str) -> list[str]: models_name_list = [] model_count = 1 models_name_list.extend(self._list_models(model_type, model_count)) return models_name_list - def list_models(self): + def list_models(self) -> list[str]: models_name_list = [] model_count = 1 for model_type in self.models_dict: @@ -97,7 +120,7 @@ def list_models(self): models_name_list.extend(model_list) return models_name_list - def log_model_details(self, model_type, lang, dataset, model): + def log_model_details(self, model_type: str, lang: str, dataset: str, model: str) -> None: logger.info("Model type: %s", model_type) logger.info("Language supported: %s", lang) logger.info("Dataset used: %s", dataset) @@ -112,7 +135,7 @@ def log_model_details(self, model_type, lang, dataset, model): self.models_dict[model_type][lang][dataset][model]["default_vocoder"], ) - def model_info_by_idx(self, model_query): + def model_info_by_idx(self, model_query: str) -> None: """Print the description of the model from .models.json file using model_query_idx Args: @@ -144,7 +167,7 @@ def model_info_by_idx(self, model_query): model_type, lang, dataset, model = model_name_list[model_query_idx - 1].split("/") self.log_model_details(model_type, lang, dataset, model) - def model_info_by_full_name(self, model_query_name): + def model_info_by_full_name(self, model_query_name: str) -> None: """Print the description of the model from .models.json file using model_full_name Args: @@ -165,35 +188,35 @@ def model_info_by_full_name(self, model_query_name): return self.log_model_details(model_type, lang, dataset, model) - def list_tts_models(self): + def list_tts_models(self) -> list[str]: """Print all `TTS` models and return a list of model names Format is `language/dataset/model` """ return self._list_for_model_type("tts_models") - def list_vocoder_models(self): + def list_vocoder_models(self) -> list[str]: """Print all the `vocoder` models and return a list of model names Format is `language/dataset/model` """ return self._list_for_model_type("vocoder_models") - def list_vc_models(self): + def list_vc_models(self) -> list[str]: """Print all the voice conversion models and return a list of model names Format is `language/dataset/model` """ return self._list_for_model_type("voice_conversion_models") - def list_langs(self): + def list_langs(self) -> None: """Print all the available languages""" logger.info("Name format: type/language") for model_type in self.models_dict: for lang in self.models_dict[model_type]: logger.info(" %s/%s", model_type, lang) - def list_datasets(self): + def list_datasets(self) -> None: """Print all the datasets""" logger.info("Name format: type/language/dataset") for model_type in self.models_dict: @@ -202,7 +225,7 @@ def list_datasets(self): logger.info(" %s/%s/%s", model_type, lang, dataset) @staticmethod - def print_model_license(model_item: Dict): + def print_model_license(model_item: ModelItem) -> None: """Print the license of a model Args: @@ -217,27 +240,27 @@ def print_model_license(model_item: Dict): else: logger.info("Model's license - No license information available") - def _download_github_model(self, model_item: Dict, output_path: str): + def _download_github_model(self, model_item: ModelItem, output_path: Path) -> None: if isinstance(model_item["github_rls_url"], list): self._download_model_files(model_item["github_rls_url"], output_path, self.progress_bar) else: self._download_zip_file(model_item["github_rls_url"], output_path, self.progress_bar) - def _download_hf_model(self, model_item: Dict, output_path: str): + def _download_hf_model(self, model_item: ModelItem, output_path: Path) -> None: if isinstance(model_item["hf_url"], list): self._download_model_files(model_item["hf_url"], output_path, self.progress_bar) else: self._download_zip_file(model_item["hf_url"], output_path, self.progress_bar) - def download_fairseq_model(self, model_name, output_path): + def download_fairseq_model(self, model_name: str, output_path: Path) -> None: URI_PREFIX = "https://dl.fbaipublicfiles.com/mms/tts/" _, lang, _, _ = model_name.split("/") model_download_uri = os.path.join(URI_PREFIX, f"{lang}.tar.gz") self._download_tar_file(model_download_uri, output_path, self.progress_bar) @staticmethod - def set_model_url(model_item: Dict): - model_item["model_url"] = None + def set_model_url(model_item: ModelItem) -> ModelItem: + model_item["model_url"] = "" if "github_rls_url" in model_item: model_item["model_url"] = model_item["github_rls_url"] elif "hf_url" in model_item: @@ -248,18 +271,18 @@ def set_model_url(model_item: Dict): model_item["model_url"] = "https://huggingface.co/coqui/" return model_item - def _set_model_item(self, model_name): + def _set_model_item(self, model_name: str) -> tuple[ModelItem, str, str, Optional[str]]: # fetch model info from the dict if "fairseq" in model_name: model_type, lang, dataset, model = model_name.split("/") - model_item = { + model_item: ModelItem = { + "model_name": model_name, "model_type": "tts_models", "license": "CC BY-NC 4.0", "default_vocoder": None, "author": "fairseq", "description": "this model is released by Meta under Fairseq repo. Visit https://github.com/facebookresearch/fairseq/tree/main/examples/mms for more info.", } - model_item["model_name"] = model_name elif "xtts" in model_name and len(model_name.split("/")) != 4: # loading xtts models with only model name (e.g. xtts_v2.0.2) # check model name has the version number with regex @@ -273,6 +296,8 @@ def _set_model_item(self, model_name): dataset = "multi-dataset" model = model_name model_item = { + "model_name": model_name, + "model_type": model_type, "default_vocoder": None, "license": "CPML", "contact": "info@coqui.ai", @@ -297,9 +322,9 @@ def _set_model_item(self, model_name): return model_item, model_full_name, model, md5hash @staticmethod - def ask_tos(model_full_path): + def ask_tos(model_full_path: Path) -> bool: """Ask the user to agree to the terms of service""" - tos_path = os.path.join(model_full_path, "tos_agreed.txt") + tos_path = model_full_path / "tos_agreed.txt" print(" > You must confirm the following:") print(' | > "I have purchased a commercial license from Coqui: licensing@coqui.ai"') print(' | > "Otherwise, I agree to the terms of the non-commercial CPML: https://coqui.ai/cpml" - [y/n]') @@ -311,7 +336,7 @@ def ask_tos(model_full_path): return False @staticmethod - def tos_agreed(model_item, model_full_path): + def tos_agreed(model_item: ModelItem, model_full_path: Path) -> bool: """Check if the user has agreed to the terms of service""" if "tos_required" in model_item and model_item["tos_required"]: tos_path = os.path.join(model_full_path, "tos_agreed.txt") @@ -320,12 +345,12 @@ def tos_agreed(model_item, model_full_path): return False return True - def create_dir_and_download_model(self, model_name, model_item, output_path): - os.makedirs(output_path, exist_ok=True) + def create_dir_and_download_model(self, model_name: str, model_item: ModelItem, output_path: Path) -> None: + output_path.mkdir(exist_ok=True, parents=True) # handle TOS if not self.tos_agreed(model_item, output_path): if not self.ask_tos(output_path): - os.rmdir(output_path) + output_path.rmdir() raise Exception(" [!] You must agree to the terms of service to use this model.") logger.info("Downloading model to %s", output_path) try: @@ -342,7 +367,7 @@ def create_dir_and_download_model(self, model_name, model_item, output_path): raise e self.print_model_license(model_item=model_item) - def check_if_configs_are_equal(self, model_name, model_item, output_path): + def check_if_configs_are_equal(self, model_name: str, model_item: ModelItem, output_path: Path) -> None: with fsspec.open(self._find_files(output_path)[1], "r", encoding="utf-8") as f: config_local = json.load(f) remote_url = None @@ -358,7 +383,7 @@ def check_if_configs_are_equal(self, model_name, model_item, output_path): logger.info("%s is already downloaded however it has been changed. Redownloading it...", model_name) self.create_dir_and_download_model(model_name, model_item, output_path) - def download_model(self, model_name): + def download_model(self, model_name: str) -> tuple[Path, Optional[Path], ModelItem]: """Download model files given the full model name. Model name is in the format 'type/language/dataset/model' @@ -374,12 +399,12 @@ def download_model(self, model_name): """ model_item, model_full_name, model, md5sum = self._set_model_item(model_name) # set the model specific output path - output_path = os.path.join(self.output_prefix, model_full_name) - if os.path.exists(output_path): + output_path = Path(self.output_prefix) / model_full_name + if output_path.is_dir(): if md5sum is not None: - md5sum_file = os.path.join(output_path, "hash.md5") - if os.path.isfile(md5sum_file): - with open(md5sum_file, mode="r") as f: + md5sum_file = output_path / "hash.md5" + if md5sum_file.is_file(): + with md5sum_file.open() as f: if not f.read() == md5sum: logger.info("%s has been updated, clearing model cache...", model_name) self.create_dir_and_download_model(model_name, model_item, output_path) @@ -407,12 +432,14 @@ def download_model(self, model_name): model not in ["tortoise-v2", "bark"] and "fairseq" not in model_name and "xtts" not in model_name ): # TODO:This is stupid but don't care for now. output_model_path, output_config_path = self._find_files(output_path) + else: + output_config_path = output_model_path / "config.json" # update paths in the config.json self._update_paths(output_path, output_config_path) return output_model_path, output_config_path, model_item @staticmethod - def _find_files(output_path: str) -> Tuple[str, str]: + def _find_files(output_path: Path) -> tuple[Path, Path]: """Find the model and config files in the output path Args: @@ -423,11 +450,11 @@ def _find_files(output_path: str) -> Tuple[str, str]: """ model_file = None config_file = None - for file_name in os.listdir(output_path): - if file_name in ["model_file.pth", "model_file.pth.tar", "model.pth", "checkpoint.pth"]: - model_file = os.path.join(output_path, file_name) - elif file_name == "config.json": - config_file = os.path.join(output_path, file_name) + for f in output_path.iterdir(): + if f.name in ["model_file.pth", "model_file.pth.tar", "model.pth", "checkpoint.pth"]: + model_file = f + elif f.name == "config.json": + config_file = f if model_file is None: raise ValueError(" [!] Model file not found in the output path") if config_file is None: @@ -435,7 +462,7 @@ def _find_files(output_path: str) -> Tuple[str, str]: return model_file, config_file @staticmethod - def _find_speaker_encoder(output_path: str) -> str: + def _find_speaker_encoder(output_path: Path) -> Optional[Path]: """Find the speaker encoder file in the output path Args: @@ -445,24 +472,24 @@ def _find_speaker_encoder(output_path: str) -> str: str: path to the speaker encoder file """ speaker_encoder_file = None - for file_name in os.listdir(output_path): - if file_name in ["model_se.pth", "model_se.pth.tar"]: - speaker_encoder_file = os.path.join(output_path, file_name) + for f in output_path.iterdir(): + if f.name in ["model_se.pth", "model_se.pth.tar"]: + speaker_encoder_file = f return speaker_encoder_file - def _update_paths(self, output_path: str, config_path: str) -> None: + def _update_paths(self, output_path: Path, config_path: Path) -> None: """Update paths for certain files in config.json after download. Args: output_path (str): local path the model is downloaded to. config_path (str): local config.json path. """ - output_stats_path = os.path.join(output_path, "scale_stats.npy") - output_d_vector_file_path = os.path.join(output_path, "speakers.json") - output_d_vector_file_pth_path = os.path.join(output_path, "speakers.pth") - output_speaker_ids_file_path = os.path.join(output_path, "speaker_ids.json") - output_speaker_ids_file_pth_path = os.path.join(output_path, "speaker_ids.pth") - speaker_encoder_config_path = os.path.join(output_path, "config_se.json") + output_stats_path = output_path / "scale_stats.npy" + output_d_vector_file_path = output_path / "speakers.json" + output_d_vector_file_pth_path = output_path / "speakers.pth" + output_speaker_ids_file_path = output_path / "speaker_ids.json" + output_speaker_ids_file_pth_path = output_path / "speaker_ids.pth" + speaker_encoder_config_path = output_path / "config_se.json" speaker_encoder_model_path = self._find_speaker_encoder(output_path) # update the scale_path.npy file path in the model config.json @@ -487,10 +514,10 @@ def _update_paths(self, output_path: str, config_path: str) -> None: self._update_path("model_args.speaker_encoder_config_path", speaker_encoder_config_path, config_path) @staticmethod - def _update_path(field_name, new_path, config_path): + def _update_path(field_name: str, new_path: Optional[Path], config_path: Path) -> None: """Update the path in the model config.json for the current environment after download""" - if new_path and os.path.exists(new_path): - config = load_config(config_path) + if new_path is not None and new_path.is_file(): + config = load_config(str(config_path)) field_names = field_name.split(".") if len(field_names) > 1: # field name points to a sub-level field @@ -515,7 +542,7 @@ def _update_path(field_name, new_path, config_path): config.save_json(config_path) @staticmethod - def _download_zip_file(file_url, output_folder, progress_bar): + def _download_zip_file(file_url: str, output_folder: Path, progress_bar: bool) -> None: """Download the github releases""" # download the file r = requests.get(file_url, stream=True) @@ -525,7 +552,7 @@ def _download_zip_file(file_url, output_folder, progress_bar): block_size = 1024 # 1 Kibibyte if progress_bar: ModelManager.tqdm_progress = tqdm(total=total_size_in_bytes, unit="iB", unit_scale=True) - temp_zip_name = os.path.join(output_folder, file_url.split("/")[-1]) + temp_zip_name = output_folder / file_url.split("/")[-1] with open(temp_zip_name, "wb") as file: for data in r.iter_content(block_size): if progress_bar: @@ -533,24 +560,24 @@ def _download_zip_file(file_url, output_folder, progress_bar): file.write(data) with zipfile.ZipFile(temp_zip_name) as z: z.extractall(output_folder) - os.remove(temp_zip_name) # delete zip after extract + temp_zip_name.unlink() # delete zip after extract except zipfile.BadZipFile: logger.exception("Bad zip file - %s", file_url) raise zipfile.BadZipFile # pylint: disable=raise-missing-from # move the files to the outer path for file_path in z.namelist(): - src_path = os.path.join(output_folder, file_path) - if os.path.isfile(src_path): - dst_path = os.path.join(output_folder, os.path.basename(file_path)) + src_path = output_folder / file_path + if src_path.is_file(): + dst_path = output_folder / os.path.basename(file_path) if src_path != dst_path: copyfile(src_path, dst_path) # remove redundant (hidden or not) folders for file_path in z.namelist(): - if os.path.isdir(os.path.join(output_folder, file_path)): - rmtree(os.path.join(output_folder, file_path)) + if (output_folder / file_path).is_dir(): + rmtree(output_folder / file_path) @staticmethod - def _download_tar_file(file_url, output_folder, progress_bar): + def _download_tar_file(file_url: str, output_folder: Path, progress_bar: bool) -> None: """Download the github releases""" # download the file r = requests.get(file_url, stream=True) @@ -560,7 +587,7 @@ def _download_tar_file(file_url, output_folder, progress_bar): block_size = 1024 # 1 Kibibyte if progress_bar: ModelManager.tqdm_progress = tqdm(total=total_size_in_bytes, unit="iB", unit_scale=True) - temp_tar_name = os.path.join(output_folder, file_url.split("/")[-1]) + temp_tar_name = output_folder / file_url.split("/")[-1] with open(temp_tar_name, "wb") as file: for data in r.iter_content(block_size): if progress_bar: @@ -569,43 +596,37 @@ def _download_tar_file(file_url, output_folder, progress_bar): with tarfile.open(temp_tar_name) as t: t.extractall(output_folder) tar_names = t.getnames() - os.remove(temp_tar_name) # delete tar after extract + temp_tar_name.unlink() # delete tar after extract except tarfile.ReadError: logger.exception("Bad tar file - %s", file_url) raise tarfile.ReadError # pylint: disable=raise-missing-from # move the files to the outer path - for file_path in os.listdir(os.path.join(output_folder, tar_names[0])): - src_path = os.path.join(output_folder, tar_names[0], file_path) - dst_path = os.path.join(output_folder, os.path.basename(file_path)) + for file_path in (output_folder / tar_names[0]).iterdir(): + src_path = file_path + dst_path = output_folder / file_path.name if src_path != dst_path: copyfile(src_path, dst_path) # remove the extracted folder - rmtree(os.path.join(output_folder, tar_names[0])) + rmtree(output_folder / tar_names[0]) @staticmethod - def _download_model_files(file_urls, output_folder, progress_bar): + def _download_model_files( + file_urls: list[str], output_folder: Union[str, os.PathLike[Any]], progress_bar: bool + ) -> None: """Download the github releases""" + output_folder = Path(output_folder) for file_url in file_urls: # download the file r = requests.get(file_url, stream=True) # extract the file - bease_filename = file_url.split("/")[-1] - temp_zip_name = os.path.join(output_folder, bease_filename) + base_filename = file_url.split("/")[-1] + file_path = output_folder / base_filename total_size_in_bytes = int(r.headers.get("content-length", 0)) block_size = 1024 # 1 Kibibyte - with open(temp_zip_name, "wb") as file: + with open(file_path, "wb") as f: if progress_bar: ModelManager.tqdm_progress = tqdm(total=total_size_in_bytes, unit="iB", unit_scale=True) for data in r.iter_content(block_size): if progress_bar: ModelManager.tqdm_progress.update(len(data)) - file.write(data) - - @staticmethod - def _check_dict_key(my_dict, key): - if key in my_dict.keys() and my_dict[key] is not None: - if not isinstance(key, str): - return True - if isinstance(key, str) and len(my_dict[key]) > 0: - return True - return False + f.write(data) diff --git a/TTS/utils/synthesizer.py b/TTS/utils/synthesizer.py index a9b9feffc1..52f5a86de5 100644 --- a/TTS/utils/synthesizer.py +++ b/TTS/utils/synthesizer.py @@ -2,7 +2,7 @@ import os import time from pathlib import Path -from typing import List +from typing import Any, List, Optional, Union import numpy as np import pysbd @@ -16,6 +16,7 @@ from TTS.tts.utils.synthesis import synthesis, transfer_voice, trim_silence from TTS.utils.audio import AudioProcessor from TTS.utils.audio.numpy_transforms import save_wav +from TTS.utils.generic_utils import optional_to_str from TTS.vc.configs.openvoice_config import OpenVoiceConfig from TTS.vc.models import setup_model as setup_vc_model from TTS.vc.models.openvoice import OpenVoice @@ -29,18 +30,18 @@ class Synthesizer(nn.Module): def __init__( self, *, - tts_checkpoint: str = "", - tts_config_path: str = "", - tts_speakers_file: str = "", - tts_languages_file: str = "", - vocoder_checkpoint: str = "", - vocoder_config: str = "", - encoder_checkpoint: str = "", - encoder_config: str = "", - vc_checkpoint: str = "", - vc_config: str = "", - model_dir: str = "", - voice_dir: str = None, + tts_checkpoint: Optional[Union[str, os.PathLike[Any]]] = None, + tts_config_path: Optional[Union[str, os.PathLike[Any]]] = None, + tts_speakers_file: Optional[Union[str, os.PathLike[Any]]] = None, + tts_languages_file: Optional[Union[str, os.PathLike[Any]]] = None, + vocoder_checkpoint: Optional[Union[str, os.PathLike[Any]]] = None, + vocoder_config: Optional[Union[str, os.PathLike[Any]]] = None, + encoder_checkpoint: Optional[Union[str, os.PathLike[Any]]] = None, + encoder_config: Optional[Union[str, os.PathLike[Any]]] = None, + vc_checkpoint: Optional[Union[str, os.PathLike[Any]]] = None, + vc_config: Optional[Union[str, os.PathLike[Any]]] = None, + model_dir: Optional[Union[str, os.PathLike[Any]]] = None, + voice_dir: Optional[Union[str, os.PathLike[Any]]] = None, use_cuda: bool = False, ) -> None: """General 🐸 TTS interface for inference. It takes a tts and a vocoder @@ -66,16 +67,17 @@ def __init__( use_cuda (bool, optional): enable/disable cuda. Defaults to False. """ super().__init__() - self.tts_checkpoint = tts_checkpoint - self.tts_config_path = tts_config_path - self.tts_speakers_file = tts_speakers_file - self.tts_languages_file = tts_languages_file - self.vocoder_checkpoint = vocoder_checkpoint - self.vocoder_config = vocoder_config - self.encoder_checkpoint = encoder_checkpoint - self.encoder_config = encoder_config - self.vc_checkpoint = vc_checkpoint - self.vc_config = vc_config + self.tts_checkpoint = optional_to_str(tts_checkpoint) + self.tts_config_path = optional_to_str(tts_config_path) + self.tts_speakers_file = optional_to_str(tts_speakers_file) + self.tts_languages_file = optional_to_str(tts_languages_file) + self.vocoder_checkpoint = optional_to_str(vocoder_checkpoint) + self.vocoder_config = optional_to_str(vocoder_config) + self.encoder_checkpoint = optional_to_str(encoder_checkpoint) + self.encoder_config = optional_to_str(encoder_config) + self.vc_checkpoint = optional_to_str(vc_checkpoint) + self.vc_config = optional_to_str(vc_config) + model_dir = optional_to_str(model_dir) self.use_cuda = use_cuda self.tts_model = None @@ -89,18 +91,18 @@ def __init__( self.d_vector_dim = 0 self.seg = self._get_segmenter("en") self.use_cuda = use_cuda - self.voice_dir = voice_dir + self.voice_dir = optional_to_str(voice_dir) if self.use_cuda: assert torch.cuda.is_available(), "CUDA is not availabe on this machine." if tts_checkpoint: - self._load_tts(tts_checkpoint, tts_config_path, use_cuda) + self._load_tts(self.tts_checkpoint, self.tts_config_path, use_cuda) if vocoder_checkpoint: - self._load_vocoder(vocoder_checkpoint, vocoder_config, use_cuda) + self._load_vocoder(self.vocoder_checkpoint, self.vocoder_config, use_cuda) - if vc_checkpoint and model_dir is None: - self._load_vc(vc_checkpoint, vc_config, use_cuda) + if vc_checkpoint and model_dir == "": + self._load_vc(self.vc_checkpoint, self.vc_config, use_cuda) if model_dir: if "fairseq" in model_dir: diff --git a/tests/zoo_tests/test_models.py b/tests/zoo_tests/test_models.py index f38880b51f..461b4fbe12 100644 --- a/tests/zoo_tests/test_models.py +++ b/tests/zoo_tests/test_models.py @@ -1,5 +1,4 @@ #!/usr/bin/env python3` -import glob import os import shutil @@ -30,22 +29,22 @@ def run_models(offset=0, step=1): print(f"\n > Run - {model_name}") model_path, _, _ = manager.download_model(model_name) if "tts_models" in model_name: - local_download_dir = os.path.dirname(model_path) + local_download_dir = model_path.parent # download and run the model - speaker_files = glob.glob(local_download_dir + "/speaker*") - language_files = glob.glob(local_download_dir + "/language*") + speaker_files = list(local_download_dir.glob("speaker*")) + language_files = list(local_download_dir.glob("language*")) speaker_arg = "" language_arg = "" if len(speaker_files) > 0: # multi-speaker model - if "speaker_ids" in speaker_files[0]: + if "speaker_ids" in speaker_files[0].stem: speaker_manager = SpeakerManager(speaker_id_file_path=speaker_files[0]) - elif "speakers" in speaker_files[0]: + elif "speakers" in speaker_files[0].stem: speaker_manager = SpeakerManager(d_vectors_file_path=speaker_files[0]) speakers = list(speaker_manager.name_to_id.keys()) if len(speakers) > 1: speaker_arg = f'--speaker_idx "{speakers[0]}"' - if len(language_files) > 0 and "language_ids" in language_files[0]: + if len(language_files) > 0 and "language_ids" in language_files[0].stem: # multi-lingual model language_manager = LanguageManager(language_ids_file_path=language_files[0]) languages = language_manager.language_names From 0df04cc259c7094f2b0f64841da634045b3f6894 Mon Sep 17 00:00:00 2001 From: Enno Hermann Date: Sat, 14 Dec 2024 15:52:13 +0100 Subject: [PATCH 2/2] docs: add notes about xtts fine-tuning --- TTS/bin/synthesize.py | 6 +++--- docs/source/faq.md | 8 +++++++- docs/source/training/finetuning.md | 3 +++ docs/source/training/index.md | 3 +++ docs/source/tutorial_for_nervous_beginners.md | 3 +++ 5 files changed, 19 insertions(+), 4 deletions(-) diff --git a/TTS/bin/synthesize.py b/TTS/bin/synthesize.py index 5fce93b7f4..47b442e266 100755 --- a/TTS/bin/synthesize.py +++ b/TTS/bin/synthesize.py @@ -34,7 +34,7 @@ tts --model_info_by_name vocoder_models/en/ljspeech/hifigan_v2 ``` -#### Single Speaker Models +#### Single speaker models - Run TTS with the default model (`tts_models/en/ljspeech/tacotron2-DDC`): @@ -102,7 +102,7 @@ --vocoder_config_path path/to/vocoder_config.json ``` -#### Multi-speaker Models +#### Multi-speaker models - List the available speakers and choose a `` among them: @@ -125,7 +125,7 @@ --speakers_file_path path/to/speaker.json --speaker_idx ``` -#### Voice Conversion Models +#### Voice conversion models ```sh tts --out_path output/path/speech.wav --model_name "//" \\ diff --git a/docs/source/faq.md b/docs/source/faq.md index 1dd5c1847b..a0eb5bbee4 100644 --- a/docs/source/faq.md +++ b/docs/source/faq.md @@ -16,13 +16,19 @@ We tried to collect common issues and questions we receive about 🐸TTS. It is - If you need faster models, consider SpeedySpeech, GlowTTS or AlignTTS. Keep in mind that SpeedySpeech requires a pre-trained Tacotron or Tacotron2 model to compute text-to-speech alignments. ## How can I train my own `tts` model? + +```{note} XTTS has separate fine-tuning scripts, see [here](models/xtts.md#training). +``` + 0. Check your dataset with notebooks in [dataset_analysis](https://github.com/idiap/coqui-ai-TTS/tree/main/notebooks/dataset_analysis) folder. Use [this notebook](https://github.com/idiap/coqui-ai-TTS/blob/main/notebooks/dataset_analysis/CheckSpectrograms.ipynb) to find the right audio processing parameters. A better set of parameters results in a better audio synthesis. 1. Write your own dataset `formatter` in `datasets/formatters.py` or [format](datasets/formatting_your_dataset) your dataset as one of the supported datasets, like LJSpeech. A `formatter` parses the metadata file and converts a list of training samples. 2. If you have a dataset with a different alphabet than English, you need to set your own character list in the ```config.json```. - - If you use phonemes for training and your language is supported [here](https://github.com/rhasspy/gruut#supported-languages), you don't need to set your character list. + - If you use phonemes for training and your language is supported by + [Espeak](https://github.com/espeak-ng/espeak-ng/blob/master/docs/languages.md) + or [Gruut](https://github.com/rhasspy/gruut#supported-languages), you don't need to set your character list. - You can use `TTS/bin/find_unique_chars.py` to get characters used in your dataset. 3. Write your own text cleaner in ```utils.text.cleaners```. It is not always necessary, except when you have a different alphabet or language-specific requirements. diff --git a/docs/source/training/finetuning.md b/docs/source/training/finetuning.md index 1fe54fbcde..fa2ed34a54 100644 --- a/docs/source/training/finetuning.md +++ b/docs/source/training/finetuning.md @@ -29,6 +29,9 @@ them and fine-tune it for your own dataset. This will help you in two main ways: ## Steps to fine-tune a 🐸 TTS model +```{note} XTTS has separate fine-tuning scripts, see [here](../models/xtts.md#training). +``` + 1. Setup your dataset. You need to format your target dataset in a certain way so that 🐸TTS data loader will be able to load it for the diff --git a/docs/source/training/index.md b/docs/source/training/index.md index bb76a705df..b09f9cadcb 100644 --- a/docs/source/training/index.md +++ b/docs/source/training/index.md @@ -8,3 +8,6 @@ The following pages show you how to train and fine-tune Coqui models: training_a_model finetuning ``` + +Also see the [XTTS page](../models/xtts.md#training) if you want to fine-tune +that model. diff --git a/docs/source/tutorial_for_nervous_beginners.md b/docs/source/tutorial_for_nervous_beginners.md index a8a64410c4..5e5eac0e0a 100644 --- a/docs/source/tutorial_for_nervous_beginners.md +++ b/docs/source/tutorial_for_nervous_beginners.md @@ -29,6 +29,9 @@ CLI, server or Python API. ## Training a `tts` Model +```{note} XTTS has separate fine-tuning scripts, see [here](models/xtts.md#training). +``` + A breakdown of a simple script that trains a GlowTTS model on the LJspeech dataset. For a more in-depth guide to training and fine-tuning also see [this page](training/index.md).