Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: only enable load with weights_only in pytorch>=2.4 #113

Draft
wants to merge 1 commit into
base: dev
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 25 additions & 21 deletions TTS/__init__.py
Original file line number Diff line number Diff line change
@@ -1,29 +1,33 @@
import _codecs
import importlib.metadata
from collections import defaultdict

import numpy as np
import torch

from TTS.config.shared_configs import BaseDatasetConfig
from TTS.tts.configs.xtts_config import XttsConfig
from TTS.tts.models.xtts import XttsArgs, XttsAudioConfig
from TTS.utils.radam import RAdam
from TTS.utils.generic_utils import is_pytorch_at_least_2_4

__version__ = importlib.metadata.version("coqui-tts")


torch.serialization.add_safe_globals([dict, defaultdict, RAdam])
if is_pytorch_at_least_2_4():
import _codecs
from collections import defaultdict

import numpy as np
import torch

from TTS.config.shared_configs import BaseDatasetConfig
from TTS.tts.configs.xtts_config import XttsConfig
from TTS.tts.models.xtts import XttsArgs, XttsAudioConfig
from TTS.utils.radam import RAdam

torch.serialization.add_safe_globals([dict, defaultdict, RAdam])

# Bark
torch.serialization.add_safe_globals(
[
np.core.multiarray.scalar,
np.dtype,
np.dtypes.Float64DType,
_codecs.encode, # TODO: safe by default from Pytorch 2.5
]
)
# Bark
torch.serialization.add_safe_globals(
[
np.core.multiarray.scalar,
np.dtype,
np.dtypes.Float64DType,
_codecs.encode, # TODO: safe by default from Pytorch 2.5
]
)

# XTTS
torch.serialization.add_safe_globals([BaseDatasetConfig, XttsConfig, XttsAudioConfig, XttsArgs])
# XTTS
torch.serialization.add_safe_globals([BaseDatasetConfig, XttsConfig, XttsAudioConfig, XttsArgs])
3 changes: 2 additions & 1 deletion TTS/tts/layers/bark/load_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

from TTS.tts.layers.bark.model import GPT, GPTConfig
from TTS.tts.layers.bark.model_fine import FineGPT, FineGPTConfig
from TTS.utils.generic_utils import is_pytorch_at_least_2_4

