Skip to content

Commit

Permalink
feat:support async vllm generator
Browse files Browse the repository at this point in the history
  • Loading branch information
fengyizhu committed Sep 7, 2024
1 parent 651093e commit 1b21c25
Show file tree
Hide file tree
Showing 11 changed files with 688 additions and 178 deletions.
166 changes: 70 additions & 96 deletions ChatTTS/core.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import os
import logging
import tempfile
import uuid
from dataclasses import dataclass, asdict
from typing import Literal, Optional, List, Tuple, Dict, Union
from json import load
Expand Down Expand Up @@ -173,15 +174,17 @@ class RefineTextParams:
min_new_token: int = 0
show_tqdm: bool = True
ensure_non_empty: bool = True
manual_seed: Optional[int] = None
manual_seed: Optional[int] = 0

@dataclass(repr=False, eq=False)
class InferCodeParams(RefineTextParams):
prompt: str = "[speed_5]"
spk_emb: Optional[str] = None
spk_smp: Optional[str] = None
txt_smp: Optional[str] = None
temperature: float = 0.3
top_P: float = 1
top_K: int = 1
temperature: float = 0.01
repetition_penalty: float = 1.05
max_new_token: int = 2048
stream_batch: int = 24
Expand All @@ -193,16 +196,17 @@ def infer(
text,
stream=False,
lang=None,
skip_refine_text=False,
skip_refine_text=True,
refine_text_only=False,
use_decoder=True,
do_text_normalization=True,
do_homophone_replacement=True,
params_refine_text=RefineTextParams(),
params_infer_code=InferCodeParams(),
params_refine_text=None,
params_infer_code=None,
stream_batch_size=16
):
self.context.set(False)
res_gen = self._infer(
return self._infer(
text,
stream,
lang,
Expand All @@ -213,11 +217,8 @@ def infer(
do_homophone_replacement,
params_refine_text,
params_infer_code,
stream_batch_size
)
if stream:
return res_gen
else:
return next(res_gen)

def interrupt(self):
self.context.set(True)
Expand Down Expand Up @@ -338,18 +339,19 @@ def _load(

return self.has_loaded()

def _infer(
async def _infer(
self,
text,
stream=False,
stream=True,
lang=None,
skip_refine_text=False,
skip_refine_text=True,
refine_text_only=False,
use_decoder=True,
do_text_normalization=True,
do_homophone_replacement=True,
params_refine_text=RefineTextParams(),
params_infer_code=InferCodeParams(),
params_refine_text=None,
params_infer_code=None,
stream_batch_size=16
):

assert self.has_loaded(use_decoder=use_decoder)
Expand Down Expand Up @@ -383,41 +385,39 @@ def _infer(
yield text
return

if stream:
length = 0
pass_batch_count = 0
for result in self._infer_code(
length = 0
async for result in self._infer_code(
text,
stream,
self.device,
use_decoder,
params_infer_code,
stream_batch_size,
):
wavs = self._decode_to_wavs(
result.hiddens if use_decoder else result.ids,
use_decoder,
)
result.destroy()
if stream:
pass_batch_count += 1
if pass_batch_count <= params_infer_code.pass_first_n_batches:
continue
a = length
b = a + params_infer_code.stream_speed
if b > wavs.shape[1]:
b = wavs.shape[1]
new_wavs = wavs[:, a:b]
length = b
yield new_wavs

if result.finished:
yield wavs[:, length:]
else:
yield wavs
if stream:
new_wavs = wavs[:, length:]
# Identify rows with non-zero elements using np.any
# keep_rows = np.any(array != 0, axis=1)
keep_cols = np.sum(new_wavs != 0, axis=0) > 0
# Filter both rows and columns using slicing
yield new_wavs[:][:, keep_cols]
# 检查是否有静音的片段,如果有则取最后一个片段,否则尝试再等一个循环
import librosa
silence_intervals = librosa.effects.split(wavs[0][length:], top_db=10)
silence_left = 0
if len(silence_intervals) == 0:
silence_left = len(wavs[0])
else:
# 如果有静音片段,则取最后一个静音片段的左边界
for i in range(len(silence_intervals)):
silence_left = silence_intervals[i][0]
# 如果找不到左边界,说明合成声音连续没有找到中断的位置。
if silence_left <= 0:
continue
new_wavs = wavs[:, length:length + silence_left]
length += len(new_wavs[0])
yield new_wavs

@torch.inference_mode()
def _vocos_decode(self, spec: torch.Tensor) -> np.ndarray:
Expand Down Expand Up @@ -456,13 +456,14 @@ def _decode_to_wavs(
return wavs

@torch.no_grad()
def _infer_code(
async def _infer_code(
self,
text: Tuple[List[str], str],
stream: bool,
device: torch.device,
return_hidden: bool,
params: InferCodeParams,
stream_batch_size: int,
):

gpt = self.gpt
Expand Down Expand Up @@ -503,6 +504,17 @@ def _infer_code(
repetition_penalty=params.repetition_penalty,
)

speaker_embedding_param = gpt(input_ids, text_mask)

if params.spk_emb is not None:
self.speaker.apply(
speaker_embedding_param,
params.spk_emb,
input_ids,
self.tokenizer.spk_emb_ids,
self.gpt.device_gpt,
)

if gpt.is_vllm:
from .model.velocity import SamplingParams

Expand All @@ -518,65 +530,27 @@ def _infer_code(
)
input_ids = [i.tolist() for i in input_ids]

result = gpt.llm.generate(
results_generator = gpt.llm.llm_engine.generate(
None,
sample_params,
input_ids,
uuid.uuid4(),
speaker_embedding_param,
input_ids[0]
)

token_ids = []
hidden_states = []
for i in result:
token_ids.append(torch.tensor(i.outputs[0].token_ids))
hidden_states.append(
i.outputs[0].hidden_states.to(torch.float32).to(self.device)
)

del text_mask, input_ids

return [
GPT.GenerationOutputs(
ids=token_ids,
hiddens=hidden_states,
attentions=[],
),
]

emb = self.embed(input_ids, text_mask)

del text_mask

if params.spk_emb is not None:
self.speaker.apply(
emb,
params.spk_emb,
input_ids,
self.tokenizer.spk_emb_ids,
self.gpt.device_gpt,
)

result = gpt.generate(
emb,
input_ids,
temperature=torch.tensor(temperature, device=device),
eos_token=num_code,
attention_mask=attention_mask,
max_new_token=params.max_new_token,
min_new_token=params.min_new_token,
logits_processors=(*logits_processors, *logits_warpers),
infer_text=False,
return_hidden=return_hidden,
stream=stream,
show_tqdm=params.show_tqdm,
ensure_non_empty=params.ensure_non_empty,
stream_batch=params.stream_batch,
manual_seed=params.manual_seed,
context=self.context,
)

del emb, input_ids

return result
async for i in results_generator:
token_ids = []
hidden_states = []
if len(i.outputs[0].token_ids) % stream_batch_size == 0 or i.finished:
token_ids.append(torch.tensor(i.outputs[0].token_ids))
hidden_states.append(
i.outputs[0].hidden_states.to(torch.float32).to(self.device)
)
yield GPT.GenerationOutputs(
ids=token_ids,
finished=i.finished,
hiddens=hidden_states,
attentions=[],
)

@torch.no_grad()
def _refine_text(
Expand Down
49 changes: 45 additions & 4 deletions ChatTTS/model/gpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,15 +46,16 @@ def __init__(
self.is_te_llama = False
self.is_vllm = use_vllm

self.emb_code = [ec.__call__ for ec in embed.emb_code]
self.emb_text = embed.emb_text.__call__
self.head_text = embed.head_text.__call__
self.head_code = [hc.__call__ for hc in embed.head_code]
if self.is_vllm:
return

self.llama_config = self._build_llama_config(gpt_config)

self.emb_code = [ec.__call__ for ec in embed.emb_code]
self.emb_text = embed.emb_text.__call__
self.head_text = embed.head_text.__call__
self.head_code = [hc.__call__ for hc in embed.head_code]


def from_pretrained(
self, gpt_folder: str, embed_file_path: str, experimental=False
Expand All @@ -68,6 +69,7 @@ def from_pretrained(
num_audio_tokens=self.num_audio_tokens,
num_text_tokens=self.num_text_tokens,
post_model_path=embed_file_path,
dtype="float32"
)
self.logger.info("vLLM model loaded")
return
Expand Down Expand Up @@ -138,6 +140,44 @@ def prepare(self, compile=False):
except RuntimeError as e:
self.logger.warning(f"compile failed: {e}. fallback to normal mode.")

def __call__(
self, input_ids: torch.Tensor, text_mask: torch.Tensor
) -> torch.Tensor:
"""
get_emb
"""
return super().__call__(input_ids, text_mask)

def forward(self, input_ids: torch.Tensor, text_mask: torch.Tensor) -> torch.Tensor:
"""
get_emb
"""
input_ids = input_ids.clone()
text_mask = text_mask.clone()
emb_text: torch.Tensor = self.emb_text(
input_ids[text_mask].narrow(1, 0, 1).squeeze_(1).to(self.device_gpt)
)

text_mask_inv = text_mask.logical_not().to(self.device_gpt)
masked_input_ids: torch.Tensor = input_ids[text_mask_inv].to(self.device_gpt)

emb_code = [
self.emb_code[i](masked_input_ids[:, i]) for i in range(self.num_vq)
]
emb_code = torch.stack(emb_code, 2).sum(2)

emb = torch.zeros(
(input_ids.shape[:-1]) + (emb_text.shape[-1],),
device=emb_text.device,
dtype=emb_text.dtype,
)
emb[text_mask] = emb_text
emb[text_mask_inv] = emb_code.to(emb.dtype)

del emb_text, emb_code, text_mask_inv

return emb

@dataclass(repr=False, eq=False)
class _GenerationInputs:
position_ids: torch.Tensor
Expand Down Expand Up @@ -273,6 +313,7 @@ class GenerationOutputs:
ids: List[torch.Tensor]
attentions: List[Optional[Tuple[torch.FloatTensor, ...]]]
hiddens: List[torch.Tensor]
finished: bool

def destroy(self):
del_all(self.ids)
Expand Down
Loading

0 comments on commit 1b21c25

Please sign in to comment.