Skip to content

Commit

Permalink
Implement tts_iter
Browse files Browse the repository at this point in the history
  • Loading branch information
jwc20 committed Jul 6, 2024
1 parent 66de122 commit 66cc9c4
Showing 1 changed file with 211 additions and 48 deletions.
259 changes: 211 additions & 48 deletions melo/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,22 +27,22 @@

start_time = datetime.now()


class TTS(nn.Module):
def __init__(self,
language,
device='auto',
use_hf=True,
config_path=None,
ckpt_path=None):
def __init__(
self, language, device="auto", use_hf=True, config_path=None, ckpt_path=None
):
super().__init__()
if device == 'auto':
device = 'cpu'
if torch.cuda.is_available(): device = 'cuda'
if torch.backends.mps.is_available(): device = 'mps'
if 'cuda' in device:
if device == "auto":
device = "cpu"
if torch.cuda.is_available():
device = "cuda"
if torch.backends.mps.is_available():
device = "mps"
if "cuda" in device:
assert torch.cuda.is_available()

# config_path =
# config_path =
hps = load_or_download_config(language, use_hf=use_hf, config_path=config_path)

num_languages = hps.num_languages
Expand All @@ -64,16 +64,20 @@ def __init__(self,
self.symbol_to_id = {s: i for i, s in enumerate(symbols)}
self.hps = hps
self.device = device

# load state_dict
checkpoint_dict = load_or_download_model(language, device, use_hf=use_hf, ckpt_path=ckpt_path)
self.model.load_state_dict(checkpoint_dict['model'], strict=True)

language = language.split('_')[0]
self.language = 'ZH_MIX_EN' if language == 'ZH' else language # we support a ZH_MIX_EN model
checkpoint_dict = load_or_download_model(
language, device, use_hf=use_hf, ckpt_path=ckpt_path
)
self.model.load_state_dict(checkpoint_dict["model"], strict=True)

language = language.split("_")[0]
self.language = (
"ZH_MIX_EN" if language == "ZH" else language
) # we support a ZH_MIX_EN model

@staticmethod
def audio_numpy_concat(segment_data_list, sr, speed=1.):
def audio_numpy_concat(segment_data_list, sr, speed=1.0):
audio_segments = []
for segment_data in segment_data_list:
audio_segments += segment_data.reshape(-1).tolist()
Expand All @@ -86,11 +90,24 @@ def split_sentences_into_pieces(text, language, quiet=False):
texts = split_sentence(text, language_str=language)
if not quiet:
print(" > Text split to sentences.")
print('\n'.join(texts))
print("\n".join(texts))
print(" > ===========================")
return texts

def tts_to_file(self, text, speaker_id, output_path=None, sdp_ratio=0.2, noise_scale=0.6, noise_scale_w=0.8, speed=1.0, pbar=None, format=None, position=None, quiet=False,):
def tts_to_file(
self,
text,
speaker_id,
output_path=None,
sdp_ratio=0.2,
noise_scale=0.6,
noise_scale_w=0.8,
speed=1.0,
pbar=None,
format=None,
position=None,
quiet=False,
):
language = self.language
texts = self.split_sentences_into_pieces(text, language, quiet)
audio_list = []
Expand All @@ -104,10 +121,12 @@ def tts_to_file(self, text, speaker_id, output_path=None, sdp_ratio=0.2, noise_s
else:
tx = tqdm(texts)
for t in tx:
if language in ['EN', 'ZH_MIX_EN']:
t = re.sub(r'([a-z])([A-Z])', r'\1 \2', t)
if language in ["EN", "ZH_MIX_EN"]:
t = re.sub(r"([a-z])([A-Z])", r"\1 \2", t)
device = self.device
bert, ja_bert, phones, tones, lang_ids = utils.get_text_for_tts_infer(t, language, self.hps, device, self.symbol_to_id)
bert, ja_bert, phones, tones, lang_ids = utils.get_text_for_tts_infer(
t, language, self.hps, device, self.symbol_to_id
)
with torch.no_grad():
x_tst = phones.to(device).unsqueeze(0)
tones = tones.to(device).unsqueeze(0)
Expand All @@ -117,7 +136,8 @@ def tts_to_file(self, text, speaker_id, output_path=None, sdp_ratio=0.2, noise_s
x_tst_lengths = torch.LongTensor([phones.size(0)]).to(device)
del phones
speakers = torch.LongTensor([speaker_id]).to(device)
audio = self.model.infer(
audio = (
self.model.infer(
x_tst,
x_tst_lengths,
speakers,
Expand All @@ -128,26 +148,43 @@ def tts_to_file(self, text, speaker_id, output_path=None, sdp_ratio=0.2, noise_s
sdp_ratio=sdp_ratio,
noise_scale=noise_scale,
noise_scale_w=noise_scale_w,
length_scale=1. / speed,
)[0][0, 0].data.cpu().float().numpy()
length_scale=1.0 / speed,
)[0][0, 0]
.data.cpu()
.float()
.numpy()
)
del x_tst, tones, lang_ids, bert, ja_bert, x_tst_lengths, speakers
#
#
audio_list.append(audio)
torch.cuda.empty_cache()
audio = self.audio_numpy_concat(audio_list, sr=self.hps.data.sampling_rate, speed=speed)
audio = self.audio_numpy_concat(
audio_list, sr=self.hps.data.sampling_rate, speed=speed
)

if output_path is None:
return audio
else:
if format:
soundfile.write(output_path, audio, self.hps.data.sampling_rate, format=format)
soundfile.write(
output_path, audio, self.hps.data.sampling_rate, format=format
)
else:
soundfile.write(output_path, audio, self.hps.data.sampling_rate)




def tts_to_base64(self, text, speaker_id, sdp_ratio=0.2, noise_scale=0.6, noise_scale_w=0.8, speed=1.0, pbar=None, format=None, position=None, quiet=False,):
def old_tts_to_base64(
self,
text,
speaker_id,
sdp_ratio=0.2,
noise_scale=0.6,
noise_scale_w=0.8,
speed=1.0,
pbar=None,
format=None,
position=None,
quiet=False,
):
language = self.language
texts = self.split_sentences_into_pieces(text, language, quiet)
audio_list = []
Expand All @@ -161,10 +198,12 @@ def tts_to_base64(self, text, speaker_id, sdp_ratio=0.2, noise_scale=0.6, noise_
else:
tx = tqdm(texts)
for t in tx:
if language in ['EN', 'ZH_MIX_EN']:
t = re.sub(r'([a-z])([A-Z])', r'\1 \2', t)
if language in ["EN", "ZH_MIX_EN"]:
t = re.sub(r"([a-z])([A-Z])", r"\1 \2", t)
device = self.device
bert, ja_bert, phones, tones, lang_ids = utils.get_text_for_tts_infer(t, language, self.hps, device, self.symbol_to_id)
bert, ja_bert, phones, tones, lang_ids = utils.get_text_for_tts_infer(
t, language, self.hps, device, self.symbol_to_id
)
with torch.no_grad():
x_tst = phones.to(device).unsqueeze(0)
tones = tones.to(device).unsqueeze(0)
Expand All @@ -174,7 +213,8 @@ def tts_to_base64(self, text, speaker_id, sdp_ratio=0.2, noise_scale=0.6, noise_
x_tst_lengths = torch.LongTensor([phones.size(0)]).to(device)
del phones
speakers = torch.LongTensor([speaker_id]).to(device)
audio = self.model.infer(
audio = (
self.model.infer(
x_tst,
x_tst_lengths,
speakers,
Expand All @@ -185,26 +225,149 @@ def tts_to_base64(self, text, speaker_id, sdp_ratio=0.2, noise_scale=0.6, noise_
sdp_ratio=sdp_ratio,
noise_scale=noise_scale,
noise_scale_w=noise_scale_w,
length_scale=1. / speed,
)[0][0, 0].data.cpu().float().numpy()
length_scale=1.0 / speed,
)[0][0, 0]
.data.cpu()
.float()
.numpy()
)
del x_tst, tones, lang_ids, bert, ja_bert, x_tst_lengths, speakers
#
#
audio_list.append(audio)
torch.cuda.empty_cache()
audio = self.audio_numpy_concat(audio_list, sr=self.hps.data.sampling_rate, speed=speed)
audio = self.audio_numpy_concat(
audio_list, sr=self.hps.data.sampling_rate, speed=speed
)

with io.BytesIO() as wav_buffer:
soundfile.write(wav_buffer, audio, self.hps.data.sampling_rate, format="WAV")
soundfile.write(
wav_buffer, audio, self.hps.data.sampling_rate, format="WAV"
)
wav_buffer.seek(0)
wav_bytes = wav_buffer.read()


wav_base64 = base64.b64encode(wav_bytes).decode("utf-8")
end_time = datetime.now()
elapsed_time = end_time - start_time

return jsonable_encoder({
"audioContent": wav_base64,
"time_taken": elapsed_time
})
return jsonable_encoder(
{"audioContent": wav_base64, "time_taken": elapsed_time}
)

def tts_iter(
self,
text,
speaker_id,
sdp_ratio=0.2,
noise_scale=0.6,
noise_scale_w=0.8,
speed=1.0,
pbar=None,
position=None,
quiet=False,
):
"""
https://github.com/myshell-ai/MeloTTS/pull/88/files
"""
language = self.language
texts = self.split_sentences_into_pieces(text, language, quiet)

if pbar:
tx = pbar(texts)
else:
if position:
tx = tqdm(texts, position=position)
elif quiet:
tx = texts
else:
tx = tqdm(texts)
for t in tx:
if language in ["EN", "ZH_MIX_EN"]:
t = re.sub(r"([a-z])([A-Z])", r"\1 \2", t)
device = self.device
bert, ja_bert, phones, tones, lang_ids = utils.get_text_for_tts_infer(
t, language, self.hps, device, self.symbol_to_id
)
with torch.no_grad():
x_tst = phones.to(device).unsqueeze(0)
tones = tones.to(device).unsqueeze(0)
lang_ids = lang_ids.to(device).unsqueeze(0)
bert = bert.to(device).unsqueeze(0)
ja_bert = ja_bert.to(device).unsqueeze(0)
x_tst_lengths = torch.LongTensor([phones.size(0)]).to(device)
del phones
speakers = torch.LongTensor([speaker_id]).to(device)
audio = (
self.model.infer(
x_tst,
x_tst_lengths,
speakers,
tones,
lang_ids,
bert,
ja_bert,
sdp_ratio=sdp_ratio,
noise_scale=noise_scale,
noise_scale_w=noise_scale_w,
length_scale=1.0 / speed,
)[0][0, 0]
.data.cpu()
.float()
.numpy()
)
del x_tst, tones, lang_ids, bert, ja_bert, x_tst_lengths, speakers

audio_segments = []
audio_segments += audio.reshape(-1).tolist()
audio_segments += [0] * int(
(self.hps.data.sampling_rate * 0.05) / speed
)
audio_segments = np.array(audio_segments).astype(np.float32)

yield audio_segments

torch.cuda.empty_cache()

def tts_to_base64(
self,
text,
speaker_id,
sdp_ratio=0.2,
noise_scale=0.6,
noise_scale_w=0.8,
speed=1.0,
pbar=None,
format=None,
position=None,
quiet=False,
):
audio_list = []
for audio in self.tts_iter(
text,
speaker_id,
sdp_ratio,
noise_scale,
noise_scale_w,
speed,
pbar,
position,
quiet,
):
audio_list.append(audio)

audio = np.concatenate(audio_list)

with io.BytesIO() as wav_buffer:
soundfile.write(
wav_buffer, audio, self.hps.data.sampling_rate, format="WAV"
)
wav_buffer.seek(0)
wav_bytes = wav_buffer.read()

wav_base64 = base64.b64encode(wav_bytes).decode("utf-8")
end_time = datetime.now()
elapsed_time = end_time - start_time

return jsonable_encoder(
{"audioContent": wav_base64, "time_taken": elapsed_time}
)

0 comments on commit 66cc9c4

Please sign in to comment.