From c8396d962587415ad87b63047665d0ac1eb872c2 Mon Sep 17 00:00:00 2001 From: Enno Hermann Date: Thu, 12 Sep 2024 20:52:07 +0200 Subject: [PATCH] chore(bark): remove manual download of hubert model Bark was previously adapted to download Hubert from HuggingFace, so the manual download is superfluous. --- TTS/.models.json | 1 - TTS/tts/configs/bark_config.py | 1 - TTS/tts/layers/bark/hubert/kmeans_hubert.py | 2 +- TTS/tts/layers/bark/inference_funcs.py | 3 +-- TTS/tts/models/bark.py | 3 --- 5 files changed, 2 insertions(+), 8 deletions(-) diff --git a/TTS/.models.json b/TTS/.models.json index a77ebea1cf..a5add6e34f 100644 --- a/TTS/.models.json +++ b/TTS/.models.json @@ -48,7 +48,6 @@ "https://coqui.gateway.scarf.sh/hf/bark/fine_2.pt", "https://coqui.gateway.scarf.sh/hf/bark/text_2.pt", "https://coqui.gateway.scarf.sh/hf/bark/config.json", - "https://coqui.gateway.scarf.sh/hf/bark/hubert.pt", "https://coqui.gateway.scarf.sh/hf/bark/tokenizer.pth" ], "default_vocoder": null, diff --git a/TTS/tts/configs/bark_config.py b/TTS/tts/configs/bark_config.py index 3b893558aa..b846febe85 100644 --- a/TTS/tts/configs/bark_config.py +++ b/TTS/tts/configs/bark_config.py @@ -96,7 +96,6 @@ def __post_init__(self): "coarse": os.path.join(self.CACHE_DIR, "coarse_2.pt"), "fine": os.path.join(self.CACHE_DIR, "fine_2.pt"), "hubert_tokenizer": os.path.join(self.CACHE_DIR, "tokenizer.pth"), - "hubert": os.path.join(self.CACHE_DIR, "hubert.pt"), } self.SMALL_REMOTE_MODEL_PATHS = { "text": {"path": os.path.join(self.REMOTE_BASE_URL, "text.pt")}, diff --git a/TTS/tts/layers/bark/hubert/kmeans_hubert.py b/TTS/tts/layers/bark/hubert/kmeans_hubert.py index 9e487b1e9d..58a614cb87 100644 --- a/TTS/tts/layers/bark/hubert/kmeans_hubert.py +++ b/TTS/tts/layers/bark/hubert/kmeans_hubert.py @@ -40,7 +40,7 @@ class CustomHubert(nn.Module): or you can train your own """ - def __init__(self, checkpoint_path, target_sample_hz=16000, seq_len_multiple_of=None, output_layer=9, device=None): + def __init__(self, target_sample_hz=16000, seq_len_multiple_of=None, output_layer=9, device=None): super().__init__() self.target_sample_hz = target_sample_hz self.seq_len_multiple_of = seq_len_multiple_of diff --git a/TTS/tts/layers/bark/inference_funcs.py b/TTS/tts/layers/bark/inference_funcs.py index b2875c7a83..65c7800dcf 100644 --- a/TTS/tts/layers/bark/inference_funcs.py +++ b/TTS/tts/layers/bark/inference_funcs.py @@ -134,10 +134,9 @@ def generate_voice( # generate semantic tokens # Load the HuBERT model hubert_manager = HubertManager() - # hubert_manager.make_sure_hubert_installed(model_path=model.config.LOCAL_MODEL_PATHS["hubert"]) hubert_manager.make_sure_tokenizer_installed(model_path=model.config.LOCAL_MODEL_PATHS["hubert_tokenizer"]) - hubert_model = CustomHubert(checkpoint_path=model.config.LOCAL_MODEL_PATHS["hubert"]).to(model.device) + hubert_model = CustomHubert().to(model.device) # Load the CustomTokenizer model tokenizer = HubertTokenizer.load_from_checkpoint( diff --git a/TTS/tts/models/bark.py b/TTS/tts/models/bark.py index cdfb5efae4..ced8f60ed8 100644 --- a/TTS/tts/models/bark.py +++ b/TTS/tts/models/bark.py @@ -243,7 +243,6 @@ def load_checkpoint( text_model_path=None, coarse_model_path=None, fine_model_path=None, - hubert_model_path=None, hubert_tokenizer_path=None, eval=False, strict=True, @@ -266,13 +265,11 @@ def load_checkpoint( text_model_path = text_model_path or os.path.join(checkpoint_dir, "text_2.pt") coarse_model_path = coarse_model_path or os.path.join(checkpoint_dir, "coarse_2.pt") fine_model_path = fine_model_path or os.path.join(checkpoint_dir, "fine_2.pt") - hubert_model_path = hubert_model_path or os.path.join(checkpoint_dir, "hubert.pt") hubert_tokenizer_path = hubert_tokenizer_path or os.path.join(checkpoint_dir, "tokenizer.pth") self.config.LOCAL_MODEL_PATHS["text"] = text_model_path self.config.LOCAL_MODEL_PATHS["coarse"] = coarse_model_path self.config.LOCAL_MODEL_PATHS["fine"] = fine_model_path - self.config.LOCAL_MODEL_PATHS["hubert"] = hubert_model_path self.config.LOCAL_MODEL_PATHS["hubert_tokenizer"] = hubert_tokenizer_path self.load_bark_models()