From a4c8a5a14fd4bd1459d247ae07c5563abe60af93 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=BA=90=E6=96=87=E9=9B=A8?= <41315874+fumiama@users.noreply.github.com> Date: Thu, 5 Sep 2024 01:12:55 +0900 Subject: [PATCH 1/2] optimize: revert default device to cpu to satisfy non-cuda users --- ChatTTS/core.py | 7 ++++--- ChatTTS/model/dvae.py | 4 ++-- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/ChatTTS/core.py b/ChatTTS/core.py index fb52bdea5..c38ad8957 100644 --- a/ChatTTS/core.py +++ b/ChatTTS/core.py @@ -272,7 +272,7 @@ def _load( vq_config=asdict(self.config.dvae.vq), dim=self.config.dvae.decoder.idim, coef=coef, - device=self.device, + device=device, ) .to(device) .eval() @@ -289,8 +289,8 @@ def _load( self.config.embed.num_text_tokens, self.config.embed.num_vq, ) - embed.from_pretrained(embed_path, device=self.device) - self.embed = embed.to(self.device) + embed.from_pretrained(embed_path, device=device) + self.embed = embed.to(device) self.logger.log(logging.INFO, "embed loaded.") gpt = GPT( @@ -318,6 +318,7 @@ def _load( decoder_config=asdict(self.config.decoder), dim=self.config.decoder.idim, coef=coef, + device=device, ) .to(device) .eval() diff --git a/ChatTTS/model/dvae.py b/ChatTTS/model/dvae.py index 3a966eaed..7e6b62a83 100644 --- a/ChatTTS/model/dvae.py +++ b/ChatTTS/model/dvae.py @@ -179,7 +179,7 @@ def __init__( hop_length=256, n_mels=100, padding: Literal["center", "same"] = "center", - device: torch.device = torch.device("cuda"), + device: torch.device = torch.device("cpu"), ): super().__init__() self.device = device @@ -213,7 +213,7 @@ def __init__( vq_config: Optional[dict] = None, dim=512, coef: Optional[str] = None, - device: torch.device = torch.device("cuda"), + device: torch.device = torch.device("cpu"), ): super().__init__() if coef is None: From 8fcc0cd6ae162ff8f2d65a2b355aaafb47d7e9e8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=BA=90=E6=96=87=E9=9B=A8?= <41315874+fumiama@users.noreply.github.com> Date: Thu, 5 Sep 2024 01:15:51 +0900 Subject: [PATCH 2/2] fix(colab): zero shot import --- examples/ipynb/colab.ipynb | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/ipynb/colab.ipynb b/examples/ipynb/colab.ipynb index 0c6ad32e5..ab943fcce 100644 --- a/examples/ipynb/colab.ipynb +++ b/examples/ipynb/colab.ipynb @@ -355,7 +355,7 @@ "metadata": {}, "outputs": [], "source": [ - "from tools.audio import load_audio\n", + "from ChatTTS.tools.audio import load_audio\n", "\n", "spk_smp = chat.sample_audio_speaker(load_audio(\"sample.mp3\", 24000))\n", "print(spk_smp) # save it in order to load the speaker without sample audio next time\n",