Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support of the XTTS 2 in tts-server #227

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions TTS/server/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,15 +119,15 @@ def create_argparser() -> argparse.ArgumentParser:
use_cuda=args.use_cuda,
)

speaker_manager = getattr(synthesizer.tts_model, "speaker_manager", None)
use_multi_speaker = hasattr(synthesizer.tts_model, "num_speakers") and (
synthesizer.tts_model.num_speakers > 1 or synthesizer.tts_speakers_file is not None
)
speaker_manager = getattr(synthesizer.tts_model, "speaker_manager", None)
) or (speaker_manager is not None)

language_manager = getattr(synthesizer.tts_model, "language_manager", None)
use_multi_language = hasattr(synthesizer.tts_model, "num_languages") and (
synthesizer.tts_model.num_languages > 1 or synthesizer.tts_languages_file is not None
)
language_manager = getattr(synthesizer.tts_model, "language_manager", None)
) or (language_manager is not None)

# TODO: set this from SpeakerManager
use_gst = synthesizer.tts_config.get("use_gst", False)
Expand Down
6 changes: 4 additions & 2 deletions TTS/tts/models/xtts.py
Original file line number Diff line number Diff line change
Expand Up @@ -723,8 +723,8 @@ def get_compatible_checkpoint_state_dict(self, model_path):
def load_checkpoint(
self,
config: "XttsConfig",
checkpoint_dir: Optional[str] = None,
checkpoint_path: Optional[str] = None,
checkpoint_dir: Optional[str] = None,
vocab_path: Optional[str] = None,
eval: bool = True,
strict: bool = True,
Expand All @@ -736,15 +736,17 @@ def load_checkpoint(

Args:
config (dict): The configuration dictionary for the model.
checkpoint_dir (str, optional): The directory where the checkpoint is stored. Defaults to None.
checkpoint_path (str, optional): The path to the checkpoint file. Defaults to None.
checkpoint_dir (str, optional): The directory where the checkpoint is stored. Defaults to None.
vocab_path (str, optional): The path to the vocabulary file. Defaults to None.
eval (bool, optional): Whether to set the model to evaluation mode. Defaults to True.
strict (bool, optional): Whether to strictly enforce that the keys in the checkpoint match the keys in the model. Defaults to True.

Returns:
None
"""
if checkpoint_dir is None and checkpoint_path:
checkpoint_dir = os.path.dirname(checkpoint_path)
if checkpoint_dir is not None and Path(checkpoint_dir).is_file():
msg = f"You passed a file to `checkpoint_dir=`. Use `checkpoint_path={checkpoint_dir}` instead."
raise ValueError(msg)
Expand Down
2 changes: 1 addition & 1 deletion TTS/utils/manage.py
Original file line number Diff line number Diff line change
Expand Up @@ -431,7 +431,7 @@ def download_model(self, model_name: str) -> tuple[Path, Optional[Path], ModelIt
output_model_path = output_path
output_config_path = None
if (
model not in ["tortoise-v2", "bark"] and "fairseq" not in model_name and "xtts" not in model_name
model not in ["tortoise-v2", "bark"] and "fairseq" not in model_name and "xtts1" not in model_name
): # TODO:This is stupid but don't care for now.
output_model_path, output_config_path = self._find_files(output_path)
else:
Expand Down