diff --git a/TTS/__init__.py b/TTS/__init__.py index 64c7369bc0..8e93c9b5db 100644 --- a/TTS/__init__.py +++ b/TTS/__init__.py @@ -1,29 +1,33 @@ -import _codecs import importlib.metadata -from collections import defaultdict -import numpy as np -import torch - -from TTS.config.shared_configs import BaseDatasetConfig -from TTS.tts.configs.xtts_config import XttsConfig -from TTS.tts.models.xtts import XttsArgs, XttsAudioConfig -from TTS.utils.radam import RAdam +from TTS.utils.generic_utils import is_pytorch_at_least_2_4 __version__ = importlib.metadata.version("coqui-tts") -torch.serialization.add_safe_globals([dict, defaultdict, RAdam]) +if is_pytorch_at_least_2_4(): + import _codecs + from collections import defaultdict + + import numpy as np + import torch + + from TTS.config.shared_configs import BaseDatasetConfig + from TTS.tts.configs.xtts_config import XttsConfig + from TTS.tts.models.xtts import XttsArgs, XttsAudioConfig + from TTS.utils.radam import RAdam + + torch.serialization.add_safe_globals([dict, defaultdict, RAdam]) -# Bark -torch.serialization.add_safe_globals( - [ - np.core.multiarray.scalar, - np.dtype, - np.dtypes.Float64DType, - _codecs.encode, # TODO: safe by default from Pytorch 2.5 - ] -) + # Bark + torch.serialization.add_safe_globals( + [ + np.core.multiarray.scalar, + np.dtype, + np.dtypes.Float64DType, + _codecs.encode, # TODO: safe by default from Pytorch 2.5 + ] + ) -# XTTS -torch.serialization.add_safe_globals([BaseDatasetConfig, XttsConfig, XttsAudioConfig, XttsArgs]) + # XTTS + torch.serialization.add_safe_globals([BaseDatasetConfig, XttsConfig, XttsAudioConfig, XttsArgs]) diff --git a/TTS/tts/layers/bark/load_model.py b/TTS/tts/layers/bark/load_model.py index 7785aab845..72eca30ac6 100644 --- a/TTS/tts/layers/bark/load_model.py +++ b/TTS/tts/layers/bark/load_model.py @@ -10,6 +10,7 @@ from TTS.tts.layers.bark.model import GPT, GPTConfig from TTS.tts.layers.bark.model_fine import FineGPT, FineGPTConfig +from TTS.utils.generic_utils import is_pytorch_at_least_2_4 if ( torch.cuda.is_available() @@ -118,7 +119,7 @@ def load_model(ckpt_path, device, config, model_type="text"): logger.info(f"{model_type} model not found, downloading...") _download(config.REMOTE_MODEL_PATHS[model_type]["path"], ckpt_path, config.CACHE_DIR) - checkpoint = torch.load(ckpt_path, map_location=device, weights_only=True) + checkpoint = torch.load(ckpt_path, map_location=device, weights_only=is_pytorch_at_least_2_4()) # this is a hack model_args = checkpoint["model_args"] if "input_vocab_size" not in model_args: diff --git a/TTS/tts/layers/tortoise/arch_utils.py b/TTS/tts/layers/tortoise/arch_utils.py index f4dbcc8054..52c2526695 100644 --- a/TTS/tts/layers/tortoise/arch_utils.py +++ b/TTS/tts/layers/tortoise/arch_utils.py @@ -9,6 +9,7 @@ from transformers import LogitsWarper from TTS.tts.layers.tortoise.xtransformers import ContinuousTransformerWrapper, RelativePositionBias +from TTS.utils.generic_utils import is_pytorch_at_least_2_4 def zero_module(module): @@ -332,7 +333,7 @@ def __init__( self.mel_norm_file = mel_norm_file if self.mel_norm_file is not None: with fsspec.open(self.mel_norm_file) as f: - self.mel_norms = torch.load(f, weights_only=True) + self.mel_norms = torch.load(f, weights_only=is_pytorch_at_least_2_4()) else: self.mel_norms = None diff --git a/TTS/tts/layers/tortoise/audio_utils.py b/TTS/tts/layers/tortoise/audio_utils.py index 94c2bae6fa..4f299a8fd9 100644 --- a/TTS/tts/layers/tortoise/audio_utils.py +++ b/TTS/tts/layers/tortoise/audio_utils.py @@ -10,6 +10,7 @@ from scipy.io.wavfile import read from TTS.utils.audio.torch_transforms import TorchSTFT +from TTS.utils.generic_utils import is_pytorch_at_least_2_4 logger = logging.getLogger(__name__) @@ -124,7 +125,7 @@ def load_voice(voice: str, extra_voice_dirs: List[str] = []): voices = get_voices(extra_voice_dirs) paths = voices[voice] if len(paths) == 1 and paths[0].endswith(".pth"): - return None, torch.load(paths[0], weights_only=True) + return None, torch.load(paths[0], weights_only=is_pytorch_at_least_2_4()) else: conds = [] for cond_path in paths: diff --git a/TTS/tts/layers/xtts/dvae.py b/TTS/tts/layers/xtts/dvae.py index 58f91785a1..73970fb0bf 100644 --- a/TTS/tts/layers/xtts/dvae.py +++ b/TTS/tts/layers/xtts/dvae.py @@ -9,6 +9,8 @@ import torchaudio from einops import rearrange +from TTS.utils.generic_utils import is_pytorch_at_least_2_4 + logger = logging.getLogger(__name__) @@ -46,7 +48,7 @@ def dvae_wav_to_mel( mel = mel_stft(wav) mel = torch.log(torch.clamp(mel, min=1e-5)) if mel_norms is None: - mel_norms = torch.load(mel_norms_file, map_location=device, weights_only=True) + mel_norms = torch.load(mel_norms_file, map_location=device, weights_only=is_pytorch_at_least_2_4()) mel = mel / mel_norms.unsqueeze(0).unsqueeze(-1) return mel diff --git a/TTS/tts/layers/xtts/hifigan_decoder.py b/TTS/tts/layers/xtts/hifigan_decoder.py index 09bd06dfde..5ef0030b8b 100644 --- a/TTS/tts/layers/xtts/hifigan_decoder.py +++ b/TTS/tts/layers/xtts/hifigan_decoder.py @@ -9,6 +9,7 @@ from torch.nn.utils.parametrize import remove_parametrizations from trainer.io import load_fsspec +from TTS.utils.generic_utils import is_pytorch_at_least_2_4 from TTS.vocoder.models.hifigan_generator import get_padding logger = logging.getLogger(__name__) @@ -328,7 +329,7 @@ def remove_weight_norm(self): def load_checkpoint( self, config, checkpoint_path, eval=False, cache=False ): # pylint: disable=unused-argument, redefined-builtin - state = torch.load(checkpoint_path, map_location=torch.device("cpu"), weights_only=True) + state = torch.load(checkpoint_path, map_location=torch.device("cpu"), weights_only=is_pytorch_at_least_2_4()) self.load_state_dict(state["model"]) if eval: self.eval() diff --git a/TTS/tts/layers/xtts/trainer/gpt_trainer.py b/TTS/tts/layers/xtts/trainer/gpt_trainer.py index f1aa6f8cd0..9d9edd5758 100644 --- a/TTS/tts/layers/xtts/trainer/gpt_trainer.py +++ b/TTS/tts/layers/xtts/trainer/gpt_trainer.py @@ -19,6 +19,7 @@ from TTS.tts.layers.xtts.trainer.dataset import XTTSDataset from TTS.tts.models.base_tts import BaseTTS from TTS.tts.models.xtts import Xtts, XttsArgs, XttsAudioConfig +from TTS.utils.generic_utils import is_pytorch_at_least_2_4 logger = logging.getLogger(__name__) @@ -91,7 +92,9 @@ def __init__(self, config: Coqpit): # load GPT if available if self.args.gpt_checkpoint: - gpt_checkpoint = torch.load(self.args.gpt_checkpoint, map_location=torch.device("cpu"), weights_only=True) + gpt_checkpoint = torch.load( + self.args.gpt_checkpoint, map_location=torch.device("cpu"), weights_only=is_pytorch_at_least_2_4() + ) # deal with coqui Trainer exported model if "model" in gpt_checkpoint.keys() and "config" in gpt_checkpoint.keys(): logger.info("Coqui Trainer checkpoint detected! Converting it!") @@ -184,7 +187,9 @@ def __init__(self, config: Coqpit): self.dvae.eval() if self.args.dvae_checkpoint: - dvae_checkpoint = torch.load(self.args.dvae_checkpoint, map_location=torch.device("cpu"), weights_only=True) + dvae_checkpoint = torch.load( + self.args.dvae_checkpoint, map_location=torch.device("cpu"), weights_only=is_pytorch_at_least_2_4() + ) self.dvae.load_state_dict(dvae_checkpoint, strict=False) logger.info("DVAE weights restored from: %s", self.args.dvae_checkpoint) else: diff --git a/TTS/tts/layers/xtts/xtts_manager.py b/TTS/tts/layers/xtts/xtts_manager.py index 5a3c47aead..8156b35f0d 100644 --- a/TTS/tts/layers/xtts/xtts_manager.py +++ b/TTS/tts/layers/xtts/xtts_manager.py @@ -1,9 +1,11 @@ import torch +from TTS.utils.generic_utils import is_pytorch_at_least_2_4 + class SpeakerManager: def __init__(self, speaker_file_path=None): - self.speakers = torch.load(speaker_file_path, weights_only=True) + self.speakers = torch.load(speaker_file_path, weights_only=is_pytorch_at_least_2_4()) @property def name_to_id(self): diff --git a/TTS/tts/models/neuralhmm_tts.py b/TTS/tts/models/neuralhmm_tts.py index 49c48c2bd4..de5401aac7 100644 --- a/TTS/tts/models/neuralhmm_tts.py +++ b/TTS/tts/models/neuralhmm_tts.py @@ -18,7 +18,7 @@ from TTS.tts.utils.speakers import SpeakerManager from TTS.tts.utils.text.tokenizer import TTSTokenizer from TTS.tts.utils.visual import plot_alignment, plot_spectrogram -from TTS.utils.generic_utils import format_aux_input +from TTS.utils.generic_utils import format_aux_input, is_pytorch_at_least_2_4 logger = logging.getLogger(__name__) @@ -107,7 +107,7 @@ def update_mean_std(self, statistics_dict: Dict): def preprocess_batch(self, text, text_len, mels, mel_len): if self.mean.item() == 0 or self.std.item() == 1: - statistics_dict = torch.load(self.mel_statistics_parameter_path, weights_only=True) + statistics_dict = torch.load(self.mel_statistics_parameter_path, weights_only=is_pytorch_at_least_2_4()) self.update_mean_std(statistics_dict) mels = self.normalize(mels) @@ -292,7 +292,9 @@ def on_init_start(self, trainer): "Data parameters found for: %s. Loading mel normalization parameters...", trainer.config.mel_statistics_parameter_path, ) - statistics = torch.load(trainer.config.mel_statistics_parameter_path, weights_only=True) + statistics = torch.load( + trainer.config.mel_statistics_parameter_path, weights_only=is_pytorch_at_least_2_4() + ) data_mean, data_std, init_transition_prob = ( statistics["mean"], statistics["std"], diff --git a/TTS/tts/models/overflow.py b/TTS/tts/models/overflow.py index 4c0f341be3..b72f4877cf 100644 --- a/TTS/tts/models/overflow.py +++ b/TTS/tts/models/overflow.py @@ -19,7 +19,7 @@ from TTS.tts.utils.speakers import SpeakerManager from TTS.tts.utils.text.tokenizer import TTSTokenizer from TTS.tts.utils.visual import plot_alignment, plot_spectrogram -from TTS.utils.generic_utils import format_aux_input +from TTS.utils.generic_utils import format_aux_input, is_pytorch_at_least_2_4 logger = logging.getLogger(__name__) @@ -120,7 +120,7 @@ def update_mean_std(self, statistics_dict: Dict): def preprocess_batch(self, text, text_len, mels, mel_len): if self.mean.item() == 0 or self.std.item() == 1: - statistics_dict = torch.load(self.mel_statistics_parameter_path, weights_only=True) + statistics_dict = torch.load(self.mel_statistics_parameter_path, weights_only=is_pytorch_at_least_2_4()) self.update_mean_std(statistics_dict) mels = self.normalize(mels) @@ -308,7 +308,9 @@ def on_init_start(self, trainer): "Data parameters found for: %s. Loading mel normalization parameters...", trainer.config.mel_statistics_parameter_path, ) - statistics = torch.load(trainer.config.mel_statistics_parameter_path, weights_only=True) + statistics = torch.load( + trainer.config.mel_statistics_parameter_path, weights_only=is_pytorch_at_least_2_4() + ) data_mean, data_std, init_transition_prob = ( statistics["mean"], statistics["std"], diff --git a/TTS/tts/models/tortoise.py b/TTS/tts/models/tortoise.py index 98e79d0cf1..01629b5d2a 100644 --- a/TTS/tts/models/tortoise.py +++ b/TTS/tts/models/tortoise.py @@ -23,6 +23,7 @@ from TTS.tts.layers.tortoise.vocoder import VocConf, VocType from TTS.tts.layers.tortoise.wav2vec_alignment import Wav2VecAlignment from TTS.tts.models.base_tts import BaseTTS +from TTS.utils.generic_utils import is_pytorch_at_least_2_4 logger = logging.getLogger(__name__) @@ -171,7 +172,11 @@ def classify_audio_clip(clip, model_dir): distribute_zero_label=False, ) classifier.load_state_dict( - torch.load(os.path.join(model_dir, "classifier.pth"), map_location=torch.device("cpu"), weights_only=True) + torch.load( + os.path.join(model_dir, "classifier.pth"), + map_location=torch.device("cpu"), + weights_only=is_pytorch_at_least_2_4(), + ) ) clip = clip.cpu().unsqueeze(0) results = F.softmax(classifier(clip), dim=-1) @@ -490,7 +495,7 @@ def get_random_conditioning_latents(self): torch.load( os.path.join(self.models_dir, "rlg_auto.pth"), map_location=torch.device("cpu"), - weights_only=True, + weights_only=is_pytorch_at_least_2_4(), ) ) self.rlg_diffusion = RandomLatentConverter(2048).eval() @@ -498,7 +503,7 @@ def get_random_conditioning_latents(self): torch.load( os.path.join(self.models_dir, "rlg_diffuser.pth"), map_location=torch.device("cpu"), - weights_only=True, + weights_only=is_pytorch_at_least_2_4(), ) ) with torch.no_grad(): @@ -885,17 +890,17 @@ def load_checkpoint( if os.path.exists(ar_path): # remove keys from the checkpoint that are not in the model - checkpoint = torch.load(ar_path, map_location=torch.device("cpu"), weights_only=True) + checkpoint = torch.load(ar_path, map_location=torch.device("cpu"), weights_only=is_pytorch_at_least_2_4()) # strict set False # due to removed `bias` and `masked_bias` changes in Transformers self.autoregressive.load_state_dict(checkpoint, strict=False) if os.path.exists(diff_path): - self.diffusion.load_state_dict(torch.load(diff_path, weights_only=True), strict=strict) + self.diffusion.load_state_dict(torch.load(diff_path, weights_only=is_pytorch_at_least_2_4()), strict=strict) if os.path.exists(clvp_path): - self.clvp.load_state_dict(torch.load(clvp_path, weights_only=True), strict=strict) + self.clvp.load_state_dict(torch.load(clvp_path, weights_only=is_pytorch_at_least_2_4()), strict=strict) if os.path.exists(vocoder_checkpoint_path): self.vocoder.load_state_dict( @@ -903,7 +908,7 @@ def load_checkpoint( torch.load( vocoder_checkpoint_path, map_location=torch.device("cpu"), - weights_only=True, + weights_only=is_pytorch_at_least_2_4(), ) ) ) diff --git a/TTS/tts/models/xtts.py b/TTS/tts/models/xtts.py index 0b7652e450..ef2cebee3c 100644 --- a/TTS/tts/models/xtts.py +++ b/TTS/tts/models/xtts.py @@ -16,6 +16,7 @@ from TTS.tts.layers.xtts.tokenizer import VoiceBpeTokenizer, split_sentence from TTS.tts.layers.xtts.xtts_manager import LanguageManager, SpeakerManager from TTS.tts.models.base_tts import BaseTTS +from TTS.utils.generic_utils import is_pytorch_at_least_2_4 logger = logging.getLogger(__name__) @@ -65,7 +66,7 @@ def wav_to_mel_cloning( mel = mel_stft(wav) mel = torch.log(torch.clamp(mel, min=1e-5)) if mel_norms is None: - mel_norms = torch.load(mel_norms_file, map_location=device, weights_only=True) + mel_norms = torch.load(mel_norms_file, map_location=device, weights_only=is_pytorch_at_least_2_4()) mel = mel / mel_norms.unsqueeze(0).unsqueeze(-1) return mel diff --git a/TTS/tts/utils/fairseq.py b/TTS/tts/utils/fairseq.py index 6eb1905d96..20907a0532 100644 --- a/TTS/tts/utils/fairseq.py +++ b/TTS/tts/utils/fairseq.py @@ -1,8 +1,10 @@ import torch +from TTS.utils.generic_utils import is_pytorch_at_least_2_4 + def rehash_fairseq_vits_checkpoint(checkpoint_file): - chk = torch.load(checkpoint_file, map_location=torch.device("cpu"), weights_only=True)["model"] + chk = torch.load(checkpoint_file, map_location=torch.device("cpu"), weights_only=is_pytorch_at_least_2_4())["model"] new_chk = {} for k, v in chk.items(): if "enc_p." in k: diff --git a/TTS/tts/utils/managers.py b/TTS/tts/utils/managers.py index 6f72581c08..6a2f7df67b 100644 --- a/TTS/tts/utils/managers.py +++ b/TTS/tts/utils/managers.py @@ -9,6 +9,7 @@ from TTS.config import load_config from TTS.encoder.utils.generic_utils import setup_encoder_model from TTS.utils.audio import AudioProcessor +from TTS.utils.generic_utils import is_pytorch_at_least_2_4 def load_file(path: str): @@ -17,7 +18,7 @@ def load_file(path: str): return json.load(f) elif path.endswith(".pth"): with fsspec.open(path, "rb") as f: - return torch.load(f, map_location="cpu", weights_only=True) + return torch.load(f, map_location="cpu", weights_only=is_pytorch_at_least_2_4()) else: raise ValueError("Unsupported file type") diff --git a/TTS/utils/generic_utils.py b/TTS/utils/generic_utils.py index 91f8844262..3ee285232f 100644 --- a/TTS/utils/generic_utils.py +++ b/TTS/utils/generic_utils.py @@ -6,6 +6,9 @@ from pathlib import Path from typing import Dict, Optional +import torch +from packaging.version import Version + logger = logging.getLogger(__name__) @@ -131,3 +134,8 @@ def setup_logger( sh = logging.StreamHandler() sh.setFormatter(formatter) lg.addHandler(sh) + + +def is_pytorch_at_least_2_4() -> bool: + """Check if the installed Pytorch version is 2.4 or higher.""" + return Version(torch.__version__) >= Version("2.4") diff --git a/TTS/vc/modules/freevc/wavlm/__init__.py b/TTS/vc/modules/freevc/wavlm/__init__.py index 528fade772..4046e137f5 100644 --- a/TTS/vc/modules/freevc/wavlm/__init__.py +++ b/TTS/vc/modules/freevc/wavlm/__init__.py @@ -5,6 +5,7 @@ import torch from trainer.io import get_user_data_dir +from TTS.utils.generic_utils import is_pytorch_at_least_2_4 from TTS.vc.modules.freevc.wavlm.wavlm import WavLM, WavLMConfig logger = logging.getLogger(__name__) @@ -26,7 +27,7 @@ def get_wavlm(device="cpu"): logger.info("Downloading WavLM model to %s ...", output_path) urllib.request.urlretrieve(model_uri, output_path) - checkpoint = torch.load(output_path, map_location=torch.device(device), weights_only=True) + checkpoint = torch.load(output_path, map_location=torch.device(device), weights_only=is_pytorch_at_least_2_4()) cfg = WavLMConfig(checkpoint["cfg"]) wavlm = WavLM(cfg).to(device) wavlm.load_state_dict(checkpoint["model"]) diff --git a/pyproject.toml b/pyproject.toml index d8aab49417..37cdd6aa69 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -47,7 +47,7 @@ dependencies = [ "numpy>=1.25.2,<2.0", "cython>=3.0.0", "scipy>=1.11.2", - "torch>=2.4", + "torch>=2.1", "torchaudio", "soundfile>=0.12.0", "librosa>=0.10.1", @@ -76,6 +76,9 @@ dependencies = [ "spacy[ja]>=3,<3.8", ] +[tool.uv.sources] +coqui-tts-trainer = { git = "https://github.com/idiap/coqui-ai-Trainer" , branch = "weights-only"} + [project.optional-dependencies] # Development dependencies dev = [