diff --git a/bark/generation.py b/bark/generation.py index 54f98709..3f90e058 100644 --- a/bark/generation.py +++ b/bark/generation.py @@ -54,18 +54,7 @@ def autocast(): SUPPORTED_LANGS = [ ("English", "en"), - ("German", "de"), - ("Spanish", "es"), - ("French", "fr"), ("Hindi", "hi"), - ("Italian", "it"), - ("Japanese", "ja"), - ("Korean", "ko"), - ("Polish", "pl"), - ("Portuguese", "pt"), - ("Russian", "ru"), - ("Turkish", "tr"), - ("Chinese", "zh"), ] ALLOWED_PROMPTS = {"announcer"} @@ -129,13 +118,25 @@ def _cast_bool_env_var(s): ) +import torch + +if hasattr(torch.nn.functional, 'flash_attention'): + print("------------------------------------------------->Flash Attention is available in PyTorch.") + flash_attention_available = True +else: + # print("------------------------------------------------->Flash Attention is NOT available in PyTorch.") + flash_attention_available = False + + + + def _grab_best_device(use_gpu=True): if torch.cuda.device_count() > 0 and use_gpu: device = "cuda" elif torch.backends.mps.is_available() and use_gpu and GLOBAL_ENABLE_MPS: device = "mps" else: - device = "cpu" + device = "cuda" return device @@ -251,7 +252,7 @@ def _load_model(ckpt_path, device, use_small=False, model_type="text"): def _load_codec_model(device): model = EncodecModel.encodec_model_24khz() - model.set_target_bandwidth(6.0) + model.set_target_bandwidth(3.0) model.eval() model.to(device) _clear_cuda_cache() @@ -268,7 +269,7 @@ def load_model(use_gpu=True, use_small=False, force_reload=False, model_type="te model_key = f"{model_type}" if OFFLOAD_CPU: models_devices[model_key] = device - device = "cpu" + device = "cuda" if model_key not in models or force_reload: ckpt_path = _get_ckpt_path(model_type, use_small=use_small) clean_models(model_key=model_key) @@ -287,11 +288,11 @@ def load_codec_model(use_gpu=True, force_reload=False): device = _grab_best_device(use_gpu=use_gpu) if device == "mps": # encodec doesn't support mps - device = "cpu" + device = "cuda" model_key = "codec" if OFFLOAD_CPU: models_devices[model_key] = device - device = "cpu" + device = "cuda" if model_key not in models or force_reload: clean_models(model_key=model_key) model = _load_codec_model(device) @@ -311,7 +312,7 @@ def preload_models( force_reload=False, ): """Load all the necessary models for the pipeline.""" - if _grab_best_device() == "cpu" and ( + if _grab_best_device() == "cuda" and ( text_use_gpu or coarse_use_gpu or fine_use_gpu or codec_use_gpu ): logger.warning("No GPU being used. Careful, inference might be very slow!") @@ -508,7 +509,7 @@ def generate_text_semantic( pbar.close() out = x.detach().cpu().numpy().squeeze()[256 + 256 + 1 :] if OFFLOAD_CPU: - model.to("cpu") + model.to("cuda") assert all(0 <= out) and all(out < SEMANTIC_VOCAB_SIZE) _clear_cuda_cache() return out @@ -527,6 +528,10 @@ def _flatten_codebooks(arr, offset_size=CODEBOOK_SIZE): COARSE_SEMANTIC_PAD_TOKEN = 12_048 COARSE_INFER_TOKEN = 12_050 +import torch.cuda +import numpy as np +import torch.nn.functional as F +import tqdm def generate_coarse( x_semantic, @@ -536,25 +541,45 @@ def generate_coarse( top_p=None, silent=False, max_coarse_history=630, # min 60 (faster), max 630 (more context) - sliding_window_len=60, - use_kv_caching=False, + sliding_window_len=120, + use_kv_caching=True, + # kv_cache_dtype = torch.bfloat16, + num_streams=4 # New parameter to control number of CUDA streams ): - """Generate coarse audio codes from semantic tokens.""" + """Generate coarse audio codes from semantic tokens with CUDA stream optimization. + + Args: + ... (existing args remain the same) ... + num_streams: Number of CUDA streams to use for parallel processing + """ + # Original input validation assert ( isinstance(x_semantic, np.ndarray) and len(x_semantic.shape) == 1 and len(x_semantic) > 0 and x_semantic.min() >= 0 and x_semantic.max() <= SEMANTIC_VOCAB_SIZE - 1 + and 60 <= max_coarse_history <= 630 + and max_coarse_history + sliding_window_len <= 1024 - 256 ) - assert 60 <= max_coarse_history <= 630 - assert max_coarse_history + sliding_window_len <= 1024 - 256 + + # Initialize CUDA streams only if CUDA is available + use_cuda = torch.cuda.is_available() + if use_cuda: + streams = [torch.cuda.Stream() for _ in range(num_streams)] + else: + streams = [None] * num_streams + semantic_to_coarse_ratio = COARSE_RATE_HZ / SEMANTIC_RATE_HZ * N_COARSE_CODEBOOKS max_semantic_history = int(np.floor(max_coarse_history / semantic_to_coarse_ratio)) + + # History prompt processing if history_prompt is not None: history_prompt = _load_history_prompt(history_prompt) x_semantic_history = history_prompt["semantic_prompt"] x_coarse_history = history_prompt["coarse_prompt"] + + # Original history prompt validation assert ( isinstance(x_semantic_history, np.ndarray) and len(x_semantic_history.shape) == 1 @@ -572,33 +597,42 @@ def generate_coarse( == round(semantic_to_coarse_ratio / N_COARSE_CODEBOOKS, 1) ) ) - x_coarse_history = _flatten_codebooks(x_coarse_history) + SEMANTIC_VOCAB_SIZE - # trim histories correctly - n_semantic_hist_provided = np.min( - [ - max_semantic_history, - len(x_semantic_history) - len(x_semantic_history) % 2, - int(np.floor(len(x_coarse_history) / semantic_to_coarse_ratio)), - ] + + # Process history using the first stream if CUDA is available + if use_cuda: + torch.cuda.synchronize() + with torch.cuda.stream(streams[0]): + x_coarse_history = _flatten_codebooks(x_coarse_history) + SEMANTIC_VOCAB_SIZE + else: + x_coarse_history = _flatten_codebooks(x_coarse_history) + SEMANTIC_VOCAB_SIZE + + n_semantic_hist_provided = min( + max_semantic_history, + len(x_semantic_history) - len(x_semantic_history) % 2, + int(np.floor(len(x_coarse_history) / semantic_to_coarse_ratio)) ) n_coarse_hist_provided = int(round(n_semantic_hist_provided * semantic_to_coarse_ratio)) x_semantic_history = x_semantic_history[-n_semantic_hist_provided:].astype(np.int32) - x_coarse_history = x_coarse_history[-n_coarse_hist_provided:].astype(np.int32) - # TODO: bit of a hack for time alignment (sounds better) - x_coarse_history = x_coarse_history[:-2] + x_coarse_history = x_coarse_history[-n_coarse_hist_provided:-2].astype(np.int32) else: x_semantic_history = np.array([], dtype=np.int32) x_coarse_history = np.array([], dtype=np.int32) - # load models if not yet exist + + # Model loading and device setup global models global models_devices if "coarse" not in models: preload_models() model = models["coarse"] if OFFLOAD_CPU: - model.to(models_devices["coarse"]) + if use_cuda: + with torch.cuda.stream(streams[0]): + model.to(models_devices["coarse"]) + else: + model.to(models_devices["coarse"]) device = next(model.parameters()).device - # start loop + + # Pre-calculations n_steps = int( round( np.floor(len(x_semantic) * semantic_to_coarse_ratio / N_COARSE_CODEBOOKS) @@ -606,86 +640,132 @@ def generate_coarse( ) ) assert n_steps > 0 and n_steps % N_COARSE_CODEBOOKS == 0 + + # Prepare input tensors x_semantic = np.hstack([x_semantic_history, x_semantic]).astype(np.int32) x_coarse = x_coarse_history.astype(np.int32) base_semantic_idx = len(x_semantic_history) - with _inference_mode(): + + # Move tensors to device + if use_cuda: + with torch.cuda.stream(streams[0]): + x_semantic_in = torch.from_numpy(x_semantic)[None].to(device) + x_coarse_in = torch.from_numpy(x_coarse)[None].to(device) + infer_token = torch.tensor([COARSE_INFER_TOKEN])[None].to(device) + torch.cuda.synchronize() + else: x_semantic_in = torch.from_numpy(x_semantic)[None].to(device) x_coarse_in = torch.from_numpy(x_coarse)[None].to(device) + infer_token = torch.tensor([COARSE_INFER_TOKEN])[None].to(device) + + with _inference_mode(): n_window_steps = int(np.ceil(n_steps / sliding_window_len)) n_step = 0 - for _ in tqdm.tqdm(range(n_window_steps), total=n_window_steps, disable=silent): - semantic_idx = base_semantic_idx + int(round(n_step / semantic_to_coarse_ratio)) - # pad from right side - x_in = x_semantic_in[:, np.max([0, semantic_idx - max_semantic_history]) :] - x_in = x_in[:, :256] - x_in = F.pad( - x_in, - (0, 256 - x_in.shape[-1]), - "constant", - COARSE_SEMANTIC_PAD_TOKEN, - ) - x_in = torch.hstack( - [ + + for window_idx in tqdm.tqdm(range(n_window_steps), total=n_window_steps, disable=silent): + stream_idx = window_idx % num_streams if use_cuda else 0 + + # Use CUDA stream if available + if use_cuda: + torch.cuda.synchronize() + stream_context = torch.cuda.stream(streams[stream_idx]) + else: + stream_context = nullcontext() + + with stream_context: + semantic_idx = base_semantic_idx + int(round(n_step / semantic_to_coarse_ratio)) + + # Prepare input window + x_in = x_semantic_in[:, max(0, semantic_idx - max_semantic_history):] + x_in = x_in[:, :256] + if x_in.shape[-1] < 256: + x_in = F.pad( + x_in, + (0, 256 - x_in.shape[-1]), + "constant", + COARSE_SEMANTIC_PAD_TOKEN, + ) + + x_in = torch.cat([ x_in, - torch.tensor([COARSE_INFER_TOKEN])[None].to(device), + infer_token, x_coarse_in[:, -max_coarse_history:], - ] - ) - kv_cache = None - for _ in range(sliding_window_len): - if n_step >= n_steps: - continue - is_major_step = n_step % N_COARSE_CODEBOOKS == 0 - - if use_kv_caching and kv_cache is not None: - x_input = x_in[:, [-1]] - else: - x_input = x_in + ], dim=1) + + # Process window + kv_cache = None + for _ in range(sliding_window_len): + if n_step >= n_steps: + continue + + is_major_step = n_step % N_COARSE_CODEBOOKS == 0 + x_input = x_in[:, [-1]] if (use_kv_caching and kv_cache is not None) else x_in + + # Model inference + logits, kv_cache = model(x_input, use_cache=use_kv_caching, past_kv=kv_cache) + + logit_start_idx = SEMANTIC_VOCAB_SIZE + (1 - int(is_major_step)) * CODEBOOK_SIZE + logit_end_idx = SEMANTIC_VOCAB_SIZE + (2 - int(is_major_step)) * CODEBOOK_SIZE + relevant_logits = logits[0, 0, logit_start_idx:logit_end_idx] + + if top_p is not None: + relevant_logits = relevant_logits.detach().cpu().type(torch.float32).numpy() + sorted_indices = np.argsort(relevant_logits)[::-1] + sorted_logits = relevant_logits[sorted_indices] + cumulative_probs = np.cumsum(softmax(sorted_logits)) + sorted_indices_to_remove = cumulative_probs > top_p + sorted_indices_to_remove[1:] = sorted_indices_to_remove[:-1].copy() + sorted_indices_to_remove[0] = False + relevant_logits[sorted_indices[sorted_indices_to_remove]] = -np.inf + relevant_logits = torch.from_numpy(relevant_logits).to(device) + + if top_k is not None: + v, _ = torch.topk(relevant_logits, min(top_k, relevant_logits.size(-1))) + relevant_logits[relevant_logits < v[-1]] = -float("Inf") + + probs = F.softmax(relevant_logits / temp, dim=-1) + item_next = torch.multinomial(probs, num_samples=1).to(torch.int32) + item_next += logit_start_idx + + x_coarse_in = torch.cat((x_coarse_in, item_next[None]), dim=1) + x_in = torch.cat((x_in, item_next[None]), dim=1) + + del logits, relevant_logits, probs, item_next + n_step += 1 + + del x_in + + # Synchronize at the end of each window if using CUDA + if use_cuda: + torch.cuda.synchronize() - logits, kv_cache = model(x_input, use_cache=use_kv_caching, past_kv=kv_cache) - logit_start_idx = ( - SEMANTIC_VOCAB_SIZE + (1 - int(is_major_step)) * CODEBOOK_SIZE - ) - logit_end_idx = ( - SEMANTIC_VOCAB_SIZE + (2 - int(is_major_step)) * CODEBOOK_SIZE - ) - relevant_logits = logits[0, 0, logit_start_idx:logit_end_idx] - if top_p is not None: - # faster to convert to numpy - original_device = relevant_logits.device - relevant_logits = relevant_logits.detach().cpu().type(torch.float32).numpy() - sorted_indices = np.argsort(relevant_logits)[::-1] - sorted_logits = relevant_logits[sorted_indices] - cumulative_probs = np.cumsum(softmax(sorted_logits)) - sorted_indices_to_remove = cumulative_probs > top_p - sorted_indices_to_remove[1:] = sorted_indices_to_remove[:-1].copy() - sorted_indices_to_remove[0] = False - relevant_logits[sorted_indices[sorted_indices_to_remove]] = -np.inf - relevant_logits = torch.from_numpy(relevant_logits) - relevant_logits = relevant_logits.to(original_device) - if top_k is not None: - v, _ = torch.topk(relevant_logits, min(top_k, relevant_logits.size(-1))) - relevant_logits[relevant_logits < v[-1]] = -float("Inf") - probs = F.softmax(relevant_logits / temp, dim=-1) - item_next = torch.multinomial(probs, num_samples=1).to(torch.int32) - item_next += logit_start_idx - x_coarse_in = torch.cat((x_coarse_in, item_next[None]), dim=1) - x_in = torch.cat((x_in, item_next[None]), dim=1) - del logits, relevant_logits, probs, item_next - n_step += 1 - del x_in del x_semantic_in + if OFFLOAD_CPU: - model.to("cpu") - gen_coarse_arr = x_coarse_in.detach().cpu().numpy().squeeze()[len(x_coarse_history) :] + if use_cuda: + with torch.cuda.stream(streams[0]): + model.to("cuda") + torch.cuda.synchronize() + else: + model.to("cuda") + + # Output processing + if use_cuda: + with torch.cuda.stream(streams[0]): + gen_coarse_arr = x_coarse_in.detach().cpu().numpy().squeeze()[len(x_coarse_history):] + torch.cuda.synchronize() + else: + gen_coarse_arr = x_coarse_in.detach().cpu().numpy().squeeze()[len(x_coarse_history):] + del x_coarse_in assert len(gen_coarse_arr) == n_steps + gen_coarse_audio_arr = gen_coarse_arr.reshape(-1, N_COARSE_CODEBOOKS).T - SEMANTIC_VOCAB_SIZE - for n in range(1, N_COARSE_CODEBOOKS): - gen_coarse_audio_arr[n, :] -= n * CODEBOOK_SIZE + offsets = np.arange(1, N_COARSE_CODEBOOKS) * CODEBOOK_SIZE + gen_coarse_audio_arr[1:] -= offsets[:, None] + _clear_cuda_cache() - return gen_coarse_audio_arr + return gen_coarse_audio_arr def generate_fine( @@ -788,7 +868,7 @@ def generate_fine( gen_fine_arr = in_arr.detach().cpu().numpy().squeeze().T del in_arr if OFFLOAD_CPU: - model.to("cpu") + model.to("cuda") gen_fine_arr = gen_fine_arr[:, n_history:] if n_remove_from_end > 0: gen_fine_arr = gen_fine_arr[:, :-n_remove_from_end] @@ -816,5 +896,5 @@ def codec_decode(fine_tokens): audio_arr = out.detach().cpu().numpy().squeeze() del arr, emb, out if OFFLOAD_CPU: - model.to("cpu") - return audio_arr + model.to("cuda") + return audio_arr \ No newline at end of file