if (
torch.cuda.is_available()
Expand Down Expand Up @@ -118,7 +119,7 @@ def load_model(ckpt_path, device, config, model_type="text"):
logger.info(f"{model_type} model not found, downloading...")
_download(config.REMOTE_MODEL_PATHS[model_type]["path"], ckpt_path, config.CACHE_DIR)

checkpoint = torch.load(ckpt_path, map_location=device, weights_only=True)
checkpoint = torch.load(ckpt_path, map_location=device, weights_only=is_pytorch_at_least_2_4())
# this is a hack
model_args = checkpoint["model_args"]
if "input_vocab_size" not in model_args:
Expand Down
3 changes: 2 additions & 1 deletion TTS/tts/layers/tortoise/arch_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from transformers import LogitsWarper

from TTS.tts.layers.tortoise.xtransformers import ContinuousTransformerWrapper, RelativePositionBias
from TTS.utils.generic_utils import is_pytorch_at_least_2_4


def zero_module(module):
Expand Down Expand Up @@ -332,7 +333,7 @@ def __init__(
self.mel_norm_file = mel_norm_file
if self.mel_norm_file is not None:
with fsspec.open(self.mel_norm_file) as f:
self.mel_norms = torch.load(f, weights_only=True)
self.mel_norms = torch.load(f, weights_only=is_pytorch_at_least_2_4())
else:
self.mel_norms = None

Expand Down
3 changes: 2 additions & 1 deletion TTS/tts/layers/tortoise/audio_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from scipy.io.wavfile import read

from TTS.utils.audio.torch_transforms import TorchSTFT
from TTS.utils.generic_utils import is_pytorch_at_least_2_4

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -124,7 +125,7 @@ def load_voice(voice: str, extra_voice_dirs: List[str] = []):
voices = get_voices(extra_voice_dirs)
paths = voices[voice]
if len(paths) == 1 and paths[0].endswith(".pth"):
return None, torch.load(paths[0], weights_only=True)
return None, torch.load(paths[0], weights_only=is_pytorch_at_least_2_4())
else:
conds = []
for cond_path in paths:
Expand Down
4 changes: 3 additions & 1 deletion TTS/tts/layers/xtts/dvae.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
import torchaudio
from einops import rearrange

from TTS.utils.generic_utils import is_pytorch_at_least_2_4

logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -46,7 +48,7 @@ def dvae_wav_to_mel(
mel = mel_stft(wav)
mel = torch.log(torch.clamp(mel, min=1e-5))
if mel_norms is None:
mel_norms = torch.load(mel_norms_file, map_location=device, weights_only=True)
mel_norms = torch.load(mel_norms_file, map_location=device, weights_only=is_pytorch_at_least_2_4())
mel = mel / mel_norms.unsqueeze(0).unsqueeze(-1)
return mel

Expand Down
3 changes: 2 additions & 1 deletion TTS/tts/layers/xtts/hifigan_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from torch.nn.utils.parametrize import remove_parametrizations
from trainer.io import load_fsspec

from TTS.utils.generic_utils import is_pytorch_at_least_2_4
from TTS.vocoder.models.hifigan_generator import get_padding

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -328,7 +329,7 @@ def remove_weight_norm(self):
def load_checkpoint(
self, config, checkpoint_path, eval=False, cache=False
): # pylint: disable=unused-argument, redefined-builtin
state = torch.load(checkpoint_path, map_location=torch.device("cpu"), weights_only=True)
state = torch.load(checkpoint_path, map_location=torch.device("cpu"), weights_only=is_pytorch_at_least_2_4())
self.load_state_dict(state["model"])
if eval:
self.eval()
Expand Down
9 changes: 7 additions & 2 deletions TTS/tts/layers/xtts/trainer/gpt_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from TTS.tts.layers.xtts.trainer.dataset import XTTSDataset
from TTS.tts.models.base_tts import BaseTTS
from TTS.tts.models.xtts import Xtts, XttsArgs, XttsAudioConfig
from TTS.utils.generic_utils import is_pytorch_at_least_2_4

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -91,7 +92,9 @@ def __init__(self, config: Coqpit):

# load GPT if available
if self.args.gpt_checkpoint:
gpt_checkpoint = torch.load(self.args.gpt_checkpoint, map_location=torch.device("cpu"), weights_only=True)
gpt_checkpoint = torch.load(
self.args.gpt_checkpoint, map_location=torch.device("cpu"), weights_only=is_pytorch_at_least_2_4()
)
# deal with coqui Trainer exported model
if "model" in gpt_checkpoint.keys() and "config" in gpt_checkpoint.keys():
logger.info("Coqui Trainer checkpoint detected! Converting it!")
Expand Down Expand Up @@ -184,7 +187,9 @@ def __init__(self, config: Coqpit):

self.dvae.eval()
if self.args.dvae_checkpoint:
dvae_checkpoint = torch.load(self.args.dvae_checkpoint, map_location=torch.device("cpu"), weights_only=True)
dvae_checkpoint = torch.load(
self.args.dvae_checkpoint, map_location=torch.device("cpu"), weights_only=is_pytorch_at_least_2_4()
)
self.dvae.load_state_dict(dvae_checkpoint, strict=False)
logger.info("DVAE weights restored from: %s", self.args.dvae_checkpoint)
else:
Expand Down
4 changes: 3 additions & 1 deletion TTS/tts/layers/xtts/xtts_manager.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import torch

from TTS.utils.generic_utils import is_pytorch_at_least_2_4


class SpeakerManager:
def __init__(self, speaker_file_path=None):
self.speakers = torch.load(speaker_file_path, weights_only=True)
self.speakers = torch.load(speaker_file_path, weights_only=is_pytorch_at_least_2_4())

@property
def name_to_id(self):
Expand Down
8 changes: 5 additions & 3 deletions TTS/tts/models/neuralhmm_tts.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from TTS.tts.utils.speakers import SpeakerManager
from TTS.tts.utils.text.tokenizer import TTSTokenizer
from TTS.tts.utils.visual import plot_alignment, plot_spectrogram
from TTS.utils.generic_utils import format_aux_input
from TTS.utils.generic_utils import format_aux_input, is_pytorch_at_least_2_4

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -107,7 +107,7 @@ def update_mean_std(self, statistics_dict: Dict):

def preprocess_batch(self, text, text_len, mels, mel_len):
if self.mean.item() == 0 or self.std.item() == 1:
statistics_dict = torch.load(self.mel_statistics_parameter_path, weights_only=True)
statistics_dict = torch.load(self.mel_statistics_parameter_path, weights_only=is_pytorch_at_least_2_4())
self.update_mean_std(statistics_dict)

mels = self.normalize(mels)
Expand Down Expand Up @@ -292,7 +292,9 @@ def on_init_start(self, trainer):
"Data parameters found for: %s. Loading mel normalization parameters...",
trainer.config.mel_statistics_parameter_path,
)
statistics = torch.load(trainer.config.mel_statistics_parameter_path, weights_only=True)
statistics = torch.load(
trainer.config.mel_statistics_parameter_path, weights_only=is_pytorch_at_least_2_4()
)
data_mean, data_std, init_transition_prob = (
statistics["mean"],
statistics["std"],
Expand Down
8 changes: 5 additions & 3 deletions TTS/tts/models/overflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from TTS.tts.utils.speakers import SpeakerManager
from TTS.tts.utils.text.tokenizer import TTSTokenizer
from TTS.tts.utils.visual import plot_alignment, plot_spectrogram
from TTS.utils.generic_utils import format_aux_input
from TTS.utils.generic_utils import format_aux_input, is_pytorch_at_least_2_4

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -120,7 +120,7 @@ def update_mean_std(self, statistics_dict: Dict):

def preprocess_batch(self, text, text_len, mels, mel_len):
if self.mean.item() == 0 or self.std.item() == 1:
statistics_dict = torch.load(self.mel_statistics_parameter_path, weights_only=True)
statistics_dict = torch.load(self.mel_statistics_parameter_path, weights_only=is_pytorch_at_least_2_4())
self.update_mean_std(statistics_dict)

mels = self.normalize(mels)
Expand Down Expand Up @@ -308,7 +308,9 @@ def on_init_start(self, trainer):
"Data parameters found for: %s. Loading mel normalization parameters...",
trainer.config.mel_statistics_parameter_path,
)
statistics = torch.load(trainer.config.mel_statistics_parameter_path, weights_only=True)
statistics = torch.load(
trainer.config.mel_statistics_parameter_path, weights_only=is_pytorch_at_least_2_4()
)
data_mean, data_std, init_transition_prob = (
statistics["mean"],
statistics["std"],
Expand Down
19 changes: 12 additions & 7 deletions TTS/tts/models/tortoise.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from TTS.tts.layers.tortoise.vocoder import VocConf, VocType
from TTS.tts.layers.tortoise.wav2vec_alignment import Wav2VecAlignment
from TTS.tts.models.base_tts import BaseTTS
from TTS.utils.generic_utils import is_pytorch_at_least_2_4

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -171,7 +172,11 @@ def classify_audio_clip(clip, model_dir):
distribute_zero_label=False,
)
classifier.load_state_dict(
torch.load(os.path.join(model_dir, "classifier.pth"), map_location=torch.device("cpu"), weights_only=True)
torch.load(
os.path.join(model_dir, "classifier.pth"),
map_location=torch.device("cpu"),
weights_only=is_pytorch_at_least_2_4(),
)
)
clip = clip.cpu().unsqueeze(0)
results = F.softmax(classifier(clip), dim=-1)
Expand Down Expand Up @@ -490,15 +495,15 @@ def get_random_conditioning_latents(self):
torch.load(
os.path.join(self.models_dir, "rlg_auto.pth"),
map_location=torch.device("cpu"),
weights_only=True,
weights_only=is_pytorch_at_least_2_4(),
)
)
self.rlg_diffusion = RandomLatentConverter(2048).eval()
self.rlg_diffusion.load_state_dict(
torch.load(
os.path.join(self.models_dir, "rlg_diffuser.pth"),
map_location=torch.device("cpu"),
weights_only=True,
weights_only=is_pytorch_at_least_2_4(),
)
)
with torch.no_grad():
Expand Down Expand Up @@ -885,25 +890,25 @@ def load_checkpoint(

if os.path.exists(ar_path):
# remove keys from the checkpoint that are not in the model
checkpoint = torch.load(ar_path, map_location=torch.device("cpu"), weights_only=True)
checkpoint = torch.load(ar_path, map_location=torch.device("cpu"), weights_only=is_pytorch_at_least_2_4())

# strict set False
# due to removed `bias` and `masked_bias` changes in Transformers
self.autoregressive.load_state_dict(checkpoint, strict=False)

if os.path.exists(diff_path):
self.diffusion.load_state_dict(torch.load(diff_path, weights_only=True), strict=strict)
self.diffusion.load_state_dict(torch.load(diff_path, weights_only=is_pytorch_at_least_2_4()), strict=strict)

if os.path.exists(clvp_path):
self.clvp.load_state_dict(torch.load(clvp_path, weights_only=True), strict=strict)
self.clvp.load_state_dict(torch.load(clvp_path, weights_only=is_pytorch_at_least_2_4()), strict=strict)

if os.path.exists(vocoder_checkpoint_path):
self.vocoder.load_state_dict(
config.model_args.vocoder.value.optionally_index(
torch.load(
vocoder_checkpoint_path,
map_location=torch.device("cpu"),
weights_only=True,
weights_only=is_pytorch_at_least_2_4(),
)
)
)
Expand Down
3 changes: 2 additions & 1 deletion TTS/tts/models/xtts.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from TTS.tts.layers.xtts.tokenizer import VoiceBpeTokenizer, split_sentence
from TTS.tts.layers.xtts.xtts_manager import LanguageManager, SpeakerManager
from TTS.tts.models.base_tts import BaseTTS
from TTS.utils.generic_utils import is_pytorch_at_least_2_4

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -65,7 +66,7 @@ def wav_to_mel_cloning(
mel = mel_stft(wav)
mel = torch.log(torch.clamp(mel, min=1e-5))
if mel_norms is None:
mel_norms = torch.load(mel_norms_file, map_location=device, weights_only=True)
mel_norms = torch.load(mel_norms_file, map_location=device, weights_only=is_pytorch_at_least_2_4())
mel = mel / mel_norms.unsqueeze(0).unsqueeze(-1)
return mel

Expand Down
4 changes: 3 additions & 1 deletion TTS/tts/utils/fairseq.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import torch

from TTS.utils.generic_utils import is_pytorch_at_least_2_4


def rehash_fairseq_vits_checkpoint(checkpoint_file):
chk = torch.load(checkpoint_file, map_location=torch.device("cpu"), weights_only=True)["model"]
chk = torch.load(checkpoint_file, map_location=torch.device("cpu"), weights_only=is_pytorch_at_least_2_4())["model"]
new_chk = {}
for k, v in chk.items():
if "enc_p." in k:
Expand Down
3 changes: 2 additions & 1 deletion TTS/tts/utils/managers.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from TTS.config import load_config
from TTS.encoder.utils.generic_utils import setup_encoder_model
from TTS.utils.audio import AudioProcessor
from TTS.utils.generic_utils import is_pytorch_at_least_2_4


def load_file(path: str):
Expand All @@ -17,7 +18,7 @@ def load_file(path: str):
return json.load(f)
elif path.endswith(".pth"):
with fsspec.open(path, "rb") as f:
return torch.load(f, map_location="cpu", weights_only=True)
return torch.load(f, map_location="cpu", weights_only=is_pytorch_at_least_2_4())
else:
raise ValueError("Unsupported file type")

Expand Down
8 changes: 8 additions & 0 deletions TTS/utils/generic_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@
from pathlib import Path
from typing import Dict, Optional

import torch
from packaging.version import Version

logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -131,3 +134,8 @@ def setup_logger(
sh = logging.StreamHandler()
sh.setFormatter(formatter)
lg.addHandler(sh)


def is_pytorch_at_least_2_4() -> bool:
"""Check if the installed Pytorch version is 2.4 or higher."""
return Version(torch.__version__) >= Version("2.4")
3 changes: 2 additions & 1 deletion TTS/vc/modules/freevc/wavlm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import torch
from trainer.io import get_user_data_dir

from TTS.utils.generic_utils import is_pytorch_at_least_2_4
from TTS.vc.modules.freevc.wavlm.wavlm import WavLM, WavLMConfig

logger = logging.getLogger(__name__)
Expand All @@ -26,7 +27,7 @@ def get_wavlm(device="cpu"):
logger.info("Downloading WavLM model to %s ...", output_path)
urllib.request.urlretrieve(model_uri, output_path)

checkpoint = torch.load(output_path, map_location=torch.device(device), weights_only=True)
checkpoint = torch.load(output_path, map_location=torch.device(device), weights_only=is_pytorch_at_least_2_4())
cfg = WavLMConfig(checkpoint["cfg"])
wavlm = WavLM(cfg).to(device)
wavlm.load_state_dict(checkpoint["model"])
Expand Down
Loading
Loading